View Javadoc
1   /*
2    * Copyright 2021 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.compression;
17  
18  import com.github.luben.zstd.Zstd;
19  import io.netty.buffer.ByteBuf;
20  import io.netty.buffer.Unpooled;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.handler.codec.EncoderException;
23  import io.netty.handler.codec.MessageToByteEncoder;
24  import io.netty.util.internal.ObjectUtil;
25  import java.nio.ByteBuffer;
26  
27  import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_COMPRESSION_LEVEL;
28  import static io.netty.handler.codec.compression.ZstdConstants.MIN_COMPRESSION_LEVEL;
29  import static io.netty.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL;
30  import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE;
31  import static io.netty.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE;
32  
33  /**
34   *  Compresses a {@link ByteBuf} using the Zstandard algorithm.
35   *  See <a href="https://facebook.github.io/zstd">Zstandard</a>.
36   */
37  public final class ZstdEncoder extends MessageToByteEncoder<ByteBuf> {
38      // Don't use static here as we want to still allow to load the classes.
39      {
40          try {
41              io.netty.handler.codec.compression.Zstd.ensureAvailability();
42          } catch (Throwable throwable) {
43              throw new ExceptionInInitializerError(throwable);
44          }
45      }
46      private final int blockSize;
47      private final int compressionLevel;
48      private final int maxEncodeSize;
49      private ByteBuf buffer;
50  
51      /**
52       * Creates a new Zstd encoder.
53       *
54       * Please note that if you use the default constructor, the default BLOCK_SIZE and MAX_BLOCK_SIZE
55       * will be used. If you want to specify BLOCK_SIZE and MAX_BLOCK_SIZE yourself,
56       * please use {@link ZstdEncoder(int,int)} constructor
57       */
58      public ZstdEncoder() {
59          this(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE);
60      }
61  
62      /**
63       * Creates a new Zstd encoder.
64       *  @param  compressionLevel
65       *            specifies the level of the compression
66       */
67      public ZstdEncoder(int compressionLevel) {
68          this(compressionLevel, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE);
69      }
70  
71      /**
72       * Creates a new Zstd encoder.
73       *  @param  blockSize
74       *            is used to calculate the compressionLevel
75       *  @param  maxEncodeSize
76       *            specifies the size of the largest compressed object
77       */
78      public ZstdEncoder(int blockSize, int maxEncodeSize) {
79          this(DEFAULT_COMPRESSION_LEVEL, blockSize, maxEncodeSize);
80      }
81  
82      /**
83       * @param  blockSize
84       *           is used to calculate the compressionLevel
85       * @param  maxEncodeSize
86       *           specifies the size of the largest compressed object
87       * @param  compressionLevel
88       *           specifies the level of the compression
89       */
90      public ZstdEncoder(int compressionLevel, int blockSize, int maxEncodeSize) {
91          super(true);
92          this.compressionLevel = ObjectUtil.checkInRange(compressionLevel,
93                  MIN_COMPRESSION_LEVEL, MAX_COMPRESSION_LEVEL, "compressionLevel");
94          this.blockSize = ObjectUtil.checkPositive(blockSize, "blockSize");
95          this.maxEncodeSize = ObjectUtil.checkPositive(maxEncodeSize, "maxEncodeSize");
96      }
97  
98      @Override
99      protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, boolean preferDirect) {
100         if (buffer == null) {
101             throw new IllegalStateException("not added to a pipeline," +
102                     "or has been removed,buffer is null");
103         }
104 
105         int remaining = msg.readableBytes() + buffer.readableBytes();
106 
107         // quick overflow check
108         if (remaining < 0) {
109             throw new EncoderException("too much data to allocate a buffer for compression");
110         }
111 
112         long bufferSize = 0;
113         while (remaining > 0) {
114             int curSize = Math.min(blockSize, remaining);
115             remaining -= curSize;
116             bufferSize += Zstd.compressBound(curSize);
117         }
118 
119         if (bufferSize > maxEncodeSize || 0 > bufferSize) {
120             throw new EncoderException("requested encode buffer size (" + bufferSize + " bytes) exceeds " +
121                     "the maximum allowable size (" + maxEncodeSize + " bytes)");
122         }
123 
124         return ctx.alloc().directBuffer((int) bufferSize);
125     }
126 
127     @Override
128     protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) {
129         if (buffer == null) {
130             throw new IllegalStateException("not added to a pipeline," +
131                     "or has been removed,buffer is null");
132         }
133 
134         final ByteBuf buffer = this.buffer;
135         int length;
136         while ((length = in.readableBytes()) > 0) {
137             final int nextChunkSize = Math.min(length, buffer.writableBytes());
138             in.readBytes(buffer, nextChunkSize);
139 
140             if (!buffer.isWritable()) {
141                 flushBufferedData(out);
142             }
143         }
144     }
145 
146     private void flushBufferedData(ByteBuf out) {
147         final int flushableBytes = buffer.readableBytes();
148         if (flushableBytes == 0) {
149             return;
150         }
151 
152         final int bufSize = (int) Zstd.compressBound(flushableBytes);
153         out.ensureWritable(bufSize);
154         final int idx = out.writerIndex();
155         int compressedLength;
156         try {
157             ByteBuffer outNioBuffer = out.internalNioBuffer(idx, out.writableBytes());
158             compressedLength = Zstd.compress(
159                     outNioBuffer,
160                     buffer.internalNioBuffer(buffer.readerIndex(), flushableBytes),
161                     compressionLevel);
162         } catch (Exception e) {
163             throw new CompressionException(e);
164         }
165 
166         out.writerIndex(idx + compressedLength);
167         buffer.clear();
168     }
169 
170     @Override
171     public void flush(final ChannelHandlerContext ctx) {
172         if (buffer != null && buffer.isReadable()) {
173             final ByteBuf buf = allocateBuffer(ctx, Unpooled.EMPTY_BUFFER, isPreferDirect());
174             flushBufferedData(buf);
175             ctx.write(buf);
176         }
177         ctx.flush();
178     }
179 
180     @Override
181     public void handlerAdded(ChannelHandlerContext ctx) {
182         buffer = ctx.alloc().directBuffer(blockSize);
183         buffer.clear();
184     }
185 
186     @Override
187     public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
188         super.handlerRemoved(ctx);
189         if (buffer != null) {
190             buffer.release();
191             buffer = null;
192         }
193     }
194 }