1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http2;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.util.ReferenceCountUtil;
24
25 import java.util.ArrayDeque;
26 import java.util.Iterator;
27 import java.util.Map;
28 import java.util.Queue;
29 import java.util.TreeMap;
30
31 import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS;
32 import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
33 import static io.netty.handler.codec.http2.Http2Exception.connectionError;
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 public class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder {
58
59
60
61
62 public static final class Http2ChannelClosedException extends Http2Exception {
63 private static final long serialVersionUID = 4768543442094476971L;
64
65 public Http2ChannelClosedException() {
66 super(Http2Error.REFUSED_STREAM, "Connection closed");
67 }
68 }
69
70 private static final class GoAwayDetail {
71 private final int lastStreamId;
72 private final long errorCode;
73 private final byte[] debugData;
74
75 GoAwayDetail(int lastStreamId, long errorCode, byte[] debugData) {
76 this.lastStreamId = lastStreamId;
77 this.errorCode = errorCode;
78 this.debugData = debugData.clone();
79 }
80 }
81
82
83
84
85
86 public static final class Http2GoAwayException extends Http2Exception {
87 private static final long serialVersionUID = 1326785622777291198L;
88 private final GoAwayDetail goAwayDetail;
89
90 public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) {
91 this(new GoAwayDetail(lastStreamId, errorCode, debugData));
92 }
93
94 Http2GoAwayException(GoAwayDetail goAwayDetail) {
95 super(Http2Error.STREAM_CLOSED);
96 this.goAwayDetail = goAwayDetail;
97 }
98
99 public int lastStreamId() {
100 return goAwayDetail.lastStreamId;
101 }
102
103 public long errorCode() {
104 return goAwayDetail.errorCode;
105 }
106
107 public byte[] debugData() {
108 return goAwayDetail.debugData.clone();
109 }
110 }
111
112
113
114
115
116 private final TreeMap<Integer, PendingStream> pendingStreams = new TreeMap<Integer, PendingStream>();
117 private int maxConcurrentStreams;
118 private boolean closed;
119 private GoAwayDetail goAwayDetail;
120
121 public StreamBufferingEncoder(Http2ConnectionEncoder delegate) {
122 this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS);
123 }
124
125 public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxConcurrentStreams) {
126 super(delegate);
127 maxConcurrentStreams = initialMaxConcurrentStreams;
128 connection().addListener(new Http2ConnectionAdapter() {
129
130 @Override
131 public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) {
132 goAwayDetail = new GoAwayDetail(
133
134 lastStreamId, errorCode,
135 ByteBufUtil.getBytes(debugData, debugData.readerIndex(), debugData.readableBytes(), false));
136 cancelGoAwayStreams(goAwayDetail);
137 }
138
139 @Override
140 public void onStreamClosed(Http2Stream stream) {
141 tryCreatePendingStreams();
142 }
143 });
144 }
145
146
147
148
149 public int numBufferedStreams() {
150 return pendingStreams.size();
151 }
152
153 @Override
154 public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
155 boolean endStream, ChannelPromise promise) {
156 return writeHeaders0(ctx, streamId, headers, false, 0, (short) 0,
157 false, padding, endStream, promise);
158 }
159
160 @Override
161 public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
162 int streamDependency, short weight, boolean exclusive, int padding,
163 boolean endOfStream, ChannelPromise promise) {
164 return writeHeaders0(ctx, streamId, headers, true, streamDependency, weight, exclusive, padding,
165 endOfStream, promise);
166 }
167
168 private ChannelFuture writeHeaders0(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
169 boolean hasPriority, int streamDependency, short weight, boolean exclusive,
170 int padding, boolean endOfStream, ChannelPromise promise) {
171 if (closed) {
172 return promise.setFailure(new Http2ChannelClosedException());
173 }
174 if (isExistingStream(streamId) || canCreateStream()) {
175 if (hasPriority) {
176 return super.writeHeaders(ctx, streamId, headers, streamDependency, weight,
177 exclusive, padding, endOfStream, promise);
178 }
179 return super.writeHeaders(ctx, streamId, headers, padding, endOfStream, promise);
180 }
181 if (goAwayDetail != null) {
182 return promise.setFailure(new Http2GoAwayException(goAwayDetail));
183 }
184 PendingStream pendingStream = pendingStreams.get(streamId);
185 if (pendingStream == null) {
186 pendingStream = new PendingStream(ctx, streamId);
187 pendingStreams.put(streamId, pendingStream);
188 }
189 pendingStream.frames.add(new HeadersFrame(headers, hasPriority, streamDependency, weight, exclusive,
190 padding, endOfStream, promise));
191 return promise;
192 }
193
194 @Override
195 public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode,
196 ChannelPromise promise) {
197 if (isExistingStream(streamId)) {
198 return super.writeRstStream(ctx, streamId, errorCode, promise);
199 }
200
201
202 PendingStream stream = pendingStreams.remove(streamId);
203 if (stream != null) {
204
205
206
207
208 stream.close(null);
209 promise.setSuccess();
210 } else {
211 promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
212 }
213 return promise;
214 }
215
216 @Override
217 public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data,
218 int padding, boolean endOfStream, ChannelPromise promise) {
219 if (isExistingStream(streamId)) {
220 return super.writeData(ctx, streamId, data, padding, endOfStream, promise);
221 }
222 PendingStream pendingStream = pendingStreams.get(streamId);
223 if (pendingStream != null) {
224 pendingStream.frames.add(new DataFrame(data, padding, endOfStream, promise));
225 } else {
226 ReferenceCountUtil.safeRelease(data);
227 promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
228 }
229 return promise;
230 }
231
232 @Override
233 public void remoteSettings(Http2Settings settings) throws Http2Exception {
234
235
236 super.remoteSettings(settings);
237
238
239 maxConcurrentStreams = connection().local().maxActiveStreams();
240
241
242 tryCreatePendingStreams();
243 }
244
245 @Override
246 public void close() {
247 try {
248 if (!closed) {
249 closed = true;
250
251
252 Http2ChannelClosedException e = new Http2ChannelClosedException();
253 while (!pendingStreams.isEmpty()) {
254 PendingStream stream = pendingStreams.pollFirstEntry().getValue();
255 stream.close(e);
256 }
257 }
258 } finally {
259 super.close();
260 }
261 }
262
263 private void tryCreatePendingStreams() {
264 while (!pendingStreams.isEmpty() && canCreateStream()) {
265 Map.Entry<Integer, PendingStream> entry = pendingStreams.pollFirstEntry();
266 PendingStream pendingStream = entry.getValue();
267 try {
268 pendingStream.sendFrames();
269 } catch (Throwable t) {
270 pendingStream.close(t);
271 }
272 }
273 }
274
275 private void cancelGoAwayStreams(GoAwayDetail goAwayDetail) {
276 Iterator<PendingStream> iter = pendingStreams.values().iterator();
277 Exception e = new Http2GoAwayException(goAwayDetail);
278 while (iter.hasNext()) {
279 PendingStream stream = iter.next();
280 if (stream.streamId > goAwayDetail.lastStreamId) {
281 iter.remove();
282 stream.close(e);
283 }
284 }
285 }
286
287
288
289
290 private boolean canCreateStream() {
291 return connection().local().numActiveStreams() < maxConcurrentStreams;
292 }
293
294 private boolean isExistingStream(int streamId) {
295 return streamId <= connection().local().lastStreamCreated();
296 }
297
298 private static final class PendingStream {
299 final ChannelHandlerContext ctx;
300 final int streamId;
301 final Queue<Frame> frames = new ArrayDeque<Frame>(2);
302
303 PendingStream(ChannelHandlerContext ctx, int streamId) {
304 this.ctx = ctx;
305 this.streamId = streamId;
306 }
307
308 void sendFrames() {
309 for (Frame frame : frames) {
310 frame.send(ctx, streamId);
311 }
312 }
313
314 void close(Throwable t) {
315 for (Frame frame : frames) {
316 frame.release(t);
317 }
318 }
319 }
320
321 private abstract static class Frame {
322 final ChannelPromise promise;
323
324 Frame(ChannelPromise promise) {
325 this.promise = promise;
326 }
327
328
329
330
331 void release(Throwable t) {
332 if (t == null) {
333 promise.setSuccess();
334 } else {
335 promise.setFailure(t);
336 }
337 }
338
339 abstract void send(ChannelHandlerContext ctx, int streamId);
340 }
341
342 private final class HeadersFrame extends Frame {
343 final Http2Headers headers;
344 final int streamDependency;
345 final boolean hasPriority;
346 final short weight;
347 final boolean exclusive;
348 final int padding;
349 final boolean endOfStream;
350
351 HeadersFrame(Http2Headers headers, boolean hasPriority, int streamDependency, short weight, boolean exclusive,
352 int padding, boolean endOfStream, ChannelPromise promise) {
353 super(promise);
354 this.headers = headers;
355 this.hasPriority = hasPriority;
356 this.streamDependency = streamDependency;
357 this.weight = weight;
358 this.exclusive = exclusive;
359 this.padding = padding;
360 this.endOfStream = endOfStream;
361 }
362
363 @Override
364 void send(ChannelHandlerContext ctx, int streamId) {
365 writeHeaders0(ctx, streamId, headers, hasPriority, streamDependency, weight, exclusive, padding,
366 endOfStream,
367 promise);
368 }
369 }
370
371 private final class DataFrame extends Frame {
372 final ByteBuf data;
373 final int padding;
374 final boolean endOfStream;
375
376 DataFrame(ByteBuf data, int padding, boolean endOfStream, ChannelPromise promise) {
377 super(promise);
378 this.data = data;
379 this.padding = padding;
380 this.endOfStream = endOfStream;
381 }
382
383 @Override
384 void release(Throwable t) {
385 super.release(t);
386 ReferenceCountUtil.safeRelease(data);
387 }
388
389 @Override
390 void send(ChannelHandlerContext ctx, int streamId) {
391 writeData(ctx, streamId, data, padding, endOfStream, promise);
392 }
393 }
394 }