1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.uring;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelOutboundBuffer;
21 import io.netty.channel.socket.ServerSocketChannel;
22 import io.netty.channel.socket.SocketChannel;
23 import io.netty.channel.socket.SocketChannelConfig;
24 import io.netty.channel.unix.IovArray;
25
26 import java.net.InetSocketAddress;
27 import java.net.SocketAddress;
28 import java.util.ArrayDeque;
29 import java.util.Queue;
30
31 import static io.netty.channel.unix.Errors.ioResult;
32
33 public final class IoUringSocketChannel extends AbstractIoUringStreamChannel implements SocketChannel {
34 private final IoUringSocketChannelConfig config;
35
36 public IoUringSocketChannel() {
37 super(null, LinuxSocket.newSocketStream(), false);
38 this.config = new IoUringSocketChannelConfig(this);
39 }
40
41 IoUringSocketChannel(Channel parent, LinuxSocket fd) {
42 super(parent, fd, true);
43 this.config = new IoUringSocketChannelConfig(this);
44 }
45
46 IoUringSocketChannel(Channel parent, LinuxSocket fd, SocketAddress remote) {
47 super(parent, fd, remote);
48 this.config = new IoUringSocketChannelConfig(this);
49 }
50
51 @Override
52 public ServerSocketChannel parent() {
53 return (ServerSocketChannel) super.parent();
54 }
55
56 @Override
57 public SocketChannelConfig config() {
58 return config;
59 }
60
61 @Override
62 public InetSocketAddress remoteAddress() {
63 return (InetSocketAddress) super.remoteAddress();
64 }
65
66 @Override
67 public InetSocketAddress localAddress() {
68 return (InetSocketAddress) super.localAddress();
69 }
70
71 @Override
72 protected AbstractUringUnsafe newUnsafe() {
73 return new IoUringSocketUnsafe();
74 }
75
76
77 private static final Object ZC_BATCH_MARKER = new Object();
78
79 private final class IoUringSocketUnsafe extends IoUringStreamUnsafe {
80
81
82
83 private Queue<Object> zcWriteQueue;
84
85 @Override
86 protected int scheduleWriteSingle(Object msg) {
87 assert writeId == 0;
88
89 if (IoUring.isSendZcSupported() && msg instanceof ByteBuf) {
90 ByteBuf buf = (ByteBuf) msg;
91 int length = buf.readableBytes();
92 if (((IoUringSocketChannelConfig) config()).shouldWriteZeroCopy(length)) {
93 long address = IoUring.memoryAddress(buf) + buf.readerIndex();
94 IoUringIoOps ops = IoUringIoOps.newSendZc(fd().intValue(), address, length, 0, nextOpsId(), 0);
95 byte opCode = ops.opcode();
96 writeId = registration().submit(ops);
97 writeOpCode = opCode;
98 if (writeId == 0) {
99 return 0;
100 }
101 return 1;
102 }
103
104 }
105 return super.scheduleWriteSingle(msg);
106 }
107
108 @Override
109 protected int scheduleWriteMultiple(ChannelOutboundBuffer in) {
110 assert writeId == 0;
111
112 IoUringSocketChannelConfig ioUringSocketChannelConfig = (IoUringSocketChannelConfig) config();
113
114 if (IoUring.isSendmsgZcSupported()
115 && (ioUringSocketChannelConfig.shouldWriteZeroCopy(((ByteBuf) in.current()).readableBytes()))) {
116 IoUringIoHandler handler = registration().attachment();
117
118 IovArray iovArray = handler.iovArray();
119 int offset = iovArray.count();
120
121
122 iovArray.maxCount(Native.MAX_SKB_FRAGS);
123 try {
124 in.forEachFlushedMessage(new ChannelOutboundBuffer.MessageProcessor() {
125 @Override
126 public boolean processMessage(Object msg) throws Exception {
127 if (msg instanceof ByteBuf) {
128 ByteBuf buf = (ByteBuf) msg;
129 int length = buf.readableBytes();
130 if (ioUringSocketChannelConfig.shouldWriteZeroCopy(length)) {
131 return iovArray.processMessage(msg);
132 }
133 }
134 return false;
135 }
136 });
137 } catch (Exception e) {
138
139 return scheduleWriteSingle(in.current());
140 }
141 long iovArrayAddress = iovArray.memoryAddress(offset);
142 int iovArrayLength = iovArray.count() - offset;
143
144 MsgHdrMemoryArray msgHdrArray = handler.msgHdrMemoryArray();
145 MsgHdrMemory hdr = msgHdrArray.nextHdr();
146 assert hdr != null;
147 hdr.set(iovArrayAddress, iovArrayLength);
148 IoUringIoOps ops = IoUringIoOps.newSendmsgZc(fd().intValue(), (byte) 0, 0, hdr.address(), nextOpsId());
149 byte opCode = ops.opcode();
150 writeId = registration().submit(ops);
151 writeOpCode = opCode;
152 if (writeId == 0) {
153 return 0;
154 }
155 return 1;
156 }
157
158 return super.scheduleWriteMultiple(in);
159 }
160
161 @Override
162 protected ChannelOutboundBuffer.MessageProcessor filterWriteMultiple(IovArray iovArray) {
163 IoUringSocketChannelConfig ioUringSocketChannelConfig = (IoUringSocketChannelConfig) config();
164 return new ChannelOutboundBuffer.MessageProcessor() {
165 @Override
166 public boolean processMessage(Object msg) throws Exception {
167 if (msg instanceof ByteBuf) {
168 ByteBuf buf = (ByteBuf) msg;
169 int length = buf.readableBytes();
170 if (ioUringSocketChannelConfig.shouldWriteZeroCopy(length)) {
171 return false;
172 }
173 }
174 return iovArray.processMessage(msg);
175 }
176 };
177 }
178
179 @Override
180 boolean writeComplete0(byte op, int res, int flags, short data, int outstanding) {
181 ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer();
182 if (op == Native.IORING_OP_SEND_ZC || op == Native.IORING_OP_SENDMSG_ZC) {
183 return handleWriteCompleteZeroCopy(op, channelOutboundBuffer, res, flags);
184 }
185 return super.writeComplete0(op, res, flags, data, outstanding);
186 }
187
188 private boolean handleWriteCompleteZeroCopy(byte op, ChannelOutboundBuffer channelOutboundBuffer,
189 int res, int flags) {
190 if ((flags & Native.IORING_CQE_F_NOTIF) == 0) {
191
192
193
194
195 writeId = 0;
196 writeOpCode = 0;
197
198 boolean more = (flags & Native.IORING_CQE_F_MORE) != 0;
199 if (more) {
200
201
202
203
204 if (zcWriteQueue == null) {
205 zcWriteQueue = new ArrayDeque<>(8);
206 }
207 }
208 if (res >= 0) {
209 if (more) {
210
211
212
213 do {
214 ByteBuf currentBuffer = (ByteBuf) channelOutboundBuffer.current();
215 assert currentBuffer != null;
216 zcWriteQueue.add(currentBuffer);
217 currentBuffer.retain();
218 int readable = currentBuffer.readableBytes();
219 int skip = Math.min(readable, res);
220 currentBuffer.skipBytes(skip);
221 channelOutboundBuffer.progress(readable);
222 if (readable <= res) {
223 boolean removed = channelOutboundBuffer.remove();
224 assert removed;
225 }
226 res -= readable;
227 } while (res > 0);
228
229 zcWriteQueue.add(ZC_BATCH_MARKER);
230 } else {
231
232 channelOutboundBuffer.removeBytes(res);
233 }
234 return true;
235 } else {
236 if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
237 if (more) {
238
239
240
241 zcWriteQueue.add(ZC_BATCH_MARKER);
242 }
243 return true;
244 }
245 try {
246 String msg = op == Native.IORING_OP_SEND_ZC ? "io_uring sendzc" : "io_uring sendmsg_zc";
247 int result = ioResult(msg, res);
248 if (more) {
249 try {
250
251
252 addFlushedToZcWriteQueue(channelOutboundBuffer);
253 } catch (Exception e) {
254
255 handleWriteError(e);
256 }
257 }
258 if (result == 0) {
259 return false;
260 }
261 } catch (Throwable cause) {
262 if (more) {
263 try {
264
265
266
267 addFlushedToZcWriteQueue(channelOutboundBuffer);
268 } catch (Exception e) {
269
270 cause.addSuppressed(e);
271 }
272 }
273 handleWriteError(cause);
274 }
275 }
276 } else {
277 if (zcWriteQueue != null) {
278 for (;;) {
279 Object queued = zcWriteQueue.remove();
280 assert queued != null;
281 if (queued == ZC_BATCH_MARKER) {
282
283 break;
284 }
285
286 ((ByteBuf) queued).release();
287 }
288 }
289 }
290 return true;
291 }
292
293 private void addFlushedToZcWriteQueue(ChannelOutboundBuffer channelOutboundBuffer) throws Exception {
294
295
296
297 try {
298 channelOutboundBuffer.forEachFlushedMessage(m -> {
299 if (!(m instanceof ByteBuf)) {
300 return false;
301 }
302 zcWriteQueue.add(m);
303 ((ByteBuf) m).retain();
304 return true;
305 });
306 } finally {
307 zcWriteQueue.add(ZC_BATCH_MARKER);
308 }
309 }
310 }
311 }