1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.buffer.CompositeByteBuf;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelConfig;
25 import io.netty.channel.ChannelFutureListener;
26 import io.netty.channel.ChannelHandlerContext;
27 import io.netty.channel.ChannelInboundHandlerAdapter;
28 import io.netty.channel.ChannelInitializer;
29 import io.netty.channel.ChannelOption;
30 import io.netty.util.ReferenceCountUtil;
31 import org.junit.jupiter.api.Test;
32 import org.junit.jupiter.api.TestInfo;
33 import org.junit.jupiter.api.Timeout;
34
35 import java.io.IOException;
36 import java.util.Random;
37 import java.util.concurrent.CountDownLatch;
38 import java.util.concurrent.TimeUnit;
39 import java.util.concurrent.atomic.AtomicReference;
40
41 import static org.junit.jupiter.api.Assertions.assertEquals;
42
43 public class CompositeBufferGatheringWriteTest extends AbstractSocketTest {
44 private static final int EXPECTED_BYTES = 20;
45
46 @Test
47 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
48 public void testSingleCompositeBufferWrite(TestInfo testInfo) throws Throwable {
49 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
50 @Override
51 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
52 testSingleCompositeBufferWrite(serverBootstrap, bootstrap);
53 }
54 });
55 }
56
57 public void testSingleCompositeBufferWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
58 Channel serverChannel = null;
59 Channel clientChannel = null;
60 try {
61 final CountDownLatch latch = new CountDownLatch(1);
62 final AtomicReference<Object> clientReceived = new AtomicReference<Object>();
63 sb.childHandler(new ChannelInitializer<Channel>() {
64 @Override
65 protected void initChannel(Channel ch) throws Exception {
66 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
67 @Override
68 public void channelActive(ChannelHandlerContext ctx) throws Exception {
69 ctx.writeAndFlush(newCompositeBuffer(ctx.alloc()))
70 .addListener(ChannelFutureListener.CLOSE);
71 }
72 });
73 }
74 });
75 cb.handler(new ChannelInitializer<Channel>() {
76 @Override
77 protected void initChannel(Channel ch) throws Exception {
78 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
79 private ByteBuf aggregator;
80 @Override
81 public void handlerAdded(ChannelHandlerContext ctx) {
82 aggregator = ctx.alloc().buffer(EXPECTED_BYTES);
83 }
84
85 @Override
86 public void channelRead(ChannelHandlerContext ctx, Object msg) {
87 try {
88 if (msg instanceof ByteBuf) {
89 aggregator.writeBytes((ByteBuf) msg);
90 }
91 } finally {
92 ReferenceCountUtil.release(msg);
93 }
94 }
95
96 @Override
97 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
98
99 if (!(cause instanceof IOException)) {
100 clientReceived.set(cause);
101 latch.countDown();
102 } else if (!cause.getMessage().contains("reset")) {
103 logger.warn("{} client got weird exception",
104 CompositeBufferGatheringWriteTest.this.getClass(), cause);
105 }
106 }
107
108 @Override
109 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
110 if (clientReceived.compareAndSet(null, aggregator)) {
111 try {
112 assertEquals(EXPECTED_BYTES, aggregator.readableBytes());
113 } catch (Throwable cause) {
114 aggregator.release();
115 aggregator = null;
116 clientReceived.set(cause);
117 } finally {
118 latch.countDown();
119 }
120 }
121 }
122 });
123 }
124 });
125
126 serverChannel = sb.bind().syncUninterruptibly().channel();
127 clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
128
129 ByteBuf expected = newCompositeBuffer(clientChannel.alloc());
130 latch.await();
131 Object received = clientReceived.get();
132 if (received instanceof ByteBuf) {
133 ByteBuf actual = (ByteBuf) received;
134 assertEquals(expected, actual);
135 expected.release();
136 actual.release();
137 } else {
138 expected.release();
139 throw (Throwable) received;
140 }
141 } finally {
142 if (clientChannel != null) {
143 clientChannel.close().sync();
144 }
145 if (serverChannel != null) {
146 serverChannel.close().sync();
147 }
148 }
149 }
150
151 @Test
152 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
153 public void testCompositeBufferPartialWriteDoesNotCorruptData(TestInfo testInfo) throws Throwable {
154 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
155 @Override
156 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
157 testCompositeBufferPartialWriteDoesNotCorruptData(serverBootstrap, bootstrap);
158 }
159 });
160 }
161
162 protected void compositeBufferPartialWriteDoesNotCorruptDataInitServerConfig(ChannelConfig config,
163 int soSndBuf) {
164 }
165
166 public void testCompositeBufferPartialWriteDoesNotCorruptData(ServerBootstrap sb, Bootstrap cb) throws Throwable {
167
168
169
170
171 Channel serverChannel = null;
172 Channel clientChannel = null;
173 try {
174 Random r = new Random();
175 final int soSndBuf = 1024;
176 ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
177 final ByteBuf expectedContent = alloc.buffer(soSndBuf * 2);
178 expectedContent.writeBytes(newRandomBytes(expectedContent.writableBytes(), r));
179 final CountDownLatch latch = new CountDownLatch(1);
180 final AtomicReference<Object> clientReceived = new AtomicReference<Object>();
181 sb.childOption(ChannelOption.SO_SNDBUF, soSndBuf)
182 .childHandler(new ChannelInitializer<Channel>() {
183 @Override
184 protected void initChannel(Channel ch) throws Exception {
185 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
186 @Override
187 public void channelActive(ChannelHandlerContext ctx) throws Exception {
188 compositeBufferPartialWriteDoesNotCorruptDataInitServerConfig(ctx.channel().config(),
189 soSndBuf);
190
191 int offset = soSndBuf - 100;
192 ctx.write(expectedContent.retainedSlice(expectedContent.readerIndex(), offset));
193
194
195 CompositeByteBuf compositeByteBuf = ctx.alloc().compositeBuffer();
196 compositeByteBuf.addComponent(true,
197 expectedContent.retainedSlice(expectedContent.readerIndex() + offset, 50));
198 offset += 50;
199 compositeByteBuf.addComponent(true,
200 expectedContent.retainedSlice(expectedContent.readerIndex() + offset, 200));
201 offset += 200;
202 ctx.write(compositeByteBuf);
203
204
205
206 ctx.write(expectedContent.retainedSlice(expectedContent.readerIndex() + offset, 50));
207 offset += 50;
208
209
210 ctx.writeAndFlush(expectedContent.retainedSlice(expectedContent.readerIndex() + offset,
211 expectedContent.readableBytes() - expectedContent.readerIndex() - offset))
212 .addListener(ChannelFutureListener.CLOSE);
213 }
214
215 @Override
216 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
217
218 if (!(cause instanceof IOException)) {
219 clientReceived.set(cause);
220 latch.countDown();
221 }
222 }
223 });
224 }
225 });
226 cb.handler(new ChannelInitializer<Channel>() {
227 @Override
228 protected void initChannel(Channel ch) throws Exception {
229 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
230 private ByteBuf aggregator;
231 @Override
232 public void handlerAdded(ChannelHandlerContext ctx) {
233 aggregator = ctx.alloc().buffer(expectedContent.readableBytes());
234 }
235
236 @Override
237 public void channelRead(ChannelHandlerContext ctx, Object msg) {
238 try {
239 if (msg instanceof ByteBuf) {
240 aggregator.writeBytes((ByteBuf) msg);
241 }
242 } finally {
243 ReferenceCountUtil.release(msg);
244 }
245 }
246
247 @Override
248 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
249
250 if (!(cause instanceof IOException)) {
251 clientReceived.set(cause);
252 latch.countDown();
253 }
254 }
255
256 @Override
257 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
258 if (clientReceived.compareAndSet(null, aggregator)) {
259 try {
260 assertEquals(expectedContent.readableBytes(), aggregator.readableBytes());
261 } catch (Throwable cause) {
262 aggregator.release();
263 aggregator = null;
264 clientReceived.set(cause);
265 } finally {
266 latch.countDown();
267 }
268 }
269 }
270 });
271 }
272 });
273
274 serverChannel = sb.bind().syncUninterruptibly().channel();
275 clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
276
277 latch.await();
278 Object received = clientReceived.get();
279 if (received instanceof ByteBuf) {
280 ByteBuf actual = (ByteBuf) received;
281 assertEquals(expectedContent, actual);
282 expectedContent.release();
283 actual.release();
284 } else {
285 expectedContent.release();
286 throw (Throwable) received;
287 }
288 } finally {
289 if (clientChannel != null) {
290 clientChannel.close().sync();
291 }
292 if (serverChannel != null) {
293 serverChannel.close().sync();
294 }
295 }
296 }
297
298 private static ByteBuf newCompositeBuffer(ByteBufAllocator alloc) {
299 CompositeByteBuf compositeByteBuf = alloc.compositeBuffer();
300 compositeByteBuf.addComponent(true, alloc.directBuffer(4).writeInt(100));
301 compositeByteBuf.addComponent(true, alloc.directBuffer(8).writeLong(123));
302 compositeByteBuf.addComponent(true, alloc.directBuffer(8).writeLong(456));
303 assertEquals(EXPECTED_BYTES, compositeByteBuf.readableBytes());
304 return compositeByteBuf;
305 }
306
307 private static byte[] newRandomBytes(int size, Random r) {
308 byte[] bytes = new byte[size];
309 r.nextBytes(bytes);
310 return bytes;
311 }
312 }