View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    * The Netty Project licenses this file to you under the Apache License,
4    * version 2.0 (the "License"); you may not use this file except in compliance
5    * with the License. You may obtain a copy of the License at:
6    * https://www.apache.org/licenses/LICENSE-2.0
7    * Unless required by applicable law or agreed to in writing, software
8    * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9    * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10   * License for the specific language governing permissions and limitations
11   * under the License.
12   */
13  package io.netty5.testsuite.transport.socket;
14  
15  import io.netty5.bootstrap.Bootstrap;
16  import io.netty5.bootstrap.ServerBootstrap;
17  import io.netty5.buffer.api.Buffer;
18  import io.netty5.channel.Channel;
19  import io.netty5.channel.ChannelHandlerContext;
20  import io.netty5.channel.ChannelInitializer;
21  import io.netty5.channel.ChannelOption;
22  import io.netty5.channel.SimpleChannelInboundHandler;
23  import io.netty5.channel.socket.SocketChannel;
24  import io.netty5.handler.traffic.AbstractTrafficShapingHandler;
25  import io.netty5.handler.traffic.ChannelTrafficShapingHandler;
26  import io.netty5.handler.traffic.GlobalTrafficShapingHandler;
27  import io.netty5.handler.traffic.TrafficCounter;
28  import io.netty5.util.concurrent.DefaultEventExecutorGroup;
29  import io.netty5.util.concurrent.EventExecutorGroup;
30  import io.netty5.util.concurrent.Promise;
31  import io.netty5.util.internal.logging.InternalLogger;
32  import io.netty5.util.internal.logging.InternalLoggerFactory;
33  import org.junit.jupiter.api.AfterAll;
34  import org.junit.jupiter.api.BeforeAll;
35  import org.junit.jupiter.api.Test;
36  import org.junit.jupiter.api.TestInfo;
37  import org.junit.jupiter.api.Timeout;
38  
39  import java.io.IOException;
40  import java.util.Arrays;
41  import java.util.Random;
42  import java.util.concurrent.Executors;
43  import java.util.concurrent.ScheduledExecutorService;
44  import java.util.concurrent.TimeUnit;
45  import java.util.concurrent.atomic.AtomicReference;
46  
47  import static org.junit.jupiter.api.Assertions.assertTrue;
48  
49  public class TrafficShapingHandlerTest extends AbstractSocketTest {
50      private static final InternalLogger logger = InternalLoggerFactory.getInstance(TrafficShapingHandlerTest.class);
51      private static final InternalLogger loggerServer = InternalLoggerFactory.getInstance("ServerTSH");
52      private static final InternalLogger loggerClient = InternalLoggerFactory.getInstance("ClientTSH");
53  
54      static final int messageSize = 1024;
55      static final int bandwidthFactor = 12;
56      static final int minfactor = 3;
57      static final int maxfactor = bandwidthFactor + bandwidthFactor / 2;
58      static final long stepms = (1000 / bandwidthFactor - 10) / 10 * 10;
59      static final long minimalms = Math.max(stepms / 2, 20) / 10 * 10;
60      static final long check = 10;
61      private static final Random random = new Random();
62      static final byte[] data = new byte[messageSize];
63  
64      private static final String TRAFFIC = "traffic";
65      private static String currentTestName;
66      private static int currentTestRun;
67  
68      private static EventExecutorGroup group;
69      private static EventExecutorGroup groupForGlobal;
70      private static final ScheduledExecutorService executor = Executors.newScheduledThreadPool(10);
71      static {
72          random.nextBytes(data);
73      }
74  
75      @BeforeAll
76      public static void createGroup() {
77          logger.info("Bandwidth: " + minfactor + " <= " + bandwidthFactor + " <= " + maxfactor +
78                      " StepMs: " + stepms + " MinMs: " + minimalms + " CheckMs: " + check);
79          group = new DefaultEventExecutorGroup(8);
80          groupForGlobal = new DefaultEventExecutorGroup(8);
81      }
82  
83      @AfterAll
84      public static void destroyGroup() throws Exception {
85          group.shutdownGracefully().asStage().sync();
86          groupForGlobal.shutdownGracefully().asStage().sync();
87          executor.shutdown();
88      }
89  
90      private static long[] computeWaitRead(int[] multipleMessage) {
91          long[] minimalWaitBetween = new long[multipleMessage.length + 1];
92          minimalWaitBetween[0] = 0;
93          for (int i = 0; i < multipleMessage.length; i++) {
94              if (multipleMessage[i] > 1) {
95                  minimalWaitBetween[i + 1] = (multipleMessage[i] - 1) * stepms + minimalms;
96              } else {
97                  minimalWaitBetween[i + 1] = 10;
98              }
99          }
100         return minimalWaitBetween;
101     }
102 
103     private static long[] computeWaitWrite(int[] multipleMessage) {
104         long[] minimalWaitBetween = new long[multipleMessage.length + 1];
105         for (int i = 0; i < multipleMessage.length; i++) {
106             if (multipleMessage[i] > 1) {
107                 minimalWaitBetween[i] = (multipleMessage[i] - 1) * stepms + minimalms;
108             } else {
109                 minimalWaitBetween[i] = 10;
110             }
111         }
112         return minimalWaitBetween;
113     }
114 
115     private static long[] computeWaitAutoRead(int []autoRead) {
116         long [] minimalWaitBetween = new long[autoRead.length + 1];
117         minimalWaitBetween[0] = 0;
118         for (int i = 0; i < autoRead.length; i++) {
119             if (autoRead[i] != 0) {
120                 if (autoRead[i] > 0) {
121                     minimalWaitBetween[i + 1] = -1;
122                 } else {
123                     minimalWaitBetween[i + 1] = check;
124                 }
125             } else {
126                 minimalWaitBetween[i + 1] = 0;
127             }
128         }
129         return minimalWaitBetween;
130     }
131 
132     @Test
133     @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
134     public void testNoTrafficShaping(TestInfo testInfo) throws Throwable {
135         currentTestName = "TEST NO TRAFFIC";
136         currentTestRun = 0;
137         run(testInfo, this::testNoTrafficShaping);
138     }
139 
140     public void testNoTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
141         int[] autoRead = null;
142         int[] multipleMessage = { 1, 2, 1 };
143         long[] minimalWaitBetween = null;
144         testTrafficShaping0(sb, cb, false, false, false, false, autoRead, minimalWaitBetween, multipleMessage);
145     }
146 
147     @Test
148     @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
149     public void testWriteTrafficShaping(TestInfo testInfo) throws Throwable {
150         currentTestName = "TEST WRITE";
151         currentTestRun = 0;
152         run(testInfo, this::testWriteTrafficShaping);
153     }
154 
155     public void testWriteTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
156         int[] autoRead = null;
157         int[] multipleMessage = { 1, 2, 1, 1 };
158         long[] minimalWaitBetween = computeWaitWrite(multipleMessage);
159         testTrafficShaping0(sb, cb, false, false, true, false, autoRead, minimalWaitBetween, multipleMessage);
160     }
161 
162     @Test
163     @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
164     public void testReadTrafficShaping(TestInfo testInfo) throws Throwable {
165         currentTestName = "TEST READ";
166         currentTestRun = 0;
167         run(testInfo, this::testReadTrafficShaping);
168     }
169 
170     public void testReadTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
171         int[] autoRead = null;
172         int[] multipleMessage = { 1, 2, 1, 1 };
173         long[] minimalWaitBetween = computeWaitRead(multipleMessage);
174         testTrafficShaping0(sb, cb, false, true, false, false, autoRead, minimalWaitBetween, multipleMessage);
175     }
176 
177     @Test
178     @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
179     public void testWrite1TrafficShaping(TestInfo testInfo) throws Throwable {
180         currentTestName = "TEST WRITE";
181         currentTestRun = 0;
182         run(testInfo, this::testWrite1TrafficShaping);
183     }
184 
185     public void testWrite1TrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
186         int[] autoRead = null;
187         int[] multipleMessage = { 1, 1, 1 };
188         long[] minimalWaitBetween = computeWaitWrite(multipleMessage);
189         testTrafficShaping0(sb, cb, false, false, true, false, autoRead, minimalWaitBetween, multipleMessage);
190     }
191 
192     @Test
193     @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
194     public void testRead1TrafficShaping(TestInfo testInfo) throws Throwable {
195         currentTestName = "TEST READ";
196         currentTestRun = 0;
197         run(testInfo, this::testRead1TrafficShaping);
198     }
199 
200     public void testRead1TrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
201         int[] autoRead = null;
202         int[] multipleMessage = { 1, 1, 1 };
203         long[] minimalWaitBetween = computeWaitRead(multipleMessage);
204         testTrafficShaping0(sb, cb, false, true, false, false, autoRead, minimalWaitBetween, multipleMessage);
205     }
206 
207     @Test
208     @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
209     public void testWriteGlobalTrafficShaping(TestInfo testInfo) throws Throwable {
210         currentTestName = "TEST GLOBAL WRITE";
211         currentTestRun = 0;
212         run(testInfo, this::testWriteGlobalTrafficShaping);
213     }
214 
215     public void testWriteGlobalTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
216         int[] autoRead = null;
217         int[] multipleMessage = { 1, 2, 1, 1 };
218         long[] minimalWaitBetween = computeWaitWrite(multipleMessage);
219         testTrafficShaping0(sb, cb, false, false, true, true, autoRead, minimalWaitBetween, multipleMessage);
220     }
221 
222     @Test
223     @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
224     public void testReadGlobalTrafficShaping(TestInfo testInfo) throws Throwable {
225         currentTestName = "TEST GLOBAL READ";
226         currentTestRun = 0;
227         run(testInfo, this::testReadGlobalTrafficShaping);
228     }
229 
230     public void testReadGlobalTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
231         int[] autoRead = null;
232         int[] multipleMessage = { 1, 2, 1, 1 };
233         long[] minimalWaitBetween = computeWaitRead(multipleMessage);
234         testTrafficShaping0(sb, cb, false, true, false, true, autoRead, minimalWaitBetween, multipleMessage);
235     }
236 
237     @Test
238     @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
239     public void testAutoReadTrafficShaping(TestInfo testInfo) throws Throwable {
240         currentTestName = "TEST AUTO READ";
241         currentTestRun = 0;
242         run(testInfo, this::testAutoReadTrafficShaping);
243     }
244 
245     public void testAutoReadTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
246         int[] autoRead = { 1, -1, -1, 1, -2, 0, 1, 0, -3, 0, 1, 2, 0 };
247         int[] multipleMessage = new int[autoRead.length];
248         Arrays.fill(multipleMessage, 1);
249         long[] minimalWaitBetween = computeWaitAutoRead(autoRead);
250         testTrafficShaping0(sb, cb, false, true, false, false, autoRead, minimalWaitBetween, multipleMessage);
251     }
252 
253     @Test
254     @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS)
255     public void testAutoReadGlobalTrafficShaping(TestInfo testInfo) throws Throwable {
256         currentTestName = "TEST AUTO READ GLOBAL";
257         currentTestRun = 0;
258         run(testInfo, this::testAutoReadGlobalTrafficShaping);
259     }
260 
261     public void testAutoReadGlobalTrafficShaping(ServerBootstrap sb, Bootstrap cb) throws Throwable {
262         int[] autoRead = { 1, -1, -1, 1, -2, 0, 1, 0, -3, 0, 1, 2, 0 };
263         int[] multipleMessage = new int[autoRead.length];
264         Arrays.fill(multipleMessage, 1);
265         long[] minimalWaitBetween = computeWaitAutoRead(autoRead);
266         testTrafficShaping0(sb, cb, false, true, false, true, autoRead, minimalWaitBetween, multipleMessage);
267     }
268 
269     /**
270      *
271      * @param additionalExecutor
272      *            shall the pipeline add the handler using an additional executor
273      * @param limitRead
274      *            True to set Read Limit on Server side
275      * @param limitWrite
276      *            True to set Write Limit on Client side
277      * @param globalLimit
278      *            True to change Channel to Global TrafficShapping
279      * @param minimalWaitBetween
280      *            time in ms that should be waited before getting the final result (note: for READ the values are
281      *            right shifted once, the first value being 0)
282      * @param multipleMessage
283      *            how many message to send at each step (for READ: the first should be 1, as the two last steps to
284      *            ensure correct testing)
285      * @throws Throwable if something goes wrong, and the test fails.
286      */
287     private static void testTrafficShaping0(
288             ServerBootstrap sb, Bootstrap cb, final boolean additionalExecutor,
289             final boolean limitRead, final boolean limitWrite, final boolean globalLimit, int[] autoRead,
290             long[] minimalWaitBetween, int[] multipleMessage) throws Throwable {
291 
292         currentTestRun++;
293         logger.info("TEST: " + currentTestName + " RUN: " + currentTestRun +
294                     " Exec: " + additionalExecutor + " Read: " + limitRead + " Write: " + limitWrite + " Global: "
295                     + globalLimit);
296         final ServerHandler sh = new ServerHandler(autoRead, multipleMessage);
297         Promise<Boolean> promise = group.next().newPromise();
298         final ClientHandler ch = new ClientHandler(promise, minimalWaitBetween, multipleMessage,
299                                                    autoRead);
300 
301         final AbstractTrafficShapingHandler handler;
302         if (limitRead) {
303             if (globalLimit) {
304                 handler = new GlobalTrafficShapingHandler(groupForGlobal, 0, bandwidthFactor * messageSize, check);
305             } else {
306                 handler = new ChannelTrafficShapingHandler(0, bandwidthFactor * messageSize, check);
307             }
308         } else if (limitWrite) {
309             if (globalLimit) {
310                 handler = new GlobalTrafficShapingHandler(groupForGlobal, bandwidthFactor * messageSize, 0, check);
311             } else {
312                 handler = new ChannelTrafficShapingHandler(bandwidthFactor * messageSize, 0, check);
313             }
314         } else {
315             handler = null;
316         }
317 
318         sb.childHandler(new ChannelInitializer<SocketChannel>() {
319             @Override
320             protected void initChannel(SocketChannel c) throws Exception {
321                 if (limitRead) {
322                     c.pipeline().addLast(TRAFFIC, handler);
323                 }
324                 c.pipeline().addLast(sh);
325             }
326         });
327         cb.handler(new ChannelInitializer<SocketChannel>() {
328             @Override
329             protected void initChannel(SocketChannel c) throws Exception {
330                 if (limitWrite) {
331                     c.pipeline().addLast(TRAFFIC, handler);
332                 }
333                 c.pipeline().addLast(ch);
334             }
335         });
336 
337         Channel sc = sb.bind().asStage().get();
338         Channel cc = cb.connect(sc.localAddress()).asStage().get();
339 
340         int totalNb = 0;
341         for (int i = 1; i < multipleMessage.length; i++) {
342             totalNb += multipleMessage[i];
343         }
344         Long start = TrafficCounter.milliSecondFromNano();
345         int nb = multipleMessage[0];
346         for (int i = 0; i < nb; i++) {
347             cc.write(cc.bufferAllocator().copyOf(data));
348         }
349         cc.flush();
350 
351         promise.asFuture().asStage().await();
352         Long stop = TrafficCounter.milliSecondFromNano();
353         assertTrue(promise.isSuccess(), "Error during execution of TrafficShapping: " + promise.cause());
354 
355         float average = (totalNb * messageSize) / (float) (stop - start);
356         logger.info("TEST: " + currentTestName + " RUN: " + currentTestRun +
357                     " Average of traffic: " + average + " compare to " + bandwidthFactor);
358         sh.channel.close().asStage().sync();
359         ch.channel.close().asStage().sync();
360         sc.close().asStage().sync();
361         if (autoRead != null) {
362             // for extra release call in AutoRead
363             Thread.sleep(minimalms);
364         }
365 
366         if (autoRead == null && minimalWaitBetween != null) {
367             assertTrue(average <= maxfactor,
368                 "Overall Traffic not ok since > " + maxfactor + ": " + average);
369             if (additionalExecutor) {
370                 // Oio is not as good when using additionalExecutor
371                 assertTrue(average >= 0.25, "Overall Traffic not ok since < 0.25: " + average);
372             } else {
373                 assertTrue(average >= minfactor,
374                     "Overall Traffic not ok since < " + minfactor + ": " + average);
375             }
376         }
377         if (handler != null && globalLimit) {
378             ((GlobalTrafficShapingHandler) handler).release();
379         }
380 
381         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
382             throw sh.exception.get();
383         }
384         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
385             throw ch.exception.get();
386         }
387         if (sh.exception.get() != null) {
388             throw sh.exception.get();
389         }
390         if (ch.exception.get() != null) {
391             throw ch.exception.get();
392         }
393     }
394 
395     private static class ClientHandler extends SimpleChannelInboundHandler<Buffer> {
396         volatile Channel channel;
397         final AtomicReference<Throwable> exception = new AtomicReference<>();
398         volatile int step;
399         // first message will always be validated
400         private long currentLastTime = TrafficCounter.milliSecondFromNano();
401         private final long[] minimalWaitBetween;
402         private final int[] multipleMessage;
403         private final int[] autoRead;
404         final Promise<Boolean> promise;
405 
406         ClientHandler(Promise<Boolean> promise, long[] minimalWaitBetween, int[] multipleMessage,
407                       int[] autoRead) {
408             this.minimalWaitBetween = minimalWaitBetween;
409             this.multipleMessage = Arrays.copyOf(multipleMessage, multipleMessage.length);
410             this.promise = promise;
411             this.autoRead = autoRead;
412         }
413 
414         @Override
415         public void channelActive(ChannelHandlerContext ctx) throws Exception {
416             channel = ctx.channel();
417         }
418 
419         @Override
420         public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
421             long lastTimestamp = 0;
422             loggerClient.debug("Step: " + step + " Read: " + in.readableBytes() / 8 + " blocks");
423             while (in.readableBytes() > 0) {
424                 lastTimestamp = in.readLong();
425                 multipleMessage[step]--;
426             }
427             if (multipleMessage[step] > 0) {
428                 // still some message to get
429                 return;
430             }
431             long minimalWait = minimalWaitBetween != null? minimalWaitBetween[step] : 0;
432             int ar = 0;
433             if (autoRead != null) {
434                 if (step > 0 && autoRead[step - 1] != 0) {
435                     ar = autoRead[step - 1];
436                 }
437             }
438             loggerClient.info("Step: " + step + " Interval: " + (lastTimestamp - currentLastTime) + " compareTo "
439                               + minimalWait + " (" + ar + ')');
440             assertTrue(lastTimestamp - currentLastTime >= minimalWait,
441                     "The interval of time is incorrect:" + (lastTimestamp - currentLastTime) + " not> " + minimalWait);
442             currentLastTime = lastTimestamp;
443             step++;
444             if (multipleMessage.length > step) {
445                 int nb = multipleMessage[step];
446                 for (int i = 0; i < nb; i++) {
447                     channel.write(channel.bufferAllocator().copyOf(data));
448                 }
449                 channel.flush();
450             } else {
451                 promise.setSuccess(true);
452             }
453         }
454 
455         @Override
456         public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
457             if (exception.compareAndSet(null, cause)) {
458                 cause.printStackTrace();
459                 promise.setFailure(cause);
460                 ctx.close();
461             }
462         }
463     }
464 
465     private static class ServerHandler extends SimpleChannelInboundHandler<Buffer> {
466         private final int[] autoRead;
467         private final int[] multipleMessage;
468         volatile Channel channel;
469         volatile int step;
470         final AtomicReference<Throwable> exception = new AtomicReference<>();
471 
472         ServerHandler(int[] autoRead, int[] multipleMessage) {
473             this.autoRead = autoRead;
474             this.multipleMessage = Arrays.copyOf(multipleMessage, multipleMessage.length);
475         }
476 
477         @Override
478         public void channelActive(ChannelHandlerContext ctx) throws Exception {
479             channel = ctx.channel();
480         }
481 
482         @Override
483         public void messageReceived(final ChannelHandlerContext ctx, Buffer in) throws Exception {
484             byte[] actual = new byte[in.readableBytes()];
485             int nb = actual.length / messageSize;
486             loggerServer.info("Step: " + step + " Read: " + nb + " blocks");
487             in.readBytes(actual, 0, actual.length);
488             long timestamp = TrafficCounter.milliSecondFromNano();
489             int isAutoRead = 0;
490             int laststep = step;
491             for (int i = 0; i < nb; i++) {
492                 multipleMessage[step]--;
493                 if (multipleMessage[step] == 0) {
494                     // setAutoRead test
495                     if (autoRead != null) {
496                         isAutoRead = autoRead[step];
497                     }
498                     step++;
499                 }
500             }
501             if (laststep != step) {
502                 // setAutoRead test
503                 if (autoRead != null && isAutoRead != 2) {
504                     if (isAutoRead != 0) {
505                         loggerServer.info("Step: " + step + " Set AutoRead: " + (isAutoRead > 0));
506                         channel.setOption(ChannelOption.AUTO_READ, isAutoRead > 0);
507                     } else {
508                         loggerServer.info("Step: " + step + " AutoRead: NO");
509                     }
510                 }
511             }
512             Thread.sleep(10);
513             loggerServer.debug("Step: " + step + " Write: " + nb);
514             for (int i = 0; i < nb; i++) {
515                 channel.write(ctx.bufferAllocator().allocate(8).writeLong(timestamp));
516             }
517             channel.flush();
518             if (laststep != step) {
519                 // setAutoRead test
520                 if (isAutoRead != 0) {
521                     if (isAutoRead < 0) {
522                         final int exactStep = step;
523                         long wait = isAutoRead == -1? minimalms : stepms + minimalms;
524                         if (isAutoRead == -3) {
525                             wait = stepms * 3;
526                         }
527                         executor.schedule(() -> {
528                             loggerServer.info("Step: " + exactStep + " Reset AutoRead");
529                             channel.setOption(ChannelOption.AUTO_READ, true);
530                         }, wait, TimeUnit.MILLISECONDS);
531                     } else {
532                         if (isAutoRead > 1) {
533                             loggerServer.debug("Step: " + step + " Will Set AutoRead: True");
534                             final int exactStep = step;
535                             executor.schedule(() -> {
536                                 loggerServer.info("Step: " + exactStep + " Set AutoRead: True");
537                                 channel.setOption(ChannelOption.AUTO_READ, true);
538                             }, stepms + minimalms, TimeUnit.MILLISECONDS);
539                         }
540                     }
541                 }
542             }
543         }
544 
545         @Override
546         public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
547             if (exception.compareAndSet(null, cause)) {
548                 cause.printStackTrace();
549                 ctx.close();
550             }
551         }
552     }
553 }