View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.compression;
17  
18  import com.github.luben.zstd.BaseZstdBufferDecompressingStreamNoFinalizer;
19  import com.github.luben.zstd.ZstdBufferDecompressingStreamNoFinalizer;
20  import com.github.luben.zstd.ZstdDirectBufferDecompressingStreamNoFinalizer;
21  import io.netty.buffer.ByteBuf;
22  import io.netty.buffer.ByteBufAllocator;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.handler.codec.ByteToMessageDecoder;
25  
26  import java.io.IOException;
27  import java.nio.ByteBuffer;
28  import java.util.List;
29  
30  /**
31   * Decompresses a compressed block {@link ByteBuf} using the Zstandard algorithm.
32   * See <a href="https://facebook.github.io/zstd">Zstandard</a>.
33   */
34  public final class ZstdDecoder extends ByteToMessageDecoder {
35      // Don't use static here as we want to still allow to load the classes.
36      {
37          try {
38              Zstd.ensureAvailability();
39              outCapacity = ZstdBufferDecompressingStreamNoFinalizer.recommendedTargetBufferSize();
40          } catch (Throwable throwable) {
41              throw new ExceptionInInitializerError(throwable);
42          }
43      }
44      private final int outCapacity;
45  
46      private State currentState = State.DECOMPRESS_DATA;
47      private ZstdStream stream;
48  
49      /**
50       * Current state of stream.
51       */
52      private enum State {
53          DECOMPRESS_DATA,
54          CORRUPTED
55      }
56  
57      @Override
58      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
59          try {
60              if (currentState == State.CORRUPTED) {
61                  in.skipBytes(in.readableBytes());
62                  return;
63              }
64              final int compressedLength = in.readableBytes();
65              if (compressedLength == 0) {
66                  // Nothing to decompress, try again later.
67                  return;
68              }
69              if (stream == null) {
70                  // We assume that if the first buffer is direct the next buffer will also most likely be direct.
71                  stream = new ZstdStream(in.isDirect(), outCapacity);
72              }
73  
74              do  {
75                  ByteBuf decompressed = stream.decompress(ctx.alloc(), in);
76                  if (decompressed == null) {
77                      return;
78                  }
79                  out.add(decompressed);
80              } while (in.isReadable());
81          } catch (DecompressionException e) {
82              currentState = State.CORRUPTED;
83              throw e;
84          }
85      }
86  
87      @Override
88      protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
89          try {
90              if (stream != null) {
91                  stream.close();
92                  stream = null;
93              }
94          } finally {
95              super.handlerRemoved0(ctx);
96          }
97      }
98  
99      private static final class ZstdStream {
100         private static final ByteBuffer EMPTY_HEAP_BUFFER = ByteBuffer.allocate(0);
101         private static final ByteBuffer EMPTY_DIRECT_BUFFER = ByteBuffer.allocateDirect(0);
102 
103         private final boolean direct;
104         private final int outCapacity;
105         private final BaseZstdBufferDecompressingStreamNoFinalizer decompressingStream;
106         private ByteBuffer current;
107 
108         ZstdStream(boolean direct, int outCapacity) {
109             this.direct = direct;
110             this.outCapacity = outCapacity;
111             if (direct) {
112                 decompressingStream = new ZstdDirectBufferDecompressingStreamNoFinalizer(EMPTY_DIRECT_BUFFER) {
113                     @Override
114                     protected ByteBuffer refill(ByteBuffer toRefill) {
115                         return ZstdStream.this.refill(toRefill);
116                     }
117                 };
118             } else {
119                 decompressingStream = new ZstdBufferDecompressingStreamNoFinalizer(EMPTY_HEAP_BUFFER) {
120                     @Override
121                     protected ByteBuffer refill(ByteBuffer toRefill) {
122                         return ZstdStream.this.refill(toRefill);
123                     }
124                 };
125             }
126         }
127 
128         ByteBuf decompress(ByteBufAllocator alloc, ByteBuf in) throws DecompressionException {
129             final ByteBuf source;
130             // Ensure we use the correct input buffer type.
131             if (direct && !in.isDirect()) {
132                 source = alloc.directBuffer(in.readableBytes());
133                 source.writeBytes(in, in.readerIndex(), in.readableBytes());
134             } else if (!direct && !in.hasArray()) {
135                 source = alloc.heapBuffer(in.readableBytes());
136                 source.writeBytes(in, in.readerIndex(), in.readableBytes());
137             } else {
138                 source = in;
139             }
140             int inPosition = -1;
141             ByteBuf outBuffer = null;
142             try {
143                 ByteBuffer inNioBuffer = CompressionUtil.safeNioBuffer(
144                         source, source.readerIndex(), source.readableBytes());
145                 inPosition = inNioBuffer.position();
146                 assert inNioBuffer.hasRemaining();
147                 current = inNioBuffer;
148 
149                 // allocate the outBuffer based on what we expect from the decompressingStream.
150                 if (direct) {
151                     outBuffer = alloc.directBuffer(outCapacity);
152                 } else {
153                     outBuffer = alloc.heapBuffer(outCapacity);
154                 }
155                 ByteBuffer target = outBuffer.internalNioBuffer(outBuffer.writerIndex(), outBuffer.writableBytes());
156                 int position = target.position();
157                 do {
158                     do {
159                         if (decompressingStream.read(target) == 0) {
160                             break;
161                         }
162                     } while (decompressingStream.hasRemaining() && target.hasRemaining() && current.hasRemaining());
163                     int written = target.position() - position;
164                     if (written > 0) {
165                         outBuffer.writerIndex(outBuffer.writerIndex() + written);
166                         ByteBuf out = outBuffer;
167                         outBuffer = null;
168                         return out;
169                     }
170                 } while (decompressingStream.hasRemaining() && current.hasRemaining());
171             } catch (IOException e) {
172                 throw new DecompressionException(e);
173             } finally {
174                 if (outBuffer != null) {
175                     outBuffer.release();
176                 }
177                 // Release in case of copy
178                 if (source != in) {
179                     source.release();
180                 }
181                 ByteBuffer buffer = current;
182                 current = null;
183                 if (inPosition != -1) {
184                     int read = buffer.position() - inPosition;
185                     if (read > 0) {
186                         in.skipBytes(read);
187                     }
188                 }
189             }
190             return null;
191         }
192 
193         private ByteBuffer refill(@SuppressWarnings("unused") ByteBuffer toRefill) {
194             return current;
195         }
196 
197         void close() {
198             decompressingStream.close();
199         }
200     }
201 }