1
2
3
4
5
6
7
8
9
10
11
12
13 package io.netty5.testsuite.transport.socket;
14
15 import io.netty5.bootstrap.Bootstrap;
16 import io.netty5.bootstrap.ServerBootstrap;
17 import io.netty5.buffer.api.Buffer;
18 import io.netty5.channel.Channel;
19 import io.netty5.channel.ChannelHandlerContext;
20 import io.netty5.channel.ChannelInitializer;
21 import io.netty5.channel.ChannelOption;
22 import io.netty5.channel.SimpleChannelInboundHandler;
23 import io.netty5.channel.socket.SocketChannel;
24 import io.netty5.handler.traffic.AbstractTrafficShapingHandler;
25 import io.netty5.handler.traffic.ChannelTrafficShapingHandler;
26 import io.netty5.handler.traffic.GlobalTrafficShapingHandler;
27 import io.netty5.handler.traffic.TrafficCounter;
28 import io.netty5.util.concurrent.DefaultEventExecutorGroup;
29 import io.netty5.util.concurrent.EventExecutorGroup;
30 import io.netty5.util.concurrent.Promise;
31 import io.netty5.util.internal.logging.InternalLogger;
32 import io.netty5.util.internal.logging.InternalLoggerFactory;
33 import org.junit.jupiter.api.AfterAll;
34 import org.junit.jupiter.api.BeforeAll;
35 import org.junit.jupiter.api.Test;
36 import org.junit.jupiter.api.TestInfo;
37 import org.junit.jupiter.api.Timeout;
38
39 import java.io.IOException;
40 import java.util.Arrays;
41 import java.util.Random;
42 import java.util.concurrent.Executors;
43 import java.util.concurrent.ScheduledExecutorService;
44 import java.util.concurrent.TimeUnit;
45 import java.util.concurrent.atomic.AtomicReference;
46
47 import static org.junit.jupiter.api.Assertions.assertTrue;
48
49 public class TrafficShapingHandlerTest extends AbstractSocketTest {
50 private static final InternalLogger logger = InternalLoggerFactory.getInstance(TrafficShapingHandlerTest.class);
51 private static final InternalLogger loggerServer = InternalLoggerFactory.getInstance("ServerTSH");
52 private static final InternalLogger loggerClient = InternalLoggerFactory.getInstance("ClientTSH");
53
54 static final int messageSize = 1024;
55 static final int bandwidthFactor = 12;
56 static final int minfactor = 3;
57 static final int maxfactor = bandwidthFactor + bandwidthFactor / 2;
58 static final long stepms = (1000 / bandwidthFactor - 10) / 10 * 10;
59 static final long minimalms = Math.max(stepms / 2, 20) / 10 * 10;
60 static final long check = 10;
61 private static final Random random = new Random();
62 static final byte[] data = new byte[messageSize];
63
64 private static final String TRAFFIC = "traffic";
65 private static String currentTestName;
66 private static int currentTestRun;
67
68 private static EventExecutorGroup group;
69 private static EventExecutorGroup groupForGlobal;
70 private static final ScheduledExecutorService executor = Executors.newScheduledThreadPool(10);
71 static {
72 random.nextBytes(data);
73 }
74
75 @BeforeAll
76 public static void createGroup() {
77 logger.info("Bandwidth: " + minfactor + " <= " + bandwidthFactor + " <= " + maxfactor +
78 " StepMs: " + stepms + " MinMs: " + minimalms + " CheckMs: " + check);
79 group = new DefaultEventExecutorGroup(8);
80 groupForGlobal = new DefaultEventExecutorGroup(8);
81 }
82
83 @AfterAll
84 public static void destroyGroup() throws Exception {
85 group.shutdownGracefully().asStage().sync();
86 groupForGlobal.shutdownGracefully().asStage().sync();
87 executor.shutdown();
88 }
89
90 private static long[] computeWaitRead(int[] multipleMessage) {
91 long[] minimalWaitBetween = new long[multipleMessage.length + 1];
92 minimalWaitBetween[0] = 0;
93 for (int i = 0; i < multipleMessage.length; i++) {
94 if (multipleMessage[i] > 1) {
95 minimalWaitBetween[i + 1] = (multipleMessage[i] - 1) * stepms + minimalms;
96 } else {
97 minimalWaitBetween[i + 1] = 10;
98 }
99 }
100 return minimalWaitBetween;
101 }
102
103 private static long[] computeWaitWrite(int[] multipleMessage) {
104 long[] minimalWaitBetween = new long[multipleMessage.length + 1];
105 for (int i = 0; i < multipleMessage.length; i++) {
106 if (multipleMessage[i] > 1) {
107 minimalWaitBetween[i] = (multipleMessage[i] - 1) * stepms + minimalms;
108 } else {
109 minimalWaitBetween[i] = 10;
110 }
111 }
112 return minimalWaitBetween;
113 }
114
115 private static long[] computeWaitAutoRead(int []autoRead) {
116 long [] minimalWaitBetween = new long[autoRead.length + 1];
117 minimalWaitBetween[0] = 0;
118 for (int i = 0; i < autoRead.length; i++) {
119 if (autoRead[i] != 0) {
120 if (autoRead[i] > 0) {
121 minimalWaitBetween[i + 1] = -1;
122 } else {
123 minimalWaitBetween[i + 1] = check;
124 }
125 } else {
126 minimalWaitBetween[i + 1] = 0;
127 }
128 }
129 return minimalWaitBetween;
130 }
131
132 @Test
133 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
134 public void testNoTrafficShaping(TestInfo testInfo) throws Throwable {
135 currentTestName = "TEST NO TRAFFIC";
136 currentTestRun = 0;
137 run(testInfo, this::testNoTrafficShaping);
138 }
139
140 public void testNoTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
141 int[] autoRead = null;
142 int[] multipleMessage = { 1, 2, 1 };
143 long[] minimalWaitBetween = null;
144 testTrafficShaping0(sb, cb, false, false, false, false, autoRead, minimalWaitBetween, multipleMessage);
145 }
146
147 @Test
148 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
149 public void testWriteTrafficShaping(TestInfo testInfo) throws Throwable {
150 currentTestName = "TEST WRITE";
151 currentTestRun = 0;
152 run(testInfo, this::testWriteTrafficShaping);
153 }
154
155 public void testWriteTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
156 int[] autoRead = null;
157 int[] multipleMessage = { 1, 2, 1, 1 };
158 long[] minimalWaitBetween = computeWaitWrite(multipleMessage);
159 testTrafficShaping0(sb, cb, false, false, true, false, autoRead, minimalWaitBetween, multipleMessage);
160 }
161
162 @Test
163 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
164 public void testReadTrafficShaping(TestInfo testInfo) throws Throwable {
165 currentTestName = "TEST READ";
166 currentTestRun = 0;
167 run(testInfo, this::testReadTrafficShaping);
168 }
169
170 public void testReadTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
171 int[] autoRead = null;
172 int[] multipleMessage = { 1, 2, 1, 1 };
173 long[] minimalWaitBetween = computeWaitRead(multipleMessage);
174 testTrafficShaping0(sb, cb, false, true, false, false, autoRead, minimalWaitBetween, multipleMessage);
175 }
176
177 @Test
178 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
179 public void testWrite1TrafficShaping(TestInfo testInfo) throws Throwable {
180 currentTestName = "TEST WRITE";
181 currentTestRun = 0;
182 run(testInfo, this::testWrite1TrafficShaping);
183 }
184
185 public void testWrite1TrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
186 int[] autoRead = null;
187 int[] multipleMessage = { 1, 1, 1 };
188 long[] minimalWaitBetween = computeWaitWrite(multipleMessage);
189 testTrafficShaping0(sb, cb, false, false, true, false, autoRead, minimalWaitBetween, multipleMessage);
190 }
191
192 @Test
193 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
194 public void testRead1TrafficShaping(TestInfo testInfo) throws Throwable {
195 currentTestName = "TEST READ";
196 currentTestRun = 0;
197 run(testInfo, this::testRead1TrafficShaping);
198 }
199
200 public void testRead1TrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
201 int[] autoRead = null;
202 int[] multipleMessage = { 1, 1, 1 };
203 long[] minimalWaitBetween = computeWaitRead(multipleMessage);
204 testTrafficShaping0(sb, cb, false, true, false, false, autoRead, minimalWaitBetween, multipleMessage);
205 }
206
207 @Test
208 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
209 public void testWriteGlobalTrafficShaping(TestInfo testInfo) throws Throwable {
210 currentTestName = "TEST GLOBAL WRITE";
211 currentTestRun = 0;
212 run(testInfo, this::testWriteGlobalTrafficShaping);
213 }
214
215 public void testWriteGlobalTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
216 int[] autoRead = null;
217 int[] multipleMessage = { 1, 2, 1, 1 };
218 long[] minimalWaitBetween = computeWaitWrite(multipleMessage);
219 testTrafficShaping0(sb, cb, false, false, true, true, autoRead, minimalWaitBetween, multipleMessage);
220 }
221
222 @Test
223 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
224 public void testReadGlobalTrafficShaping(TestInfo testInfo) throws Throwable {
225 currentTestName = "TEST GLOBAL READ";
226 currentTestRun = 0;
227 run(testInfo, this::testReadGlobalTrafficShaping);
228 }
229
230 public void testReadGlobalTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
231 int[] autoRead = null;
232 int[] multipleMessage = { 1, 2, 1, 1 };
233 long[] minimalWaitBetween = computeWaitRead(multipleMessage);
234 testTrafficShaping0(sb, cb, false, true, false, true, autoRead, minimalWaitBetween, multipleMessage);
235 }
236
237 @Test
238 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
239 public void testAutoReadTrafficShaping(TestInfo testInfo) throws Throwable {
240 currentTestName = "TEST AUTO READ";
241 currentTestRun = 0;
242 run(testInfo, this::testAutoReadTrafficShaping);
243 }
244
245 public void testAutoReadTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
246 int[] autoRead = { 1, -1, -1, 1, -2, 0, 1, 0, -3, 0, 1, 2, 0 };
247 int[] multipleMessage = new int[autoRead.length];
248 Arrays.fill(multipleMessage, 1);
249 long[] minimalWaitBetween = computeWaitAutoRead(autoRead);
250 testTrafficShaping0(sb, cb, false, true, false, false, autoRead, minimalWaitBetween, multipleMessage);
251 }
252
253 @Test
254 @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
255 public void testAutoReadGlobalTrafficShaping(TestInfo testInfo) throws Throwable {
256 currentTestName = "TEST AUTO READ GLOBAL";
257 currentTestRun = 0;
258 run(testInfo, this::testAutoReadGlobalTrafficShaping);
259 }
260
261 public void testAutoReadGlobalTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
262 int[] autoRead = { 1, -1, -1, 1, -2, 0, 1, 0, -3, 0, 1, 2, 0 };
263 int[] multipleMessage = new int[autoRead.length];
264 Arrays.fill(multipleMessage, 1);
265 long[] minimalWaitBetween = computeWaitAutoRead(autoRead);
266 testTrafficShaping0(sb, cb, false, true, false, true, autoRead, minimalWaitBetween, multipleMessage);
267 }
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287 private static void testTrafficShaping0(
288 ServerBootstrap sb, Bootstrap cb, final boolean additionalExecutor,
289 final boolean limitRead, final boolean limitWrite, final boolean globalLimit, int[] autoRead,
290 long[] minimalWaitBetween, int[] multipleMessage) throws Throwable {
291
292 currentTestRun++;
293 logger.info("TEST: " + currentTestName + " RUN: " + currentTestRun +
294 " Exec: " + additionalExecutor + " Read: " + limitRead + " Write: " + limitWrite + " Global: "
295 + globalLimit);
296 final ServerHandler sh = new ServerHandler(autoRead, multipleMessage);
297 Promise<Boolean> promise = group.next().newPromise();
298 final ClientHandler ch = new ClientHandler(promise, minimalWaitBetween, multipleMessage,
299 autoRead);
300
301 final AbstractTrafficShapingHandler handler;
302 if (limitRead) {
303 if (globalLimit) {
304 handler = new GlobalTrafficShapingHandler(groupForGlobal, 0, bandwidthFactor * messageSize, check);
305 } else {
306 handler = new ChannelTrafficShapingHandler(0, bandwidthFactor * messageSize, check);
307 }
308 } else if (limitWrite) {
309 if (globalLimit) {
310 handler = new GlobalTrafficShapingHandler(groupForGlobal, bandwidthFactor * messageSize, 0, check);
311 } else {
312 handler = new ChannelTrafficShapingHandler(bandwidthFactor * messageSize, 0, check);
313 }
314 } else {
315 handler = null;
316 }
317
318 sb.childHandler(new ChannelInitializer<SocketChannel>() {
319 @Override
320 protected void initChannel(SocketChannel c) throws Exception {
321 if (limitRead) {
322 c.pipeline().addLast(TRAFFIC, handler);
323 }
324 c.pipeline().addLast(sh);
325 }
326 });
327 cb.handler(new ChannelInitializer<SocketChannel>() {
328 @Override
329 protected void initChannel(SocketChannel c) throws Exception {
330 if (limitWrite) {
331 c.pipeline().addLast(TRAFFIC, handler);
332 }
333 c.pipeline().addLast(ch);
334 }
335 });
336
337 Channel sc = sb.bind().asStage().get();
338 Channel cc = cb.connect(sc.localAddress()).asStage().get();
339
340 int totalNb = 0;
341 for (int i = 1; i < multipleMessage.length; i++) {
342 totalNb += multipleMessage[i];
343 }
344 Long start = TrafficCounter.milliSecondFromNano();
345 int nb = multipleMessage[0];
346 for (int i = 0; i < nb; i++) {
347 cc.write(cc.bufferAllocator().copyOf(data));
348 }
349 cc.flush();
350
351 promise.asFuture().asStage().await();
352 Long stop = TrafficCounter.milliSecondFromNano();
353 assertTrue(promise.isSuccess(), "Error during execution of TrafficShapping: " + promise.cause());
354
355 float average = (totalNb * messageSize) / (float) (stop - start);
356 logger.info("TEST: " + currentTestName + " RUN: " + currentTestRun +
357 " Average of traffic: " + average + " compare to " + bandwidthFactor);
358 sh.channel.close().asStage().sync();
359 ch.channel.close().asStage().sync();
360 sc.close().asStage().sync();
361 if (autoRead != null) {
362
363 Thread.sleep(minimalms);
364 }
365
366 if (autoRead == null && minimalWaitBetween != null) {
367 assertTrue(average <= maxfactor,
368 "Overall Traffic not ok since > " + maxfactor + ": " + average);
369 if (additionalExecutor) {
370
371 assertTrue(average >= 0.25, "Overall Traffic not ok since < 0.25: " + average);
372 } else {
373 assertTrue(average >= minfactor,
374 "Overall Traffic not ok since < " + minfactor + ": " + average);
375 }
376 }
377 if (handler != null && globalLimit) {
378 ((GlobalTrafficShapingHandler) handler).release();
379 }
380
381 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
382 throw sh.exception.get();
383 }
384 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
385 throw ch.exception.get();
386 }
387 if (sh.exception.get() != null) {
388 throw sh.exception.get();
389 }
390 if (ch.exception.get() != null) {
391 throw ch.exception.get();
392 }
393 }
394
395 private static class ClientHandler extends SimpleChannelInboundHandler<Buffer> {
396 volatile Channel channel;
397 final AtomicReference<Throwable> exception = new AtomicReference<>();
398 volatile int step;
399
400 private long currentLastTime = TrafficCounter.milliSecondFromNano();
401 private final long[] minimalWaitBetween;
402 private final int[] multipleMessage;
403 private final int[] autoRead;
404 final Promise<Boolean> promise;
405
406 ClientHandler(Promise<Boolean> promise, long[] minimalWaitBetween, int[] multipleMessage,
407 int[] autoRead) {
408 this.minimalWaitBetween = minimalWaitBetween;
409 this.multipleMessage = Arrays.copyOf(multipleMessage, multipleMessage.length);
410 this.promise = promise;
411 this.autoRead = autoRead;
412 }
413
414 @Override
415 public void channelActive(ChannelHandlerContext ctx) throws Exception {
416 channel = ctx.channel();
417 }
418
419 @Override
420 public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
421 long lastTimestamp = 0;
422 loggerClient.debug("Step: " + step + " Read: " + in.readableBytes() / 8 + " blocks");
423 while (in.readableBytes() > 0) {
424 lastTimestamp = in.readLong();
425 multipleMessage[step]--;
426 }
427 if (multipleMessage[step] > 0) {
428
429 return;
430 }
431 long minimalWait = minimalWaitBetween != null? minimalWaitBetween[step] : 0;
432 int ar = 0;
433 if (autoRead != null) {
434 if (step > 0 && autoRead[step - 1] != 0) {
435 ar = autoRead[step - 1];
436 }
437 }
438 loggerClient.info("Step: " + step + " Interval: " + (lastTimestamp - currentLastTime) + " compareTo "
439 + minimalWait + " (" + ar + ')');
440 assertTrue(lastTimestamp - currentLastTime >= minimalWait,
441 "The interval of time is incorrect:" + (lastTimestamp - currentLastTime) + " not> " + minimalWait);
442 currentLastTime = lastTimestamp;
443 step++;
444 if (multipleMessage.length > step) {
445 int nb = multipleMessage[step];
446 for (int i = 0; i < nb; i++) {
447 channel.write(channel.bufferAllocator().copyOf(data));
448 }
449 channel.flush();
450 } else {
451 promise.setSuccess(true);
452 }
453 }
454
455 @Override
456 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
457 if (exception.compareAndSet(null, cause)) {
458 cause.printStackTrace();
459 promise.setFailure(cause);
460 ctx.close();
461 }
462 }
463 }
464
465 private static class ServerHandler extends SimpleChannelInboundHandler<Buffer> {
466 private final int[] autoRead;
467 private final int[] multipleMessage;
468 volatile Channel channel;
469 volatile int step;
470 final AtomicReference<Throwable> exception = new AtomicReference<>();
471
472 ServerHandler(int[] autoRead, int[] multipleMessage) {
473 this.autoRead = autoRead;
474 this.multipleMessage = Arrays.copyOf(multipleMessage, multipleMessage.length);
475 }
476
477 @Override
478 public void channelActive(ChannelHandlerContext ctx) throws Exception {
479 channel = ctx.channel();
480 }
481
482 @Override
483 public void messageReceived(final ChannelHandlerContext ctx, Buffer in) throws Exception {
484 byte[] actual = new byte[in.readableBytes()];
485 int nb = actual.length / messageSize;
486 loggerServer.info("Step: " + step + " Read: " + nb + " blocks");
487 in.readBytes(actual, 0, actual.length);
488 long timestamp = TrafficCounter.milliSecondFromNano();
489 int isAutoRead = 0;
490 int laststep = step;
491 for (int i = 0; i < nb; i++) {
492 multipleMessage[step]--;
493 if (multipleMessage[step] == 0) {
494
495 if (autoRead != null) {
496 isAutoRead = autoRead[step];
497 }
498 step++;
499 }
500 }
501 if (laststep != step) {
502
503 if (autoRead != null && isAutoRead != 2) {
504 if (isAutoRead != 0) {
505 loggerServer.info("Step: " + step + " Set AutoRead: " + (isAutoRead > 0));
506 channel.setOption(ChannelOption.AUTO_READ, isAutoRead > 0);
507 } else {
508 loggerServer.info("Step: " + step + " AutoRead: NO");
509 }
510 }
511 }
512 Thread.sleep(10);
513 loggerServer.debug("Step: " + step + " Write: " + nb);
514 for (int i = 0; i < nb; i++) {
515 channel.write(ctx.bufferAllocator().allocate(8).writeLong(timestamp));
516 }
517 channel.flush();
518 if (laststep != step) {
519
520 if (isAutoRead != 0) {
521 if (isAutoRead < 0) {
522 final int exactStep = step;
523 long wait = isAutoRead == -1? minimalms : stepms + minimalms;
524 if (isAutoRead == -3) {
525 wait = stepms * 3;
526 }
527 executor.schedule(() -> {
528 loggerServer.info("Step: " + exactStep + " Reset AutoRead");
529 channel.setOption(ChannelOption.AUTO_READ, true);
530 }, wait, TimeUnit.MILLISECONDS);
531 } else {
532 if (isAutoRead > 1) {
533 loggerServer.debug("Step: " + step + " Will Set AutoRead: True");
534 final int exactStep = step;
535 executor.schedule(() -> {
536 loggerServer.info("Step: " + exactStep + " Set AutoRead: True");
537 channel.setOption(ChannelOption.AUTO_READ, true);
538 }, stepms + minimalms, TimeUnit.MILLISECONDS);
539 }
540 }
541 }
542 }
543 }
544
545 @Override
546 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
547 if (exception.compareAndSet(null, cause)) {
548 cause.printStackTrace();
549 ctx.close();
550 }
551 }
552 }
553 }