1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.socket.nio;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelException;
21 import io.netty.channel.ChannelFuture;
22 import io.netty.channel.ChannelFutureListener;
23 import io.netty.channel.ChannelOption;
24 import io.netty.channel.ChannelOutboundBuffer;
25 import io.netty.channel.ChannelPromise;
26 import io.netty.channel.EventLoop;
27 import io.netty.channel.FileRegion;
28 import io.netty.channel.RecvByteBufAllocator;
29 import io.netty.channel.nio.AbstractNioByteChannel;
30 import io.netty.channel.socket.DefaultSocketChannelConfig;
31 import io.netty.channel.socket.InternetProtocolFamily;
32 import io.netty.channel.socket.ServerSocketChannel;
33 import io.netty.channel.socket.SocketChannelConfig;
34 import io.netty.util.concurrent.GlobalEventExecutor;
35 import io.netty.util.internal.PlatformDependent;
36 import io.netty.util.internal.SocketUtils;
37 import io.netty.util.internal.SuppressJava6Requirement;
38 import io.netty.util.internal.logging.InternalLogger;
39 import io.netty.util.internal.logging.InternalLoggerFactory;
40
41
42 import java.io.IOException;
43 import java.lang.reflect.Method;
44 import java.net.InetSocketAddress;
45 import java.net.Socket;
46 import java.net.SocketAddress;
47 import java.nio.ByteBuffer;
48 import java.nio.channels.SelectionKey;
49 import java.nio.channels.SocketChannel;
50 import java.nio.channels.spi.SelectorProvider;
51 import java.util.Map;
52 import java.util.concurrent.Executor;
53
54 import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
55
56
57
58
59 public class NioSocketChannel extends AbstractNioByteChannel implements io.netty.channel.socket.SocketChannel {
60 private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
61 private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
62
63 private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
64 SelectorProviderUtil.findOpenMethod("openSocketChannel");
65
66 private final SocketChannelConfig config;
67
68 private static SocketChannel newChannel(SelectorProvider provider, InternetProtocolFamily family) {
69 try {
70 SocketChannel channel = SelectorProviderUtil.newChannel(OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider, family);
71 return channel == null ? provider.openSocketChannel() : channel;
72 } catch (IOException e) {
73 throw new ChannelException("Failed to open a socket.", e);
74 }
75 }
76
77
78
79
80 public NioSocketChannel() {
81 this(DEFAULT_SELECTOR_PROVIDER);
82 }
83
84
85
86
87 public NioSocketChannel(SelectorProvider provider) {
88 this(provider, null);
89 }
90
91
92
93
94 public NioSocketChannel(SelectorProvider provider, InternetProtocolFamily family) {
95 this(newChannel(provider, family));
96 }
97
98
99
100
101 public NioSocketChannel(SocketChannel socket) {
102 this(null, socket);
103 }
104
105
106
107
108
109
110
111 public NioSocketChannel(Channel parent, SocketChannel socket) {
112 super(parent, socket);
113 config = new NioSocketChannelConfig(this, socket.socket());
114 }
115
116 @Override
117 public ServerSocketChannel parent() {
118 return (ServerSocketChannel) super.parent();
119 }
120
121 @Override
122 public SocketChannelConfig config() {
123 return config;
124 }
125
126 @Override
127 protected SocketChannel javaChannel() {
128 return (SocketChannel) super.javaChannel();
129 }
130
131 @Override
132 public boolean isActive() {
133 SocketChannel ch = javaChannel();
134 return ch.isOpen() && ch.isConnected();
135 }
136
137 @Override
138 public boolean isOutputShutdown() {
139 return javaChannel().socket().isOutputShutdown() || !isActive();
140 }
141
142 @Override
143 public boolean isInputShutdown() {
144 return javaChannel().socket().isInputShutdown() || !isActive();
145 }
146
147 @Override
148 public boolean isShutdown() {
149 Socket socket = javaChannel().socket();
150 return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive();
151 }
152
153 @Override
154 public InetSocketAddress localAddress() {
155 return (InetSocketAddress) super.localAddress();
156 }
157
158 @Override
159 public InetSocketAddress remoteAddress() {
160 return (InetSocketAddress) super.remoteAddress();
161 }
162
163 @SuppressJava6Requirement(reason = "Usage guarded by java version check")
164 @Override
165 protected final void doShutdownOutput() throws Exception {
166 if (PlatformDependent.javaVersion() >= 7) {
167 javaChannel().shutdownOutput();
168 } else {
169 javaChannel().socket().shutdownOutput();
170 }
171 }
172
173 @Override
174 public ChannelFuture shutdownOutput() {
175 return shutdownOutput(newPromise());
176 }
177
178 @Override
179 public ChannelFuture shutdownOutput(final ChannelPromise promise) {
180 final EventLoop loop = eventLoop();
181 if (loop.inEventLoop()) {
182 ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
183 } else {
184 loop.execute(new Runnable() {
185 @Override
186 public void run() {
187 ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
188 }
189 });
190 }
191 return promise;
192 }
193
194 @Override
195 public ChannelFuture shutdownInput() {
196 return shutdownInput(newPromise());
197 }
198
199 @Override
200 protected boolean isInputShutdown0() {
201 return isInputShutdown();
202 }
203
204 @Override
205 public ChannelFuture shutdownInput(final ChannelPromise promise) {
206 EventLoop loop = eventLoop();
207 if (loop.inEventLoop()) {
208 shutdownInput0(promise);
209 } else {
210 loop.execute(new Runnable() {
211 @Override
212 public void run() {
213 shutdownInput0(promise);
214 }
215 });
216 }
217 return promise;
218 }
219
220 @Override
221 public ChannelFuture shutdown() {
222 return shutdown(newPromise());
223 }
224
225 @Override
226 public ChannelFuture shutdown(final ChannelPromise promise) {
227 ChannelFuture shutdownOutputFuture = shutdownOutput();
228 if (shutdownOutputFuture.isDone()) {
229 shutdownOutputDone(shutdownOutputFuture, promise);
230 } else {
231 shutdownOutputFuture.addListener(new ChannelFutureListener() {
232 @Override
233 public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
234 shutdownOutputDone(shutdownOutputFuture, promise);
235 }
236 });
237 }
238 return promise;
239 }
240
241 private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
242 ChannelFuture shutdownInputFuture = shutdownInput();
243 if (shutdownInputFuture.isDone()) {
244 shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
245 } else {
246 shutdownInputFuture.addListener(new ChannelFutureListener() {
247 @Override
248 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
249 shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
250 }
251 });
252 }
253 }
254
255 private static void shutdownDone(ChannelFuture shutdownOutputFuture,
256 ChannelFuture shutdownInputFuture,
257 ChannelPromise promise) {
258 Throwable shutdownOutputCause = shutdownOutputFuture.cause();
259 Throwable shutdownInputCause = shutdownInputFuture.cause();
260 if (shutdownOutputCause != null) {
261 if (shutdownInputCause != null) {
262 logger.debug("Exception suppressed because a previous exception occurred.",
263 shutdownInputCause);
264 }
265 promise.setFailure(shutdownOutputCause);
266 } else if (shutdownInputCause != null) {
267 promise.setFailure(shutdownInputCause);
268 } else {
269 promise.setSuccess();
270 }
271 }
272 private void shutdownInput0(final ChannelPromise promise) {
273 try {
274 shutdownInput0();
275 promise.setSuccess();
276 } catch (Throwable t) {
277 promise.setFailure(t);
278 }
279 }
280
281 @SuppressJava6Requirement(reason = "Usage guarded by java version check")
282 private void shutdownInput0() throws Exception {
283 if (PlatformDependent.javaVersion() >= 7) {
284 javaChannel().shutdownInput();
285 } else {
286 javaChannel().socket().shutdownInput();
287 }
288 }
289
290 @Override
291 protected SocketAddress localAddress0() {
292 return javaChannel().socket().getLocalSocketAddress();
293 }
294
295 @Override
296 protected SocketAddress remoteAddress0() {
297 return javaChannel().socket().getRemoteSocketAddress();
298 }
299
300 @Override
301 protected void doBind(SocketAddress localAddress) throws Exception {
302 doBind0(localAddress);
303 }
304
305 private void doBind0(SocketAddress localAddress) throws Exception {
306 if (PlatformDependent.javaVersion() >= 7) {
307 SocketUtils.bind(javaChannel(), localAddress);
308 } else {
309 SocketUtils.bind(javaChannel().socket(), localAddress);
310 }
311 }
312
313 @Override
314 protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
315 if (localAddress != null) {
316 doBind0(localAddress);
317 }
318
319 boolean success = false;
320 try {
321 boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
322 if (!connected) {
323 selectionKey().interestOps(SelectionKey.OP_CONNECT);
324 }
325 success = true;
326 return connected;
327 } finally {
328 if (!success) {
329 doClose();
330 }
331 }
332 }
333
334 @Override
335 protected void doFinishConnect() throws Exception {
336 if (!javaChannel().finishConnect()) {
337 throw new Error();
338 }
339 }
340
341 @Override
342 protected void doDisconnect() throws Exception {
343 doClose();
344 }
345
346 @Override
347 protected void doClose() throws Exception {
348 super.doClose();
349 javaChannel().close();
350 }
351
352 @Override
353 protected int doReadBytes(ByteBuf byteBuf) throws Exception {
354 final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
355 allocHandle.attemptedBytesRead(byteBuf.writableBytes());
356 return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
357 }
358
359 @Override
360 protected int doWriteBytes(ByteBuf buf) throws Exception {
361 final int expectedWrittenBytes = buf.readableBytes();
362 return buf.readBytes(javaChannel(), expectedWrittenBytes);
363 }
364
365 @Override
366 protected long doWriteFileRegion(FileRegion region) throws Exception {
367 final long position = region.transferred();
368 return region.transferTo(javaChannel(), position);
369 }
370
371 private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
372
373
374
375 if (attempted == written) {
376 if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
377 ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
378 }
379 } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
380 ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
381 }
382 }
383
384 @Override
385 protected void doWrite(ChannelOutboundBuffer in) throws Exception {
386 SocketChannel ch = javaChannel();
387 int writeSpinCount = config().getWriteSpinCount();
388 do {
389 if (in.isEmpty()) {
390
391 clearOpWrite();
392
393 return;
394 }
395
396
397 int maxBytesPerGatheringWrite = ((NioSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
398 ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
399 int nioBufferCnt = in.nioBufferCount();
400
401
402
403 switch (nioBufferCnt) {
404 case 0:
405
406 writeSpinCount -= doWrite0(in);
407 break;
408 case 1: {
409
410
411
412 ByteBuffer buffer = nioBuffers[0];
413 int attemptedBytes = buffer.remaining();
414 final int localWrittenBytes = ch.write(buffer);
415 if (localWrittenBytes <= 0) {
416 incompleteWrite(true);
417 return;
418 }
419 adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
420 in.removeBytes(localWrittenBytes);
421 --writeSpinCount;
422 break;
423 }
424 default: {
425
426
427
428 long attemptedBytes = in.nioBufferSize();
429 final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
430 if (localWrittenBytes <= 0) {
431 incompleteWrite(true);
432 return;
433 }
434
435 adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
436 maxBytesPerGatheringWrite);
437 in.removeBytes(localWrittenBytes);
438 --writeSpinCount;
439 break;
440 }
441 }
442 } while (writeSpinCount > 0);
443
444 incompleteWrite(writeSpinCount < 0);
445 }
446
447 @Override
448 protected AbstractNioUnsafe newUnsafe() {
449 return new NioSocketChannelUnsafe();
450 }
451
452 private final class NioSocketChannelUnsafe extends NioByteUnsafe {
453 @Override
454 protected Executor prepareToClose() {
455 try {
456 if (javaChannel().isOpen() && config().getSoLinger() > 0) {
457
458
459
460
461 doDeregister();
462 return GlobalEventExecutor.INSTANCE;
463 }
464 } catch (Throwable ignore) {
465
466
467
468 }
469 return null;
470 }
471 }
472
473 private final class NioSocketChannelConfig extends DefaultSocketChannelConfig {
474 private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
475 private NioSocketChannelConfig(NioSocketChannel channel, Socket javaSocket) {
476 super(channel, javaSocket);
477 calculateMaxBytesPerGatheringWrite();
478 }
479
480 @Override
481 protected void autoReadCleared() {
482 clearReadPending();
483 }
484
485 @Override
486 public NioSocketChannelConfig setSendBufferSize(int sendBufferSize) {
487 super.setSendBufferSize(sendBufferSize);
488 calculateMaxBytesPerGatheringWrite();
489 return this;
490 }
491
492 @Override
493 public <T> boolean setOption(ChannelOption<T> option, T value) {
494 if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) {
495 return NioChannelOption.setOption(jdkChannel(), (NioChannelOption<T>) option, value);
496 }
497 return super.setOption(option, value);
498 }
499
500 @Override
501 public <T> T getOption(ChannelOption<T> option) {
502 if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) {
503 return NioChannelOption.getOption(jdkChannel(), (NioChannelOption<T>) option);
504 }
505 return super.getOption(option);
506 }
507
508 @Override
509 public Map<ChannelOption<?>, Object> getOptions() {
510 if (PlatformDependent.javaVersion() >= 7) {
511 return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel()));
512 }
513 return super.getOptions();
514 }
515
516 void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
517 this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
518 }
519
520 int getMaxBytesPerGatheringWrite() {
521 return maxBytesPerGatheringWrite;
522 }
523
524 private void calculateMaxBytesPerGatheringWrite() {
525
526 int newSendBufferSize = getSendBufferSize() << 1;
527 if (newSendBufferSize > 0) {
528 setMaxBytesPerGatheringWrite(newSendBufferSize);
529 }
530 }
531
532 private SocketChannel jdkChannel() {
533 return ((NioSocketChannel) channel).javaChannel();
534 }
535 }
536 }