1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package io.netty5.handler.ssl.util;
18
19 import io.netty5.buffer.BufferUtil;
20 import io.netty5.util.concurrent.FastThreadLocal;
21 import io.netty5.util.internal.EmptyArrays;
22 import io.netty5.util.internal.StringUtil;
23
24 import javax.net.ssl.ManagerFactoryParameters;
25 import javax.net.ssl.TrustManager;
26 import javax.net.ssl.TrustManagerFactory;
27 import javax.net.ssl.X509TrustManager;
28 import java.security.KeyStore;
29 import java.security.MessageDigest;
30 import java.security.NoSuchAlgorithmException;
31 import java.security.cert.CertificateEncodingException;
32 import java.security.cert.CertificateException;
33 import java.security.cert.X509Certificate;
34 import java.util.ArrayList;
35 import java.util.Arrays;
36 import java.util.List;
37 import java.util.regex.Pattern;
38
39 import static java.util.Objects.requireNonNull;
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 public final class FingerprintTrustManagerFactory extends SimpleTrustManagerFactory {
83
84 private static final Pattern FINGERPRINT_PATTERN = Pattern.compile("^[0-9a-fA-F:]+$");
85 private static final Pattern FINGERPRINT_STRIP_PATTERN = Pattern.compile(":");
86
87
88
89
90
91
92
93 public static FingerprintTrustManagerFactoryBuilder builder(String algorithm) {
94 return new FingerprintTrustManagerFactoryBuilder(algorithm);
95 }
96
97 private final FastThreadLocal<MessageDigest> tlmd;
98
99 private final TrustManager tm = new X509TrustManager() {
100
101 @Override
102 public void checkClientTrusted(X509Certificate[] chain, String s) throws CertificateException {
103 checkTrusted("client", chain);
104 }
105
106 @Override
107 public void checkServerTrusted(X509Certificate[] chain, String s) throws CertificateException {
108 checkTrusted("server", chain);
109 }
110
111 private void checkTrusted(String type, X509Certificate[] chain) throws CertificateException {
112 X509Certificate cert = chain[0];
113 byte[] fingerprint = fingerprint(cert);
114 boolean found = false;
115 for (byte[] allowedFingerprint: fingerprints) {
116 if (Arrays.equals(fingerprint, allowedFingerprint)) {
117 found = true;
118 break;
119 }
120 }
121
122 if (!found) {
123 throw new CertificateException(
124 type + " certificate with unknown fingerprint: " + cert.getSubjectDN());
125 }
126 }
127
128 private byte[] fingerprint(X509Certificate cert) throws CertificateEncodingException {
129 MessageDigest md = tlmd.get();
130 md.reset();
131 return md.digest(cert.getEncoded());
132 }
133
134 @Override
135 public X509Certificate[] getAcceptedIssuers() {
136 return EmptyArrays.EMPTY_X509_CERTIFICATES;
137 }
138 };
139
140 private final byte[][] fingerprints;
141
142
143
144
145
146
147
148 FingerprintTrustManagerFactory(final String algorithm, byte[][] fingerprints) {
149 requireNonNull(algorithm, "algorithm");
150 requireNonNull(fingerprints, "fingerprints");
151
152 if (fingerprints.length == 0) {
153 throw new IllegalArgumentException("No fingerprints provided");
154 }
155
156
157 final MessageDigest md;
158 try {
159 md = MessageDigest.getInstance(algorithm);
160 } catch (NoSuchAlgorithmException e) {
161 throw new IllegalArgumentException(
162 String.format("Unsupported hash algorithm: %s", algorithm), e);
163 }
164
165 int hashLength = md.getDigestLength();
166 List<byte[]> list = new ArrayList<>(fingerprints.length);
167 for (byte[] f: fingerprints) {
168 if (f == null) {
169 break;
170 }
171 if (f.length != hashLength) {
172 throw new IllegalArgumentException(
173 String.format("malformed fingerprint (length is %d but expected %d): %s",
174 f.length, hashLength, BufferUtil.hexDump(f)));
175 }
176 list.add(f.clone());
177 }
178
179 tlmd = new FastThreadLocal<>() {
180
181 @Override
182 protected MessageDigest initialValue() {
183 try {
184 return MessageDigest.getInstance(algorithm);
185 } catch (NoSuchAlgorithmException e) {
186 throw new IllegalArgumentException(
187 String.format("Unsupported hash algorithm: %s", algorithm), e);
188 }
189 }
190 };
191
192 this.fingerprints = list.toArray(EmptyArrays.EMPTY_BYTES_BYTES);
193 }
194
195 static byte[][] toFingerprintArray(Iterable<String> fingerprints) {
196 requireNonNull(fingerprints, "fingerprints");
197
198 List<byte[]> list = new ArrayList<>();
199 for (String f: fingerprints) {
200 if (f == null) {
201 break;
202 }
203
204 if (!FINGERPRINT_PATTERN.matcher(f).matches()) {
205 throw new IllegalArgumentException("malformed fingerprint: " + f);
206 }
207 f = FINGERPRINT_STRIP_PATTERN.matcher(f).replaceAll("");
208
209 list.add(StringUtil.decodeHexDump(f));
210 }
211
212 return list.toArray(EmptyArrays.EMPTY_BYTES_BYTES);
213 }
214
215 @Override
216 protected void engineInit(KeyStore keyStore) throws Exception { }
217
218 @Override
219 protected void engineInit(ManagerFactoryParameters managerFactoryParameters) throws Exception { }
220
221 @Override
222 protected TrustManager[] engineGetTrustManagers() {
223 return new TrustManager[] { tm };
224 }
225 }