1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.SimpleChannelInboundHandler;
26 import java.util.concurrent.TimeUnit;
27 import org.junit.jupiter.api.Test;
28 import org.junit.jupiter.api.TestInfo;
29 import org.junit.jupiter.api.Timeout;
30
31 import java.io.IOException;
32 import java.util.concurrent.atomic.AtomicReference;
33
34 import static org.junit.jupiter.api.Assertions.assertEquals;
35 import static org.junit.jupiter.api.Assertions.assertTrue;
36
37 public class SocketCancelWriteTest extends AbstractSocketTest {
38
39 @Test
40 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
41 public void testCancelWrite(TestInfo testInfo) throws Throwable {
42 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
43 @Override
44 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
45 testCancelWrite(serverBootstrap, bootstrap);
46 }
47 });
48 }
49
50 public void testCancelWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
51 final TestHandler sh = new TestHandler();
52 final TestHandler ch = new TestHandler();
53 final ByteBuf a = Unpooled.buffer().writeByte('a');
54 final ByteBuf b = Unpooled.buffer().writeByte('b');
55 final ByteBuf c = Unpooled.buffer().writeByte('c');
56 final ByteBuf d = Unpooled.buffer().writeByte('d');
57 final ByteBuf e = Unpooled.buffer().writeByte('e');
58
59 cb.handler(ch);
60 sb.childHandler(sh);
61
62 Channel sc = sb.bind().sync().channel();
63 Channel cc = cb.connect(sc.localAddress()).sync().channel();
64
65 ChannelFuture f = cc.write(a);
66 assertTrue(f.cancel(false));
67 cc.writeAndFlush(b);
68 cc.write(c);
69 ChannelFuture f2 = cc.write(d);
70 assertTrue(f2.cancel(false));
71 cc.writeAndFlush(e);
72
73 while (sh.counter < 3) {
74 if (sh.exception.get() != null) {
75 break;
76 }
77 if (ch.exception.get() != null) {
78 break;
79 }
80 try {
81 Thread.sleep(50);
82 } catch (InterruptedException ignore) {
83
84 }
85 }
86 sh.channel.close().sync();
87 ch.channel.close().sync();
88 sc.close().sync();
89
90 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
91 throw sh.exception.get();
92 }
93 if (sh.exception.get() != null) {
94 throw sh.exception.get();
95 }
96 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
97 throw ch.exception.get();
98 }
99 if (ch.exception.get() != null) {
100 throw ch.exception.get();
101 }
102 assertEquals(0, ch.counter);
103 assertEquals(Unpooled.wrappedBuffer(new byte[]{'b', 'c', 'e'}), sh.received);
104 }
105
106 private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
107 volatile Channel channel;
108 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
109 volatile int counter;
110 final ByteBuf received = Unpooled.buffer();
111 @Override
112 public void channelActive(ChannelHandlerContext ctx)
113 throws Exception {
114 channel = ctx.channel();
115 }
116
117 @Override
118 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
119 counter += in.readableBytes();
120 received.writeBytes(in);
121 }
122
123 @Override
124 public void exceptionCaught(ChannelHandlerContext ctx,
125 Throwable cause) throws Exception {
126 if (exception.compareAndSet(null, cause)) {
127 ctx.close();
128 }
129 }
130 }
131 }