1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.haproxy;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.handler.codec.ByteToMessageDecoder;
21 import io.netty.handler.codec.LineBasedFrameDecoder;
22 import io.netty.handler.codec.ProtocolDetectionResult;
23 import io.netty.util.CharsetUtil;
24
25 import java.util.List;
26
27
28
29
30
31
32 public class HAProxyMessageDecoder extends ByteToMessageDecoder {
33
34
35
36 private static final int V1_MAX_LENGTH = 108;
37
38
39
40
41 private static final int V2_MAX_LENGTH = 16 + 65535;
42
43
44
45
46 private static final int V2_MIN_LENGTH = 16 + 216;
47
48
49
50
51 private static final int V2_MAX_TLV = 65535 - 216;
52
53
54
55
56 private static final int DELIMITER_LENGTH = 2;
57
58
59
60
61 private static final byte[] BINARY_PREFIX = {
62 (byte) 0x0D,
63 (byte) 0x0A,
64 (byte) 0x0D,
65 (byte) 0x0A,
66 (byte) 0x00,
67 (byte) 0x0D,
68 (byte) 0x0A,
69 (byte) 0x51,
70 (byte) 0x55,
71 (byte) 0x49,
72 (byte) 0x54,
73 (byte) 0x0A
74 };
75
76 private static final byte[] TEXT_PREFIX = {
77 (byte) 'P',
78 (byte) 'R',
79 (byte) 'O',
80 (byte) 'X',
81 (byte) 'Y',
82 };
83
84
85
86
87 private static final int BINARY_PREFIX_LENGTH = BINARY_PREFIX.length;
88
89
90
91
92 private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V1 =
93 ProtocolDetectionResult.detected(HAProxyProtocolVersion.V1);
94
95
96
97
98 private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V2 =
99 ProtocolDetectionResult.detected(HAProxyProtocolVersion.V2);
100
101
102
103
104 private boolean discarding;
105
106
107
108
109 private int discardedBytes;
110
111
112
113
114 private boolean finished;
115
116
117
118
119 private int version = -1;
120
121
122
123
124
125 private final int v2MaxHeaderSize;
126
127
128
129
130 public HAProxyMessageDecoder() {
131 v2MaxHeaderSize = V2_MAX_LENGTH;
132 }
133
134
135
136
137
138
139
140
141
142
143
144 public HAProxyMessageDecoder(int maxTlvSize) {
145 if (maxTlvSize < 1) {
146 v2MaxHeaderSize = V2_MIN_LENGTH;
147 } else if (maxTlvSize > V2_MAX_TLV) {
148 v2MaxHeaderSize = V2_MAX_LENGTH;
149 } else {
150 int calcMax = maxTlvSize + V2_MIN_LENGTH;
151 if (calcMax > V2_MAX_LENGTH) {
152 v2MaxHeaderSize = V2_MAX_LENGTH;
153 } else {
154 v2MaxHeaderSize = calcMax;
155 }
156 }
157 }
158
159
160
161
162
163 private static int findVersion(final ByteBuf buffer) {
164 final int n = buffer.readableBytes();
165
166 if (n < 13) {
167 return -1;
168 }
169
170 int idx = buffer.readerIndex();
171 return match(BINARY_PREFIX, buffer, idx) ? buffer.getByte(idx + BINARY_PREFIX_LENGTH) : 1;
172 }
173
174
175
176
177
178 private static int findEndOfHeader(final ByteBuf buffer) {
179 final int n = buffer.readableBytes();
180
181
182 if (n < 16) {
183 return -1;
184 }
185
186 int offset = buffer.readerIndex() + 14;
187
188
189 int totalHeaderBytes = 16 + buffer.getUnsignedShort(offset);
190
191
192 if (n >= totalHeaderBytes) {
193 return totalHeaderBytes;
194 } else {
195 return -1;
196 }
197 }
198
199
200
201
202
203 private static int findEndOfLine(final ByteBuf buffer) {
204 final int n = buffer.writerIndex();
205 for (int i = buffer.readerIndex(); i < n; i++) {
206 final byte b = buffer.getByte(i);
207 if (b == '\r' && i < n - 1 && buffer.getByte(i + 1) == '\n') {
208 return i;
209 }
210 }
211 return -1;
212 }
213
214 @Override
215 public boolean isSingleDecode() {
216
217
218 return true;
219 }
220
221 @Override
222 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
223 super.channelRead(ctx, msg);
224 if (finished) {
225 ctx.pipeline().remove(this);
226 }
227 }
228
229 @Override
230 protected final void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
231
232 if (version == -1) {
233 if ((version = findVersion(in)) == -1) {
234 return;
235 }
236 }
237
238 ByteBuf decoded;
239
240 if (version == 1) {
241 decoded = decodeLine(ctx, in);
242 } else {
243 decoded = decodeStruct(ctx, in);
244 }
245
246 if (decoded != null) {
247 finished = true;
248 try {
249 if (version == 1) {
250 out.add(HAProxyMessage.decodeHeader(decoded.toString(CharsetUtil.US_ASCII)));
251 } else {
252 out.add(HAProxyMessage.decodeHeader(decoded));
253 }
254 } catch (HAProxyProtocolException e) {
255 fail(ctx, null, e);
256 }
257 }
258 }
259
260
261
262
263
264
265
266
267
268
269 private ByteBuf decodeStruct(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
270 final int eoh = findEndOfHeader(buffer);
271 if (!discarding) {
272 if (eoh >= 0) {
273 final int length = eoh - buffer.readerIndex();
274 if (length > v2MaxHeaderSize) {
275 buffer.readerIndex(eoh);
276 failOverLimit(ctx, length);
277 return null;
278 }
279 return buffer.readSlice(length);
280 } else {
281 final int length = buffer.readableBytes();
282 if (length > v2MaxHeaderSize) {
283 discardedBytes = length;
284 buffer.skipBytes(length);
285 discarding = true;
286 failOverLimit(ctx, "over " + discardedBytes);
287 }
288 return null;
289 }
290 } else {
291 if (eoh >= 0) {
292 buffer.readerIndex(eoh);
293 discardedBytes = 0;
294 discarding = false;
295 } else {
296 discardedBytes = buffer.readableBytes();
297 buffer.skipBytes(discardedBytes);
298 }
299 return null;
300 }
301 }
302
303
304
305
306
307
308
309
310
311
312 private ByteBuf decodeLine(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
313 final int eol = findEndOfLine(buffer);
314 if (!discarding) {
315 if (eol >= 0) {
316 final int length = eol - buffer.readerIndex();
317 if (length > V1_MAX_LENGTH) {
318 buffer.readerIndex(eol + DELIMITER_LENGTH);
319 failOverLimit(ctx, length);
320 return null;
321 }
322 ByteBuf frame = buffer.readSlice(length);
323 buffer.skipBytes(DELIMITER_LENGTH);
324 return frame;
325 } else {
326 final int length = buffer.readableBytes();
327 if (length > V1_MAX_LENGTH) {
328 discardedBytes = length;
329 buffer.skipBytes(length);
330 discarding = true;
331 failOverLimit(ctx, "over " + discardedBytes);
332 }
333 return null;
334 }
335 } else {
336 if (eol >= 0) {
337 final int delimLength = buffer.getByte(eol) == '\r' ? 2 : 1;
338 buffer.readerIndex(eol + delimLength);
339 discardedBytes = 0;
340 discarding = false;
341 } else {
342 discardedBytes = buffer.readableBytes();
343 buffer.skipBytes(discardedBytes);
344 }
345 return null;
346 }
347 }
348
349 private void failOverLimit(final ChannelHandlerContext ctx, int length) {
350 failOverLimit(ctx, String.valueOf(length));
351 }
352
353 private void failOverLimit(final ChannelHandlerContext ctx, String length) {
354 int maxLength = version == 1 ? V1_MAX_LENGTH : v2MaxHeaderSize;
355 fail(ctx, "header length (" + length + ") exceeds the allowed maximum (" + maxLength + ')', null);
356 }
357
358 private void fail(final ChannelHandlerContext ctx, String errMsg, Exception e) {
359 finished = true;
360 ctx.close();
361 HAProxyProtocolException ppex;
362 if (errMsg != null && e != null) {
363 ppex = new HAProxyProtocolException(errMsg, e);
364 } else if (errMsg != null) {
365 ppex = new HAProxyProtocolException(errMsg);
366 } else if (e != null) {
367 ppex = new HAProxyProtocolException(e);
368 } else {
369 ppex = new HAProxyProtocolException();
370 }
371 throw ppex;
372 }
373
374
375
376
377 public static ProtocolDetectionResult<HAProxyProtocolVersion> detectProtocol(ByteBuf buffer) {
378 if (buffer.readableBytes() < 12) {
379 return ProtocolDetectionResult.needsMoreData();
380 }
381
382 int idx = buffer.readerIndex();
383
384 if (match(BINARY_PREFIX, buffer, idx)) {
385 return DETECTION_RESULT_V2;
386 }
387 if (match(TEXT_PREFIX, buffer, idx)) {
388 return DETECTION_RESULT_V1;
389 }
390 return ProtocolDetectionResult.invalid();
391 }
392
393 private static boolean match(byte[] prefix, ByteBuf buffer, int idx) {
394 for (int i = 0; i < prefix.length; i++) {
395 final byte b = buffer.getByte(idx + i);
396 if (b != prefix[i]) {
397 return false;
398 }
399 }
400 return true;
401 }
402 }