1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.uring;
17
18 import io.netty.channel.Channel;
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelOutboundBuffer;
22 import io.netty.channel.ChannelPipeline;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.channel.IoRegistration;
25 import io.netty.channel.unix.DomainSocketAddress;
26 import io.netty.channel.unix.DomainSocketChannel;
27 import io.netty.channel.unix.DomainSocketChannelConfig;
28 import io.netty.channel.unix.DomainSocketReadMode;
29 import io.netty.channel.unix.Errors;
30 import io.netty.channel.unix.FileDescriptor;
31 import io.netty.channel.unix.PeerCredentials;
32
33 import java.io.IOException;
34 import java.net.SocketAddress;
35
36
37
38
39 public final class IoUringDomainSocketChannel extends AbstractIoUringStreamChannel implements DomainSocketChannel {
40
41 private final IoUringDomainSocketChannelConfig config;
42
43 private volatile DomainSocketAddress local;
44 private volatile DomainSocketAddress remote;
45
46 public IoUringDomainSocketChannel() {
47 super(null, LinuxSocket.newSocketDomain(), false);
48 config = new IoUringDomainSocketChannelConfig(this);
49 }
50
51 IoUringDomainSocketChannel(Channel parent, FileDescriptor fd) {
52 this(parent, new LinuxSocket(fd.intValue()));
53 }
54
55 IoUringDomainSocketChannel(Channel parent, LinuxSocket fd) {
56 super(parent, fd, true);
57 local = fd.localDomainSocketAddress();
58 remote = fd.remoteDomainSocketAddress();
59 config = new IoUringDomainSocketChannelConfig(this);
60 }
61
62 @Override
63 public DomainSocketChannelConfig config() {
64 return config;
65 }
66
67 @Override
68 public DomainSocketAddress localAddress() {
69 return local;
70 }
71
72 @Override
73 public DomainSocketAddress remoteAddress() {
74 return remote;
75 }
76
77
78
79
80
81 public PeerCredentials peerCredentials() throws IOException {
82 return socket.getPeerCredentials();
83 }
84
85 @Override
86 protected Object filterOutboundMessage(Object msg) {
87 if (msg instanceof FileDescriptor) {
88 return msg;
89 }
90 return super.filterOutboundMessage(msg);
91 }
92
93 @Override
94 protected AbstractUringUnsafe newUnsafe() {
95 return new IoUringDomainSocketUnsafe();
96 }
97
98 @Override
99 protected boolean allowMultiShotPollIn() {
100
101
102 return false;
103 }
104
105 private final class IoUringDomainSocketUnsafe extends IoUringStreamUnsafe {
106
107 private MsgHdrMemory writeMsgHdrMemory;
108 private MsgHdrMemory readMsgHdrMemory;
109
110 @Override
111 protected int scheduleWriteSingle(Object msg) {
112 if (msg instanceof FileDescriptor) {
113
114
115 if (writeMsgHdrMemory == null) {
116 writeMsgHdrMemory = new MsgHdrMemory();
117 }
118 IoRegistration registration = registration();
119 IoUringIoOps ioUringIoOps = prepSendFdIoOps((FileDescriptor) msg, writeMsgHdrMemory);
120 writeId = registration.submit(ioUringIoOps);
121 writeOpCode = Native.IORING_OP_SENDMSG;
122 if (writeId == 0) {
123 MsgHdrMemory memory = writeMsgHdrMemory;
124 writeMsgHdrMemory = null;
125 memory.release();
126 return 0;
127 }
128 return 1;
129 }
130 return super.scheduleWriteSingle(msg);
131 }
132
133 @Override
134 boolean writeComplete0(byte op, int res, int flags, short data, int outstanding) {
135 if (op == Native.IORING_OP_SENDMSG) {
136 writeId = 0;
137 writeOpCode = 0;
138 if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
139 return true;
140 }
141 try {
142 int nativeCallResult = res >= 0 ? res : Errors.ioResult("io_uring sendmsg", res);
143 if (nativeCallResult >= 0) {
144 ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer();
145 channelOutboundBuffer.remove();
146 }
147 } catch (Throwable throwable) {
148 handleWriteError(throwable);
149 }
150 return true;
151 }
152 return super.writeComplete0(op, res, flags, data, outstanding);
153 }
154
155 private IoUringIoOps prepSendFdIoOps(FileDescriptor fileDescriptor, MsgHdrMemory msgHdrMemory) {
156 msgHdrMemory.setScmRightsFd(fileDescriptor.intValue());
157 return IoUringIoOps.newSendmsg(
158 fd().intValue(), (byte) 0, 0, msgHdrMemory.address(), msgHdrMemory.idx());
159 }
160
161 @Override
162 protected int scheduleRead0(boolean first, boolean socketIsEmpty) {
163 DomainSocketReadMode readMode = config.getReadMode();
164 switch (readMode) {
165 case FILE_DESCRIPTORS:
166 return scheduleRecvReadFd();
167 case BYTES:
168 return super.scheduleRead0(first, socketIsEmpty);
169 default:
170 throw new Error();
171 }
172 }
173
174 private int scheduleRecvReadFd() {
175
176
177 if (readMsgHdrMemory == null) {
178 readMsgHdrMemory = new MsgHdrMemory();
179 }
180 readMsgHdrMemory.prepRecvReadFd();
181 IoRegistration registration = registration();
182 IoUringIoOps ioUringIoOps = IoUringIoOps.newRecvmsg(
183 fd().intValue(), (byte) 0, 0, readMsgHdrMemory.address(), readMsgHdrMemory.idx());
184 readId = registration.submit(ioUringIoOps);
185 readOpCode = Native.IORING_OP_RECVMSG;
186 if (readId == 0) {
187 MsgHdrMemory memory = readMsgHdrMemory;
188 readMsgHdrMemory = null;
189 memory.release();
190 return 0;
191 }
192 return 1;
193 }
194
195 @Override
196 protected void readComplete0(byte op, int res, int flags, short data, int outstanding) {
197 if (op == Native.IORING_OP_RECVMSG) {
198 readId = 0;
199 if (res == Native.ERRNO_ECANCELED_NEGATIVE) {
200 return;
201 }
202 final IoUringRecvByteAllocatorHandle allocHandle = recvBufAllocHandle();
203 final ChannelPipeline pipeline = pipeline();
204 try {
205 int nativeCallResult = res >= 0 ? res : Errors.ioResult("io_uring recvmsg", res);
206 int nativeFd = readMsgHdrMemory.getScmRightsFd();
207 allocHandle.lastBytesRead(nativeFd);
208 allocHandle.incMessagesRead(1);
209 pipeline.fireChannelRead(new FileDescriptor(nativeFd));
210 } catch (Throwable throwable) {
211 handleReadException(pipeline, null, throwable, false, allocHandle);
212 } finally {
213 allocHandle.readComplete();
214 pipeline.fireChannelReadComplete();
215 }
216 return;
217 }
218 super.readComplete0(op, res, flags, data, outstanding);
219 }
220
221 @Override
222 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
223
224 ChannelPromise channelPromise = newPromise().addListener(new ChannelFutureListener() {
225 @Override
226 public void operationComplete(ChannelFuture future) throws Exception {
227 if (future.isSuccess()) {
228 local = localAddress != null
229 ? (DomainSocketAddress) localAddress
230 : socket.localDomainSocketAddress();
231 remote = (DomainSocketAddress) remoteAddress;
232 promise.setSuccess();
233 } else {
234 promise.setFailure(future.cause());
235 }
236 }
237 });
238 super.connect(remoteAddress, localAddress, channelPromise);
239 }
240
241 @Override
242 protected void freeResourcesNow(IoRegistration reg) {
243 super.freeResourcesNow(reg);
244 if (readMsgHdrMemory != null) {
245 readMsgHdrMemory.release();
246 readMsgHdrMemory = null;
247 }
248 if (writeMsgHdrMemory != null) {
249 writeMsgHdrMemory.release();
250 writeMsgHdrMemory = null;
251 }
252 }
253 }
254
255 @Override
256 boolean isPollInFirst() {
257 DomainSocketReadMode readMode = config.getReadMode();
258 switch (readMode) {
259 case BYTES:
260 return super.isPollInFirst();
261 case FILE_DESCRIPTORS:
262 return false;
263 default:
264 throw new Error();
265 }
266 }
267 }