View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.uring;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelFuture;
21  import io.netty.channel.ChannelFutureListener;
22  import io.netty.channel.ChannelMetadata;
23  import io.netty.channel.ChannelOutboundBuffer;
24  import io.netty.channel.ChannelPipeline;
25  import io.netty.channel.ChannelPromise;
26  import io.netty.channel.EventLoop;
27  import io.netty.channel.socket.DuplexChannel;
28  import io.netty.channel.unix.IovArray;
29  import io.netty.channel.unix.Limits;
30  import io.netty.util.internal.logging.InternalLogger;
31  import io.netty.util.internal.logging.InternalLoggerFactory;
32  
33  import java.net.SocketAddress;
34  import java.io.IOException;
35  
36  import static io.netty.channel.unix.Errors.ioResult;
37  
38  abstract class AbstractIoUringStreamChannel extends AbstractIoUringChannel implements DuplexChannel {
39      private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractIoUringStreamChannel.class);
40      private static final ChannelMetadata METADATA = new ChannelMetadata(false, 16);
41  
42      // Store the opCode so we know if we used WRITE or WRITEV.
43      private byte writeOpCode;
44  
45      // Keep track of the ids used for write and read so we can cancel these when needed.
46      private long writeId;
47      private long readId;
48  
49      AbstractIoUringStreamChannel(Channel parent, LinuxSocket socket, boolean active) {
50          // Use a blocking fd, we can make use of fastpoll.
51          super(parent, LinuxSocket.makeBlocking(socket), active);
52      }
53  
54      AbstractIoUringStreamChannel(Channel parent, LinuxSocket socket, SocketAddress remote) {
55          // Use a blocking fd, we can make use of fastpoll.
56          super(parent, LinuxSocket.makeBlocking(socket), remote);
57      }
58  
59      @Override
60      public ChannelMetadata metadata() {
61          return METADATA;
62      }
63  
64      @Override
65      protected final AbstractUringUnsafe newUnsafe() {
66          return new IoUringStreamUnsafe();
67      }
68  
69      @Override
70      public final ChannelFuture shutdown() {
71          return shutdown(newPromise());
72      }
73  
74      @Override
75      public final ChannelFuture shutdown(final ChannelPromise promise) {
76          ChannelFuture shutdownOutputFuture = shutdownOutput();
77          if (shutdownOutputFuture.isDone()) {
78              shutdownOutputDone(shutdownOutputFuture, promise);
79          } else {
80              shutdownOutputFuture.addListener(new ChannelFutureListener() {
81                  @Override
82                  public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
83                      shutdownOutputDone(shutdownOutputFuture, promise);
84                  }
85              });
86          }
87          return promise;
88      }
89  
90      @Override
91      protected final void doShutdownOutput() throws Exception {
92          socket.shutdown(false, true);
93      }
94  
95      private void shutdownInput0(final ChannelPromise promise) {
96          try {
97              socket.shutdown(true, false);
98              promise.setSuccess();
99          } catch (Throwable cause) {
100             promise.setFailure(cause);
101         }
102     }
103 
104     @Override
105     public final boolean isOutputShutdown() {
106         return socket.isOutputShutdown();
107     }
108 
109     @Override
110     public final boolean isInputShutdown() {
111         return socket.isInputShutdown();
112     }
113 
114     @Override
115     public final boolean isShutdown() {
116         return socket.isShutdown();
117     }
118 
119     @Override
120     public final ChannelFuture shutdownOutput() {
121         return shutdownOutput(newPromise());
122     }
123 
124     @Override
125     public final ChannelFuture shutdownOutput(final ChannelPromise promise) {
126         EventLoop loop = eventLoop();
127         if (loop.inEventLoop()) {
128             ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
129         } else {
130             loop.execute(new Runnable() {
131                 @Override
132                 public void run() {
133                     ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
134                 }
135             });
136         }
137 
138         return promise;
139     }
140 
141     @Override
142     public final ChannelFuture shutdownInput() {
143         return shutdownInput(newPromise());
144     }
145 
146     @Override
147     public final ChannelFuture shutdownInput(final ChannelPromise promise) {
148         EventLoop loop = eventLoop();
149         if (loop.inEventLoop()) {
150             shutdownInput0(promise);
151         } else {
152             loop.execute(new Runnable() {
153                 @Override
154                 public void run() {
155                     shutdownInput0(promise);
156                 }
157             });
158         }
159         return promise;
160     }
161 
162     private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
163         ChannelFuture shutdownInputFuture = shutdownInput();
164         if (shutdownInputFuture.isDone()) {
165             shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
166         } else {
167             shutdownInputFuture.addListener(new ChannelFutureListener() {
168                 @Override
169                 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
170                     shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
171                 }
172             });
173         }
174     }
175 
176     private static void shutdownDone(ChannelFuture shutdownOutputFuture,
177                                      ChannelFuture shutdownInputFuture,
178                                      ChannelPromise promise) {
179         Throwable shutdownOutputCause = shutdownOutputFuture.cause();
180         Throwable shutdownInputCause = shutdownInputFuture.cause();
181         if (shutdownOutputCause != null) {
182             if (shutdownInputCause != null) {
183                 logger.info("Exception suppressed because a previous exception occurred.",
184                              shutdownInputCause);
185             }
186             promise.setFailure(shutdownOutputCause);
187         } else if (shutdownInputCause != null) {
188             promise.setFailure(shutdownInputCause);
189         } else {
190             promise.setSuccess();
191         }
192     }
193 
194     @Override
195     protected final void doRegister(ChannelPromise promise) {
196         super.doRegister(promise);
197         promise.addListener(f -> {
198             if (f.isSuccess()) {
199                 if (active) {
200                     // Register for POLLRDHUP if this channel is already considered active.
201                     schedulePollRdHup();
202                 }
203             }
204         });
205     }
206 
207     private final class IoUringStreamUnsafe extends AbstractUringUnsafe {
208 
209         private ByteBuf readBuffer;
210         private IovArray iovArray;
211 
212         @Override
213         protected int scheduleWriteMultiple(ChannelOutboundBuffer in) {
214             assert iovArray == null;
215             assert writeId == 0;
216             int numElements = Math.min(in.size(), Limits.IOV_MAX);
217             ByteBuf iovArrayBuffer = alloc().directBuffer(numElements * IovArray.IOV_SIZE);
218             iovArray = new IovArray(iovArrayBuffer);
219             try {
220                 int offset = iovArray.count();
221                 in.forEachFlushedMessage(iovArray);
222 
223                 int fd = fd().intValue();
224                 IoUringIoRegistration registration = registration();
225                 IoUringIoOps ops = IoUringIoOps.newWritev(fd, 0, 0, iovArray.memoryAddress(offset),
226                         iovArray.count() - offset, nextOpsId());
227                 byte opCode = ops.opcode();
228                 writeId = registration.submit(ops);
229                 writeOpCode = opCode;
230             } catch (Exception e) {
231                 iovArray.release();
232                 iovArray = null;
233 
234                 // This should never happen, anyway fallback to single write.
235                 scheduleWriteSingle(in.current());
236             }
237             return 1;
238         }
239 
240         @Override
241         protected int scheduleWriteSingle(Object msg) {
242             assert iovArray == null;
243             assert writeId == 0;
244             ByteBuf buf = (ByteBuf) msg;
245 
246             int fd = fd().intValue();
247             IoUringIoRegistration registration = registration();
248             IoUringIoOps ops = IoUringIoOps.newWrite(fd, 0, 0,
249                     buf.memoryAddress() + buf.readerIndex(), buf.readableBytes(), nextOpsId());
250             byte opCode = ops.opcode();
251             writeId = registration.submit(ops);
252             writeOpCode = opCode;
253             return 1;
254         }
255 
256         @Override
257         protected int scheduleRead0(boolean first) {
258             assert readBuffer == null;
259             assert readId == 0;
260 
261             final IoUringRecvByteAllocatorHandle allocHandle = recvBufAllocHandle();
262             ByteBuf byteBuf = allocHandle.allocate(alloc());
263             allocHandle.attemptedBytesRead(byteBuf.writableBytes());
264 
265             readBuffer = byteBuf;
266 
267             int fd = fd().intValue();
268             IoUringIoRegistration registration = registration();
269             // Depending on if this is the first read or not we will use Native.MSG_DONTWAIT.
270             // The idea is that if the socket is blocking we can do the first read in a blocking fashion
271             // and so not need to also register POLLIN. As we can not 100 % sure if reads after the first will
272             // be possible directly we schedule these with Native.MSG_DONTWAIT. This allows us to still be
273             // able to signal the fireChannelReadComplete() in a timely manner and be consistent with other
274             // transports.
275             IoUringIoOps ops = IoUringIoOps.newRecv(fd, 0, first ? 0 : Native.MSG_DONTWAIT,
276                     byteBuf.memoryAddress() + byteBuf.writerIndex(), byteBuf.writableBytes(), nextOpsId());
277             readId = registration.submit(ops);
278             return 1;
279         }
280 
281         @Override
282         protected void readComplete0(int res, int flags, int data, int outstanding) {
283             assert readId != 0;
284             readId = 0;
285             boolean allDataRead = false;
286 
287             final IoUringRecvByteAllocatorHandle allocHandle = recvBufAllocHandle();
288             final ChannelPipeline pipeline = pipeline();
289             ByteBuf byteBuf = this.readBuffer;
290             this.readBuffer = null;
291             assert byteBuf != null;
292 
293             try {
294                 if (res < 0) {
295                     if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
296                         byteBuf.release();
297                         return;
298                     }
299                     // If res is negative we should pass it to ioResult(...) which will either throw
300                     // or convert it to 0 if we could not read because the socket was not readable.
301                     allocHandle.lastBytesRead(ioResult("io_uring read", res));
302                 } else if (res > 0) {
303                     byteBuf.writerIndex(byteBuf.writerIndex() + res);
304                     allocHandle.lastBytesRead(res);
305                 } else {
306                     // EOF which we signal with -1.
307                     allocHandle.lastBytesRead(-1);
308                 }
309                 if (allocHandle.lastBytesRead() <= 0) {
310                     // nothing was read, release the buffer.
311                     byteBuf.release();
312                     byteBuf = null;
313                     allDataRead = allocHandle.lastBytesRead() < 0;
314                     if (allDataRead) {
315                         // There is nothing left to read as we received an EOF.
316                         shutdownInput(true);
317                     }
318                     allocHandle.readComplete();
319                     pipeline.fireChannelReadComplete();
320                     return;
321                 }
322 
323                 allocHandle.incMessagesRead(1);
324                 pipeline.fireChannelRead(byteBuf);
325                 byteBuf = null;
326                 if (allocHandle.continueReading() &&
327                         // If IORING_CQE_F_SOCK_NONEMPTY is supported we should check for it first before
328                         // trying to schedule a read. If it's supported and not part of the flags we know for sure
329                         // that the next read (which would be using Native.MSG_DONTWAIT) will complete without
330                         // be able to read any data. This is useless work and we can skip it.
331                         (!IoUring.isIOUringCqeFSockNonEmptySupported() ||
332                         (flags & Native.IORING_CQE_F_SOCK_NONEMPTY) != 0)) {
333                     // Let's schedule another read.
334                     scheduleRead(false);
335                 } else {
336                     // We did not fill the whole ByteBuf so we should break the "read loop" and try again later.
337                     allocHandle.readComplete();
338                     pipeline.fireChannelReadComplete();
339                 }
340             } catch (Throwable t) {
341                 handleReadException(pipeline, byteBuf, t, allDataRead, allocHandle);
342             }
343         }
344 
345         private void handleReadException(ChannelPipeline pipeline, ByteBuf byteBuf,
346                                          Throwable cause, boolean allDataRead,
347                                          IoUringRecvByteAllocatorHandle allocHandle) {
348             if (byteBuf != null) {
349                 if (byteBuf.isReadable()) {
350                     pipeline.fireChannelRead(byteBuf);
351                 } else {
352                     byteBuf.release();
353                 }
354             }
355             allocHandle.readComplete();
356             pipeline.fireChannelReadComplete();
357             pipeline.fireExceptionCaught(cause);
358             if (allDataRead || cause instanceof IOException) {
359                 shutdownInput(true);
360             }
361         }
362 
363         @Override
364         boolean writeComplete0(int res, int flags, int data, int outstanding) {
365             assert writeId != 0;
366             writeId = 0;
367             writeOpCode = 0;
368             IovArray iovArray = this.iovArray;
369             if (iovArray != null) {
370                 this.iovArray = null;
371                 iovArray.release();
372             }
373             if (res >= 0) {
374                 unsafe().outboundBuffer().removeBytes(res);
375             } else if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
376                 return true;
377             } else {
378                 try {
379                     if (ioResult("io_uring write", res) == 0) {
380                         return false;
381                     }
382                 } catch (Throwable cause) {
383                     handleWriteError(cause);
384                 }
385             }
386             return true;
387         }
388     }
389 
390     @Override
391     protected final void cancelOutstandingReads(IoUringIoRegistration registration, int numOutstandingReads) {
392         if (readId != 0) {
393             // Let's try to cancel outstanding reads as these might be submitted and waiting for data (via fastpoll).
394             assert numOutstandingReads == 1;
395             int fd = fd().intValue();
396             IoUringIoOps ops = IoUringIoOps.newAsyncCancel(fd, 0, readId, Native.IORING_OP_RECV);
397             registration.submit(ops);
398         } else {
399             assert numOutstandingReads == 0;
400         }
401     }
402 
403     @Override
404     protected final void cancelOutstandingWrites(IoUringIoRegistration registration, int numOutstandingWrites) {
405         if (writeId != 0) {
406             // Let's try to cancel outstanding writes as these might be submitted and waiting to finish writing
407             // (via fastpoll).
408             assert numOutstandingWrites == 1;
409             assert writeOpCode != 0;
410             int fd = fd().intValue();
411             registration.submit(IoUringIoOps.newAsyncCancel(fd, 0, writeId, writeOpCode));
412         } else {
413             assert numOutstandingWrites == 0;
414         }
415     }
416 }