1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http2;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.util.ReferenceCountUtil;
24
25 import java.util.ArrayDeque;
26 import java.util.Iterator;
27 import java.util.Map;
28 import java.util.Queue;
29 import java.util.TreeMap;
30
31 import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS;
32 import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
33 import static io.netty.handler.codec.http2.Http2Exception.connectionError;
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 public class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder {
58
59
60
61
62 public static final class Http2ChannelClosedException extends Http2Exception {
63 private static final long serialVersionUID = 4768543442094476971L;
64
65 public Http2ChannelClosedException() {
66 super(Http2Error.REFUSED_STREAM, "Connection closed");
67 }
68 }
69
70 private static final class GoAwayDetail {
71 private final int lastStreamId;
72 private final long errorCode;
73 private final byte[] debugData;
74
75 GoAwayDetail(int lastStreamId, long errorCode, byte[] debugData) {
76 this.lastStreamId = lastStreamId;
77 this.errorCode = errorCode;
78 this.debugData = debugData.clone();
79 }
80 }
81
82
83
84
85
86 public static final class Http2GoAwayException extends Http2Exception {
87 private static final long serialVersionUID = 1326785622777291198L;
88 private final GoAwayDetail goAwayDetail;
89
90 public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) {
91 this(new GoAwayDetail(lastStreamId, errorCode, debugData));
92 }
93
94 Http2GoAwayException(GoAwayDetail goAwayDetail) {
95 super(Http2Error.STREAM_CLOSED);
96 this.goAwayDetail = goAwayDetail;
97 }
98
99 public int lastStreamId() {
100 return goAwayDetail.lastStreamId;
101 }
102
103 public long errorCode() {
104 return goAwayDetail.errorCode;
105 }
106
107 public byte[] debugData() {
108 return goAwayDetail.debugData.clone();
109 }
110 }
111
112
113
114
115
116 private final TreeMap<Integer, PendingStream> pendingStreams = new TreeMap<Integer, PendingStream>();
117 private int maxConcurrentStreams;
118 private boolean closed;
119 private GoAwayDetail goAwayDetail;
120
121 public StreamBufferingEncoder(Http2ConnectionEncoder delegate) {
122 this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS);
123 }
124
125 public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxConcurrentStreams) {
126 super(delegate);
127 maxConcurrentStreams = initialMaxConcurrentStreams;
128 connection().addListener(new Http2ConnectionAdapter() {
129
130 @Override
131 public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) {
132 goAwayDetail = new GoAwayDetail(
133
134 lastStreamId, errorCode,
135 ByteBufUtil.getBytes(debugData, debugData.readerIndex(), debugData.readableBytes(), false));
136 cancelGoAwayStreams(goAwayDetail);
137 }
138
139 @Override
140 public void onStreamClosed(Http2Stream stream) {
141 tryCreatePendingStreams();
142 }
143 });
144 }
145
146
147
148
149 public int numBufferedStreams() {
150 return pendingStreams.size();
151 }
152
153 @Override
154 public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
155 boolean endStream, ChannelPromise promise) {
156 return writeHeaders0(ctx, streamId, headers, false, 0, (short) 0,
157 false, padding, endStream, promise);
158 }
159
160 @Override
161 public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
162 int streamDependency, short weight, boolean exclusive, int padding,
163 boolean endOfStream, ChannelPromise promise) {
164 return writeHeaders0(ctx, streamId, headers, true, streamDependency, weight, exclusive, padding,
165 endOfStream, promise);
166 }
167
168 private ChannelFuture writeHeaders0(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
169 boolean hasPriority, int streamDependency, short weight, boolean exclusive,
170 int padding, boolean endOfStream, ChannelPromise promise) {
171 if (closed) {
172 return promise.setFailure(new Http2ChannelClosedException());
173 }
174 if (isExistingStream(streamId) || canCreateStream()) {
175 if (hasPriority) {
176 return super.writeHeaders(ctx, streamId, headers, streamDependency, weight,
177 exclusive, padding, endOfStream, promise);
178 }
179 return super.writeHeaders(ctx, streamId, headers, padding, endOfStream, promise);
180 }
181 if (goAwayDetail != null) {
182 return promise.setFailure(new Http2GoAwayException(goAwayDetail));
183 }
184 PendingStream pendingStream = pendingStreams.get(streamId);
185 if (pendingStream == null) {
186 pendingStream = new PendingStream(ctx, streamId);
187 pendingStreams.put(streamId, pendingStream);
188 }
189 pendingStream.frames.add(new HeadersFrame(headers, hasPriority, streamDependency, weight, exclusive,
190 padding, endOfStream, promise));
191 return promise;
192 }
193
194 @Override
195 public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode,
196 ChannelPromise promise) {
197 if (isExistingStream(streamId)) {
198 return super.writeRstStream(ctx, streamId, errorCode, promise);
199 }
200
201
202 PendingStream stream = pendingStreams.remove(streamId);
203 if (stream != null) {
204
205
206
207
208 stream.close(null);
209 promise.setSuccess();
210 } else {
211 promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
212 }
213 return promise;
214 }
215
216 @Override
217 public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data,
218 int padding, boolean endOfStream, ChannelPromise promise) {
219 if (isExistingStream(streamId)) {
220 return super.writeData(ctx, streamId, data, padding, endOfStream, promise);
221 }
222 PendingStream pendingStream = pendingStreams.get(streamId);
223 if (pendingStream != null) {
224 pendingStream.frames.add(new DataFrame(data, padding, endOfStream, promise));
225 } else {
226 ReferenceCountUtil.safeRelease(data);
227 promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
228 }
229 return promise;
230 }
231
232 @Override
233 public ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise) {
234 final ChannelFuture future = super.writeSettingsAck(ctx, promise);
235
236
237
238 updateMaxConcurrentStreams();
239 return future;
240 }
241
242 @Override
243 public void remoteSettings(Http2Settings settings) throws Http2Exception {
244
245
246 super.remoteSettings(settings);
247 updateMaxConcurrentStreams();
248 }
249
250 private void updateMaxConcurrentStreams() {
251
252 maxConcurrentStreams = connection().local().maxActiveStreams();
253
254 tryCreatePendingStreams();
255 }
256
257 @Override
258 public void close() {
259 try {
260 if (!closed) {
261 closed = true;
262
263
264 Http2ChannelClosedException e = new Http2ChannelClosedException();
265 while (!pendingStreams.isEmpty()) {
266 PendingStream stream = pendingStreams.pollFirstEntry().getValue();
267 stream.close(e);
268 }
269 }
270 } finally {
271 super.close();
272 }
273 }
274
275 private void tryCreatePendingStreams() {
276 while (!pendingStreams.isEmpty() && canCreateStream()) {
277 Map.Entry<Integer, PendingStream> entry = pendingStreams.pollFirstEntry();
278 PendingStream pendingStream = entry.getValue();
279 try {
280 pendingStream.sendFrames();
281 } catch (Throwable t) {
282 pendingStream.close(t);
283 }
284 }
285 }
286
287 private void cancelGoAwayStreams(GoAwayDetail goAwayDetail) {
288 Iterator<PendingStream> iter = pendingStreams.values().iterator();
289 Exception e = new Http2GoAwayException(goAwayDetail);
290 while (iter.hasNext()) {
291 PendingStream stream = iter.next();
292 if (stream.streamId > goAwayDetail.lastStreamId) {
293 iter.remove();
294 stream.close(e);
295 }
296 }
297 }
298
299
300
301
302 private boolean canCreateStream() {
303 return connection().local().numActiveStreams() < maxConcurrentStreams;
304 }
305
306 private boolean isExistingStream(int streamId) {
307 return streamId <= connection().local().lastStreamCreated();
308 }
309
310 private static final class PendingStream {
311 final ChannelHandlerContext ctx;
312 final int streamId;
313 final Queue<Frame> frames = new ArrayDeque<Frame>(2);
314
315 PendingStream(ChannelHandlerContext ctx, int streamId) {
316 this.ctx = ctx;
317 this.streamId = streamId;
318 }
319
320 void sendFrames() {
321 for (Frame frame : frames) {
322 frame.send(ctx, streamId);
323 }
324 }
325
326 void close(Throwable t) {
327 for (Frame frame : frames) {
328 frame.release(t);
329 }
330 }
331 }
332
333 private abstract static class Frame {
334 final ChannelPromise promise;
335
336 Frame(ChannelPromise promise) {
337 this.promise = promise;
338 }
339
340
341
342
343 void release(Throwable t) {
344 if (t == null) {
345 promise.setSuccess();
346 } else {
347 promise.setFailure(t);
348 }
349 }
350
351 abstract void send(ChannelHandlerContext ctx, int streamId);
352 }
353
354 private final class HeadersFrame extends Frame {
355 final Http2Headers headers;
356 final int streamDependency;
357 final boolean hasPriority;
358 final short weight;
359 final boolean exclusive;
360 final int padding;
361 final boolean endOfStream;
362
363 HeadersFrame(Http2Headers headers, boolean hasPriority, int streamDependency, short weight, boolean exclusive,
364 int padding, boolean endOfStream, ChannelPromise promise) {
365 super(promise);
366 this.headers = headers;
367 this.hasPriority = hasPriority;
368 this.streamDependency = streamDependency;
369 this.weight = weight;
370 this.exclusive = exclusive;
371 this.padding = padding;
372 this.endOfStream = endOfStream;
373 }
374
375 @Override
376 void send(ChannelHandlerContext ctx, int streamId) {
377 writeHeaders0(ctx, streamId, headers, hasPriority, streamDependency, weight, exclusive, padding,
378 endOfStream,
379 promise);
380 }
381 }
382
383 private final class DataFrame extends Frame {
384 final ByteBuf data;
385 final int padding;
386 final boolean endOfStream;
387
388 DataFrame(ByteBuf data, int padding, boolean endOfStream, ChannelPromise promise) {
389 super(promise);
390 this.data = data;
391 this.padding = padding;
392 this.endOfStream = endOfStream;
393 }
394
395 @Override
396 void release(Throwable t) {
397 super.release(t);
398 ReferenceCountUtil.safeRelease(data);
399 }
400
401 @Override
402 void send(ChannelHandlerContext ctx, int streamId) {
403 writeData(ctx, streamId, data, padding, endOfStream, promise);
404 }
405 }
406 }