View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.compression;
17  
18  import com.github.luben.zstd.ZstdIOException;
19  import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.handler.codec.ByteToMessageDecoder;
23  import io.netty.util.internal.ObjectUtil;
24  
25  import java.io.Closeable;
26  import java.io.IOException;
27  import java.io.InputStream;
28  import java.util.List;
29  
30  /**
31   * Decompresses a compressed block {@link ByteBuf} using the Zstandard algorithm.
32   * See <a href="https://facebook.github.io/zstd">Zstandard</a>.
33   */
34  public final class ZstdDecoder extends ByteToMessageDecoder {
35      // Don't use static here as we want to still allow to load the classes.
36      {
37          try {
38              Zstd.ensureAvailability();
39          } catch (Throwable throwable) {
40              throw new ExceptionInInitializerError(throwable);
41          }
42      }
43  
44      private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES;
45  
46      private final int maximumAllocationSize;
47      private final int maxForwardBytes;
48      private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream();
49      private ZstdInputStreamNoFinalizer zstdIs;
50  
51      private boolean needsRead;
52      private State currentState = State.DECOMPRESS_DATA;
53  
54      /**
55       * Current state of stream.
56       */
57      private enum State {
58          DECOMPRESS_DATA,
59          CORRUPTED
60      }
61  
62      public ZstdDecoder() {
63          this(4 * 1024 * 1024);
64      }
65  
66      public ZstdDecoder(int maximumAllocationSize) {
67          this.maximumAllocationSize = ObjectUtil.checkPositiveOrZero(maximumAllocationSize, "maximumAllocationSize");
68          this.maxForwardBytes = maximumAllocationSize > 0 ? maximumAllocationSize : DEFAULT_MAX_FORWARD_BYTES;
69      }
70  
71      @Override
72      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
73          needsRead = true;
74          try {
75              if (currentState == State.CORRUPTED) {
76                  in.skipBytes(in.readableBytes());
77  
78                  return;
79              }
80              inputStream.current = in;
81  
82              ByteBuf outBuffer = null;
83  
84              final int compressedLength = in.readableBytes();
85              try {
86                  long uncompressedLength;
87                  if (in.isDirect()) {
88                      uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
89                              CompressionUtil.safeNioBuffer(in, in.readerIndex(), in.readableBytes()));
90                  } else {
91                      uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
92                              in.array(), in.readerIndex() + in.arrayOffset(), in.readableBytes());
93                  }
94                  if (uncompressedLength <= 0) {
95                      // Let's start with the compressedLength * 2 as often we will not have everything
96                      // we need in the in buffer and don't want to reserve too much memory.
97                      uncompressedLength = compressedLength * 2L;
98                  }
99  
100                 int w;
101                 do {
102                     if (outBuffer == null) {
103                         outBuffer = ctx.alloc().heapBuffer((int) (maximumAllocationSize == 0 ?
104                                 uncompressedLength : Math.min(maximumAllocationSize, uncompressedLength)));
105                     }
106                     do {
107                         w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes());
108                     } while (w > 0 && outBuffer.isWritable());
109                     if (!outBuffer.isWritable() || outBuffer.readableBytes() >= maxForwardBytes) {
110                         needsRead = false;
111                         ctx.fireChannelRead(outBuffer);
112                         outBuffer = null;
113                     }
114                 } while (w > 0);
115                 if (outBuffer != null && outBuffer.isReadable()) {
116                     needsRead = false;
117                     ctx.fireChannelRead(outBuffer);
118                     outBuffer = null;
119                 }
120             } finally {
121                 if (outBuffer != null) {
122                     outBuffer.release();
123                 }
124             }
125         } catch (Exception e) {
126             currentState = State.CORRUPTED;
127             throw new DecompressionException(e);
128         } finally {
129             inputStream.current = null;
130         }
131     }
132 
133     @Override
134     public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
135         // Discard bytes of the cumulation buffer if needed.
136         discardSomeReadBytes();
137 
138         if (needsRead && !ctx.channel().config().isAutoRead()) {
139             ctx.read();
140         }
141         ctx.fireChannelReadComplete();
142     }
143 
144     @Override
145     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
146         super.handlerAdded(ctx);
147         zstdIs = new ZstdInputStreamNoFinalizer(inputStream);
148         zstdIs.setContinuous(true);
149     }
150 
151     @Override
152     protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
153         try {
154             closeSilently(zstdIs);
155         } finally {
156             super.handlerRemoved0(ctx);
157         }
158     }
159 
160     private static void closeSilently(Closeable closeable) {
161         if (closeable != null) {
162             try {
163                 closeable.close();
164             } catch (IOException ignore) {
165                 // ignore
166             }
167         }
168     }
169 
170     private static final class MutableByteBufInputStream extends InputStream {
171         ByteBuf current;
172 
173         @Override
174         public int read() {
175             if (current == null || !current.isReadable()) {
176                 return -1;
177             }
178             return current.readByte() & 0xff;
179         }
180 
181         @Override
182         public int read(byte[] b, int off, int len) {
183             int available = available();
184             if (available == 0) {
185                 return -1;
186             }
187 
188             len = Math.min(available, len);
189             current.readBytes(b, off, len);
190             return len;
191         }
192 
193         @Override
194         public int available() {
195             return current == null ? 0 : current.readableBytes();
196         }
197     }
198 }