1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.uring;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelOutboundBuffer;
21 import io.netty.channel.socket.ServerSocketChannel;
22 import io.netty.channel.socket.SocketChannel;
23 import io.netty.channel.socket.SocketChannelConfig;
24 import io.netty.channel.unix.IovArray;
25
26 import java.net.InetSocketAddress;
27 import java.net.SocketAddress;
28 import java.util.ArrayDeque;
29 import java.util.Queue;
30
31 import static io.netty.channel.unix.Errors.ioResult;
32
33 public final class IoUringSocketChannel extends AbstractIoUringStreamChannel implements SocketChannel {
34 private final IoUringSocketChannelConfig config;
35
36 public IoUringSocketChannel() {
37 super(null, LinuxSocket.newSocketStream(), false);
38 this.config = new IoUringSocketChannelConfig(this);
39 }
40
41 IoUringSocketChannel(Channel parent, LinuxSocket fd) {
42 super(parent, fd, true);
43 this.config = new IoUringSocketChannelConfig(this);
44 }
45
46 IoUringSocketChannel(Channel parent, LinuxSocket fd, SocketAddress remote) {
47 super(parent, fd, remote);
48 this.config = new IoUringSocketChannelConfig(this);
49 }
50
51 @Override
52 public ServerSocketChannel parent() {
53 return (ServerSocketChannel) super.parent();
54 }
55
56 @Override
57 public SocketChannelConfig config() {
58 return config;
59 }
60
61 @Override
62 public InetSocketAddress remoteAddress() {
63 return (InetSocketAddress) super.remoteAddress();
64 }
65
66 @Override
67 public InetSocketAddress localAddress() {
68 return (InetSocketAddress) super.localAddress();
69 }
70
71 @Override
72 protected AbstractUringUnsafe newUnsafe() {
73 return new IoUringSocketUnsafe();
74 }
75
76
77 private static final Object ZC_BATCH_MARKER = new Object();
78
79 private final class IoUringSocketUnsafe extends IoUringStreamUnsafe {
80
81
82
83 private Queue<Object> zcWriteQueue;
84
85 @Override
86 protected int scheduleWriteSingle(Object msg) {
87 assert writeId == 0;
88
89 if (IoUring.isSendZcSupported() && msg instanceof ByteBuf) {
90 ByteBuf buf = (ByteBuf) msg;
91 int length = buf.readableBytes();
92 if (((IoUringSocketChannelConfig) config()).shouldWriteZeroCopy(length)) {
93 long address = IoUring.memoryAddress(buf) + buf.readerIndex();
94 IoUringIoOps ops = IoUringIoOps.newSendZc(fd().intValue(), address, length, 0, nextOpsId(), 0);
95 byte opCode = ops.opcode();
96 writeId = registration().submit(ops);
97 writeOpCode = opCode;
98 if (writeId == 0) {
99 return 0;
100 }
101 return 1;
102 }
103
104 }
105 return super.scheduleWriteSingle(msg);
106 }
107
108 @Override
109 protected int scheduleWriteMultiple(ChannelOutboundBuffer in) {
110 assert writeId == 0;
111
112 IoUringSocketChannelConfig ioUringSocketChannelConfig = (IoUringSocketChannelConfig) config();
113
114 if (IoUring.isSendmsgZcSupported()
115 && (ioUringSocketChannelConfig.shouldWriteZeroCopy(((ByteBuf) in.current()).readableBytes()))) {
116 IoUringIoHandler handler = registration().attachment();
117
118 IovArray iovArray = handler.iovArray();
119 int offset = iovArray.count();
120
121
122 iovArray.maxCount(Native.MAX_SKB_FRAGS);
123 try {
124 in.forEachFlushedMessage(new ChannelOutboundBuffer.MessageProcessor() {
125 @Override
126 public boolean processMessage(Object msg) throws Exception {
127 if (msg instanceof ByteBuf) {
128 ByteBuf buf = (ByteBuf) msg;
129 int length = buf.readableBytes();
130 if (ioUringSocketChannelConfig.shouldWriteZeroCopy(length)) {
131 return iovArray.processMessage(msg);
132 }
133 }
134 return false;
135 }
136 });
137 } catch (Exception e) {
138
139 return scheduleWriteSingle(in.current());
140 }
141 long iovArrayAddress = iovArray.memoryAddress(offset);
142 int iovArrayLength = iovArray.count() - offset;
143
144 MsgHdrMemoryArray msgHdrArray = handler.msgHdrMemoryArray();
145 MsgHdrMemory hdr = msgHdrArray.nextHdr();
146 assert hdr != null;
147 hdr.set(iovArrayAddress, iovArrayLength);
148 IoUringIoOps ops = IoUringIoOps.newSendmsgZc(fd().intValue(), (byte) 0, 0, hdr.address(), nextOpsId());
149 byte opCode = ops.opcode();
150 writeId = registration().submit(ops);
151 writeOpCode = opCode;
152 if (writeId == 0) {
153 return 0;
154 }
155 return 1;
156 }
157
158 return super.scheduleWriteMultiple(in);
159 }
160
161 @Override
162 protected ChannelOutboundBuffer.MessageProcessor filterWriteMultiple(IovArray iovArray) {
163 if (!IoUring.isSendmsgZcSupported()) {
164 return super.filterWriteMultiple(iovArray);
165 }
166 IoUringSocketChannelConfig ioUringSocketChannelConfig = (IoUringSocketChannelConfig) config();
167 return new ChannelOutboundBuffer.MessageProcessor() {
168 @Override
169 public boolean processMessage(Object msg) throws Exception {
170 if (msg instanceof ByteBuf) {
171 ByteBuf buf = (ByteBuf) msg;
172 int length = buf.readableBytes();
173 if (ioUringSocketChannelConfig.shouldWriteZeroCopy(length)) {
174 return false;
175 }
176 }
177 return iovArray.processMessage(msg);
178 }
179 };
180 }
181
182 @Override
183 boolean writeComplete0(byte op, int res, int flags, short data, int outstanding) {
184 ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer();
185 if (op == Native.IORING_OP_SEND_ZC || op == Native.IORING_OP_SENDMSG_ZC) {
186 return handleWriteCompleteZeroCopy(op, channelOutboundBuffer, res, flags);
187 }
188 return super.writeComplete0(op, res, flags, data, outstanding);
189 }
190
191 private boolean handleWriteCompleteZeroCopy(byte op, ChannelOutboundBuffer channelOutboundBuffer,
192 int res, int flags) {
193 if ((flags & Native.IORING_CQE_F_NOTIF) == 0) {
194
195
196
197
198 writeId = 0;
199 writeOpCode = 0;
200
201 boolean more = (flags & Native.IORING_CQE_F_MORE) != 0;
202 if (more) {
203
204
205
206
207 if (zcWriteQueue == null) {
208 zcWriteQueue = new ArrayDeque<>(8);
209 }
210 }
211 if (res >= 0) {
212 if (more) {
213
214
215
216 do {
217 ByteBuf currentBuffer = (ByteBuf) channelOutboundBuffer.current();
218 assert currentBuffer != null;
219 zcWriteQueue.add(currentBuffer);
220 currentBuffer.retain();
221 int readable = currentBuffer.readableBytes();
222 int skip = Math.min(readable, res);
223 currentBuffer.skipBytes(skip);
224 channelOutboundBuffer.progress(readable);
225 if (readable <= res) {
226 boolean removed = channelOutboundBuffer.remove();
227 assert removed;
228 }
229 res -= readable;
230 } while (res > 0);
231
232 zcWriteQueue.add(ZC_BATCH_MARKER);
233 } else {
234
235 channelOutboundBuffer.removeBytes(res);
236 }
237 return true;
238 } else {
239 if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
240 if (more) {
241
242
243
244 zcWriteQueue.add(ZC_BATCH_MARKER);
245 }
246 return true;
247 }
248 try {
249 String msg = op == Native.IORING_OP_SEND_ZC ? "io_uring sendzc" : "io_uring sendmsg_zc";
250 int result = ioResult(msg, res);
251 if (more) {
252 try {
253
254
255 addFlushedToZcWriteQueue(channelOutboundBuffer);
256 } catch (Exception e) {
257
258 handleWriteError(e);
259 }
260 }
261 if (result == 0) {
262 return false;
263 }
264 } catch (Throwable cause) {
265 if (more) {
266 try {
267
268
269
270 addFlushedToZcWriteQueue(channelOutboundBuffer);
271 } catch (Exception e) {
272
273 cause.addSuppressed(e);
274 }
275 }
276 handleWriteError(cause);
277 }
278 }
279 } else {
280 if (zcWriteQueue != null) {
281 for (;;) {
282 Object queued = zcWriteQueue.remove();
283 assert queued != null;
284 if (queued == ZC_BATCH_MARKER) {
285
286 break;
287 }
288
289 ((ByteBuf) queued).release();
290 }
291 }
292 }
293 return true;
294 }
295
296 private void addFlushedToZcWriteQueue(ChannelOutboundBuffer channelOutboundBuffer) throws Exception {
297
298
299
300 try {
301 channelOutboundBuffer.forEachFlushedMessage(m -> {
302 if (!(m instanceof ByteBuf)) {
303 return false;
304 }
305 zcWriteQueue.add(m);
306 ((ByteBuf) m).retain();
307 return true;
308 });
309 } finally {
310 zcWriteQueue.add(ZC_BATCH_MARKER);
311 }
312 }
313 }
314 }