1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.compression;
17
18 import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
19 import io.netty.buffer.ByteBuf;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.handler.codec.ByteToMessageDecoder;
22 import io.netty.util.internal.ObjectUtil;
23
24 import java.io.Closeable;
25 import java.io.IOException;
26 import java.io.InputStream;
27 import java.util.List;
28
29
30
31
32
33 public final class ZstdDecoder extends ByteToMessageDecoder {
34
35 {
36 try {
37 Zstd.ensureAvailability();
38 } catch (Throwable throwable) {
39 throw new ExceptionInInitializerError(throwable);
40 }
41 }
42
43 private static final int DEFAULT_MAX_FORWARD_BYTES = CompressionUtil.DEFAULT_MAX_FORWARD_BYTES;
44
45
46
47 public static final int DEFAULT_MAXIMUM_ALLOCATION_SIZE = 4 * 1024 * 1024;
48
49
50
51
52 public static final int DEFAULT_MAX_WINDOW_LOG = 27;
53 private static final int MIN_WINDOW_LOG = 10;
54 private static final int MAX_WINDOW_LOG = 31;
55 private final int maximumAllocationSize;
56 private final int maxForwardBytes;
57 private final int maxWindowLog;
58 private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream();
59 private ZstdInputStreamNoFinalizer zstdIs;
60
61 private boolean needsRead;
62 private State currentState = State.DECOMPRESS_DATA;
63
64
65
66
67 private enum State {
68 DECOMPRESS_DATA,
69 CORRUPTED
70 }
71
72
73
74
75
76
77
78
79
80 public ZstdDecoder() {
81 this(DEFAULT_MAXIMUM_ALLOCATION_SIZE, DEFAULT_MAX_WINDOW_LOG);
82 }
83
84
85
86
87
88
89
90
91
92
93 public ZstdDecoder(int maximumAllocationSize) {
94 this(maximumAllocationSize, DEFAULT_MAX_WINDOW_LOG);
95 }
96
97
98
99
100
101
102
103
104
105
106 public ZstdDecoder(int maximumAllocationSize, int maxWindowLog) {
107 this.maximumAllocationSize = ObjectUtil.checkPositiveOrZero(maximumAllocationSize, "maximumAllocationSize");
108 this.maxForwardBytes = maximumAllocationSize > 0 ? maximumAllocationSize : DEFAULT_MAX_FORWARD_BYTES;
109 this.maxWindowLog = ObjectUtil.checkInRange(maxWindowLog, MIN_WINDOW_LOG, MAX_WINDOW_LOG, "maxWindowLog");
110 }
111
112 @Override
113 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
114 needsRead = true;
115 try {
116 if (currentState == State.CORRUPTED) {
117 in.skipBytes(in.readableBytes());
118
119 return;
120 }
121 inputStream.current = in;
122
123 ByteBuf outBuffer = null;
124
125 final int compressedLength = in.readableBytes();
126 try {
127 long uncompressedLength;
128 if (in.isDirect()) {
129 uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
130 CompressionUtil.safeNioBuffer(in, in.readerIndex(), in.readableBytes()));
131 } else {
132 uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
133 in.array(), in.readerIndex() + in.arrayOffset(), in.readableBytes());
134 }
135 if (uncompressedLength <= 0) {
136
137
138 uncompressedLength = compressedLength * 2L;
139 }
140
141 int w;
142 do {
143 if (outBuffer == null) {
144 outBuffer = ctx.alloc().heapBuffer((int) (maximumAllocationSize == 0 ?
145 uncompressedLength : Math.min(maximumAllocationSize, uncompressedLength)));
146 }
147 do {
148 w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes());
149 } while (w > 0 && outBuffer.isWritable());
150 if (!outBuffer.isWritable() || outBuffer.readableBytes() >= maxForwardBytes) {
151 needsRead = false;
152 ctx.fireChannelRead(outBuffer);
153 outBuffer = null;
154 }
155 } while (w > 0);
156 if (outBuffer != null && outBuffer.isReadable()) {
157 needsRead = false;
158 ctx.fireChannelRead(outBuffer);
159 outBuffer = null;
160 }
161 } finally {
162 if (outBuffer != null) {
163 outBuffer.release();
164 }
165 }
166 } catch (Exception e) {
167 currentState = State.CORRUPTED;
168 throw new DecompressionException(e);
169 } finally {
170 inputStream.current = null;
171 }
172 }
173
174 @Override
175 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
176
177 discardSomeReadBytes();
178
179 if (needsRead && !ctx.channel().config().isAutoRead()) {
180 ctx.read();
181 }
182 ctx.fireChannelReadComplete();
183 }
184
185 @Override
186 public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
187 super.handlerAdded(ctx);
188 zstdIs = new ZstdInputStreamNoFinalizer(inputStream);
189 zstdIs.setContinuous(true);
190
191
192 zstdIs.setLongMax(maxWindowLog);
193 }
194
195 @Override
196 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
197 try {
198 closeSilently(zstdIs);
199 } finally {
200 super.handlerRemoved0(ctx);
201 }
202 }
203
204 private static void closeSilently(Closeable closeable) {
205 if (closeable != null) {
206 try {
207 closeable.close();
208 } catch (IOException ignore) {
209
210 }
211 }
212 }
213
214 private static final class MutableByteBufInputStream extends InputStream {
215 ByteBuf current;
216
217 @Override
218 public int read() {
219 if (current == null || !current.isReadable()) {
220 return -1;
221 }
222 return current.readByte() & 0xff;
223 }
224
225 @Override
226 public int read(byte[] b, int off, int len) {
227 int available = available();
228 if (available == 0) {
229 return -1;
230 }
231
232 len = Math.min(available, len);
233 current.readBytes(b, off, len);
234 return len;
235 }
236
237 @Override
238 public int available() {
239 return current == null ? 0 : current.readableBytes();
240 }
241 }
242 }