View Javadoc
1   /*
2    * Copyright 2021 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.handler.codec.compression;
17  
18  import com.github.luben.zstd.Zstd;
19  import io.netty5.buffer.api.Buffer;
20  import io.netty5.buffer.api.BufferAllocator;
21  import io.netty5.handler.codec.EncoderException;
22  import io.netty5.util.internal.ObjectUtil;
23  
24  import java.nio.ByteBuffer;
25  import java.util.function.Supplier;
26  
27  import static io.netty5.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE;
28  import static io.netty5.handler.codec.compression.ZstdConstants.DEFAULT_COMPRESSION_LEVEL;
29  import static io.netty5.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE;
30  import static io.netty5.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL;
31  
32  /**
33   *  Compresses a {@link Buffer} using the Zstandard algorithm.
34   *  See <a href="https://facebook.github.io/zstd">Zstandard</a>.
35   */
36  public final class ZstdCompressor implements Compressor {
37  
38      private final int blockSize;
39      private final int compressionLevel;
40      private final int maxEncodeSize;
41  
42      private enum State {
43          PROCESSING,
44          FINISHED,
45          CLOSED
46      }
47  
48      private State state = State.PROCESSING;
49  
50      /**
51       * Creates a new Zstd compressor factory.
52       *
53       * Please note that if you use the default constructor, the default BLOCK_SIZE and MAX_BLOCK_SIZE
54       * will be used. If you want to specify BLOCK_SIZE and MAX_BLOCK_SIZE yourself,
55       * please use the {@link #newFactory(int,int)} method.
56       *
57       * @return the factory.
58       */
59      public static Supplier<ZstdCompressor> newFactory() {
60          return newFactory(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE);
61      }
62  
63      /**
64       * Creates a new Zstd compressor factory.
65       *
66       *  @param  compressionLevel
67       *            specifies the level of the compression
68       * @return the factory.
69       */
70      public static Supplier<ZstdCompressor> newFactory(int compressionLevel) {
71          return newFactory(compressionLevel, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE);
72      }
73  
74      /**
75       * Creates a new Zstd compressor factory.
76       *
77       *  @param  blockSize
78       *            is used to calculate the compressionLevel
79       *  @param  maxEncodeSize
80       *            specifies the size of the largest compressed object
81       * @return the factory.
82       */
83      public static Supplier<ZstdCompressor> newFactory(int blockSize, int maxEncodeSize) {
84          return newFactory(DEFAULT_COMPRESSION_LEVEL, blockSize, maxEncodeSize);
85      }
86  
87      /**
88       * @param  blockSize
89       *           is used to calculate the compressionLevel
90       * @param  maxEncodeSize
91       *           specifies the size of the largest compressed object
92       * @param  compressionLevel
93       *           specifies the level of the compression
94       * @return the factory.
95       */
96      public static Supplier<ZstdCompressor> newFactory(int compressionLevel, int blockSize, int maxEncodeSize) {
97          ObjectUtil.checkInRange(compressionLevel, 0, MAX_COMPRESSION_LEVEL, "compressionLevel");
98          ObjectUtil.checkPositive(blockSize, "blockSize");
99          ObjectUtil.checkPositive(maxEncodeSize, "maxEncodeSize");
100         return () -> new ZstdCompressor(compressionLevel, blockSize, maxEncodeSize);
101     }
102     /**
103      * @param  blockSize
104      *           is used to calculate the compressionLevel
105      * @param  maxEncodeSize
106      *           specifies the size of the largest compressed object
107      * @param  compressionLevel
108      *           specifies the level of the compression
109      */
110     private ZstdCompressor(int compressionLevel, int blockSize, int maxEncodeSize) {
111         this.compressionLevel = compressionLevel;
112         this.blockSize = blockSize;
113         this.maxEncodeSize = maxEncodeSize;
114     }
115 
116     private Buffer allocateBuffer(BufferAllocator allocator, Buffer msg) {
117         int remaining = msg.readableBytes();
118 
119         long bufferSize = 0;
120         while (remaining > 0) {
121             int curSize = Math.min(blockSize, remaining);
122             remaining -= curSize;
123             bufferSize += Zstd.compressBound(curSize);
124         }
125 
126         if (bufferSize > maxEncodeSize || 0 > bufferSize) {
127             throw new EncoderException("requested encode buffer size (" + bufferSize + " bytes) exceeds " +
128                     "the maximum allowable size (" + maxEncodeSize + " bytes)");
129         }
130 
131         // TODO: It would be better if we could allocate depending on the input type
132         return allocator.allocate((int) bufferSize);
133     }
134 
135     @Override
136     public Buffer compress(Buffer in, BufferAllocator allocator) throws CompressionException {
137         switch (state) {
138             case CLOSED:
139                 throw new CompressionException("Compressor closed");
140             case FINISHED:
141                 return allocator.allocate(0);
142             case PROCESSING:
143                 if (in.readableBytes() == 0) {
144                     return allocator.allocate(0);
145                 }
146                 Buffer out = allocateBuffer(allocator, in);
147                 try {
148                     compressData(in, out);
149                     return out;
150                 } catch (Throwable cause) {
151                     out.close();
152                     throw cause;
153                 }
154             default:
155                 throw new IllegalStateException();
156         }
157     }
158 
159     @Override
160     public Buffer finish(BufferAllocator allocator) {
161         switch (state) {
162             case CLOSED:
163                 throw new CompressionException("Compressor closed");
164             case FINISHED:
165             case PROCESSING:
166                 state = State.FINISHED;
167                 return allocator.allocate(0);
168             default:
169                 throw new IllegalStateException();
170         }
171     }
172 
173     @Override
174     public boolean isFinished() {
175         return state != State.PROCESSING;
176     }
177 
178     @Override
179     public boolean isClosed() {
180         return state == State.CLOSED;
181     }
182 
183     @Override
184     public void close() {
185         state = State.CLOSED;
186     }
187 
188     private void compressData(Buffer in, Buffer out) {
189         final int flushableBytes = in.readableBytes();
190         if (flushableBytes == 0) {
191             return;
192         }
193 
194         final int bufSize = (int) Zstd.compressBound(flushableBytes);
195         out.ensureWritable(bufSize);
196         try {
197             assert out.countWritableComponents() == 1;
198             try (var writableIteration = out.forEachWritable()) {
199                 var writableComponent = writableIteration.first();
200                 try (var readableIteration = in.forEachReadable()) {
201                     for (var readableComponent = readableIteration.first();
202                          readableComponent != null; readableComponent = readableComponent.next()) {
203                         final int compressedLength;
204                         if (in.isDirect() && out.isDirect()) {
205                             ByteBuffer inNioBuffer = readableComponent.readableBuffer();
206                             compressedLength = Zstd.compress(
207                                     writableComponent.writableBuffer(),
208                                     inNioBuffer,
209                                     compressionLevel);
210                         } else {
211                             final byte[] inArray;
212                             final int inOffset;
213                             final int inLen = readableComponent.readableBytes();
214                             if (readableComponent.hasReadableArray()) {
215                                 inArray = readableComponent.readableArray();
216                                 inOffset = readableComponent.readableArrayOffset();
217                             } else {
218                                 inArray = new byte[inLen];
219                                 readableComponent.readableBuffer().get(inArray);
220                                 inOffset = 0;
221                             }
222 
223                             final byte[] outArray;
224                             final int outOffset;
225                             final int outLen = writableComponent.writableBytes();
226                             if (writableComponent.hasWritableArray()) {
227                                 outArray = writableComponent.writableArray();
228                                 outOffset = writableComponent.writableArrayOffset();
229                             } else {
230                                 outArray = new byte[out.writableBytes()];
231                                 outOffset = 0;
232                             }
233 
234                             compressedLength = (int) Zstd.compressByteArray(
235                                     outArray, outOffset, outLen, inArray, inOffset, inLen, compressionLevel);
236                             if (!writableComponent.hasWritableArray()) {
237                                 writableComponent.writableBuffer().put(outArray);
238                             }
239                         }
240                         writableComponent.skipWritableBytes(compressedLength);
241                         readableComponent.skipReadableBytes(readableComponent.readableBytes());
242                     }
243                 }
244             }
245         } catch (Exception e) {
246             throw new CompressionException(e);
247         }
248     }
249 }