1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.codec.compression;
17
18 import io.netty5.buffer.api.Buffer;
19 import io.netty5.buffer.api.BufferAllocator;
20 import net.jpountz.lz4.LZ4Exception;
21 import net.jpountz.lz4.LZ4Factory;
22 import net.jpountz.lz4.LZ4FastDecompressor;
23
24 import java.util.function.Supplier;
25 import java.util.zip.Checksum;
26
27 import static io.netty5.handler.codec.compression.Lz4Constants.BLOCK_TYPE_COMPRESSED;
28 import static io.netty5.handler.codec.compression.Lz4Constants.BLOCK_TYPE_NON_COMPRESSED;
29 import static io.netty5.handler.codec.compression.Lz4Constants.COMPRESSION_LEVEL_BASE;
30 import static io.netty5.handler.codec.compression.Lz4Constants.DEFAULT_SEED;
31 import static io.netty5.handler.codec.compression.Lz4Constants.HEADER_LENGTH;
32 import static io.netty5.handler.codec.compression.Lz4Constants.MAGIC_NUMBER;
33 import static io.netty5.handler.codec.compression.Lz4Constants.MAX_BLOCK_SIZE;
34 import static java.util.Objects.requireNonNull;
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 public final class Lz4Decompressor implements Decompressor {
53
54
55
56 private enum State {
57 INIT_BLOCK,
58 DECOMPRESS_DATA,
59 FINISHED,
60 CORRUPTED,
61 CLOSED
62 }
63
64 private State currentState = State.INIT_BLOCK;
65
66
67
68
69 private LZ4FastDecompressor decompressor;
70
71
72
73
74 private BufferChecksum checksum;
75
76
77
78
79 private int blockType;
80
81
82
83
84 private int compressedLength;
85
86
87
88
89 private int decompressedLength;
90
91
92
93
94 private int currentChecksum;
95
96
97
98
99
100
101
102
103
104
105 private Lz4Decompressor(LZ4Factory factory, Checksum checksum) {
106 decompressor = factory.fastDecompressor();
107 this.checksum = checksum == null ? null : checksum instanceof Lz4XXHash32 ? (Lz4XXHash32) checksum :
108 new BufferChecksum(checksum);
109 }
110
111
112
113
114
115
116
117
118
119
120
121
122 public static Supplier<Lz4Decompressor> newFactory() {
123 return newFactory(false);
124 }
125
126
127
128
129
130
131
132
133
134 public static Supplier<Lz4Decompressor> newFactory(boolean validateChecksums) {
135 return newFactory(LZ4Factory.fastestInstance(), validateChecksums);
136 }
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151 public static Supplier<Lz4Decompressor> newFactory(LZ4Factory factory, boolean validateChecksums) {
152 return newFactory(factory, validateChecksums ? new Lz4XXHash32(DEFAULT_SEED) : null);
153 }
154
155
156
157
158
159
160
161
162
163
164
165 public static Supplier<Lz4Decompressor> newFactory(LZ4Factory factory, Checksum checksum) {
166 requireNonNull(factory, "factory");
167 return () -> new Lz4Decompressor(factory, checksum);
168 }
169
170 private void decompress(Buffer compressed, Buffer uncompressed) {
171 assert compressed.countReadableComponents() == 1;
172 try (var writableIteration = uncompressed.forEachWritable()) {
173 var writableComponent = writableIteration.first();
174 try (var readableIteration = compressed.forEachReadable()) {
175 var readableComponent = readableIteration.first();
176 decompressor.decompress(
177 readableComponent.readableBuffer(), writableComponent.writableBuffer());
178 }
179 }
180 }
181
182 @Override
183 public Buffer decompress(Buffer in, BufferAllocator allocator) throws DecompressionException {
184 try {
185 switch (currentState) {
186 case CORRUPTED:
187 case FINISHED:
188 return allocator.allocate(0);
189 case CLOSED:
190 throw new DecompressionException("Decompressor closed");
191 case INIT_BLOCK:
192 if (in.readableBytes() < HEADER_LENGTH) {
193 return null;
194 }
195 final long magic = in.readLong();
196 if (magic != MAGIC_NUMBER) {
197 streamCorrupted("unexpected block identifier");
198 }
199
200 final int token = in.readByte();
201 final int compressionLevel = (token & 0x0F) + COMPRESSION_LEVEL_BASE;
202 int blockType = token & 0xF0;
203
204 int compressedLength = Integer.reverseBytes(in.readInt());
205 if (compressedLength < 0 || compressedLength > MAX_BLOCK_SIZE) {
206 streamCorrupted(String.format(
207 "invalid compressedLength: %d (expected: 0-%d)",
208 compressedLength, MAX_BLOCK_SIZE));
209 }
210
211 int decompressedLength = Integer.reverseBytes(in.readInt());
212 final int maxDecompressedLength = 1 << compressionLevel;
213 if (decompressedLength < 0 || decompressedLength > maxDecompressedLength) {
214 streamCorrupted(String.format(
215 "invalid decompressedLength: %d (expected: 0-%d)",
216 decompressedLength, maxDecompressedLength));
217 }
218 if (decompressedLength == 0 && compressedLength != 0
219 || decompressedLength != 0 && compressedLength == 0
220 || blockType == BLOCK_TYPE_NON_COMPRESSED && decompressedLength != compressedLength) {
221 streamCorrupted(String.format(
222 "stream corrupted: compressedLength(%d) and decompressedLength(%d) mismatch",
223 compressedLength, decompressedLength));
224 }
225
226 int currentChecksum = Integer.reverseBytes(in.readInt());
227 if (decompressedLength == 0 && compressedLength == 0) {
228 if (currentChecksum != 0) {
229 streamCorrupted("stream corrupted: checksum error");
230 }
231 currentState = State.FINISHED;
232 decompressor = null;
233 checksum = null;
234 return null;
235 }
236
237 this.blockType = blockType;
238 this.compressedLength = compressedLength;
239 this.decompressedLength = decompressedLength;
240 this.currentChecksum = currentChecksum;
241
242 currentState = State.DECOMPRESS_DATA;
243
244 case DECOMPRESS_DATA:
245 blockType = this.blockType;
246 compressedLength = this.compressedLength;
247 decompressedLength = this.decompressedLength;
248 currentChecksum = this.currentChecksum;
249
250 if (in.readableBytes() < compressedLength) {
251 return null;
252 }
253
254 final BufferChecksum checksum = this.checksum;
255 Buffer uncompressed = null;
256
257 try {
258 switch (blockType) {
259 case BLOCK_TYPE_NON_COMPRESSED:
260
261 assert compressedLength == decompressedLength;
262 uncompressed = in.readSplit(decompressedLength);
263 break;
264 case BLOCK_TYPE_COMPRESSED:
265 uncompressed = allocator.allocate(decompressedLength);
266
267 assert uncompressed.countWritableComponents() == 1;
268 if (in.countReadableComponents() > 1) {
269
270
271 try (Buffer inBuffer = allocator.allocate(compressedLength)) {
272 in.copyInto(in.readerOffset(), inBuffer,
273 inBuffer.writerOffset(), compressedLength);
274 inBuffer.skipWritableBytes(compressedLength);
275 decompress(inBuffer, uncompressed);
276 }
277 } else {
278 decompress(in, uncompressed);
279 }
280 in.skipReadableBytes(compressedLength);
281
282 uncompressed.skipWritableBytes(decompressedLength);
283 break;
284 default:
285 streamCorrupted(String.format(
286 "unexpected blockType: %d (expected: %d or %d)",
287 blockType, BLOCK_TYPE_NON_COMPRESSED, BLOCK_TYPE_COMPRESSED));
288 }
289
290 if (checksum != null) {
291 CompressionUtil.checkChecksum(checksum, uncompressed, currentChecksum);
292 }
293 Buffer buffer = uncompressed;
294 uncompressed = null;
295 currentState = State.INIT_BLOCK;
296 return buffer;
297 } catch (LZ4Exception e) {
298 streamCorrupted(e);
299 } finally {
300 if (uncompressed != null) {
301 uncompressed.close();
302 }
303 }
304 default:
305 throw new IllegalStateException();
306 }
307 } catch (Exception e) {
308 currentState = State.CORRUPTED;
309 throw e;
310 }
311 }
312
313 @Override
314 public boolean isFinished() {
315 switch (currentState) {
316 case FINISHED:
317 case CLOSED:
318 case CORRUPTED:
319 return true;
320 default:
321 return false;
322 }
323 }
324
325 @Override
326 public boolean isClosed() {
327 return currentState == State.CLOSED;
328 }
329
330 @Override
331 public void close() {
332 currentState = State.CLOSED;
333 }
334
335 private void streamCorrupted(String message) {
336 currentState = State.CORRUPTED;
337 throw new DecompressionException(message);
338 }
339
340 private void streamCorrupted(Exception cause) {
341 currentState = State.CORRUPTED;
342 throw new DecompressionException(cause);
343 }
344 }