View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.channel.socket.nio;
17  
18  import io.netty5.buffer.api.Buffer;
19  import io.netty5.channel.Channel;
20  import io.netty5.channel.ChannelException;
21  import io.netty5.channel.ChannelOption;
22  import io.netty5.channel.ChannelOutboundBuffer;
23  import io.netty5.channel.ChannelShutdownDirection;
24  import io.netty5.channel.EventLoop;
25  import io.netty5.channel.FileRegion;
26  import io.netty5.channel.RecvBufferAllocator;
27  import io.netty5.channel.nio.AbstractNioByteChannel;
28  import io.netty5.util.concurrent.Future;
29  import io.netty5.util.concurrent.GlobalEventExecutor;
30  import io.netty5.util.internal.PlatformDependent;
31  import io.netty5.util.internal.SocketUtils;
32  
33  import java.io.IOException;
34  import java.lang.reflect.Method;
35  import java.net.ProtocolFamily;
36  import java.net.SocketAddress;
37  import java.net.SocketOption;
38  import java.net.StandardSocketOptions;
39  import java.nio.ByteBuffer;
40  import java.nio.channels.SelectionKey;
41  import java.nio.channels.SocketChannel;
42  import java.nio.channels.spi.SelectorProvider;
43  import java.util.concurrent.Executor;
44  
45  import static io.netty5.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
46  import static io.netty5.channel.socket.nio.NioChannelUtil.isDomainSocket;
47  import static io.netty5.channel.socket.nio.NioChannelUtil.toDomainSocketAddress;
48  import static io.netty5.channel.socket.nio.NioChannelUtil.toJdkFamily;
49  import static io.netty5.channel.socket.nio.NioChannelUtil.toUnixDomainSocketAddress;
50  
51  /**
52   * {@link io.netty5.channel.socket.SocketChannel} which uses NIO selector based implementation.
53   *
54   * <h3>Available options</h3>
55   *
56   * In addition to the options provided by {@link io.netty5.channel.socket.SocketChannel},
57   * {@link NioSocketChannel} allows the following options in the option map:
58   *
59   * <table border="1" cellspacing="0" cellpadding="6">
60   * <tr>
61   * <th>{@link ChannelOption}</th>
62   * <th>{@code INET}</th>
63   * <th>{@code INET6}</th>
64   * <th>{@code UNIX</th>
65   * </tr><tr>
66   * <td>{@link NioChannelOption}</td><td>X</td><td>X</td><td>X</td>
67   * </tr>
68   * </table>
69   */
70  public class NioSocketChannel
71          extends AbstractNioByteChannel<NioServerSocketChannel, SocketAddress, SocketAddress>
72          implements io.netty5.channel.socket.SocketChannel {
73      private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
74  
75      private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
76              NioChannelUtil.findOpenMethod("openSocketChannel");
77  
78      private static SocketChannel newChannel(SelectorProvider provider, ProtocolFamily family) {
79          try {
80              SocketChannel channel = NioChannelUtil.newChannel(OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider, family);
81              return channel == null ? provider.openSocketChannel() : channel;
82          } catch (IOException e) {
83              throw new ChannelException("Failed to open a socket.", e);
84          }
85      }
86  
87      private final ProtocolFamily family;
88      private volatile boolean inputShutdown;
89      private volatile boolean outputShutdown;
90  
91      /**
92       * Create a new instance
93       */
94      public NioSocketChannel(EventLoop eventLoop) {
95          this(eventLoop, DEFAULT_SELECTOR_PROVIDER);
96      }
97  
98      /**
99       * Create a new instance using the given {@link SelectorProvider}.
100      */
101     public NioSocketChannel(EventLoop eventLoop, SelectorProvider provider) {
102         this(eventLoop, provider, null);
103     }
104 
105     /**
106      * Create a new instance using the given {@link SelectorProvider} and protocol family (supported only since JDK 15).
107      */
108     public NioSocketChannel(EventLoop eventLoop, SelectorProvider provider, ProtocolFamily family) {
109         this(null, eventLoop, newChannel(provider, toJdkFamily(family)), family);
110     }
111 
112     /**
113      * Create a new instance using the given {@link SocketChannel}.
114      */
115     public NioSocketChannel(EventLoop eventLoop, SocketChannel socket) {
116         this(null, eventLoop, socket, null);
117     }
118 
119     /**
120      * Create a new instance
121      *
122      * @param parent    the {@link Channel} which created this instance or {@code null} if it was created by the user
123      * @param eventLoop the {@link EventLoop} to use for IO.
124      * @param socket    the {@link SocketChannel} which will be used
125      */
126     public NioSocketChannel(NioServerSocketChannel parent, EventLoop eventLoop, SocketChannel socket) {
127         this(parent, eventLoop, socket, null);
128     }
129 
130     /**
131      * Create a new instance
132      *
133      * @param parent    the {@link Channel} which created this instance or {@code null} if it was created by the user
134      * @param eventLoop the {@link EventLoop} to use for IO.
135      * @param socket    the {@link SocketChannel} which will be used
136      * @param family    the {@link ProtocolFamily} that was used to create th {@link SocketChannel}
137      */
138     public NioSocketChannel(NioServerSocketChannel parent, EventLoop eventLoop, SocketChannel socket,
139                             ProtocolFamily family) {
140         super(parent, eventLoop, socket);
141         this.family = toJdkFamily(family);
142         // Enable TCP_NODELAY by default if possible.
143         if (!isDomainSocket(family) && PlatformDependent.canEnableTcpNoDelayByDefault()) {
144             try {
145                 javaChannel().setOption(StandardSocketOptions.TCP_NODELAY, true);
146             } catch (Exception e) {
147                 // Ignore.
148             }
149         }
150         calculateMaxBytesPerGatheringWrite();
151     }
152 
153     @Override
154     protected SocketChannel javaChannel() {
155         return (SocketChannel) super.javaChannel();
156     }
157 
158     @Override
159     public boolean isActive() {
160         SocketChannel ch = javaChannel();
161         return ch.isOpen() && ch.isConnected();
162     }
163 
164     @Override
165     public boolean isShutdown(ChannelShutdownDirection direction) {
166         if (!isActive()) {
167             return true;
168         }
169         switch (direction) {
170             case Outbound:
171                 return outputShutdown;
172             case Inbound:
173                 return inputShutdown;
174             default:
175                 throw new AssertionError();
176         }
177     }
178 
179     @Override
180     protected void doShutdown(ChannelShutdownDirection direction) throws Exception {
181         switch (direction) {
182             case Inbound:
183                 javaChannel().shutdownInput();
184                 inputShutdown = true;
185                 break;
186             case Outbound:
187                 javaChannel().shutdownOutput();
188                 outputShutdown = true;
189                 break;
190             default:
191                 throw new AssertionError();
192         }
193     }
194 
195     @Override
196     protected SocketAddress localAddress0() {
197         try {
198             SocketAddress address = javaChannel().getLocalAddress();
199             if (isDomainSocket(family)) {
200                 return toDomainSocketAddress(address);
201             }
202             return address;
203         } catch (IOException e) {
204             // Just return null
205             return null;
206         }
207     }
208 
209     @Override
210     protected SocketAddress remoteAddress0() {
211         try {
212             SocketAddress address = javaChannel().getRemoteAddress();
213             if (isDomainSocket(family)) {
214                 return toDomainSocketAddress(address);
215             }
216             return address;
217         } catch (IOException e) {
218             // Just return null
219             return null;
220         }
221     }
222 
223     @Override
224     protected void doBind(SocketAddress localAddress) throws Exception {
225         doBind0(localAddress);
226     }
227 
228     private void doBind0(SocketAddress localAddress) throws Exception {
229         if (isDomainSocket(family)) {
230             localAddress = toUnixDomainSocketAddress(localAddress);
231         }
232         SocketUtils.bind(javaChannel(), localAddress);
233     }
234 
235     @Override
236     protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
237         if (localAddress != null) {
238             doBind0(localAddress);
239         }
240 
241         boolean success = false;
242         try {
243             if (isDomainSocket(family)) {
244                 remoteAddress = toUnixDomainSocketAddress(remoteAddress);
245             }
246             boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
247             if (!connected) {
248                 selectionKey().interestOps(SelectionKey.OP_CONNECT);
249             }
250             success = true;
251             return connected;
252         } finally {
253             if (!success) {
254                 doClose();
255             }
256         }
257     }
258 
259     @Override
260     protected boolean doFinishConnect(SocketAddress requestedRemoteAddress) throws Exception {
261         return javaChannel().finishConnect();
262     }
263 
264     @Override
265     protected void doDisconnect() throws Exception {
266         doClose();
267     }
268 
269     @Override
270     protected int doReadBytes(Buffer buffer) throws Exception {
271         final RecvBufferAllocator.Handle allocHandle = recvBufAllocHandle();
272         allocHandle.attemptedBytesRead(buffer.writableBytes());
273         return buffer.transferFrom(javaChannel(), allocHandle.attemptedBytesRead());
274     }
275 
276     @Override
277     protected int doWriteBytes(Buffer buf) throws Exception {
278         final int expectedWrittenBytes = buf.readableBytes();
279         return buf.transferTo(javaChannel(), expectedWrittenBytes);
280     }
281 
282     @Override
283     protected long doWriteFileRegion(FileRegion region) throws Exception {
284         final long position = region.transferred();
285         return region.transferTo(javaChannel(), position);
286     }
287 
288     private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
289         // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change
290         // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try
291         // make a best effort to adjust as OS behavior changes.
292         if (attempted == written) {
293             if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
294                 setMaxBytesPerGatheringWrite(attempted << 1);
295             }
296         } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
297             setMaxBytesPerGatheringWrite(attempted >>> 1);
298         }
299     }
300 
301     @Override
302     protected void doWrite(ChannelOutboundBuffer in) throws Exception {
303         SocketChannel ch = javaChannel();
304         int writeSpinCount = getWriteSpinCount();
305         do {
306             if (in.isEmpty()) {
307                 // All written so clear OP_WRITE
308                 clearOpWrite();
309                 // Directly return here so incompleteWrite(...) is not called.
310                 return;
311             }
312 
313             // Ensure the pending writes are made of ByteBufs only.
314             int maxBytesPerGatheringWrite = getMaxBytesPerGatheringWrite();
315             ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
316             int nioBufferCnt = in.nioBufferCount();
317 
318             // Always use nioBuffers() to workaround data-corruption.
319             // See https://github.com/netty/netty/issues/2761
320             switch (nioBufferCnt) {
321                 case 0:
322                     // We have something else beside ByteBuffers to write so fallback to normal writes.
323                     writeSpinCount -= doWrite0(in);
324                     break;
325                 case 1: {
326                     // Only one ByteBuf so use non-gathering write
327                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
328                     // to check if the total size of all the buffers is non-zero.
329                     ByteBuffer buffer = nioBuffers[0];
330                     int attemptedBytes = buffer.remaining();
331                     final int localWrittenBytes = ch.write(buffer);
332                     if (localWrittenBytes <= 0) {
333                         incompleteWrite(true);
334                         return;
335                     }
336                     adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
337                     in.removeBytes(localWrittenBytes);
338                     --writeSpinCount;
339                     break;
340                 }
341                 default: {
342                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
343                     // to check if the total size of all the buffers is non-zero.
344                     // We limit the max amount to int above so cast is safe
345                     long attemptedBytes = in.nioBufferSize();
346                     final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
347                     if (localWrittenBytes <= 0) {
348                         incompleteWrite(true);
349                         return;
350                     }
351                     // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above.
352                     adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
353                             maxBytesPerGatheringWrite);
354                     in.removeBytes(localWrittenBytes);
355                     --writeSpinCount;
356                     break;
357                 }
358             }
359         } while (writeSpinCount > 0);
360 
361         incompleteWrite(writeSpinCount < 0);
362     }
363 
364     @Override
365     protected Future<Executor> prepareToClose() {
366         if (!isDomainSocket(family)) {
367             try {
368                 if (javaChannel().isOpen() && getOption(ChannelOption.SO_LINGER) > 0) {
369                     // We need to cancel this key of the channel so we may not end up in a eventloop spin
370                     // because we try to read or write until the actual close happens which may be later due
371                     // SO_LINGER handling.
372                     // See https://github.com/netty/netty/issues/4449
373                     return executor().deregisterForIo(this).map(v -> GlobalEventExecutor.INSTANCE);
374                 }
375             } catch (Throwable ignore) {
376                 // Ignore the error as the underlying channel may be closed in the meantime and so
377                 // getSoLinger() may produce an exception. In this case we just return null.
378                 // See https://github.com/netty/netty/issues/4449
379             }
380         }
381         return null;
382     }
383 
384     @SuppressWarnings("unchecked")
385     @Override
386     protected <T> T getExtendedOption(ChannelOption<T> option) {
387         SocketOption<T> socketOption = NioChannelOption.toSocketOption(option);
388         if (socketOption != null) {
389             return NioChannelOption.getOption(javaChannel(), socketOption);
390         } else {
391             return super.getExtendedOption(option);
392         }
393     }
394 
395     @Override
396     protected <T> void setExtendedOption(ChannelOption<T> option, T value) {
397         SocketOption<T> socketOption = NioChannelOption.toSocketOption(option);
398         if (socketOption != null) {
399             NioChannelOption.setOption(javaChannel(), socketOption, value);
400         } else {
401             super.setExtendedOption(option, value);
402         }
403     }
404 
405     @Override
406     protected boolean isExtendedOptionSupported(ChannelOption<?> option) {
407         SocketOption<?> socketOption = NioChannelOption.toSocketOption(option);
408         if (socketOption != null) {
409             return NioChannelOption.isOptionSupported(javaChannel(), socketOption);
410         }
411         return super.isExtendedOptionSupported(option);
412     }
413 
414     @Override
415     protected void autoReadCleared() {
416         clearReadPending();
417     }
418 
419     private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
420 
421     void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
422         this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
423     }
424 
425     int getMaxBytesPerGatheringWrite() {
426         return maxBytesPerGatheringWrite;
427     }
428 
429     private void calculateMaxBytesPerGatheringWrite() {
430         try {
431             // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide.
432             int newSendBufferSize = javaChannel().getOption(StandardSocketOptions.SO_SNDBUF) << 1;
433             if (newSendBufferSize > 0) {
434                 setMaxBytesPerGatheringWrite(newSendBufferSize);
435             }
436         } catch (IOException e) {
437             throw new ChannelException(e);
438         }
439     }
440 }