1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.compression;
17
18 import static io.netty.handler.codec.compression.Snappy.validateChecksum;
19 import io.netty.buffer.ByteBuf;
20 import io.netty.buffer.ByteBufUtil;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.handler.codec.ByteToMessageDecoder;
23
24 import java.util.Arrays;
25 import java.util.List;
26
27
28
29
30
31
32
33
34
35
36
37
38 public class SnappyFramedDecoder extends ByteToMessageDecoder {
39
40 private enum ChunkType {
41 STREAM_IDENTIFIER,
42 COMPRESSED_DATA,
43 UNCOMPRESSED_DATA,
44 RESERVED_UNSKIPPABLE,
45 RESERVED_SKIPPABLE
46 }
47
48 private static final byte[] SNAPPY = { 's', 'N', 'a', 'P', 'p', 'Y' };
49 private static final int MAX_UNCOMPRESSED_DATA_SIZE = 65536 + 4;
50
51 private final Snappy snappy = new Snappy();
52 private final boolean validateChecksums;
53
54 private boolean started;
55 private boolean corrupted;
56
57
58
59
60
61
62 public SnappyFramedDecoder() {
63 this(false);
64 }
65
66
67
68
69
70
71
72
73
74
75 public SnappyFramedDecoder(boolean validateChecksums) {
76 this.validateChecksums = validateChecksums;
77 }
78
79 @Override
80 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
81 if (corrupted) {
82 in.skipBytes(in.readableBytes());
83 return;
84 }
85
86 try {
87 int idx = in.readerIndex();
88 final int inSize = in.readableBytes();
89 if (inSize < 4) {
90
91
92 return;
93 }
94
95 final int chunkTypeVal = in.getUnsignedByte(idx);
96 final ChunkType chunkType = mapChunkType((byte) chunkTypeVal);
97 final int chunkLength = ByteBufUtil.swapMedium(in.getUnsignedMedium(idx + 1));
98
99 switch (chunkType) {
100 case STREAM_IDENTIFIER:
101 if (chunkLength != SNAPPY.length) {
102 throw new DecompressionException("Unexpected length of stream identifier: " + chunkLength);
103 }
104
105 if (inSize < 4 + SNAPPY.length) {
106 break;
107 }
108
109 byte[] identifier = new byte[chunkLength];
110 in.skipBytes(4).readBytes(identifier);
111
112 if (!Arrays.equals(identifier, SNAPPY)) {
113 throw new DecompressionException("Unexpected stream identifier contents. Mismatched snappy " +
114 "protocol version?");
115 }
116
117 started = true;
118 break;
119 case RESERVED_SKIPPABLE:
120 if (!started) {
121 throw new DecompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER");
122 }
123
124 if (inSize < 4 + chunkLength) {
125
126 return;
127 }
128
129 in.skipBytes(4 + chunkLength);
130 break;
131 case RESERVED_UNSKIPPABLE:
132
133
134
135 throw new DecompressionException(
136 "Found reserved unskippable chunk type: 0x" + Integer.toHexString(chunkTypeVal));
137 case UNCOMPRESSED_DATA:
138 if (!started) {
139 throw new DecompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER");
140 }
141 if (chunkLength > MAX_UNCOMPRESSED_DATA_SIZE) {
142 throw new DecompressionException("Received UNCOMPRESSED_DATA larger than 65540 bytes");
143 }
144
145 if (inSize < 4 + chunkLength) {
146 return;
147 }
148
149 in.skipBytes(4);
150 if (validateChecksums) {
151 int checksum = ByteBufUtil.swapInt(in.readInt());
152 validateChecksum(checksum, in, in.readerIndex(), chunkLength - 4);
153 } else {
154 in.skipBytes(4);
155 }
156 out.add(in.readSlice(chunkLength - 4).retain());
157 break;
158 case COMPRESSED_DATA:
159 if (!started) {
160 throw new DecompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER");
161 }
162
163 if (inSize < 4 + chunkLength) {
164 return;
165 }
166
167 in.skipBytes(4);
168 int checksum = ByteBufUtil.swapInt(in.readInt());
169 ByteBuf uncompressed = ctx.alloc().buffer(0);
170 if (validateChecksums) {
171 int oldWriterIndex = in.writerIndex();
172 try {
173 in.writerIndex(in.readerIndex() + chunkLength - 4);
174 snappy.decode(in, uncompressed);
175 } finally {
176 in.writerIndex(oldWriterIndex);
177 }
178 validateChecksum(checksum, uncompressed, 0, uncompressed.writerIndex());
179 } else {
180 snappy.decode(in.readSlice(chunkLength - 4), uncompressed);
181 }
182 out.add(uncompressed);
183 snappy.reset();
184 break;
185 }
186 } catch (Exception e) {
187 corrupted = true;
188 throw e;
189 }
190 }
191
192
193
194
195
196
197
198 private static ChunkType mapChunkType(byte type) {
199 if (type == 0) {
200 return ChunkType.COMPRESSED_DATA;
201 } else if (type == 1) {
202 return ChunkType.UNCOMPRESSED_DATA;
203 } else if (type == (byte) 0xff) {
204 return ChunkType.STREAM_IDENTIFIER;
205 } else if ((type & 0x80) == 0x80) {
206 return ChunkType.RESERVED_SKIPPABLE;
207 } else {
208 return ChunkType.RESERVED_UNSKIPPABLE;
209 }
210 }
211 }