1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.uring;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelOutboundBuffer;
21 import io.netty.channel.socket.ServerSocketChannel;
22 import io.netty.channel.socket.SocketChannel;
23 import io.netty.channel.socket.SocketChannelConfig;
24 import io.netty.channel.unix.IovArray;
25
26 import java.net.InetSocketAddress;
27 import java.net.SocketAddress;
28 import java.util.ArrayDeque;
29 import java.util.Queue;
30
31 import static io.netty.channel.unix.Errors.ioResult;
32
33 public final class IoUringSocketChannel extends AbstractIoUringStreamChannel implements SocketChannel {
34 private final IoUringSocketChannelConfig config;
35
36 public IoUringSocketChannel() {
37 super(null, LinuxSocket.newSocketStream(), false);
38 this.config = new IoUringSocketChannelConfig(this);
39 }
40
41 IoUringSocketChannel(Channel parent, LinuxSocket fd) {
42 super(parent, fd, true);
43 this.config = new IoUringSocketChannelConfig(this);
44 }
45
46 IoUringSocketChannel(Channel parent, LinuxSocket fd, SocketAddress remote) {
47 super(parent, fd, remote);
48 this.config = new IoUringSocketChannelConfig(this);
49 }
50
51 @Override
52 public ServerSocketChannel parent() {
53 return (ServerSocketChannel) super.parent();
54 }
55
56 @Override
57 public SocketChannelConfig config() {
58 return config;
59 }
60
61 @Override
62 public InetSocketAddress remoteAddress() {
63 return (InetSocketAddress) super.remoteAddress();
64 }
65
66 @Override
67 public InetSocketAddress localAddress() {
68 return (InetSocketAddress) super.localAddress();
69 }
70
71 @Override
72 protected AbstractUringUnsafe newUnsafe() {
73 return new IoUringSocketUnsafe();
74 }
75
76
77 private static final Object ZC_BATCH_MARKER = new Object();
78
79 private final class IoUringSocketUnsafe extends IoUringStreamUnsafe {
80
81
82
83 private Queue<Object> zcWriteQueue;
84
85 @Override
86 protected int scheduleWriteSingle(Object msg) {
87 assert writeId == 0;
88
89 if (IoUring.isSendZcSupported() && msg instanceof ByteBuf) {
90 ByteBuf buf = (ByteBuf) msg;
91 int length = buf.readableBytes();
92 if (((IoUringSocketChannelConfig) config()).shouldWriteZeroCopy(length)) {
93 long address = IoUring.memoryAddress(buf) + buf.readerIndex();
94 IoUringIoOps ops = IoUringIoOps.newSendZc(fd().intValue(), address, length, 0, nextOpsId(), 0);
95 byte opCode = ops.opcode();
96 writeId = registration().submit(ops);
97 writeOpCode = opCode;
98 if (writeId == 0) {
99 return 0;
100 }
101 return 1;
102 }
103
104 }
105 return super.scheduleWriteSingle(msg);
106 }
107
108 @Override
109 protected int scheduleWriteMultiple(ChannelOutboundBuffer in) {
110 assert writeId == 0;
111
112 if (IoUring.isSendmsgZcSupported() && (
113 (IoUringSocketChannelConfig) config()).shouldWriteZeroCopy((int) in.totalPendingWriteBytes())) {
114 IoUringIoHandler handler = registration().attachment();
115
116 IovArray iovArray = handler.iovArray();
117 int offset = iovArray.count();
118
119
120 iovArray.maxCount(Native.MAX_SKB_FRAGS);
121 try {
122 in.forEachFlushedMessage(iovArray);
123 } catch (Exception e) {
124
125 return scheduleWriteSingle(in.current());
126 }
127 long iovArrayAddress = iovArray.memoryAddress(offset);
128 int iovArrayLength = iovArray.count() - offset;
129
130 MsgHdrMemoryArray msgHdrArray = handler.msgHdrMemoryArray();
131 MsgHdrMemory hdr = msgHdrArray.nextHdr();
132 assert hdr != null;
133 hdr.set(iovArrayAddress, iovArrayLength);
134 IoUringIoOps ops = IoUringIoOps.newSendmsgZc(fd().intValue(), (byte) 0, 0, hdr.address(), nextOpsId());
135 byte opCode = ops.opcode();
136 writeId = registration().submit(ops);
137 writeOpCode = opCode;
138 if (writeId == 0) {
139 return 0;
140 }
141 return 1;
142 }
143
144 return super.scheduleWriteMultiple(in);
145 }
146
147 @Override
148 boolean writeComplete0(byte op, int res, int flags, short data, int outstanding) {
149 ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer();
150 if (op == Native.IORING_OP_SEND_ZC || op == Native.IORING_OP_SENDMSG_ZC) {
151 return handleWriteCompleteZeroCopy(op, channelOutboundBuffer, res, flags);
152 }
153 return super.writeComplete0(op, res, flags, data, outstanding);
154 }
155
156 private boolean handleWriteCompleteZeroCopy(byte op, ChannelOutboundBuffer channelOutboundBuffer,
157 int res, int flags) {
158 if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
159 return true;
160 }
161 if ((flags & Native.IORING_CQE_F_NOTIF) == 0) {
162
163
164
165
166 writeId = 0;
167 writeOpCode = 0;
168
169 boolean more = (flags & Native.IORING_CQE_F_MORE) != 0;
170 if (res >= 0) {
171 if (more) {
172
173
174
175
176 if (zcWriteQueue == null) {
177 zcWriteQueue = new ArrayDeque<>(8);
178 }
179
180
181
182 do {
183 ByteBuf currentBuffer = (ByteBuf) channelOutboundBuffer.current();
184 assert currentBuffer != null;
185 zcWriteQueue.add(currentBuffer);
186 currentBuffer.retain();
187 int readable = currentBuffer.readableBytes();
188 int skip = Math.min(readable, res);
189 currentBuffer.skipBytes(skip);
190 channelOutboundBuffer.progress(readable);
191 if (readable <= res) {
192 boolean removed = channelOutboundBuffer.remove();
193 assert removed;
194 }
195 res -= readable;
196 } while (res > 0);
197
198 zcWriteQueue.add(ZC_BATCH_MARKER);
199 } else {
200
201 channelOutboundBuffer.removeBytes(res);
202 }
203 return true;
204 } else {
205 try {
206 String msg = op == Native.IORING_OP_SEND_ZC ? "io_uring sendzc" : "io_uring sendmsg_zc";
207 if (ioResult(msg, res) == 0) {
208 return false;
209 }
210 } catch (Throwable cause) {
211 handleWriteError(cause);
212 }
213 }
214 } else {
215 if (zcWriteQueue != null) {
216 for (;;) {
217 Object queued = zcWriteQueue.remove();
218 assert queued != null;
219 if (queued == ZC_BATCH_MARKER) {
220
221 break;
222 }
223
224 ((ByteBuf) queued).release();
225 }
226 }
227 }
228 return true;
229 }
230
231 @Override
232 protected boolean canCloseNow0() {
233 return (zcWriteQueue == null || zcWriteQueue.isEmpty()) && super.canCloseNow0();
234 }
235 }
236 }