1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http2;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.util.ReferenceCountUtil;
24
25 import java.util.ArrayDeque;
26 import java.util.Iterator;
27 import java.util.Map;
28 import java.util.Queue;
29 import java.util.TreeMap;
30
31 import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS;
32 import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
33 import static io.netty.handler.codec.http2.Http2Exception.connectionError;
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 public class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder {
58
59
60
61
62 public static final class Http2ChannelClosedException extends Http2Exception {
63 private static final long serialVersionUID = 4768543442094476971L;
64
65 public Http2ChannelClosedException() {
66 super(Http2Error.REFUSED_STREAM, "Connection closed");
67 }
68 }
69
70 private static final class GoAwayDetail {
71 private final int lastStreamId;
72 private final long errorCode;
73 private final byte[] debugData;
74
75 GoAwayDetail(int lastStreamId, long errorCode, byte[] debugData) {
76 this.lastStreamId = lastStreamId;
77 this.errorCode = errorCode;
78 this.debugData = debugData.clone();
79 }
80 }
81
82
83
84
85
86 public static final class Http2GoAwayException extends Http2Exception {
87 private static final long serialVersionUID = 1326785622777291198L;
88 private final GoAwayDetail goAwayDetail;
89
90 public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) {
91 this(new GoAwayDetail(lastStreamId, errorCode, debugData));
92 }
93
94 Http2GoAwayException(GoAwayDetail goAwayDetail) {
95 super(Http2Error.STREAM_CLOSED);
96 this.goAwayDetail = goAwayDetail;
97 }
98
99 public int lastStreamId() {
100 return goAwayDetail.lastStreamId;
101 }
102
103 public long errorCode() {
104 return goAwayDetail.errorCode;
105 }
106
107 public byte[] debugData() {
108 return goAwayDetail.debugData.clone();
109 }
110 }
111
112
113
114
115
116 private final TreeMap<Integer, PendingStream> pendingStreams = new TreeMap<Integer, PendingStream>();
117 private int maxConcurrentStreams;
118 private boolean closed;
119 private GoAwayDetail goAwayDetail;
120
121 public StreamBufferingEncoder(Http2ConnectionEncoder delegate) {
122 this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS);
123 }
124
125 public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxConcurrentStreams) {
126 super(delegate);
127 maxConcurrentStreams = initialMaxConcurrentStreams;
128 connection().addListener(new Http2ConnectionAdapter() {
129
130 @Override
131 public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) {
132 goAwayDetail = new GoAwayDetail(
133
134 lastStreamId, errorCode,
135 ByteBufUtil.getBytes(debugData, debugData.readerIndex(), debugData.readableBytes(), false));
136 cancelGoAwayStreams(goAwayDetail);
137 }
138
139 @Override
140 public void onStreamClosed(Http2Stream stream) {
141 tryCreatePendingStreams();
142 }
143 });
144 }
145
146
147
148
149 public int numBufferedStreams() {
150 return pendingStreams.size();
151 }
152
153 @Override
154 public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
155 int padding, boolean endStream, ChannelPromise promise) {
156 return writeHeaders(ctx, streamId, headers, 0, Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT,
157 false, padding, endStream, promise);
158 }
159
160 @Override
161 public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
162 int streamDependency, short weight, boolean exclusive,
163 int padding, boolean endOfStream, ChannelPromise promise) {
164 if (closed) {
165 return promise.setFailure(new Http2ChannelClosedException());
166 }
167 if (isExistingStream(streamId) || canCreateStream()) {
168 return super.writeHeaders(ctx, streamId, headers, streamDependency, weight,
169 exclusive, padding, endOfStream, promise);
170 }
171 if (goAwayDetail != null) {
172 return promise.setFailure(new Http2GoAwayException(goAwayDetail));
173 }
174 PendingStream pendingStream = pendingStreams.get(streamId);
175 if (pendingStream == null) {
176 pendingStream = new PendingStream(ctx, streamId);
177 pendingStreams.put(streamId, pendingStream);
178 }
179 pendingStream.frames.add(new HeadersFrame(headers, streamDependency, weight, exclusive,
180 padding, endOfStream, promise));
181 return promise;
182 }
183
184 @Override
185 public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode,
186 ChannelPromise promise) {
187 if (isExistingStream(streamId)) {
188 return super.writeRstStream(ctx, streamId, errorCode, promise);
189 }
190
191
192 PendingStream stream = pendingStreams.remove(streamId);
193 if (stream != null) {
194
195
196
197
198 stream.close(null);
199 promise.setSuccess();
200 } else {
201 promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
202 }
203 return promise;
204 }
205
206 @Override
207 public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data,
208 int padding, boolean endOfStream, ChannelPromise promise) {
209 if (isExistingStream(streamId)) {
210 return super.writeData(ctx, streamId, data, padding, endOfStream, promise);
211 }
212 PendingStream pendingStream = pendingStreams.get(streamId);
213 if (pendingStream != null) {
214 pendingStream.frames.add(new DataFrame(data, padding, endOfStream, promise));
215 } else {
216 ReferenceCountUtil.safeRelease(data);
217 promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
218 }
219 return promise;
220 }
221
222 @Override
223 public void remoteSettings(Http2Settings settings) throws Http2Exception {
224
225
226 super.remoteSettings(settings);
227
228
229 maxConcurrentStreams = connection().local().maxActiveStreams();
230
231
232 tryCreatePendingStreams();
233 }
234
235 @Override
236 public void close() {
237 try {
238 if (!closed) {
239 closed = true;
240
241
242 Http2ChannelClosedException e = new Http2ChannelClosedException();
243 while (!pendingStreams.isEmpty()) {
244 PendingStream stream = pendingStreams.pollFirstEntry().getValue();
245 stream.close(e);
246 }
247 }
248 } finally {
249 super.close();
250 }
251 }
252
253 private void tryCreatePendingStreams() {
254 while (!pendingStreams.isEmpty() && canCreateStream()) {
255 Map.Entry<Integer, PendingStream> entry = pendingStreams.pollFirstEntry();
256 PendingStream pendingStream = entry.getValue();
257 try {
258 pendingStream.sendFrames();
259 } catch (Throwable t) {
260 pendingStream.close(t);
261 }
262 }
263 }
264
265 private void cancelGoAwayStreams(GoAwayDetail goAwayDetail) {
266 Iterator<PendingStream> iter = pendingStreams.values().iterator();
267 Exception e = new Http2GoAwayException(goAwayDetail);
268 while (iter.hasNext()) {
269 PendingStream stream = iter.next();
270 if (stream.streamId > goAwayDetail.lastStreamId) {
271 iter.remove();
272 stream.close(e);
273 }
274 }
275 }
276
277
278
279
280 private boolean canCreateStream() {
281 return connection().local().numActiveStreams() < maxConcurrentStreams;
282 }
283
284 private boolean isExistingStream(int streamId) {
285 return streamId <= connection().local().lastStreamCreated();
286 }
287
288 private static final class PendingStream {
289 final ChannelHandlerContext ctx;
290 final int streamId;
291 final Queue<Frame> frames = new ArrayDeque<Frame>(2);
292
293 PendingStream(ChannelHandlerContext ctx, int streamId) {
294 this.ctx = ctx;
295 this.streamId = streamId;
296 }
297
298 void sendFrames() {
299 for (Frame frame : frames) {
300 frame.send(ctx, streamId);
301 }
302 }
303
304 void close(Throwable t) {
305 for (Frame frame : frames) {
306 frame.release(t);
307 }
308 }
309 }
310
311 private abstract static class Frame {
312 final ChannelPromise promise;
313
314 Frame(ChannelPromise promise) {
315 this.promise = promise;
316 }
317
318
319
320
321 void release(Throwable t) {
322 if (t == null) {
323 promise.setSuccess();
324 } else {
325 promise.setFailure(t);
326 }
327 }
328
329 abstract void send(ChannelHandlerContext ctx, int streamId);
330 }
331
332 private final class HeadersFrame extends Frame {
333 final Http2Headers headers;
334 final int streamDependency;
335 final short weight;
336 final boolean exclusive;
337 final int padding;
338 final boolean endOfStream;
339
340 HeadersFrame(Http2Headers headers, int streamDependency, short weight, boolean exclusive,
341 int padding, boolean endOfStream, ChannelPromise promise) {
342 super(promise);
343 this.headers = headers;
344 this.streamDependency = streamDependency;
345 this.weight = weight;
346 this.exclusive = exclusive;
347 this.padding = padding;
348 this.endOfStream = endOfStream;
349 }
350
351 @Override
352 void send(ChannelHandlerContext ctx, int streamId) {
353 writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise);
354 }
355 }
356
357 private final class DataFrame extends Frame {
358 final ByteBuf data;
359 final int padding;
360 final boolean endOfStream;
361
362 DataFrame(ByteBuf data, int padding, boolean endOfStream, ChannelPromise promise) {
363 super(promise);
364 this.data = data;
365 this.padding = padding;
366 this.endOfStream = endOfStream;
367 }
368
369 @Override
370 void release(Throwable t) {
371 super.release(t);
372 ReferenceCountUtil.safeRelease(data);
373 }
374
375 @Override
376 void send(ChannelHandlerContext ctx, int streamId) {
377 writeData(ctx, streamId, data, padding, endOfStream, promise);
378 }
379 }
380 }