View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.socket.nio;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelException;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelFutureListener;
23  import io.netty.channel.ChannelOption;
24  import io.netty.channel.ChannelOutboundBuffer;
25  import io.netty.channel.ChannelPromise;
26  import io.netty.channel.EventLoop;
27  import io.netty.channel.FileRegion;
28  import io.netty.channel.RecvByteBufAllocator;
29  import io.netty.channel.nio.AbstractNioByteChannel;
30  import io.netty.channel.socket.DefaultSocketChannelConfig;
31  import io.netty.channel.socket.InternetProtocolFamily;
32  import io.netty.channel.socket.ServerSocketChannel;
33  import io.netty.channel.socket.SocketChannelConfig;
34  import io.netty.util.concurrent.GlobalEventExecutor;
35  import io.netty.util.internal.PlatformDependent;
36  import io.netty.util.internal.SocketUtils;
37  import io.netty.util.internal.SuppressJava6Requirement;
38  import io.netty.util.internal.logging.InternalLogger;
39  import io.netty.util.internal.logging.InternalLoggerFactory;
40  
41  
42  import java.io.IOException;
43  import java.lang.reflect.Method;
44  import java.net.InetSocketAddress;
45  import java.net.Socket;
46  import java.net.SocketAddress;
47  import java.nio.ByteBuffer;
48  import java.nio.channels.SelectionKey;
49  import java.nio.channels.SocketChannel;
50  import java.nio.channels.spi.SelectorProvider;
51  import java.util.Map;
52  import java.util.concurrent.Executor;
53  
54  import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
55  
56  /**
57   * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation.
58   */
59  public class NioSocketChannel extends AbstractNioByteChannel implements io.netty.channel.socket.SocketChannel {
60      private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
61      private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
62  
63      private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
64              SelectorProviderUtil.findOpenMethod("openSocketChannel");
65  
66      private final SocketChannelConfig config;
67  
68      private static SocketChannel newChannel(SelectorProvider provider, InternetProtocolFamily family) {
69          try {
70              SocketChannel channel = SelectorProviderUtil.newChannel(OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider, family);
71              return channel == null ? provider.openSocketChannel() : channel;
72          } catch (IOException e) {
73              throw new ChannelException("Failed to open a socket.", e);
74          }
75      }
76  
77      /**
78       * Create a new instance
79       */
80      public NioSocketChannel() {
81          this(DEFAULT_SELECTOR_PROVIDER);
82      }
83  
84      /**
85       * Create a new instance using the given {@link SelectorProvider}.
86       */
87      public NioSocketChannel(SelectorProvider provider) {
88          this(provider, null);
89      }
90  
91      /**
92       * Create a new instance using the given {@link SelectorProvider} and protocol family (supported only since JDK 15).
93       */
94      public NioSocketChannel(SelectorProvider provider, InternetProtocolFamily family) {
95          this(newChannel(provider, family));
96      }
97  
98      /**
99       * Create a new instance using the given {@link SocketChannel}.
100      */
101     public NioSocketChannel(SocketChannel socket) {
102         this(null, socket);
103     }
104 
105     /**
106      * Create a new instance
107      *
108      * @param parent    the {@link Channel} which created this instance or {@code null} if it was created by the user
109      * @param socket    the {@link SocketChannel} which will be used
110      */
111     public NioSocketChannel(Channel parent, SocketChannel socket) {
112         super(parent, socket);
113         config = new NioSocketChannelConfig(this, socket.socket());
114     }
115 
116     @Override
117     public ServerSocketChannel parent() {
118         return (ServerSocketChannel) super.parent();
119     }
120 
121     @Override
122     public SocketChannelConfig config() {
123         return config;
124     }
125 
126     @Override
127     protected SocketChannel javaChannel() {
128         return (SocketChannel) super.javaChannel();
129     }
130 
131     @Override
132     public boolean isActive() {
133         SocketChannel ch = javaChannel();
134         return ch.isOpen() && ch.isConnected();
135     }
136 
137     @Override
138     public boolean isOutputShutdown() {
139         return javaChannel().socket().isOutputShutdown() || !isActive();
140     }
141 
142     @Override
143     public boolean isInputShutdown() {
144         return javaChannel().socket().isInputShutdown() || !isActive();
145     }
146 
147     @Override
148     public boolean isShutdown() {
149         Socket socket = javaChannel().socket();
150         return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive();
151     }
152 
153     @Override
154     public InetSocketAddress localAddress() {
155         return (InetSocketAddress) super.localAddress();
156     }
157 
158     @Override
159     public InetSocketAddress remoteAddress() {
160         return (InetSocketAddress) super.remoteAddress();
161     }
162 
163     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
164     @Override
165     protected final void doShutdownOutput() throws Exception {
166         if (PlatformDependent.javaVersion() >= 7) {
167             javaChannel().shutdownOutput();
168         } else {
169             javaChannel().socket().shutdownOutput();
170         }
171     }
172 
173     @Override
174     public ChannelFuture shutdownOutput() {
175         return shutdownOutput(newPromise());
176     }
177 
178     @Override
179     public ChannelFuture shutdownOutput(final ChannelPromise promise) {
180         final EventLoop loop = eventLoop();
181         if (loop.inEventLoop()) {
182             ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
183         } else {
184             loop.execute(new Runnable() {
185                 @Override
186                 public void run() {
187                     ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
188                 }
189             });
190         }
191         return promise;
192     }
193 
194     @Override
195     public ChannelFuture shutdownInput() {
196         return shutdownInput(newPromise());
197     }
198 
199     @Override
200     protected boolean isInputShutdown0() {
201         return isInputShutdown();
202     }
203 
204     @Override
205     public ChannelFuture shutdownInput(final ChannelPromise promise) {
206         EventLoop loop = eventLoop();
207         if (loop.inEventLoop()) {
208             shutdownInput0(promise);
209         } else {
210             loop.execute(new Runnable() {
211                 @Override
212                 public void run() {
213                     shutdownInput0(promise);
214                 }
215             });
216         }
217         return promise;
218     }
219 
220     @Override
221     public ChannelFuture shutdown() {
222         return shutdown(newPromise());
223     }
224 
225     @Override
226     public ChannelFuture shutdown(final ChannelPromise promise) {
227         ChannelFuture shutdownOutputFuture = shutdownOutput();
228         if (shutdownOutputFuture.isDone()) {
229             shutdownOutputDone(shutdownOutputFuture, promise);
230         } else {
231             shutdownOutputFuture.addListener(new ChannelFutureListener() {
232                 @Override
233                 public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
234                     shutdownOutputDone(shutdownOutputFuture, promise);
235                 }
236             });
237         }
238         return promise;
239     }
240 
241     private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
242         ChannelFuture shutdownInputFuture = shutdownInput();
243         if (shutdownInputFuture.isDone()) {
244             shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
245         } else {
246             shutdownInputFuture.addListener(new ChannelFutureListener() {
247                 @Override
248                 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
249                     shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
250                 }
251             });
252         }
253     }
254 
255     private static void shutdownDone(ChannelFuture shutdownOutputFuture,
256                                      ChannelFuture shutdownInputFuture,
257                                      ChannelPromise promise) {
258         Throwable shutdownOutputCause = shutdownOutputFuture.cause();
259         Throwable shutdownInputCause = shutdownInputFuture.cause();
260         if (shutdownOutputCause != null) {
261             if (shutdownInputCause != null) {
262                 logger.debug("Exception suppressed because a previous exception occurred.",
263                         shutdownInputCause);
264             }
265             promise.setFailure(shutdownOutputCause);
266         } else if (shutdownInputCause != null) {
267             promise.setFailure(shutdownInputCause);
268         } else {
269             promise.setSuccess();
270         }
271     }
272     private void shutdownInput0(final ChannelPromise promise) {
273         try {
274             shutdownInput0();
275             promise.setSuccess();
276         } catch (Throwable t) {
277             promise.setFailure(t);
278         }
279     }
280 
281     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
282     private void shutdownInput0() throws Exception {
283         if (PlatformDependent.javaVersion() >= 7) {
284             javaChannel().shutdownInput();
285         } else {
286             javaChannel().socket().shutdownInput();
287         }
288     }
289 
290     @Override
291     protected SocketAddress localAddress0() {
292         return javaChannel().socket().getLocalSocketAddress();
293     }
294 
295     @Override
296     protected SocketAddress remoteAddress0() {
297         return javaChannel().socket().getRemoteSocketAddress();
298     }
299 
300     @Override
301     protected void doBind(SocketAddress localAddress) throws Exception {
302         doBind0(localAddress);
303     }
304 
305     private void doBind0(SocketAddress localAddress) throws Exception {
306         if (PlatformDependent.javaVersion() >= 7) {
307             SocketUtils.bind(javaChannel(), localAddress);
308         } else {
309             SocketUtils.bind(javaChannel().socket(), localAddress);
310         }
311     }
312 
313     @Override
314     protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
315         if (localAddress != null) {
316             doBind0(localAddress);
317         }
318 
319         boolean success = false;
320         try {
321             boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
322             if (!connected) {
323                 selectionKey().interestOps(SelectionKey.OP_CONNECT);
324             }
325             success = true;
326             return connected;
327         } finally {
328             if (!success) {
329                 doClose();
330             }
331         }
332     }
333 
334     @Override
335     protected void doFinishConnect() throws Exception {
336         if (!javaChannel().finishConnect()) {
337             throw new Error();
338         }
339     }
340 
341     @Override
342     protected void doDisconnect() throws Exception {
343         doClose();
344     }
345 
346     @Override
347     protected void doClose() throws Exception {
348         super.doClose();
349         javaChannel().close();
350     }
351 
352     @Override
353     protected int doReadBytes(ByteBuf byteBuf) throws Exception {
354         final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
355         allocHandle.attemptedBytesRead(byteBuf.writableBytes());
356         return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
357     }
358 
359     @Override
360     protected int doWriteBytes(ByteBuf buf) throws Exception {
361         final int expectedWrittenBytes = buf.readableBytes();
362         return buf.readBytes(javaChannel(), expectedWrittenBytes);
363     }
364 
365     @Override
366     protected long doWriteFileRegion(FileRegion region) throws Exception {
367         final long position = region.transferred();
368         return region.transferTo(javaChannel(), position);
369     }
370 
371     private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
372         // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change
373         // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try
374         // make a best effort to adjust as OS behavior changes.
375         if (attempted == written) {
376             if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
377                 ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
378             }
379         } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
380             ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
381         }
382     }
383 
384     @Override
385     protected void doWrite(ChannelOutboundBuffer in) throws Exception {
386         SocketChannel ch = javaChannel();
387         int writeSpinCount = config().getWriteSpinCount();
388         do {
389             if (in.isEmpty()) {
390                 // All written so clear OP_WRITE
391                 clearOpWrite();
392                 // Directly return here so incompleteWrite(...) is not called.
393                 return;
394             }
395 
396             // Ensure the pending writes are made of ByteBufs only.
397             int maxBytesPerGatheringWrite = ((NioSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
398             ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
399             int nioBufferCnt = in.nioBufferCount();
400 
401             // Always use nioBuffers() to workaround data-corruption.
402             // See https://github.com/netty/netty/issues/2761
403             switch (nioBufferCnt) {
404                 case 0:
405                     // We have something else beside ByteBuffers to write so fallback to normal writes.
406                     writeSpinCount -= doWrite0(in);
407                     break;
408                 case 1: {
409                     // Only one ByteBuf so use non-gathering write
410                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
411                     // to check if the total size of all the buffers is non-zero.
412                     ByteBuffer buffer = nioBuffers[0];
413                     int attemptedBytes = buffer.remaining();
414                     final int localWrittenBytes = ch.write(buffer);
415                     if (localWrittenBytes <= 0) {
416                         incompleteWrite(true);
417                         return;
418                     }
419                     adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
420                     in.removeBytes(localWrittenBytes);
421                     --writeSpinCount;
422                     break;
423                 }
424                 default: {
425                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
426                     // to check if the total size of all the buffers is non-zero.
427                     // We limit the max amount to int above so cast is safe
428                     long attemptedBytes = in.nioBufferSize();
429                     final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
430                     if (localWrittenBytes <= 0) {
431                         incompleteWrite(true);
432                         return;
433                     }
434                     // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above.
435                     adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
436                             maxBytesPerGatheringWrite);
437                     in.removeBytes(localWrittenBytes);
438                     --writeSpinCount;
439                     break;
440                 }
441             }
442         } while (writeSpinCount > 0);
443 
444         incompleteWrite(writeSpinCount < 0);
445     }
446 
447     @Override
448     protected AbstractNioUnsafe newUnsafe() {
449         return new NioSocketChannelUnsafe();
450     }
451 
452     private final class NioSocketChannelUnsafe extends NioByteUnsafe {
453         @Override
454         protected Executor prepareToClose() {
455             try {
456                 if (javaChannel().isOpen() && config().getSoLinger() > 0) {
457                     // We need to cancel this key of the channel so we may not end up in a eventloop spin
458                     // because we try to read or write until the actual close happens which may be later due
459                     // SO_LINGER handling.
460                     // See https://github.com/netty/netty/issues/4449
461                     doDeregister();
462                     return GlobalEventExecutor.INSTANCE;
463                 }
464             } catch (Throwable ignore) {
465                 // Ignore the error as the underlying channel may be closed in the meantime and so
466                 // getSoLinger() may produce an exception. In this case we just return null.
467                 // See https://github.com/netty/netty/issues/4449
468             }
469             return null;
470         }
471     }
472 
473     private final class NioSocketChannelConfig extends DefaultSocketChannelConfig {
474         private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
475         private NioSocketChannelConfig(NioSocketChannel channel, Socket javaSocket) {
476             super(channel, javaSocket);
477             calculateMaxBytesPerGatheringWrite();
478         }
479 
480         @Override
481         protected void autoReadCleared() {
482             clearReadPending();
483         }
484 
485         @Override
486         public NioSocketChannelConfig setSendBufferSize(int sendBufferSize) {
487             super.setSendBufferSize(sendBufferSize);
488             calculateMaxBytesPerGatheringWrite();
489             return this;
490         }
491 
492         @Override
493         public <T> boolean setOption(ChannelOption<T> option, T value) {
494             if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) {
495                 return NioChannelOption.setOption(jdkChannel(), (NioChannelOption<T>) option, value);
496             }
497             return super.setOption(option, value);
498         }
499 
500         @Override
501         public <T> T getOption(ChannelOption<T> option) {
502             if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) {
503                 return NioChannelOption.getOption(jdkChannel(), (NioChannelOption<T>) option);
504             }
505             return super.getOption(option);
506         }
507 
508         @Override
509         public Map<ChannelOption<?>, Object> getOptions() {
510             if (PlatformDependent.javaVersion() >= 7) {
511                 return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel()));
512             }
513             return super.getOptions();
514         }
515 
516         void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
517             this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
518         }
519 
520         int getMaxBytesPerGatheringWrite() {
521             return maxBytesPerGatheringWrite;
522         }
523 
524         private void calculateMaxBytesPerGatheringWrite() {
525             // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide.
526             int newSendBufferSize = getSendBufferSize() << 1;
527             if (newSendBufferSize > 0) {
528                 setMaxBytesPerGatheringWrite(newSendBufferSize);
529             }
530         }
531 
532         private SocketChannel jdkChannel() {
533             return ((NioSocketChannel) channel).javaChannel();
534         }
535     }
536 }