View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.epoll;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufAllocator;
20  import io.netty.buffer.CompositeByteBuf;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelConfig;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelOption;
26  import io.netty.channel.ChannelOutboundBuffer;
27  import io.netty.channel.ChannelPipeline;
28  import io.netty.channel.ChannelPromise;
29  import io.netty.channel.ConnectTimeoutException;
30  import io.netty.channel.DefaultFileRegion;
31  import io.netty.channel.RecvByteBufAllocator;
32  import io.netty.channel.socket.ChannelInputShutdownEvent;
33  import io.netty.channel.unix.FileDescriptor;
34  import io.netty.util.internal.PlatformDependent;
35  import io.netty.util.internal.StringUtil;
36  
37  import java.io.IOException;
38  import java.net.SocketAddress;
39  import java.nio.ByteBuffer;
40  import java.util.concurrent.ScheduledFuture;
41  import java.util.concurrent.TimeUnit;
42  
43  public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel {
44  
45      private static final String EXPECTED_TYPES =
46              " (expected: " + StringUtil.simpleClassName(ByteBuf.class) + ", " +
47                      StringUtil.simpleClassName(DefaultFileRegion.class) + ')';
48  
49      private volatile boolean inputShutdown;
50      private volatile boolean outputShutdown;
51  
52      protected AbstractEpollStreamChannel(Channel parent, int fd) {
53          super(parent, fd, Native.EPOLLIN, true);
54          // Add EPOLLRDHUP so we are notified once the remote peer close the connection.
55          flags |= Native.EPOLLRDHUP;
56      }
57  
58      protected AbstractEpollStreamChannel(int fd) {
59          super(fd, Native.EPOLLIN);
60          // Add EPOLLRDHUP so we are notified once the remote peer close the connection.
61          flags |= Native.EPOLLRDHUP;
62      }
63  
64      protected AbstractEpollStreamChannel(FileDescriptor fd) {
65          super(null, fd, Native.EPOLLIN, Native.getSoError(fd.intValue()) == 0);
66      }
67  
68      @Override
69      protected AbstractEpollUnsafe newUnsafe() {
70          return new EpollStreamUnsafe();
71      }
72  
73      /**
74       * Write bytes form the given {@link ByteBuf} to the underlying {@link java.nio.channels.Channel}.
75       * @param buf           the {@link ByteBuf} from which the bytes should be written
76       */
77      private boolean writeBytes(ChannelOutboundBuffer in, ByteBuf buf, int writeSpinCount) throws Exception {
78          int readableBytes = buf.readableBytes();
79          if (readableBytes == 0) {
80              in.remove();
81              return true;
82          }
83  
84          if (buf.hasMemoryAddress() || buf.nioBufferCount() == 1) {
85              int writtenBytes = doWriteBytes(buf, writeSpinCount);
86              in.removeBytes(writtenBytes);
87              return writtenBytes == readableBytes;
88          } else {
89              ByteBuffer[] nioBuffers = buf.nioBuffers();
90              return writeBytesMultiple(in, nioBuffers, nioBuffers.length, readableBytes, writeSpinCount);
91          }
92      }
93  
94      private boolean writeBytesMultiple(
95              ChannelOutboundBuffer in, IovArray array, int writeSpinCount) throws IOException {
96  
97          long expectedWrittenBytes = array.size();
98          final long initialExpectedWrittenBytes = expectedWrittenBytes;
99  
100         int cnt = array.count();
101 
102         assert expectedWrittenBytes != 0;
103         assert cnt != 0;
104 
105         boolean done = false;
106         int offset = 0;
107         int end = offset + cnt;
108         for (int i = writeSpinCount - 1; i >= 0; i--) {
109             long localWrittenBytes = Native.writevAddresses(fd().intValue(), array.memoryAddress(offset), cnt);
110             if (localWrittenBytes == 0) {
111                 break;
112             }
113             expectedWrittenBytes -= localWrittenBytes;
114 
115             if (expectedWrittenBytes == 0) {
116                 // Written everything, just break out here (fast-path)
117                 done = true;
118                 break;
119             }
120 
121             do {
122                 long bytes = array.processWritten(offset, localWrittenBytes);
123                 if (bytes == -1) {
124                     // incomplete write
125                     break;
126                 } else {
127                     offset++;
128                     cnt--;
129                     localWrittenBytes -= bytes;
130                 }
131             } while (offset < end && localWrittenBytes > 0);
132         }
133         if (!done) {
134             setFlag(Native.EPOLLOUT);
135         }
136         in.removeBytes(initialExpectedWrittenBytes - expectedWrittenBytes);
137         return done;
138     }
139 
140     private boolean writeBytesMultiple(
141             ChannelOutboundBuffer in, ByteBuffer[] nioBuffers,
142             int nioBufferCnt, long expectedWrittenBytes, int writeSpinCount) throws IOException {
143 
144         assert expectedWrittenBytes != 0;
145         final long initialExpectedWrittenBytes = expectedWrittenBytes;
146 
147         boolean done = false;
148         int offset = 0;
149         int end = offset + nioBufferCnt;
150         for (int i = writeSpinCount - 1; i >= 0; i--) {
151             long localWrittenBytes = Native.writev(fd().intValue(), nioBuffers, offset, nioBufferCnt);
152             if (localWrittenBytes == 0) {
153                 break;
154             }
155             expectedWrittenBytes -= localWrittenBytes;
156 
157             if (expectedWrittenBytes == 0) {
158                 // Written everything, just break out here (fast-path)
159                 done = true;
160                 break;
161             }
162             do {
163                 ByteBuffer buffer = nioBuffers[offset];
164                 int pos = buffer.position();
165                 int bytes = buffer.limit() - pos;
166                 if (bytes > localWrittenBytes) {
167                     buffer.position(pos + (int) localWrittenBytes);
168                     // incomplete write
169                     break;
170                 } else {
171                     offset++;
172                     nioBufferCnt--;
173                     localWrittenBytes -= bytes;
174                 }
175             } while (offset < end && localWrittenBytes > 0);
176         }
177 
178         in.removeBytes(initialExpectedWrittenBytes - expectedWrittenBytes);
179         if (!done) {
180             setFlag(Native.EPOLLOUT);
181         }
182         return done;
183     }
184 
185     /**
186      * Write a {@link DefaultFileRegion}
187      *
188      * @param region        the {@link DefaultFileRegion} from which the bytes should be written
189      * @return amount       the amount of written bytes
190      */
191     private boolean writeFileRegion(
192             ChannelOutboundBuffer in, DefaultFileRegion region, int writeSpinCount) throws Exception {
193         final long regionCount = region.count();
194         if (region.transfered() >= regionCount) {
195             in.remove();
196             return true;
197         }
198 
199         final long baseOffset = region.position();
200         boolean done = false;
201         long flushedAmount = 0;
202 
203         for (int i = writeSpinCount - 1; i >= 0; i--) {
204             final long offset = region.transfered();
205             final long localFlushedAmount =
206                     Native.sendfile(fd().intValue(), region, baseOffset, offset, regionCount - offset);
207             if (localFlushedAmount == 0) {
208                 break;
209             }
210 
211             flushedAmount += localFlushedAmount;
212             if (region.transfered() >= regionCount) {
213                 done = true;
214                 break;
215             }
216         }
217 
218         if (flushedAmount > 0) {
219             in.progress(flushedAmount);
220         }
221 
222         if (done) {
223             in.remove();
224         } else {
225             // Returned EAGAIN need to set EPOLLOUT
226             setFlag(Native.EPOLLOUT);
227         }
228         return done;
229     }
230 
231     @Override
232     protected void doWrite(ChannelOutboundBuffer in) throws Exception {
233         int writeSpinCount = config().getWriteSpinCount();
234         for (;;) {
235             final int msgCount = in.size();
236 
237             if (msgCount == 0) {
238                 // Wrote all messages.
239                 clearFlag(Native.EPOLLOUT);
240                 break;
241             }
242 
243             // Do gathering write if the outbounf buffer entries start with more than one ByteBuf.
244             if (msgCount > 1 && in.current() instanceof ByteBuf) {
245                 if (!doWriteMultiple(in, writeSpinCount)) {
246                     break;
247                 }
248 
249                 // We do not break the loop here even if the outbound buffer was flushed completely,
250                 // because a user might have triggered another write and flush when we notify his or her
251                 // listeners.
252             } else { // msgCount == 1
253                 if (!doWriteSingle(in, writeSpinCount)) {
254                     break;
255                 }
256             }
257         }
258     }
259 
260     protected boolean doWriteSingle(ChannelOutboundBuffer in, int writeSpinCount) throws Exception {
261         // The outbound buffer contains only one message or it contains a file region.
262         Object msg = in.current();
263         if (msg instanceof ByteBuf) {
264             ByteBuf buf = (ByteBuf) msg;
265             if (!writeBytes(in, buf, writeSpinCount)) {
266                 // was not able to write everything so break here we will get notified later again once
267                 // the network stack can handle more writes.
268                 return false;
269             }
270         } else if (msg instanceof DefaultFileRegion) {
271             DefaultFileRegion region = (DefaultFileRegion) msg;
272             if (!writeFileRegion(in, region, writeSpinCount)) {
273                 // was not able to write everything so break here we will get notified later again once
274                 // the network stack can handle more writes.
275                 return false;
276             }
277         } else {
278             // Should never reach here.
279             throw new Error();
280         }
281 
282         return true;
283     }
284 
285     private boolean doWriteMultiple(ChannelOutboundBuffer in, int writeSpinCount) throws Exception {
286         if (PlatformDependent.hasUnsafe()) {
287             // this means we can cast to IovArray and write the IovArray directly.
288             IovArray array = IovArrayThreadLocal.get(in);
289             int cnt = array.count();
290             if (cnt >= 1) {
291                 // TODO: Handle the case where cnt == 1 specially.
292                 if (!writeBytesMultiple(in, array, writeSpinCount)) {
293                     // was not able to write everything so break here we will get notified later again once
294                     // the network stack can handle more writes.
295                     return false;
296                 }
297             } else { // cnt == 0, which means the outbound buffer contained empty buffers only.
298                 in.removeBytes(0);
299             }
300         } else {
301             ByteBuffer[] buffers = in.nioBuffers();
302             int cnt = in.nioBufferCount();
303             if (cnt >= 1) {
304                 // TODO: Handle the case where cnt == 1 specially.
305                 if (!writeBytesMultiple(in, buffers, cnt, in.nioBufferSize(), writeSpinCount)) {
306                     // was not able to write everything so break here we will get notified later again once
307                     // the network stack can handle more writes.
308                     return false;
309                 }
310             } else { // cnt == 0, which means the outbound buffer contained empty buffers only.
311                 in.removeBytes(0);
312             }
313         }
314 
315         return true;
316     }
317 
318     @Override
319     protected Object filterOutboundMessage(Object msg) {
320         if (msg instanceof ByteBuf) {
321             ByteBuf buf = (ByteBuf) msg;
322             if (!buf.hasMemoryAddress() && (PlatformDependent.hasUnsafe() || !buf.isDirect())) {
323                 if (buf instanceof CompositeByteBuf) {
324                     // Special handling of CompositeByteBuf to reduce memory copies if some of the Components
325                     // in the CompositeByteBuf are backed by a memoryAddress.
326                     CompositeByteBuf comp = (CompositeByteBuf) buf;
327                     if (!comp.isDirect() || comp.nioBufferCount() > Native.IOV_MAX) {
328                         // more then 1024 buffers for gathering writes so just do a memory copy.
329                         buf = newDirectBuffer(buf);
330                         assert buf.hasMemoryAddress();
331                     }
332                 } else {
333                     // We can only handle buffers with memory address so we need to copy if a non direct is
334                     // passed to write.
335                     buf = newDirectBuffer(buf);
336                     assert buf.hasMemoryAddress();
337                 }
338             }
339             return buf;
340         }
341 
342         if (msg instanceof DefaultFileRegion) {
343             return msg;
344         }
345 
346         throw new UnsupportedOperationException(
347                 "unsupported message type: " + StringUtil.simpleClassName(msg) + EXPECTED_TYPES);
348     }
349 
350     protected boolean isInputShutdown0() {
351         return inputShutdown;
352     }
353 
354     protected boolean isOutputShutdown0() {
355         return outputShutdown || !isActive();
356     }
357 
358     protected void shutdownOutput0(final ChannelPromise promise) {
359         try {
360             Native.shutdown(fd().intValue(), false, true);
361             outputShutdown = true;
362             promise.setSuccess();
363         } catch (Throwable cause) {
364             promise.setFailure(cause);
365         }
366     }
367 
368     /**
369      * Connect to the remote peer
370      */
371     protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
372         if (localAddress != null) {
373             Native.bind(fd().intValue(), localAddress);
374         }
375 
376         boolean success = false;
377         try {
378             boolean connected = Native.connect(fd().intValue(), remoteAddress);
379             if (!connected) {
380                 setFlag(Native.EPOLLOUT);
381             }
382             success = true;
383             return connected;
384         } finally {
385             if (!success) {
386                 doClose();
387             }
388         }
389     }
390 
391     class EpollStreamUnsafe extends AbstractEpollUnsafe {
392         /**
393          * The future of the current connection attempt.  If not null, subsequent
394          * connection attempts will fail.
395          */
396         private ChannelPromise connectPromise;
397         private ScheduledFuture<?> connectTimeoutFuture;
398         private SocketAddress requestedRemoteAddress;
399 
400         private RecvByteBufAllocator.Handle allocHandle;
401 
402         private void closeOnRead(ChannelPipeline pipeline) {
403             inputShutdown = true;
404             if (isOpen()) {
405                 if (Boolean.TRUE.equals(config().getOption(ChannelOption.ALLOW_HALF_CLOSURE))) {
406                     clearEpollIn0();
407                     pipeline.fireUserEventTriggered(ChannelInputShutdownEvent.INSTANCE);
408                 } else {
409                     close(voidPromise());
410                 }
411             }
412         }
413 
414         private boolean handleReadException(ChannelPipeline pipeline, ByteBuf byteBuf, Throwable cause, boolean close) {
415             if (byteBuf != null) {
416                 if (byteBuf.isReadable()) {
417                     readPending = false;
418                     pipeline.fireChannelRead(byteBuf);
419                 } else {
420                     byteBuf.release();
421                 }
422             }
423             pipeline.fireChannelReadComplete();
424             pipeline.fireExceptionCaught(cause);
425             if (close || cause instanceof IOException) {
426                 closeOnRead(pipeline);
427                 return true;
428             }
429             return false;
430         }
431 
432         @Override
433         public void connect(
434                 final SocketAddress remoteAddress, final SocketAddress localAddress, final ChannelPromise promise) {
435             if (!promise.setUncancellable() || !ensureOpen(promise)) {
436                 return;
437             }
438 
439             try {
440                 if (connectPromise != null) {
441                     throw new IllegalStateException("connection attempt already made");
442                 }
443 
444                 boolean wasActive = isActive();
445                 if (doConnect(remoteAddress, localAddress)) {
446                     fulfillConnectPromise(promise, wasActive);
447                 } else {
448                     connectPromise = promise;
449                     requestedRemoteAddress = remoteAddress;
450 
451                     // Schedule connect timeout.
452                     int connectTimeoutMillis = config().getConnectTimeoutMillis();
453                     if (connectTimeoutMillis > 0) {
454                         connectTimeoutFuture = eventLoop().schedule(new Runnable() {
455                             @Override
456                             public void run() {
457                                 ChannelPromise connectPromise = EpollStreamUnsafe.this.connectPromise;
458                                 ConnectTimeoutException cause =
459                                         new ConnectTimeoutException("connection timed out: " + remoteAddress);
460                                 if (connectPromise != null && connectPromise.tryFailure(cause)) {
461                                     close(voidPromise());
462                                 }
463                             }
464                         }, connectTimeoutMillis, TimeUnit.MILLISECONDS);
465                     }
466 
467                     promise.addListener(new ChannelFutureListener() {
468                         @Override
469                         public void operationComplete(ChannelFuture future) throws Exception {
470                             if (future.isCancelled()) {
471                                 if (connectTimeoutFuture != null) {
472                                     connectTimeoutFuture.cancel(false);
473                                 }
474                                 connectPromise = null;
475                                 close(voidPromise());
476                             }
477                         }
478                     });
479                 }
480             } catch (Throwable t) {
481                 closeIfClosed();
482                 promise.tryFailure(annotateConnectException(t, remoteAddress));
483             }
484         }
485 
486         private void fulfillConnectPromise(ChannelPromise promise, boolean wasActive) {
487             if (promise == null) {
488                 // Closed via cancellation and the promise has been notified already.
489                 return;
490             }
491             active = true;
492 
493             // trySuccess() will return false if a user cancelled the connection attempt.
494             boolean promiseSet = promise.trySuccess();
495 
496             // Regardless if the connection attempt was cancelled, channelActive() event should be triggered,
497             // because what happened is what happened.
498             if (!wasActive && isActive()) {
499                 pipeline().fireChannelActive();
500             }
501 
502             // If a user cancelled the connection attempt, close the channel, which is followed by channelInactive().
503             if (!promiseSet) {
504                 close(voidPromise());
505             }
506         }
507 
508         private void fulfillConnectPromise(ChannelPromise promise, Throwable cause) {
509             if (promise == null) {
510                 // Closed via cancellation and the promise has been notified already.
511                 return;
512             }
513 
514             // Use tryFailure() instead of setFailure() to avoid the race against cancel().
515             promise.tryFailure(cause);
516             closeIfClosed();
517         }
518 
519         private void finishConnect() {
520             // Note this method is invoked by the event loop only if the connection attempt was
521             // neither cancelled nor timed out.
522 
523             assert eventLoop().inEventLoop();
524 
525             boolean connectStillInProgress = false;
526             try {
527                 boolean wasActive = isActive();
528                 if (!doFinishConnect()) {
529                     connectStillInProgress = true;
530                     return;
531                 }
532                 fulfillConnectPromise(connectPromise, wasActive);
533             } catch (Throwable t) {
534                 fulfillConnectPromise(connectPromise, annotateConnectException(t, requestedRemoteAddress));
535             } finally {
536                 if (!connectStillInProgress) {
537                     // Check for null as the connectTimeoutFuture is only created if a connectTimeoutMillis > 0 is used
538                     // See https://github.com/netty/netty/issues/1770
539                     if (connectTimeoutFuture != null) {
540                         connectTimeoutFuture.cancel(false);
541                     }
542                     connectPromise = null;
543                 }
544             }
545         }
546 
547         @Override
548         void epollOutReady() {
549             if (connectPromise != null) {
550                 // pending connect which is now complete so handle it.
551                 finishConnect();
552             } else {
553                 super.epollOutReady();
554             }
555         }
556 
557         /**
558          * Finish the connect
559          */
560         private boolean doFinishConnect() throws Exception {
561             if (Native.finishConnect(fd().intValue())) {
562                 clearFlag(Native.EPOLLOUT);
563                 return true;
564             } else {
565                 setFlag(Native.EPOLLOUT);
566                 return false;
567             }
568         }
569 
570         @Override
571         void epollRdHupReady() {
572             if (isActive()) {
573                 epollInReady();
574             } else {
575                 closeOnRead(pipeline());
576             }
577         }
578 
579         @Override
580         void epollInReady() {
581             final ChannelConfig config = config();
582             boolean edgeTriggered = isFlagSet(Native.EPOLLET);
583 
584             if (!readPending && !edgeTriggered && !config.isAutoRead()) {
585                 // ChannelConfig.setAutoRead(false) was called in the meantime
586                 clearEpollIn0();
587                 return;
588             }
589 
590             final ChannelPipeline pipeline = pipeline();
591             final ByteBufAllocator allocator = config.getAllocator();
592             RecvByteBufAllocator.Handle allocHandle = this.allocHandle;
593             if (allocHandle == null) {
594                 this.allocHandle = allocHandle = config.getRecvByteBufAllocator().newHandle();
595             }
596 
597             ByteBuf byteBuf = null;
598             boolean close = false;
599             try {
600                 // if edgeTriggered is used we need to read all messages as we are not notified again otherwise.
601                 final int maxMessagesPerRead = edgeTriggered
602                         ? Integer.MAX_VALUE : config.getMaxMessagesPerRead();
603                 int messages = 0;
604                 int totalReadAmount = 0;
605                 do {
606                     // we use a direct buffer here as the native implementations only be able
607                     // to handle direct buffers.
608                     byteBuf = allocHandle.allocate(allocator);
609                     int writable = byteBuf.writableBytes();
610                     int localReadAmount = doReadBytes(byteBuf);
611                     if (localReadAmount <= 0) {
612                         // not was read release the buffer
613                         byteBuf.release();
614                         close = localReadAmount < 0;
615                         break;
616                     }
617                     readPending = false;
618                     pipeline.fireChannelRead(byteBuf);
619                     byteBuf = null;
620 
621                     if (totalReadAmount >= Integer.MAX_VALUE - localReadAmount) {
622                         allocHandle.record(totalReadAmount);
623 
624                         // Avoid overflow.
625                         totalReadAmount = localReadAmount;
626                     } else {
627                         totalReadAmount += localReadAmount;
628                     }
629 
630                     if (localReadAmount < writable) {
631                         // Read less than what the buffer can hold,
632                         // which might mean we drained the recv buffer completely.
633                         break;
634                     }
635                     if (!edgeTriggered && !config.isAutoRead()) {
636                         // This is not using EPOLLET so we can stop reading
637                         // ASAP as we will get notified again later with
638                         // pending data
639                         break;
640                     }
641                 } while (++ messages < maxMessagesPerRead);
642 
643                 pipeline.fireChannelReadComplete();
644                 allocHandle.record(totalReadAmount);
645 
646                 if (close) {
647                     closeOnRead(pipeline);
648                     close = false;
649                 }
650             } catch (Throwable t) {
651                 boolean closed = handleReadException(pipeline, byteBuf, t, close);
652                 if (!closed) {
653                     // trigger a read again as there may be something left to read and because of epoll ET we
654                     // will not get notified again until we read everything from the socket
655                     eventLoop().execute(new Runnable() {
656                         @Override
657                         public void run() {
658                             epollInReady();
659                         }
660                     });
661                 }
662             } finally {
663                 // Check if there is a readPending which was not processed yet.
664                 // This could be for two reasons:
665                 // * The user called Channel.read() or ChannelHandlerContext.read() in channelRead(...) method
666                 // * The user called Channel.read() or ChannelHandlerContext.read() in channelReadComplete(...) method
667                 //
668                 // See https://github.com/netty/netty/issues/2254
669                 if (!readPending && !config.isAutoRead()) {
670                     clearEpollIn0();
671                 }
672             }
673         }
674     }
675 }