1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  package io.netty.handler.codec.http2;
16  
17  import io.netty.util.collection.IntCollections;
18  import io.netty.util.collection.IntObjectHashMap;
19  import io.netty.util.collection.IntObjectMap;
20  import io.netty.util.internal.DefaultPriorityQueue;
21  import io.netty.util.internal.EmptyPriorityQueue;
22  import io.netty.util.internal.MathUtil;
23  import io.netty.util.internal.PriorityQueue;
24  import io.netty.util.internal.PriorityQueueNode;
25  import io.netty.util.internal.SystemPropertyUtil;
26  
27  import java.io.Serializable;
28  import java.util.ArrayList;
29  import java.util.Comparator;
30  import java.util.Iterator;
31  import java.util.List;
32  
33  import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID;
34  import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MIN_ALLOCATION_CHUNK;
35  import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
36  import static io.netty.handler.codec.http2.Http2CodecUtil.streamableBytes;
37  import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR;
38  import static io.netty.handler.codec.http2.Http2Exception.connectionError;
39  import static io.netty.util.internal.ObjectUtil.checkPositive;
40  import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
41  import static java.lang.Integer.MAX_VALUE;
42  import static java.lang.Math.max;
43  import static java.lang.Math.min;
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  public final class WeightedFairQueueByteDistributor implements StreamByteDistributor {
59      
60  
61  
62  
63  
64  
65  
66  
67      static final int INITIAL_CHILDREN_MAP_SIZE =
68              max(1, SystemPropertyUtil.getInt("io.netty.http2.childrenMapSize", 2));
69      
70  
71  
72      private static final int DEFAULT_MAX_STATE_ONLY_SIZE = 5;
73  
74      private final Http2Connection.PropertyKey stateKey;
75      
76  
77  
78  
79      private final IntObjectMap<State> stateOnlyMap;
80      
81  
82  
83  
84      private final PriorityQueue<State> stateOnlyRemovalQueue;
85      private final Http2Connection connection;
86      private final State connectionState;
87      
88  
89  
90  
91      private int allocationQuantum = DEFAULT_MIN_ALLOCATION_CHUNK;
92      private final int maxStateOnlySize;
93  
94      public WeightedFairQueueByteDistributor(Http2Connection connection) {
95          this(connection, DEFAULT_MAX_STATE_ONLY_SIZE);
96      }
97  
98      public WeightedFairQueueByteDistributor(Http2Connection connection, int maxStateOnlySize) {
99          checkPositiveOrZero(maxStateOnlySize, "maxStateOnlySize");
100         if (maxStateOnlySize == 0) {
101             stateOnlyMap = IntCollections.emptyMap();
102             stateOnlyRemovalQueue = EmptyPriorityQueue.instance();
103         } else {
104             stateOnlyMap = new IntObjectHashMap<State>(maxStateOnlySize);
105             
106             
107             stateOnlyRemovalQueue = new DefaultPriorityQueue<State>(StateOnlyComparator.INSTANCE, maxStateOnlySize + 2);
108         }
109         this.maxStateOnlySize = maxStateOnlySize;
110 
111         this.connection = connection;
112         stateKey = connection.newKey();
113         final Http2Stream connectionStream = connection.connectionStream();
114         connectionStream.setProperty(stateKey, connectionState = new State(connectionStream, 16));
115 
116         
117         connection.addListener(new Http2ConnectionAdapter() {
118             @Override
119             public void onStreamAdded(Http2Stream stream) {
120                 State state = stateOnlyMap.remove(stream.id());
121                 if (state == null) {
122                     state = new State(stream);
123                     
124                     List<ParentChangedEvent> events = new ArrayList<ParentChangedEvent>(1);
125                     connectionState.takeChild(state, false, events);
126                     notifyParentChanged(events);
127                 } else {
128                     stateOnlyRemovalQueue.removeTyped(state);
129                     state.stream = stream;
130                 }
131                 switch (stream.state()) {
132                     case RESERVED_REMOTE:
133                     case RESERVED_LOCAL:
134                         state.setStreamReservedOrActivated();
135                         
136                         
137                         break;
138                     default:
139                         break;
140                 }
141                 stream.setProperty(stateKey, state);
142             }
143 
144             @Override
145             public void onStreamActive(Http2Stream stream) {
146                 state(stream).setStreamReservedOrActivated();
147                 
148                 
149             }
150 
151             @Override
152             public void onStreamClosed(Http2Stream stream) {
153                 state(stream).close();
154             }
155 
156             @Override
157             public void onStreamRemoved(Http2Stream stream) {
158                 
159                 
160                 
161                 State state = state(stream);
162 
163                 
164                 
165                 
166                 state.stream = null;
167 
168                 if (WeightedFairQueueByteDistributor.this.maxStateOnlySize == 0) {
169                     state.parent.removeChild(state);
170                     return;
171                 }
172                 if (stateOnlyRemovalQueue.size() == WeightedFairQueueByteDistributor.this.maxStateOnlySize) {
173                     State stateToRemove = stateOnlyRemovalQueue.peek();
174                     if (StateOnlyComparator.INSTANCE.compare(stateToRemove, state) >= 0) {
175                         
176                         
177                         state.parent.removeChild(state);
178                         return;
179                     }
180                     stateOnlyRemovalQueue.poll();
181                     stateToRemove.parent.removeChild(stateToRemove);
182                     stateOnlyMap.remove(stateToRemove.streamId);
183                 }
184                 stateOnlyRemovalQueue.add(state);
185                 stateOnlyMap.put(state.streamId, state);
186             }
187         });
188     }
189 
190     @Override
191     public void updateStreamableBytes(StreamState state) {
192         state(state.stream()).updateStreamableBytes(streamableBytes(state),
193                                                     state.hasFrame() && state.windowSize() >= 0);
194     }
195 
196     @Override
197     public void updateDependencyTree(int childStreamId, int parentStreamId, short weight, boolean exclusive) {
198         State state = state(childStreamId);
199         if (state == null) {
200             
201             
202             
203             if (maxStateOnlySize == 0) {
204                 return;
205             }
206             state = new State(childStreamId);
207             stateOnlyRemovalQueue.add(state);
208             stateOnlyMap.put(childStreamId, state);
209         }
210 
211         State newParent = state(parentStreamId);
212         if (newParent == null) {
213             
214             
215             
216             if (maxStateOnlySize == 0) {
217                 return;
218             }
219             newParent = new State(parentStreamId);
220             stateOnlyRemovalQueue.add(newParent);
221             stateOnlyMap.put(parentStreamId, newParent);
222             
223             List<ParentChangedEvent> events = new ArrayList<ParentChangedEvent>(1);
224             connectionState.takeChild(newParent, false, events);
225             notifyParentChanged(events);
226         }
227 
228         
229         
230         if (state.activeCountForTree != 0 && state.parent != null) {
231             state.parent.totalQueuedWeights += weight - state.weight;
232         }
233         state.weight = weight;
234 
235         if (newParent != state.parent || exclusive && newParent.children.size() != 1) {
236             final List<ParentChangedEvent> events;
237             if (newParent.isDescendantOf(state)) {
238                 events = new ArrayList<ParentChangedEvent>(2 + (exclusive ? newParent.children.size() : 0));
239                 state.parent.takeChild(newParent, false, events);
240             } else {
241                 events = new ArrayList<ParentChangedEvent>(1 + (exclusive ? newParent.children.size() : 0));
242             }
243             newParent.takeChild(state, exclusive, events);
244             notifyParentChanged(events);
245         }
246 
247         
248         
249         
250         while (stateOnlyRemovalQueue.size() > maxStateOnlySize) {
251             State stateToRemove = stateOnlyRemovalQueue.poll();
252             stateToRemove.parent.removeChild(stateToRemove);
253             stateOnlyMap.remove(stateToRemove.streamId);
254         }
255     }
256 
257     @Override
258     public boolean distribute(int maxBytes, Writer writer) throws Http2Exception {
259         
260         if (connectionState.activeCountForTree == 0) {
261             return false;
262         }
263 
264         
265         
266         
267         int oldIsActiveCountForTree;
268         do {
269             oldIsActiveCountForTree = connectionState.activeCountForTree;
270             
271             maxBytes -= distributeToChildren(maxBytes, writer, connectionState);
272         } while (connectionState.activeCountForTree != 0 &&
273                 (maxBytes > 0 || oldIsActiveCountForTree != connectionState.activeCountForTree));
274 
275         return connectionState.activeCountForTree != 0;
276     }
277 
278     
279 
280 
281 
282     public void allocationQuantum(int allocationQuantum) {
283         checkPositive(allocationQuantum, "allocationQuantum");
284         this.allocationQuantum = allocationQuantum;
285     }
286 
287     private int distribute(int maxBytes, Writer writer, State state) throws Http2Exception {
288         if (state.isActive()) {
289             int nsent = min(maxBytes, state.streamableBytes);
290             state.write(nsent, writer);
291             if (nsent == 0 && maxBytes != 0) {
292                 
293                 
294                 
295                 
296                 state.updateStreamableBytes(state.streamableBytes, false);
297             }
298             return nsent;
299         }
300 
301         return distributeToChildren(maxBytes, writer, state);
302     }
303 
304     
305 
306 
307 
308 
309 
310 
311 
312 
313 
314     private int distributeToChildren(int maxBytes, Writer writer, State state) throws Http2Exception {
315         long oldTotalQueuedWeights = state.totalQueuedWeights;
316         State childState = state.pollPseudoTimeQueue();
317         State nextChildState = state.peekPseudoTimeQueue();
318         childState.setDistributing();
319         try {
320             assert nextChildState == null || nextChildState.pseudoTimeToWrite >= childState.pseudoTimeToWrite :
321                 "nextChildState[" + nextChildState.streamId + "].pseudoTime(" + nextChildState.pseudoTimeToWrite +
322                 ") < " + " childState[" + childState.streamId + "].pseudoTime(" + childState.pseudoTimeToWrite + ')';
323             int nsent = distribute(nextChildState == null ? maxBytes :
324                             min(maxBytes, (int) min((nextChildState.pseudoTimeToWrite - childState.pseudoTimeToWrite) *
325                                                childState.weight / oldTotalQueuedWeights + allocationQuantum, MAX_VALUE)
326                                ),
327                                writer,
328                                childState);
329             state.pseudoTime += nsent;
330             childState.updatePseudoTime(state, nsent, oldTotalQueuedWeights);
331             return nsent;
332         } finally {
333             childState.unsetDistributing();
334             
335             
336             
337             if (childState.activeCountForTree != 0) {
338                 state.offerPseudoTimeQueue(childState);
339             }
340         }
341     }
342 
343     private State state(Http2Stream stream) {
344         return stream.getProperty(stateKey);
345     }
346 
347     private State state(int streamId) {
348         Http2Stream stream = connection.stream(streamId);
349         return stream != null ? state(stream) : stateOnlyMap.get(streamId);
350     }
351 
352     
353 
354 
355     boolean isChild(int childId, int parentId, short weight) {
356         State parent = state(parentId);
357         State child;
358         return parent.children.containsKey(childId) &&
359                 (child = state(childId)).parent == parent && child.weight == weight;
360     }
361 
362     
363 
364 
365     int numChildren(int streamId) {
366         State state = state(streamId);
367         return state == null ? 0 : state.children.size();
368     }
369 
370     
371 
372 
373 
374     void notifyParentChanged(List<ParentChangedEvent> events) {
375         for (int i = 0; i < events.size(); ++i) {
376             ParentChangedEvent event = events.get(i);
377             stateOnlyRemovalQueue.priorityChanged(event.state);
378             if (event.state.parent != null && event.state.activeCountForTree != 0) {
379                 event.state.parent.offerAndInitializePseudoTime(event.state);
380                 event.state.parent.activeCountChangeForTree(event.state.activeCountForTree);
381             }
382         }
383     }
384 
385     
386 
387 
388 
389 
390 
391 
392 
393     private static final class StateOnlyComparator implements Comparator<State>, Serializable {
394         private static final long serialVersionUID = -4806936913002105966L;
395 
396         static final StateOnlyComparator INSTANCE = new StateOnlyComparator();
397 
398         @Override
399         public int compare(State o1, State o2) {
400             
401             boolean o1Actived = o1.wasStreamReservedOrActivated();
402             if (o1Actived != o2.wasStreamReservedOrActivated()) {
403                 return o1Actived ? -1 : 1;
404             }
405             
406             int x = o2.dependencyTreeDepth - o1.dependencyTreeDepth;
407 
408             
409             
410             
411             
412             
413 
414             
415             return x != 0 ? x : o1.streamId - o2.streamId;
416         }
417     }
418 
419     private static final class StatePseudoTimeComparator implements Comparator<State>, Serializable {
420         private static final long serialVersionUID = -1437548640227161828L;
421 
422         static final StatePseudoTimeComparator INSTANCE = new StatePseudoTimeComparator();
423 
424         @Override
425         public int compare(State o1, State o2) {
426             return MathUtil.compare(o1.pseudoTimeToWrite, o2.pseudoTimeToWrite);
427         }
428     }
429 
430     
431 
432 
433     private final class State implements PriorityQueueNode {
434         private static final byte STATE_IS_ACTIVE = 0x1;
435         private static final byte STATE_IS_DISTRIBUTING = 0x2;
436         private static final byte STATE_STREAM_ACTIVATED = 0x4;
437 
438         
439 
440 
441         Http2Stream stream;
442         State parent;
443         IntObjectMap<State> children = IntCollections.emptyMap();
444         private final PriorityQueue<State> pseudoTimeQueue;
445         final int streamId;
446         int streamableBytes;
447         int dependencyTreeDepth;
448         
449 
450 
451         int activeCountForTree;
452         private int pseudoTimeQueueIndex = INDEX_NOT_IN_QUEUE;
453         private int stateOnlyQueueIndex = INDEX_NOT_IN_QUEUE;
454         
455 
456 
457         long pseudoTimeToWrite;
458         
459 
460 
461         long pseudoTime;
462         long totalQueuedWeights;
463         private byte flags;
464         short weight = DEFAULT_PRIORITY_WEIGHT;
465 
466         State(int streamId) {
467             this(streamId, null, 0);
468         }
469 
470         State(Http2Stream stream) {
471             this(stream, 0);
472         }
473 
474         State(Http2Stream stream, int initialSize) {
475             this(stream.id(), stream, initialSize);
476         }
477 
478         State(int streamId, Http2Stream stream, int initialSize) {
479             this.stream = stream;
480             this.streamId = streamId;
481             pseudoTimeQueue = new DefaultPriorityQueue<State>(StatePseudoTimeComparator.INSTANCE, initialSize);
482         }
483 
484         boolean isDescendantOf(State state) {
485             State next = parent;
486             while (next != null) {
487                 if (next == state) {
488                     return true;
489                 }
490                 next = next.parent;
491             }
492             return false;
493         }
494 
495         void takeChild(State child, boolean exclusive, List<ParentChangedEvent> events) {
496             takeChild(null, child, exclusive, events);
497         }
498 
499         
500 
501 
502 
503         void takeChild(Iterator<IntObjectMap.PrimitiveEntry<State>> childItr, State child, boolean exclusive,
504                        List<ParentChangedEvent> events) {
505             State oldParent = child.parent;
506 
507             if (oldParent != this) {
508                 events.add(new ParentChangedEvent(child, oldParent));
509                 child.setParent(this);
510                 
511                 
512                 
513                 if (childItr != null) {
514                     childItr.remove();
515                 } else if (oldParent != null) {
516                     oldParent.children.remove(child.streamId);
517                 }
518 
519                 
520                 initChildrenIfEmpty();
521 
522                 final State oldChild = children.put(child.streamId, child);
523                 assert oldChild == null : "A stream with the same stream ID was already in the child map.";
524             }
525 
526             if (exclusive && !children.isEmpty()) {
527                 
528                 
529                 Iterator<IntObjectMap.PrimitiveEntry<State>> itr = removeAllChildrenExcept(child).entries().iterator();
530                 while (itr.hasNext()) {
531                     child.takeChild(itr, itr.next().value(), false, events);
532                 }
533             }
534         }
535 
536         
537 
538 
539         void removeChild(State child) {
540             if (children.remove(child.streamId) != null) {
541                 List<ParentChangedEvent> events = new ArrayList<ParentChangedEvent>(1 + child.children.size());
542                 events.add(new ParentChangedEvent(child, child.parent));
543                 child.setParent(null);
544 
545                 if (!child.children.isEmpty()) {
546                     
547                     Iterator<IntObjectMap.PrimitiveEntry<State>> itr = child.children.entries().iterator();
548                     long totalWeight = child.getTotalWeight();
549                     do {
550                         
551                         State dependency = itr.next().value();
552                         dependency.weight = (short) max(1, dependency.weight * child.weight / totalWeight);
553                         takeChild(itr, dependency, false, events);
554                     } while (itr.hasNext());
555                 }
556 
557                 notifyParentChanged(events);
558             }
559         }
560 
561         private long getTotalWeight() {
562             long totalWeight = 0L;
563             for (State state : children.values()) {
564                 totalWeight += state.weight;
565             }
566             return totalWeight;
567         }
568 
569         
570 
571 
572 
573 
574         private IntObjectMap<State> removeAllChildrenExcept(State stateToRetain) {
575             stateToRetain = children.remove(stateToRetain.streamId);
576             IntObjectMap<State> prevChildren = children;
577             
578             
579             initChildren();
580             if (stateToRetain != null) {
581                 children.put(stateToRetain.streamId, stateToRetain);
582             }
583             return prevChildren;
584         }
585 
586         private void setParent(State newParent) {
587             
588             if (activeCountForTree != 0 && parent != null) {
589                 parent.removePseudoTimeQueue(this);
590                 parent.activeCountChangeForTree(-activeCountForTree);
591             }
592             parent = newParent;
593             
594             dependencyTreeDepth = newParent == null ? MAX_VALUE : newParent.dependencyTreeDepth + 1;
595         }
596 
597         private void initChildrenIfEmpty() {
598             if (children == IntCollections.<State>emptyMap()) {
599                 initChildren();
600             }
601         }
602 
603         private void initChildren() {
604             children = new IntObjectHashMap<State>(INITIAL_CHILDREN_MAP_SIZE);
605         }
606 
607         void write(int numBytes, Writer writer) throws Http2Exception {
608             assert stream != null;
609             try {
610                 writer.write(stream, numBytes);
611             } catch (Throwable t) {
612                 throw connectionError(INTERNAL_ERROR, t, "byte distribution write error");
613             }
614         }
615 
616         void activeCountChangeForTree(int increment) {
617             assert activeCountForTree + increment >= 0;
618             activeCountForTree += increment;
619             if (parent != null) {
620                 assert activeCountForTree != increment ||
621                        pseudoTimeQueueIndex == INDEX_NOT_IN_QUEUE ||
622                        parent.pseudoTimeQueue.containsTyped(this) :
623                      "State[" + streamId + "].activeCountForTree changed from 0 to " + increment + " is in a " +
624                      "pseudoTimeQueue, but not in parent[ " + parent.streamId + "]'s pseudoTimeQueue";
625                 if (activeCountForTree == 0) {
626                     parent.removePseudoTimeQueue(this);
627                 } else if (activeCountForTree == increment && !isDistributing()) {
628                     
629                     
630                     
631                     
632                     
633                     
634                     
635                     
636                     parent.offerAndInitializePseudoTime(this);
637                 }
638                 parent.activeCountChangeForTree(increment);
639             }
640         }
641 
642         void updateStreamableBytes(int newStreamableBytes, boolean isActive) {
643             if (isActive() != isActive) {
644                 if (isActive) {
645                     activeCountChangeForTree(1);
646                     setActive();
647                 } else {
648                     activeCountChangeForTree(-1);
649                     unsetActive();
650                 }
651             }
652 
653             streamableBytes = newStreamableBytes;
654         }
655 
656         
657 
658 
659         void updatePseudoTime(State parentState, int nsent, long totalQueuedWeights) {
660             assert streamId != CONNECTION_STREAM_ID && nsent >= 0;
661             
662             
663             pseudoTimeToWrite = min(pseudoTimeToWrite, parentState.pseudoTime) + nsent * totalQueuedWeights / weight;
664         }
665 
666         
667 
668 
669 
670 
671         void offerAndInitializePseudoTime(State state) {
672             state.pseudoTimeToWrite = pseudoTime;
673             offerPseudoTimeQueue(state);
674         }
675 
676         void offerPseudoTimeQueue(State state) {
677             pseudoTimeQueue.offer(state);
678             totalQueuedWeights += state.weight;
679         }
680 
681         
682 
683 
684         State pollPseudoTimeQueue() {
685             State state = pseudoTimeQueue.poll();
686             
687             totalQueuedWeights -= state.weight;
688             return state;
689         }
690 
691         void removePseudoTimeQueue(State state) {
692             if (pseudoTimeQueue.removeTyped(state)) {
693                 totalQueuedWeights -= state.weight;
694             }
695         }
696 
697         State peekPseudoTimeQueue() {
698             return pseudoTimeQueue.peek();
699         }
700 
701         void close() {
702             updateStreamableBytes(0, false);
703             stream = null;
704         }
705 
706         boolean wasStreamReservedOrActivated() {
707             return (flags & STATE_STREAM_ACTIVATED) != 0;
708         }
709 
710         void setStreamReservedOrActivated() {
711             flags |= STATE_STREAM_ACTIVATED;
712         }
713 
714         boolean isActive() {
715             return (flags & STATE_IS_ACTIVE) != 0;
716         }
717 
718         private void setActive() {
719             flags |= STATE_IS_ACTIVE;
720         }
721 
722         private void unsetActive() {
723             flags &= ~STATE_IS_ACTIVE;
724         }
725 
726         boolean isDistributing() {
727             return (flags & STATE_IS_DISTRIBUTING) != 0;
728         }
729 
730         void setDistributing() {
731             flags |= STATE_IS_DISTRIBUTING;
732         }
733 
734         void unsetDistributing() {
735             flags &= ~STATE_IS_DISTRIBUTING;
736         }
737 
738         @Override
739         public int priorityQueueIndex(DefaultPriorityQueue<?> queue) {
740             return queue == stateOnlyRemovalQueue ? stateOnlyQueueIndex : pseudoTimeQueueIndex;
741         }
742 
743         @Override
744         public void priorityQueueIndex(DefaultPriorityQueue<?> queue, int i) {
745             if (queue == stateOnlyRemovalQueue) {
746                 stateOnlyQueueIndex = i;
747             } else {
748                 pseudoTimeQueueIndex = i;
749             }
750         }
751 
752         @Override
753         public String toString() {
754             
755             StringBuilder sb = new StringBuilder(256 * (activeCountForTree > 0 ? activeCountForTree : 1));
756             toString(sb);
757             return sb.toString();
758         }
759 
760         private void toString(StringBuilder sb) {
761             sb.append("{streamId ").append(streamId)
762                     .append(" streamableBytes ").append(streamableBytes)
763                     .append(" activeCountForTree ").append(activeCountForTree)
764                     .append(" pseudoTimeQueueIndex ").append(pseudoTimeQueueIndex)
765                     .append(" pseudoTimeToWrite ").append(pseudoTimeToWrite)
766                     .append(" pseudoTime ").append(pseudoTime)
767                     .append(" flags ").append(flags)
768                     .append(" pseudoTimeQueue.size() ").append(pseudoTimeQueue.size())
769                     .append(" stateOnlyQueueIndex ").append(stateOnlyQueueIndex)
770                     .append(" parent.streamId ").append(parent == null ? -1 : parent.streamId).append("} [");
771 
772             if (!pseudoTimeQueue.isEmpty()) {
773                 for (State s : pseudoTimeQueue) {
774                     s.toString(sb);
775                     sb.append(", ");
776                 }
777                 
778                 sb.setLength(sb.length() - 2);
779             }
780             sb.append(']');
781         }
782     }
783 
784     
785 
786 
787     private static final class ParentChangedEvent {
788         final State state;
789         final State oldParent;
790 
791         
792 
793 
794 
795 
796         ParentChangedEvent(State state, State oldParent) {
797             this.state = state;
798             this.oldParent = oldParent;
799         }
800     }
801 }