1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.ssl;
17
18 import io.netty5.buffer.BufferUtil;
19 import io.netty5.buffer.api.Buffer;
20 import io.netty5.util.Resource;
21 import io.netty5.channel.ChannelHandlerContext;
22 import io.netty5.handler.codec.ByteToMessageDecoder;
23 import io.netty5.handler.codec.DecoderException;
24 import io.netty5.util.concurrent.Future;
25 import io.netty5.util.internal.PlatformDependent;
26 import io.netty5.util.internal.logging.InternalLogger;
27 import io.netty5.util.internal.logging.InternalLoggerFactory;
28
29
30
31
32 public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder {
33
34 private static final InternalLogger logger =
35 InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
36
37 private boolean handshakeFailed;
38 private boolean suppressRead;
39 private boolean readPending;
40 private Buffer handshakeBuffer;
41
42 @Override
43 protected void decode(ChannelHandlerContext ctx, Buffer in) throws Exception {
44
45
46 if (!suppressRead && !handshakeFailed) {
47 try {
48 int readerIndex = in.readerOffset();
49 int readableBytes = in.readableBytes();
50 int handshakeLength = -1;
51
52
53 while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
54 final int contentType = in.getUnsignedByte(readerIndex);
55 switch (contentType) {
56 case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
57
58 case SslUtils.SSL_CONTENT_TYPE_ALERT:
59 final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
60
61
62 if (len == SslUtils.NOT_ENCRYPTED) {
63 handshakeFailed = true;
64 NotSslRecordException e = new NotSslRecordException(
65 "not an SSL/TLS record: " + BufferUtil.hexDump(in));
66 in.skipReadableBytes(in.readableBytes());
67 ctx.fireChannelInboundEvent(new SniCompletionEvent(e));
68 ctx.fireChannelInboundEvent(new SslHandshakeCompletionEvent(e));
69 throw e;
70 }
71 if (len == SslUtils.NOT_ENOUGH_DATA) {
72
73 return;
74 }
75
76 select(ctx, null);
77 return;
78 case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
79 final int majorVersion = in.getUnsignedByte(readerIndex + 1);
80
81 if (majorVersion == 3) {
82 int packetLength = in.getUnsignedShort(readerIndex + 3) +
83 SslUtils.SSL_RECORD_HEADER_LENGTH;
84
85 if (readableBytes < packetLength) {
86
87 return;
88 }
89 if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
90 select(ctx, null);
91 return;
92 }
93
94 final int endOffset = readerIndex + packetLength;
95
96
97 if (handshakeLength == -1) {
98 if (readerIndex + 4 > endOffset) {
99
100 return;
101 }
102
103 final int handshakeType = in.getUnsignedByte(readerIndex +
104 SslUtils.SSL_RECORD_HEADER_LENGTH);
105
106
107
108 if (handshakeType != 1) {
109 select(ctx, null);
110 return;
111 }
112
113
114
115 handshakeLength = in.getUnsignedMedium(readerIndex +
116 SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
117
118
119 readerIndex += 4;
120 packetLength -= 4;
121
122 if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
123
124
125 readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
126 in.readerOffset(readerIndex);
127 select(ctx, in.readSplit(handshakeLength));
128 return;
129 } else {
130 if (handshakeBuffer == null) {
131 handshakeBuffer = ctx.bufferAllocator().allocate(handshakeLength);
132 } else {
133
134 handshakeBuffer.resetOffsets();
135 }
136 }
137 }
138
139
140 int hsLen = packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH;
141 in.copyInto(readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
142 handshakeBuffer, handshakeBuffer.writerOffset(), hsLen);
143 handshakeBuffer.skipWritableBytes(hsLen);
144 readerIndex += packetLength;
145 readableBytes -= packetLength;
146 if (handshakeLength <= handshakeBuffer.readableBytes()) {
147 Buffer clientHello = handshakeBuffer.readerOffset(0).writerOffset(handshakeLength);
148 handshakeBuffer = null;
149
150 select(ctx, clientHello);
151 return;
152 }
153 break;
154 }
155
156 default:
157
158 select(ctx, null);
159 return;
160 }
161 }
162 } catch (NotSslRecordException e) {
163
164 throw e;
165 } catch (Exception e) {
166
167 if (logger.isDebugEnabled()) {
168 logger.debug("Unexpected client hello packet: " + BufferUtil.hexDump(in), e);
169 }
170 select(ctx, null);
171 }
172 }
173 }
174
175 private void releaseHandshakeBuffer() {
176 Resource.dispose(handshakeBuffer);
177 handshakeBuffer = null;
178 }
179
180 private void select(final ChannelHandlerContext ctx, Buffer clientHello) {
181 final Future<T> future;
182 try {
183 future = lookup(ctx, clientHello);
184 if (future.isDone()) {
185 Resource.dispose(clientHello);
186 onLookupComplete(ctx, future);
187 } else {
188 suppressRead = true;
189 future.addListener(f -> {
190 Resource.dispose(clientHello);
191 try {
192 suppressRead = false;
193 try {
194 onLookupComplete(ctx, f);
195 } catch (DecoderException err) {
196 ctx.fireChannelExceptionCaught(err);
197 } catch (Exception cause) {
198 ctx.fireChannelExceptionCaught(new DecoderException(cause));
199 } catch (Throwable cause) {
200 ctx.fireChannelExceptionCaught(cause);
201 }
202 } finally {
203 if (readPending) {
204 readPending = false;
205 ctx.read();
206 }
207 }
208 });
209 }
210 } catch (Throwable cause) {
211 PlatformDependent.throwException(cause);
212 }
213 }
214
215 @Override
216 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
217 releaseHandshakeBuffer();
218
219 super.handlerRemoved0(ctx);
220 }
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246 protected abstract Future<T> lookup(ChannelHandlerContext ctx, Buffer clientHello) throws Exception;
247
248
249
250
251
252
253 protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<? extends T> future) throws Exception;
254
255 @Override
256 public void read(ChannelHandlerContext ctx) {
257 if (suppressRead) {
258 readPending = true;
259 } else {
260 ctx.read();
261 }
262 }
263 }