View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.socket.nio;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufAllocator;
20  import io.netty.channel.Channel;
21  import io.netty.channel.ChannelConfig;
22  import io.netty.channel.ChannelException;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelOption;
26  import io.netty.channel.ChannelOutboundBuffer;
27  import io.netty.channel.ChannelPromise;
28  import io.netty.channel.DefaultChannelConfig;
29  import io.netty.channel.EventLoop;
30  import io.netty.channel.FileRegion;
31  import io.netty.channel.MessageSizeEstimator;
32  import io.netty.channel.RecvByteBufAllocator;
33  import io.netty.channel.ServerChannel;
34  import io.netty.channel.WriteBufferWaterMark;
35  import io.netty.channel.nio.AbstractNioByteChannel;
36  import io.netty.channel.socket.DuplexChannel;
37  import io.netty.channel.socket.DuplexChannelConfig;
38  import io.netty.util.internal.PlatformDependent;
39  import io.netty.util.internal.SocketUtils;
40  import io.netty.util.internal.logging.InternalLogger;
41  import io.netty.util.internal.logging.InternalLoggerFactory;
42  
43  import java.io.IOException;
44  import java.lang.reflect.Method;
45  import java.net.SocketAddress;
46  import java.net.StandardSocketOptions;
47  import java.nio.ByteBuffer;
48  import java.nio.channels.SelectionKey;
49  import java.nio.channels.SocketChannel;
50  import java.nio.channels.spi.SelectorProvider;
51  import java.util.ArrayList;
52  import java.util.List;
53  import java.util.Map;
54  
55  import static io.netty.channel.ChannelOption.SO_RCVBUF;
56  import static io.netty.channel.ChannelOption.SO_SNDBUF;
57  import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
58  
59  /**
60   * {@link DuplexChannel} which uses NIO selector based implementation to support
61   * UNIX Domain Sockets. This is only supported when using Java 16+.
62   */
63  public final class NioDomainSocketChannel extends AbstractNioByteChannel
64          implements DuplexChannel {
65      private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioDomainSocketChannel.class);
66      private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
67  
68      private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
69              SelectorProviderUtil.findOpenMethod("openSocketChannel");
70  
71      private final ChannelConfig config;
72      private volatile boolean isInputShutdown;
73      private volatile boolean isOutputShutdown;
74  
75      static SocketChannel newChannel(SelectorProvider provider) {
76          if (PlatformDependent.javaVersion() < 16) {
77              throw new UnsupportedOperationException("Only supported on java 16+");
78          }
79          try {
80              SocketChannel channel = SelectorProviderUtil.newDomainSocketChannel(
81                      OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider);
82              if (channel == null) {
83                  throw new ChannelException("Failed to open a socket.");
84              }
85              return channel;
86          } catch (IOException e) {
87              throw new ChannelException("Failed to open a socket.", e);
88          }
89      }
90  
91      /**
92       * Create a new instance
93       */
94      public NioDomainSocketChannel() {
95          this(DEFAULT_SELECTOR_PROVIDER);
96      }
97  
98      /**
99       * Create a new instance using the given {@link SelectorProvider}.
100      */
101     public NioDomainSocketChannel(SelectorProvider provider) {
102         this(newChannel(provider));
103     }
104 
105     /**
106      * Create a new instance using the given {@link SocketChannel}.
107      */
108     public NioDomainSocketChannel(SocketChannel socket) {
109         this(null, socket);
110     }
111 
112     /**
113      * Create a new instance
114      *
115      * @param parent    the {@link Channel} which created this instance or {@code null} if it was created by the user
116      * @param socket    the {@link SocketChannel} which will be used
117      */
118     public NioDomainSocketChannel(Channel parent, SocketChannel socket) {
119         super(parent, socket);
120         if (PlatformDependent.javaVersion() < 16) {
121             throw new UnsupportedOperationException("Only supported on java 16+");
122         }
123         config = new NioDomainSocketChannelConfig(this, socket);
124     }
125 
126     @Override
127     public ServerChannel parent() {
128         return (ServerChannel) super.parent();
129     }
130 
131     @Override
132     public ChannelConfig config() {
133         return config;
134     }
135 
136     @Override
137     protected SocketChannel javaChannel() {
138         return (SocketChannel) super.javaChannel();
139     }
140 
141     @Override
142     public boolean isActive() {
143         SocketChannel ch = javaChannel();
144         return ch.isOpen() && ch.isConnected();
145     }
146 
147     @Override
148     public boolean isOutputShutdown() {
149         return isOutputShutdown || !isActive();
150     }
151 
152     @Override
153     public boolean isInputShutdown() {
154         return isInputShutdown || !isActive();
155     }
156 
157     @Override
158     public boolean isShutdown() {
159         return isInputShutdown() && isOutputShutdown() || !isActive();
160     }
161 
162     @Override
163     protected void doShutdownOutput() throws Exception {
164         javaChannel().shutdownOutput();
165         isOutputShutdown = true;
166     }
167 
168     @Override
169     public ChannelFuture shutdownOutput() {
170         return shutdownOutput(newPromise());
171     }
172 
173     @Override
174     public ChannelFuture shutdownOutput(final ChannelPromise promise) {
175         final EventLoop loop = eventLoop();
176         if (loop.inEventLoop()) {
177             ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
178         } else {
179             loop.execute(new Runnable() {
180                 @Override
181                 public void run() {
182                     ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
183                 }
184             });
185         }
186         return promise;
187     }
188 
189     @Override
190     public ChannelFuture shutdownInput() {
191         return shutdownInput(newPromise());
192     }
193 
194     @Override
195     protected boolean isInputShutdown0() {
196         return isInputShutdown();
197     }
198 
199     @Override
200     public ChannelFuture shutdownInput(final ChannelPromise promise) {
201         EventLoop loop = eventLoop();
202         if (loop.inEventLoop()) {
203             shutdownInput0(promise);
204         } else {
205             loop.execute(new Runnable() {
206                 @Override
207                 public void run() {
208                     shutdownInput0(promise);
209                 }
210             });
211         }
212         return promise;
213     }
214 
215     @Override
216     public ChannelFuture shutdown() {
217         return shutdown(newPromise());
218     }
219 
220     @Override
221     public ChannelFuture shutdown(final ChannelPromise promise) {
222         ChannelFuture shutdownOutputFuture = shutdownOutput();
223         if (shutdownOutputFuture.isDone()) {
224             shutdownOutputDone(shutdownOutputFuture, promise);
225         } else {
226             shutdownOutputFuture.addListener(new ChannelFutureListener() {
227                 @Override
228                 public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
229                     shutdownOutputDone(shutdownOutputFuture, promise);
230                 }
231             });
232         }
233         return promise;
234     }
235 
236     private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
237         ChannelFuture shutdownInputFuture = shutdownInput();
238         if (shutdownInputFuture.isDone()) {
239             shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
240         } else {
241             shutdownInputFuture.addListener(new ChannelFutureListener() {
242                 @Override
243                 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
244                     shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
245                 }
246             });
247         }
248     }
249 
250     private static void shutdownDone(ChannelFuture shutdownOutputFuture,
251                                      ChannelFuture shutdownInputFuture,
252                                      ChannelPromise promise) {
253         Throwable shutdownOutputCause = shutdownOutputFuture.cause();
254         Throwable shutdownInputCause = shutdownInputFuture.cause();
255         if (shutdownOutputCause != null) {
256             if (shutdownInputCause != null) {
257                 logger.debug("Exception suppressed because a previous exception occurred.",
258                         shutdownInputCause);
259             }
260             promise.setFailure(shutdownOutputCause);
261         } else if (shutdownInputCause != null) {
262             promise.setFailure(shutdownInputCause);
263         } else {
264             promise.setSuccess();
265         }
266     }
267     private void shutdownInput0(final ChannelPromise promise) {
268         try {
269             shutdownInput0();
270             promise.setSuccess();
271         } catch (Throwable t) {
272             promise.setFailure(t);
273         }
274     }
275 
276     private void shutdownInput0() throws Exception {
277         javaChannel().shutdownInput();
278         isInputShutdown = true;
279     }
280 
281     @Override
282     protected SocketAddress localAddress0() {
283         try {
284             return javaChannel().getLocalAddress();
285         } catch (Exception ignore) {
286             // ignore
287         }
288         return null;
289     }
290 
291     @Override
292     protected SocketAddress remoteAddress0() {
293         try {
294             return javaChannel().getRemoteAddress();
295         } catch (Exception ignore) {
296             // ignore
297         }
298         return null;
299     }
300 
301     @Override
302     protected void doBind(SocketAddress localAddress) throws Exception {
303         SocketUtils.bind(javaChannel(), localAddress);
304     }
305 
306     @Override
307     protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
308         if (localAddress != null) {
309             doBind(localAddress);
310         }
311 
312         boolean success = false;
313         try {
314             boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
315             if (!connected) {
316                 selectionKey().interestOps(SelectionKey.OP_CONNECT);
317             }
318             success = true;
319             return connected;
320         } finally {
321             if (!success) {
322                 doClose();
323             }
324         }
325     }
326 
327     @Override
328     protected void doFinishConnect() throws Exception {
329         if (!javaChannel().finishConnect()) {
330             throw new Error();
331         }
332     }
333 
334     @Override
335     protected void doDisconnect() throws Exception {
336         doClose();
337     }
338 
339     @Override
340     protected void doClose() throws Exception {
341         try {
342             super.doClose();
343         } finally {
344             javaChannel().close();
345         }
346     }
347 
348     @Override
349     protected int doReadBytes(ByteBuf byteBuf) throws Exception {
350         final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
351         allocHandle.attemptedBytesRead(byteBuf.writableBytes());
352         return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
353     }
354 
355     @Override
356     protected int doWriteBytes(ByteBuf buf) throws Exception {
357         final int expectedWrittenBytes = buf.readableBytes();
358         return buf.readBytes(javaChannel(), expectedWrittenBytes);
359     }
360 
361     @Override
362     protected long doWriteFileRegion(FileRegion region) throws Exception {
363         final long position = region.transferred();
364         return region.transferTo(javaChannel(), position);
365     }
366 
367     private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
368         // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change
369         // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try
370         // make a best effort to adjust as OS behavior changes.
371         if (attempted == written) {
372             if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
373                 ((NioDomainSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
374             }
375         } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
376             ((NioDomainSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
377         }
378     }
379 
380     @Override
381     protected void doWrite(ChannelOutboundBuffer in) throws Exception {
382         SocketChannel ch = javaChannel();
383         int writeSpinCount = config().getWriteSpinCount();
384         do {
385             if (in.isEmpty()) {
386                 // All written so clear OP_WRITE
387                 clearOpWrite();
388                 // Directly return here so incompleteWrite(...) is not called.
389                 return;
390             }
391 
392             // Ensure the pending writes are made of ByteBufs only.
393             int maxBytesPerGatheringWrite = ((NioDomainSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
394             ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
395             int nioBufferCnt = in.nioBufferCount();
396 
397             // Always use nioBuffers() to workaround data-corruption.
398             // See https://github.com/netty/netty/issues/2761
399             switch (nioBufferCnt) {
400                 case 0:
401                     // We have something else beside ByteBuffers to write so fallback to normal writes.
402                     writeSpinCount -= doWrite0(in);
403                     break;
404                 case 1: {
405                     // Only one ByteBuf so use non-gathering write
406                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
407                     // to check if the total size of all the buffers is non-zero.
408                     ByteBuffer buffer = nioBuffers[0];
409                     int attemptedBytes = buffer.remaining();
410                     final int localWrittenBytes = ch.write(buffer);
411                     if (localWrittenBytes <= 0) {
412                         incompleteWrite(true);
413                         return;
414                     }
415                     adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
416                     in.removeBytes(localWrittenBytes);
417                     --writeSpinCount;
418                     break;
419                 }
420                 default: {
421                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
422                     // to check if the total size of all the buffers is non-zero.
423                     // We limit the max amount to int above so cast is safe
424                     long attemptedBytes = in.nioBufferSize();
425                     final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
426                     if (localWrittenBytes <= 0) {
427                         incompleteWrite(true);
428                         return;
429                     }
430                     // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above.
431                     adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
432                             maxBytesPerGatheringWrite);
433                     in.removeBytes(localWrittenBytes);
434                     --writeSpinCount;
435                     break;
436                 }
437             }
438         } while (writeSpinCount > 0);
439 
440         incompleteWrite(writeSpinCount < 0);
441     }
442 
443     @Override
444     protected AbstractNioUnsafe newUnsafe() {
445         return new NioSocketChannelUnsafe();
446     }
447 
448     private final class NioSocketChannelUnsafe extends NioByteUnsafe {
449         // Only extending it so we create a new instance in newUnsafe() and return it.
450     }
451 
452     private final class NioDomainSocketChannelConfig extends DefaultChannelConfig
453             implements DuplexChannelConfig {
454         private volatile boolean allowHalfClosure;
455         private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
456         private final SocketChannel javaChannel;
457         private NioDomainSocketChannelConfig(NioDomainSocketChannel channel, SocketChannel javaChannel) {
458             super(channel);
459             this.javaChannel = javaChannel;
460             calculateMaxBytesPerGatheringWrite();
461         }
462 
463         @Override
464         public boolean isAllowHalfClosure() {
465             return allowHalfClosure;
466         }
467 
468         @Override
469         public NioDomainSocketChannelConfig setAllowHalfClosure(boolean allowHalfClosure) {
470             this.allowHalfClosure = allowHalfClosure;
471             return this;
472         }
473         @Override
474         public Map<ChannelOption<?>, Object> getOptions() {
475             List<ChannelOption<?>> options = new ArrayList<ChannelOption<?>>();
476             options.add(SO_RCVBUF);
477             options.add(SO_SNDBUF);
478             for (ChannelOption<?> opt : NioChannelOption.getOptions(jdkChannel())) {
479                 options.add(opt);
480             }
481             return getOptions(super.getOptions(), options.toArray(new ChannelOption[0]));
482         }
483 
484         @SuppressWarnings("unchecked")
485         @Override
486         public <T> T getOption(ChannelOption<T> option) {
487             if (option == SO_RCVBUF) {
488                 return (T) Integer.valueOf(getReceiveBufferSize());
489             }
490             if (option == SO_SNDBUF) {
491                 return (T) Integer.valueOf(getSendBufferSize());
492             }
493             if (option instanceof NioChannelOption) {
494                 return NioChannelOption.getOption(jdkChannel(), (NioChannelOption<T>) option);
495             }
496 
497             return super.getOption(option);
498         }
499 
500         @Override
501         public <T> boolean setOption(ChannelOption<T> option, T value) {
502             if (option == SO_RCVBUF) {
503                 validate(option, value);
504                 setReceiveBufferSize((Integer) value);
505             } else if (option == SO_SNDBUF) {
506                 validate(option, value);
507                 setSendBufferSize((Integer) value);
508             } else if (option instanceof NioChannelOption) {
509                 return NioChannelOption.setOption(jdkChannel(), (NioChannelOption<T>) option, value);
510             } else {
511                 return super.setOption(option, value);
512             }
513 
514             return true;
515         }
516 
517         private int getReceiveBufferSize() {
518             try {
519                 return javaChannel.getOption(StandardSocketOptions.SO_RCVBUF);
520             } catch (IOException e) {
521                 throw new ChannelException(e);
522             }
523         }
524 
525         private NioDomainSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) {
526             try {
527                 javaChannel.setOption(StandardSocketOptions.SO_RCVBUF, receiveBufferSize);
528             } catch (IOException e) {
529                 throw new ChannelException(e);
530             }
531             return this;
532         }
533 
534         private int getSendBufferSize() {
535             try {
536                 return javaChannel.getOption(StandardSocketOptions.SO_SNDBUF);
537             } catch (IOException e) {
538                 throw new ChannelException(e);
539             }
540         }
541         private NioDomainSocketChannelConfig setSendBufferSize(int sendBufferSize) {
542             try {
543                 javaChannel.setOption(StandardSocketOptions.SO_SNDBUF, sendBufferSize);
544             } catch (IOException e) {
545                 throw new ChannelException(e);
546             }
547             return this;
548         }
549 
550         @Override
551         public NioDomainSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) {
552             super.setConnectTimeoutMillis(connectTimeoutMillis);
553             return this;
554         }
555 
556         @Override
557         @Deprecated
558         public NioDomainSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) {
559             super.setMaxMessagesPerRead(maxMessagesPerRead);
560             return this;
561         }
562 
563         @Override
564         public NioDomainSocketChannelConfig setWriteSpinCount(int writeSpinCount) {
565             super.setWriteSpinCount(writeSpinCount);
566             return this;
567         }
568 
569         @Override
570         public NioDomainSocketChannelConfig setAllocator(ByteBufAllocator allocator) {
571             super.setAllocator(allocator);
572             return this;
573         }
574 
575         @Override
576         public NioDomainSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) {
577             super.setRecvByteBufAllocator(allocator);
578             return this;
579         }
580 
581         @Override
582         public NioDomainSocketChannelConfig setAutoRead(boolean autoRead) {
583             super.setAutoRead(autoRead);
584             return this;
585         }
586 
587         @Override
588         public NioDomainSocketChannelConfig setAutoClose(boolean autoClose) {
589             super.setAutoClose(autoClose);
590             return this;
591         }
592 
593         @Override
594         public NioDomainSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) {
595             super.setWriteBufferHighWaterMark(writeBufferHighWaterMark);
596             return this;
597         }
598 
599         @Override
600         public NioDomainSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) {
601             super.setWriteBufferLowWaterMark(writeBufferLowWaterMark);
602             return this;
603         }
604 
605         @Override
606         public NioDomainSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) {
607             super.setWriteBufferWaterMark(writeBufferWaterMark);
608             return this;
609         }
610 
611         @Override
612         public NioDomainSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) {
613             super.setMessageSizeEstimator(estimator);
614             return this;
615         }
616 
617         @Override
618         protected void autoReadCleared() {
619             clearReadPending();
620         }
621 
622         void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
623             this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
624         }
625 
626         int getMaxBytesPerGatheringWrite() {
627             return maxBytesPerGatheringWrite;
628         }
629 
630         private void calculateMaxBytesPerGatheringWrite() {
631             // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide.
632             int newSendBufferSize = getSendBufferSize() << 1;
633             if (newSendBufferSize > 0) {
634                 setMaxBytesPerGatheringWrite(newSendBufferSize);
635             }
636         }
637 
638         private SocketChannel jdkChannel() {
639             return javaChannel;
640         }
641     }
642 }