1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  package io.netty.handler.codec.haproxy;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.AddressFamily;
20  import io.netty.util.AbstractReferenceCounted;
21  import io.netty.util.ByteProcessor;
22  import io.netty.util.CharsetUtil;
23  import io.netty.util.NetUtil;
24  import io.netty.util.ResourceLeakDetector;
25  import io.netty.util.ResourceLeakDetectorFactory;
26  import io.netty.util.ResourceLeakTracker;
27  import io.netty.util.internal.ObjectUtil;
28  import io.netty.util.internal.StringUtil;
29  
30  import java.util.ArrayList;
31  import java.util.Collections;
32  import java.util.List;
33  
34  
35  
36  
37  public final class HAProxyMessage extends AbstractReferenceCounted {
38  
39      
40      private static final int MAX_NESTING_LEVEL = 128;
41      private static final ResourceLeakDetector<HAProxyMessage> leakDetector =
42              ResourceLeakDetectorFactory.instance().newResourceLeakDetector(HAProxyMessage.class);
43  
44      private final ResourceLeakTracker<HAProxyMessage> leak;
45      private final HAProxyProtocolVersion protocolVersion;
46      private final HAProxyCommand command;
47      private final HAProxyProxiedProtocol proxiedProtocol;
48      private final String sourceAddress;
49      private final String destinationAddress;
50      private final int sourcePort;
51      private final int destinationPort;
52      private final List<HAProxyTLV> tlvs;
53  
54      
55  
56  
57      private HAProxyMessage(
58              HAProxyProtocolVersion protocolVersion, HAProxyCommand command, HAProxyProxiedProtocol proxiedProtocol,
59              String sourceAddress, String destinationAddress, String sourcePort, String destinationPort) {
60          this(
61                  protocolVersion, command, proxiedProtocol,
62                  sourceAddress, destinationAddress, portStringToInt(sourcePort), portStringToInt(destinationPort));
63      }
64  
65      
66  
67  
68  
69  
70  
71  
72  
73  
74  
75      public HAProxyMessage(
76              HAProxyProtocolVersion protocolVersion, HAProxyCommand command, HAProxyProxiedProtocol proxiedProtocol,
77              String sourceAddress, String destinationAddress, int sourcePort, int destinationPort) {
78  
79          this(protocolVersion, command, proxiedProtocol,
80               sourceAddress, destinationAddress, sourcePort, destinationPort, Collections.<HAProxyTLV>emptyList());
81      }
82  
83      
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94      public HAProxyMessage(
95              HAProxyProtocolVersion protocolVersion, HAProxyCommand command, HAProxyProxiedProtocol proxiedProtocol,
96              String sourceAddress, String destinationAddress, int sourcePort, int destinationPort,
97              List<? extends HAProxyTLV> tlvs) {
98  
99          ObjectUtil.checkNotNull(protocolVersion, "protocolVersion");
100         ObjectUtil.checkNotNull(proxiedProtocol, "proxiedProtocol");
101         ObjectUtil.checkNotNull(tlvs, "tlvs");
102         AddressFamily addrFamily = proxiedProtocol.addressFamily();
103 
104         checkAddress(sourceAddress, addrFamily);
105         checkAddress(destinationAddress, addrFamily);
106         checkPort(sourcePort, addrFamily);
107         checkPort(destinationPort, addrFamily);
108 
109         this.protocolVersion = protocolVersion;
110         this.command = command;
111         this.proxiedProtocol = proxiedProtocol;
112         this.sourceAddress = sourceAddress;
113         this.destinationAddress = destinationAddress;
114         this.sourcePort = sourcePort;
115         this.destinationPort = destinationPort;
116         this.tlvs = Collections.unmodifiableList(tlvs);
117 
118         leak = leakDetector.track(this);
119     }
120 
121     
122 
123 
124 
125 
126 
127 
128     static HAProxyMessage decodeHeader(ByteBuf header) {
129         ObjectUtil.checkNotNull(header, "header");
130 
131         if (header.readableBytes() < 16) {
132             throw new HAProxyProtocolException(
133                     "incomplete header: " + header.readableBytes() + " bytes (expected: 16+ bytes)");
134         }
135 
136         
137         header.skipBytes(12);
138         final byte verCmdByte = header.readByte();
139 
140         HAProxyProtocolVersion ver;
141         try {
142             ver = HAProxyProtocolVersion.valueOf(verCmdByte);
143         } catch (IllegalArgumentException e) {
144             throw new HAProxyProtocolException(e);
145         }
146 
147         if (ver != HAProxyProtocolVersion.V2) {
148             throw new HAProxyProtocolException("version 1 unsupported: 0x" + Integer.toHexString(verCmdByte));
149         }
150 
151         HAProxyCommand cmd;
152         try {
153             cmd = HAProxyCommand.valueOf(verCmdByte);
154         } catch (IllegalArgumentException e) {
155             throw new HAProxyProtocolException(e);
156         }
157 
158         if (cmd == HAProxyCommand.LOCAL) {
159             return unknownMsg(HAProxyProtocolVersion.V2, HAProxyCommand.LOCAL);
160         }
161 
162         
163         HAProxyProxiedProtocol protAndFam;
164         try {
165             protAndFam = HAProxyProxiedProtocol.valueOf(header.readByte());
166         } catch (IllegalArgumentException e) {
167             throw new HAProxyProtocolException(e);
168         }
169 
170         if (protAndFam == HAProxyProxiedProtocol.UNKNOWN) {
171             return unknownMsg(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY);
172         }
173 
174         int addressInfoLen = header.readUnsignedShort();
175 
176         String srcAddress;
177         String dstAddress;
178         int addressLen;
179         int srcPort = 0;
180         int dstPort = 0;
181 
182         AddressFamily addressFamily = protAndFam.addressFamily();
183 
184         if (addressFamily == AddressFamily.AF_UNIX) {
185             
186             if (addressInfoLen < 216 || header.readableBytes() < 216) {
187                 throw new HAProxyProtocolException(
188                     "incomplete UNIX socket address information: " +
189                             Math.min(addressInfoLen, header.readableBytes()) + " bytes (expected: 216+ bytes)");
190             }
191             int startIdx = header.readerIndex();
192             int addressEnd = header.indexOf(startIdx, startIdx + 108, (byte) 0); 
193             if (addressEnd == -1) {
194                 addressLen = 108;
195             } else {
196                 addressLen = addressEnd - startIdx;
197             }
198             srcAddress = header.toString(startIdx, addressLen, CharsetUtil.US_ASCII);
199 
200             startIdx += 108;
201 
202             addressEnd = header.indexOf(startIdx, startIdx + 108, (byte) 0); 
203             if (addressEnd == -1) {
204                 addressLen = 108;
205             } else {
206                 addressLen = addressEnd - startIdx;
207             }
208             dstAddress = header.toString(startIdx, addressLen, CharsetUtil.US_ASCII);
209             
210             
211             header.readerIndex(startIdx + 108);
212         } else {
213             if (addressFamily == AddressFamily.AF_IPv4) {
214                 
215                 if (addressInfoLen < 12 || header.readableBytes() < 12) {
216                     throw new HAProxyProtocolException(
217                         "incomplete IPv4 address information: " +
218                                 Math.min(addressInfoLen, header.readableBytes()) + " bytes (expected: 12+ bytes)");
219                 }
220                 addressLen = 4;
221             } else if (addressFamily == AddressFamily.AF_IPv6) {
222                 
223                 if (addressInfoLen < 36 || header.readableBytes() < 36) {
224                     throw new HAProxyProtocolException(
225                         "incomplete IPv6 address information: " +
226                                 Math.min(addressInfoLen, header.readableBytes()) + " bytes (expected: 36+ bytes)");
227                 }
228                 addressLen = 16;
229             } else {
230                 throw new HAProxyProtocolException(
231                     "unable to parse address information (unknown address family: " + addressFamily + ')');
232             }
233 
234             
235             srcAddress = ipBytesToString(header, addressLen);
236             dstAddress = ipBytesToString(header, addressLen);
237             srcPort = header.readUnsignedShort();
238             dstPort = header.readUnsignedShort();
239         }
240 
241         final List<HAProxyTLV> tlvs = readTlvs(header);
242 
243         return new HAProxyMessage(ver, cmd, protAndFam, srcAddress, dstAddress, srcPort, dstPort, tlvs);
244     }
245 
246     private static List<HAProxyTLV> readTlvs(final ByteBuf header) {
247         HAProxyTLV haProxyTLV = readNextTLV(header, 0);
248         if (haProxyTLV == null) {
249             return Collections.emptyList();
250         }
251         
252         List<HAProxyTLV> haProxyTLVs = new ArrayList<HAProxyTLV>(4);
253 
254         do {
255             haProxyTLVs.add(haProxyTLV);
256             if (haProxyTLV instanceof HAProxySSLTLV) {
257                 haProxyTLVs.addAll(((HAProxySSLTLV) haProxyTLV).encapsulatedTLVs());
258             }
259         } while ((haProxyTLV = readNextTLV(header, 0)) != null);
260         return haProxyTLVs;
261     }
262 
263     private static HAProxyTLV readNextTLV(final ByteBuf header, int nestingLevel) {
264         if (nestingLevel > MAX_NESTING_LEVEL) {
265             throw new HAProxyProtocolException(
266                     "Maximum TLV nesting level reached: " + nestingLevel + " (expected: < " + MAX_NESTING_LEVEL + ')');
267         }
268         
269         if (header.readableBytes() < 4) {
270             return null;
271         }
272 
273         final byte typeAsByte = header.readByte();
274         final HAProxyTLV.Type type = HAProxyTLV.Type.typeForByteValue(typeAsByte);
275 
276         final int length = header.readUnsignedShort();
277         switch (type) {
278         case PP2_TYPE_SSL:
279             final ByteBuf rawContent = header.retainedSlice(header.readerIndex(), length);
280             final ByteBuf byteBuf = header.readSlice(length);
281             final byte client = byteBuf.readByte();
282             final int verify = byteBuf.readInt();
283 
284             if (byteBuf.readableBytes() >= 4) {
285 
286                 final List<HAProxyTLV> encapsulatedTlvs = new ArrayList<HAProxyTLV>(4);
287                 do {
288                     final HAProxyTLV haProxyTLV = readNextTLV(byteBuf, nestingLevel + 1);
289                     if (haProxyTLV == null) {
290                         break;
291                     }
292                     encapsulatedTlvs.add(haProxyTLV);
293                 } while (byteBuf.readableBytes() >= 4);
294 
295                 return new HAProxySSLTLV(verify, client, encapsulatedTlvs, rawContent);
296             }
297             return new HAProxySSLTLV(verify, client, Collections.<HAProxyTLV>emptyList(), rawContent);
298         
299         case PP2_TYPE_ALPN:
300         case PP2_TYPE_AUTHORITY:
301         case PP2_TYPE_SSL_VERSION:
302         case PP2_TYPE_SSL_CN:
303         case PP2_TYPE_NETNS:
304         case OTHER:
305             return new HAProxyTLV(type, typeAsByte, header.readRetainedSlice(length));
306         default:
307             return null;
308         }
309     }
310 
311     
312 
313 
314 
315 
316 
317 
318     static HAProxyMessage decodeHeader(String header) {
319         if (header == null) {
320             throw new HAProxyProtocolException("header");
321         }
322 
323         String[] parts = header.split(" ");
324         int numParts = parts.length;
325 
326         if (numParts < 2) {
327             throw new HAProxyProtocolException(
328                     "invalid header: " + header + " (expected: 'PROXY' and proxied protocol values)");
329         }
330 
331         if (!"PROXY".equals(parts[0])) {
332             throw new HAProxyProtocolException("unknown identifier: " + parts[0]);
333         }
334 
335         HAProxyProxiedProtocol protAndFam;
336         try {
337             protAndFam = HAProxyProxiedProtocol.valueOf(parts[1]);
338         } catch (IllegalArgumentException e) {
339             throw new HAProxyProtocolException(e);
340         }
341 
342         if (protAndFam != HAProxyProxiedProtocol.TCP4 &&
343                 protAndFam != HAProxyProxiedProtocol.TCP6 &&
344                 protAndFam != HAProxyProxiedProtocol.UNKNOWN) {
345             throw new HAProxyProtocolException("unsupported v1 proxied protocol: " + parts[1]);
346         }
347 
348         if (protAndFam == HAProxyProxiedProtocol.UNKNOWN) {
349             return unknownMsg(HAProxyProtocolVersion.V1, HAProxyCommand.PROXY);
350         }
351 
352         if (numParts != 6) {
353             throw new HAProxyProtocolException("invalid TCP4/6 header: " + header + " (expected: 6 parts)");
354         }
355 
356         try {
357             return new HAProxyMessage(
358                     HAProxyProtocolVersion.V1, HAProxyCommand.PROXY,
359                     protAndFam, parts[2], parts[3], parts[4], parts[5]);
360         } catch (RuntimeException e) {
361             throw new HAProxyProtocolException("invalid HAProxy message", e);
362         }
363     }
364 
365     
366 
367 
368 
369     private static HAProxyMessage unknownMsg(HAProxyProtocolVersion version, HAProxyCommand command) {
370         return new HAProxyMessage(version, command, HAProxyProxiedProtocol.UNKNOWN, null, null, 0, 0);
371     }
372 
373     
374 
375 
376 
377 
378 
379 
380     private static String ipBytesToString(ByteBuf header, int addressLen) {
381         StringBuilder sb = new StringBuilder();
382         final int ipv4Len = 4;
383         final int ipv6Len = 8;
384         if (addressLen == ipv4Len) {
385             for (int i = 0; i < ipv4Len; i++) {
386                 sb.append(header.readByte() & 0xff);
387                 sb.append('.');
388             }
389         } else {
390             for (int i = 0; i < ipv6Len; i++) {
391                 sb.append(Integer.toHexString(header.readUnsignedShort()));
392                 sb.append(':');
393             }
394         }
395         sb.setLength(sb.length() - 1);
396         return sb.toString();
397     }
398 
399     
400 
401 
402 
403 
404 
405 
406     private static int portStringToInt(String value) {
407         int port;
408         try {
409             port = Integer.parseInt(value);
410         } catch (NumberFormatException e) {
411             throw new IllegalArgumentException("invalid port: " + value, e);
412         }
413 
414         if (port <= 0 || port > 65535) {
415             throw new IllegalArgumentException("invalid port: " + value + " (expected: 1 ~ 65535)");
416         }
417 
418         return port;
419     }
420 
421     
422 
423 
424 
425 
426 
427 
428     private static void checkAddress(String address, AddressFamily addrFamily) {
429         ObjectUtil.checkNotNull(addrFamily, "addrFamily");
430 
431         switch (addrFamily) {
432             case AF_UNSPEC:
433                 if (address != null) {
434                     throw new IllegalArgumentException("unable to validate an AF_UNSPEC address: " + address);
435                 }
436                 return;
437             case AF_UNIX:
438                 ObjectUtil.checkNotNull(address, "address");
439                 if (address.getBytes(CharsetUtil.US_ASCII).length > 108) {
440                     throw new IllegalArgumentException("invalid AF_UNIX address: " + address);
441                 }
442                 return;
443         }
444 
445         ObjectUtil.checkNotNull(address, "address");
446 
447         switch (addrFamily) {
448             case AF_IPv4:
449                 if (!NetUtil.isValidIpV4Address(address)) {
450                     throw new IllegalArgumentException("invalid IPv4 address: " + address);
451                 }
452                 break;
453             case AF_IPv6:
454                 if (!NetUtil.isValidIpV6Address(address)) {
455                     throw new IllegalArgumentException("invalid IPv6 address: " + address);
456                 }
457                 break;
458             default:
459                 throw new IllegalArgumentException("unexpected addrFamily: " + addrFamily);
460         }
461     }
462 
463     
464 
465 
466 
467 
468 
469     private static void checkPort(int port, AddressFamily addrFamily) {
470         switch (addrFamily) {
471         case AF_IPv6:
472         case AF_IPv4:
473             if (port < 0 || port > 65535) {
474                 throw new IllegalArgumentException("invalid port: " + port + " (expected: 0 ~ 65535)");
475             }
476             break;
477         case AF_UNIX:
478         case AF_UNSPEC:
479             if (port != 0) {
480                 throw new IllegalArgumentException("port cannot be specified with addrFamily: " + addrFamily);
481             }
482             break;
483         default:
484             throw new IllegalArgumentException("unexpected addrFamily: " + addrFamily);
485         }
486     }
487 
488     
489 
490 
491     public HAProxyProtocolVersion protocolVersion() {
492         return protocolVersion;
493     }
494 
495     
496 
497 
498     public HAProxyCommand command() {
499         return command;
500     }
501 
502     
503 
504 
505     public HAProxyProxiedProtocol proxiedProtocol() {
506         return proxiedProtocol;
507     }
508 
509     
510 
511 
512 
513     public String sourceAddress() {
514         return sourceAddress;
515     }
516 
517     
518 
519 
520     public String destinationAddress() {
521         return destinationAddress;
522     }
523 
524     
525 
526 
527     public int sourcePort() {
528         return sourcePort;
529     }
530 
531     
532 
533 
534     public int destinationPort() {
535         return destinationPort;
536     }
537 
538     
539 
540 
541 
542 
543     public List<HAProxyTLV> tlvs() {
544         return tlvs;
545     }
546 
547     int tlvNumBytes() {
548         int tlvNumBytes = 0;
549         for (int i = 0; i < tlvs.size(); i++) {
550             tlvNumBytes += tlvs.get(i).totalNumBytes();
551         }
552         return tlvNumBytes;
553     }
554 
555     @Override
556     public HAProxyMessage touch() {
557         tryRecord();
558         return (HAProxyMessage) super.touch();
559     }
560 
561     @Override
562     public HAProxyMessage touch(Object hint) {
563         if (leak != null) {
564             leak.record(hint);
565         }
566         return this;
567     }
568 
569     @Override
570     public HAProxyMessage retain() {
571         tryRecord();
572         return (HAProxyMessage) super.retain();
573     }
574 
575     @Override
576     public HAProxyMessage retain(int increment) {
577         tryRecord();
578         return (HAProxyMessage) super.retain(increment);
579     }
580 
581     @Override
582     public boolean release() {
583         tryRecord();
584         return super.release();
585     }
586 
587     @Override
588     public boolean release(int decrement) {
589         tryRecord();
590         return super.release(decrement);
591     }
592 
593     private void tryRecord() {
594         if (leak != null) {
595             leak.record();
596         }
597     }
598 
599     @Override
600     protected void deallocate() {
601         try {
602             for (HAProxyTLV tlv : tlvs) {
603                 tlv.release();
604             }
605         } finally {
606             final ResourceLeakTracker<HAProxyMessage> leak = this.leak;
607             if (leak != null) {
608                 boolean closed = leak.close(this);
609                 assert closed;
610             }
611         }
612     }
613 
614     @Override
615     public String toString() {
616         StringBuilder sb = new StringBuilder(256)
617                 .append(StringUtil.simpleClassName(this))
618                 .append("(protocolVersion: ").append(protocolVersion)
619                 .append(", command: ").append(command)
620                 .append(", proxiedProtocol: ").append(proxiedProtocol)
621                 .append(", sourceAddress: ").append(sourceAddress)
622                 .append(", destinationAddress: ").append(destinationAddress)
623                 .append(", sourcePort: ").append(sourcePort)
624                 .append(", destinationPort: ").append(destinationPort)
625                 .append(", tlvs: [");
626         if (!tlvs.isEmpty()) {
627             for (HAProxyTLV tlv: tlvs) {
628                 sb.append(tlv).append(", ");
629             }
630             sb.setLength(sb.length() - 2);
631         }
632         sb.append("])");
633         return sb.toString();
634     }
635 }