View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.compression;
17  
18  import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
19  import io.netty.buffer.ByteBuf;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.handler.codec.ByteToMessageDecoder;
22  import io.netty.util.internal.ObjectUtil;
23  
24  import java.io.Closeable;
25  import java.io.IOException;
26  import java.io.InputStream;
27  import java.util.List;
28  
29  /**
30   * Decompresses a compressed block {@link ByteBuf} using the Zstandard algorithm.
31   * See <a href="https://facebook.github.io/zstd">Zstandard</a>.
32   */
33  public final class ZstdDecoder extends ByteToMessageDecoder {
34      // Don't use static here as we want to still allow to load the classes.
35      {
36          try {
37              Zstd.ensureAvailability();
38          } catch (Throwable throwable) {
39              throw new ExceptionInInitializerError(throwable);
40          }
41      }
42  
43      private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES;
44      /**
45       * Default maximum size of a single output buffer, in bytes (4 MiB).
46       */
47      public static final int DEFAULT_MAXIMUM_ALLOCATION_SIZE = 4 * 1024 * 1024;
48      /**
49       * Default upper bound on the {@code Window_Log} accepted by the decoder.
50       * {@code 27} corresponds to a 128 MiB decompression window.
51       */
52      public static final int DEFAULT_MAX_WINDOW_LOG = 27;
53      private static final int MIN_WINDOW_LOG = 10;
54      private static final int MAX_WINDOW_LOG = 31;
55      private final int maximumAllocationSize;
56      private final int maxForwardBytes;
57      private final int maxWindowLog;
58      private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream();
59      private ZstdInputStreamNoFinalizer zstdIs;
60  
61      private boolean needsRead;
62      private State currentState = State.DECOMPRESS_DATA;
63  
64      /**
65       * Current state of stream.
66       */
67      private enum State {
68          DECOMPRESS_DATA,
69          CORRUPTED
70      }
71  
72      /**
73       * Creates a new decoder with the {@link #DEFAULT_MAXIMUM_ALLOCATION_SIZE},
74       * and the {@link #DEFAULT_MAX_WINDOW_LOG} window log size.
75       * <p>
76       * The window log size bounds the memory usage of the sliding window for ZSTD frame decompression.
77       * Frames declaring a larger window will be rejected to bound the memory the decoder may allocate per stream.
78       *
79       */
80      public ZstdDecoder() {
81          this(DEFAULT_MAXIMUM_ALLOCATION_SIZE, DEFAULT_MAX_WINDOW_LOG);
82      }
83  
84      /**
85       * Creates a new decoder with the given maximum allocation size,
86       * and the {@link #DEFAULT_MAX_WINDOW_LOG} window log size.
87       * <p>
88       * The window log size bounds the memory usage of the sliding window for ZSTD frame decompression.
89       * Frames declaring a larger window will be rejected to bound the memory the decoder may allocate per stream.
90       *
91       * @param maximumAllocationSize maximum size of a single output buffer.
92       */
93      public ZstdDecoder(int maximumAllocationSize) {
94          this(maximumAllocationSize, DEFAULT_MAX_WINDOW_LOG);
95      }
96  
97      /**
98       * Creates a new decoder with an explicit upper bound on the accepted {@code Window_Log}.
99       *
100      * @param maximumAllocationSize maximum size of a single output buffer.
101      * @param maxWindowLog          upper bound on the {@code Window_Log} field of incoming
102      *                              frames; must be in {@code [10, 31]}. Frames declaring a
103      *                              larger window will be rejected to bound the memory the
104      *                              decoder may allocate per stream.
105      */
106     public ZstdDecoder(int maximumAllocationSize, int maxWindowLog) {
107         this.maximumAllocationSize = ObjectUtil.checkPositiveOrZero(maximumAllocationSize, "maximumAllocationSize");
108         this.maxForwardBytes = maximumAllocationSize > 0 ? maximumAllocationSize : DEFAULT_MAX_FORWARD_BYTES;
109         this.maxWindowLog = ObjectUtil.checkInRange(maxWindowLog, MIN_WINDOW_LOG, MAX_WINDOW_LOG, "maxWindowLog");
110     }
111 
112     @Override
113     protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
114         needsRead = true;
115         try {
116             if (currentState == State.CORRUPTED) {
117                 in.skipBytes(in.readableBytes());
118 
119                 return;
120             }
121             inputStream.current = in;
122 
123             ByteBuf outBuffer = null;
124 
125             final int compressedLength = in.readableBytes();
126             try {
127                 long uncompressedLength;
128                 if (in.isDirect()) {
129                     uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
130                             CompressionUtil.safeNioBuffer(in, in.readerIndex(), in.readableBytes()));
131                 } else {
132                     uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
133                             in.array(), in.readerIndex() + in.arrayOffset(), in.readableBytes());
134                 }
135                 if (uncompressedLength <= 0) {
136                     // Let's start with the compressedLength * 2 as often we will not have everything
137                     // we need in the in buffer and don't want to reserve too much memory.
138                     uncompressedLength = compressedLength * 2L;
139                 }
140 
141                 int w;
142                 do {
143                     if (outBuffer == null) {
144                         outBuffer = ctx.alloc().heapBuffer((int) (maximumAllocationSize == 0 ?
145                                 uncompressedLength : Math.min(maximumAllocationSize, uncompressedLength)));
146                     }
147                     do {
148                         w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes());
149                     } while (w > 0 && outBuffer.isWritable());
150                     if (!outBuffer.isWritable() || outBuffer.readableBytes() >= maxForwardBytes) {
151                         needsRead = false;
152                         ctx.fireChannelRead(outBuffer);
153                         outBuffer = null;
154                     }
155                 } while (w > 0);
156                 if (outBuffer != null && outBuffer.isReadable()) {
157                     needsRead = false;
158                     ctx.fireChannelRead(outBuffer);
159                     outBuffer = null;
160                 }
161             } finally {
162                 if (outBuffer != null) {
163                     outBuffer.release();
164                 }
165             }
166         } catch (Exception e) {
167             currentState = State.CORRUPTED;
168             throw new DecompressionException(e);
169         } finally {
170             inputStream.current = null;
171         }
172     }
173 
174     @Override
175     public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
176         // Discard bytes of the cumulation buffer if needed.
177         discardSomeReadBytes();
178 
179         if (needsRead && !ctx.channel().config().isAutoRead()) {
180             ctx.read();
181         }
182         ctx.fireChannelReadComplete();
183     }
184 
185     @Override
186     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
187         super.handlerAdded(ctx);
188         zstdIs = new ZstdInputStreamNoFinalizer(inputStream);
189         zstdIs.setContinuous(true);
190         // Bound the decompression window to mitigate memory amplification from frames that
191         // declare an oversized Window_Size.
192         zstdIs.setLongMax(maxWindowLog);
193     }
194 
195     @Override
196     protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
197         try {
198             closeSilently(zstdIs);
199         } finally {
200             super.handlerRemoved0(ctx);
201         }
202     }
203 
204     private static void closeSilently(Closeable closeable) {
205         if (closeable != null) {
206             try {
207                 closeable.close();
208             } catch (IOException ignore) {
209                 // ignore
210             }
211         }
212     }
213 
214     private static final class MutableByteBufInputStream extends InputStream {
215         ByteBuf current;
216 
217         @Override
218         public int read() {
219             if (current == null || !current.isReadable()) {
220                 return -1;
221             }
222             return current.readByte() & 0xff;
223         }
224 
225         @Override
226         public int read(byte[] b, int off, int len) {
227             int available = available();
228             if (available == 0) {
229                 return -1;
230             }
231 
232             len = Math.min(available, len);
233             current.readBytes(b, off, len);
234             return len;
235         }
236 
237         @Override
238         public int available() {
239             return current == null ? 0 : current.readableBytes();
240         }
241     }
242 }