View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.uring;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelFuture;
21  import io.netty.channel.ChannelFutureListener;
22  import io.netty.channel.ChannelMetadata;
23  import io.netty.channel.ChannelOutboundBuffer;
24  import io.netty.channel.ChannelPipeline;
25  import io.netty.channel.ChannelPromise;
26  import io.netty.channel.DefaultFileRegion;
27  import io.netty.channel.EventLoop;
28  import io.netty.channel.socket.DuplexChannel;
29  import io.netty.channel.unix.IovArray;
30  import io.netty.channel.unix.Limits;
31  import io.netty.util.internal.logging.InternalLogger;
32  import io.netty.util.internal.logging.InternalLoggerFactory;
33  
34  import java.io.IOException;
35  import java.net.SocketAddress;
36  
37  import static io.netty.channel.unix.Errors.ioResult;
38  
39  abstract class AbstractIoUringStreamChannel extends AbstractIoUringChannel implements DuplexChannel {
40      private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractIoUringStreamChannel.class);
41      private static final ChannelMetadata METADATA = new ChannelMetadata(false, 16);
42  
43      // Store the opCode so we know if we used WRITE or WRITEV.
44      private byte writeOpCode;
45  
46      // Keep track of the ids used for write and read so we can cancel these when needed.
47      private long writeId;
48      private long readId;
49  
50      AbstractIoUringStreamChannel(Channel parent, LinuxSocket socket, boolean active) {
51          // Use a blocking fd, we can make use of fastpoll.
52          super(parent, LinuxSocket.makeBlocking(socket), active);
53      }
54  
55      AbstractIoUringStreamChannel(Channel parent, LinuxSocket socket, SocketAddress remote) {
56          // Use a blocking fd, we can make use of fastpoll.
57          super(parent, LinuxSocket.makeBlocking(socket), remote);
58      }
59  
60      @Override
61      public ChannelMetadata metadata() {
62          return METADATA;
63      }
64  
65      @Override
66      protected final AbstractUringUnsafe newUnsafe() {
67          return new IoUringStreamUnsafe();
68      }
69  
70      @Override
71      public final ChannelFuture shutdown() {
72          return shutdown(newPromise());
73      }
74  
75      @Override
76      public final ChannelFuture shutdown(final ChannelPromise promise) {
77          ChannelFuture shutdownOutputFuture = shutdownOutput();
78          if (shutdownOutputFuture.isDone()) {
79              shutdownOutputDone(shutdownOutputFuture, promise);
80          } else {
81              shutdownOutputFuture.addListener(new ChannelFutureListener() {
82                  @Override
83                  public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
84                      shutdownOutputDone(shutdownOutputFuture, promise);
85                  }
86              });
87          }
88          return promise;
89      }
90  
91      @Override
92      protected final void doShutdownOutput() throws Exception {
93          socket.shutdown(false, true);
94      }
95  
96      private void shutdownInput0(final ChannelPromise promise) {
97          try {
98              socket.shutdown(true, false);
99              promise.setSuccess();
100         } catch (Throwable cause) {
101             promise.setFailure(cause);
102         }
103     }
104 
105     @Override
106     public final boolean isOutputShutdown() {
107         return socket.isOutputShutdown();
108     }
109 
110     @Override
111     public final boolean isInputShutdown() {
112         return socket.isInputShutdown();
113     }
114 
115     @Override
116     public final boolean isShutdown() {
117         return socket.isShutdown();
118     }
119 
120     @Override
121     public final ChannelFuture shutdownOutput() {
122         return shutdownOutput(newPromise());
123     }
124 
125     @Override
126     public final ChannelFuture shutdownOutput(final ChannelPromise promise) {
127         EventLoop loop = eventLoop();
128         if (loop.inEventLoop()) {
129             ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
130         } else {
131             loop.execute(new Runnable() {
132                 @Override
133                 public void run() {
134                     ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
135                 }
136             });
137         }
138 
139         return promise;
140     }
141 
142     @Override
143     public final ChannelFuture shutdownInput() {
144         return shutdownInput(newPromise());
145     }
146 
147     @Override
148     public final ChannelFuture shutdownInput(final ChannelPromise promise) {
149         EventLoop loop = eventLoop();
150         if (loop.inEventLoop()) {
151             shutdownInput0(promise);
152         } else {
153             loop.execute(new Runnable() {
154                 @Override
155                 public void run() {
156                     shutdownInput0(promise);
157                 }
158             });
159         }
160         return promise;
161     }
162 
163     private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
164         ChannelFuture shutdownInputFuture = shutdownInput();
165         if (shutdownInputFuture.isDone()) {
166             shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
167         } else {
168             shutdownInputFuture.addListener(new ChannelFutureListener() {
169                 @Override
170                 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
171                     shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
172                 }
173             });
174         }
175     }
176 
177     private static void shutdownDone(ChannelFuture shutdownOutputFuture,
178                                      ChannelFuture shutdownInputFuture,
179                                      ChannelPromise promise) {
180         Throwable shutdownOutputCause = shutdownOutputFuture.cause();
181         Throwable shutdownInputCause = shutdownInputFuture.cause();
182         if (shutdownOutputCause != null) {
183             if (shutdownInputCause != null) {
184                 logger.info("Exception suppressed because a previous exception occurred.",
185                              shutdownInputCause);
186             }
187             promise.setFailure(shutdownOutputCause);
188         } else if (shutdownInputCause != null) {
189             promise.setFailure(shutdownInputCause);
190         } else {
191             promise.setSuccess();
192         }
193     }
194 
195     @Override
196     protected final void doRegister(ChannelPromise promise) {
197         super.doRegister(promise);
198         promise.addListener(f -> {
199             if (f.isSuccess()) {
200                 if (active) {
201                     // Register for POLLRDHUP if this channel is already considered active.
202                     schedulePollRdHup();
203                 }
204             }
205         });
206     }
207 
208     @Override
209     protected Object filterOutboundMessage(Object msg) {
210         // Since we cannot use synchronous sendfile,
211         // the channel can only support DefaultFileRegion instead of FileRegion.
212         if (IoUring.isIOUringSpliceSupported() && msg instanceof DefaultFileRegion) {
213             return new IoUringFileRegion((DefaultFileRegion) msg);
214         }
215 
216         return super.filterOutboundMessage(msg);
217     }
218 
219     private final class IoUringStreamUnsafe extends AbstractUringUnsafe {
220 
221         private ByteBuf readBuffer;
222         private IovArray iovArray;
223 
224         @Override
225         protected int scheduleWriteMultiple(ChannelOutboundBuffer in) {
226             assert iovArray == null;
227             assert writeId == 0;
228             int numElements = Math.min(in.size(), Limits.IOV_MAX);
229             ByteBuf iovArrayBuffer = alloc().directBuffer(numElements * IovArray.IOV_SIZE);
230             iovArray = new IovArray(iovArrayBuffer);
231             try {
232                 int offset = iovArray.count();
233                 in.forEachFlushedMessage(iovArray);
234 
235                 int fd = fd().intValue();
236                 IoUringIoRegistration registration = registration();
237                 IoUringIoOps ops = IoUringIoOps.newWritev(fd, (byte) 0, 0, iovArray.memoryAddress(offset),
238                         iovArray.count() - offset, nextOpsId());
239                 byte opCode = ops.opcode();
240                 writeId = registration.submit(ops);
241                 writeOpCode = opCode;
242                 if (writeId == 0) {
243                     return 0;
244                 }
245             } catch (Exception e) {
246                 iovArray.release();
247                 iovArray = null;
248 
249                 // This should never happen, anyway fallback to single write.
250                 scheduleWriteSingle(in.current());
251             }
252             return 1;
253         }
254 
255         @Override
256         protected int scheduleWriteSingle(Object msg) {
257             assert iovArray == null;
258             assert writeId == 0;
259 
260             int fd = fd().intValue();
261             IoUringIoRegistration registration = registration();
262             final IoUringIoOps ops;
263             if (msg instanceof IoUringFileRegion) {
264                 IoUringFileRegion fileRegion = (IoUringFileRegion) msg;
265                 try {
266                     fileRegion.open();
267                 } catch (IOException e) {
268                     this.handleWriteError(e);
269                     return 0;
270                 }
271                 ops = fileRegion.splice(fd);
272             } else {
273                 ByteBuf buf = (ByteBuf) msg;
274                 ops = IoUringIoOps.newWrite(fd, (byte) 0, 0,
275                         buf.memoryAddress() + buf.readerIndex(), buf.readableBytes(), nextOpsId());
276             }
277 
278             byte opCode = ops.opcode();
279             writeId = registration.submit(ops);
280             writeOpCode = opCode;
281             if (writeId == 0) {
282                 return 0;
283             }
284             return 1;
285         }
286 
287         @Override
288         protected int scheduleRead0(boolean first) {
289             assert readBuffer == null;
290             assert readId == 0;
291 
292             final IoUringRecvByteAllocatorHandle allocHandle = recvBufAllocHandle();
293             ByteBuf byteBuf = allocHandle.allocate(alloc());
294             try {
295                 allocHandle.attemptedBytesRead(byteBuf.writableBytes());
296                 int fd = fd().intValue();
297                 IoUringIoRegistration registration = registration();
298                 // Depending on if this is the first read or not we will use Native.MSG_DONTWAIT.
299                 // The idea is that if the socket is blocking we can do the first read in a blocking fashion
300                 // and so not need to also register POLLIN. As we can not 100 % sure if reads after the first will
301                 // be possible directly we schedule these with Native.MSG_DONTWAIT. This allows us to still be
302                 // able to signal the fireChannelReadComplete() in a timely manner and be consistent with other
303                 // transports.
304                 IoUringIoOps ops = IoUringIoOps.newRecv(fd, (byte) 0, first ? 0 : Native.MSG_DONTWAIT,
305                         byteBuf.memoryAddress() + byteBuf.writerIndex(), byteBuf.writableBytes(), nextOpsId());
306                 readId = registration.submit(ops);
307                 if (readId == 0) {
308                     return 0;
309                 }
310                 readBuffer = byteBuf;
311                 byteBuf = null;
312                 return 1;
313             } finally {
314                 if (byteBuf != null) {
315                     byteBuf.release();
316                 }
317             }
318         }
319 
320         @Override
321         protected void readComplete0(byte op, int res, int flags, short data, int outstanding) {
322             assert readId != 0;
323             readId = 0;
324             boolean allDataRead = false;
325 
326             final IoUringRecvByteAllocatorHandle allocHandle = recvBufAllocHandle();
327             final ChannelPipeline pipeline = pipeline();
328             ByteBuf byteBuf = this.readBuffer;
329             this.readBuffer = null;
330             assert byteBuf != null;
331 
332             try {
333                 if (res < 0) {
334                     if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
335                         byteBuf.release();
336                         return;
337                     }
338                     // If res is negative we should pass it to ioResult(...) which will either throw
339                     // or convert it to 0 if we could not read because the socket was not readable.
340                     allocHandle.lastBytesRead(ioResult("io_uring read", res));
341                 } else if (res > 0) {
342                     byteBuf.writerIndex(byteBuf.writerIndex() + res);
343                     allocHandle.lastBytesRead(res);
344                 } else {
345                     // EOF which we signal with -1.
346                     allocHandle.lastBytesRead(-1);
347                 }
348                 if (allocHandle.lastBytesRead() <= 0) {
349                     // nothing was read, release the buffer.
350                     byteBuf.release();
351                     byteBuf = null;
352                     allDataRead = allocHandle.lastBytesRead() < 0;
353                     if (allDataRead) {
354                         // There is nothing left to read as we received an EOF.
355                         shutdownInput(true);
356                     }
357                     allocHandle.readComplete();
358                     pipeline.fireChannelReadComplete();
359                     return;
360                 }
361 
362                 allocHandle.incMessagesRead(1);
363                 pipeline.fireChannelRead(byteBuf);
364                 byteBuf = null;
365                 if (allocHandle.continueReading() &&
366                         // If IORING_CQE_F_SOCK_NONEMPTY is supported we should check for it first before
367                         // trying to schedule a read. If it's supported and not part of the flags we know for sure
368                         // that the next read (which would be using Native.MSG_DONTWAIT) will complete without
369                         // be able to read any data. This is useless work and we can skip it.
370                         (!IoUring.isIOUringCqeFSockNonEmptySupported() ||
371                         (flags & Native.IORING_CQE_F_SOCK_NONEMPTY) != 0)) {
372                     // Let's schedule another read.
373                     scheduleRead(false);
374                 } else {
375                     // We did not fill the whole ByteBuf so we should break the "read loop" and try again later.
376                     allocHandle.readComplete();
377                     pipeline.fireChannelReadComplete();
378                 }
379             } catch (Throwable t) {
380                 handleReadException(pipeline, byteBuf, t, allDataRead, allocHandle);
381             }
382         }
383 
384         private void handleReadException(ChannelPipeline pipeline, ByteBuf byteBuf,
385                                          Throwable cause, boolean allDataRead,
386                                          IoUringRecvByteAllocatorHandle allocHandle) {
387             if (byteBuf != null) {
388                 if (byteBuf.isReadable()) {
389                     pipeline.fireChannelRead(byteBuf);
390                 } else {
391                     byteBuf.release();
392                 }
393             }
394             allocHandle.readComplete();
395             pipeline.fireChannelReadComplete();
396             pipeline.fireExceptionCaught(cause);
397             if (allDataRead || cause instanceof IOException) {
398                 shutdownInput(true);
399             }
400         }
401 
402         @Override
403         boolean writeComplete0(byte op, int res, int flags, short data, int outstanding) {
404             assert writeId != 0;
405             writeId = 0;
406             writeOpCode = 0;
407 
408             ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer();
409             Object current = channelOutboundBuffer.current();
410             if (current instanceof IoUringFileRegion) {
411                 IoUringFileRegion fileRegion = (IoUringFileRegion) current;
412                 try {
413                     int result = res >= 0 ? res : ioResult("io_uring splice", res);
414                     if (result == 0 && fileRegion.count() > 0) {
415                         validateFileRegion(fileRegion.fileRegion, fileRegion.transfered());
416                         return false;
417                     }
418                     int progress = fileRegion.handleResult(result, data);
419                     if (progress == -1) {
420                         // Done with writing
421                         channelOutboundBuffer.remove();
422                     } else if (progress > 0) {
423                         channelOutboundBuffer.progress(progress);
424                     }
425                 } catch (Throwable cause) {
426                     handleWriteError(cause);
427                 }
428                 return true;
429             }
430 
431             IovArray iovArray = this.iovArray;
432             if (iovArray != null) {
433                 this.iovArray = null;
434                 iovArray.release();
435             }
436             if (res >= 0) {
437                 channelOutboundBuffer.removeBytes(res);
438             } else if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
439                 return true;
440             } else {
441                 try {
442                     if (ioResult("io_uring write", res) == 0) {
443                         return false;
444                     }
445                 } catch (Throwable cause) {
446                     handleWriteError(cause);
447                 }
448             }
449             return true;
450         }
451 
452         @Override
453         protected void freeResourcesNow(IoUringIoRegistration reg) {
454             super.freeResourcesNow(reg);
455             assert readBuffer == null;
456         }
457     }
458 
459     @Override
460     protected final void cancelOutstandingReads(IoUringIoRegistration registration, int numOutstandingReads) {
461         if (readId != 0) {
462             // Let's try to cancel outstanding reads as these might be submitted and waiting for data (via fastpoll).
463             assert numOutstandingReads == 1;
464             int fd = fd().intValue();
465             IoUringIoOps ops = IoUringIoOps.newAsyncCancel(fd, (byte) 0, readId, Native.IORING_OP_RECV);
466             registration.submit(ops);
467         } else {
468             assert numOutstandingReads == 0;
469         }
470     }
471 
472     @Override
473     protected final void cancelOutstandingWrites(IoUringIoRegistration registration, int numOutstandingWrites) {
474         if (writeId != 0) {
475             // Let's try to cancel outstanding writes as these might be submitted and waiting to finish writing
476             // (via fastpoll).
477             assert numOutstandingWrites == 1;
478             assert writeOpCode != 0;
479             int fd = fd().intValue();
480             registration.submit(IoUringIoOps.newAsyncCancel(fd, (byte) 0, writeId, writeOpCode));
481         } else {
482             assert numOutstandingWrites == 0;
483         }
484     }
485 }