View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.socket.nio;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufAllocator;
20  import io.netty.channel.Channel;
21  import io.netty.channel.ChannelConfig;
22  import io.netty.channel.ChannelException;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelOption;
26  import io.netty.channel.ChannelOutboundBuffer;
27  import io.netty.channel.ChannelPromise;
28  import io.netty.channel.DefaultChannelConfig;
29  import io.netty.channel.EventLoop;
30  import io.netty.channel.FileRegion;
31  import io.netty.channel.MessageSizeEstimator;
32  import io.netty.channel.RecvByteBufAllocator;
33  import io.netty.channel.WriteBufferWaterMark;
34  import io.netty.channel.nio.AbstractNioByteChannel;
35  import io.netty.channel.socket.DuplexChannel;
36  import io.netty.channel.socket.DuplexChannelConfig;
37  import io.netty.channel.socket.ServerSocketChannel;
38  import io.netty.util.internal.PlatformDependent;
39  import io.netty.util.internal.SocketUtils;
40  import io.netty.util.internal.SuppressJava6Requirement;
41  import io.netty.util.internal.logging.InternalLogger;
42  import io.netty.util.internal.logging.InternalLoggerFactory;
43  
44  import java.io.IOException;
45  import java.lang.reflect.Method;
46  import java.net.SocketAddress;
47  import java.net.StandardSocketOptions;
48  import java.nio.ByteBuffer;
49  import java.nio.channels.SelectionKey;
50  import java.nio.channels.SocketChannel;
51  import java.nio.channels.spi.SelectorProvider;
52  import java.util.ArrayList;
53  import java.util.List;
54  import java.util.Map;
55  
56  import static io.netty.channel.ChannelOption.SO_RCVBUF;
57  import static io.netty.channel.ChannelOption.SO_SNDBUF;
58  import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
59  
60  /**
61   * {@link DuplexChannel} which uses NIO selector based implementation to support
62   * UNIX Domain Sockets. This is only supported when using Java 16+.
63   */
64  public final class NioDomainSocketChannel extends AbstractNioByteChannel
65          implements DuplexChannel {
66      private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioDomainSocketChannel.class);
67      private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
68  
69      private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
70              SelectorProviderUtil.findOpenMethod("openSocketChannel");
71  
72      private final ChannelConfig config;
73      private volatile boolean isInputShutdown;
74      private volatile boolean isOutputShutdown;
75  
76      private static SocketChannel newChannel(SelectorProvider provider) {
77          if (PlatformDependent.javaVersion() < 16) {
78              throw new UnsupportedOperationException("Only supported on java 16+");
79          }
80          try {
81              SocketChannel channel = SelectorProviderUtil.newDomainSocketChannel(
82                      OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider);
83              if (channel == null) {
84                  throw new ChannelException("Failed to open a socket.");
85              }
86              return channel;
87          } catch (IOException e) {
88              throw new ChannelException("Failed to open a socket.", e);
89          }
90      }
91  
92      /**
93       * Create a new instance
94       */
95      public NioDomainSocketChannel() {
96          this(DEFAULT_SELECTOR_PROVIDER);
97      }
98  
99      /**
100      * Create a new instance using the given {@link SelectorProvider}.
101      */
102     public NioDomainSocketChannel(SelectorProvider provider) {
103         this(newChannel(provider));
104     }
105 
106     /**
107      * Create a new instance using the given {@link SocketChannel}.
108      */
109     public NioDomainSocketChannel(SocketChannel socket) {
110         this(null, socket);
111     }
112 
113     /**
114      * Create a new instance
115      *
116      * @param parent    the {@link Channel} which created this instance or {@code null} if it was created by the user
117      * @param socket    the {@link SocketChannel} which will be used
118      */
119     public NioDomainSocketChannel(Channel parent, SocketChannel socket) {
120         super(parent, socket);
121         if (PlatformDependent.javaVersion() < 16) {
122             throw new UnsupportedOperationException("Only supported on java 16+");
123         }
124         config = new NioDomainSocketChannelConfig(this, socket);
125     }
126 
127     @Override
128     public ServerSocketChannel parent() {
129         return (ServerSocketChannel) super.parent();
130     }
131 
132     @Override
133     public ChannelConfig config() {
134         return config;
135     }
136 
137     @Override
138     protected SocketChannel javaChannel() {
139         return (SocketChannel) super.javaChannel();
140     }
141 
142     @Override
143     public boolean isActive() {
144         SocketChannel ch = javaChannel();
145         return ch.isOpen() && ch.isConnected();
146     }
147 
148     @Override
149     public boolean isOutputShutdown() {
150         return isOutputShutdown || !isActive();
151     }
152 
153     @Override
154     public boolean isInputShutdown() {
155         return isInputShutdown || !isActive();
156     }
157 
158     @Override
159     public boolean isShutdown() {
160         return isInputShutdown() && isOutputShutdown() || !isActive();
161     }
162 
163     @SuppressJava6Requirement(reason = "guarded by version check")
164     @Override
165     protected void doShutdownOutput() throws Exception {
166         javaChannel().shutdownOutput();
167         isOutputShutdown = true;
168     }
169 
170     @Override
171     public ChannelFuture shutdownOutput() {
172         return shutdownOutput(newPromise());
173     }
174 
175     @Override
176     public ChannelFuture shutdownOutput(final ChannelPromise promise) {
177         final EventLoop loop = eventLoop();
178         if (loop.inEventLoop()) {
179             ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
180         } else {
181             loop.execute(new Runnable() {
182                 @Override
183                 public void run() {
184                     ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
185                 }
186             });
187         }
188         return promise;
189     }
190 
191     @Override
192     public ChannelFuture shutdownInput() {
193         return shutdownInput(newPromise());
194     }
195 
196     @Override
197     protected boolean isInputShutdown0() {
198         return isInputShutdown();
199     }
200 
201     @Override
202     public ChannelFuture shutdownInput(final ChannelPromise promise) {
203         EventLoop loop = eventLoop();
204         if (loop.inEventLoop()) {
205             shutdownInput0(promise);
206         } else {
207             loop.execute(new Runnable() {
208                 @Override
209                 public void run() {
210                     shutdownInput0(promise);
211                 }
212             });
213         }
214         return promise;
215     }
216 
217     @Override
218     public ChannelFuture shutdown() {
219         return shutdown(newPromise());
220     }
221 
222     @Override
223     public ChannelFuture shutdown(final ChannelPromise promise) {
224         ChannelFuture shutdownOutputFuture = shutdownOutput();
225         if (shutdownOutputFuture.isDone()) {
226             shutdownOutputDone(shutdownOutputFuture, promise);
227         } else {
228             shutdownOutputFuture.addListener(new ChannelFutureListener() {
229                 @Override
230                 public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
231                     shutdownOutputDone(shutdownOutputFuture, promise);
232                 }
233             });
234         }
235         return promise;
236     }
237 
238     private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
239         ChannelFuture shutdownInputFuture = shutdownInput();
240         if (shutdownInputFuture.isDone()) {
241             shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
242         } else {
243             shutdownInputFuture.addListener(new ChannelFutureListener() {
244                 @Override
245                 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
246                     shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
247                 }
248             });
249         }
250     }
251 
252     private static void shutdownDone(ChannelFuture shutdownOutputFuture,
253                                      ChannelFuture shutdownInputFuture,
254                                      ChannelPromise promise) {
255         Throwable shutdownOutputCause = shutdownOutputFuture.cause();
256         Throwable shutdownInputCause = shutdownInputFuture.cause();
257         if (shutdownOutputCause != null) {
258             if (shutdownInputCause != null) {
259                 logger.debug("Exception suppressed because a previous exception occurred.",
260                         shutdownInputCause);
261             }
262             promise.setFailure(shutdownOutputCause);
263         } else if (shutdownInputCause != null) {
264             promise.setFailure(shutdownInputCause);
265         } else {
266             promise.setSuccess();
267         }
268     }
269 
270     private void shutdownInput0(final ChannelPromise promise) {
271         try {
272             shutdownInput0();
273             promise.setSuccess();
274         } catch (Throwable t) {
275             promise.setFailure(t);
276         }
277     }
278 
279     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
280     private void shutdownInput0() throws Exception {
281         javaChannel().shutdownInput();
282         isInputShutdown = true;
283     }
284 
285     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
286     @Override
287     protected SocketAddress localAddress0() {
288         try {
289             return javaChannel().getLocalAddress();
290         } catch (Exception ignore) {
291             // ignore
292         }
293         return null;
294     }
295 
296     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
297     @Override
298     protected SocketAddress remoteAddress0() {
299         try {
300             return javaChannel().getRemoteAddress();
301         } catch (Exception ignore) {
302             // ignore
303         }
304         return null;
305     }
306 
307     @Override
308     protected void doBind(SocketAddress localAddress) throws Exception {
309         SocketUtils.bind(javaChannel(), localAddress);
310     }
311 
312     @Override
313     protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
314         if (localAddress != null) {
315             doBind(localAddress);
316         }
317 
318         boolean success = false;
319         try {
320             boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
321             if (!connected) {
322                 selectionKey().interestOps(SelectionKey.OP_CONNECT);
323             }
324             success = true;
325             return connected;
326         } finally {
327             if (!success) {
328                 doClose();
329             }
330         }
331     }
332 
333     @Override
334     protected void doFinishConnect() throws Exception {
335         if (!javaChannel().finishConnect()) {
336             throw new Error();
337         }
338     }
339 
340     @Override
341     protected void doDisconnect() throws Exception {
342         doClose();
343     }
344 
345     @Override
346     protected void doClose() throws Exception {
347         try {
348             super.doClose();
349         } finally {
350             javaChannel().close();
351         }
352     }
353 
354     @Override
355     protected int doReadBytes(ByteBuf byteBuf) throws Exception {
356         final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
357         allocHandle.attemptedBytesRead(byteBuf.writableBytes());
358         return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
359     }
360 
361     @Override
362     protected int doWriteBytes(ByteBuf buf) throws Exception {
363         final int expectedWrittenBytes = buf.readableBytes();
364         return buf.readBytes(javaChannel(), expectedWrittenBytes);
365     }
366 
367     @Override
368     protected long doWriteFileRegion(FileRegion region) throws Exception {
369         final long position = region.transferred();
370         return region.transferTo(javaChannel(), position);
371     }
372 
373     private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
374         // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change
375         // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try
376         // make a best effort to adjust as OS behavior changes.
377         if (attempted == written) {
378             if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
379                 ((NioDomainSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
380             }
381         } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
382             ((NioDomainSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
383         }
384     }
385 
386     @Override
387     protected void doWrite(ChannelOutboundBuffer in) throws Exception {
388         SocketChannel ch = javaChannel();
389         int writeSpinCount = config().getWriteSpinCount();
390         do {
391             if (in.isEmpty()) {
392                 // All written so clear OP_WRITE
393                 clearOpWrite();
394                 // Directly return here so incompleteWrite(...) is not called.
395                 return;
396             }
397 
398             // Ensure the pending writes are made of ByteBufs only.
399             int maxBytesPerGatheringWrite = ((NioDomainSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
400             ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
401             int nioBufferCnt = in.nioBufferCount();
402 
403             // Always use nioBuffers() to workaround data-corruption.
404             // See https://github.com/netty/netty/issues/2761
405             switch (nioBufferCnt) {
406                 case 0:
407                     // We have something else beside ByteBuffers to write so fallback to normal writes.
408                     writeSpinCount -= doWrite0(in);
409                     break;
410                 case 1: {
411                     // Only one ByteBuf so use non-gathering write
412                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
413                     // to check if the total size of all the buffers is non-zero.
414                     ByteBuffer buffer = nioBuffers[0];
415                     int attemptedBytes = buffer.remaining();
416                     final int localWrittenBytes = ch.write(buffer);
417                     if (localWrittenBytes <= 0) {
418                         incompleteWrite(true);
419                         return;
420                     }
421                     adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
422                     in.removeBytes(localWrittenBytes);
423                     --writeSpinCount;
424                     break;
425                 }
426                 default: {
427                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
428                     // to check if the total size of all the buffers is non-zero.
429                     // We limit the max amount to int above so cast is safe
430                     long attemptedBytes = in.nioBufferSize();
431                     final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
432                     if (localWrittenBytes <= 0) {
433                         incompleteWrite(true);
434                         return;
435                     }
436                     // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above.
437                     adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
438                             maxBytesPerGatheringWrite);
439                     in.removeBytes(localWrittenBytes);
440                     --writeSpinCount;
441                     break;
442                 }
443             }
444         } while (writeSpinCount > 0);
445 
446         incompleteWrite(writeSpinCount < 0);
447     }
448 
449     @Override
450     protected AbstractNioUnsafe newUnsafe() {
451         return new NioSocketChannelUnsafe();
452     }
453 
454     private final class NioSocketChannelUnsafe extends NioByteUnsafe {
455         // Only extending it so we create a new instance in newUnsafe() and return it.
456     }
457 
458     private final class NioDomainSocketChannelConfig extends DefaultChannelConfig
459             implements DuplexChannelConfig {
460         private volatile boolean allowHalfClosure;
461         private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
462         private final SocketChannel javaChannel;
463         private NioDomainSocketChannelConfig(NioDomainSocketChannel channel, SocketChannel javaChannel) {
464             super(channel);
465             this.javaChannel = javaChannel;
466             calculateMaxBytesPerGatheringWrite();
467         }
468 
469         @Override
470         public boolean isAllowHalfClosure() {
471             return allowHalfClosure;
472         }
473 
474         @Override
475         public NioDomainSocketChannelConfig setAllowHalfClosure(boolean allowHalfClosure) {
476             this.allowHalfClosure = allowHalfClosure;
477             return this;
478         }
479         @Override
480         public Map<ChannelOption<?>, Object> getOptions() {
481             List<ChannelOption<?>> options = new ArrayList<ChannelOption<?>>();
482             options.add(SO_RCVBUF);
483             options.add(SO_SNDBUF);
484             for (ChannelOption<?> opt : NioChannelOption.getOptions(jdkChannel())) {
485                 options.add(opt);
486             }
487             return getOptions(super.getOptions(), options.toArray(new ChannelOption[0]));
488         }
489 
490         @SuppressWarnings("unchecked")
491         @Override
492         public <T> T getOption(ChannelOption<T> option) {
493             if (option == SO_RCVBUF) {
494                 return (T) Integer.valueOf(getReceiveBufferSize());
495             }
496             if (option == SO_SNDBUF) {
497                 return (T) Integer.valueOf(getSendBufferSize());
498             }
499             if (option instanceof NioChannelOption) {
500                 return NioChannelOption.getOption(jdkChannel(), (NioChannelOption<T>) option);
501             }
502 
503             return super.getOption(option);
504         }
505 
506         @Override
507         public <T> boolean setOption(ChannelOption<T> option, T value) {
508             if (option == SO_RCVBUF) {
509                 validate(option, value);
510                 setReceiveBufferSize((Integer) value);
511             } else if (option == SO_SNDBUF) {
512                 validate(option, value);
513                 setSendBufferSize((Integer) value);
514             } else if (option instanceof NioChannelOption) {
515                 return NioChannelOption.setOption(jdkChannel(), (NioChannelOption<T>) option, value);
516             } else {
517                 return super.setOption(option, value);
518             }
519 
520             return true;
521         }
522 
523         @SuppressJava6Requirement(reason = "Usage guarded by java version check")
524         private int getReceiveBufferSize() {
525             try {
526                 return javaChannel.getOption(StandardSocketOptions.SO_RCVBUF);
527             } catch (IOException e) {
528                 throw new ChannelException(e);
529             }
530         }
531 
532         @SuppressJava6Requirement(reason = "Usage guarded by java version check")
533         private NioDomainSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) {
534             try {
535                 javaChannel.<Integer>setOption(StandardSocketOptions.SO_RCVBUF, receiveBufferSize);
536             } catch (IOException e) {
537                 throw new ChannelException(e);
538             }
539             return this;
540         }
541 
542         @SuppressJava6Requirement(reason = "Usage guarded by java version check")
543         private int getSendBufferSize() {
544             try {
545                 return javaChannel.getOption(StandardSocketOptions.SO_SNDBUF);
546             } catch (IOException e) {
547                 throw new ChannelException(e);
548             }
549         }
550         @SuppressJava6Requirement(reason = "Usage guarded by java version check")
551         private NioDomainSocketChannelConfig setSendBufferSize(int sendBufferSize) {
552             try {
553                 javaChannel.<Integer>setOption(StandardSocketOptions.SO_SNDBUF, sendBufferSize);
554             } catch (IOException e) {
555                 throw new ChannelException(e);
556             }
557             return this;
558         }
559 
560         @Override
561         public NioDomainSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) {
562             super.setConnectTimeoutMillis(connectTimeoutMillis);
563             return this;
564         }
565 
566         @Override
567         @Deprecated
568         public NioDomainSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) {
569             super.setMaxMessagesPerRead(maxMessagesPerRead);
570             return this;
571         }
572 
573         @Override
574         public NioDomainSocketChannelConfig setWriteSpinCount(int writeSpinCount) {
575             super.setWriteSpinCount(writeSpinCount);
576             return this;
577         }
578 
579         @Override
580         public NioDomainSocketChannelConfig setAllocator(ByteBufAllocator allocator) {
581             super.setAllocator(allocator);
582             return this;
583         }
584 
585         @Override
586         public NioDomainSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) {
587             super.setRecvByteBufAllocator(allocator);
588             return this;
589         }
590 
591         @Override
592         public NioDomainSocketChannelConfig setAutoRead(boolean autoRead) {
593             super.setAutoRead(autoRead);
594             return this;
595         }
596 
597         @Override
598         public NioDomainSocketChannelConfig setAutoClose(boolean autoClose) {
599             super.setAutoClose(autoClose);
600             return this;
601         }
602 
603         @Override
604         public NioDomainSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) {
605             super.setWriteBufferHighWaterMark(writeBufferHighWaterMark);
606             return this;
607         }
608 
609         @Override
610         public NioDomainSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) {
611             super.setWriteBufferLowWaterMark(writeBufferLowWaterMark);
612             return this;
613         }
614 
615         @Override
616         public NioDomainSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) {
617             super.setWriteBufferWaterMark(writeBufferWaterMark);
618             return this;
619         }
620 
621         @Override
622         public NioDomainSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) {
623             super.setMessageSizeEstimator(estimator);
624             return this;
625         }
626 
627         @Override
628         protected void autoReadCleared() {
629             clearReadPending();
630         }
631 
632         void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
633             this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
634         }
635 
636         int getMaxBytesPerGatheringWrite() {
637             return maxBytesPerGatheringWrite;
638         }
639 
640         private void calculateMaxBytesPerGatheringWrite() {
641             // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide.
642             int newSendBufferSize = getSendBufferSize() << 1;
643             if (newSendBufferSize > 0) {
644                 setMaxBytesPerGatheringWrite(newSendBufferSize);
645             }
646         }
647 
648         private SocketChannel jdkChannel() {
649             return javaChannel;
650         }
651     }
652 }