View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.haproxy;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.ChannelHandlerContext;
20  import io.netty.handler.codec.ByteToMessageDecoder;
21  import io.netty.handler.codec.ProtocolDetectionResult;
22  import io.netty.util.CharsetUtil;
23  
24  import java.util.List;
25  
26  import static io.netty.handler.codec.haproxy.HAProxyConstants.*;
27  
28  /**
29   * Decodes an HAProxy proxy protocol header
30   *
31   * @see <a href="https://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt">Proxy Protocol Specification</a>
32   */
33  public class HAProxyMessageDecoder extends ByteToMessageDecoder {
34      /**
35       * Maximum possible length of a v1 proxy protocol header per spec
36       */
37      private static final int V1_MAX_LENGTH = 108;
38  
39      /**
40       * Maximum possible length of a v2 proxy protocol header (fixed 16 bytes + max unsigned short)
41       */
42      private static final int V2_MAX_LENGTH = 16 + 65535;
43  
44      /**
45       * Minimum possible length of a fully functioning v2 proxy protocol header (fixed 16 bytes + v2 address info space)
46       */
47      private static final int V2_MIN_LENGTH = 16 + 216;
48  
49      /**
50       * Maximum possible length for v2 additional TLV data (max unsigned short - max v2 address info space)
51       */
52      private static final int V2_MAX_TLV = 65535 - 216;
53  
54      /**
55       * Binary header prefix length
56       */
57      private static final int BINARY_PREFIX_LENGTH = BINARY_PREFIX.length;
58  
59      /**
60       * {@link ProtocolDetectionResult} for {@link HAProxyProtocolVersion#V1}.
61       */
62      private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V1 =
63              ProtocolDetectionResult.detected(HAProxyProtocolVersion.V1);
64  
65      /**
66       * {@link ProtocolDetectionResult} for {@link HAProxyProtocolVersion#V2}.
67       */
68      private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V2 =
69              ProtocolDetectionResult.detected(HAProxyProtocolVersion.V2);
70  
71      /**
72       * Used to extract a header frame out of the {@link ByteBuf} and return it.
73       */
74      private HeaderExtractor headerExtractor;
75  
76      /**
77       * {@code true} if we're discarding input because we're already over maxLength
78       */
79      private boolean discarding;
80  
81      /**
82       * Number of discarded bytes
83       */
84      private int discardedBytes;
85  
86      /**
87       * Whether or not to throw an exception as soon as we exceed maxLength.
88       */
89      private final boolean failFast;
90  
91      /**
92       * {@code true} if we're finished decoding the proxy protocol header
93       */
94      private boolean finished;
95  
96      /**
97       * Protocol specification version
98       */
99      private int version = -1;
100 
101     /**
102      * The latest v2 spec (2014/05/18) allows for additional data to be sent in the proxy protocol header beyond the
103      * address information block so now we need a configurable max header size
104      */
105     private final int v2MaxHeaderSize;
106 
107     /**
108      * Creates a new decoder with no additional data (TLV) restrictions, and should throw an exception as soon as
109      * we exceed maxLength.
110      */
111     public HAProxyMessageDecoder() {
112         this(true);
113     }
114 
115     /**
116      * Creates a new decoder with no additional data (TLV) restrictions, whether or not to throw an exception as soon
117      * as we exceed maxLength.
118      *
119      * @param failFast Whether or not to throw an exception as soon as we exceed maxLength
120      */
121     public HAProxyMessageDecoder(boolean failFast) {
122         v2MaxHeaderSize = V2_MAX_LENGTH;
123         this.failFast = failFast;
124     }
125 
126     /**
127      * Creates a new decoder with restricted additional data (TLV) size, and should throw an exception as soon as
128      * we exceed maxLength.
129      * <p>
130      * <b>Note:</b> limiting TLV size only affects processing of v2, binary headers. Also, as allowed by the 1.5 spec
131      * TLV data is currently ignored. For maximum performance it would be best to configure your upstream proxy host to
132      * <b>NOT</b> send TLV data and instantiate with a max TLV size of {@code 0}.
133      * </p>
134      *
135      * @param maxTlvSize maximum number of bytes allowed for additional data (Type-Length-Value vectors) in a v2 header
136      */
137     public HAProxyMessageDecoder(int maxTlvSize) {
138         this(maxTlvSize, true);
139     }
140 
141     /**
142      * Creates a new decoder with restricted additional data (TLV) size, whether or not to throw an exception as soon
143      * as we exceed maxLength.
144      *
145      * @param maxTlvSize maximum number of bytes allowed for additional data (Type-Length-Value vectors) in a v2 header
146      * @param failFast Whether or not to throw an exception as soon as we exceed maxLength
147      */
148     public HAProxyMessageDecoder(int maxTlvSize, boolean failFast) {
149         if (maxTlvSize < 1) {
150             v2MaxHeaderSize = V2_MIN_LENGTH;
151         } else if (maxTlvSize > V2_MAX_TLV) {
152             v2MaxHeaderSize = V2_MAX_LENGTH;
153         } else {
154             int calcMax = maxTlvSize + V2_MIN_LENGTH;
155             if (calcMax > V2_MAX_LENGTH) {
156                 v2MaxHeaderSize = V2_MAX_LENGTH;
157             } else {
158                 v2MaxHeaderSize = calcMax;
159             }
160         }
161         this.failFast = failFast;
162     }
163 
164     /**
165      * Returns the proxy protocol specification version in the buffer if the version is found.
166      * Returns -1 if no version was found in the buffer.
167      */
168     private static int findVersion(final ByteBuf buffer) {
169         final int n = buffer.readableBytes();
170         // per spec, the version number is found in the 13th byte
171         if (n < 13) {
172             return -1;
173         }
174 
175         int idx = buffer.readerIndex();
176         return match(BINARY_PREFIX, buffer, idx) ? buffer.getByte(idx + BINARY_PREFIX_LENGTH) : 1;
177     }
178 
179     /**
180      * Returns the index in the buffer of the end of header if found.
181      * Returns -1 if no end of header was found in the buffer.
182      */
183     private static int findEndOfHeader(final ByteBuf buffer) {
184         final int n = buffer.readableBytes();
185 
186         // per spec, the 15th and 16th bytes contain the address length in bytes
187         if (n < 16) {
188             return -1;
189         }
190 
191         int offset = buffer.readerIndex() + 14;
192 
193         // the total header length will be a fixed 16 byte sequence + the dynamic address information block
194         int totalHeaderBytes = 16 + buffer.getUnsignedShort(offset);
195 
196         // ensure we actually have the full header available
197         if (n >= totalHeaderBytes) {
198             return totalHeaderBytes;
199         } else {
200             return -1;
201         }
202     }
203 
204     /**
205      * Returns the index in the buffer of the end of line found.
206      * Returns -1 if no end of line was found in the buffer.
207      */
208     private static int findEndOfLine(final ByteBuf buffer) {
209         final int n = buffer.writerIndex();
210         for (int i = buffer.readerIndex(); i < n; i++) {
211             final byte b = buffer.getByte(i);
212             if (b == '\r' && i < n - 1 && buffer.getByte(i + 1) == '\n') {
213                 return i;  // \r\n
214             }
215         }
216         return -1;  // Not found.
217     }
218 
219     @Override
220     public boolean isSingleDecode() {
221         // ByteToMessageDecoder uses this method to optionally break out of the decoding loop after each unit of work.
222         // Since we only ever want to decode a single header we always return true to save a bit of work here.
223         return true;
224     }
225 
226     @Override
227     public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
228         super.channelRead(ctx, msg);
229         if (finished) {
230             ctx.pipeline().remove(this);
231         }
232     }
233 
234     @Override
235     protected final void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
236         // determine the specification version
237         if (version == -1) {
238             if ((version = findVersion(in)) == -1) {
239                 return;
240             }
241         }
242 
243         ByteBuf decoded;
244 
245         if (version == 1) {
246             decoded = decodeLine(ctx, in);
247         } else {
248             decoded = decodeStruct(ctx, in);
249         }
250 
251         if (decoded != null) {
252             finished = true;
253             try {
254                 if (version == 1) {
255                     out.add(HAProxyMessage.decodeHeader(decoded.toString(CharsetUtil.US_ASCII)));
256                 } else {
257                     out.add(HAProxyMessage.decodeHeader(decoded));
258                 }
259             } catch (HAProxyProtocolException e) {
260                 fail(ctx, null, e);
261             }
262         }
263     }
264 
265     /**
266      * Create a frame out of the {@link ByteBuf} and return it.
267      *
268      * @param ctx     the {@link ChannelHandlerContext} which this {@link HAProxyMessageDecoder} belongs to
269      * @param buffer  the {@link ByteBuf} from which to read data
270      * @return frame  the {@link ByteBuf} which represent the frame or {@code null} if no frame could
271      *                be created
272      */
273     private ByteBuf decodeStruct(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
274         if (headerExtractor == null) {
275             headerExtractor = new StructHeaderExtractor(v2MaxHeaderSize);
276         }
277         return headerExtractor.extract(ctx, buffer);
278     }
279 
280     /**
281      * Create a frame out of the {@link ByteBuf} and return it.
282      *
283      * @param ctx     the {@link ChannelHandlerContext} which this {@link HAProxyMessageDecoder} belongs to
284      * @param buffer  the {@link ByteBuf} from which to read data
285      * @return frame  the {@link ByteBuf} which represent the frame or {@code null} if no frame could
286      *                be created
287      */
288     private ByteBuf decodeLine(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
289         if (headerExtractor == null) {
290             headerExtractor = new LineHeaderExtractor(V1_MAX_LENGTH);
291         }
292         return headerExtractor.extract(ctx, buffer);
293     }
294 
295     private void failOverLimit(final ChannelHandlerContext ctx, int length) {
296         failOverLimit(ctx, String.valueOf(length));
297     }
298 
299     private void failOverLimit(final ChannelHandlerContext ctx, String length) {
300         int maxLength = version == 1 ? V1_MAX_LENGTH : v2MaxHeaderSize;
301         fail(ctx, "header length (" + length + ") exceeds the allowed maximum (" + maxLength + ')', null);
302     }
303 
304     private void fail(final ChannelHandlerContext ctx, String errMsg, Exception e) {
305         finished = true;
306         ctx.close(); // drop connection immediately per spec
307         HAProxyProtocolException ppex;
308         if (errMsg != null && e != null) {
309             ppex = new HAProxyProtocolException(errMsg, e);
310         } else if (errMsg != null) {
311             ppex = new HAProxyProtocolException(errMsg);
312         } else if (e != null) {
313             ppex = new HAProxyProtocolException(e);
314         } else {
315             ppex = new HAProxyProtocolException();
316         }
317         throw ppex;
318     }
319 
320     /**
321      * Returns the {@link ProtocolDetectionResult} for the given {@link ByteBuf}.
322      */
323     public static ProtocolDetectionResult<HAProxyProtocolVersion> detectProtocol(ByteBuf buffer) {
324         if (buffer.readableBytes() < 12) {
325             return ProtocolDetectionResult.needsMoreData();
326         }
327 
328         int idx = buffer.readerIndex();
329 
330         if (match(BINARY_PREFIX, buffer, idx)) {
331             return DETECTION_RESULT_V2;
332         }
333         if (match(TEXT_PREFIX, buffer, idx)) {
334             return DETECTION_RESULT_V1;
335         }
336         return ProtocolDetectionResult.invalid();
337     }
338 
339     private static boolean match(byte[] prefix, ByteBuf buffer, int idx) {
340         for (int i = 0; i < prefix.length; i++) {
341             final byte b = buffer.getByte(idx + i);
342             if (b != prefix[i]) {
343                 return false;
344             }
345         }
346         return true;
347     }
348 
349     /**
350      * HeaderExtractor create a header frame out of the {@link ByteBuf}.
351      */
352     private abstract class HeaderExtractor {
353         /** Header max size */
354         private final int maxHeaderSize;
355 
356         protected HeaderExtractor(int maxHeaderSize) {
357             this.maxHeaderSize = maxHeaderSize;
358         }
359 
360         /**
361          * Create a frame out of the {@link ByteBuf} and return it.
362          *
363          * @param ctx     the {@link ChannelHandlerContext} which this {@link HAProxyMessageDecoder} belongs to
364          * @param buffer  the {@link ByteBuf} from which to read data
365          * @return frame  the {@link ByteBuf} which represent the frame or {@code null} if no frame could
366          *                be created
367          * @throws Exception if exceed maxLength
368          */
369         public ByteBuf extract(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
370             final int eoh = findEndOfHeader(buffer);
371             if (!discarding) {
372                 if (eoh >= 0) {
373                     final int length = eoh - buffer.readerIndex();
374                     if (length > maxHeaderSize) {
375                         buffer.readerIndex(eoh + delimiterLength(buffer, eoh));
376                         failOverLimit(ctx, length);
377                         return null;
378                     }
379                     ByteBuf frame = buffer.readSlice(length);
380                     buffer.skipBytes(delimiterLength(buffer, eoh));
381                     return frame;
382                 } else {
383                     final int length = buffer.readableBytes();
384                     if (length > maxHeaderSize) {
385                         discardedBytes = length;
386                         buffer.skipBytes(length);
387                         discarding = true;
388                         if (failFast) {
389                             failOverLimit(ctx, "over " + discardedBytes);
390                         }
391                     }
392                     return null;
393                 }
394             } else {
395                 if (eoh >= 0) {
396                     final int length = discardedBytes + eoh - buffer.readerIndex();
397                     buffer.readerIndex(eoh + delimiterLength(buffer, eoh));
398                     discardedBytes = 0;
399                     discarding = false;
400                     if (!failFast) {
401                         failOverLimit(ctx, "over " + length);
402                     }
403                 } else {
404                     discardedBytes += buffer.readableBytes();
405                     buffer.skipBytes(buffer.readableBytes());
406                 }
407                 return null;
408             }
409         }
410 
411         /**
412          * Find the end of the header from the given {@link ByteBuf},the end may be a CRLF, or the length given by the
413          * header.
414          *
415          * @param buffer the buffer to be searched
416          * @return {@code -1} if can not find the end, otherwise return the buffer index of end
417          */
418         protected abstract int findEndOfHeader(ByteBuf buffer);
419 
420         /**
421          * Get the length of the header delimiter.
422          *
423          * @param buffer the buffer where delimiter is located
424          * @param eoh index of delimiter
425          * @return length of the delimiter
426          */
427         protected abstract int delimiterLength(ByteBuf buffer, int eoh);
428     }
429 
430     private final class LineHeaderExtractor extends HeaderExtractor {
431 
432         LineHeaderExtractor(int maxHeaderSize) {
433             super(maxHeaderSize);
434         }
435 
436         @Override
437         protected int findEndOfHeader(ByteBuf buffer) {
438             return findEndOfLine(buffer);
439         }
440 
441         @Override
442         protected int delimiterLength(ByteBuf buffer, int eoh) {
443             return buffer.getByte(eoh) == '\r' ? 2 : 1;
444         }
445     }
446 
447     private final class StructHeaderExtractor extends HeaderExtractor {
448 
449         StructHeaderExtractor(int maxHeaderSize) {
450             super(maxHeaderSize);
451         }
452 
453         @Override
454         protected int findEndOfHeader(ByteBuf buffer) {
455             return HAProxyMessageDecoder.findEndOfHeader(buffer);
456         }
457 
458         @Override
459         protected int delimiterLength(ByteBuf buffer, int eoh) {
460             return 0;
461         }
462     }
463 }