View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.compression;
17  
18  import com.github.luben.zstd.ZstdIOException;
19  import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.handler.codec.ByteToMessageDecoder;
23  import io.netty.util.internal.ObjectUtil;
24  
25  import java.io.Closeable;
26  import java.io.IOException;
27  import java.io.InputStream;
28  import java.util.List;
29  
30  /**
31   * Decompresses a compressed block {@link ByteBuf} using the Zstandard algorithm.
32   * See <a href="https://facebook.github.io/zstd">Zstandard</a>.
33   */
34  public final class ZstdDecoder extends ByteToMessageDecoder {
35      // Don't use static here as we want to still allow to load the classes.
36      {
37          try {
38              Zstd.ensureAvailability();
39          } catch (Throwable throwable) {
40              throw new ExceptionInInitializerError(throwable);
41          }
42      }
43  
44      private final int maximumAllocationSize;
45      private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream();
46      private ZstdInputStreamNoFinalizer zstdIs;
47  
48      private boolean needsRead;
49      private State currentState = State.DECOMPRESS_DATA;
50  
51      /**
52       * Current state of stream.
53       */
54      private enum State {
55          DECOMPRESS_DATA,
56          CORRUPTED
57      }
58  
59      public ZstdDecoder() {
60          this(4 * 1024 * 1024);
61      }
62  
63      public ZstdDecoder(int maximumAllocationSize) {
64          this.maximumAllocationSize = ObjectUtil.checkPositiveOrZero(maximumAllocationSize, "maximumAllocationSize");
65      }
66  
67      @Override
68      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69          needsRead = true;
70          try {
71              if (currentState == State.CORRUPTED) {
72                  in.skipBytes(in.readableBytes());
73  
74                  return;
75              }
76              inputStream.current = in;
77  
78              ByteBuf outBuffer = null;
79  
80              final int compressedLength = in.readableBytes();
81              try {
82                  long uncompressedLength;
83                  if (in.isDirect()) {
84                      uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
85                              CompressionUtil.safeNioBuffer(in, in.readerIndex(), in.readableBytes()));
86                  } else {
87                      uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
88                              in.array(), in.readerIndex() + in.arrayOffset(), in.readableBytes());
89                  }
90                  if (uncompressedLength <= 0) {
91                      // Let's start with the compressedLength * 2 as often we will not have everything
92                      // we need in the in buffer and don't want to reserve too much memory.
93                      uncompressedLength = compressedLength * 2L;
94                  }
95  
96                  int w;
97                  do {
98                      if (outBuffer == null) {
99                          outBuffer = ctx.alloc().heapBuffer((int) (maximumAllocationSize == 0 ?
100                                 uncompressedLength : Math.min(maximumAllocationSize, uncompressedLength)));
101                     }
102                     do {
103                         w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes());
104                     } while (w != -1 && outBuffer.isWritable());
105                     if (outBuffer.isReadable()) {
106                         needsRead = false;
107                         ctx.fireChannelRead(outBuffer);
108                         outBuffer = null;
109                     }
110                 } while (w != -1);
111             } finally {
112                 if (outBuffer != null) {
113                     outBuffer.release();
114                 }
115             }
116         } catch (Exception e) {
117             currentState = State.CORRUPTED;
118             throw new DecompressionException(e);
119         } finally {
120             inputStream.current = null;
121         }
122     }
123 
124     @Override
125     public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
126         // Discard bytes of the cumulation buffer if needed.
127         discardSomeReadBytes();
128 
129         if (needsRead && !ctx.channel().config().isAutoRead()) {
130             ctx.read();
131         }
132         ctx.fireChannelReadComplete();
133     }
134 
135     @Override
136     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
137         super.handlerAdded(ctx);
138         zstdIs = new ZstdInputStreamNoFinalizer(inputStream);
139         zstdIs.setContinuous(true);
140     }
141 
142     @Override
143     protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
144         try {
145             closeSilently(zstdIs);
146         } finally {
147             super.handlerRemoved0(ctx);
148         }
149     }
150 
151     private static void closeSilently(Closeable closeable) {
152         if (closeable != null) {
153             try {
154                 closeable.close();
155             } catch (IOException ignore) {
156                 // ignore
157             }
158         }
159     }
160 
161     private static final class MutableByteBufInputStream extends InputStream {
162         ByteBuf current;
163 
164         @Override
165         public int read() {
166             if (current == null || !current.isReadable()) {
167                 return -1;
168             }
169             return current.readByte() & 0xff;
170         }
171 
172         @Override
173         public int read(byte[] b, int off, int len) {
174             int available = available();
175             if (available == 0) {
176                 return -1;
177             }
178 
179             len = Math.min(available, len);
180             current.readBytes(b, off, len);
181             return len;
182         }
183 
184         @Override
185         public int available() {
186             return current == null ? 0 : current.readableBytes();
187         }
188     }
189 }