View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  
17  package io.netty.handler.codec.mqtt;
18  
19  import io.netty.buffer.ByteBuf;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.handler.codec.DecoderException;
22  import io.netty.handler.codec.ReplayingDecoder;
23  import io.netty.handler.codec.mqtt.MqttDecoder.DecoderState;
24  import io.netty.util.CharsetUtil;
25  
26  import java.util.ArrayList;
27  import java.util.List;
28  
29  import static io.netty.handler.codec.mqtt.MqttCodecUtil.*;
30  
31  /**
32   * Decodes Mqtt messages from bytes, following
33   * <a href="http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html">
34   *     the MQTT protocl specification v3.1</a>
35   */
36  public class MqttDecoder extends ReplayingDecoder<DecoderState> {
37  
38      private static final int DEFAULT_MAX_BYTES_IN_MESSAGE = 8092;
39  
40      /**
41       * States of the decoder.
42       * We start at READ_FIXED_HEADER, followed by
43       * READ_VARIABLE_HEADER and finally READ_PAYLOAD.
44       */
45      enum DecoderState {
46          READ_FIXED_HEADER,
47          READ_VARIABLE_HEADER,
48          READ_PAYLOAD,
49          BAD_MESSAGE,
50      }
51  
52      private MqttFixedHeader mqttFixedHeader;
53      private Object variableHeader;
54      private Object payload;
55      private int bytesRemainingInVariablePart;
56  
57      private final int maxBytesInMessage;
58  
59      public MqttDecoder() {
60        this(DEFAULT_MAX_BYTES_IN_MESSAGE);
61      }
62  
63      public MqttDecoder(int maxBytesInMessage) {
64          super(DecoderState.READ_FIXED_HEADER);
65          this.maxBytesInMessage = maxBytesInMessage;
66      }
67  
68      @Override
69      protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws Exception {
70          switch (state()) {
71              case READ_FIXED_HEADER:
72                  mqttFixedHeader = decodeFixedHeader(buffer);
73                  bytesRemainingInVariablePart = mqttFixedHeader.remainingLength();
74                  checkpoint(DecoderState.READ_VARIABLE_HEADER);
75                  // fall through
76  
77              case READ_VARIABLE_HEADER:  try {
78                  if (bytesRemainingInVariablePart > maxBytesInMessage) {
79                      throw new DecoderException("too large message: " + bytesRemainingInVariablePart + " bytes");
80                  }
81                  final Result<?> decodedVariableHeader = decodeVariableHeader(buffer, mqttFixedHeader);
82                  variableHeader = decodedVariableHeader.value;
83                  bytesRemainingInVariablePart -= decodedVariableHeader.numberOfBytesConsumed;
84                  checkpoint(DecoderState.READ_PAYLOAD);
85                  // fall through
86              } catch (Exception cause) {
87                  out.add(invalidMessage(cause));
88                  return;
89              }
90  
91              case READ_PAYLOAD: try {
92                  final Result<?> decodedPayload =
93                          decodePayload(
94                                  buffer,
95                                  mqttFixedHeader.messageType(),
96                                  bytesRemainingInVariablePart,
97                                  variableHeader);
98                  payload = decodedPayload.value;
99                  bytesRemainingInVariablePart -= decodedPayload.numberOfBytesConsumed;
100                 if (bytesRemainingInVariablePart != 0) {
101                     throw new DecoderException(
102                             "non-zero remaining payload bytes: " +
103                                     bytesRemainingInVariablePart + " (" + mqttFixedHeader.messageType() + ')');
104                 }
105                 checkpoint(DecoderState.READ_FIXED_HEADER);
106                 MqttMessage message = MqttMessageFactory.newMessage(mqttFixedHeader, variableHeader, payload);
107                 mqttFixedHeader = null;
108                 variableHeader = null;
109                 payload = null;
110                 out.add(message);
111                 break;
112             } catch (Exception cause) {
113                 out.add(invalidMessage(cause));
114                 return;
115             }
116 
117             case BAD_MESSAGE:
118                 // Keep discarding until disconnection.
119                 buffer.skipBytes(actualReadableBytes());
120                 break;
121 
122             default:
123                 // Shouldn't reach here.
124                 throw new Error();
125         }
126     }
127 
128     private MqttMessage invalidMessage(Throwable cause) {
129       checkpoint(DecoderState.BAD_MESSAGE);
130       return MqttMessageFactory.newInvalidMessage(cause);
131     }
132 
133     /**
134      * Decodes the fixed header. It's one byte for the flags and then variable bytes for the remaining length.
135      *
136      * @param buffer the buffer to decode from
137      * @return the fixed header
138      */
139     private static MqttFixedHeader decodeFixedHeader(ByteBuf buffer) {
140         short b1 = buffer.readUnsignedByte();
141 
142         MqttMessageType messageType = MqttMessageType.valueOf(b1 >> 4);
143         boolean dupFlag = (b1 & 0x08) == 0x08;
144         int qosLevel = (b1 & 0x06) >> 1;
145         boolean retain = (b1 & 0x01) != 0;
146 
147         int remainingLength = 0;
148         int multiplier = 1;
149         short digit;
150         int loops = 0;
151         do {
152             digit = buffer.readUnsignedByte();
153             remainingLength += (digit & 127) * multiplier;
154             multiplier *= 128;
155             loops++;
156         } while ((digit & 128) != 0 && loops < 4);
157 
158         // MQTT protocol limits Remaining Length to 4 bytes
159         if (loops == 4 && (digit & 128) != 0) {
160             throw new DecoderException("remaining length exceeds 4 digits (" + messageType + ')');
161         }
162         MqttFixedHeader decodedFixedHeader =
163                 new MqttFixedHeader(messageType, dupFlag, MqttQoS.valueOf(qosLevel), retain, remainingLength);
164         return validateFixedHeader(resetUnusedFields(decodedFixedHeader));
165     }
166 
167     /**
168      * Decodes the variable header (if any)
169      * @param buffer the buffer to decode from
170      * @param mqttFixedHeader MqttFixedHeader of the same message
171      * @return the variable header
172      */
173     private static Result<?> decodeVariableHeader(ByteBuf buffer, MqttFixedHeader mqttFixedHeader) {
174         switch (mqttFixedHeader.messageType()) {
175             case CONNECT:
176                 return decodeConnectionVariableHeader(buffer);
177 
178             case CONNACK:
179                 return decodeConnAckVariableHeader(buffer);
180 
181             case SUBSCRIBE:
182             case UNSUBSCRIBE:
183             case SUBACK:
184             case UNSUBACK:
185             case PUBACK:
186             case PUBREC:
187             case PUBCOMP:
188             case PUBREL:
189                 return decodeMessageIdVariableHeader(buffer);
190 
191             case PUBLISH:
192                 return decodePublishVariableHeader(buffer, mqttFixedHeader);
193 
194             case PINGREQ:
195             case PINGRESP:
196             case DISCONNECT:
197                 // Empty variable header
198                 return new Result<Object>(null, 0);
199         }
200         return new Result<Object>(null, 0); //should never reach here
201     }
202 
203     private static Result<MqttConnectVariableHeader> decodeConnectionVariableHeader(ByteBuf buffer) {
204         final Result<String> protoString = decodeString(buffer);
205         int numberOfBytesConsumed = protoString.numberOfBytesConsumed;
206 
207         final byte protocolLevel = buffer.readByte();
208         numberOfBytesConsumed += 1;
209 
210         final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(protoString.value, protocolLevel);
211 
212         final int b1 = buffer.readUnsignedByte();
213         numberOfBytesConsumed += 1;
214 
215         final Result<Integer> keepAlive = decodeMsbLsb(buffer);
216         numberOfBytesConsumed += keepAlive.numberOfBytesConsumed;
217 
218         final boolean hasUserName = (b1 & 0x80) == 0x80;
219         final boolean hasPassword = (b1 & 0x40) == 0x40;
220         final boolean willRetain = (b1 & 0x20) == 0x20;
221         final int willQos = (b1 & 0x18) >> 3;
222         final boolean willFlag = (b1 & 0x04) == 0x04;
223         final boolean cleanSession = (b1 & 0x02) == 0x02;
224 
225         final MqttConnectVariableHeader mqttConnectVariableHeader = new MqttConnectVariableHeader(
226                 mqttVersion.protocolName(),
227                 mqttVersion.protocolLevel(),
228                 hasUserName,
229                 hasPassword,
230                 willRetain,
231                 willQos,
232                 willFlag,
233                 cleanSession,
234                 keepAlive.value);
235         return new Result<MqttConnectVariableHeader>(mqttConnectVariableHeader, numberOfBytesConsumed);
236     }
237 
238     private static Result<MqttConnAckVariableHeader> decodeConnAckVariableHeader(ByteBuf buffer) {
239         buffer.readUnsignedByte(); // reserved byte
240         byte returnCode = buffer.readByte();
241         final int numberOfBytesConsumed = 2;
242         final MqttConnAckVariableHeader mqttConnAckVariableHeader =
243                 new MqttConnAckVariableHeader(MqttConnectReturnCode.valueOf(returnCode));
244         return new Result<MqttConnAckVariableHeader>(mqttConnAckVariableHeader, numberOfBytesConsumed);
245     }
246 
247     private static Result<MqttMessageIdVariableHeader> decodeMessageIdVariableHeader(ByteBuf buffer) {
248         final Result<Integer> messageId = decodeMessageId(buffer);
249         return new Result<MqttMessageIdVariableHeader>(
250                 MqttMessageIdVariableHeader.from(messageId.value),
251                 messageId.numberOfBytesConsumed);
252     }
253 
254     private static Result<MqttPublishVariableHeader> decodePublishVariableHeader(
255             ByteBuf buffer,
256             MqttFixedHeader mqttFixedHeader) {
257         final Result<String> decodedTopic = decodeString(buffer);
258         if (!isValidPublishTopicName(decodedTopic.value)) {
259             throw new DecoderException("invalid publish topic name: " + decodedTopic.value + " (contains wildcards)");
260         }
261         int numberOfBytesConsumed = decodedTopic.numberOfBytesConsumed;
262 
263         int messageId = -1;
264         if (mqttFixedHeader.qosLevel().value() > 0) {
265             final Result<Integer> decodedMessageId = decodeMessageId(buffer);
266             messageId = decodedMessageId.value;
267             numberOfBytesConsumed += decodedMessageId.numberOfBytesConsumed;
268         }
269         final MqttPublishVariableHeader mqttPublishVariableHeader =
270                 new MqttPublishVariableHeader(decodedTopic.value, messageId);
271         return new Result<MqttPublishVariableHeader>(mqttPublishVariableHeader, numberOfBytesConsumed);
272     }
273 
274     private static Result<Integer> decodeMessageId(ByteBuf buffer) {
275         final Result<Integer> messageId = decodeMsbLsb(buffer);
276         if (!isValidMessageId(messageId.value)) {
277             throw new DecoderException("invalid messageId: " + messageId.value);
278         }
279         return messageId;
280     }
281 
282     /**
283      * Decodes the payload.
284      *
285      * @param buffer the buffer to decode from
286      * @param messageType  type of the message being decoded
287      * @param bytesRemainingInVariablePart bytes remaining
288      * @param variableHeader variable header of the same message
289      * @return the payload
290      */
291     private static Result<?> decodePayload(
292             ByteBuf buffer,
293             MqttMessageType messageType,
294             int bytesRemainingInVariablePart,
295             Object variableHeader) {
296         switch (messageType) {
297             case CONNECT:
298                 return decodeConnectionPayload(buffer, (MqttConnectVariableHeader) variableHeader);
299 
300             case SUBSCRIBE:
301                 return decodeSubscribePayload(buffer, bytesRemainingInVariablePart);
302 
303             case SUBACK:
304                 return decodeSubackPayload(buffer, bytesRemainingInVariablePart);
305 
306             case UNSUBSCRIBE:
307                 return decodeUnsubscribePayload(buffer, bytesRemainingInVariablePart);
308 
309             case PUBLISH:
310                 return decodePublishPayload(buffer, bytesRemainingInVariablePart);
311 
312             default:
313                 // unknown payload , no byte consumed
314                 return new Result<Object>(null, 0);
315         }
316     }
317 
318     private static Result<MqttConnectPayload> decodeConnectionPayload(
319             ByteBuf buffer,
320             MqttConnectVariableHeader mqttConnectVariableHeader) {
321         final Result<String> decodedClientId = decodeString(buffer);
322         final String decodedClientIdValue = decodedClientId.value;
323         final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(mqttConnectVariableHeader.name(),
324                 (byte) mqttConnectVariableHeader.version());
325         if (!isValidClientId(mqttVersion, decodedClientIdValue)) {
326             throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + decodedClientIdValue);
327         }
328         int numberOfBytesConsumed = decodedClientId.numberOfBytesConsumed;
329 
330         Result<String> decodedWillTopic = null;
331         Result<String> decodedWillMessage = null;
332         if (mqttConnectVariableHeader.isWillFlag()) {
333             decodedWillTopic = decodeString(buffer, 0, 32767);
334             numberOfBytesConsumed += decodedWillTopic.numberOfBytesConsumed;
335             decodedWillMessage = decodeAsciiString(buffer);
336             numberOfBytesConsumed += decodedWillMessage.numberOfBytesConsumed;
337         }
338         Result<String> decodedUserName = null;
339         Result<String> decodedPassword = null;
340         if (mqttConnectVariableHeader.hasUserName()) {
341             decodedUserName = decodeString(buffer);
342             numberOfBytesConsumed += decodedUserName.numberOfBytesConsumed;
343         }
344         if (mqttConnectVariableHeader.hasPassword()) {
345             decodedPassword = decodeString(buffer);
346             numberOfBytesConsumed += decodedPassword.numberOfBytesConsumed;
347         }
348 
349         final MqttConnectPayload mqttConnectPayload =
350                 new MqttConnectPayload(
351                         decodedClientId.value,
352                         decodedWillTopic != null ? decodedWillTopic.value : null,
353                         decodedWillMessage != null ? decodedWillMessage.value : null,
354                         decodedUserName != null ? decodedUserName.value : null,
355                         decodedPassword != null ? decodedPassword.value : null);
356         return new Result<MqttConnectPayload>(mqttConnectPayload, numberOfBytesConsumed);
357     }
358 
359     private static Result<MqttSubscribePayload> decodeSubscribePayload(
360             ByteBuf buffer,
361             int bytesRemainingInVariablePart) {
362         final List<MqttTopicSubscription> subscribeTopics = new ArrayList<MqttTopicSubscription>();
363         int numberOfBytesConsumed = 0;
364         while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
365             final Result<String> decodedTopicName = decodeString(buffer);
366             numberOfBytesConsumed += decodedTopicName.numberOfBytesConsumed;
367             int qos = buffer.readUnsignedByte() & 0x03;
368             numberOfBytesConsumed++;
369             subscribeTopics.add(new MqttTopicSubscription(decodedTopicName.value, MqttQoS.valueOf(qos)));
370         }
371         return new Result<MqttSubscribePayload>(new MqttSubscribePayload(subscribeTopics), numberOfBytesConsumed);
372     }
373 
374     private static Result<MqttSubAckPayload> decodeSubackPayload(
375             ByteBuf buffer,
376             int bytesRemainingInVariablePart) {
377         final List<Integer> grantedQos = new ArrayList<Integer>();
378         int numberOfBytesConsumed = 0;
379         while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
380             int qos = buffer.readUnsignedByte() & 0x03;
381             numberOfBytesConsumed++;
382             grantedQos.add(qos);
383         }
384         return new Result<MqttSubAckPayload>(new MqttSubAckPayload(grantedQos), numberOfBytesConsumed);
385     }
386 
387     private static Result<MqttUnsubscribePayload> decodeUnsubscribePayload(
388             ByteBuf buffer,
389             int bytesRemainingInVariablePart) {
390         final List<String> unsubscribeTopics = new ArrayList<String>();
391         int numberOfBytesConsumed = 0;
392         while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
393             final Result<String> decodedTopicName = decodeString(buffer);
394             numberOfBytesConsumed += decodedTopicName.numberOfBytesConsumed;
395             unsubscribeTopics.add(decodedTopicName.value);
396         }
397         return new Result<MqttUnsubscribePayload>(
398                 new MqttUnsubscribePayload(unsubscribeTopics),
399                 numberOfBytesConsumed);
400     }
401 
402     private static Result<ByteBuf> decodePublishPayload(ByteBuf buffer, int bytesRemainingInVariablePart) {
403         ByteBuf b = buffer.readSlice(bytesRemainingInVariablePart).retain();
404         return new Result<ByteBuf>(b, bytesRemainingInVariablePart);
405     }
406 
407     private static Result<String> decodeString(ByteBuf buffer) {
408         return decodeString(buffer, 0, Integer.MAX_VALUE);
409     }
410 
411     private static Result<String> decodeAsciiString(ByteBuf buffer) {
412         Result<String> result = decodeString(buffer, 0, Integer.MAX_VALUE);
413         final String s = result.value;
414         for (int i = 0; i < s.length(); i++) {
415             if (s.charAt(i) > 127) {
416                 return new Result<String>(null, result.numberOfBytesConsumed);
417             }
418         }
419         return new Result<String>(s, result.numberOfBytesConsumed);
420     }
421 
422     private static Result<String> decodeString(ByteBuf buffer, int minBytes, int maxBytes) {
423         final Result<Integer> decodedSize = decodeMsbLsb(buffer);
424         int size = decodedSize.value;
425         int numberOfBytesConsumed = decodedSize.numberOfBytesConsumed;
426         if (size < minBytes || size > maxBytes) {
427             buffer.skipBytes(size);
428             numberOfBytesConsumed += size;
429             return new Result<String>(null, numberOfBytesConsumed);
430         }
431         ByteBuf buf = buffer.readBytes(size);
432         numberOfBytesConsumed += size;
433         return new Result<String>(buf.toString(CharsetUtil.UTF_8), numberOfBytesConsumed);
434     }
435 
436     private static Result<Integer> decodeMsbLsb(ByteBuf buffer) {
437         return decodeMsbLsb(buffer, 0, 65535);
438     }
439 
440     private static Result<Integer> decodeMsbLsb(ByteBuf buffer, int min, int max) {
441         short msbSize = buffer.readUnsignedByte();
442         short lsbSize = buffer.readUnsignedByte();
443         final int numberOfBytesConsumed = 2;
444         int result = msbSize << 8 | lsbSize;
445         if (result < min || result > max) {
446             result = -1;
447         }
448         return new Result<Integer>(result, numberOfBytesConsumed);
449     }
450 
451     private static final class Result<T> {
452 
453         private final T value;
454         private final int numberOfBytesConsumed;
455 
456         Result(T value, int numberOfBytesConsumed) {
457             this.value = value;
458             this.numberOfBytesConsumed = numberOfBytesConsumed;
459         }
460     }
461 }