1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.codec.compression;
17
18 import com.github.luben.zstd.Zstd;
19 import io.netty5.buffer.api.Buffer;
20 import io.netty5.buffer.api.BufferAllocator;
21 import io.netty5.handler.codec.EncoderException;
22 import io.netty5.util.internal.ObjectUtil;
23
24 import java.nio.ByteBuffer;
25 import java.util.function.Supplier;
26
27 import static io.netty5.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE;
28 import static io.netty5.handler.codec.compression.ZstdConstants.DEFAULT_COMPRESSION_LEVEL;
29 import static io.netty5.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE;
30 import static io.netty5.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL;
31
32
33
34
35
36 public final class ZstdCompressor implements Compressor {
37
38 private final int blockSize;
39 private final int compressionLevel;
40 private final int maxEncodeSize;
41
42 private enum State {
43 PROCESSING,
44 FINISHED,
45 CLOSED
46 }
47
48 private State state = State.PROCESSING;
49
50
51
52
53
54
55
56
57
58
59 public static Supplier<ZstdCompressor> newFactory() {
60 return newFactory(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE);
61 }
62
63
64
65
66
67
68
69
70 public static Supplier<ZstdCompressor> newFactory(int compressionLevel) {
71 return newFactory(compressionLevel, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE);
72 }
73
74
75
76
77
78
79
80
81
82
83 public static Supplier<ZstdCompressor> newFactory(int blockSize, int maxEncodeSize) {
84 return newFactory(DEFAULT_COMPRESSION_LEVEL, blockSize, maxEncodeSize);
85 }
86
87
88
89
90
91
92
93
94
95
96 public static Supplier<ZstdCompressor> newFactory(int compressionLevel, int blockSize, int maxEncodeSize) {
97 ObjectUtil.checkInRange(compressionLevel, 0, MAX_COMPRESSION_LEVEL, "compressionLevel");
98 ObjectUtil.checkPositive(blockSize, "blockSize");
99 ObjectUtil.checkPositive(maxEncodeSize, "maxEncodeSize");
100 return () -> new ZstdCompressor(compressionLevel, blockSize, maxEncodeSize);
101 }
102
103
104
105
106
107
108
109
110 private ZstdCompressor(int compressionLevel, int blockSize, int maxEncodeSize) {
111 this.compressionLevel = compressionLevel;
112 this.blockSize = blockSize;
113 this.maxEncodeSize = maxEncodeSize;
114 }
115
116 private Buffer allocateBuffer(BufferAllocator allocator, Buffer msg) {
117 int remaining = msg.readableBytes();
118
119 long bufferSize = 0;
120 while (remaining > 0) {
121 int curSize = Math.min(blockSize, remaining);
122 remaining -= curSize;
123 bufferSize += Zstd.compressBound(curSize);
124 }
125
126 if (bufferSize > maxEncodeSize || 0 > bufferSize) {
127 throw new EncoderException("requested encode buffer size (" + bufferSize + " bytes) exceeds " +
128 "the maximum allowable size (" + maxEncodeSize + " bytes)");
129 }
130
131
132 return allocator.allocate((int) bufferSize);
133 }
134
135 @Override
136 public Buffer compress(Buffer in, BufferAllocator allocator) throws CompressionException {
137 switch (state) {
138 case CLOSED:
139 throw new CompressionException("Compressor closed");
140 case FINISHED:
141 return allocator.allocate(0);
142 case PROCESSING:
143 if (in.readableBytes() == 0) {
144 return allocator.allocate(0);
145 }
146 Buffer out = allocateBuffer(allocator, in);
147 try {
148 compressData(in, out);
149 return out;
150 } catch (Throwable cause) {
151 out.close();
152 throw cause;
153 }
154 default:
155 throw new IllegalStateException();
156 }
157 }
158
159 @Override
160 public Buffer finish(BufferAllocator allocator) {
161 switch (state) {
162 case CLOSED:
163 throw new CompressionException("Compressor closed");
164 case FINISHED:
165 case PROCESSING:
166 state = State.FINISHED;
167 return allocator.allocate(0);
168 default:
169 throw new IllegalStateException();
170 }
171 }
172
173 @Override
174 public boolean isFinished() {
175 return state != State.PROCESSING;
176 }
177
178 @Override
179 public boolean isClosed() {
180 return state == State.CLOSED;
181 }
182
183 @Override
184 public void close() {
185 state = State.CLOSED;
186 }
187
188 private void compressData(Buffer in, Buffer out) {
189 final int flushableBytes = in.readableBytes();
190 if (flushableBytes == 0) {
191 return;
192 }
193
194 final int bufSize = (int) Zstd.compressBound(flushableBytes);
195 out.ensureWritable(bufSize);
196 try {
197 assert out.countWritableComponents() == 1;
198 try (var writableIteration = out.forEachWritable()) {
199 var writableComponent = writableIteration.first();
200 try (var readableIteration = in.forEachReadable()) {
201 for (var readableComponent = readableIteration.first();
202 readableComponent != null; readableComponent = readableComponent.next()) {
203 final int compressedLength;
204 if (in.isDirect() && out.isDirect()) {
205 ByteBuffer inNioBuffer = readableComponent.readableBuffer();
206 compressedLength = Zstd.compress(
207 writableComponent.writableBuffer(),
208 inNioBuffer,
209 compressionLevel);
210 } else {
211 final byte[] inArray;
212 final int inOffset;
213 final int inLen = readableComponent.readableBytes();
214 if (readableComponent.hasReadableArray()) {
215 inArray = readableComponent.readableArray();
216 inOffset = readableComponent.readableArrayOffset();
217 } else {
218 inArray = new byte[inLen];
219 readableComponent.readableBuffer().get(inArray);
220 inOffset = 0;
221 }
222
223 final byte[] outArray;
224 final int outOffset;
225 final int outLen = writableComponent.writableBytes();
226 if (writableComponent.hasWritableArray()) {
227 outArray = writableComponent.writableArray();
228 outOffset = writableComponent.writableArrayOffset();
229 } else {
230 outArray = new byte[out.writableBytes()];
231 outOffset = 0;
232 }
233
234 compressedLength = (int) Zstd.compressByteArray(
235 outArray, outOffset, outLen, inArray, inOffset, inLen, compressionLevel);
236 if (!writableComponent.hasWritableArray()) {
237 writableComponent.writableBuffer().put(outArray);
238 }
239 }
240 writableComponent.skipWritableBytes(compressedLength);
241 readableComponent.skipReadableBytes(readableComponent.readableBytes());
242 }
243 }
244 }
245 } catch (Exception e) {
246 throw new CompressionException(e);
247 }
248 }
249 }