View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.socket.nio;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelException;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelFutureListener;
23  import io.netty.channel.ChannelOption;
24  import io.netty.channel.ChannelOutboundBuffer;
25  import io.netty.channel.ChannelPromise;
26  import io.netty.channel.EventLoop;
27  import io.netty.channel.FileRegion;
28  import io.netty.channel.RecvByteBufAllocator;
29  import io.netty.channel.nio.AbstractNioByteChannel;
30  import io.netty.channel.nio.NioIoOps;
31  import io.netty.channel.socket.DefaultSocketChannelConfig;
32  import io.netty.channel.socket.InternetProtocolFamily;
33  import io.netty.channel.socket.ServerSocketChannel;
34  import io.netty.channel.socket.SocketChannelConfig;
35  import io.netty.channel.socket.SocketProtocolFamily;
36  import io.netty.util.concurrent.GlobalEventExecutor;
37  import io.netty.util.internal.SocketUtils;
38  import io.netty.util.internal.logging.InternalLogger;
39  import io.netty.util.internal.logging.InternalLoggerFactory;
40  
41  import java.io.IOException;
42  import java.lang.reflect.Method;
43  import java.net.InetSocketAddress;
44  import java.net.Socket;
45  import java.net.SocketAddress;
46  import java.nio.ByteBuffer;
47  import java.nio.channels.SocketChannel;
48  import java.nio.channels.spi.SelectorProvider;
49  import java.util.Map;
50  import java.util.concurrent.Executor;
51  
52  import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
53  
54  /**
55   * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation.
56   */
57  public class NioSocketChannel extends AbstractNioByteChannel implements io.netty.channel.socket.SocketChannel {
58      private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
59      private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
60  
61      private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
62              SelectorProviderUtil.findOpenMethod("openSocketChannel");
63  
64      private final SocketChannelConfig config;
65  
66      private static SocketChannel newChannel(SelectorProvider provider, SocketProtocolFamily family) {
67          try {
68              SocketChannel channel = SelectorProviderUtil.newChannel(OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider, family);
69              return channel == null ? provider.openSocketChannel() : channel;
70          } catch (IOException e) {
71              throw new ChannelException("Failed to open a socket.", e);
72          }
73      }
74  
75      /**
76       * Create a new instance
77       */
78      public NioSocketChannel() {
79          this(DEFAULT_SELECTOR_PROVIDER);
80      }
81  
82      /**
83       * Create a new instance using the given {@link SelectorProvider}.
84       */
85      public NioSocketChannel(SelectorProvider provider) {
86          this(provider, (SocketProtocolFamily) null);
87      }
88  
89      /**
90       * Create a new instance using the given {@link SelectorProvider} and protocol family (supported only since JDK 15).
91       *
92       * @deprecated use {@link NioSocketChannel#NioSocketChannel(SelectorProvider, SocketProtocolFamily)}
93       */
94      @Deprecated
95      public NioSocketChannel(SelectorProvider provider, InternetProtocolFamily family) {
96          this(provider, family == null ? null : family.toSocketProtocolFamily());
97      }
98  
99      /**
100      * Create a new instance using the given {@link SelectorProvider} and protocol family (supported only since JDK 15).
101      */
102     public NioSocketChannel(SelectorProvider provider, SocketProtocolFamily family) {
103         this(newChannel(provider, family));
104     }
105 
106     /**
107      * Create a new instance using the given {@link SocketChannel}.
108      */
109     public NioSocketChannel(SocketChannel socket) {
110         this(null, socket);
111     }
112 
113     /**
114      * Create a new instance
115      *
116      * @param parent    the {@link Channel} which created this instance or {@code null} if it was created by the user
117      * @param socket    the {@link SocketChannel} which will be used
118      */
119     public NioSocketChannel(Channel parent, SocketChannel socket) {
120         super(parent, socket);
121         config = new NioSocketChannelConfig(this, socket.socket());
122     }
123 
124     @Override
125     public ServerSocketChannel parent() {
126         return (ServerSocketChannel) super.parent();
127     }
128 
129     @Override
130     public SocketChannelConfig config() {
131         return config;
132     }
133 
134     @Override
135     protected SocketChannel javaChannel() {
136         return (SocketChannel) super.javaChannel();
137     }
138 
139     @Override
140     public boolean isActive() {
141         SocketChannel ch = javaChannel();
142         return ch.isOpen() && ch.isConnected();
143     }
144 
145     @Override
146     public boolean isOutputShutdown() {
147         return javaChannel().socket().isOutputShutdown() || !isActive();
148     }
149 
150     @Override
151     public boolean isInputShutdown() {
152         return javaChannel().socket().isInputShutdown() || !isActive();
153     }
154 
155     @Override
156     public boolean isShutdown() {
157         Socket socket = javaChannel().socket();
158         return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive();
159     }
160 
161     @Override
162     public InetSocketAddress localAddress() {
163         return (InetSocketAddress) super.localAddress();
164     }
165 
166     @Override
167     public InetSocketAddress remoteAddress() {
168         return (InetSocketAddress) super.remoteAddress();
169     }
170 
171     @Override
172     protected final void doShutdownOutput() throws Exception {
173         javaChannel().shutdownOutput();
174     }
175 
176     @Override
177     public ChannelFuture shutdownOutput() {
178         return shutdownOutput(newPromise());
179     }
180 
181     @Override
182     public ChannelFuture shutdownOutput(final ChannelPromise promise) {
183         final EventLoop loop = eventLoop();
184         if (loop.inEventLoop()) {
185             ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
186         } else {
187             loop.execute(new Runnable() {
188                 @Override
189                 public void run() {
190                     ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
191                 }
192             });
193         }
194         return promise;
195     }
196 
197     @Override
198     public ChannelFuture shutdownInput() {
199         return shutdownInput(newPromise());
200     }
201 
202     @Override
203     protected boolean isInputShutdown0() {
204         return isInputShutdown();
205     }
206 
207     @Override
208     public ChannelFuture shutdownInput(final ChannelPromise promise) {
209         EventLoop loop = eventLoop();
210         if (loop.inEventLoop()) {
211             shutdownInput0(promise);
212         } else {
213             loop.execute(new Runnable() {
214                 @Override
215                 public void run() {
216                     shutdownInput0(promise);
217                 }
218             });
219         }
220         return promise;
221     }
222 
223     @Override
224     public ChannelFuture shutdown() {
225         return shutdown(newPromise());
226     }
227 
228     @Override
229     public ChannelFuture shutdown(final ChannelPromise promise) {
230         ChannelFuture shutdownOutputFuture = shutdownOutput();
231         if (shutdownOutputFuture.isDone()) {
232             shutdownOutputDone(shutdownOutputFuture, promise);
233         } else {
234             shutdownOutputFuture.addListener(new ChannelFutureListener() {
235                 @Override
236                 public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
237                     shutdownOutputDone(shutdownOutputFuture, promise);
238                 }
239             });
240         }
241         return promise;
242     }
243 
244     private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
245         ChannelFuture shutdownInputFuture = shutdownInput();
246         if (shutdownInputFuture.isDone()) {
247             shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
248         } else {
249             shutdownInputFuture.addListener(new ChannelFutureListener() {
250                 @Override
251                 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
252                     shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
253                 }
254             });
255         }
256     }
257 
258     private static void shutdownDone(ChannelFuture shutdownOutputFuture,
259                                      ChannelFuture shutdownInputFuture,
260                                      ChannelPromise promise) {
261         Throwable shutdownOutputCause = shutdownOutputFuture.cause();
262         Throwable shutdownInputCause = shutdownInputFuture.cause();
263         if (shutdownOutputCause != null) {
264             if (shutdownInputCause != null) {
265                 logger.debug("Exception suppressed because a previous exception occurred.",
266                         shutdownInputCause);
267             }
268             promise.setFailure(shutdownOutputCause);
269         } else if (shutdownInputCause != null) {
270             promise.setFailure(shutdownInputCause);
271         } else {
272             promise.setSuccess();
273         }
274     }
275     private void shutdownInput0(final ChannelPromise promise) {
276         try {
277             shutdownInput0();
278             promise.setSuccess();
279         } catch (Throwable t) {
280             promise.setFailure(t);
281         }
282     }
283 
284     private void shutdownInput0() throws Exception {
285         javaChannel().shutdownInput();
286     }
287 
288     @Override
289     protected SocketAddress localAddress0() {
290         return javaChannel().socket().getLocalSocketAddress();
291     }
292 
293     @Override
294     protected SocketAddress remoteAddress0() {
295         return javaChannel().socket().getRemoteSocketAddress();
296     }
297 
298     @Override
299     protected void doBind(SocketAddress localAddress) throws Exception {
300         doBind0(localAddress);
301     }
302 
303     private void doBind0(SocketAddress localAddress) throws Exception {
304         SocketUtils.bind(javaChannel(), localAddress);
305     }
306 
307     @Override
308     protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
309         if (localAddress != null) {
310             doBind0(localAddress);
311         }
312 
313         boolean success = false;
314         try {
315             boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
316             if (!connected) {
317                 addAndSubmit(NioIoOps.CONNECT);
318             }
319             success = true;
320             return connected;
321         } finally {
322             if (!success) {
323                 doClose();
324             }
325         }
326     }
327 
328     @Override
329     protected void doFinishConnect() throws Exception {
330         if (!javaChannel().finishConnect()) {
331             throw new Error();
332         }
333     }
334 
335     @Override
336     protected void doDisconnect() throws Exception {
337         doClose();
338     }
339 
340     @Override
341     protected void doClose() throws Exception {
342         super.doClose();
343         javaChannel().close();
344     }
345 
346     @Override
347     protected int doReadBytes(ByteBuf byteBuf) throws Exception {
348         final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
349         allocHandle.attemptedBytesRead(byteBuf.writableBytes());
350         return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
351     }
352 
353     @Override
354     protected int doWriteBytes(ByteBuf buf) throws Exception {
355         final int expectedWrittenBytes = buf.readableBytes();
356         return buf.readBytes(javaChannel(), expectedWrittenBytes);
357     }
358 
359     @Override
360     protected long doWriteFileRegion(FileRegion region) throws Exception {
361         final long position = region.transferred();
362         return region.transferTo(javaChannel(), position);
363     }
364 
365     private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
366         // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change
367         // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try
368         // make a best effort to adjust as OS behavior changes.
369         if (attempted == written) {
370             if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
371                 ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
372             }
373         } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
374             ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
375         }
376     }
377 
378     @Override
379     protected void doWrite(ChannelOutboundBuffer in) throws Exception {
380         SocketChannel ch = javaChannel();
381         int writeSpinCount = config().getWriteSpinCount();
382         do {
383             if (in.isEmpty()) {
384                 // All written so clear OP_WRITE
385                 clearOpWrite();
386                 // Directly return here so incompleteWrite(...) is not called.
387                 return;
388             }
389 
390             // Ensure the pending writes are made of ByteBufs only.
391             int maxBytesPerGatheringWrite = ((NioSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
392             ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
393             int nioBufferCnt = in.nioBufferCount();
394 
395             // Always use nioBuffers() to workaround data-corruption.
396             // See https://github.com/netty/netty/issues/2761
397             switch (nioBufferCnt) {
398                 case 0:
399                     // We have something else beside ByteBuffers to write so fallback to normal writes.
400                     writeSpinCount -= doWrite0(in);
401                     break;
402                 case 1: {
403                     // Only one ByteBuf so use non-gathering write
404                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
405                     // to check if the total size of all the buffers is non-zero.
406                     ByteBuffer buffer = nioBuffers[0];
407                     int attemptedBytes = buffer.remaining();
408                     final int localWrittenBytes = ch.write(buffer);
409                     if (localWrittenBytes <= 0) {
410                         incompleteWrite(true);
411                         return;
412                     }
413                     adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
414                     in.removeBytes(localWrittenBytes);
415                     --writeSpinCount;
416                     break;
417                 }
418                 default: {
419                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
420                     // to check if the total size of all the buffers is non-zero.
421                     // We limit the max amount to int above so cast is safe
422                     long attemptedBytes = in.nioBufferSize();
423                     final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
424                     if (localWrittenBytes <= 0) {
425                         incompleteWrite(true);
426                         return;
427                     }
428                     // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above.
429                     adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
430                             maxBytesPerGatheringWrite);
431                     in.removeBytes(localWrittenBytes);
432                     --writeSpinCount;
433                     break;
434                 }
435             }
436         } while (writeSpinCount > 0);
437 
438         incompleteWrite(writeSpinCount < 0);
439     }
440 
441     @Override
442     protected AbstractNioUnsafe newUnsafe() {
443         return new NioSocketChannelUnsafe();
444     }
445 
446     private final class NioSocketChannelUnsafe extends NioByteUnsafe {
447         @Override
448         protected Executor prepareToClose() {
449             try {
450                 if (javaChannel().isOpen() && config().getSoLinger() > 0) {
451                     // We need to cancel this key of the channel so we may not end up in a eventloop spin
452                     // because we try to read or write until the actual close happens which may be later due
453                     // SO_LINGER handling.
454                     // See https://github.com/netty/netty/issues/4449
455                     doDeregister();
456                     return GlobalEventExecutor.INSTANCE;
457                 }
458             } catch (Throwable ignore) {
459                 // Ignore the error as the underlying channel may be closed in the meantime and so
460                 // getSoLinger() may produce an exception. In this case we just return null.
461                 // See https://github.com/netty/netty/issues/4449
462             }
463             return null;
464         }
465     }
466 
467     private final class NioSocketChannelConfig extends DefaultSocketChannelConfig {
468         private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
469         private NioSocketChannelConfig(NioSocketChannel channel, Socket javaSocket) {
470             super(channel, javaSocket);
471             calculateMaxBytesPerGatheringWrite();
472         }
473 
474         @Override
475         protected void autoReadCleared() {
476             clearReadPending();
477         }
478 
479         @Override
480         public NioSocketChannelConfig setSendBufferSize(int sendBufferSize) {
481             super.setSendBufferSize(sendBufferSize);
482             calculateMaxBytesPerGatheringWrite();
483             return this;
484         }
485 
486         @Override
487         public <T> boolean setOption(ChannelOption<T> option, T value) {
488             if (option instanceof NioChannelOption) {
489                 return NioChannelOption.setOption(jdkChannel(), (NioChannelOption<T>) option, value);
490             }
491             return super.setOption(option, value);
492         }
493 
494         @Override
495         public <T> T getOption(ChannelOption<T> option) {
496             if (option instanceof NioChannelOption) {
497                 return NioChannelOption.getOption(jdkChannel(), (NioChannelOption<T>) option);
498             }
499             return super.getOption(option);
500         }
501 
502         @Override
503         public Map<ChannelOption<?>, Object> getOptions() {
504             return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel()));
505         }
506 
507         void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
508             this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
509         }
510 
511         int getMaxBytesPerGatheringWrite() {
512             return maxBytesPerGatheringWrite;
513         }
514 
515         private void calculateMaxBytesPerGatheringWrite() {
516             // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide.
517             int newSendBufferSize = getSendBufferSize() << 1;
518             if (newSendBufferSize > 0) {
519                 setMaxBytesPerGatheringWrite(newSendBufferSize);
520             }
521         }
522 
523         private SocketChannel jdkChannel() {
524             return ((NioSocketChannel) channel).javaChannel();
525         }
526     }
527 }