1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.ssl;
17
18 import io.netty5.channel.Channel;
19 import io.netty5.channel.ChannelHandler;
20 import io.netty5.channel.ChannelHandlerContext;
21 import io.netty5.channel.ChannelInitializer;
22 import io.netty5.channel.ChannelPipeline;
23 import io.netty5.channel.ChannelShutdownDirection;
24 import io.netty5.handler.codec.DecoderException;
25 import io.netty5.util.internal.RecyclableArrayList;
26 import io.netty5.util.internal.logging.InternalLogger;
27 import io.netty5.util.internal.logging.InternalLoggerFactory;
28
29 import javax.net.ssl.SSLException;
30
31 import static java.util.Objects.requireNonNull;
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68 public abstract class ApplicationProtocolNegotiationHandler implements ChannelHandler {
69
70 private static final InternalLogger logger =
71 InternalLoggerFactory.getInstance(ApplicationProtocolNegotiationHandler.class);
72
73 private final String fallbackProtocol;
74 private final RecyclableArrayList bufferedMessages = RecyclableArrayList.newInstance();
75 private ChannelHandlerContext ctx;
76 private boolean sslHandlerChecked;
77
78
79
80
81
82
83
84 protected ApplicationProtocolNegotiationHandler(String fallbackProtocol) {
85 this.fallbackProtocol = requireNonNull(fallbackProtocol, "fallbackProtocol");
86 }
87
88 @Override
89 public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
90 this.ctx = ctx;
91 }
92
93 @Override
94 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
95 fireBufferedMessages();
96 bufferedMessages.recycle();
97 }
98
99 @Override
100 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
101
102 bufferedMessages.add(msg);
103 if (!sslHandlerChecked) {
104 sslHandlerChecked = true;
105 if (ctx.pipeline().get(SslHandler.class) == null) {
106
107
108 removeSelfIfPresent(ctx);
109 }
110 }
111 }
112
113
114
115
116 private void fireBufferedMessages() {
117 if (!bufferedMessages.isEmpty()) {
118 for (int i = 0; i < bufferedMessages.size(); i++) {
119 ctx.fireChannelRead(bufferedMessages.get(i));
120 }
121 ctx.fireChannelReadComplete();
122 bufferedMessages.clear();
123 }
124 }
125
126 @Override
127 public void channelInboundEvent(ChannelHandlerContext ctx, Object evt) throws Exception {
128 if (evt instanceof SslHandshakeCompletionEvent) {
129
130 ctx.fireChannelInboundEvent(evt);
131
132 SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
133 try {
134 if (handshakeEvent.isSuccess()) {
135 String protocol = handshakeEvent.applicationProtocol();
136 configurePipeline(ctx, protocol != null ? protocol : fallbackProtocol);
137 } else {
138
139
140
141
142
143 }
144 } catch (Throwable cause) {
145 channelExceptionCaught(ctx, cause);
146 } finally {
147
148 if (handshakeEvent.isSuccess()) {
149 removeSelfIfPresent(ctx);
150 }
151 }
152 } else {
153 ctx.fireChannelInboundEvent(evt);
154 }
155 }
156
157 @Override
158 public void channelShutdown(ChannelHandlerContext ctx, ChannelShutdownDirection direction) {
159 if (direction == ChannelShutdownDirection.Inbound) {
160 fireBufferedMessages();
161 }
162 ctx.fireChannelShutdown(direction);
163 }
164
165 @Override
166 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
167 fireBufferedMessages();
168 ctx.fireChannelInactive();
169 }
170
171 private void removeSelfIfPresent(ChannelHandlerContext ctx) {
172 ChannelPipeline pipeline = ctx.pipeline();
173 if (!ctx.isRemoved()) {
174 pipeline.remove(this);
175 }
176 }
177
178
179
180
181
182
183
184
185
186 protected abstract void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception;
187
188
189
190
191 protected void handshakeFailure(ChannelHandlerContext ctx, Throwable cause) throws Exception {
192 logger.warn("{} TLS handshake failed:", ctx.channel(), cause);
193 ctx.close();
194 }
195
196 @Override
197 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
198 Throwable wrapped;
199 if (cause instanceof DecoderException && (wrapped = cause.getCause()) instanceof SSLException) {
200 try {
201 handshakeFailure(ctx, wrapped);
202 return;
203 } finally {
204 removeSelfIfPresent(ctx);
205 }
206 }
207 logger.warn("{} Failed to select the application-level protocol:", ctx.channel(), cause);
208 ctx.fireChannelExceptionCaught(cause);
209 ctx.close();
210 }
211 }