View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.uring;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelOutboundBuffer;
21  import io.netty.channel.socket.ServerSocketChannel;
22  import io.netty.channel.socket.SocketChannel;
23  import io.netty.channel.socket.SocketChannelConfig;
24  import io.netty.channel.unix.IovArray;
25  
26  import java.net.InetSocketAddress;
27  import java.net.SocketAddress;
28  import java.util.ArrayDeque;
29  import java.util.Queue;
30  
31  import static io.netty.channel.unix.Errors.ioResult;
32  
33  public final class IoUringSocketChannel extends AbstractIoUringStreamChannel implements SocketChannel {
34      private final IoUringSocketChannelConfig config;
35  
36      public IoUringSocketChannel() {
37         super(null, LinuxSocket.newSocketStream(), false);
38         this.config = new IoUringSocketChannelConfig(this);
39      }
40  
41      IoUringSocketChannel(Channel parent, LinuxSocket fd) {
42          super(parent, fd, true);
43          this.config = new IoUringSocketChannelConfig(this);
44      }
45  
46      IoUringSocketChannel(Channel parent, LinuxSocket fd, SocketAddress remote) {
47          super(parent, fd, remote);
48          this.config = new IoUringSocketChannelConfig(this);
49      }
50  
51      @Override
52      public ServerSocketChannel parent() {
53          return (ServerSocketChannel) super.parent();
54      }
55  
56      @Override
57      public SocketChannelConfig config() {
58          return config;
59      }
60  
61      @Override
62      public InetSocketAddress remoteAddress() {
63          return (InetSocketAddress) super.remoteAddress();
64      }
65  
66      @Override
67      public InetSocketAddress localAddress() {
68          return (InetSocketAddress) super.localAddress();
69      }
70  
71      @Override
72      protected AbstractUringUnsafe newUnsafe() {
73          return new IoUringSocketUnsafe();
74      }
75  
76      // Marker object that is used to mark a batch of buffers that were used with zero-copy write operations.
77      private static final Object ZC_BATCH_MARKER = new Object();
78  
79      private final class IoUringSocketUnsafe extends IoUringStreamUnsafe {
80          /**
81           * Queue that holds buffers that we can't release yet as the kernel still holds a reference to these.
82           */
83          private Queue<Object> zcWriteQueue;
84  
85          @Override
86          protected int scheduleWriteSingle(Object msg) {
87              assert writeId == 0;
88  
89              if (IoUring.isSendZcSupported() && msg instanceof ByteBuf) {
90                  ByteBuf buf = (ByteBuf) msg;
91                  int length = buf.readableBytes();
92                  if (((IoUringSocketChannelConfig) config()).shouldWriteZeroCopy(length)) {
93                      long address = IoUring.memoryAddress(buf) + buf.readerIndex();
94                      IoUringIoOps ops = IoUringIoOps.newSendZc(fd().intValue(), address, length, 0, nextOpsId(), 0);
95                      byte opCode = ops.opcode();
96                      writeId = registration().submit(ops);
97                      writeOpCode = opCode;
98                      if (writeId == 0) {
99                          return 0;
100                     }
101                     return 1;
102                 }
103                 // Should not use send_zc, just use normal write.
104             }
105             return super.scheduleWriteSingle(msg);
106         }
107 
108         @Override
109         protected int scheduleWriteMultiple(ChannelOutboundBuffer in) {
110             assert writeId == 0;
111 
112             IoUringSocketChannelConfig ioUringSocketChannelConfig = (IoUringSocketChannelConfig) config();
113             //at least one buffer in the batch exceeds `IO_URING_WRITE_ZERO_COPY_THRESHOLD`.
114             if (IoUring.isSendmsgZcSupported()
115                     && (ioUringSocketChannelConfig.shouldWriteZeroCopy(((ByteBuf) in.current()).readableBytes()))) {
116                 IoUringIoHandler handler = registration().attachment();
117 
118                 IovArray iovArray = handler.iovArray();
119                 int offset = iovArray.count();
120                 // Limit to the maximum number of fragments to ensure we don't get an error when we have too many
121                 // buffers.
122                 iovArray.maxCount(Native.MAX_SKB_FRAGS);
123                 try {
124                     in.forEachFlushedMessage(new ChannelOutboundBuffer.MessageProcessor() {
125                         @Override
126                         public boolean processMessage(Object msg) throws Exception {
127                             if (msg instanceof ByteBuf) {
128                                 ByteBuf buf = (ByteBuf) msg;
129                                 int length = buf.readableBytes();
130                                 if (ioUringSocketChannelConfig.shouldWriteZeroCopy(length)) {
131                                     return iovArray.processMessage(msg);
132                                 }
133                             }
134                             return false;
135                         }
136                     });
137                 } catch (Exception e) {
138                     // This should never happen, anyway fallback to single write.
139                     return scheduleWriteSingle(in.current());
140                 }
141                 long iovArrayAddress = iovArray.memoryAddress(offset);
142                 int iovArrayLength = iovArray.count() - offset;
143 
144                 MsgHdrMemoryArray msgHdrArray = handler.msgHdrMemoryArray();
145                 MsgHdrMemory hdr = msgHdrArray.nextHdr();
146                 assert hdr != null;
147                 hdr.set(iovArrayAddress, iovArrayLength);
148                 IoUringIoOps ops = IoUringIoOps.newSendmsgZc(fd().intValue(), (byte) 0, 0, hdr.address(), nextOpsId());
149                 byte opCode = ops.opcode();
150                 writeId = registration().submit(ops);
151                 writeOpCode = opCode;
152                 if (writeId == 0) {
153                     return 0;
154                 }
155                 return 1;
156             }
157             // Should not use sendmsg_zc, just use normal writev.
158             return super.scheduleWriteMultiple(in);
159         }
160 
161         @Override
162         protected ChannelOutboundBuffer.MessageProcessor filterWriteMultiple(IovArray iovArray) {
163             IoUringSocketChannelConfig ioUringSocketChannelConfig = (IoUringSocketChannelConfig) config();
164             return new ChannelOutboundBuffer.MessageProcessor() {
165                 @Override
166                 public boolean processMessage(Object msg) throws Exception {
167                     if (msg instanceof ByteBuf) {
168                         ByteBuf buf = (ByteBuf) msg;
169                         int length = buf.readableBytes();
170                         if (ioUringSocketChannelConfig.shouldWriteZeroCopy(length)) {
171                             return false;
172                         }
173                     }
174                     return iovArray.processMessage(msg);
175                 }
176             };
177         }
178 
179         @Override
180         boolean writeComplete0(byte op, int res, int flags, short data, int outstanding) {
181             ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer();
182             if (op == Native.IORING_OP_SEND_ZC || op == Native.IORING_OP_SENDMSG_ZC) {
183                 return handleWriteCompleteZeroCopy(op, channelOutboundBuffer, res, flags);
184             }
185             return super.writeComplete0(op, res, flags, data, outstanding);
186         }
187 
188         private boolean handleWriteCompleteZeroCopy(byte op, ChannelOutboundBuffer channelOutboundBuffer,
189                                                     int res, int flags) {
190             if ((flags & Native.IORING_CQE_F_NOTIF) == 0) {
191                 // We only want to reset these if IORING_CQE_F_NOTIF is not set.
192                 // If it's set we know this is only an extra notification for a write but we already handled
193                 // the write completions before.
194                 // See https://man7.org/linux/man-pages/man2/io_uring_enter.2.html section: IORING_OP_SEND_ZC
195                 writeId = 0;
196                 writeOpCode = 0;
197 
198                 boolean more = (flags & Native.IORING_CQE_F_MORE) != 0;
199                 if (more) {
200                     // This is the result of send_sz or sendmsg_sc but there will also be another notification
201                     // which will let us know that we can release the buffer(s). In this case let's retain the
202                     // buffer(s) once and store it in an internal queue. Once we receive the notification we will
203                     // call release() on the buffer(s) as it's not used by the kernel anymore.
204                     if (zcWriteQueue == null) {
205                         zcWriteQueue = new ArrayDeque<>(8);
206                     }
207                 }
208                 if (res >= 0) {
209                     if (more) {
210 
211                         // Loop through all the buffers that were part of the operation so we can add them to our
212                         // internal queue to release later.
213                         do {
214                             ByteBuf currentBuffer = (ByteBuf) channelOutboundBuffer.current();
215                             assert currentBuffer != null;
216                             zcWriteQueue.add(currentBuffer);
217                             currentBuffer.retain();
218                             int readable = currentBuffer.readableBytes();
219                             int skip = Math.min(readable, res);
220                             currentBuffer.skipBytes(skip);
221                             channelOutboundBuffer.progress(readable);
222                             if (readable <= res) {
223                                 boolean removed = channelOutboundBuffer.remove();
224                                 assert removed;
225                             }
226                             res -= readable;
227                         } while (res > 0);
228                         // Add the marker so we know when we need to stop releasing
229                         zcWriteQueue.add(ZC_BATCH_MARKER);
230                     } else {
231                         // We don't expect any extra notification, just directly let the buffer be released.
232                         channelOutboundBuffer.removeBytes(res);
233                     }
234                     return true;
235                 } else {
236                     if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
237                         if (more) {
238                             // The send was cancelled but we expect another notification. Just add the marker to the
239                             // queue so we don't get into trouble once the final notification for this operation is
240                             // received.
241                             zcWriteQueue.add(ZC_BATCH_MARKER);
242                         }
243                         return true;
244                     }
245                     try {
246                         String msg = op == Native.IORING_OP_SEND_ZC ? "io_uring sendzc" : "io_uring sendmsg_zc";
247                         int result = ioResult(msg, res);
248                         if (more) {
249                             try {
250                                 // We expect another notification so we need to ensure we retain these buffers
251                                 // so we can release these once we see IORING_CQE_F_NOTIF set.
252                                 addFlushedToZcWriteQueue(channelOutboundBuffer);
253                             } catch (Exception e) {
254                                 // should never happen but let's handle it anyway.
255                                 handleWriteError(e);
256                             }
257                         }
258                         if (result == 0) {
259                             return false;
260                         }
261                     } catch (Throwable cause) {
262                         if (more) {
263                             try {
264                                 // We expect another notification as handleWriteError(...) will fail all flushed writes
265                                 // and also release any buffers we need to ensure we retain these buffers
266                                 // so we can release these once we see IORING_CQE_F_NOTIF set.
267                                 addFlushedToZcWriteQueue(channelOutboundBuffer);
268                             } catch (Exception e) {
269                                 // should never happen but let's handle it anyway.
270                                 cause.addSuppressed(e);
271                             }
272                         }
273                         handleWriteError(cause);
274                     }
275                 }
276             } else {
277                 if (zcWriteQueue != null) {
278                     for (;;) {
279                         Object queued = zcWriteQueue.remove();
280                         assert queued != null;
281                         if (queued == ZC_BATCH_MARKER) {
282                             // Done releasing the buffers of the zero-copy batch.
283                             break;
284                         }
285                         // The buffer can now be released.
286                         ((ByteBuf) queued).release();
287                     }
288                 }
289             }
290             return true;
291         }
292 
293         private void addFlushedToZcWriteQueue(ChannelOutboundBuffer channelOutboundBuffer) throws Exception {
294             // We expect another notification as handleWriteError(...) will fail all flushed writes
295             // and also release any buffers we need to ensure we retain these buffers
296             // so we can release these once we see IORING_CQE_F_NOTIF set.
297             try {
298                 channelOutboundBuffer.forEachFlushedMessage(m -> {
299                     if (!(m instanceof ByteBuf)) {
300                         return false;
301                     }
302                     zcWriteQueue.add(m);
303                     ((ByteBuf) m).retain();
304                     return true;
305                 });
306             } finally {
307                 zcWriteQueue.add(ZC_BATCH_MARKER);
308             }
309         }
310     }
311 }