1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  package io.netty5.handler.traffic;
17  
18  import io.netty5.channel.Channel;
19  import io.netty5.channel.ChannelHandlerContext;
20  import io.netty5.channel.ChannelOption;
21  import io.netty5.util.Attribute;
22  import io.netty5.util.Resource;
23  import io.netty5.util.concurrent.EventExecutor;
24  import io.netty5.util.concurrent.EventExecutorGroup;
25  import io.netty5.util.concurrent.Future;
26  import io.netty5.util.concurrent.Promise;
27  import io.netty5.util.internal.logging.InternalLogger;
28  import io.netty5.util.internal.logging.InternalLoggerFactory;
29  
30  import java.util.AbstractCollection;
31  import java.util.ArrayDeque;
32  import java.util.Collection;
33  import java.util.Iterator;
34  import java.util.concurrent.ConcurrentHashMap;
35  import java.util.concurrent.ConcurrentMap;
36  import java.util.concurrent.TimeUnit;
37  import java.util.concurrent.atomic.AtomicLong;
38  
39  import static io.netty5.util.internal.ObjectUtil.checkNotNullWithIAE;
40  import static io.netty5.util.internal.ObjectUtil.checkPositive;
41  import static io.netty5.util.internal.ObjectUtil.checkPositiveOrZero;
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  public class GlobalChannelTrafficShapingHandler extends AbstractTrafficShapingHandler {
91      private static final InternalLogger logger =
92              InternalLoggerFactory.getInstance(GlobalChannelTrafficShapingHandler.class);
93      
94  
95  
96      final ConcurrentMap<Integer, PerChannel> channelQueues = new ConcurrentHashMap<>();
97  
98      
99  
100 
101     private final AtomicLong queuesSize = new AtomicLong();
102 
103     
104 
105 
106     private final AtomicLong cumulativeWrittenBytes = new AtomicLong();
107 
108     
109 
110 
111     private final AtomicLong cumulativeReadBytes = new AtomicLong();
112 
113     
114 
115 
116 
117     volatile long maxGlobalWriteSize = DEFAULT_MAX_SIZE * 100; 
118 
119     
120 
121 
122     private volatile long writeChannelLimit;
123 
124     
125 
126 
127     private volatile long readChannelLimit;
128 
129     private static final float DEFAULT_DEVIATION = 0.1F;
130     private static final float MAX_DEVIATION = 0.4F;
131     private static final float DEFAULT_SLOWDOWN = 0.4F;
132     private static final float DEFAULT_ACCELERATION = -0.1F;
133     private volatile float maxDeviation;
134     private volatile float accelerationFactor;
135     private volatile float slowDownFactor;
136     private volatile boolean readDeviationActive;
137     private volatile boolean writeDeviationActive;
138 
139     static final class PerChannel {
140         ArrayDeque<ToSend> messagesQueue;
141         TrafficCounter channelTrafficCounter;
142         long queueSize;
143         long lastWriteTimestamp;
144         long lastReadTimestamp;
145     }
146 
147     
148 
149 
150     void createGlobalTrafficCounter(EventExecutorGroup executor) {
151         
152         setMaxDeviation(DEFAULT_DEVIATION, DEFAULT_SLOWDOWN, DEFAULT_ACCELERATION);
153         checkNotNullWithIAE(executor, "executor");
154         TrafficCounter tc = new GlobalChannelTrafficCounter(this, executor, "GlobalChannelTC", checkInterval);
155         setTrafficCounter(tc);
156         tc.start();
157     }
158 
159     @Override
160     protected int userDefinedWritabilityIndex() {
161         return AbstractTrafficShapingHandler.GLOBALCHANNEL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX;
162     }
163 
164     
165 
166 
167 
168 
169 
170 
171 
172 
173 
174 
175 
176 
177 
178 
179 
180 
181 
182 
183     public GlobalChannelTrafficShapingHandler(EventExecutorGroup executor,
184             long writeGlobalLimit, long readGlobalLimit,
185             long writeChannelLimit, long readChannelLimit,
186             long checkInterval, long maxTime) {
187         super(writeGlobalLimit, readGlobalLimit, checkInterval, maxTime);
188         createGlobalTrafficCounter(executor);
189         this.writeChannelLimit = writeChannelLimit;
190         this.readChannelLimit = readChannelLimit;
191     }
192 
193     
194 
195 
196 
197 
198 
199 
200 
201 
202 
203 
204 
205 
206 
207 
208 
209 
210     public GlobalChannelTrafficShapingHandler(EventExecutorGroup executor,
211             long writeGlobalLimit, long readGlobalLimit,
212             long writeChannelLimit, long readChannelLimit,
213             long checkInterval) {
214         super(writeGlobalLimit, readGlobalLimit, checkInterval);
215         this.writeChannelLimit = writeChannelLimit;
216         this.readChannelLimit = readChannelLimit;
217         createGlobalTrafficCounter(executor);
218     }
219 
220     
221 
222 
223 
224 
225 
226 
227 
228 
229 
230 
231 
232 
233 
234     public GlobalChannelTrafficShapingHandler(EventExecutorGroup executor,
235             long writeGlobalLimit, long readGlobalLimit,
236             long writeChannelLimit, long readChannelLimit) {
237         super(writeGlobalLimit, readGlobalLimit);
238         this.writeChannelLimit = writeChannelLimit;
239         this.readChannelLimit = readChannelLimit;
240         createGlobalTrafficCounter(executor);
241     }
242 
243     
244 
245 
246 
247 
248 
249 
250 
251 
252     public GlobalChannelTrafficShapingHandler(EventExecutorGroup executor, long checkInterval) {
253         super(checkInterval);
254         createGlobalTrafficCounter(executor);
255     }
256 
257     
258 
259 
260 
261 
262 
263     public GlobalChannelTrafficShapingHandler(EventExecutorGroup executor) {
264         createGlobalTrafficCounter(executor);
265     }
266 
267     @Override
268     public boolean isSharable() {
269         return true;
270     }
271 
272     
273 
274 
275     public float maxDeviation() {
276         return maxDeviation;
277     }
278 
279     
280 
281 
282     public float accelerationFactor() {
283         return accelerationFactor;
284     }
285 
286     
287 
288 
289     public float slowDownFactor() {
290         return slowDownFactor;
291     }
292 
293     
294 
295 
296 
297 
298 
299 
300 
301 
302 
303 
304     public void setMaxDeviation(float maxDeviation, float slowDownFactor, float accelerationFactor) {
305         if (maxDeviation > MAX_DEVIATION) {
306             throw new IllegalArgumentException("maxDeviation must be <= " + MAX_DEVIATION);
307         }
308         checkPositiveOrZero(slowDownFactor, "slowDownFactor");
309         if (accelerationFactor > 0) {
310             throw new IllegalArgumentException("accelerationFactor must be <= 0");
311         }
312         this.maxDeviation = maxDeviation;
313         this.accelerationFactor = 1 + accelerationFactor;
314         this.slowDownFactor = 1 + slowDownFactor;
315     }
316 
317     private void computeDeviationCumulativeBytes() {
318         
319         long maxWrittenBytes = 0;
320         long maxReadBytes = 0;
321         long minWrittenBytes = Long.MAX_VALUE;
322         long minReadBytes = Long.MAX_VALUE;
323         for (PerChannel perChannel : channelQueues.values()) {
324             long value = perChannel.channelTrafficCounter.cumulativeWrittenBytes();
325             if (maxWrittenBytes < value) {
326                 maxWrittenBytes = value;
327             }
328             if (minWrittenBytes > value) {
329                 minWrittenBytes = value;
330             }
331             value = perChannel.channelTrafficCounter.cumulativeReadBytes();
332             if (maxReadBytes < value) {
333                 maxReadBytes = value;
334             }
335             if (minReadBytes > value) {
336                 minReadBytes = value;
337             }
338         }
339         boolean multiple = channelQueues.size() > 1;
340         readDeviationActive = multiple && minReadBytes < maxReadBytes / 2;
341         writeDeviationActive = multiple && minWrittenBytes < maxWrittenBytes / 2;
342         cumulativeWrittenBytes.set(maxWrittenBytes);
343         cumulativeReadBytes.set(maxReadBytes);
344     }
345 
346     @Override
347     protected void doAccounting(TrafficCounter counter) {
348         computeDeviationCumulativeBytes();
349         super.doAccounting(counter);
350     }
351 
352     private long computeBalancedWait(float maxLocal, float maxGlobal, long wait) {
353         if (maxGlobal == 0) {
354             
355             return wait;
356         }
357         float ratio = maxLocal / maxGlobal;
358         
359         if (ratio > maxDeviation) {
360             if (ratio < 1 - maxDeviation) {
361                 return wait;
362             } else {
363                 ratio = slowDownFactor;
364                 if (wait < MINIMAL_WAIT) {
365                     wait = MINIMAL_WAIT;
366                 }
367             }
368         } else {
369             ratio = accelerationFactor;
370         }
371         return (long) (wait * ratio);
372     }
373 
374     
375 
376 
377     public long getMaxGlobalWriteSize() {
378         return maxGlobalWriteSize;
379     }
380 
381     
382 
383 
384 
385 
386 
387 
388 
389 
390 
391     public void setMaxGlobalWriteSize(long maxGlobalWriteSize) {
392         this.maxGlobalWriteSize = checkPositive(maxGlobalWriteSize, "maxGlobalWriteSize");
393     }
394 
395     
396 
397 
398     public long queuesSize() {
399         return queuesSize.get();
400     }
401 
402     
403 
404 
405 
406     public void configureChannel(long newWriteLimit, long newReadLimit) {
407         writeChannelLimit = newWriteLimit;
408         readChannelLimit = newReadLimit;
409         long now = TrafficCounter.milliSecondFromNano();
410         for (PerChannel perChannel : channelQueues.values()) {
411             perChannel.channelTrafficCounter.resetAccounting(now);
412         }
413     }
414 
415     
416 
417 
418     public long getWriteChannelLimit() {
419         return writeChannelLimit;
420     }
421 
422     
423 
424 
425     public void setWriteChannelLimit(long writeLimit) {
426         writeChannelLimit = writeLimit;
427         long now = TrafficCounter.milliSecondFromNano();
428         for (PerChannel perChannel : channelQueues.values()) {
429             perChannel.channelTrafficCounter.resetAccounting(now);
430         }
431     }
432 
433     
434 
435 
436     public long getReadChannelLimit() {
437         return readChannelLimit;
438     }
439 
440     
441 
442 
443     public void setReadChannelLimit(long readLimit) {
444         readChannelLimit = readLimit;
445         long now = TrafficCounter.milliSecondFromNano();
446         for (PerChannel perChannel : channelQueues.values()) {
447             perChannel.channelTrafficCounter.resetAccounting(now);
448         }
449     }
450 
451     
452 
453 
454     public final void release() {
455         trafficCounter.stop();
456     }
457 
458     private PerChannel getOrSetPerChannel(ChannelHandlerContext ctx) {
459         
460         Channel channel = ctx.channel();
461         Integer key = channel.hashCode();
462         PerChannel perChannel = channelQueues.get(key);
463         if (perChannel == null) {
464             perChannel = new PerChannel();
465             perChannel.messagesQueue = new ArrayDeque<>();
466             
467             perChannel.channelTrafficCounter = new TrafficCounter(this, null, "ChannelTC" +
468                     ctx.channel().hashCode(), checkInterval);
469             perChannel.queueSize = 0L;
470             perChannel.lastReadTimestamp = TrafficCounter.milliSecondFromNano();
471             perChannel.lastWriteTimestamp = perChannel.lastReadTimestamp;
472             channelQueues.put(key, perChannel);
473         }
474         return perChannel;
475     }
476 
477     @Override
478     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
479         getOrSetPerChannel(ctx);
480         trafficCounter.resetCumulativeTime();
481         super.handlerAdded(ctx);
482     }
483 
484     @Override
485     public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
486         trafficCounter.resetCumulativeTime();
487         Channel channel = ctx.channel();
488         Integer key = channel.hashCode();
489         PerChannel perChannel = channelQueues.remove(key);
490         if (perChannel != null) {
491             
492             synchronized (perChannel) {
493                 if (channel.isActive()) {
494                     for (ToSend toSend : perChannel.messagesQueue) {
495                         long size = calculateSize(toSend.toSend);
496                         trafficCounter.bytesRealWriteFlowControl(size);
497                         perChannel.channelTrafficCounter.bytesRealWriteFlowControl(size);
498                         perChannel.queueSize -= size;
499                         queuesSize.addAndGet(-size);
500                         ctx.write(toSend.toSend).cascadeTo(toSend.promise);
501                     }
502                 } else {
503                     queuesSize.addAndGet(-perChannel.queueSize);
504                     for (ToSend toSend : perChannel.messagesQueue) {
505                         if (Resource.isAccessible(toSend.toSend, false)) {
506                             Resource.dispose(toSend.toSend);
507                         }
508                     }
509                 }
510                 perChannel.messagesQueue.clear();
511             }
512         }
513         releaseWriteSuspended(ctx);
514         releaseReadSuspended(ctx);
515         super.handlerRemoved(ctx);
516     }
517 
518     @Override
519     public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
520         long size = calculateSize(msg);
521         long now = TrafficCounter.milliSecondFromNano();
522         if (size > 0) {
523             
524             long waitGlobal = trafficCounter.readTimeToWait(size, getReadLimit(), maxTime, now);
525             Integer key = ctx.channel().hashCode();
526             PerChannel perChannel = channelQueues.get(key);
527             long wait = 0;
528             if (perChannel != null) {
529                 wait = perChannel.channelTrafficCounter.readTimeToWait(size, readChannelLimit, maxTime, now);
530                 if (readDeviationActive) {
531                     
532                     long maxLocalRead;
533                     maxLocalRead = perChannel.channelTrafficCounter.cumulativeReadBytes();
534                     long maxGlobalRead = cumulativeReadBytes.get();
535                     if (maxLocalRead <= 0) {
536                         maxLocalRead = 0;
537                     }
538                     if (maxGlobalRead < maxLocalRead) {
539                         maxGlobalRead = maxLocalRead;
540                     }
541                     wait = computeBalancedWait(maxLocalRead, maxGlobalRead, wait);
542                 }
543             }
544             if (wait < waitGlobal) {
545                 wait = waitGlobal;
546             }
547             wait = checkWaitReadTime(ctx, wait, now);
548             if (wait >= MINIMAL_WAIT) { 
549                 
550                 
551                 Channel channel = ctx.channel();
552                 if (logger.isDebugEnabled()) {
553                     logger.debug("Read Suspend: " + wait + ':' + channel.getOption(ChannelOption.AUTO_READ) + ':'
554                             + isHandlerActive(ctx));
555                 }
556                 if (channel.getOption(ChannelOption.AUTO_READ) && isHandlerActive(ctx)) {
557                     channel.setOption(ChannelOption.AUTO_READ, false);
558                     channel.attr(READ_SUSPENDED).set(true);
559                     
560                     
561                     Attribute<Runnable> attr = channel.attr(REOPEN_TASK);
562                     Runnable reopenTask = attr.get();
563                     if (reopenTask == null) {
564                         reopenTask = new ReopenReadTimerTask(ctx);
565                         attr.set(reopenTask);
566                     }
567                     ctx.executor().schedule(reopenTask, wait, TimeUnit.MILLISECONDS);
568                     if (logger.isDebugEnabled()) {
569                         logger.debug("Suspend final status => " + channel.getOption(ChannelOption.AUTO_READ) + ':'
570                                 + isHandlerActive(ctx) + " will reopened at: " + wait);
571                     }
572                 }
573             }
574         }
575         informReadOperation(ctx, now);
576         ctx.fireChannelRead(msg);
577     }
578 
579     @Override
580     protected long checkWaitReadTime(final ChannelHandlerContext ctx, long wait, final long now) {
581         Integer key = ctx.channel().hashCode();
582         PerChannel perChannel = channelQueues.get(key);
583         if (perChannel != null) {
584             if (wait > maxTime && now + wait - perChannel.lastReadTimestamp > maxTime) {
585                 wait = maxTime;
586             }
587         }
588         return wait;
589     }
590 
591     @Override
592     protected void informReadOperation(final ChannelHandlerContext ctx, final long now) {
593         Integer key = ctx.channel().hashCode();
594         PerChannel perChannel = channelQueues.get(key);
595         if (perChannel != null) {
596             perChannel.lastReadTimestamp = now;
597         }
598     }
599 
600     private static final class ToSend {
601         final long relativeTimeAction;
602         final Object toSend;
603         final Promise<Void> promise;
604         final long size;
605 
606         private ToSend(final long delay, final Object toSend, final long size, final Promise<Void> promise) {
607             relativeTimeAction = delay;
608             this.toSend = toSend;
609             this.size = size;
610             this.promise = promise;
611         }
612     }
613 
614     protected long maximumCumulativeWrittenBytes() {
615         return cumulativeWrittenBytes.get();
616     }
617 
618     protected long maximumCumulativeReadBytes() {
619         return cumulativeReadBytes.get();
620     }
621 
622     
623 
624 
625 
626     public Collection<TrafficCounter> channelTrafficCounters() {
627         return new AbstractCollection<>() {
628             @Override
629             public Iterator<TrafficCounter> iterator() {
630                 return new Iterator<>() {
631                     final Iterator<PerChannel> iter = channelQueues.values().iterator();
632 
633                     @Override
634                     public boolean hasNext() {
635                         return iter.hasNext();
636                     }
637 
638                     @Override
639                     public TrafficCounter next() {
640                         return iter.next().channelTrafficCounter;
641                     }
642 
643                     @Override
644                     public void remove() {
645                         throw new UnsupportedOperationException();
646                     }
647                 };
648             }
649 
650             @Override
651             public int size() {
652                 return channelQueues.size();
653             }
654         };
655     }
656 
657     @Override
658     public Future<Void> write(final ChannelHandlerContext ctx, final Object msg) {
659         long size = calculateSize(msg);
660         long now = TrafficCounter.milliSecondFromNano();
661         if (size > 0) {
662             
663             long waitGlobal = trafficCounter.writeTimeToWait(size, getWriteLimit(), maxTime, now);
664             Integer key = ctx.channel().hashCode();
665             PerChannel perChannel = channelQueues.get(key);
666             long wait = 0;
667             if (perChannel != null) {
668                 wait = perChannel.channelTrafficCounter.writeTimeToWait(size, writeChannelLimit, maxTime, now);
669                 if (writeDeviationActive) {
670                     
671                     long maxLocalWrite;
672                     maxLocalWrite = perChannel.channelTrafficCounter.cumulativeWrittenBytes();
673                     long maxGlobalWrite = cumulativeWrittenBytes.get();
674                     if (maxLocalWrite <= 0) {
675                         maxLocalWrite = 0;
676                     }
677                     if (maxGlobalWrite < maxLocalWrite) {
678                         maxGlobalWrite = maxLocalWrite;
679                     }
680                     wait = computeBalancedWait(maxLocalWrite, maxGlobalWrite, wait);
681                 }
682             }
683             if (wait < waitGlobal) {
684                 wait = waitGlobal;
685             }
686             if (wait >= MINIMAL_WAIT) {
687                 if (logger.isDebugEnabled()) {
688                     logger.debug("Write suspend: " + wait + ':' + ctx.channel().getOption(ChannelOption.AUTO_READ) + ':'
689                             + isHandlerActive(ctx));
690                 }
691                 Promise<Void> promise = ctx.newPromise();
692                 submitWrite(ctx, msg, size, wait, now, promise);
693                 return promise.asFuture();
694             }
695         }
696         Promise<Void> promise = ctx.newPromise();
697         
698         submitWrite(ctx, msg, size, 0, now, promise);
699         return promise.asFuture();
700     }
701 
702     @Override
703     protected void submitWrite(final ChannelHandlerContext ctx, final Object msg,
704             final long size, final long writedelay, final long now,
705             final Promise<Void> promise) {
706         Channel channel = ctx.channel();
707         Integer key = channel.hashCode();
708         PerChannel perChannel = channelQueues.get(key);
709         if (perChannel == null) {
710             
711             
712             perChannel = getOrSetPerChannel(ctx);
713         }
714         final ToSend newToSend;
715         long delay = writedelay;
716         boolean globalSizeExceeded = false;
717         
718         synchronized (perChannel) {
719             if (writedelay == 0 && perChannel.messagesQueue.isEmpty()) {
720                 trafficCounter.bytesRealWriteFlowControl(size);
721                 perChannel.channelTrafficCounter.bytesRealWriteFlowControl(size);
722                 ctx.write(msg).cascadeTo(promise);
723                 perChannel.lastWriteTimestamp = now;
724                 return;
725             }
726             if (delay > maxTime && now + delay - perChannel.lastWriteTimestamp > maxTime) {
727                 delay = maxTime;
728             }
729             newToSend = new ToSend(delay + now, msg, size, promise);
730             perChannel.messagesQueue.addLast(newToSend);
731             perChannel.queueSize += size;
732             queuesSize.addAndGet(size);
733             checkWriteSuspend(ctx, delay, perChannel.queueSize);
734             if (queuesSize.get() > maxGlobalWriteSize) {
735                 globalSizeExceeded = true;
736             }
737         }
738         if (globalSizeExceeded) {
739             setUserDefinedWritability(ctx, false);
740         }
741         final long futureNow = newToSend.relativeTimeAction;
742         final PerChannel forSchedule = perChannel;
743         ctx.executor().schedule(() -> sendAllValid(ctx, forSchedule, futureNow), delay, TimeUnit.MILLISECONDS);
744     }
745 
746     private void sendAllValid(final ChannelHandlerContext ctx, final PerChannel perChannel, final long now) {
747         
748         synchronized (perChannel) {
749             ToSend newToSend = perChannel.messagesQueue.pollFirst();
750             for (; newToSend != null; newToSend = perChannel.messagesQueue.pollFirst()) {
751                 if (newToSend.relativeTimeAction <= now) {
752                     long size = newToSend.size;
753                     trafficCounter.bytesRealWriteFlowControl(size);
754                     perChannel.channelTrafficCounter.bytesRealWriteFlowControl(size);
755                     perChannel.queueSize -= size;
756                     queuesSize.addAndGet(-size);
757                     ctx.write(newToSend.toSend).cascadeTo(newToSend.promise);
758                     perChannel.lastWriteTimestamp = now;
759                 } else {
760                     perChannel.messagesQueue.addFirst(newToSend);
761                     break;
762                 }
763             }
764             if (perChannel.messagesQueue.isEmpty()) {
765                 releaseWriteSuspended(ctx);
766             }
767         }
768         ctx.flush();
769     }
770 
771     @Override
772     public String toString() {
773         return new StringBuilder(340).append(super.toString())
774             .append(" Write Channel Limit: ").append(writeChannelLimit)
775             .append(" Read Channel Limit: ").append(readChannelLimit).toString();
776     }
777 }