1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package io.netty.handler.codec.http2;
16
17 import io.netty.buffer.ByteBuf;
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.channel.embedded.EmbeddedChannel;
21 import io.netty.handler.codec.ByteToMessageDecoder;
22 import io.netty.handler.codec.compression.Brotli;
23 import io.netty.handler.codec.compression.BrotliDecoder;
24 import io.netty.handler.codec.compression.Zstd;
25 import io.netty.handler.codec.compression.ZstdDecoder;
26 import io.netty.handler.codec.compression.ZlibCodecFactory;
27 import io.netty.handler.codec.compression.ZlibWrapper;
28 import io.netty.handler.codec.compression.SnappyFrameDecoder;
29
30 import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_ENCODING;
31 import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
32 import static io.netty.handler.codec.http.HttpHeaderValues.BR;
33 import static io.netty.handler.codec.http.HttpHeaderValues.DEFLATE;
34 import static io.netty.handler.codec.http.HttpHeaderValues.GZIP;
35 import static io.netty.handler.codec.http.HttpHeaderValues.IDENTITY;
36 import static io.netty.handler.codec.http.HttpHeaderValues.X_DEFLATE;
37 import static io.netty.handler.codec.http.HttpHeaderValues.X_GZIP;
38 import static io.netty.handler.codec.http.HttpHeaderValues.SNAPPY;
39 import static io.netty.handler.codec.http.HttpHeaderValues.ZSTD;
40 import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR;
41 import static io.netty.handler.codec.http2.Http2Exception.streamError;
42 import static io.netty.util.internal.ObjectUtil.checkNotNull;
43 import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
44
45
46
47
48
49 public class DelegatingDecompressorFrameListener extends Http2FrameListenerDecorator {
50
51 private final Http2Connection connection;
52 private final boolean strict;
53 private boolean flowControllerInitialized;
54 private final Http2Connection.PropertyKey propertyKey;
55 private final int maxAllocation;
56
57
58
59
60
61
62
63
64
65
66 @Deprecated
67 public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener) {
68 this(connection, listener, 0);
69 }
70
71
72
73
74
75
76
77
78
79 public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener,
80 int maxAllocation) {
81 this(connection, listener, true, maxAllocation);
82 }
83
84
85
86
87
88
89
90
91
92
93
94
95 @Deprecated
96 public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener,
97 boolean strict) {
98 this(connection, listener, strict, 0);
99 }
100
101
102
103
104
105
106
107
108
109
110
111 public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener,
112 boolean strict, int maxAllocation) {
113 super(listener);
114 this.connection = connection;
115 this.strict = strict;
116 this.maxAllocation = checkPositiveOrZero(maxAllocation, "maxAllocation");
117
118 propertyKey = connection.newKey();
119 connection.addListener(new Http2ConnectionAdapter() {
120 @Override
121 public void onStreamRemoved(Http2Stream stream) {
122 final Http2Decompressor decompressor = decompressor(stream);
123 if (decompressor != null) {
124 cleanup(decompressor);
125 }
126 }
127 });
128 }
129
130 @Override
131 public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream)
132 throws Http2Exception {
133 final Http2Stream stream = connection.stream(streamId);
134 final Http2Decompressor decompressor = decompressor(stream);
135 if (decompressor == null) {
136
137 return listener.onDataRead(ctx, streamId, data, padding, endOfStream);
138 }
139
140 final EmbeddedChannel channel = decompressor.decompressor();
141 final int compressedBytes = data.readableBytes() + padding;
142 decompressor.incrementCompressedBytes(compressedBytes);
143 try {
144
145 channel.writeInbound(data.retain());
146 ByteBuf buf = nextReadableBuf(channel);
147 if (buf == null && endOfStream && channel.finish()) {
148 buf = nextReadableBuf(channel);
149 }
150 if (buf == null) {
151 if (endOfStream) {
152 listener.onDataRead(ctx, streamId, Unpooled.EMPTY_BUFFER, padding, true);
153 }
154
155
156
157
158 decompressor.incrementDecompressedBytes(compressedBytes);
159 return compressedBytes;
160 }
161 try {
162 Http2LocalFlowController flowController = connection.local().flowController();
163 decompressor.incrementDecompressedBytes(padding);
164 for (;;) {
165 ByteBuf nextBuf = nextReadableBuf(channel);
166 boolean decompressedEndOfStream = nextBuf == null && endOfStream;
167 if (decompressedEndOfStream && channel.finish()) {
168 nextBuf = nextReadableBuf(channel);
169 decompressedEndOfStream = nextBuf == null;
170 }
171
172 decompressor.incrementDecompressedBytes(buf.readableBytes());
173
174
175
176 flowController.consumeBytes(stream,
177 listener.onDataRead(ctx, streamId, buf, padding, decompressedEndOfStream));
178 if (nextBuf == null) {
179 break;
180 }
181
182 padding = 0;
183 buf.release();
184 buf = nextBuf;
185 }
186
187
188
189 return 0;
190 } finally {
191 buf.release();
192 }
193 } catch (Http2Exception e) {
194 throw e;
195 } catch (Throwable t) {
196 throw streamError(stream.id(), INTERNAL_ERROR, t,
197 "Decompressor error detected while delegating data read on streamId %d", stream.id());
198 }
199 }
200
201 @Override
202 public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
203 boolean endStream) throws Http2Exception {
204 initDecompressor(ctx, streamId, headers, endStream);
205 listener.onHeadersRead(ctx, streamId, headers, padding, endStream);
206 }
207
208 @Override
209 public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency,
210 short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception {
211 initDecompressor(ctx, streamId, headers, endStream);
212 listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream);
213 }
214
215
216
217
218
219
220
221
222
223
224 protected EmbeddedChannel newContentDecompressor(final ChannelHandlerContext ctx, CharSequence contentEncoding)
225 throws Http2Exception {
226 if (GZIP.contentEqualsIgnoreCase(contentEncoding) || X_GZIP.contentEqualsIgnoreCase(contentEncoding)) {
227 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
228 ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP, maxAllocation));
229 }
230 if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) {
231 final ZlibWrapper wrapper = strict ? ZlibWrapper.ZLIB : ZlibWrapper.ZLIB_OR_NONE;
232
233 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
234 ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(wrapper, maxAllocation));
235 }
236 if (Brotli.isAvailable() && BR.contentEqualsIgnoreCase(contentEncoding)) {
237 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
238 ctx.channel().config(), new BrotliDecoder());
239 }
240 if (SNAPPY.contentEqualsIgnoreCase(contentEncoding)) {
241 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
242 ctx.channel().config(), new SnappyFrameDecoder());
243 }
244 if (Zstd.isAvailable() && ZSTD.contentEqualsIgnoreCase(contentEncoding)) {
245 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
246 ctx.channel().config(), new ZstdDecoder());
247 }
248
249 return null;
250 }
251
252
253
254
255
256
257
258
259
260 protected CharSequence getTargetContentEncoding(@SuppressWarnings("UnusedParameters") CharSequence contentEncoding)
261 throws Http2Exception {
262 return IDENTITY;
263 }
264
265
266
267
268
269
270
271
272
273
274
275 private void initDecompressor(ChannelHandlerContext ctx, int streamId, Http2Headers headers, boolean endOfStream)
276 throws Http2Exception {
277 final Http2Stream stream = connection.stream(streamId);
278 if (stream == null) {
279 return;
280 }
281
282 Http2Decompressor decompressor = decompressor(stream);
283 if (decompressor == null && !endOfStream) {
284
285 CharSequence contentEncoding = headers.get(CONTENT_ENCODING);
286 if (contentEncoding == null) {
287 contentEncoding = IDENTITY;
288 }
289 final EmbeddedChannel channel = newContentDecompressor(ctx, contentEncoding);
290 if (channel != null) {
291 decompressor = new Http2Decompressor(channel);
292 stream.setProperty(propertyKey, decompressor);
293
294
295 CharSequence targetContentEncoding = getTargetContentEncoding(contentEncoding);
296 if (IDENTITY.contentEqualsIgnoreCase(targetContentEncoding)) {
297 headers.remove(CONTENT_ENCODING);
298 } else {
299 headers.set(CONTENT_ENCODING, targetContentEncoding);
300 }
301 }
302 }
303
304 if (decompressor != null) {
305
306
307
308 headers.remove(CONTENT_LENGTH);
309
310
311
312 if (!flowControllerInitialized) {
313 flowControllerInitialized = true;
314 connection.local().flowController(new ConsumedBytesConverter(connection.local().flowController()));
315 }
316 }
317 }
318
319 Http2Decompressor decompressor(Http2Stream stream) {
320 return stream == null ? null : (Http2Decompressor) stream.getProperty(propertyKey);
321 }
322
323
324
325
326
327
328 private static void cleanup(Http2Decompressor decompressor) {
329 decompressor.decompressor().finishAndReleaseAll();
330 }
331
332
333
334
335
336
337
338
339 private static ByteBuf nextReadableBuf(EmbeddedChannel decompressor) {
340 for (;;) {
341 final ByteBuf buf = decompressor.readInbound();
342 if (buf == null) {
343 return null;
344 }
345 if (!buf.isReadable()) {
346 buf.release();
347 continue;
348 }
349 return buf;
350 }
351 }
352
353
354
355
356 private final class ConsumedBytesConverter implements Http2LocalFlowController {
357 private final Http2LocalFlowController flowController;
358
359 ConsumedBytesConverter(Http2LocalFlowController flowController) {
360 this.flowController = checkNotNull(flowController, "flowController");
361 }
362
363 @Override
364 public Http2LocalFlowController frameWriter(Http2FrameWriter frameWriter) {
365 return flowController.frameWriter(frameWriter);
366 }
367
368 @Override
369 public void channelHandlerContext(ChannelHandlerContext ctx) throws Http2Exception {
370 flowController.channelHandlerContext(ctx);
371 }
372
373 @Override
374 public void initialWindowSize(int newWindowSize) throws Http2Exception {
375 flowController.initialWindowSize(newWindowSize);
376 }
377
378 @Override
379 public int initialWindowSize() {
380 return flowController.initialWindowSize();
381 }
382
383 @Override
384 public int windowSize(Http2Stream stream) {
385 return flowController.windowSize(stream);
386 }
387
388 @Override
389 public void incrementWindowSize(Http2Stream stream, int delta) throws Http2Exception {
390 flowController.incrementWindowSize(stream, delta);
391 }
392
393 @Override
394 public void receiveFlowControlledFrame(Http2Stream stream, ByteBuf data, int padding,
395 boolean endOfStream) throws Http2Exception {
396 flowController.receiveFlowControlledFrame(stream, data, padding, endOfStream);
397 }
398
399 @Override
400 public boolean consumeBytes(Http2Stream stream, int numBytes) throws Http2Exception {
401 Http2Decompressor decompressor = decompressor(stream);
402 if (decompressor != null) {
403
404 numBytes = decompressor.consumeBytes(stream.id(), numBytes);
405 }
406 try {
407 return flowController.consumeBytes(stream, numBytes);
408 } catch (Http2Exception e) {
409 throw e;
410 } catch (Throwable t) {
411
412
413 throw streamError(stream.id(), INTERNAL_ERROR, t, "Error while returning bytes to flow control window");
414 }
415 }
416
417 @Override
418 public int unconsumedBytes(Http2Stream stream) {
419 return flowController.unconsumedBytes(stream);
420 }
421
422 @Override
423 public int initialWindowSize(Http2Stream stream) {
424 return flowController.initialWindowSize(stream);
425 }
426 }
427
428
429
430
431 private static final class Http2Decompressor {
432 private final EmbeddedChannel decompressor;
433 private int compressed;
434 private int decompressed;
435
436 Http2Decompressor(EmbeddedChannel decompressor) {
437 this.decompressor = decompressor;
438 }
439
440
441
442
443 EmbeddedChannel decompressor() {
444 return decompressor;
445 }
446
447
448
449
450 void incrementCompressedBytes(int delta) {
451 assert delta >= 0;
452 compressed += delta;
453 }
454
455
456
457
458 void incrementDecompressedBytes(int delta) {
459 assert delta >= 0;
460 decompressed += delta;
461 }
462
463
464
465
466
467
468
469
470
471 int consumeBytes(int streamId, int decompressedBytes) throws Http2Exception {
472 checkPositiveOrZero(decompressedBytes, "decompressedBytes");
473 if (decompressed - decompressedBytes < 0) {
474 throw streamError(streamId, INTERNAL_ERROR,
475 "Attempting to return too many bytes for stream %d. decompressed: %d " +
476 "decompressedBytes: %d", streamId, decompressed, decompressedBytes);
477 }
478 double consumedRatio = decompressedBytes / (double) decompressed;
479 int consumedCompressed = Math.min(compressed, (int) Math.ceil(compressed * consumedRatio));
480 if (compressed - consumedCompressed < 0) {
481 throw streamError(streamId, INTERNAL_ERROR,
482 "overflow when converting decompressed bytes to compressed bytes for stream %d." +
483 "decompressedBytes: %d decompressed: %d compressed: %d consumedCompressed: %d",
484 streamId, decompressedBytes, decompressed, compressed, consumedCompressed);
485 }
486 decompressed -= decompressedBytes;
487 compressed -= consumedCompressed;
488
489 return consumedCompressed;
490 }
491 }
492 }