View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.socket.nio;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelException;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelFutureListener;
23  import io.netty.channel.ChannelOutboundBuffer;
24  import io.netty.channel.ChannelPromise;
25  import io.netty.channel.EventLoop;
26  import io.netty.channel.FileRegion;
27  import io.netty.channel.RecvByteBufAllocator;
28  import io.netty.channel.nio.AbstractNioByteChannel;
29  import io.netty.channel.socket.DefaultSocketChannelConfig;
30  import io.netty.channel.socket.ServerSocketChannel;
31  import io.netty.channel.socket.SocketChannelConfig;
32  import io.netty.util.concurrent.GlobalEventExecutor;
33  import io.netty.util.internal.PlatformDependent;
34  import io.netty.util.internal.SocketUtils;
35  import io.netty.util.internal.UnstableApi;
36  import io.netty.util.internal.logging.InternalLogger;
37  import io.netty.util.internal.logging.InternalLoggerFactory;
38  
39  import java.io.IOException;
40  import java.net.InetSocketAddress;
41  import java.net.Socket;
42  import java.net.SocketAddress;
43  import java.nio.ByteBuffer;
44  import java.nio.channels.SelectionKey;
45  import java.nio.channels.SocketChannel;
46  import java.nio.channels.spi.SelectorProvider;
47  import java.util.concurrent.Executor;
48  
49  import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
50  
51  /**
52   * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation.
53   */
54  public class NioSocketChannel extends AbstractNioByteChannel implements io.netty.channel.socket.SocketChannel {
55      private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
56      private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
57  
58      private static SocketChannel newSocket(SelectorProvider provider) {
59          try {
60              /**
61               *  Use the {@link SelectorProvider} to open {@link SocketChannel} and so remove condition in
62               *  {@link SelectorProvider#provider()} which is called by each SocketChannel.open() otherwise.
63               *
64               *  See <a href="https://github.com/netty/netty/issues/2308">#2308</a>.
65               */
66              return provider.openSocketChannel();
67          } catch (IOException e) {
68              throw new ChannelException("Failed to open a socket.", e);
69          }
70      }
71  
72      private final SocketChannelConfig config;
73  
74      /**
75       * Create a new instance
76       */
77      public NioSocketChannel() {
78          this(DEFAULT_SELECTOR_PROVIDER);
79      }
80  
81      /**
82       * Create a new instance using the given {@link SelectorProvider}.
83       */
84      public NioSocketChannel(SelectorProvider provider) {
85          this(newSocket(provider));
86      }
87  
88      /**
89       * Create a new instance using the given {@link SocketChannel}.
90       */
91      public NioSocketChannel(SocketChannel socket) {
92          this(null, socket);
93      }
94  
95      /**
96       * Create a new instance
97       *
98       * @param parent    the {@link Channel} which created this instance or {@code null} if it was created by the user
99       * @param socket    the {@link SocketChannel} which will be used
100      */
101     public NioSocketChannel(Channel parent, SocketChannel socket) {
102         super(parent, socket);
103         config = new NioSocketChannelConfig(this, socket.socket());
104     }
105 
106     @Override
107     public ServerSocketChannel parent() {
108         return (ServerSocketChannel) super.parent();
109     }
110 
111     @Override
112     public SocketChannelConfig config() {
113         return config;
114     }
115 
116     @Override
117     protected SocketChannel javaChannel() {
118         return (SocketChannel) super.javaChannel();
119     }
120 
121     @Override
122     public boolean isActive() {
123         SocketChannel ch = javaChannel();
124         return ch.isOpen() && ch.isConnected();
125     }
126 
127     @Override
128     public boolean isOutputShutdown() {
129         return javaChannel().socket().isOutputShutdown() || !isActive();
130     }
131 
132     @Override
133     public boolean isInputShutdown() {
134         return javaChannel().socket().isInputShutdown() || !isActive();
135     }
136 
137     @Override
138     public boolean isShutdown() {
139         Socket socket = javaChannel().socket();
140         return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive();
141     }
142 
143     @Override
144     public InetSocketAddress localAddress() {
145         return (InetSocketAddress) super.localAddress();
146     }
147 
148     @Override
149     public InetSocketAddress remoteAddress() {
150         return (InetSocketAddress) super.remoteAddress();
151     }
152 
153     @UnstableApi
154     @Override
155     protected final void doShutdownOutput() throws Exception {
156         if (PlatformDependent.javaVersion() >= 7) {
157             javaChannel().shutdownOutput();
158         } else {
159             javaChannel().socket().shutdownOutput();
160         }
161     }
162 
163     @Override
164     public ChannelFuture shutdownOutput() {
165         return shutdownOutput(newPromise());
166     }
167 
168     @Override
169     public ChannelFuture shutdownOutput(final ChannelPromise promise) {
170         final EventLoop loop = eventLoop();
171         if (loop.inEventLoop()) {
172             ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
173         } else {
174             loop.execute(new Runnable() {
175                 @Override
176                 public void run() {
177                     ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
178                 }
179             });
180         }
181         return promise;
182     }
183 
184     @Override
185     public ChannelFuture shutdownInput() {
186         return shutdownInput(newPromise());
187     }
188 
189     @Override
190     protected boolean isInputShutdown0() {
191         return isInputShutdown();
192     }
193 
194     @Override
195     public ChannelFuture shutdownInput(final ChannelPromise promise) {
196         EventLoop loop = eventLoop();
197         if (loop.inEventLoop()) {
198             shutdownInput0(promise);
199         } else {
200             loop.execute(new Runnable() {
201                 @Override
202                 public void run() {
203                     shutdownInput0(promise);
204                 }
205             });
206         }
207         return promise;
208     }
209 
210     @Override
211     public ChannelFuture shutdown() {
212         return shutdown(newPromise());
213     }
214 
215     @Override
216     public ChannelFuture shutdown(final ChannelPromise promise) {
217         ChannelFuture shutdownOutputFuture = shutdownOutput();
218         if (shutdownOutputFuture.isDone()) {
219             shutdownOutputDone(shutdownOutputFuture, promise);
220         } else {
221             shutdownOutputFuture.addListener(new ChannelFutureListener() {
222                 @Override
223                 public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
224                     shutdownOutputDone(shutdownOutputFuture, promise);
225                 }
226             });
227         }
228         return promise;
229     }
230 
231     private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
232         ChannelFuture shutdownInputFuture = shutdownInput();
233         if (shutdownInputFuture.isDone()) {
234             shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
235         } else {
236             shutdownInputFuture.addListener(new ChannelFutureListener() {
237                 @Override
238                 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
239                     shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
240                 }
241             });
242         }
243     }
244 
245     private static void shutdownDone(ChannelFuture shutdownOutputFuture,
246                                      ChannelFuture shutdownInputFuture,
247                                      ChannelPromise promise) {
248         Throwable shutdownOutputCause = shutdownOutputFuture.cause();
249         Throwable shutdownInputCause = shutdownInputFuture.cause();
250         if (shutdownOutputCause != null) {
251             if (shutdownInputCause != null) {
252                 logger.debug("Exception suppressed because a previous exception occurred.",
253                         shutdownInputCause);
254             }
255             promise.setFailure(shutdownOutputCause);
256         } else if (shutdownInputCause != null) {
257             promise.setFailure(shutdownInputCause);
258         } else {
259             promise.setSuccess();
260         }
261     }
262     private void shutdownInput0(final ChannelPromise promise) {
263         try {
264             shutdownInput0();
265             promise.setSuccess();
266         } catch (Throwable t) {
267             promise.setFailure(t);
268         }
269     }
270 
271     private void shutdownInput0() throws Exception {
272         if (PlatformDependent.javaVersion() >= 7) {
273             javaChannel().shutdownInput();
274         } else {
275             javaChannel().socket().shutdownInput();
276         }
277     }
278 
279     @Override
280     protected SocketAddress localAddress0() {
281         return javaChannel().socket().getLocalSocketAddress();
282     }
283 
284     @Override
285     protected SocketAddress remoteAddress0() {
286         return javaChannel().socket().getRemoteSocketAddress();
287     }
288 
289     @Override
290     protected void doBind(SocketAddress localAddress) throws Exception {
291         doBind0(localAddress);
292     }
293 
294     private void doBind0(SocketAddress localAddress) throws Exception {
295         if (PlatformDependent.javaVersion() >= 7) {
296             SocketUtils.bind(javaChannel(), localAddress);
297         } else {
298             SocketUtils.bind(javaChannel().socket(), localAddress);
299         }
300     }
301 
302     @Override
303     protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
304         if (localAddress != null) {
305             doBind0(localAddress);
306         }
307 
308         boolean success = false;
309         try {
310             boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
311             if (!connected) {
312                 selectionKey().interestOps(SelectionKey.OP_CONNECT);
313             }
314             success = true;
315             return connected;
316         } finally {
317             if (!success) {
318                 doClose();
319             }
320         }
321     }
322 
323     @Override
324     protected void doFinishConnect() throws Exception {
325         if (!javaChannel().finishConnect()) {
326             throw new Error();
327         }
328     }
329 
330     @Override
331     protected void doDisconnect() throws Exception {
332         doClose();
333     }
334 
335     @Override
336     protected void doClose() throws Exception {
337         super.doClose();
338         javaChannel().close();
339     }
340 
341     @Override
342     protected int doReadBytes(ByteBuf byteBuf) throws Exception {
343         final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
344         allocHandle.attemptedBytesRead(byteBuf.writableBytes());
345         return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
346     }
347 
348     @Override
349     protected int doWriteBytes(ByteBuf buf) throws Exception {
350         final int expectedWrittenBytes = buf.readableBytes();
351         return buf.readBytes(javaChannel(), expectedWrittenBytes);
352     }
353 
354     @Override
355     protected long doWriteFileRegion(FileRegion region) throws Exception {
356         final long position = region.transferred();
357         return region.transferTo(javaChannel(), position);
358     }
359 
360     private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
361         // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change
362         // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try
363         // make a best effort to adjust as OS behavior changes.
364         if (attempted == written) {
365             if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
366                 ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
367             }
368         } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
369             ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
370         }
371     }
372 
373     @Override
374     protected void doWrite(ChannelOutboundBuffer in) throws Exception {
375         SocketChannel ch = javaChannel();
376         int writeSpinCount = config().getWriteSpinCount();
377         do {
378             if (in.isEmpty()) {
379                 // All written so clear OP_WRITE
380                 clearOpWrite();
381                 // Directly return here so incompleteWrite(...) is not called.
382                 return;
383             }
384 
385             // Ensure the pending writes are made of ByteBufs only.
386             int maxBytesPerGatheringWrite = ((NioSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
387             ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
388             int nioBufferCnt = in.nioBufferCount();
389 
390             // Always us nioBuffers() to workaround data-corruption.
391             // See https://github.com/netty/netty/issues/2761
392             switch (nioBufferCnt) {
393                 case 0:
394                     // We have something else beside ByteBuffers to write so fallback to normal writes.
395                     writeSpinCount -= doWrite0(in);
396                     break;
397                 case 1: {
398                     // Only one ByteBuf so use non-gathering write
399                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
400                     // to check if the total size of all the buffers is non-zero.
401                     ByteBuffer buffer = nioBuffers[0];
402                     int attemptedBytes = buffer.remaining();
403                     final int localWrittenBytes = ch.write(buffer);
404                     if (localWrittenBytes <= 0) {
405                         incompleteWrite(true);
406                         return;
407                     }
408                     adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
409                     in.removeBytes(localWrittenBytes);
410                     --writeSpinCount;
411                     break;
412                 }
413                 default: {
414                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
415                     // to check if the total size of all the buffers is non-zero.
416                     // We limit the max amount to int above so cast is safe
417                     long attemptedBytes = in.nioBufferSize();
418                     final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
419                     if (localWrittenBytes <= 0) {
420                         incompleteWrite(true);
421                         return;
422                     }
423                     // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above.
424                     adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
425                             maxBytesPerGatheringWrite);
426                     in.removeBytes(localWrittenBytes);
427                     --writeSpinCount;
428                     break;
429                 }
430             }
431         } while (writeSpinCount > 0);
432 
433         incompleteWrite(writeSpinCount < 0);
434     }
435 
436     @Override
437     protected AbstractNioUnsafe newUnsafe() {
438         return new NioSocketChannelUnsafe();
439     }
440 
441     private final class NioSocketChannelUnsafe extends NioByteUnsafe {
442         @Override
443         protected Executor prepareToClose() {
444             try {
445                 if (javaChannel().isOpen() && config().getSoLinger() > 0) {
446                     // We need to cancel this key of the channel so we may not end up in a eventloop spin
447                     // because we try to read or write until the actual close happens which may be later due
448                     // SO_LINGER handling.
449                     // See https://github.com/netty/netty/issues/4449
450                     doDeregister();
451                     return GlobalEventExecutor.INSTANCE;
452                 }
453             } catch (Throwable ignore) {
454                 // Ignore the error as the underlying channel may be closed in the meantime and so
455                 // getSoLinger() may produce an exception. In this case we just return null.
456                 // See https://github.com/netty/netty/issues/4449
457             }
458             return null;
459         }
460     }
461 
462     private final class NioSocketChannelConfig extends DefaultSocketChannelConfig {
463         private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
464 
465         private NioSocketChannelConfig(NioSocketChannel channel, Socket javaSocket) {
466             super(channel, javaSocket);
467             calculateMaxBytesPerGatheringWrite();
468         }
469 
470         @Override
471         protected void autoReadCleared() {
472             clearReadPending();
473         }
474 
475         @Override
476         public NioSocketChannelConfig setSendBufferSize(int sendBufferSize) {
477             super.setSendBufferSize(sendBufferSize);
478             calculateMaxBytesPerGatheringWrite();
479             return this;
480         }
481 
482         void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
483             this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
484         }
485 
486         int getMaxBytesPerGatheringWrite() {
487             return maxBytesPerGatheringWrite;
488         }
489 
490         private void calculateMaxBytesPerGatheringWrite() {
491             // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide.
492             int newSendBufferSize = getSendBufferSize() << 1;
493             if (newSendBufferSize > 0) {
494                 setMaxBytesPerGatheringWrite(getSendBufferSize() << 1);
495             }
496         }
497     }
498 }