View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  
17  package io.netty.handler.codec.mqtt;
18  
19  import io.netty.buffer.ByteBuf;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.handler.codec.DecoderException;
22  import io.netty.handler.codec.ReplayingDecoder;
23  import io.netty.handler.codec.mqtt.MqttDecoder.DecoderState;
24  import io.netty.util.CharsetUtil;
25  
26  import java.util.ArrayList;
27  import java.util.List;
28  
29  import static io.netty.handler.codec.mqtt.MqttCodecUtil.isValidClientId;
30  import static io.netty.handler.codec.mqtt.MqttCodecUtil.isValidMessageId;
31  import static io.netty.handler.codec.mqtt.MqttCodecUtil.isValidPublishTopicName;
32  import static io.netty.handler.codec.mqtt.MqttCodecUtil.resetUnusedFields;
33  import static io.netty.handler.codec.mqtt.MqttCodecUtil.validateFixedHeader;
34  
35  /**
36   * Decodes Mqtt messages from bytes, following
37   * <a href="http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html">
38   *     the MQTT protocol specification v3.1</a>
39   */
40  public final class MqttDecoder extends ReplayingDecoder<DecoderState> {
41  
42      private static final int DEFAULT_MAX_BYTES_IN_MESSAGE = 8092;
43  
44      /**
45       * States of the decoder.
46       * We start at READ_FIXED_HEADER, followed by
47       * READ_VARIABLE_HEADER and finally READ_PAYLOAD.
48       */
49      enum DecoderState {
50          READ_FIXED_HEADER,
51          READ_VARIABLE_HEADER,
52          READ_PAYLOAD,
53          BAD_MESSAGE,
54      }
55  
56      private MqttFixedHeader mqttFixedHeader;
57      private Object variableHeader;
58      private int bytesRemainingInVariablePart;
59  
60      private final int maxBytesInMessage;
61  
62      public MqttDecoder() {
63        this(DEFAULT_MAX_BYTES_IN_MESSAGE);
64      }
65  
66      public MqttDecoder(int maxBytesInMessage) {
67          super(DecoderState.READ_FIXED_HEADER);
68          this.maxBytesInMessage = maxBytesInMessage;
69      }
70  
71      @Override
72      protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws Exception {
73          switch (state()) {
74              case READ_FIXED_HEADER: try {
75                  mqttFixedHeader = decodeFixedHeader(buffer);
76                  bytesRemainingInVariablePart = mqttFixedHeader.remainingLength();
77                  checkpoint(DecoderState.READ_VARIABLE_HEADER);
78                  // fall through
79              } catch (Exception cause) {
80                  out.add(invalidMessage(cause));
81                  return;
82              }
83  
84              case READ_VARIABLE_HEADER:  try {
85                  if (bytesRemainingInVariablePart > maxBytesInMessage) {
86                      throw new DecoderException("too large message: " + bytesRemainingInVariablePart + " bytes");
87                  }
88                  final Result<?> decodedVariableHeader = decodeVariableHeader(buffer, mqttFixedHeader);
89                  variableHeader = decodedVariableHeader.value;
90                  bytesRemainingInVariablePart -= decodedVariableHeader.numberOfBytesConsumed;
91                  checkpoint(DecoderState.READ_PAYLOAD);
92                  // fall through
93              } catch (Exception cause) {
94                  out.add(invalidMessage(cause));
95                  return;
96              }
97  
98              case READ_PAYLOAD: try {
99                  final Result<?> decodedPayload =
100                         decodePayload(
101                                 buffer,
102                                 mqttFixedHeader.messageType(),
103                                 bytesRemainingInVariablePart,
104                                 variableHeader);
105                 bytesRemainingInVariablePart -= decodedPayload.numberOfBytesConsumed;
106                 if (bytesRemainingInVariablePart != 0) {
107                     throw new DecoderException(
108                             "non-zero remaining payload bytes: " +
109                                     bytesRemainingInVariablePart + " (" + mqttFixedHeader.messageType() + ')');
110                 }
111                 checkpoint(DecoderState.READ_FIXED_HEADER);
112                 MqttMessage message = MqttMessageFactory.newMessage(
113                         mqttFixedHeader, variableHeader, decodedPayload.value);
114                 mqttFixedHeader = null;
115                 variableHeader = null;
116                 out.add(message);
117                 break;
118             } catch (Exception cause) {
119                 out.add(invalidMessage(cause));
120                 return;
121             }
122 
123             case BAD_MESSAGE:
124                 // Keep discarding until disconnection.
125                 buffer.skipBytes(actualReadableBytes());
126                 break;
127 
128             default:
129                 // Shouldn't reach here.
130                 throw new Error();
131         }
132     }
133 
134     private MqttMessage invalidMessage(Throwable cause) {
135       checkpoint(DecoderState.BAD_MESSAGE);
136       return MqttMessageFactory.newInvalidMessage(cause);
137     }
138 
139     /**
140      * Decodes the fixed header. It's one byte for the flags and then variable bytes for the remaining length.
141      *
142      * @param buffer the buffer to decode from
143      * @return the fixed header
144      */
145     private static MqttFixedHeader decodeFixedHeader(ByteBuf buffer) {
146         short b1 = buffer.readUnsignedByte();
147 
148         MqttMessageType messageType = MqttMessageType.valueOf(b1 >> 4);
149         boolean dupFlag = (b1 & 0x08) == 0x08;
150         int qosLevel = (b1 & 0x06) >> 1;
151         boolean retain = (b1 & 0x01) != 0;
152 
153         int remainingLength = 0;
154         int multiplier = 1;
155         short digit;
156         int loops = 0;
157         do {
158             digit = buffer.readUnsignedByte();
159             remainingLength += (digit & 127) * multiplier;
160             multiplier *= 128;
161             loops++;
162         } while ((digit & 128) != 0 && loops < 4);
163 
164         // MQTT protocol limits Remaining Length to 4 bytes
165         if (loops == 4 && (digit & 128) != 0) {
166             throw new DecoderException("remaining length exceeds 4 digits (" + messageType + ')');
167         }
168         MqttFixedHeader decodedFixedHeader =
169                 new MqttFixedHeader(messageType, dupFlag, MqttQoS.valueOf(qosLevel), retain, remainingLength);
170         return validateFixedHeader(resetUnusedFields(decodedFixedHeader));
171     }
172 
173     /**
174      * Decodes the variable header (if any)
175      * @param buffer the buffer to decode from
176      * @param mqttFixedHeader MqttFixedHeader of the same message
177      * @return the variable header
178      */
179     private static Result<?> decodeVariableHeader(ByteBuf buffer, MqttFixedHeader mqttFixedHeader) {
180         switch (mqttFixedHeader.messageType()) {
181             case CONNECT:
182                 return decodeConnectionVariableHeader(buffer);
183 
184             case CONNACK:
185                 return decodeConnAckVariableHeader(buffer);
186 
187             case SUBSCRIBE:
188             case UNSUBSCRIBE:
189             case SUBACK:
190             case UNSUBACK:
191             case PUBACK:
192             case PUBREC:
193             case PUBCOMP:
194             case PUBREL:
195                 return decodeMessageIdVariableHeader(buffer);
196 
197             case PUBLISH:
198                 return decodePublishVariableHeader(buffer, mqttFixedHeader);
199 
200             case PINGREQ:
201             case PINGRESP:
202             case DISCONNECT:
203                 // Empty variable header
204                 return new Result<Object>(null, 0);
205         }
206         return new Result<Object>(null, 0); //should never reach here
207     }
208 
209     private static Result<MqttConnectVariableHeader> decodeConnectionVariableHeader(ByteBuf buffer) {
210         final Result<String> protoString = decodeString(buffer);
211         int numberOfBytesConsumed = protoString.numberOfBytesConsumed;
212 
213         final byte protocolLevel = buffer.readByte();
214         numberOfBytesConsumed += 1;
215 
216         final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(protoString.value, protocolLevel);
217 
218         final int b1 = buffer.readUnsignedByte();
219         numberOfBytesConsumed += 1;
220 
221         final Result<Integer> keepAlive = decodeMsbLsb(buffer);
222         numberOfBytesConsumed += keepAlive.numberOfBytesConsumed;
223 
224         final boolean hasUserName = (b1 & 0x80) == 0x80;
225         final boolean hasPassword = (b1 & 0x40) == 0x40;
226         final boolean willRetain = (b1 & 0x20) == 0x20;
227         final int willQos = (b1 & 0x18) >> 3;
228         final boolean willFlag = (b1 & 0x04) == 0x04;
229         final boolean cleanSession = (b1 & 0x02) == 0x02;
230         if (mqttVersion == MqttVersion.MQTT_3_1_1) {
231             final boolean zeroReservedFlag = (b1 & 0x01) == 0x0;
232             if (!zeroReservedFlag) {
233                 // MQTT v3.1.1: The Server MUST validate that the reserved flag in the CONNECT Control Packet is
234                 // set to zero and disconnect the Client if it is not zero.
235                 // See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349230
236                 throw new DecoderException("non-zero reserved flag");
237             }
238         }
239 
240         final MqttConnectVariableHeader mqttConnectVariableHeader = new MqttConnectVariableHeader(
241                 mqttVersion.protocolName(),
242                 mqttVersion.protocolLevel(),
243                 hasUserName,
244                 hasPassword,
245                 willRetain,
246                 willQos,
247                 willFlag,
248                 cleanSession,
249                 keepAlive.value);
250         return new Result<MqttConnectVariableHeader>(mqttConnectVariableHeader, numberOfBytesConsumed);
251     }
252 
253     private static Result<MqttConnAckVariableHeader> decodeConnAckVariableHeader(ByteBuf buffer) {
254         final boolean sessionPresent = (buffer.readUnsignedByte() & 0x01) == 0x01;
255         byte returnCode = buffer.readByte();
256         final int numberOfBytesConsumed = 2;
257         final MqttConnAckVariableHeader mqttConnAckVariableHeader =
258                 new MqttConnAckVariableHeader(MqttConnectReturnCode.valueOf(returnCode), sessionPresent);
259         return new Result<MqttConnAckVariableHeader>(mqttConnAckVariableHeader, numberOfBytesConsumed);
260     }
261 
262     private static Result<MqttMessageIdVariableHeader> decodeMessageIdVariableHeader(ByteBuf buffer) {
263         final Result<Integer> messageId = decodeMessageId(buffer);
264         return new Result<MqttMessageIdVariableHeader>(
265                 MqttMessageIdVariableHeader.from(messageId.value),
266                 messageId.numberOfBytesConsumed);
267     }
268 
269     private static Result<MqttPublishVariableHeader> decodePublishVariableHeader(
270             ByteBuf buffer,
271             MqttFixedHeader mqttFixedHeader) {
272         final Result<String> decodedTopic = decodeString(buffer);
273         if (!isValidPublishTopicName(decodedTopic.value)) {
274             throw new DecoderException("invalid publish topic name: " + decodedTopic.value + " (contains wildcards)");
275         }
276         int numberOfBytesConsumed = decodedTopic.numberOfBytesConsumed;
277 
278         int messageId = -1;
279         if (mqttFixedHeader.qosLevel().value() > 0) {
280             final Result<Integer> decodedMessageId = decodeMessageId(buffer);
281             messageId = decodedMessageId.value;
282             numberOfBytesConsumed += decodedMessageId.numberOfBytesConsumed;
283         }
284         final MqttPublishVariableHeader mqttPublishVariableHeader =
285                 new MqttPublishVariableHeader(decodedTopic.value, messageId);
286         return new Result<MqttPublishVariableHeader>(mqttPublishVariableHeader, numberOfBytesConsumed);
287     }
288 
289     private static Result<Integer> decodeMessageId(ByteBuf buffer) {
290         final Result<Integer> messageId = decodeMsbLsb(buffer);
291         if (!isValidMessageId(messageId.value)) {
292             throw new DecoderException("invalid messageId: " + messageId.value);
293         }
294         return messageId;
295     }
296 
297     /**
298      * Decodes the payload.
299      *
300      * @param buffer the buffer to decode from
301      * @param messageType  type of the message being decoded
302      * @param bytesRemainingInVariablePart bytes remaining
303      * @param variableHeader variable header of the same message
304      * @return the payload
305      */
306     private static Result<?> decodePayload(
307             ByteBuf buffer,
308             MqttMessageType messageType,
309             int bytesRemainingInVariablePart,
310             Object variableHeader) {
311         switch (messageType) {
312             case CONNECT:
313                 return decodeConnectionPayload(buffer, (MqttConnectVariableHeader) variableHeader);
314 
315             case SUBSCRIBE:
316                 return decodeSubscribePayload(buffer, bytesRemainingInVariablePart);
317 
318             case SUBACK:
319                 return decodeSubackPayload(buffer, bytesRemainingInVariablePart);
320 
321             case UNSUBSCRIBE:
322                 return decodeUnsubscribePayload(buffer, bytesRemainingInVariablePart);
323 
324             case PUBLISH:
325                 return decodePublishPayload(buffer, bytesRemainingInVariablePart);
326 
327             default:
328                 // unknown payload , no byte consumed
329                 return new Result<Object>(null, 0);
330         }
331     }
332 
333     private static Result<MqttConnectPayload> decodeConnectionPayload(
334             ByteBuf buffer,
335             MqttConnectVariableHeader mqttConnectVariableHeader) {
336         final Result<String> decodedClientId = decodeString(buffer);
337         final String decodedClientIdValue = decodedClientId.value;
338         final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(mqttConnectVariableHeader.name(),
339                 (byte) mqttConnectVariableHeader.version());
340         if (!isValidClientId(mqttVersion, decodedClientIdValue)) {
341             throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + decodedClientIdValue);
342         }
343         int numberOfBytesConsumed = decodedClientId.numberOfBytesConsumed;
344 
345         Result<String> decodedWillTopic = null;
346         Result<byte[]> decodedWillMessage = null;
347         if (mqttConnectVariableHeader.isWillFlag()) {
348             decodedWillTopic = decodeString(buffer, 0, 32767);
349             numberOfBytesConsumed += decodedWillTopic.numberOfBytesConsumed;
350             decodedWillMessage = decodeByteArray(buffer);
351             numberOfBytesConsumed += decodedWillMessage.numberOfBytesConsumed;
352         }
353         Result<String> decodedUserName = null;
354         Result<byte[]> decodedPassword = null;
355         if (mqttConnectVariableHeader.hasUserName()) {
356             decodedUserName = decodeString(buffer);
357             numberOfBytesConsumed += decodedUserName.numberOfBytesConsumed;
358         }
359         if (mqttConnectVariableHeader.hasPassword()) {
360             decodedPassword = decodeByteArray(buffer);
361             numberOfBytesConsumed += decodedPassword.numberOfBytesConsumed;
362         }
363 
364         final MqttConnectPayload mqttConnectPayload =
365                 new MqttConnectPayload(
366                         decodedClientId.value,
367                         decodedWillTopic != null ? decodedWillTopic.value : null,
368                         decodedWillMessage != null ? decodedWillMessage.value : null,
369                         decodedUserName != null ? decodedUserName.value : null,
370                         decodedPassword != null ? decodedPassword.value : null);
371         return new Result<MqttConnectPayload>(mqttConnectPayload, numberOfBytesConsumed);
372     }
373 
374     private static Result<MqttSubscribePayload> decodeSubscribePayload(
375             ByteBuf buffer,
376             int bytesRemainingInVariablePart) {
377         final List<MqttTopicSubscription> subscribeTopics = new ArrayList<MqttTopicSubscription>();
378         int numberOfBytesConsumed = 0;
379         while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
380             final Result<String> decodedTopicName = decodeString(buffer);
381             numberOfBytesConsumed += decodedTopicName.numberOfBytesConsumed;
382             int qos = buffer.readUnsignedByte() & 0x03;
383             numberOfBytesConsumed++;
384             subscribeTopics.add(new MqttTopicSubscription(decodedTopicName.value, MqttQoS.valueOf(qos)));
385         }
386         return new Result<MqttSubscribePayload>(new MqttSubscribePayload(subscribeTopics), numberOfBytesConsumed);
387     }
388 
389     private static Result<MqttSubAckPayload> decodeSubackPayload(
390             ByteBuf buffer,
391             int bytesRemainingInVariablePart) {
392         final List<Integer> grantedQos = new ArrayList<Integer>();
393         int numberOfBytesConsumed = 0;
394         while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
395             int qos = buffer.readUnsignedByte() & 0x03;
396             numberOfBytesConsumed++;
397             grantedQos.add(qos);
398         }
399         return new Result<MqttSubAckPayload>(new MqttSubAckPayload(grantedQos), numberOfBytesConsumed);
400     }
401 
402     private static Result<MqttUnsubscribePayload> decodeUnsubscribePayload(
403             ByteBuf buffer,
404             int bytesRemainingInVariablePart) {
405         final List<String> unsubscribeTopics = new ArrayList<String>();
406         int numberOfBytesConsumed = 0;
407         while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
408             final Result<String> decodedTopicName = decodeString(buffer);
409             numberOfBytesConsumed += decodedTopicName.numberOfBytesConsumed;
410             unsubscribeTopics.add(decodedTopicName.value);
411         }
412         return new Result<MqttUnsubscribePayload>(
413                 new MqttUnsubscribePayload(unsubscribeTopics),
414                 numberOfBytesConsumed);
415     }
416 
417     private static Result<ByteBuf> decodePublishPayload(ByteBuf buffer, int bytesRemainingInVariablePart) {
418         ByteBuf b = buffer.readRetainedSlice(bytesRemainingInVariablePart);
419         return new Result<ByteBuf>(b, bytesRemainingInVariablePart);
420     }
421 
422     private static Result<String> decodeString(ByteBuf buffer) {
423         return decodeString(buffer, 0, Integer.MAX_VALUE);
424     }
425 
426     private static Result<String> decodeString(ByteBuf buffer, int minBytes, int maxBytes) {
427         final Result<Integer> decodedSize = decodeMsbLsb(buffer);
428         int size = decodedSize.value;
429         int numberOfBytesConsumed = decodedSize.numberOfBytesConsumed;
430         if (size < minBytes || size > maxBytes) {
431             buffer.skipBytes(size);
432             numberOfBytesConsumed += size;
433             return new Result<String>(null, numberOfBytesConsumed);
434         }
435         String s = buffer.toString(buffer.readerIndex(), size, CharsetUtil.UTF_8);
436         buffer.skipBytes(size);
437         numberOfBytesConsumed += size;
438         return new Result<String>(s, numberOfBytesConsumed);
439     }
440 
441     private static Result<byte[]> decodeByteArray(ByteBuf buffer) {
442         final Result<Integer> decodedSize = decodeMsbLsb(buffer);
443         int size = decodedSize.value;
444         byte[] bytes = new byte[size];
445         buffer.readBytes(bytes);
446         return new Result<byte[]>(bytes, decodedSize.numberOfBytesConsumed + size);
447     }
448 
449     private static Result<Integer> decodeMsbLsb(ByteBuf buffer) {
450         return decodeMsbLsb(buffer, 0, 65535);
451     }
452 
453     private static Result<Integer> decodeMsbLsb(ByteBuf buffer, int min, int max) {
454         short msbSize = buffer.readUnsignedByte();
455         short lsbSize = buffer.readUnsignedByte();
456         final int numberOfBytesConsumed = 2;
457         int result = msbSize << 8 | lsbSize;
458         if (result < min || result > max) {
459             result = -1;
460         }
461         return new Result<Integer>(result, numberOfBytesConsumed);
462     }
463 
464     private static final class Result<T> {
465 
466         private final T value;
467         private final int numberOfBytesConsumed;
468 
469         Result(T value, int numberOfBytesConsumed) {
470             this.value = value;
471             this.numberOfBytesConsumed = numberOfBytesConsumed;
472         }
473     }
474 }