View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.uring;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelOutboundBuffer;
21  import io.netty.channel.socket.ServerSocketChannel;
22  import io.netty.channel.socket.SocketChannel;
23  import io.netty.channel.socket.SocketChannelConfig;
24  import io.netty.channel.unix.IovArray;
25  
26  import java.net.InetSocketAddress;
27  import java.net.SocketAddress;
28  import java.util.ArrayDeque;
29  import java.util.Queue;
30  
31  import static io.netty.channel.unix.Errors.ioResult;
32  
33  public final class IoUringSocketChannel extends AbstractIoUringStreamChannel implements SocketChannel {
34      private final IoUringSocketChannelConfig config;
35  
36      public IoUringSocketChannel() {
37         super(null, LinuxSocket.newSocketStream(), false);
38         this.config = new IoUringSocketChannelConfig(this);
39      }
40  
41      IoUringSocketChannel(Channel parent, LinuxSocket fd) {
42          super(parent, fd, true);
43          this.config = new IoUringSocketChannelConfig(this);
44      }
45  
46      IoUringSocketChannel(Channel parent, LinuxSocket fd, SocketAddress remote) {
47          super(parent, fd, remote);
48          this.config = new IoUringSocketChannelConfig(this);
49      }
50  
51      @Override
52      public ServerSocketChannel parent() {
53          return (ServerSocketChannel) super.parent();
54      }
55  
56      @Override
57      public SocketChannelConfig config() {
58          return config;
59      }
60  
61      @Override
62      public InetSocketAddress remoteAddress() {
63          return (InetSocketAddress) super.remoteAddress();
64      }
65  
66      @Override
67      public InetSocketAddress localAddress() {
68          return (InetSocketAddress) super.localAddress();
69      }
70  
71      @Override
72      protected AbstractUringUnsafe newUnsafe() {
73          return new IoUringSocketUnsafe();
74      }
75  
76      // Marker object that is used to mark a batch of buffers that were used with zero-copy write operations.
77      private static final Object ZC_BATCH_MARKER = new Object();
78  
79      private final class IoUringSocketUnsafe extends IoUringStreamUnsafe {
80          /**
81           * Queue that holds buffers that we can't release yet as the kernel still holds a reference to these.
82           */
83          private Queue<Object> zcWriteQueue;
84  
85          @Override
86          protected int scheduleWriteSingle(Object msg) {
87              assert writeId == 0;
88  
89              if (IoUring.isSendZcSupported() && msg instanceof ByteBuf) {
90                  ByteBuf buf = (ByteBuf) msg;
91                  int length = buf.readableBytes();
92                  if (((IoUringSocketChannelConfig) config()).shouldWriteZeroCopy(length)) {
93                      long address = IoUring.memoryAddress(buf) + buf.readerIndex();
94                      IoUringIoOps ops = IoUringIoOps.newSendZc(fd().intValue(), address, length, 0, nextOpsId(), 0);
95                      byte opCode = ops.opcode();
96                      writeId = registration().submit(ops);
97                      writeOpCode = opCode;
98                      if (writeId == 0) {
99                          return 0;
100                     }
101                     return 1;
102                 }
103                 // Should not use send_zc, just use normal write.
104             }
105             return super.scheduleWriteSingle(msg);
106         }
107 
108         @Override
109         protected int scheduleWriteMultiple(ChannelOutboundBuffer in) {
110             assert writeId == 0;
111 
112             IoUringSocketChannelConfig ioUringSocketChannelConfig = (IoUringSocketChannelConfig) config();
113             //at least one buffer in the batch exceeds `IO_URING_WRITE_ZERO_COPY_THRESHOLD`.
114             if (IoUring.isSendmsgZcSupported()
115                     && (ioUringSocketChannelConfig.shouldWriteZeroCopy(((ByteBuf) in.current()).readableBytes()))) {
116                 IoUringIoHandler handler = registration().attachment();
117 
118                 IovArray iovArray = handler.iovArray();
119                 int offset = iovArray.count();
120                 // Limit to the maximum number of fragments to ensure we don't get an error when we have too many
121                 // buffers.
122                 iovArray.maxCount(Native.MAX_SKB_FRAGS);
123                 try {
124                     in.forEachFlushedMessage(new ChannelOutboundBuffer.MessageProcessor() {
125                         @Override
126                         public boolean processMessage(Object msg) throws Exception {
127                             if (msg instanceof ByteBuf) {
128                                 ByteBuf buf = (ByteBuf) msg;
129                                 int length = buf.readableBytes();
130                                 if (ioUringSocketChannelConfig.shouldWriteZeroCopy(length)) {
131                                     return iovArray.processMessage(msg);
132                                 }
133                             }
134                             return false;
135                         }
136                     });
137                 } catch (Exception e) {
138                     // This should never happen, anyway fallback to single write.
139                     return scheduleWriteSingle(in.current());
140                 }
141                 long iovArrayAddress = iovArray.memoryAddress(offset);
142                 int iovArrayLength = iovArray.count() - offset;
143 
144                 MsgHdrMemoryArray msgHdrArray = handler.msgHdrMemoryArray();
145                 MsgHdrMemory hdr = msgHdrArray.nextHdr();
146                 assert hdr != null;
147                 hdr.set(iovArrayAddress, iovArrayLength);
148                 IoUringIoOps ops = IoUringIoOps.newSendmsgZc(fd().intValue(), (byte) 0, 0, hdr.address(), nextOpsId());
149                 byte opCode = ops.opcode();
150                 writeId = registration().submit(ops);
151                 writeOpCode = opCode;
152                 if (writeId == 0) {
153                     return 0;
154                 }
155                 return 1;
156             }
157             // Should not use sendmsg_zc, just use normal writev.
158             return super.scheduleWriteMultiple(in);
159         }
160 
161         @Override
162         protected ChannelOutboundBuffer.MessageProcessor filterWriteMultiple(IovArray iovArray) {
163             if (!IoUring.isSendmsgZcSupported()) {
164                 return super.filterWriteMultiple(iovArray);
165             }
166             IoUringSocketChannelConfig ioUringSocketChannelConfig = (IoUringSocketChannelConfig) config();
167             return new ChannelOutboundBuffer.MessageProcessor() {
168                 @Override
169                 public boolean processMessage(Object msg) throws Exception {
170                     if (msg instanceof ByteBuf) {
171                         ByteBuf buf = (ByteBuf) msg;
172                         int length = buf.readableBytes();
173                         if (ioUringSocketChannelConfig.shouldWriteZeroCopy(length)) {
174                             return false;
175                         }
176                     }
177                     return iovArray.processMessage(msg);
178                 }
179             };
180         }
181 
182         @Override
183         boolean writeComplete0(byte op, int res, int flags, short data, int outstanding) {
184             ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer();
185             if (op == Native.IORING_OP_SEND_ZC || op == Native.IORING_OP_SENDMSG_ZC) {
186                 return handleWriteCompleteZeroCopy(op, channelOutboundBuffer, res, flags);
187             }
188             return super.writeComplete0(op, res, flags, data, outstanding);
189         }
190 
191         private boolean handleWriteCompleteZeroCopy(byte op, ChannelOutboundBuffer channelOutboundBuffer,
192                                                     int res, int flags) {
193             if ((flags & Native.IORING_CQE_F_NOTIF) == 0) {
194                 // We only want to reset these if IORING_CQE_F_NOTIF is not set.
195                 // If it's set we know this is only an extra notification for a write but we already handled
196                 // the write completions before.
197                 // See https://man7.org/linux/man-pages/man2/io_uring_enter.2.html section: IORING_OP_SEND_ZC
198                 writeId = 0;
199                 writeOpCode = 0;
200 
201                 boolean more = (flags & Native.IORING_CQE_F_MORE) != 0;
202                 if (more) {
203                     // This is the result of send_sz or sendmsg_sc but there will also be another notification
204                     // which will let us know that we can release the buffer(s). In this case let's retain the
205                     // buffer(s) once and store it in an internal queue. Once we receive the notification we will
206                     // call release() on the buffer(s) as it's not used by the kernel anymore.
207                     if (zcWriteQueue == null) {
208                         zcWriteQueue = new ArrayDeque<>(8);
209                     }
210                 }
211                 if (res >= 0) {
212                     if (more) {
213 
214                         // Loop through all the buffers that were part of the operation so we can add them to our
215                         // internal queue to release later.
216                         do {
217                             ByteBuf currentBuffer = (ByteBuf) channelOutboundBuffer.current();
218                             assert currentBuffer != null;
219                             zcWriteQueue.add(currentBuffer);
220                             currentBuffer.retain();
221                             int readable = currentBuffer.readableBytes();
222                             int skip = Math.min(readable, res);
223                             currentBuffer.skipBytes(skip);
224                             channelOutboundBuffer.progress(readable);
225                             if (readable <= res) {
226                                 boolean removed = channelOutboundBuffer.remove();
227                                 assert removed;
228                             }
229                             res -= readable;
230                         } while (res > 0);
231                         // Add the marker so we know when we need to stop releasing
232                         zcWriteQueue.add(ZC_BATCH_MARKER);
233                     } else {
234                         // We don't expect any extra notification, just directly let the buffer be released.
235                         channelOutboundBuffer.removeBytes(res);
236                     }
237                     return true;
238                 } else {
239                     if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
240                         if (more) {
241                             // The send was cancelled but we expect another notification. Just add the marker to the
242                             // queue so we don't get into trouble once the final notification for this operation is
243                             // received.
244                             zcWriteQueue.add(ZC_BATCH_MARKER);
245                         }
246                         return true;
247                     }
248                     try {
249                         String msg = op == Native.IORING_OP_SEND_ZC ? "io_uring sendzc" : "io_uring sendmsg_zc";
250                         int result = ioResult(msg, res);
251                         if (more) {
252                             try {
253                                 // We expect another notification so we need to ensure we retain these buffers
254                                 // so we can release these once we see IORING_CQE_F_NOTIF set.
255                                 addFlushedToZcWriteQueue(channelOutboundBuffer);
256                             } catch (Exception e) {
257                                 // should never happen but let's handle it anyway.
258                                 handleWriteError(e);
259                             }
260                         }
261                         if (result == 0) {
262                             return false;
263                         }
264                     } catch (Throwable cause) {
265                         if (more) {
266                             try {
267                                 // We expect another notification as handleWriteError(...) will fail all flushed writes
268                                 // and also release any buffers we need to ensure we retain these buffers
269                                 // so we can release these once we see IORING_CQE_F_NOTIF set.
270                                 addFlushedToZcWriteQueue(channelOutboundBuffer);
271                             } catch (Exception e) {
272                                 // should never happen but let's handle it anyway.
273                                 cause.addSuppressed(e);
274                             }
275                         }
276                         handleWriteError(cause);
277                     }
278                 }
279             } else {
280                 if (zcWriteQueue != null) {
281                     for (;;) {
282                         Object queued = zcWriteQueue.remove();
283                         assert queued != null;
284                         if (queued == ZC_BATCH_MARKER) {
285                             // Done releasing the buffers of the zero-copy batch.
286                             break;
287                         }
288                         // The buffer can now be released.
289                         ((ByteBuf) queued).release();
290                     }
291                 }
292             }
293             return true;
294         }
295 
296         private void addFlushedToZcWriteQueue(ChannelOutboundBuffer channelOutboundBuffer) throws Exception {
297             // We expect another notification as handleWriteError(...) will fail all flushed writes
298             // and also release any buffers we need to ensure we retain these buffers
299             // so we can release these once we see IORING_CQE_F_NOTIF set.
300             try {
301                 channelOutboundBuffer.forEachFlushedMessage(m -> {
302                     if (!(m instanceof ByteBuf)) {
303                         return false;
304                     }
305                     zcWriteQueue.add(m);
306                     ((ByteBuf) m).retain();
307                     return true;
308                 });
309             } finally {
310                 zcWriteQueue.add(ZC_BATCH_MARKER);
311             }
312         }
313     }
314 }