1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.codec.http2;
17
18 import io.netty5.buffer.BufferUtil;
19 import io.netty5.buffer.api.Buffer;
20 import io.netty5.channel.ChannelHandlerContext;
21 import io.netty5.util.concurrent.Future;
22 import io.netty5.util.concurrent.Promise;
23 import io.netty5.util.internal.SilentDispose;
24 import io.netty5.util.internal.UnstableApi;
25 import io.netty5.util.internal.logging.InternalLogger;
26 import io.netty5.util.internal.logging.InternalLoggerFactory;
27
28 import java.util.ArrayDeque;
29 import java.util.Iterator;
30 import java.util.Map;
31 import java.util.Queue;
32 import java.util.TreeMap;
33
34 import static io.netty5.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS;
35 import static io.netty5.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
36 import static io.netty5.handler.codec.http2.Http2Exception.connectionError;
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60 @UnstableApi
61 public class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder {
62 private static final InternalLogger logger = InternalLoggerFactory.getInstance(StreamBufferingEncoder.class);
63
64
65
66
67 public static final class Http2ChannelClosedException extends Http2Exception {
68 private static final long serialVersionUID = 4768543442094476971L;
69
70 public Http2ChannelClosedException() {
71 super(Http2Error.REFUSED_STREAM, "Connection closed");
72 }
73 }
74
75 private static final class GoAwayDetail {
76 private final int lastStreamId;
77 private final long errorCode;
78 private final byte[] debugData;
79
80 GoAwayDetail(int lastStreamId, long errorCode, byte[] debugData) {
81 this.lastStreamId = lastStreamId;
82 this.errorCode = errorCode;
83 this.debugData = debugData.clone();
84 }
85 }
86
87
88
89
90
91 public static final class Http2GoAwayException extends Http2Exception {
92 private static final long serialVersionUID = 1326785622777291198L;
93 private final GoAwayDetail goAwayDetail;
94
95 public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) {
96 this(new GoAwayDetail(lastStreamId, errorCode, debugData));
97 }
98
99 Http2GoAwayException(GoAwayDetail goAwayDetail) {
100 super(Http2Error.STREAM_CLOSED);
101 this.goAwayDetail = goAwayDetail;
102 }
103
104 public int lastStreamId() {
105 return goAwayDetail.lastStreamId;
106 }
107
108 public long errorCode() {
109 return goAwayDetail.errorCode;
110 }
111
112 public byte[] debugData() {
113 return goAwayDetail.debugData.clone();
114 }
115 }
116
117
118
119
120
121 private final TreeMap<Integer, PendingStream> pendingStreams = new TreeMap<>();
122 private int maxConcurrentStreams;
123 private boolean closed;
124 private GoAwayDetail goAwayDetail;
125
126 public StreamBufferingEncoder(Http2ConnectionEncoder delegate) {
127 this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS);
128 }
129
130 public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxConcurrentStreams) {
131 super(delegate);
132 maxConcurrentStreams = initialMaxConcurrentStreams;
133 connection().addListener(new Http2ConnectionAdapter() {
134
135 @Override
136 public void onGoAwayReceived(int lastStreamId, long errorCode, Buffer debugData) {
137 goAwayDetail = new GoAwayDetail(
138
139 lastStreamId, errorCode,
140 BufferUtil.getBytes(debugData, debugData.readerOffset(), debugData.readableBytes()));
141 cancelGoAwayStreams(goAwayDetail);
142 }
143
144 @Override
145 public void onStreamClosed(Http2Stream stream) {
146 tryCreatePendingStreams();
147 }
148 });
149 }
150
151
152
153
154 public int numBufferedStreams() {
155 return pendingStreams.size();
156 }
157
158 @Override
159 public Future<Void> writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
160 int padding, boolean endStream) {
161 return writeHeaders(ctx, streamId, headers, 0, Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT,
162 false, padding, endStream);
163 }
164
165 @Override
166 public Future<Void> writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
167 int streamDependency, short weight, boolean exclusive,
168 int padding, boolean endOfStream) {
169 if (closed) {
170 return ctx.newFailedFuture(new Http2ChannelClosedException());
171 }
172 if (isExistingStream(streamId) || canCreateStream()) {
173 return super.writeHeaders(ctx, streamId, headers, streamDependency, weight,
174 exclusive, padding, endOfStream);
175 }
176 if (goAwayDetail != null) {
177 return ctx.newFailedFuture(new Http2GoAwayException(goAwayDetail));
178 }
179 PendingStream pendingStream = pendingStreams.get(streamId);
180 if (pendingStream == null) {
181 pendingStream = new PendingStream(ctx, streamId);
182 pendingStreams.put(streamId, pendingStream);
183 }
184 Promise<Void> promise = ctx.newPromise();
185
186 pendingStream.frames.add(new HeadersFrame(headers, streamDependency, weight, exclusive,
187 padding, endOfStream, promise));
188 return promise.asFuture();
189 }
190
191 @Override
192 public Future<Void> writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode) {
193 if (isExistingStream(streamId)) {
194 return super.writeRstStream(ctx, streamId, errorCode);
195 }
196
197
198 PendingStream stream = pendingStreams.remove(streamId);
199 if (stream != null) {
200
201
202
203
204 stream.close(null);
205 return ctx.newSucceededFuture();
206 } else {
207 return ctx.newFailedFuture(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
208 }
209 }
210
211 @Override
212 public Future<Void> writeData(ChannelHandlerContext ctx, int streamId, Buffer data,
213 int padding, boolean endOfStream) {
214 if (isExistingStream(streamId)) {
215 return super.writeData(ctx, streamId, data, padding, endOfStream);
216 }
217 PendingStream pendingStream = pendingStreams.get(streamId);
218 if (pendingStream != null) {
219 Promise<Void> promise = ctx.newPromise();
220 pendingStream.frames.add(new DataFrame(data, padding, endOfStream, promise));
221 return promise.asFuture();
222 } else {
223 SilentDispose.dispose(data, logger);
224 return ctx.newFailedFuture(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId));
225 }
226 }
227
228 @Override
229 public void remoteSettings(Http2Settings settings) throws Http2Exception {
230
231
232 super.remoteSettings(settings);
233
234
235 maxConcurrentStreams = connection().local().maxActiveStreams();
236
237
238 tryCreatePendingStreams();
239 }
240
241 @Override
242 public void close() {
243 try {
244 if (!closed) {
245 closed = true;
246
247
248 Http2ChannelClosedException e = new Http2ChannelClosedException();
249 while (!pendingStreams.isEmpty()) {
250 PendingStream stream = pendingStreams.pollFirstEntry().getValue();
251 stream.close(e);
252 }
253 }
254 } finally {
255 super.close();
256 }
257 }
258
259 private void tryCreatePendingStreams() {
260 while (!pendingStreams.isEmpty() && canCreateStream()) {
261 Map.Entry<Integer, PendingStream> entry = pendingStreams.pollFirstEntry();
262 PendingStream pendingStream = entry.getValue();
263 try {
264 pendingStream.sendFrames();
265 } catch (Throwable t) {
266 pendingStream.close(t);
267 }
268 }
269 }
270
271 private void cancelGoAwayStreams(GoAwayDetail goAwayDetail) {
272 Iterator<PendingStream> iter = pendingStreams.values().iterator();
273 Exception e = new Http2GoAwayException(goAwayDetail);
274 while (iter.hasNext()) {
275 PendingStream stream = iter.next();
276 if (stream.streamId > goAwayDetail.lastStreamId) {
277 iter.remove();
278 stream.close(e);
279 }
280 }
281 }
282
283
284
285
286 private boolean canCreateStream() {
287 return connection().local().numActiveStreams() < maxConcurrentStreams;
288 }
289
290 private boolean isExistingStream(int streamId) {
291 return streamId <= connection().local().lastStreamCreated();
292 }
293
294 private static final class PendingStream {
295 final ChannelHandlerContext ctx;
296 final int streamId;
297 final Queue<Frame> frames = new ArrayDeque<>(2);
298
299 PendingStream(ChannelHandlerContext ctx, int streamId) {
300 this.ctx = ctx;
301 this.streamId = streamId;
302 }
303
304 void sendFrames() {
305 for (Frame frame : frames) {
306 frame.send(ctx, streamId);
307 }
308 }
309
310 void close(Throwable t) {
311 for (Frame frame : frames) {
312 frame.release(t);
313 }
314 }
315 }
316
317 private abstract static class Frame {
318 final Promise<Void> promise;
319
320 Frame(Promise<Void> promise) {
321 this.promise = promise;
322 }
323
324
325
326
327 void release(Throwable t) {
328 if (t == null) {
329 promise.setSuccess(null);
330 } else {
331 promise.setFailure(t);
332 }
333 }
334
335 abstract void send(ChannelHandlerContext ctx, int streamId);
336 }
337
338 private final class HeadersFrame extends Frame {
339 final Http2Headers headers;
340 final int streamDependency;
341 final short weight;
342 final boolean exclusive;
343 final int padding;
344 final boolean endOfStream;
345
346 HeadersFrame(Http2Headers headers, int streamDependency, short weight, boolean exclusive,
347 int padding, boolean endOfStream, Promise<Void> promise) {
348 super(promise);
349 this.headers = headers;
350 this.streamDependency = streamDependency;
351 this.weight = weight;
352 this.exclusive = exclusive;
353 this.padding = padding;
354 this.endOfStream = endOfStream;
355 }
356
357 @Override
358 void send(ChannelHandlerContext ctx, int streamId) {
359 writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream)
360 .cascadeTo(promise);
361 }
362 }
363
364 private final class DataFrame extends Frame {
365 final Buffer data;
366 final int padding;
367 final boolean endOfStream;
368
369 DataFrame(Buffer data, int padding, boolean endOfStream, Promise<Void> promise) {
370 super(promise);
371 this.data = data;
372 this.padding = padding;
373 this.endOfStream = endOfStream;
374 }
375
376 @Override
377 void release(Throwable t) {
378 super.release(t);
379 SilentDispose.dispose(data, logger);
380 }
381
382 @Override
383 void send(ChannelHandlerContext ctx, int streamId) {
384 writeData(ctx, streamId, data, padding, endOfStream).cascadeTo(promise);
385 }
386 }
387 }