View Javadoc

1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.haproxy;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.ChannelHandlerContext;
20  import io.netty.handler.codec.ByteToMessageDecoder;
21  import io.netty.handler.codec.LineBasedFrameDecoder;
22  import io.netty.handler.codec.ProtocolDetectionResult;
23  import io.netty.util.CharsetUtil;
24  
25  import java.util.List;
26  
27  /**
28   * Decodes an HAProxy proxy protocol header
29   *
30   * @see <a href="http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt">Proxy Protocol Specification</a>
31   */
32  public class HAProxyMessageDecoder extends ByteToMessageDecoder {
33      /**
34       * Maximum possible length of a v1 proxy protocol header per spec
35       */
36      private static final int V1_MAX_LENGTH = 108;
37  
38      /**
39       * Maximum possible length of a v2 proxy protocol header (fixed 16 bytes + max unsigned short)
40       */
41      private static final int V2_MAX_LENGTH = 16 + 65535;
42  
43      /**
44       * Minimum possible length of a fully functioning v2 proxy protocol header (fixed 16 bytes + v2 address info space)
45       */
46      private static final int V2_MIN_LENGTH = 16 + 216;
47  
48      /**
49       * Maximum possible length for v2 additional TLV data (max unsigned short - max v2 address info space)
50       */
51      private static final int V2_MAX_TLV = 65535 - 216;
52  
53      /**
54       * Version 1 header delimiter is always '\r\n' per spec
55       */
56      private static final int DELIMITER_LENGTH = 2;
57  
58      /**
59       * Binary header prefix
60       */
61      private static final byte[] BINARY_PREFIX = {
62              (byte) 0x0D,
63              (byte) 0x0A,
64              (byte) 0x0D,
65              (byte) 0x0A,
66              (byte) 0x00,
67              (byte) 0x0D,
68              (byte) 0x0A,
69              (byte) 0x51,
70              (byte) 0x55,
71              (byte) 0x49,
72              (byte) 0x54,
73              (byte) 0x0A
74      };
75  
76      private static final byte[] TEXT_PREFIX = {
77              (byte) 'P',
78              (byte) 'R',
79              (byte) 'O',
80              (byte) 'X',
81              (byte) 'Y',
82      };
83  
84      /**
85       * Binary header prefix length
86       */
87      private static final int BINARY_PREFIX_LENGTH = BINARY_PREFIX.length;
88  
89      /**
90       * {@link ProtocolDetectionResult} for {@link HAProxyProtocolVersion#V1}.
91       */
92      private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V1 =
93              ProtocolDetectionResult.detected(HAProxyProtocolVersion.V1);
94  
95      /**
96       * {@link ProtocolDetectionResult} for {@link HAProxyProtocolVersion#V2}.
97       */
98      private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V2 =
99              ProtocolDetectionResult.detected(HAProxyProtocolVersion.V2);
100 
101     /**
102      * {@code true} if we're discarding input because we're already over maxLength
103      */
104     private boolean discarding;
105 
106     /**
107      * Number of discarded bytes
108      */
109     private int discardedBytes;
110 
111     /**
112      * {@code true} if we're finished decoding the proxy protocol header
113      */
114     private boolean finished;
115 
116     /**
117      * Protocol specification version
118      */
119     private int version = -1;
120 
121     /**
122      * The latest v2 spec (2014/05/18) allows for additional data to be sent in the proxy protocol header beyond the
123      * address information block so now we need a configurable max header size
124      */
125     private final int v2MaxHeaderSize;
126 
127     /**
128      * Creates a new decoder with no additional data (TLV) restrictions
129      */
130     public HAProxyMessageDecoder() {
131         v2MaxHeaderSize = V2_MAX_LENGTH;
132     }
133 
134     /**
135      * Creates a new decoder with restricted additional data (TLV) size
136      * <p>
137      * <b>Note:</b> limiting TLV size only affects processing of v2, binary headers. Also, as allowed by the 1.5 spec
138      * TLV data is currently ignored. For maximum performance it would be best to configure your upstream proxy host to
139      * <b>NOT</b> send TLV data and instantiate with a max TLV size of {@code 0}.
140      * </p>
141      *
142      * @param maxTlvSize maximum number of bytes allowed for additional data (Type-Length-Value vectors) in a v2 header
143      */
144     public HAProxyMessageDecoder(int maxTlvSize) {
145         if (maxTlvSize < 1) {
146             v2MaxHeaderSize = V2_MIN_LENGTH;
147         } else if (maxTlvSize > V2_MAX_TLV) {
148             v2MaxHeaderSize = V2_MAX_LENGTH;
149         } else {
150             int calcMax = maxTlvSize + V2_MIN_LENGTH;
151             if (calcMax > V2_MAX_LENGTH) {
152                 v2MaxHeaderSize = V2_MAX_LENGTH;
153             } else {
154                 v2MaxHeaderSize = calcMax;
155             }
156         }
157     }
158 
159     /**
160      * Returns the proxy protocol specification version in the buffer if the version is found.
161      * Returns -1 if no version was found in the buffer.
162      */
163     private static int findVersion(final ByteBuf buffer) {
164         final int n = buffer.readableBytes();
165         // per spec, the version number is found in the 13th byte
166         if (n < 13) {
167             return -1;
168         }
169 
170         int idx = buffer.readerIndex();
171         return match(BINARY_PREFIX, buffer, idx) ? buffer.getByte(idx + BINARY_PREFIX_LENGTH) : 1;
172     }
173 
174     /**
175      * Returns the index in the buffer of the end of header if found.
176      * Returns -1 if no end of header was found in the buffer.
177      */
178     private static int findEndOfHeader(final ByteBuf buffer) {
179         final int n = buffer.readableBytes();
180 
181         // per spec, the 15th and 16th bytes contain the address length in bytes
182         if (n < 16) {
183             return -1;
184         }
185 
186         int offset = buffer.readerIndex() + 14;
187 
188         // the total header length will be a fixed 16 byte sequence + the dynamic address information block
189         int totalHeaderBytes = 16 + buffer.getUnsignedShort(offset);
190 
191         // ensure we actually have the full header available
192         if (n >= totalHeaderBytes) {
193             return totalHeaderBytes;
194         } else {
195             return -1;
196         }
197     }
198 
199     /**
200      * Returns the index in the buffer of the end of line found.
201      * Returns -1 if no end of line was found in the buffer.
202      */
203     private static int findEndOfLine(final ByteBuf buffer) {
204         final int n = buffer.writerIndex();
205         for (int i = buffer.readerIndex(); i < n; i++) {
206             final byte b = buffer.getByte(i);
207             if (b == '\r' && i < n - 1 && buffer.getByte(i + 1) == '\n') {
208                 return i;  // \r\n
209             }
210         }
211         return -1;  // Not found.
212     }
213 
214     @Override
215     public boolean isSingleDecode() {
216         // ByteToMessageDecoder uses this method to optionally break out of the decoding loop after each unit of work.
217         // Since we only ever want to decode a single header we always return true to save a bit of work here.
218         return true;
219     }
220 
221     @Override
222     public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
223         super.channelRead(ctx, msg);
224         if (finished) {
225             ctx.pipeline().remove(this);
226         }
227     }
228 
229     @Override
230     protected final void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
231         // determine the specification version
232         if (version == -1) {
233             if ((version = findVersion(in)) == -1) {
234                 return;
235             }
236         }
237 
238         ByteBuf decoded;
239 
240         if (version == 1) {
241             decoded = decodeLine(ctx, in);
242         } else {
243             decoded = decodeStruct(ctx, in);
244         }
245 
246         if (decoded != null) {
247             finished = true;
248             try {
249                 if (version == 1) {
250                     out.add(HAProxyMessage.decodeHeader(decoded.toString(CharsetUtil.US_ASCII)));
251                 } else {
252                     out.add(HAProxyMessage.decodeHeader(decoded));
253                 }
254             } catch (HAProxyProtocolException e) {
255                 fail(ctx, null, e);
256             }
257         }
258     }
259 
260     /**
261      * Create a frame out of the {@link ByteBuf} and return it.
262      * Based on code from {@link LineBasedFrameDecoder#decode(ChannelHandlerContext, ByteBuf)}.
263      *
264      * @param ctx     the {@link ChannelHandlerContext} which this {@link HAProxyMessageDecoder} belongs to
265      * @param buffer  the {@link ByteBuf} from which to read data
266      * @return frame  the {@link ByteBuf} which represent the frame or {@code null} if no frame could
267      *                be created
268      */
269     private ByteBuf decodeStruct(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
270         final int eoh = findEndOfHeader(buffer);
271         if (!discarding) {
272             if (eoh >= 0) {
273                 final int length = eoh - buffer.readerIndex();
274                 if (length > v2MaxHeaderSize) {
275                     buffer.readerIndex(eoh);
276                     failOverLimit(ctx, length);
277                     return null;
278                 }
279                 return buffer.readSlice(length);
280             } else {
281                 final int length = buffer.readableBytes();
282                 if (length > v2MaxHeaderSize) {
283                     discardedBytes = length;
284                     buffer.skipBytes(length);
285                     discarding = true;
286                     failOverLimit(ctx, "over " + discardedBytes);
287                 }
288                 return null;
289             }
290         } else {
291             if (eoh >= 0) {
292                 buffer.readerIndex(eoh);
293                 discardedBytes = 0;
294                 discarding = false;
295             } else {
296                 discardedBytes = buffer.readableBytes();
297                 buffer.skipBytes(discardedBytes);
298             }
299             return null;
300         }
301     }
302 
303     /**
304      * Create a frame out of the {@link ByteBuf} and return it.
305      * Based on code from {@link LineBasedFrameDecoder#decode(ChannelHandlerContext, ByteBuf)}.
306      *
307      * @param ctx     the {@link ChannelHandlerContext} which this {@link HAProxyMessageDecoder} belongs to
308      * @param buffer  the {@link ByteBuf} from which to read data
309      * @return frame  the {@link ByteBuf} which represent the frame or {@code null} if no frame could
310      *                be created
311      */
312     private ByteBuf decodeLine(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
313         final int eol = findEndOfLine(buffer);
314         if (!discarding) {
315             if (eol >= 0) {
316                 final int length = eol - buffer.readerIndex();
317                 if (length > V1_MAX_LENGTH) {
318                     buffer.readerIndex(eol + DELIMITER_LENGTH);
319                     failOverLimit(ctx, length);
320                     return null;
321                 }
322                 ByteBuf frame = buffer.readSlice(length);
323                 buffer.skipBytes(DELIMITER_LENGTH);
324                 return frame;
325             } else {
326                 final int length = buffer.readableBytes();
327                 if (length > V1_MAX_LENGTH) {
328                     discardedBytes = length;
329                     buffer.skipBytes(length);
330                     discarding = true;
331                     failOverLimit(ctx, "over " + discardedBytes);
332                 }
333                 return null;
334             }
335         } else {
336             if (eol >= 0) {
337                 final int delimLength = buffer.getByte(eol) == '\r' ? 2 : 1;
338                 buffer.readerIndex(eol + delimLength);
339                 discardedBytes = 0;
340                 discarding = false;
341             } else {
342                 discardedBytes = buffer.readableBytes();
343                 buffer.skipBytes(discardedBytes);
344             }
345             return null;
346         }
347     }
348 
349     private void failOverLimit(final ChannelHandlerContext ctx, int length) {
350         failOverLimit(ctx, String.valueOf(length));
351     }
352 
353     private void failOverLimit(final ChannelHandlerContext ctx, String length) {
354         int maxLength = version == 1 ? V1_MAX_LENGTH : v2MaxHeaderSize;
355         fail(ctx, "header length (" + length + ") exceeds the allowed maximum (" + maxLength + ')', null);
356     }
357 
358     private void fail(final ChannelHandlerContext ctx, String errMsg, Exception e) {
359         finished = true;
360         ctx.close(); // drop connection immediately per spec
361         HAProxyProtocolException ppex;
362         if (errMsg != null && e != null) {
363             ppex = new HAProxyProtocolException(errMsg, e);
364         } else if (errMsg != null) {
365             ppex = new HAProxyProtocolException(errMsg);
366         } else if (e != null) {
367             ppex = new HAProxyProtocolException(e);
368         } else {
369             ppex = new HAProxyProtocolException();
370         }
371         throw ppex;
372     }
373 
374     /**
375      * Returns the {@link ProtocolDetectionResult} for the given {@link ByteBuf}.
376      */
377     public static ProtocolDetectionResult<HAProxyProtocolVersion> detectProtocol(ByteBuf buffer) {
378         if (buffer.readableBytes() < 12) {
379             return ProtocolDetectionResult.needsMoreData();
380         }
381 
382         int idx = buffer.readerIndex();
383 
384         if (match(BINARY_PREFIX, buffer, idx)) {
385             return DETECTION_RESULT_V2;
386         }
387         if (match(TEXT_PREFIX, buffer, idx)) {
388             return DETECTION_RESULT_V1;
389         }
390         return ProtocolDetectionResult.invalid();
391     }
392 
393     private static boolean match(byte[] prefix, ByteBuf buffer, int idx) {
394         for (int i = 0; i < prefix.length; i++) {
395             final byte b = buffer.getByte(idx + i);
396             if (b != prefix[i]) {
397                 return false;
398             }
399         }
400         return true;
401     }
402 }