1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package io.netty.handler.codec.http2;
16
17 import io.netty.buffer.ByteBuf;
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.channel.embedded.EmbeddedChannel;
21 import io.netty.handler.codec.ByteToMessageDecoder;
22 import io.netty.handler.codec.compression.Brotli;
23 import io.netty.handler.codec.compression.BrotliDecoder;
24 import io.netty.handler.codec.compression.Zstd;
25 import io.netty.handler.codec.compression.ZstdDecoder;
26 import io.netty.handler.codec.compression.ZlibCodecFactory;
27 import io.netty.handler.codec.compression.ZlibWrapper;
28 import io.netty.handler.codec.compression.SnappyFrameDecoder;
29
30 import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_ENCODING;
31 import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
32 import static io.netty.handler.codec.http.HttpHeaderValues.BR;
33 import static io.netty.handler.codec.http.HttpHeaderValues.DEFLATE;
34 import static io.netty.handler.codec.http.HttpHeaderValues.GZIP;
35 import static io.netty.handler.codec.http.HttpHeaderValues.IDENTITY;
36 import static io.netty.handler.codec.http.HttpHeaderValues.X_DEFLATE;
37 import static io.netty.handler.codec.http.HttpHeaderValues.X_GZIP;
38 import static io.netty.handler.codec.http.HttpHeaderValues.SNAPPY;
39 import static io.netty.handler.codec.http.HttpHeaderValues.ZSTD;
40 import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR;
41 import static io.netty.handler.codec.http2.Http2Exception.streamError;
42 import static io.netty.util.internal.ObjectUtil.checkNotNull;
43 import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
44
45
46
47
48
49 public class DelegatingDecompressorFrameListener extends Http2FrameListenerDecorator {
50
51 private final Http2Connection connection;
52 private final boolean strict;
53 private boolean flowControllerInitialized;
54 private final Http2Connection.PropertyKey propertyKey;
55
56 public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener) {
57 this(connection, listener, true);
58 }
59
60 public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener,
61 boolean strict) {
62 super(listener);
63 this.connection = connection;
64 this.strict = strict;
65
66 propertyKey = connection.newKey();
67 connection.addListener(new Http2ConnectionAdapter() {
68 @Override
69 public void onStreamRemoved(Http2Stream stream) {
70 final Http2Decompressor decompressor = decompressor(stream);
71 if (decompressor != null) {
72 cleanup(decompressor);
73 }
74 }
75 });
76 }
77
78 @Override
79 public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream)
80 throws Http2Exception {
81 final Http2Stream stream = connection.stream(streamId);
82 final Http2Decompressor decompressor = decompressor(stream);
83 if (decompressor == null) {
84
85 return listener.onDataRead(ctx, streamId, data, padding, endOfStream);
86 }
87
88 final EmbeddedChannel channel = decompressor.decompressor();
89 final int compressedBytes = data.readableBytes() + padding;
90 decompressor.incrementCompressedBytes(compressedBytes);
91 try {
92
93 channel.writeInbound(data.retain());
94 ByteBuf buf = nextReadableBuf(channel);
95 if (buf == null && endOfStream && channel.finish()) {
96 buf = nextReadableBuf(channel);
97 }
98 if (buf == null) {
99 if (endOfStream) {
100 listener.onDataRead(ctx, streamId, Unpooled.EMPTY_BUFFER, padding, true);
101 }
102
103
104
105
106 decompressor.incrementDecompressedBytes(compressedBytes);
107 return compressedBytes;
108 }
109 try {
110 Http2LocalFlowController flowController = connection.local().flowController();
111 decompressor.incrementDecompressedBytes(padding);
112 for (;;) {
113 ByteBuf nextBuf = nextReadableBuf(channel);
114 boolean decompressedEndOfStream = nextBuf == null && endOfStream;
115 if (decompressedEndOfStream && channel.finish()) {
116 nextBuf = nextReadableBuf(channel);
117 decompressedEndOfStream = nextBuf == null;
118 }
119
120 decompressor.incrementDecompressedBytes(buf.readableBytes());
121
122
123
124 flowController.consumeBytes(stream,
125 listener.onDataRead(ctx, streamId, buf, padding, decompressedEndOfStream));
126 if (nextBuf == null) {
127 break;
128 }
129
130 padding = 0;
131 buf.release();
132 buf = nextBuf;
133 }
134
135
136
137 return 0;
138 } finally {
139 buf.release();
140 }
141 } catch (Http2Exception e) {
142 throw e;
143 } catch (Throwable t) {
144 throw streamError(stream.id(), INTERNAL_ERROR, t,
145 "Decompressor error detected while delegating data read on streamId %d", stream.id());
146 }
147 }
148
149 @Override
150 public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
151 boolean endStream) throws Http2Exception {
152 initDecompressor(ctx, streamId, headers, endStream);
153 listener.onHeadersRead(ctx, streamId, headers, padding, endStream);
154 }
155
156 @Override
157 public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency,
158 short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception {
159 initDecompressor(ctx, streamId, headers, endStream);
160 listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream);
161 }
162
163
164
165
166
167
168
169
170
171
172 protected EmbeddedChannel newContentDecompressor(final ChannelHandlerContext ctx, CharSequence contentEncoding)
173 throws Http2Exception {
174 if (GZIP.contentEqualsIgnoreCase(contentEncoding) || X_GZIP.contentEqualsIgnoreCase(contentEncoding)) {
175 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
176 ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP));
177 }
178 if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) {
179 final ZlibWrapper wrapper = strict ? ZlibWrapper.ZLIB : ZlibWrapper.ZLIB_OR_NONE;
180
181 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
182 ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(wrapper));
183 }
184 if (Brotli.isAvailable() && BR.contentEqualsIgnoreCase(contentEncoding)) {
185 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
186 ctx.channel().config(), new BrotliDecoder());
187 }
188 if (SNAPPY.contentEqualsIgnoreCase(contentEncoding)) {
189 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
190 ctx.channel().config(), new SnappyFrameDecoder());
191 }
192 if (Zstd.isAvailable() && ZSTD.contentEqualsIgnoreCase(contentEncoding)) {
193 return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
194 ctx.channel().config(), new ZstdDecoder());
195 }
196
197 return null;
198 }
199
200
201
202
203
204
205
206
207
208 protected CharSequence getTargetContentEncoding(@SuppressWarnings("UnusedParameters") CharSequence contentEncoding)
209 throws Http2Exception {
210 return IDENTITY;
211 }
212
213
214
215
216
217
218
219
220
221
222
223 private void initDecompressor(ChannelHandlerContext ctx, int streamId, Http2Headers headers, boolean endOfStream)
224 throws Http2Exception {
225 final Http2Stream stream = connection.stream(streamId);
226 if (stream == null) {
227 return;
228 }
229
230 Http2Decompressor decompressor = decompressor(stream);
231 if (decompressor == null && !endOfStream) {
232
233 CharSequence contentEncoding = headers.get(CONTENT_ENCODING);
234 if (contentEncoding == null) {
235 contentEncoding = IDENTITY;
236 }
237 final EmbeddedChannel channel = newContentDecompressor(ctx, contentEncoding);
238 if (channel != null) {
239 decompressor = new Http2Decompressor(channel);
240 stream.setProperty(propertyKey, decompressor);
241
242
243 CharSequence targetContentEncoding = getTargetContentEncoding(contentEncoding);
244 if (IDENTITY.contentEqualsIgnoreCase(targetContentEncoding)) {
245 headers.remove(CONTENT_ENCODING);
246 } else {
247 headers.set(CONTENT_ENCODING, targetContentEncoding);
248 }
249 }
250 }
251
252 if (decompressor != null) {
253
254
255
256 headers.remove(CONTENT_LENGTH);
257
258
259
260 if (!flowControllerInitialized) {
261 flowControllerInitialized = true;
262 connection.local().flowController(new ConsumedBytesConverter(connection.local().flowController()));
263 }
264 }
265 }
266
267 Http2Decompressor decompressor(Http2Stream stream) {
268 return stream == null ? null : (Http2Decompressor) stream.getProperty(propertyKey);
269 }
270
271
272
273
274
275
276 private static void cleanup(Http2Decompressor decompressor) {
277 decompressor.decompressor().finishAndReleaseAll();
278 }
279
280
281
282
283
284
285
286
287 private static ByteBuf nextReadableBuf(EmbeddedChannel decompressor) {
288 for (;;) {
289 final ByteBuf buf = decompressor.readInbound();
290 if (buf == null) {
291 return null;
292 }
293 if (!buf.isReadable()) {
294 buf.release();
295 continue;
296 }
297 return buf;
298 }
299 }
300
301
302
303
304 private final class ConsumedBytesConverter implements Http2LocalFlowController {
305 private final Http2LocalFlowController flowController;
306
307 ConsumedBytesConverter(Http2LocalFlowController flowController) {
308 this.flowController = checkNotNull(flowController, "flowController");
309 }
310
311 @Override
312 public Http2LocalFlowController frameWriter(Http2FrameWriter frameWriter) {
313 return flowController.frameWriter(frameWriter);
314 }
315
316 @Override
317 public void channelHandlerContext(ChannelHandlerContext ctx) throws Http2Exception {
318 flowController.channelHandlerContext(ctx);
319 }
320
321 @Override
322 public void initialWindowSize(int newWindowSize) throws Http2Exception {
323 flowController.initialWindowSize(newWindowSize);
324 }
325
326 @Override
327 public int initialWindowSize() {
328 return flowController.initialWindowSize();
329 }
330
331 @Override
332 public int windowSize(Http2Stream stream) {
333 return flowController.windowSize(stream);
334 }
335
336 @Override
337 public void incrementWindowSize(Http2Stream stream, int delta) throws Http2Exception {
338 flowController.incrementWindowSize(stream, delta);
339 }
340
341 @Override
342 public void receiveFlowControlledFrame(Http2Stream stream, ByteBuf data, int padding,
343 boolean endOfStream) throws Http2Exception {
344 flowController.receiveFlowControlledFrame(stream, data, padding, endOfStream);
345 }
346
347 @Override
348 public boolean consumeBytes(Http2Stream stream, int numBytes) throws Http2Exception {
349 Http2Decompressor decompressor = decompressor(stream);
350 if (decompressor != null) {
351
352 numBytes = decompressor.consumeBytes(stream.id(), numBytes);
353 }
354 try {
355 return flowController.consumeBytes(stream, numBytes);
356 } catch (Http2Exception e) {
357 throw e;
358 } catch (Throwable t) {
359
360
361 throw streamError(stream.id(), INTERNAL_ERROR, t, "Error while returning bytes to flow control window");
362 }
363 }
364
365 @Override
366 public int unconsumedBytes(Http2Stream stream) {
367 return flowController.unconsumedBytes(stream);
368 }
369
370 @Override
371 public int initialWindowSize(Http2Stream stream) {
372 return flowController.initialWindowSize(stream);
373 }
374 }
375
376
377
378
379 private static final class Http2Decompressor {
380 private final EmbeddedChannel decompressor;
381 private int compressed;
382 private int decompressed;
383
384 Http2Decompressor(EmbeddedChannel decompressor) {
385 this.decompressor = decompressor;
386 }
387
388
389
390
391 EmbeddedChannel decompressor() {
392 return decompressor;
393 }
394
395
396
397
398 void incrementCompressedBytes(int delta) {
399 assert delta >= 0;
400 compressed += delta;
401 }
402
403
404
405
406 void incrementDecompressedBytes(int delta) {
407 assert delta >= 0;
408 decompressed += delta;
409 }
410
411
412
413
414
415
416
417
418
419 int consumeBytes(int streamId, int decompressedBytes) throws Http2Exception {
420 checkPositiveOrZero(decompressedBytes, "decompressedBytes");
421 if (decompressed - decompressedBytes < 0) {
422 throw streamError(streamId, INTERNAL_ERROR,
423 "Attempting to return too many bytes for stream %d. decompressed: %d " +
424 "decompressedBytes: %d", streamId, decompressed, decompressedBytes);
425 }
426 double consumedRatio = decompressedBytes / (double) decompressed;
427 int consumedCompressed = Math.min(compressed, (int) Math.ceil(compressed * consumedRatio));
428 if (compressed - consumedCompressed < 0) {
429 throw streamError(streamId, INTERNAL_ERROR,
430 "overflow when converting decompressed bytes to compressed bytes for stream %d." +
431 "decompressedBytes: %d decompressed: %d compressed: %d consumedCompressed: %d",
432 streamId, decompressedBytes, decompressed, compressed, consumedCompressed);
433 }
434 decompressed -= decompressedBytes;
435 compressed -= consumedCompressed;
436
437 return consumedCompressed;
438 }
439 }
440 }