1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.compression;
17
18 import com.github.luben.zstd.ZstdIOException;
19 import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.handler.codec.ByteToMessageDecoder;
23 import io.netty.util.internal.ObjectUtil;
24
25 import java.io.Closeable;
26 import java.io.IOException;
27 import java.io.InputStream;
28 import java.util.List;
29
30
31
32
33
34 public final class ZstdDecoder extends ByteToMessageDecoder {
35
36 {
37 try {
38 Zstd.ensureAvailability();
39 } catch (Throwable throwable) {
40 throw new ExceptionInInitializerError(throwable);
41 }
42 }
43
44 private final int maximumAllocationSize;
45 private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream();
46 private ZstdInputStreamNoFinalizer zstdIs;
47
48 private boolean needsRead;
49 private State currentState = State.DECOMPRESS_DATA;
50
51
52
53
54 private enum State {
55 DECOMPRESS_DATA,
56 CORRUPTED
57 }
58
59 public ZstdDecoder() {
60 this(4 * 1024 * 1024);
61 }
62
63 public ZstdDecoder(int maximumAllocationSize) {
64 this.maximumAllocationSize = ObjectUtil.checkPositiveOrZero(maximumAllocationSize, "maximumAllocationSize");
65 }
66
67 @Override
68 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69 needsRead = true;
70 try {
71 if (currentState == State.CORRUPTED) {
72 in.skipBytes(in.readableBytes());
73
74 return;
75 }
76 inputStream.current = in;
77
78 ByteBuf outBuffer = null;
79
80 final int compressedLength = in.readableBytes();
81 try {
82 long uncompressedLength;
83 if (in.isDirect()) {
84 uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
85 CompressionUtil.safeNioBuffer(in, in.readerIndex(), in.readableBytes()));
86 } else {
87 uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
88 in.array(), in.readerIndex() + in.arrayOffset(), in.readableBytes());
89 }
90 if (uncompressedLength <= 0) {
91
92
93 uncompressedLength = compressedLength * 2L;
94 }
95
96 int w;
97 do {
98 if (outBuffer == null) {
99 outBuffer = ctx.alloc().heapBuffer((int) (maximumAllocationSize == 0 ?
100 uncompressedLength : Math.min(maximumAllocationSize, uncompressedLength)));
101 }
102 do {
103 w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes());
104 } while (w != -1 && outBuffer.isWritable());
105 if (outBuffer.isReadable()) {
106 needsRead = false;
107 ctx.fireChannelRead(outBuffer);
108 outBuffer = null;
109 }
110 } while (w != -1);
111 } finally {
112 if (outBuffer != null) {
113 outBuffer.release();
114 }
115 }
116 } catch (Exception e) {
117 currentState = State.CORRUPTED;
118 throw new DecompressionException(e);
119 } finally {
120 inputStream.current = null;
121 }
122 }
123
124 @Override
125 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
126
127 discardSomeReadBytes();
128
129 if (needsRead && !ctx.channel().config().isAutoRead()) {
130 ctx.read();
131 }
132 ctx.fireChannelReadComplete();
133 }
134
135 @Override
136 public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
137 super.handlerAdded(ctx);
138 zstdIs = new ZstdInputStreamNoFinalizer(inputStream);
139 zstdIs.setContinuous(true);
140 }
141
142 @Override
143 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
144 try {
145 closeSilently(zstdIs);
146 } finally {
147 super.handlerRemoved0(ctx);
148 }
149 }
150
151 private static void closeSilently(Closeable closeable) {
152 if (closeable != null) {
153 try {
154 closeable.close();
155 } catch (IOException ignore) {
156
157 }
158 }
159 }
160
161 private static final class MutableByteBufInputStream extends InputStream {
162 ByteBuf current;
163
164 @Override
165 public int read() {
166 if (current == null || !current.isReadable()) {
167 return -1;
168 }
169 return current.readByte() & 0xff;
170 }
171
172 @Override
173 public int read(byte[] b, int off, int len) {
174 int available = available();
175 if (available == 0) {
176 return -1;
177 }
178
179 len = Math.min(available, len);
180 current.readBytes(b, off, len);
181 return len;
182 }
183
184 @Override
185 public int available() {
186 return current == null ? 0 : current.readableBytes();
187 }
188 }
189 }