View Javadoc
1   /*
2    * Copyright 2021 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.compression;
17  
18  import com.github.luben.zstd.Zstd;
19  import io.netty.buffer.ByteBuf;
20  import io.netty.buffer.Unpooled;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.handler.codec.EncoderException;
23  import io.netty.handler.codec.MessageToByteEncoder;
24  import io.netty.util.internal.ObjectUtil;
25  import java.nio.ByteBuffer;
26  
27  import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_COMPRESSION_LEVEL;
28  import static io.netty.handler.codec.compression.ZstdConstants.MIN_COMPRESSION_LEVEL;
29  import static io.netty.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL;
30  import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE;
31  import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_MAX_ENCODE_SIZE;
32  
33  /**
34   *  Compresses a {@link ByteBuf} using the Zstandard algorithm.
35   *  See <a href="https://facebook.github.io/zstd">Zstandard</a>.
36   */
37  public final class ZstdEncoder extends MessageToByteEncoder<ByteBuf> {
38      // Don't use static here as we want to still allow to load the classes.
39      {
40          try {
41              io.netty.handler.codec.compression.Zstd.ensureAvailability();
42          } catch (Throwable throwable) {
43              throw new ExceptionInInitializerError(throwable);
44          }
45      }
46      private final int blockSize;
47      private final int compressionLevel;
48      private final int maxEncodeSize;
49      private ByteBuf buffer;
50  
51      /**
52       * Creates a new Zstd encoder.
53       *
54       * Please note that if you use the default constructor, the default BLOCK_SIZE and MAX_BLOCK_SIZE
55       * will be used. If you want to specify BLOCK_SIZE and MAX_BLOCK_SIZE yourself,
56       * please use {@link ZstdEncoder(int,int)} constructor
57       */
58      public ZstdEncoder() {
59          this(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE);
60      }
61  
62      /**
63       * Creates a new Zstd encoder.
64       *  @param  compressionLevel
65       *            specifies the level of the compression
66       */
67      public ZstdEncoder(int compressionLevel) {
68          this(compressionLevel, DEFAULT_BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE);
69      }
70  
71      /**
72       * Creates a new Zstd encoder.
73       *  @param  blockSize
74       *            is used to calculate the compressionLevel
75       *  @param  maxEncodeSize
76       *            specifies the size of the largest compressed object
77       */
78      public ZstdEncoder(int blockSize, int maxEncodeSize) {
79          this(DEFAULT_COMPRESSION_LEVEL, blockSize, maxEncodeSize);
80      }
81  
82      /**
83       * @param  blockSize
84       *           is used to calculate the compressionLevel
85       * @param  maxEncodeSize
86       *           specifies the size of the largest compressed object
87       * @param  compressionLevel
88       *           specifies the level of the compression
89       */
90      public ZstdEncoder(int compressionLevel, int blockSize, int maxEncodeSize) {
91          super(ByteBuf.class, true);
92          this.compressionLevel = ObjectUtil.checkInRange(compressionLevel,
93                  MIN_COMPRESSION_LEVEL, MAX_COMPRESSION_LEVEL, "compressionLevel");
94          this.blockSize = ObjectUtil.checkPositive(blockSize, "blockSize");
95          this.maxEncodeSize = ObjectUtil.checkPositive(maxEncodeSize, "maxEncodeSize");
96      }
97  
98      @Override
99      protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, boolean preferDirect) {
100         if (buffer == null) {
101             throw new IllegalStateException("not added to a pipeline," +
102                     "or has been removed,buffer is null");
103         }
104 
105         int remaining = msg.readableBytes() + buffer.readableBytes();
106 
107         // quick overflow check
108         if (remaining < 0) {
109             throw new EncoderException("too much data to allocate a buffer for compression");
110         }
111 
112         long bufferSize = 0;
113         while (remaining > 0) {
114             int curSize = Math.min(blockSize, remaining);
115             remaining -= curSize;
116             // calculate the max compressed size with Zstd.compressBound since
117             // it returns the maximum size of the compressed data
118             bufferSize = Math.max(bufferSize, Zstd.compressBound(curSize));
119         }
120 
121         if (bufferSize > maxEncodeSize || 0 > bufferSize) {
122             throw new EncoderException("requested encode buffer size (" + bufferSize + " bytes) exceeds " +
123                     "the maximum allowable size (" + maxEncodeSize + " bytes)");
124         }
125 
126         return ctx.alloc().directBuffer((int) bufferSize);
127     }
128 
129     @Override
130     protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) {
131         if (buffer == null) {
132             throw new IllegalStateException("not added to a pipeline," +
133                     "or has been removed,buffer is null");
134         }
135 
136         final ByteBuf buffer = this.buffer;
137         int length;
138         while ((length = in.readableBytes()) > 0) {
139             final int nextChunkSize = Math.min(length, buffer.writableBytes());
140             in.readBytes(buffer, nextChunkSize);
141 
142             if (!buffer.isWritable()) {
143                 flushBufferedData(out);
144             }
145         }
146         // return the remaining data in the buffer
147         // when buffer size is smaller than the block size
148         if (buffer.isReadable()) {
149             flushBufferedData(out);
150         }
151     }
152 
153     private void flushBufferedData(ByteBuf out) {
154         final int flushableBytes = buffer.readableBytes();
155         if (flushableBytes == 0) {
156             return;
157         }
158 
159         final int bufSize = (int) Zstd.compressBound(flushableBytes);
160         out.ensureWritable(bufSize);
161         final int idx = out.writerIndex();
162         int compressedLength;
163         try {
164             ByteBuffer outNioBuffer = out.internalNioBuffer(idx, out.writableBytes());
165             compressedLength = Zstd.compress(
166                     outNioBuffer,
167                     buffer.internalNioBuffer(buffer.readerIndex(), flushableBytes),
168                     compressionLevel);
169         } catch (Exception e) {
170             throw new CompressionException(e);
171         }
172 
173         out.writerIndex(idx + compressedLength);
174         buffer.clear();
175     }
176 
177     @Override
178     public void flush(final ChannelHandlerContext ctx) {
179         if (buffer != null && buffer.isReadable()) {
180             final ByteBuf buf = allocateBuffer(ctx, Unpooled.EMPTY_BUFFER, isPreferDirect());
181             flushBufferedData(buf);
182             ctx.write(buf);
183         }
184         ctx.flush();
185     }
186 
187     @Override
188     public void handlerAdded(ChannelHandlerContext ctx) {
189         buffer = ctx.alloc().directBuffer(blockSize);
190         buffer.clear();
191     }
192 
193     @Override
194     public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
195         super.handlerRemoved(ctx);
196         if (buffer != null) {
197             buffer.release();
198             buffer = null;
199         }
200     }
201 }