1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.embedded;
17
18 import java.net.SocketAddress;
19 import java.nio.channels.ClosedChannelException;
20 import java.util.ArrayDeque;
21 import java.util.Queue;
22 import java.util.concurrent.TimeUnit;
23
24 import io.netty.channel.AbstractChannel;
25 import io.netty.channel.Channel;
26 import io.netty.channel.ChannelConfig;
27 import io.netty.channel.ChannelFuture;
28 import io.netty.channel.ChannelFutureListener;
29 import io.netty.channel.ChannelHandler;
30 import io.netty.channel.ChannelHandlerContext;
31 import io.netty.channel.ChannelId;
32 import io.netty.channel.ChannelInitializer;
33 import io.netty.channel.ChannelMetadata;
34 import io.netty.channel.ChannelOutboundBuffer;
35 import io.netty.channel.ChannelPipeline;
36 import io.netty.channel.ChannelPromise;
37 import io.netty.channel.DefaultChannelConfig;
38 import io.netty.channel.DefaultChannelPipeline;
39 import io.netty.channel.EventLoop;
40 import io.netty.channel.RecvByteBufAllocator;
41 import io.netty.util.ReferenceCountUtil;
42 import io.netty.util.internal.ObjectUtil;
43 import io.netty.util.internal.PlatformDependent;
44 import io.netty.util.internal.RecyclableArrayList;
45 import io.netty.util.internal.logging.InternalLogger;
46 import io.netty.util.internal.logging.InternalLoggerFactory;
47
48
49
50
51 public class EmbeddedChannel extends AbstractChannel {
52
53 private static final SocketAddress LOCAL_ADDRESS = new EmbeddedSocketAddress();
54 private static final SocketAddress REMOTE_ADDRESS = new EmbeddedSocketAddress();
55
56 private static final ChannelHandler[] EMPTY_HANDLERS = new ChannelHandler[0];
57 private enum State { OPEN, ACTIVE, CLOSED }
58
59 private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class);
60
61 private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false);
62 private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true);
63
64 private final EmbeddedEventLoop loop = new EmbeddedEventLoop();
65 private final ChannelFutureListener recordExceptionListener = new ChannelFutureListener() {
66 @Override
67 public void operationComplete(ChannelFuture future) throws Exception {
68 recordException(future);
69 }
70 };
71
72 private final ChannelMetadata metadata;
73 private final ChannelConfig config;
74
75 private Queue<Object> inboundMessages;
76 private Queue<Object> outboundMessages;
77 private Throwable lastException;
78 private State state;
79
80
81
82
83 public EmbeddedChannel() {
84 this(EMPTY_HANDLERS);
85 }
86
87
88
89
90
91
92 public EmbeddedChannel(ChannelId channelId) {
93 this(channelId, EMPTY_HANDLERS);
94 }
95
96
97
98
99
100
101 public EmbeddedChannel(ChannelHandler... handlers) {
102 this(EmbeddedChannelId.INSTANCE, handlers);
103 }
104
105
106
107
108
109
110
111
112 public EmbeddedChannel(boolean hasDisconnect, ChannelHandler... handlers) {
113 this(EmbeddedChannelId.INSTANCE, hasDisconnect, handlers);
114 }
115
116
117
118
119
120
121
122
123
124
125 public EmbeddedChannel(boolean register, boolean hasDisconnect, ChannelHandler... handlers) {
126 this(EmbeddedChannelId.INSTANCE, register, hasDisconnect, handlers);
127 }
128
129
130
131
132
133
134
135
136 public EmbeddedChannel(ChannelId channelId, ChannelHandler... handlers) {
137 this(channelId, false, handlers);
138 }
139
140
141
142
143
144
145
146
147
148
149 public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, ChannelHandler... handlers) {
150 this(channelId, true, hasDisconnect, handlers);
151 }
152
153
154
155
156
157
158
159
160
161
162
163
164 public EmbeddedChannel(ChannelId channelId, boolean register, boolean hasDisconnect,
165 ChannelHandler... handlers) {
166 this(null, channelId, register, hasDisconnect, handlers);
167 }
168
169
170
171
172
173
174
175
176
177
178
179
180
181 public EmbeddedChannel(Channel parent, ChannelId channelId, boolean register, boolean hasDisconnect,
182 final ChannelHandler... handlers) {
183 super(parent, channelId);
184 metadata = metadata(hasDisconnect);
185 config = new DefaultChannelConfig(this);
186 setup(register, handlers);
187 }
188
189
190
191
192
193
194
195
196
197
198
199 public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, final ChannelConfig config,
200 final ChannelHandler... handlers) {
201 super(null, channelId);
202 metadata = metadata(hasDisconnect);
203 this.config = ObjectUtil.checkNotNull(config, "config");
204 setup(true, handlers);
205 }
206
207 private static ChannelMetadata metadata(boolean hasDisconnect) {
208 return hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT;
209 }
210
211 private void setup(boolean register, final ChannelHandler... handlers) {
212 ObjectUtil.checkNotNull(handlers, "handlers");
213 ChannelPipeline p = pipeline();
214 p.addLast(new ChannelInitializer<Channel>() {
215 @Override
216 protected void initChannel(Channel ch) throws Exception {
217 ChannelPipeline pipeline = ch.pipeline();
218 for (ChannelHandler h: handlers) {
219 if (h == null) {
220 break;
221 }
222 pipeline.addLast(h);
223 }
224 }
225 });
226 if (register) {
227 ChannelFuture future = loop.register(this);
228 assert future.isDone();
229 }
230 }
231
232
233
234
235 public void register() throws Exception {
236 ChannelFuture future = loop.register(this);
237 assert future.isDone();
238 Throwable cause = future.cause();
239 if (cause != null) {
240 PlatformDependent.throwException(cause);
241 }
242 }
243
244 @Override
245 protected final DefaultChannelPipeline newChannelPipeline() {
246 return new EmbeddedChannelPipeline(this);
247 }
248
249 @Override
250 public ChannelMetadata metadata() {
251 return metadata;
252 }
253
254 @Override
255 public ChannelConfig config() {
256 return config;
257 }
258
259 @Override
260 public boolean isOpen() {
261 return state != State.CLOSED;
262 }
263
264 @Override
265 public boolean isActive() {
266 return state == State.ACTIVE;
267 }
268
269
270
271
272 public Queue<Object> inboundMessages() {
273 if (inboundMessages == null) {
274 inboundMessages = new ArrayDeque<Object>();
275 }
276 return inboundMessages;
277 }
278
279
280
281
282 @Deprecated
283 public Queue<Object> lastInboundBuffer() {
284 return inboundMessages();
285 }
286
287
288
289
290 public Queue<Object> outboundMessages() {
291 if (outboundMessages == null) {
292 outboundMessages = new ArrayDeque<Object>();
293 }
294 return outboundMessages;
295 }
296
297
298
299
300 @Deprecated
301 public Queue<Object> lastOutboundBuffer() {
302 return outboundMessages();
303 }
304
305
306
307
308 @SuppressWarnings("unchecked")
309 public <T> T readInbound() {
310 T message = (T) poll(inboundMessages);
311 if (message != null) {
312 ReferenceCountUtil.touch(message, "Caller of readInbound() will handle the message from this point");
313 }
314 return message;
315 }
316
317
318
319
320 @SuppressWarnings("unchecked")
321 public <T> T readOutbound() {
322 T message = (T) poll(outboundMessages);
323 if (message != null) {
324 ReferenceCountUtil.touch(message, "Caller of readOutbound() will handle the message from this point.");
325 }
326 return message;
327 }
328
329
330
331
332
333
334
335
336 public boolean writeInbound(Object... msgs) {
337 ensureOpen();
338 if (msgs.length == 0) {
339 return isNotEmpty(inboundMessages);
340 }
341
342 ChannelPipeline p = pipeline();
343 for (Object m: msgs) {
344 p.fireChannelRead(m);
345 }
346
347 flushInbound(false, voidPromise());
348 return isNotEmpty(inboundMessages);
349 }
350
351
352
353
354
355
356
357 public ChannelFuture writeOneInbound(Object msg) {
358 return writeOneInbound(msg, newPromise());
359 }
360
361
362
363
364
365
366
367 public ChannelFuture writeOneInbound(Object msg, ChannelPromise promise) {
368 if (checkOpen(true)) {
369 pipeline().fireChannelRead(msg);
370 }
371 return checkException(promise);
372 }
373
374
375
376
377
378
379 public EmbeddedChannel flushInbound() {
380 flushInbound(true, voidPromise());
381 return this;
382 }
383
384 private ChannelFuture flushInbound(boolean recordException, ChannelPromise promise) {
385 if (checkOpen(recordException)) {
386 pipeline().fireChannelReadComplete();
387 runPendingTasks();
388 }
389
390 return checkException(promise);
391 }
392
393
394
395
396
397
398
399 public boolean writeOutbound(Object... msgs) {
400 ensureOpen();
401 if (msgs.length == 0) {
402 return isNotEmpty(outboundMessages);
403 }
404
405 RecyclableArrayList futures = RecyclableArrayList.newInstance(msgs.length);
406 try {
407 for (Object m: msgs) {
408 if (m == null) {
409 break;
410 }
411 futures.add(write(m));
412 }
413
414 flushOutbound0();
415
416 int size = futures.size();
417 for (int i = 0; i < size; i++) {
418 ChannelFuture future = (ChannelFuture) futures.get(i);
419 if (future.isDone()) {
420 recordException(future);
421 } else {
422
423 future.addListener(recordExceptionListener);
424 }
425 }
426
427 checkException();
428 return isNotEmpty(outboundMessages);
429 } finally {
430 futures.recycle();
431 }
432 }
433
434
435
436
437
438
439
440 public ChannelFuture writeOneOutbound(Object msg) {
441 return writeOneOutbound(msg, newPromise());
442 }
443
444
445
446
447
448
449
450 public ChannelFuture writeOneOutbound(Object msg, ChannelPromise promise) {
451 if (checkOpen(true)) {
452 return write(msg, promise);
453 }
454 return checkException(promise);
455 }
456
457
458
459
460
461
462 public EmbeddedChannel flushOutbound() {
463 if (checkOpen(true)) {
464 flushOutbound0();
465 }
466 checkException(voidPromise());
467 return this;
468 }
469
470 private void flushOutbound0() {
471
472
473 runPendingTasks();
474
475 flush();
476 }
477
478
479
480
481
482
483 public boolean finish() {
484 return finish(false);
485 }
486
487
488
489
490
491
492
493 public boolean finishAndReleaseAll() {
494 return finish(true);
495 }
496
497
498
499
500
501
502
503 private boolean finish(boolean releaseAll) {
504 close();
505 try {
506 checkException();
507 return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages);
508 } finally {
509 if (releaseAll) {
510 releaseAll(inboundMessages);
511 releaseAll(outboundMessages);
512 }
513 }
514 }
515
516
517
518
519
520 public boolean releaseInbound() {
521 return releaseAll(inboundMessages);
522 }
523
524
525
526
527
528 public boolean releaseOutbound() {
529 return releaseAll(outboundMessages);
530 }
531
532 private static boolean releaseAll(Queue<Object> queue) {
533 if (isNotEmpty(queue)) {
534 for (;;) {
535 Object msg = queue.poll();
536 if (msg == null) {
537 break;
538 }
539 ReferenceCountUtil.release(msg);
540 }
541 return true;
542 }
543 return false;
544 }
545
546 private void finishPendingTasks(boolean cancel) {
547 runPendingTasks();
548 if (cancel) {
549
550 embeddedEventLoop().cancelScheduledTasks();
551 }
552 }
553
554 @Override
555 public final ChannelFuture close() {
556 return close(newPromise());
557 }
558
559 @Override
560 public final ChannelFuture disconnect() {
561 return disconnect(newPromise());
562 }
563
564 @Override
565 public final ChannelFuture close(ChannelPromise promise) {
566
567
568 runPendingTasks();
569 ChannelFuture future = super.close(promise);
570
571
572 finishPendingTasks(true);
573 return future;
574 }
575
576 @Override
577 public final ChannelFuture disconnect(ChannelPromise promise) {
578 ChannelFuture future = super.disconnect(promise);
579 finishPendingTasks(!metadata.hasDisconnect());
580 return future;
581 }
582
583 private static boolean isNotEmpty(Queue<Object> queue) {
584 return queue != null && !queue.isEmpty();
585 }
586
587 private static Object poll(Queue<Object> queue) {
588 return queue != null ? queue.poll() : null;
589 }
590
591
592
593
594
595 public void runPendingTasks() {
596 try {
597 embeddedEventLoop().runTasks();
598 } catch (Exception e) {
599 recordException(e);
600 }
601
602 try {
603 embeddedEventLoop().runScheduledTasks();
604 } catch (Exception e) {
605 recordException(e);
606 }
607 }
608
609
610
611
612
613
614 public long runScheduledPendingTasks() {
615 try {
616 return embeddedEventLoop().runScheduledTasks();
617 } catch (Exception e) {
618 recordException(e);
619 return embeddedEventLoop().nextScheduledTask();
620 }
621 }
622
623 private void recordException(ChannelFuture future) {
624 if (!future.isSuccess()) {
625 recordException(future.cause());
626 }
627 }
628
629 private void recordException(Throwable cause) {
630 if (lastException == null) {
631 lastException = cause;
632 } else {
633 logger.warn(
634 "More than one exception was raised. " +
635 "Will report only the first one and log others.", cause);
636 }
637 }
638
639
640
641
642
643 public void advanceTimeBy(long duration, TimeUnit unit) {
644 embeddedEventLoop().advanceTimeBy(unit.toNanos(duration));
645 }
646
647
648
649
650
651
652 public void freezeTime() {
653 embeddedEventLoop().freezeTime();
654 }
655
656
657
658
659
660
661
662
663 public void unfreezeTime() {
664 embeddedEventLoop().unfreezeTime();
665 }
666
667
668
669
670 private ChannelFuture checkException(ChannelPromise promise) {
671 Throwable t = lastException;
672 if (t != null) {
673 lastException = null;
674
675 if (promise.isVoid()) {
676 PlatformDependent.throwException(t);
677 }
678
679 return promise.setFailure(t);
680 }
681
682 return promise.setSuccess();
683 }
684
685
686
687
688 public void checkException() {
689 checkException(voidPromise());
690 }
691
692
693
694
695
696 private boolean checkOpen(boolean recordException) {
697 if (!isOpen()) {
698 if (recordException) {
699 recordException(new ClosedChannelException());
700 }
701 return false;
702 }
703
704 return true;
705 }
706
707 private EmbeddedEventLoop embeddedEventLoop() {
708 if (isRegistered()) {
709 return (EmbeddedEventLoop) super.eventLoop();
710 }
711
712 return loop;
713 }
714
715
716
717
718 protected final void ensureOpen() {
719 if (!checkOpen(true)) {
720 checkException();
721 }
722 }
723
724 @Override
725 protected boolean isCompatible(EventLoop loop) {
726 return loop instanceof EmbeddedEventLoop;
727 }
728
729 @Override
730 protected SocketAddress localAddress0() {
731 return isActive()? LOCAL_ADDRESS : null;
732 }
733
734 @Override
735 protected SocketAddress remoteAddress0() {
736 return isActive()? REMOTE_ADDRESS : null;
737 }
738
739 @Override
740 protected void doRegister() throws Exception {
741 state = State.ACTIVE;
742 }
743
744 @Override
745 protected void doBind(SocketAddress localAddress) throws Exception {
746
747 }
748
749 @Override
750 protected void doDisconnect() throws Exception {
751 if (!metadata.hasDisconnect()) {
752 doClose();
753 }
754 }
755
756 @Override
757 protected void doClose() throws Exception {
758 state = State.CLOSED;
759 }
760
761 @Override
762 protected void doBeginRead() throws Exception {
763
764 }
765
766 @Override
767 protected AbstractUnsafe newUnsafe() {
768 return new EmbeddedUnsafe();
769 }
770
771 @Override
772 public Unsafe unsafe() {
773 return ((EmbeddedUnsafe) super.unsafe()).wrapped;
774 }
775
776 @Override
777 protected void doWrite(ChannelOutboundBuffer in) throws Exception {
778 for (;;) {
779 Object msg = in.current();
780 if (msg == null) {
781 break;
782 }
783
784 ReferenceCountUtil.retain(msg);
785 handleOutboundMessage(msg);
786 in.remove();
787 }
788 }
789
790
791
792
793
794
795 protected void handleOutboundMessage(Object msg) {
796 outboundMessages().add(msg);
797 }
798
799
800
801
802 protected void handleInboundMessage(Object msg) {
803 inboundMessages().add(msg);
804 }
805
806 private final class EmbeddedUnsafe extends AbstractUnsafe {
807
808
809
810 final Unsafe wrapped = new Unsafe() {
811 @Override
812 public RecvByteBufAllocator.Handle recvBufAllocHandle() {
813 return EmbeddedUnsafe.this.recvBufAllocHandle();
814 }
815
816 @Override
817 public SocketAddress localAddress() {
818 return EmbeddedUnsafe.this.localAddress();
819 }
820
821 @Override
822 public SocketAddress remoteAddress() {
823 return EmbeddedUnsafe.this.remoteAddress();
824 }
825
826 @Override
827 public void register(EventLoop eventLoop, ChannelPromise promise) {
828 EmbeddedUnsafe.this.register(eventLoop, promise);
829 runPendingTasks();
830 }
831
832 @Override
833 public void bind(SocketAddress localAddress, ChannelPromise promise) {
834 EmbeddedUnsafe.this.bind(localAddress, promise);
835 runPendingTasks();
836 }
837
838 @Override
839 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
840 EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise);
841 runPendingTasks();
842 }
843
844 @Override
845 public void disconnect(ChannelPromise promise) {
846 EmbeddedUnsafe.this.disconnect(promise);
847 runPendingTasks();
848 }
849
850 @Override
851 public void close(ChannelPromise promise) {
852 EmbeddedUnsafe.this.close(promise);
853 runPendingTasks();
854 }
855
856 @Override
857 public void closeForcibly() {
858 EmbeddedUnsafe.this.closeForcibly();
859 runPendingTasks();
860 }
861
862 @Override
863 public void deregister(ChannelPromise promise) {
864 EmbeddedUnsafe.this.deregister(promise);
865 runPendingTasks();
866 }
867
868 @Override
869 public void beginRead() {
870 EmbeddedUnsafe.this.beginRead();
871 runPendingTasks();
872 }
873
874 @Override
875 public void write(Object msg, ChannelPromise promise) {
876 EmbeddedUnsafe.this.write(msg, promise);
877 runPendingTasks();
878 }
879
880 @Override
881 public void flush() {
882 EmbeddedUnsafe.this.flush();
883 runPendingTasks();
884 }
885
886 @Override
887 public ChannelPromise voidPromise() {
888 return EmbeddedUnsafe.this.voidPromise();
889 }
890
891 @Override
892 public ChannelOutboundBuffer outboundBuffer() {
893 return EmbeddedUnsafe.this.outboundBuffer();
894 }
895 };
896
897 @Override
898 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
899 safeSetSuccess(promise);
900 }
901 }
902
903 private final class EmbeddedChannelPipeline extends DefaultChannelPipeline {
904 EmbeddedChannelPipeline(EmbeddedChannel channel) {
905 super(channel);
906 }
907
908 @Override
909 protected void onUnhandledInboundException(Throwable cause) {
910 recordException(cause);
911 }
912
913 @Override
914 protected void onUnhandledInboundMessage(ChannelHandlerContext ctx, Object msg) {
915 handleInboundMessage(msg);
916 }
917 }
918 }