1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.uring;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelOutboundBuffer;
21 import io.netty.channel.socket.ServerSocketChannel;
22 import io.netty.channel.socket.SocketChannel;
23 import io.netty.channel.socket.SocketChannelConfig;
24 import io.netty.channel.unix.IovArray;
25
26 import java.net.InetSocketAddress;
27 import java.net.SocketAddress;
28 import java.util.ArrayDeque;
29 import java.util.Queue;
30
31 import static io.netty.channel.unix.Errors.ioResult;
32
33 public final class IoUringSocketChannel extends AbstractIoUringStreamChannel implements SocketChannel {
34 private final IoUringSocketChannelConfig config;
35
36 public IoUringSocketChannel() {
37 super(null, LinuxSocket.newSocketStream(), false);
38 this.config = new IoUringSocketChannelConfig(this);
39 }
40
41 IoUringSocketChannel(Channel parent, LinuxSocket fd) {
42 super(parent, fd, true);
43 this.config = new IoUringSocketChannelConfig(this);
44 }
45
46 IoUringSocketChannel(Channel parent, LinuxSocket fd, SocketAddress remote) {
47 super(parent, fd, remote);
48 this.config = new IoUringSocketChannelConfig(this);
49 }
50
51 @Override
52 public ServerSocketChannel parent() {
53 return (ServerSocketChannel) super.parent();
54 }
55
56 @Override
57 public SocketChannelConfig config() {
58 return config;
59 }
60
61 @Override
62 public InetSocketAddress remoteAddress() {
63 return (InetSocketAddress) super.remoteAddress();
64 }
65
66 @Override
67 public InetSocketAddress localAddress() {
68 return (InetSocketAddress) super.localAddress();
69 }
70
71 @Override
72 protected AbstractUringUnsafe newUnsafe() {
73 return new IoUringSocketUnsafe();
74 }
75
76
77 private static final Object ZC_BATCH_MARKER = new Object();
78
79 private final class IoUringSocketUnsafe extends IoUringStreamUnsafe {
80
81
82
83 private Queue<Object> zcWriteQueue;
84
85 @Override
86 protected int scheduleWriteSingle(Object msg) {
87 assert writeId == 0;
88
89 if (IoUring.isSendZcSupported() && msg instanceof ByteBuf) {
90 ByteBuf buf = (ByteBuf) msg;
91 int length = buf.readableBytes();
92 if (((IoUringSocketChannelConfig) config()).shouldWriteZeroCopy(length)) {
93 long address = IoUring.memoryAddress(buf) + buf.readerIndex();
94 IoUringIoOps ops = IoUringIoOps.newSendZc(fd().intValue(), address, length, 0, nextOpsId(), 0);
95 byte opCode = ops.opcode();
96 writeId = registration().submit(ops);
97 writeOpCode = opCode;
98 if (writeId == 0) {
99 return 0;
100 }
101 return 1;
102 }
103
104 }
105 return super.scheduleWriteSingle(msg);
106 }
107
108 @Override
109 protected int scheduleWriteMultiple(ChannelOutboundBuffer in) {
110 assert writeId == 0;
111
112 if (IoUring.isSendmsgZcSupported() && (
113 (IoUringSocketChannelConfig) config()).shouldWriteZeroCopy((int) in.totalPendingWriteBytes())) {
114 IoUringIoHandler handler = registration().attachment();
115
116 IovArray iovArray = handler.iovArray();
117 int offset = iovArray.count();
118
119
120 iovArray.maxCount(Native.MAX_SKB_FRAGS);
121 try {
122 in.forEachFlushedMessage(iovArray);
123 } catch (Exception e) {
124
125 return scheduleWriteSingle(in.current());
126 }
127 long iovArrayAddress = iovArray.memoryAddress(offset);
128 int iovArrayLength = iovArray.count() - offset;
129
130 MsgHdrMemoryArray msgHdrArray = handler.msgHdrMemoryArray();
131 MsgHdrMemory hdr = msgHdrArray.nextHdr();
132 assert hdr != null;
133 hdr.set(iovArrayAddress, iovArrayLength);
134 IoUringIoOps ops = IoUringIoOps.newSendmsgZc(fd().intValue(), (byte) 0, 0, hdr.address(), nextOpsId());
135 byte opCode = ops.opcode();
136 writeId = registration().submit(ops);
137 writeOpCode = opCode;
138 if (writeId == 0) {
139 return 0;
140 }
141 return 1;
142 }
143
144 return super.scheduleWriteMultiple(in);
145 }
146
147 @Override
148 boolean writeComplete0(byte op, int res, int flags, short data, int outstanding) {
149 ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer();
150 if (op == Native.IORING_OP_SEND_ZC || op == Native.IORING_OP_SENDMSG_ZC) {
151 return handleWriteCompleteZeroCopy(op, channelOutboundBuffer, res, flags);
152 }
153 return super.writeComplete0(op, res, flags, data, outstanding);
154 }
155
156 private boolean handleWriteCompleteZeroCopy(byte op, ChannelOutboundBuffer channelOutboundBuffer,
157 int res, int flags) {
158 if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
159 return true;
160 }
161 if ((flags & Native.IORING_CQE_F_NOTIF) == 0) {
162
163
164
165
166 writeId = 0;
167 writeOpCode = 0;
168
169 boolean more = (flags & Native.IORING_CQE_F_MORE) != 0;
170 if (more) {
171
172
173
174
175 if (zcWriteQueue == null) {
176 zcWriteQueue = new ArrayDeque<>(8);
177 }
178 }
179 if (res >= 0) {
180 if (more) {
181
182
183
184 do {
185 ByteBuf currentBuffer = (ByteBuf) channelOutboundBuffer.current();
186 assert currentBuffer != null;
187 zcWriteQueue.add(currentBuffer);
188 currentBuffer.retain();
189 int readable = currentBuffer.readableBytes();
190 int skip = Math.min(readable, res);
191 currentBuffer.skipBytes(skip);
192 channelOutboundBuffer.progress(readable);
193 if (readable <= res) {
194 boolean removed = channelOutboundBuffer.remove();
195 assert removed;
196 }
197 res -= readable;
198 } while (res > 0);
199
200 zcWriteQueue.add(ZC_BATCH_MARKER);
201 } else {
202
203 channelOutboundBuffer.removeBytes(res);
204 }
205 return true;
206 } else {
207 try {
208 String msg = op == Native.IORING_OP_SEND_ZC ? "io_uring sendzc" : "io_uring sendmsg_zc";
209 int result = ioResult(msg, res);
210 if (more) {
211 try {
212
213
214 addFlushedToZcWriteQueue(channelOutboundBuffer);
215 } catch (Exception e) {
216
217 handleWriteError(e);
218 }
219 }
220 if (result == 0) {
221 return false;
222 }
223 } catch (Throwable cause) {
224 if (more) {
225 try {
226
227
228
229 addFlushedToZcWriteQueue(channelOutboundBuffer);
230 } catch (Exception e) {
231
232 cause.addSuppressed(e);
233 }
234 }
235 handleWriteError(cause);
236 }
237 }
238 } else {
239 if (zcWriteQueue != null) {
240 for (;;) {
241 Object queued = zcWriteQueue.remove();
242 assert queued != null;
243 if (queued == ZC_BATCH_MARKER) {
244
245 break;
246 }
247
248 ((ByteBuf) queued).release();
249 }
250 }
251 }
252 return true;
253 }
254
255 private void addFlushedToZcWriteQueue(ChannelOutboundBuffer channelOutboundBuffer) throws Exception {
256
257
258
259 try {
260 channelOutboundBuffer.forEachFlushedMessage(m -> {
261 if (!(m instanceof ByteBuf)) {
262 return false;
263 }
264 zcWriteQueue.add(m);
265 ((ByteBuf) m).retain();
266 return true;
267 });
268 } finally {
269 zcWriteQueue.add(ZC_BATCH_MARKER);
270 }
271 }
272 @Override
273 protected boolean canCloseNow0() {
274 return (zcWriteQueue == null || zcWriteQueue.isEmpty()) && super.canCloseNow0();
275 }
276 }
277 }