View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.socket.nio;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufAllocator;
20  import io.netty.channel.Channel;
21  import io.netty.channel.ChannelConfig;
22  import io.netty.channel.ChannelException;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelOption;
26  import io.netty.channel.ChannelOutboundBuffer;
27  import io.netty.channel.ChannelPromise;
28  import io.netty.channel.DefaultChannelConfig;
29  import io.netty.channel.EventLoop;
30  import io.netty.channel.FileRegion;
31  import io.netty.channel.MessageSizeEstimator;
32  import io.netty.channel.RecvByteBufAllocator;
33  import io.netty.channel.WriteBufferWaterMark;
34  import io.netty.channel.nio.AbstractNioByteChannel;
35  import io.netty.channel.socket.DuplexChannel;
36  import io.netty.channel.socket.DuplexChannelConfig;
37  import io.netty.channel.socket.ServerSocketChannel;
38  import io.netty.util.internal.PlatformDependent;
39  import io.netty.util.internal.SocketUtils;
40  import io.netty.util.internal.SuppressJava6Requirement;
41  import io.netty.util.internal.logging.InternalLogger;
42  import io.netty.util.internal.logging.InternalLoggerFactory;
43  
44  import java.io.IOException;
45  import java.lang.reflect.Method;
46  import java.net.SocketAddress;
47  import java.net.StandardSocketOptions;
48  import java.nio.ByteBuffer;
49  import java.nio.channels.SelectionKey;
50  import java.nio.channels.SocketChannel;
51  import java.nio.channels.spi.SelectorProvider;
52  import java.util.ArrayList;
53  import java.util.List;
54  import java.util.Map;
55  
56  import static io.netty.channel.ChannelOption.SO_RCVBUF;
57  import static io.netty.channel.ChannelOption.SO_SNDBUF;
58  import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
59  
60  /**
61   * {@link DuplexChannel} which uses NIO selector based implementation to support
62   * UNIX Domain Sockets. This is only supported when using Java 16+.
63   */
64  public final class NioDomainSocketChannel extends AbstractNioByteChannel
65          implements DuplexChannel {
66      private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioDomainSocketChannel.class);
67      private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
68  
69      private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
70              SelectorProviderUtil.findOpenMethod("openSocketChannel");
71  
72      private final ChannelConfig config;
73      private volatile boolean isInputShutdown;
74      private volatile boolean isOutputShutdown;
75  
76      private static SocketChannel newChannel(SelectorProvider provider) {
77          if (PlatformDependent.javaVersion() < 16) {
78              throw new UnsupportedOperationException("Only supported on java 16+");
79          }
80          try {
81              SocketChannel channel = SelectorProviderUtil.newDomainSocketChannel(
82                      OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider);
83              if (channel == null) {
84                  throw new ChannelException("Failed to open a socket.");
85              }
86              return channel;
87          } catch (IOException e) {
88              throw new ChannelException("Failed to open a socket.", e);
89          }
90      }
91  
92      /**
93       * Create a new instance
94       */
95      public NioDomainSocketChannel() {
96          this(DEFAULT_SELECTOR_PROVIDER);
97      }
98  
99      /**
100      * Create a new instance using the given {@link SelectorProvider}.
101      */
102     public NioDomainSocketChannel(SelectorProvider provider) {
103         this(newChannel(provider));
104     }
105 
106     /**
107      * Create a new instance using the given {@link SocketChannel}.
108      */
109     public NioDomainSocketChannel(SocketChannel socket) {
110         this(null, socket);
111     }
112 
113     /**
114      * Create a new instance
115      *
116      * @param parent    the {@link Channel} which created this instance or {@code null} if it was created by the user
117      * @param socket    the {@link SocketChannel} which will be used
118      */
119     public NioDomainSocketChannel(Channel parent, SocketChannel socket) {
120         super(parent, socket);
121         if (PlatformDependent.javaVersion() < 16) {
122             throw new UnsupportedOperationException("Only supported on java 16+");
123         }
124         config = new NioDomainSocketChannelConfig(this, socket);
125     }
126 
127     @Override
128     public ServerSocketChannel parent() {
129         return (ServerSocketChannel) super.parent();
130     }
131 
132     @Override
133     public ChannelConfig config() {
134         return config;
135     }
136 
137     @Override
138     protected SocketChannel javaChannel() {
139         return (SocketChannel) super.javaChannel();
140     }
141 
142     @Override
143     public boolean isActive() {
144         SocketChannel ch = javaChannel();
145         return ch.isOpen() && ch.isConnected();
146     }
147 
148     @Override
149     public boolean isOutputShutdown() {
150         return isOutputShutdown || !isActive();
151     }
152 
153     @Override
154     public boolean isInputShutdown() {
155         return isInputShutdown || !isActive();
156     }
157 
158     @Override
159     public boolean isShutdown() {
160         return isInputShutdown() && isOutputShutdown() || !isActive();
161     }
162 
163     @Override
164     protected void doShutdownOutput() throws Exception {
165         javaChannel().shutdownOutput();
166         isOutputShutdown = true;
167     }
168 
169     @Override
170     public ChannelFuture shutdownOutput() {
171         return shutdownOutput(newPromise());
172     }
173 
174     @Override
175     public ChannelFuture shutdownOutput(final ChannelPromise promise) {
176         final EventLoop loop = eventLoop();
177         if (loop.inEventLoop()) {
178             ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
179         } else {
180             loop.execute(new Runnable() {
181                 @Override
182                 public void run() {
183                     ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
184                 }
185             });
186         }
187         return promise;
188     }
189 
190     @Override
191     public ChannelFuture shutdownInput() {
192         return shutdownInput(newPromise());
193     }
194 
195     @Override
196     protected boolean isInputShutdown0() {
197         return isInputShutdown();
198     }
199 
200     @Override
201     public ChannelFuture shutdownInput(final ChannelPromise promise) {
202         EventLoop loop = eventLoop();
203         if (loop.inEventLoop()) {
204             shutdownInput0(promise);
205         } else {
206             loop.execute(new Runnable() {
207                 @Override
208                 public void run() {
209                     shutdownInput0(promise);
210                 }
211             });
212         }
213         return promise;
214     }
215 
216     @Override
217     public ChannelFuture shutdown() {
218         return shutdown(newPromise());
219     }
220 
221     @Override
222     public ChannelFuture shutdown(final ChannelPromise promise) {
223         ChannelFuture shutdownOutputFuture = shutdownOutput();
224         if (shutdownOutputFuture.isDone()) {
225             shutdownOutputDone(shutdownOutputFuture, promise);
226         } else {
227             shutdownOutputFuture.addListener(new ChannelFutureListener() {
228                 @Override
229                 public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
230                     shutdownOutputDone(shutdownOutputFuture, promise);
231                 }
232             });
233         }
234         return promise;
235     }
236 
237     private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
238         ChannelFuture shutdownInputFuture = shutdownInput();
239         if (shutdownInputFuture.isDone()) {
240             shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
241         } else {
242             shutdownInputFuture.addListener(new ChannelFutureListener() {
243                 @Override
244                 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
245                     shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
246                 }
247             });
248         }
249     }
250 
251     private static void shutdownDone(ChannelFuture shutdownOutputFuture,
252                                      ChannelFuture shutdownInputFuture,
253                                      ChannelPromise promise) {
254         Throwable shutdownOutputCause = shutdownOutputFuture.cause();
255         Throwable shutdownInputCause = shutdownInputFuture.cause();
256         if (shutdownOutputCause != null) {
257             if (shutdownInputCause != null) {
258                 logger.debug("Exception suppressed because a previous exception occurred.",
259                         shutdownInputCause);
260             }
261             promise.setFailure(shutdownOutputCause);
262         } else if (shutdownInputCause != null) {
263             promise.setFailure(shutdownInputCause);
264         } else {
265             promise.setSuccess();
266         }
267     }
268     private void shutdownInput0(final ChannelPromise promise) {
269         try {
270             shutdownInput0();
271             promise.setSuccess();
272         } catch (Throwable t) {
273             promise.setFailure(t);
274         }
275     }
276 
277     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
278     private void shutdownInput0() throws Exception {
279         javaChannel().shutdownInput();
280         isInputShutdown = true;
281     }
282 
283     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
284     @Override
285     protected SocketAddress localAddress0() {
286         try {
287             return javaChannel().getLocalAddress();
288         } catch (Exception ignore) {
289             // ignore
290         }
291         return null;
292     }
293 
294     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
295     @Override
296     protected SocketAddress remoteAddress0() {
297         try {
298             return javaChannel().getRemoteAddress();
299         } catch (Exception ignore) {
300             // ignore
301         }
302         return null;
303     }
304 
305     @Override
306     protected void doBind(SocketAddress localAddress) throws Exception {
307         SocketUtils.bind(javaChannel(), localAddress);
308     }
309 
310     @Override
311     protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
312         if (localAddress != null) {
313             doBind(localAddress);
314         }
315 
316         boolean success = false;
317         try {
318             boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
319             if (!connected) {
320                 selectionKey().interestOps(SelectionKey.OP_CONNECT);
321             }
322             success = true;
323             return connected;
324         } finally {
325             if (!success) {
326                 doClose();
327             }
328         }
329     }
330 
331     @Override
332     protected void doFinishConnect() throws Exception {
333         if (!javaChannel().finishConnect()) {
334             throw new Error();
335         }
336     }
337 
338     @Override
339     protected void doDisconnect() throws Exception {
340         doClose();
341     }
342 
343     @Override
344     protected void doClose() throws Exception {
345         try {
346             super.doClose();
347         } finally {
348             javaChannel().close();
349         }
350     }
351 
352     @Override
353     protected int doReadBytes(ByteBuf byteBuf) throws Exception {
354         final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
355         allocHandle.attemptedBytesRead(byteBuf.writableBytes());
356         return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
357     }
358 
359     @Override
360     protected int doWriteBytes(ByteBuf buf) throws Exception {
361         final int expectedWrittenBytes = buf.readableBytes();
362         return buf.readBytes(javaChannel(), expectedWrittenBytes);
363     }
364 
365     @Override
366     protected long doWriteFileRegion(FileRegion region) throws Exception {
367         final long position = region.transferred();
368         return region.transferTo(javaChannel(), position);
369     }
370 
371     private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
372         // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change
373         // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try
374         // make a best effort to adjust as OS behavior changes.
375         if (attempted == written) {
376             if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
377                 ((NioDomainSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
378             }
379         } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
380             ((NioDomainSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
381         }
382     }
383 
384     @Override
385     protected void doWrite(ChannelOutboundBuffer in) throws Exception {
386         SocketChannel ch = javaChannel();
387         int writeSpinCount = config().getWriteSpinCount();
388         do {
389             if (in.isEmpty()) {
390                 // All written so clear OP_WRITE
391                 clearOpWrite();
392                 // Directly return here so incompleteWrite(...) is not called.
393                 return;
394             }
395 
396             // Ensure the pending writes are made of ByteBufs only.
397             int maxBytesPerGatheringWrite = ((NioDomainSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
398             ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
399             int nioBufferCnt = in.nioBufferCount();
400 
401             // Always use nioBuffers() to workaround data-corruption.
402             // See https://github.com/netty/netty/issues/2761
403             switch (nioBufferCnt) {
404                 case 0:
405                     // We have something else beside ByteBuffers to write so fallback to normal writes.
406                     writeSpinCount -= doWrite0(in);
407                     break;
408                 case 1: {
409                     // Only one ByteBuf so use non-gathering write
410                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
411                     // to check if the total size of all the buffers is non-zero.
412                     ByteBuffer buffer = nioBuffers[0];
413                     int attemptedBytes = buffer.remaining();
414                     final int localWrittenBytes = ch.write(buffer);
415                     if (localWrittenBytes <= 0) {
416                         incompleteWrite(true);
417                         return;
418                     }
419                     adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
420                     in.removeBytes(localWrittenBytes);
421                     --writeSpinCount;
422                     break;
423                 }
424                 default: {
425                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
426                     // to check if the total size of all the buffers is non-zero.
427                     // We limit the max amount to int above so cast is safe
428                     long attemptedBytes = in.nioBufferSize();
429                     final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
430                     if (localWrittenBytes <= 0) {
431                         incompleteWrite(true);
432                         return;
433                     }
434                     // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above.
435                     adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
436                             maxBytesPerGatheringWrite);
437                     in.removeBytes(localWrittenBytes);
438                     --writeSpinCount;
439                     break;
440                 }
441             }
442         } while (writeSpinCount > 0);
443 
444         incompleteWrite(writeSpinCount < 0);
445     }
446 
447     @Override
448     protected AbstractNioUnsafe newUnsafe() {
449         return new NioSocketChannelUnsafe();
450     }
451 
452     private final class NioSocketChannelUnsafe extends NioByteUnsafe {
453         // Only extending it so we create a new instance in newUnsafe() and return it.
454     }
455 
456     private final class NioDomainSocketChannelConfig extends DefaultChannelConfig
457             implements DuplexChannelConfig {
458         private volatile boolean allowHalfClosure;
459         private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
460         private final SocketChannel javaChannel;
461         private NioDomainSocketChannelConfig(NioDomainSocketChannel channel, SocketChannel javaChannel) {
462             super(channel);
463             this.javaChannel = javaChannel;
464             calculateMaxBytesPerGatheringWrite();
465         }
466 
467         @Override
468         public boolean isAllowHalfClosure() {
469             return allowHalfClosure;
470         }
471 
472         @Override
473         public NioDomainSocketChannelConfig setAllowHalfClosure(boolean allowHalfClosure) {
474             this.allowHalfClosure = allowHalfClosure;
475             return this;
476         }
477         @Override
478         public Map<ChannelOption<?>, Object> getOptions() {
479             List<ChannelOption<?>> options = new ArrayList<ChannelOption<?>>();
480             options.add(SO_RCVBUF);
481             options.add(SO_SNDBUF);
482             for (ChannelOption<?> opt : NioChannelOption.getOptions(jdkChannel())) {
483                 options.add(opt);
484             }
485             return getOptions(super.getOptions(), options.toArray(new ChannelOption[0]));
486         }
487 
488         @SuppressWarnings("unchecked")
489         @Override
490         public <T> T getOption(ChannelOption<T> option) {
491             if (option == SO_RCVBUF) {
492                 return (T) Integer.valueOf(getReceiveBufferSize());
493             }
494             if (option == SO_SNDBUF) {
495                 return (T) Integer.valueOf(getSendBufferSize());
496             }
497             if (option instanceof NioChannelOption) {
498                 return NioChannelOption.getOption(jdkChannel(), (NioChannelOption<T>) option);
499             }
500 
501             return super.getOption(option);
502         }
503 
504         @Override
505         public <T> boolean setOption(ChannelOption<T> option, T value) {
506             if (option == SO_RCVBUF) {
507                 validate(option, value);
508                 setReceiveBufferSize((Integer) value);
509             } else if (option == SO_SNDBUF) {
510                 validate(option, value);
511                 setSendBufferSize((Integer) value);
512             } else if (option instanceof NioChannelOption) {
513                 return NioChannelOption.setOption(jdkChannel(), (NioChannelOption<T>) option, value);
514             } else {
515                 return super.setOption(option, value);
516             }
517 
518             return true;
519         }
520 
521         @SuppressJava6Requirement(reason = "Usage guarded by java version check")
522         private int getReceiveBufferSize() {
523             try {
524                 return javaChannel.getOption(StandardSocketOptions.SO_RCVBUF);
525             } catch (IOException e) {
526                 throw new ChannelException(e);
527             }
528         }
529 
530         @SuppressJava6Requirement(reason = "Usage guarded by java version check")
531         private NioDomainSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) {
532             try {
533                 javaChannel.<Integer>setOption(StandardSocketOptions.SO_RCVBUF, receiveBufferSize);
534             } catch (IOException e) {
535                 throw new ChannelException(e);
536             }
537             return this;
538         }
539 
540         @SuppressJava6Requirement(reason = "Usage guarded by java version check")
541         private int getSendBufferSize() {
542             try {
543                 return javaChannel.getOption(StandardSocketOptions.SO_SNDBUF);
544             } catch (IOException e) {
545                 throw new ChannelException(e);
546             }
547         }
548         @SuppressJava6Requirement(reason = "Usage guarded by java version check")
549         private NioDomainSocketChannelConfig setSendBufferSize(int sendBufferSize) {
550             try {
551                 javaChannel.<Integer>setOption(StandardSocketOptions.SO_SNDBUF, sendBufferSize);
552             } catch (IOException e) {
553                 throw new ChannelException(e);
554             }
555             return this;
556         }
557 
558         @Override
559         public NioDomainSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) {
560             super.setConnectTimeoutMillis(connectTimeoutMillis);
561             return this;
562         }
563 
564         @Override
565         @Deprecated
566         public NioDomainSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) {
567             super.setMaxMessagesPerRead(maxMessagesPerRead);
568             return this;
569         }
570 
571         @Override
572         public NioDomainSocketChannelConfig setWriteSpinCount(int writeSpinCount) {
573             super.setWriteSpinCount(writeSpinCount);
574             return this;
575         }
576 
577         @Override
578         public NioDomainSocketChannelConfig setAllocator(ByteBufAllocator allocator) {
579             super.setAllocator(allocator);
580             return this;
581         }
582 
583         @Override
584         public NioDomainSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) {
585             super.setRecvByteBufAllocator(allocator);
586             return this;
587         }
588 
589         @Override
590         public NioDomainSocketChannelConfig setAutoRead(boolean autoRead) {
591             super.setAutoRead(autoRead);
592             return this;
593         }
594 
595         @Override
596         public NioDomainSocketChannelConfig setAutoClose(boolean autoClose) {
597             super.setAutoClose(autoClose);
598             return this;
599         }
600 
601         @Override
602         public NioDomainSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) {
603             super.setWriteBufferHighWaterMark(writeBufferHighWaterMark);
604             return this;
605         }
606 
607         @Override
608         public NioDomainSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) {
609             super.setWriteBufferLowWaterMark(writeBufferLowWaterMark);
610             return this;
611         }
612 
613         @Override
614         public NioDomainSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) {
615             super.setWriteBufferWaterMark(writeBufferWaterMark);
616             return this;
617         }
618 
619         @Override
620         public NioDomainSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) {
621             super.setMessageSizeEstimator(estimator);
622             return this;
623         }
624 
625         @Override
626         protected void autoReadCleared() {
627             clearReadPending();
628         }
629 
630         void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
631             this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
632         }
633 
634         int getMaxBytesPerGatheringWrite() {
635             return maxBytesPerGatheringWrite;
636         }
637 
638         private void calculateMaxBytesPerGatheringWrite() {
639             // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide.
640             int newSendBufferSize = getSendBufferSize() << 1;
641             if (newSendBufferSize > 0) {
642                 setMaxBytesPerGatheringWrite(newSendBufferSize);
643             }
644         }
645 
646         private SocketChannel jdkChannel() {
647             return javaChannel;
648         }
649     }
650 }