1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.embedded;
17
18 import io.netty.channel.AbstractChannel;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelConfig;
21 import io.netty.channel.ChannelFuture;
22 import io.netty.channel.ChannelFutureListener;
23 import io.netty.channel.ChannelHandler;
24 import io.netty.channel.ChannelInitializer;
25 import io.netty.channel.ChannelMetadata;
26 import io.netty.channel.ChannelOutboundBuffer;
27 import io.netty.channel.ChannelPipeline;
28 import io.netty.channel.ChannelPromise;
29 import io.netty.channel.DefaultChannelConfig;
30 import io.netty.channel.DefaultChannelPipeline;
31 import io.netty.channel.EventLoop;
32 import io.netty.util.ReferenceCountUtil;
33 import io.netty.util.internal.ObjectUtil;
34 import io.netty.util.internal.PlatformDependent;
35 import io.netty.util.internal.RecyclableArrayList;
36 import io.netty.util.internal.UnstableApi;
37 import io.netty.util.internal.logging.InternalLogger;
38 import io.netty.util.internal.logging.InternalLoggerFactory;
39
40 import java.net.SocketAddress;
41 import java.nio.channels.ClosedChannelException;
42 import java.util.ArrayDeque;
43 import java.util.Queue;
44
45
46
47
48 public class EmbeddedChannel extends AbstractChannel {
49
50 private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class);
51
52 private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false);
53 private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true);
54
55 private final EmbeddedEventLoop loop = new EmbeddedEventLoop();
56 private final ChannelFutureListener recordExceptionListener = new ChannelFutureListener() {
57 @Override
58 public void operationComplete(ChannelFuture future) throws Exception {
59 recordException(future);
60 }
61 };
62
63 private final ChannelMetadata metadata;
64 private final ChannelConfig config;
65 private final SocketAddress localAddress = new EmbeddedSocketAddress();
66 private final SocketAddress remoteAddress = new EmbeddedSocketAddress();
67
68 private Queue<Object> inboundMessages;
69 private Queue<Object> outboundMessages;
70 private Throwable lastException;
71 private int state;
72
73
74
75
76
77
78 public EmbeddedChannel(final ChannelHandler... handlers) {
79 this(false, handlers);
80 }
81
82
83
84
85
86
87
88
89
90 public EmbeddedChannel(boolean hasDisconnect, final ChannelHandler... handlers) {
91 this(true, hasDisconnect, handlers);
92 }
93
94
95
96
97
98
99
100
101
102
103 public EmbeddedChannel(boolean register, boolean hasDisconnect, ChannelHandler... handlers) {
104 super(null);
105 metadata = metadata(hasDisconnect);
106 config = new DefaultChannelConfig(this);
107 setup(register, handlers);
108 }
109
110
111
112
113
114
115
116
117
118
119 public EmbeddedChannel(boolean hasDisconnect, final ChannelConfig config,
120 final ChannelHandler... handlers) {
121 super(null);
122 metadata = metadata(hasDisconnect);
123 this.config = ObjectUtil.checkNotNull(config, "config");
124 setup(true, handlers);
125 }
126
127 private static ChannelMetadata metadata(boolean hasDisconnect) {
128 return hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT;
129 }
130
131 private void setup(boolean register, final ChannelHandler... handlers) {
132 ObjectUtil.checkNotNull(handlers, "handlers");
133 ChannelPipeline p = pipeline();
134 p.addLast(new ChannelInitializer<Channel>() {
135 @Override
136 protected void initChannel(Channel ch) throws Exception {
137 ChannelPipeline pipeline = ch.pipeline();
138 for (ChannelHandler h: handlers) {
139 if (h == null) {
140 break;
141 }
142 pipeline.addLast(h);
143 }
144 }
145 });
146 if (register) {
147 ChannelFuture future = loop.register(this);
148 assert future.isDone();
149 }
150 }
151
152
153
154
155 public void register() throws Exception {
156 ChannelFuture future = loop.register(this);
157 assert future.isDone();
158 Throwable cause = future.cause();
159 if (cause != null) {
160 PlatformDependent.throwException(cause);
161 }
162 }
163
164 @Override
165 protected final DefaultChannelPipeline newChannelPipeline() {
166 return new EmbeddedChannelPipeline(this);
167 }
168
169 @Override
170 public ChannelMetadata metadata() {
171 return metadata;
172 }
173
174 @Override
175 public ChannelConfig config() {
176 return config;
177 }
178
179 @Override
180 public boolean isOpen() {
181 return state < 2;
182 }
183
184 @Override
185 public boolean isActive() {
186 return state == 1;
187 }
188
189
190
191
192 public Queue<Object> inboundMessages() {
193 if (inboundMessages == null) {
194 inboundMessages = new ArrayDeque<Object>();
195 }
196 return inboundMessages;
197 }
198
199
200
201
202 @Deprecated
203 public Queue<Object> lastInboundBuffer() {
204 return inboundMessages();
205 }
206
207
208
209
210 public Queue<Object> outboundMessages() {
211 if (outboundMessages == null) {
212 outboundMessages = new ArrayDeque<Object>();
213 }
214 return outboundMessages;
215 }
216
217
218
219
220 @Deprecated
221 public Queue<Object> lastOutboundBuffer() {
222 return outboundMessages();
223 }
224
225
226
227
228 public Object readInbound() {
229 return poll(inboundMessages);
230 }
231
232
233
234
235 public Object readOutbound() {
236 return poll(outboundMessages);
237 }
238
239
240
241
242
243
244
245
246 public boolean writeInbound(Object... msgs) {
247 ensureOpen();
248 if (msgs.length == 0) {
249 return isNotEmpty(inboundMessages);
250 }
251
252 ChannelPipeline p = pipeline();
253 for (Object m: msgs) {
254 p.fireChannelRead(m);
255 }
256 p.fireChannelReadComplete();
257 runPendingTasks();
258 checkException();
259 return isNotEmpty(inboundMessages);
260 }
261
262
263
264
265
266
267
268 public boolean writeOutbound(Object... msgs) {
269 ensureOpen();
270 if (msgs.length == 0) {
271 return isNotEmpty(outboundMessages);
272 }
273
274 RecyclableArrayList futures = RecyclableArrayList.newInstance(msgs.length);
275 try {
276 for (Object m: msgs) {
277 if (m == null) {
278 break;
279 }
280 futures.add(write(m));
281 }
282
283
284 runPendingTasks();
285 flush();
286
287 int size = futures.size();
288 for (int i = 0; i < size; i++) {
289 ChannelFuture future = (ChannelFuture) futures.get(i);
290 if (future.isDone()) {
291 recordException(future);
292 } else {
293
294 future.addListener(recordExceptionListener);
295 }
296 }
297
298 checkException();
299 return isNotEmpty(outboundMessages);
300 } finally {
301 futures.recycle();
302 }
303 }
304
305
306
307
308
309
310 public boolean finish() {
311 return finish(false);
312 }
313
314
315
316
317
318
319
320 public boolean finishAndReleaseAll() {
321 return finish(true);
322 }
323
324
325
326
327
328
329
330 private boolean finish(boolean releaseAll) {
331 close();
332 try {
333 checkException();
334 return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages);
335 } finally {
336 if (releaseAll) {
337 releaseAll(inboundMessages);
338 releaseAll(outboundMessages);
339 }
340 }
341 }
342
343
344
345
346
347 public boolean releaseInbound() {
348 return releaseAll(inboundMessages);
349 }
350
351
352
353
354
355 public boolean releaseOutbound() {
356 return releaseAll(outboundMessages);
357 }
358
359 private static boolean releaseAll(Queue<Object> queue) {
360 if (isNotEmpty(queue)) {
361 for (;;) {
362 Object msg = queue.poll();
363 if (msg == null) {
364 break;
365 }
366 ReferenceCountUtil.release(msg);
367 }
368 return true;
369 }
370 return false;
371 }
372
373 private void finishPendingTasks(boolean cancel) {
374 runPendingTasks();
375 if (cancel) {
376
377 loop.cancelScheduledTasks();
378 }
379 }
380
381 @Override
382 public final ChannelFuture close() {
383 return close(newPromise());
384 }
385
386 @Override
387 public final ChannelFuture disconnect() {
388 return disconnect(newPromise());
389 }
390
391 @Override
392 public final ChannelFuture close(ChannelPromise promise) {
393
394
395 runPendingTasks();
396 ChannelFuture future = super.close(promise);
397
398
399 finishPendingTasks(true);
400 return future;
401 }
402
403 @Override
404 public final ChannelFuture disconnect(ChannelPromise promise) {
405 ChannelFuture future = super.disconnect(promise);
406 finishPendingTasks(!metadata.hasDisconnect());
407 return future;
408 }
409
410 private static boolean isNotEmpty(Queue<Object> queue) {
411 return queue != null && !queue.isEmpty();
412 }
413
414 private static Object poll(Queue<Object> queue) {
415 return queue != null ? queue.poll() : null;
416 }
417
418
419
420
421
422 public void runPendingTasks() {
423 try {
424 loop.runTasks();
425 } catch (Exception e) {
426 recordException(e);
427 }
428
429 try {
430 loop.runScheduledTasks();
431 } catch (Exception e) {
432 recordException(e);
433 }
434 }
435
436
437
438
439
440
441 public long runScheduledPendingTasks() {
442 try {
443 return loop.runScheduledTasks();
444 } catch (Exception e) {
445 recordException(e);
446 return loop.nextScheduledTask();
447 }
448 }
449
450 private void recordException(ChannelFuture future) {
451 if (!future.isSuccess()) {
452 recordException(future.cause());
453 }
454 }
455
456 private void recordException(Throwable cause) {
457 if (lastException == null) {
458 lastException = cause;
459 } else {
460 logger.warn(
461 "More than one exception was raised. " +
462 "Will report only the first one and log others.", cause);
463 }
464 }
465
466
467
468
469 public void checkException() {
470 Throwable t = lastException;
471 if (t == null) {
472 return;
473 }
474
475 lastException = null;
476
477 PlatformDependent.throwException(t);
478 }
479
480
481
482
483 protected final void ensureOpen() {
484 if (!isOpen()) {
485 recordException(new ClosedChannelException());
486 checkException();
487 }
488 }
489
490 @Override
491 protected boolean isCompatible(EventLoop loop) {
492 return loop instanceof EmbeddedEventLoop;
493 }
494
495 @Override
496 protected SocketAddress localAddress0() {
497 return isActive()? localAddress : null;
498 }
499
500 @Override
501 protected SocketAddress remoteAddress0() {
502 return isActive()? remoteAddress : null;
503 }
504
505 @Override
506 protected void doRegister() throws Exception {
507 state = 1;
508 }
509
510 @Override
511 protected void doBind(SocketAddress localAddress) throws Exception {
512
513 }
514
515 @Override
516 protected void doDisconnect() throws Exception {
517 if (!metadata.hasDisconnect()) {
518 doClose();
519 }
520 }
521
522 @Override
523 protected void doClose() throws Exception {
524 state = 2;
525 }
526
527 @Override
528 protected void doBeginRead() throws Exception {
529
530 }
531
532 @Override
533 protected AbstractUnsafe newUnsafe() {
534 return new EmbeddedUnsafe();
535 }
536
537 @Override
538 public Unsafe unsafe() {
539 return ((EmbeddedUnsafe) super.unsafe()).wrapped;
540 }
541
542 @Override
543 protected void doWrite(ChannelOutboundBuffer in) throws Exception {
544 for (;;) {
545 Object msg = in.current();
546 if (msg == null) {
547 break;
548 }
549
550 ReferenceCountUtil.retain(msg);
551 outboundMessages().add(msg);
552 in.remove();
553 }
554 }
555
556 private final class EmbeddedUnsafe extends AbstractUnsafe {
557
558
559
560 final Unsafe wrapped = new Unsafe() {
561 @Override
562 public SocketAddress localAddress() {
563 return EmbeddedUnsafe.this.localAddress();
564 }
565
566 @Override
567 public SocketAddress remoteAddress() {
568 return EmbeddedUnsafe.this.remoteAddress();
569 }
570
571 @Override
572 public void register(EventLoop eventLoop, ChannelPromise promise) {
573 EmbeddedUnsafe.this.register(eventLoop, promise);
574 runPendingTasks();
575 }
576
577 @Override
578 public void bind(SocketAddress localAddress, ChannelPromise promise) {
579 EmbeddedUnsafe.this.bind(localAddress, promise);
580 runPendingTasks();
581 }
582
583 @Override
584 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
585 EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise);
586 runPendingTasks();
587 }
588
589 @Override
590 public void disconnect(ChannelPromise promise) {
591 EmbeddedUnsafe.this.disconnect(promise);
592 runPendingTasks();
593 }
594
595 @Override
596 public void close(ChannelPromise promise) {
597 EmbeddedUnsafe.this.close(promise);
598 runPendingTasks();
599 }
600
601 @Override
602 public void closeForcibly() {
603 EmbeddedUnsafe.this.closeForcibly();
604 runPendingTasks();
605 }
606
607 @Override
608 public void deregister(ChannelPromise promise) {
609 EmbeddedUnsafe.this.deregister(promise);
610 runPendingTasks();
611 }
612
613 @Override
614 public void beginRead() {
615 EmbeddedUnsafe.this.beginRead();
616 runPendingTasks();
617 }
618
619 @Override
620 public void write(Object msg, ChannelPromise promise) {
621 EmbeddedUnsafe.this.write(msg, promise);
622 runPendingTasks();
623 }
624
625 @Override
626 public void flush() {
627 EmbeddedUnsafe.this.flush();
628 runPendingTasks();
629 }
630
631 @Override
632 public ChannelPromise voidPromise() {
633 return EmbeddedUnsafe.this.voidPromise();
634 }
635
636 @Override
637 public ChannelOutboundBuffer outboundBuffer() {
638 return EmbeddedUnsafe.this.outboundBuffer();
639 }
640 };
641
642 @Override
643 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
644 safeSetSuccess(promise);
645 }
646 }
647
648 private final class EmbeddedChannelPipeline extends DefaultChannelPipeline {
649 EmbeddedChannelPipeline(EmbeddedChannel channel) {
650 super(channel);
651 }
652
653 @Override
654 protected void onUnhandledInboundException(Throwable cause) {
655 recordException(cause);
656 }
657
658 @Override
659 protected void onUnhandledInboundMessage(Object msg) {
660 inboundMessages().add(msg);
661 }
662 }
663 }