1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.compression;
17
18 import com.github.luben.zstd.Zstd;
19 import io.netty.buffer.ByteBuf;
20 import io.netty.buffer.Unpooled;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.handler.codec.EncoderException;
23 import io.netty.handler.codec.MessageToByteEncoder;
24 import io.netty.util.internal.ObjectUtil;
25 import java.nio.ByteBuffer;
26
27 import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_COMPRESSION_LEVEL;
28 import static io.netty.handler.codec.compression.ZstdConstants.MIN_COMPRESSION_LEVEL;
29 import static io.netty.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL;
30 import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE;
31 import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_MAX_ENCODE_SIZE;
32
33
34
35
36
37 public final class ZstdEncoder extends MessageToByteEncoder<ByteBuf> {
38
39 {
40 try {
41 io.netty.handler.codec.compression.Zstd.ensureAvailability();
42 } catch (Throwable throwable) {
43 throw new ExceptionInInitializerError(throwable);
44 }
45 }
46 private final int blockSize;
47 private final int compressionLevel;
48 private final int maxEncodeSize;
49 private ByteBuf buffer;
50
51
52
53
54
55
56
57
58 public ZstdEncoder() {
59 this(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE);
60 }
61
62
63
64
65
66
67 public ZstdEncoder(int compressionLevel) {
68 this(compressionLevel, DEFAULT_BLOCK_SIZE, DEFAULT_MAX_ENCODE_SIZE);
69 }
70
71
72
73
74
75
76
77
78 public ZstdEncoder(int blockSize, int maxEncodeSize) {
79 this(DEFAULT_COMPRESSION_LEVEL, blockSize, maxEncodeSize);
80 }
81
82
83
84
85
86
87
88
89
90 public ZstdEncoder(int compressionLevel, int blockSize, int maxEncodeSize) {
91 super(ByteBuf.class, true);
92 this.compressionLevel = ObjectUtil.checkInRange(compressionLevel,
93 MIN_COMPRESSION_LEVEL, MAX_COMPRESSION_LEVEL, "compressionLevel");
94 this.blockSize = ObjectUtil.checkPositive(blockSize, "blockSize");
95 this.maxEncodeSize = ObjectUtil.checkPositive(maxEncodeSize, "maxEncodeSize");
96 }
97
98 @Override
99 protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, boolean preferDirect) {
100 if (buffer == null) {
101 throw new IllegalStateException("not added to a pipeline," +
102 "or has been removed,buffer is null");
103 }
104
105 int remaining = msg.readableBytes() + buffer.readableBytes();
106
107
108 if (remaining < 0) {
109 throw new EncoderException("too much data to allocate a buffer for compression");
110 }
111
112 long bufferSize = 0;
113 while (remaining > 0) {
114 int curSize = Math.min(blockSize, remaining);
115 remaining -= curSize;
116
117
118 bufferSize = Math.max(bufferSize, Zstd.compressBound(curSize));
119 }
120
121 if (bufferSize > maxEncodeSize || 0 > bufferSize) {
122 throw new EncoderException("requested encode buffer size (" + bufferSize + " bytes) exceeds " +
123 "the maximum allowable size (" + maxEncodeSize + " bytes)");
124 }
125
126 return ctx.alloc().directBuffer((int) bufferSize);
127 }
128
129 @Override
130 protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) {
131 if (buffer == null) {
132 throw new IllegalStateException("not added to a pipeline," +
133 "or has been removed,buffer is null");
134 }
135
136 final ByteBuf buffer = this.buffer;
137 int length;
138 while ((length = in.readableBytes()) > 0) {
139 final int nextChunkSize = Math.min(length, buffer.writableBytes());
140 in.readBytes(buffer, nextChunkSize);
141
142 if (!buffer.isWritable()) {
143 flushBufferedData(out);
144 }
145 }
146
147
148 if (buffer.isReadable()) {
149 flushBufferedData(out);
150 }
151 }
152
153 private void flushBufferedData(ByteBuf out) {
154 final int flushableBytes = buffer.readableBytes();
155 if (flushableBytes == 0) {
156 return;
157 }
158
159 final int bufSize = (int) Zstd.compressBound(flushableBytes);
160 out.ensureWritable(bufSize);
161 final int idx = out.writerIndex();
162 int compressedLength;
163 try {
164 ByteBuffer outNioBuffer = out.internalNioBuffer(idx, out.writableBytes());
165 compressedLength = Zstd.compress(
166 outNioBuffer,
167 buffer.internalNioBuffer(buffer.readerIndex(), flushableBytes),
168 compressionLevel);
169 } catch (Exception e) {
170 throw new CompressionException(e);
171 }
172
173 out.writerIndex(idx + compressedLength);
174 buffer.clear();
175 }
176
177 @Override
178 public void flush(final ChannelHandlerContext ctx) {
179 if (buffer != null && buffer.isReadable()) {
180 final ByteBuf buf = allocateBuffer(ctx, Unpooled.EMPTY_BUFFER, isPreferDirect());
181 flushBufferedData(buf);
182 ctx.write(buf);
183 }
184 ctx.flush();
185 }
186
187 @Override
188 public void handlerAdded(ChannelHandlerContext ctx) {
189 buffer = ctx.alloc().directBuffer(blockSize);
190 buffer.clear();
191 }
192
193 @Override
194 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
195 super.handlerRemoved(ctx);
196 if (buffer != null) {
197 buffer.release();
198 buffer = null;
199 }
200 }
201 }