1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  package io.netty.handler.codec.compression;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.ChannelHandlerContext;
20  import io.netty.handler.codec.ByteToMessageDecoder;
21  
22  import java.util.List;
23  
24  import static io.netty.handler.codec.compression.Snappy.validateChecksum;
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  public class SnappyFrameDecoder extends ByteToMessageDecoder {
38  
39      private enum ChunkType {
40          STREAM_IDENTIFIER,
41          COMPRESSED_DATA,
42          UNCOMPRESSED_DATA,
43          RESERVED_UNSKIPPABLE,
44          RESERVED_SKIPPABLE
45      }
46  
47      private static final int SNAPPY_IDENTIFIER_LEN = 6;
48      
49      private static final int MAX_UNCOMPRESSED_DATA_SIZE = 65536 + 4;
50      
51      private static final int MAX_DECOMPRESSED_DATA_SIZE = 65536;
52      
53      private static final int MAX_COMPRESSED_CHUNK_SIZE = 16777216 - 1;
54  
55      private final Snappy snappy = new Snappy();
56      private final boolean validateChecksums;
57  
58      private boolean started;
59      private boolean corrupted;
60      private int numBytesToSkip;
61  
62      
63  
64  
65  
66  
67      public SnappyFrameDecoder() {
68          this(false);
69      }
70  
71      
72  
73  
74  
75  
76  
77  
78  
79  
80      public SnappyFrameDecoder(boolean validateChecksums) {
81          this.validateChecksums = validateChecksums;
82      }
83  
84      @Override
85      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
86          if (corrupted) {
87              in.skipBytes(in.readableBytes());
88              return;
89          }
90  
91          if (numBytesToSkip != 0) {
92              
93              int skipBytes = Math.min(numBytesToSkip, in.readableBytes());
94              in.skipBytes(skipBytes);
95              numBytesToSkip -= skipBytes;
96  
97              
98              return;
99          }
100 
101         try {
102             int idx = in.readerIndex();
103             final int inSize = in.readableBytes();
104             if (inSize < 4) {
105                 
106                 
107                 return;
108             }
109 
110             final int chunkTypeVal = in.getUnsignedByte(idx);
111             final ChunkType chunkType = mapChunkType((byte) chunkTypeVal);
112             final int chunkLength = in.getUnsignedMediumLE(idx + 1);
113 
114             switch (chunkType) {
115                 case STREAM_IDENTIFIER:
116                     if (chunkLength != SNAPPY_IDENTIFIER_LEN) {
117                         throw new DecompressionException("Unexpected length of stream identifier: " + chunkLength);
118                     }
119 
120                     if (inSize < 4 + SNAPPY_IDENTIFIER_LEN) {
121                         break;
122                     }
123 
124                     in.skipBytes(4);
125                     int offset = in.readerIndex();
126                     in.skipBytes(SNAPPY_IDENTIFIER_LEN);
127 
128                     checkByte(in.getByte(offset++), (byte) 's');
129                     checkByte(in.getByte(offset++), (byte) 'N');
130                     checkByte(in.getByte(offset++), (byte) 'a');
131                     checkByte(in.getByte(offset++), (byte) 'P');
132                     checkByte(in.getByte(offset++), (byte) 'p');
133                     checkByte(in.getByte(offset), (byte) 'Y');
134 
135                     started = true;
136                     break;
137                 case RESERVED_SKIPPABLE:
138                     if (!started) {
139                         throw new DecompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER");
140                     }
141 
142                     in.skipBytes(4);
143 
144                     int skipBytes = Math.min(chunkLength, in.readableBytes());
145                     in.skipBytes(skipBytes);
146                     if (skipBytes != chunkLength) {
147                         
148                         
149                         numBytesToSkip = chunkLength - skipBytes;
150                     }
151                     break;
152                 case RESERVED_UNSKIPPABLE:
153                     
154                     
155                     
156                     throw new DecompressionException(
157                             "Found reserved unskippable chunk type: 0x" + Integer.toHexString(chunkTypeVal));
158                 case UNCOMPRESSED_DATA:
159                     if (!started) {
160                         throw new DecompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER");
161                     }
162                     if (chunkLength > MAX_UNCOMPRESSED_DATA_SIZE) {
163                         throw new DecompressionException("Received UNCOMPRESSED_DATA larger than " +
164                                 MAX_UNCOMPRESSED_DATA_SIZE + " bytes");
165                     }
166 
167                     if (inSize < 4 + chunkLength) {
168                         return;
169                     }
170 
171                     in.skipBytes(4);
172                     if (validateChecksums) {
173                         int checksum = in.readIntLE();
174                         validateChecksum(checksum, in, in.readerIndex(), chunkLength - 4);
175                     } else {
176                         in.skipBytes(4);
177                     }
178                     out.add(in.readRetainedSlice(chunkLength - 4));
179                     break;
180                 case COMPRESSED_DATA:
181                     if (!started) {
182                         throw new DecompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER");
183                     }
184 
185                     if (chunkLength > MAX_COMPRESSED_CHUNK_SIZE) {
186                         throw new DecompressionException("Received COMPRESSED_DATA that contains" +
187                                 " chunk that exceeds " + MAX_COMPRESSED_CHUNK_SIZE + " bytes");
188                     }
189 
190                     if (inSize < 4 + chunkLength) {
191                         return;
192                     }
193 
194                     in.skipBytes(4);
195                     int checksum = in.readIntLE();
196 
197                     int uncompressedSize = snappy.getPreamble(in);
198                     if (uncompressedSize > MAX_DECOMPRESSED_DATA_SIZE) {
199                         throw new DecompressionException("Received COMPRESSED_DATA that contains" +
200                                 " uncompressed data that exceeds " + MAX_DECOMPRESSED_DATA_SIZE + " bytes");
201                     }
202 
203                     ByteBuf uncompressed = ctx.alloc().buffer(uncompressedSize, MAX_DECOMPRESSED_DATA_SIZE);
204                     try {
205                         if (validateChecksums) {
206                             int oldWriterIndex = in.writerIndex();
207                             try {
208                                 in.writerIndex(in.readerIndex() + chunkLength - 4);
209                                 snappy.decode(in, uncompressed);
210                             } finally {
211                                 in.writerIndex(oldWriterIndex);
212                             }
213                             validateChecksum(checksum, uncompressed, 0, uncompressed.writerIndex());
214                         } else {
215                             snappy.decode(in.readSlice(chunkLength - 4), uncompressed);
216                         }
217                         out.add(uncompressed);
218                         uncompressed = null;
219                     } finally {
220                         if (uncompressed != null) {
221                             uncompressed.release();
222                         }
223                     }
224                     snappy.reset();
225                     break;
226             }
227         } catch (Exception e) {
228             corrupted = true;
229             throw e;
230         }
231     }
232 
233     private static void checkByte(byte actual, byte expect) {
234         if (actual != expect) {
235             throw new DecompressionException("Unexpected stream identifier contents. Mismatched snappy " +
236                     "protocol version?");
237         }
238     }
239 
240     
241 
242 
243 
244 
245 
246     private static ChunkType mapChunkType(byte type) {
247         if (type == 0) {
248             return ChunkType.COMPRESSED_DATA;
249         } else if (type == 1) {
250             return ChunkType.UNCOMPRESSED_DATA;
251         } else if (type == (byte) 0xff) {
252             return ChunkType.STREAM_IDENTIFIER;
253         } else if ((type & 0x80) == 0x80) {
254             return ChunkType.RESERVED_SKIPPABLE;
255         } else {
256             return ChunkType.RESERVED_UNSKIPPABLE;
257         }
258     }
259 }