View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.uring;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelOutboundBuffer;
21  import io.netty.channel.socket.ServerSocketChannel;
22  import io.netty.channel.socket.SocketChannel;
23  import io.netty.channel.socket.SocketChannelConfig;
24  import io.netty.channel.unix.IovArray;
25  
26  import java.net.InetSocketAddress;
27  import java.net.SocketAddress;
28  import java.util.ArrayDeque;
29  import java.util.Queue;
30  
31  import static io.netty.channel.unix.Errors.ioResult;
32  
33  public final class IoUringSocketChannel extends AbstractIoUringStreamChannel implements SocketChannel {
34      private final IoUringSocketChannelConfig config;
35  
36      public IoUringSocketChannel() {
37         super(null, LinuxSocket.newSocketStream(), false);
38         this.config = new IoUringSocketChannelConfig(this);
39      }
40  
41      IoUringSocketChannel(Channel parent, LinuxSocket fd) {
42          super(parent, fd, true);
43          this.config = new IoUringSocketChannelConfig(this);
44      }
45  
46      IoUringSocketChannel(Channel parent, LinuxSocket fd, SocketAddress remote) {
47          super(parent, fd, remote);
48          this.config = new IoUringSocketChannelConfig(this);
49      }
50  
51      @Override
52      public ServerSocketChannel parent() {
53          return (ServerSocketChannel) super.parent();
54      }
55  
56      @Override
57      public SocketChannelConfig config() {
58          return config;
59      }
60  
61      @Override
62      public InetSocketAddress remoteAddress() {
63          return (InetSocketAddress) super.remoteAddress();
64      }
65  
66      @Override
67      public InetSocketAddress localAddress() {
68          return (InetSocketAddress) super.localAddress();
69      }
70  
71      @Override
72      protected AbstractUringUnsafe newUnsafe() {
73          return new IoUringSocketUnsafe();
74      }
75  
76      // Marker object that is used to mark a batch of buffers that were used with zero-copy write operations.
77      private static final Object ZC_BATCH_MARKER = new Object();
78  
79      private final class IoUringSocketUnsafe extends IoUringStreamUnsafe {
80          /**
81           * Queue that holds buffers that we can't release yet as the kernel still holds a reference to these.
82           */
83          private Queue<Object> zcWriteQueue;
84  
85          @Override
86          protected int scheduleWriteSingle(Object msg) {
87              assert writeId == 0;
88  
89              if (IoUring.isSendZcSupported() && msg instanceof ByteBuf) {
90                  ByteBuf buf = (ByteBuf) msg;
91                  int length = buf.readableBytes();
92                  if (((IoUringSocketChannelConfig) config()).shouldWriteZeroCopy(length)) {
93                      long address = IoUring.memoryAddress(buf) + buf.readerIndex();
94                      IoUringIoOps ops = IoUringIoOps.newSendZc(fd().intValue(), address, length, 0, nextOpsId(), 0);
95                      byte opCode = ops.opcode();
96                      writeId = registration().submit(ops);
97                      writeOpCode = opCode;
98                      if (writeId == 0) {
99                          return 0;
100                     }
101                     return 1;
102                 }
103                 // Should not use send_zc, just use normal write.
104             }
105             return super.scheduleWriteSingle(msg);
106         }
107 
108         @Override
109         protected int scheduleWriteMultiple(ChannelOutboundBuffer in) {
110             assert writeId == 0;
111 
112             if (IoUring.isSendmsgZcSupported() && (
113                     (IoUringSocketChannelConfig) config()).shouldWriteZeroCopy((int) in.totalPendingWriteBytes())) {
114                 IoUringIoHandler handler = registration().attachment();
115 
116                 IovArray iovArray = handler.iovArray();
117                 int offset = iovArray.count();
118                 // Limit to the maximum number of fragments to ensure we don't get an error when we have too many
119                 // buffers.
120                 iovArray.maxCount(Native.MAX_SKB_FRAGS);
121                 try {
122                     in.forEachFlushedMessage(iovArray);
123                 } catch (Exception e) {
124                     // This should never happen, anyway fallback to single write.
125                     return scheduleWriteSingle(in.current());
126                 }
127                 long iovArrayAddress = iovArray.memoryAddress(offset);
128                 int iovArrayLength = iovArray.count() - offset;
129 
130                 MsgHdrMemoryArray msgHdrArray = handler.msgHdrMemoryArray();
131                 MsgHdrMemory hdr = msgHdrArray.nextHdr();
132                 assert hdr != null;
133                 hdr.set(iovArrayAddress, iovArrayLength);
134                 IoUringIoOps ops = IoUringIoOps.newSendmsgZc(fd().intValue(), (byte) 0, 0, hdr.address(), nextOpsId());
135                 byte opCode = ops.opcode();
136                 writeId = registration().submit(ops);
137                 writeOpCode = opCode;
138                 if (writeId == 0) {
139                     return 0;
140                 }
141                 return 1;
142             }
143             // Should not use sendmsg_zc, just use normal writev.
144             return super.scheduleWriteMultiple(in);
145         }
146 
147         @Override
148         boolean writeComplete0(byte op, int res, int flags, short data, int outstanding) {
149             ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer();
150             if (op == Native.IORING_OP_SEND_ZC || op == Native.IORING_OP_SENDMSG_ZC) {
151                 return handleWriteCompleteZeroCopy(op, channelOutboundBuffer, res, flags);
152             }
153             return super.writeComplete0(op, res, flags, data, outstanding);
154         }
155 
156         private boolean handleWriteCompleteZeroCopy(byte op, ChannelOutboundBuffer channelOutboundBuffer,
157                                                     int res, int flags) {
158             if ((flags & Native.IORING_CQE_F_NOTIF) == 0) {
159                 // We only want to reset these if IORING_CQE_F_NOTIF is not set.
160                 // If it's set we know this is only an extra notification for a write but we already handled
161                 // the write completions before.
162                 // See https://man7.org/linux/man-pages/man2/io_uring_enter.2.html section: IORING_OP_SEND_ZC
163                 writeId = 0;
164                 writeOpCode = 0;
165 
166                 boolean more = (flags & Native.IORING_CQE_F_MORE) != 0;
167                 if (more) {
168                     // This is the result of send_sz or sendmsg_sc but there will also be another notification
169                     // which will let us know that we can release the buffer(s). In this case let's retain the
170                     // buffer(s) once and store it in an internal queue. Once we receive the notification we will
171                     // call release() on the buffer(s) as it's not used by the kernel anymore.
172                     if (zcWriteQueue == null) {
173                         zcWriteQueue = new ArrayDeque<>(8);
174                     }
175                 }
176                 if (res >= 0) {
177                     if (more) {
178 
179                         // Loop through all the buffers that were part of the operation so we can add them to our
180                         // internal queue to release later.
181                         do {
182                             ByteBuf currentBuffer = (ByteBuf) channelOutboundBuffer.current();
183                             assert currentBuffer != null;
184                             zcWriteQueue.add(currentBuffer);
185                             currentBuffer.retain();
186                             int readable = currentBuffer.readableBytes();
187                             int skip = Math.min(readable, res);
188                             currentBuffer.skipBytes(skip);
189                             channelOutboundBuffer.progress(readable);
190                             if (readable <= res) {
191                                 boolean removed = channelOutboundBuffer.remove();
192                                 assert removed;
193                             }
194                             res -= readable;
195                         } while (res > 0);
196                         // Add the marker so we know when we need to stop releasing
197                         zcWriteQueue.add(ZC_BATCH_MARKER);
198                     } else {
199                         // We don't expect any extra notification, just directly let the buffer be released.
200                         channelOutboundBuffer.removeBytes(res);
201                     }
202                     return true;
203                 } else {
204                     if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
205                         if (more) {
206                             // The send was cancelled but we expect another notification. Just add the marker to the
207                             // queue so we don't get into trouble once the final notification for this operation is
208                             // received.
209                             zcWriteQueue.add(ZC_BATCH_MARKER);
210                         }
211                         return true;
212                     }
213                     try {
214                         String msg = op == Native.IORING_OP_SEND_ZC ? "io_uring sendzc" : "io_uring sendmsg_zc";
215                         int result = ioResult(msg, res);
216                         if (more) {
217                             try {
218                                 // We expect another notification so we need to ensure we retain these buffers
219                                 // so we can release these once we see IORING_CQE_F_NOTIF set.
220                                 addFlushedToZcWriteQueue(channelOutboundBuffer);
221                             } catch (Exception e) {
222                                 // should never happen but let's handle it anyway.
223                                 handleWriteError(e);
224                             }
225                         }
226                         if (result == 0) {
227                             return false;
228                         }
229                     } catch (Throwable cause) {
230                         if (more) {
231                             try {
232                                 // We expect another notification as handleWriteError(...) will fail all flushed writes
233                                 // and also release any buffers we need to ensure we retain these buffers
234                                 // so we can release these once we see IORING_CQE_F_NOTIF set.
235                                 addFlushedToZcWriteQueue(channelOutboundBuffer);
236                             } catch (Exception e) {
237                                 // should never happen but let's handle it anyway.
238                                 cause.addSuppressed(e);
239                             }
240                         }
241                         handleWriteError(cause);
242                     }
243                 }
244             } else {
245                 if (zcWriteQueue != null) {
246                     for (;;) {
247                         Object queued = zcWriteQueue.remove();
248                         assert queued != null;
249                         if (queued == ZC_BATCH_MARKER) {
250                             // Done releasing the buffers of the zero-copy batch.
251                             break;
252                         }
253                         // The buffer can now be released.
254                         ((ByteBuf) queued).release();
255                     }
256                 }
257             }
258             return true;
259         }
260 
261         private void addFlushedToZcWriteQueue(ChannelOutboundBuffer channelOutboundBuffer) throws Exception {
262             // We expect another notification as handleWriteError(...) will fail all flushed writes
263             // and also release any buffers we need to ensure we retain these buffers
264             // so we can release these once we see IORING_CQE_F_NOTIF set.
265             try {
266                 channelOutboundBuffer.forEachFlushedMessage(m -> {
267                     if (!(m instanceof ByteBuf)) {
268                         return false;
269                     }
270                     zcWriteQueue.add(m);
271                     ((ByteBuf) m).retain();
272                     return true;
273                 });
274             } finally {
275                 zcWriteQueue.add(ZC_BATCH_MARKER);
276             }
277         }
278     }
279 }