1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.util.internal.StringUtil;
19
20 import java.nio.ByteBuffer;
21 import java.util.LinkedHashSet;
22 import java.util.List;
23 import java.util.function.BiConsumer;
24 import java.util.function.BiFunction;
25 import javax.net.ssl.SSLEngine;
26 import javax.net.ssl.SSLEngineResult;
27 import javax.net.ssl.SSLException;
28
29 import static io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelectionListener;
30 import static io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelector;
31 import static io.netty.handler.ssl.SslUtils.toSSLHandshakeException;
32
33 class JdkAlpnSslEngine extends JdkSslEngine {
34 private final ProtocolSelectionListener selectionListener;
35 private final AlpnSelector alpnSelector;
36
37 final class AlpnSelector implements BiFunction<SSLEngine, List<String>, String> {
38 private final ProtocolSelector selector;
39 private boolean called;
40
41 AlpnSelector(ProtocolSelector selector) {
42 this.selector = selector;
43 }
44
45 @Override
46 public String apply(SSLEngine sslEngine, List<String> strings) {
47 assert !called;
48 called = true;
49
50 try {
51 String selected = selector.select(strings);
52 return selected == null ? StringUtil.EMPTY_STRING : selected;
53 } catch (Exception cause) {
54
55
56
57
58 return null;
59 }
60 }
61
62 void checkUnsupported() {
63 if (called) {
64
65
66
67
68 return;
69 }
70 String protocol = getApplicationProtocol();
71 assert protocol != null;
72
73 if (protocol.isEmpty()) {
74
75 selector.unsupported();
76 }
77 }
78 }
79
80 JdkAlpnSslEngine(SSLEngine engine,
81 @SuppressWarnings("deprecation") JdkApplicationProtocolNegotiator applicationNegotiator,
82 boolean isServer, BiConsumer<SSLEngine, AlpnSelector> setHandshakeApplicationProtocolSelector,
83 BiConsumer<SSLEngine, List<String>> setApplicationProtocols) {
84 super(engine);
85 if (isServer) {
86 selectionListener = null;
87 alpnSelector = new AlpnSelector(applicationNegotiator.protocolSelectorFactory().
88 newSelector(this, new LinkedHashSet<String>(applicationNegotiator.protocols())));
89 setHandshakeApplicationProtocolSelector.accept(engine, alpnSelector);
90 } else {
91 selectionListener = applicationNegotiator.protocolListenerFactory()
92 .newListener(this, applicationNegotiator.protocols());
93 alpnSelector = null;
94 setApplicationProtocols.accept(engine, applicationNegotiator.protocols());
95 }
96 }
97
98 JdkAlpnSslEngine(SSLEngine engine,
99 @SuppressWarnings("deprecation") JdkApplicationProtocolNegotiator applicationNegotiator,
100 boolean isServer) {
101 this(engine, applicationNegotiator, isServer,
102 new BiConsumer<SSLEngine, AlpnSelector>() {
103 @Override
104 public void accept(SSLEngine e, AlpnSelector s) {
105 JdkAlpnSslUtils.setHandshakeApplicationProtocolSelector(e, s);
106 }
107 },
108 new BiConsumer<SSLEngine, List<String>>() {
109 @Override
110 public void accept(SSLEngine e, List<String> p) {
111 JdkAlpnSslUtils.setApplicationProtocols(e, p);
112 }
113 });
114 }
115
116 private SSLEngineResult verifyProtocolSelection(SSLEngineResult result) throws SSLException {
117 if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
118 if (alpnSelector == null) {
119
120 try {
121 String protocol = getApplicationProtocol();
122 assert protocol != null;
123 if (protocol.isEmpty()) {
124
125
126
127
128 selectionListener.unsupported();
129 } else {
130 selectionListener.selected(protocol);
131 }
132 } catch (Throwable e) {
133 throw toSSLHandshakeException(e);
134 }
135 } else {
136 assert selectionListener == null;
137 alpnSelector.checkUnsupported();
138 }
139 }
140 return result;
141 }
142
143 @Override
144 public SSLEngineResult wrap(ByteBuffer src, ByteBuffer dst) throws SSLException {
145 return verifyProtocolSelection(super.wrap(src, dst));
146 }
147
148 @Override
149 public SSLEngineResult wrap(ByteBuffer[] srcs, ByteBuffer dst) throws SSLException {
150 return verifyProtocolSelection(super.wrap(srcs, dst));
151 }
152
153 @Override
154 public SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int len, ByteBuffer dst) throws SSLException {
155 return verifyProtocolSelection(super.wrap(srcs, offset, len, dst));
156 }
157
158 @Override
159 public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer dst) throws SSLException {
160 return verifyProtocolSelection(super.unwrap(src, dst));
161 }
162
163 @Override
164 public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts) throws SSLException {
165 return verifyProtocolSelection(super.unwrap(src, dsts));
166 }
167
168 @Override
169 public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dst, int offset, int len) throws SSLException {
170 return verifyProtocolSelection(super.unwrap(src, dst, offset, len));
171 }
172
173 @Override
174 void setNegotiatedApplicationProtocol(String applicationProtocol) {
175
176 }
177
178 @Override
179 public String getNegotiatedApplicationProtocol() {
180 String protocol = getApplicationProtocol();
181 if (protocol != null) {
182 return protocol.isEmpty() ? null : protocol;
183 }
184 return null;
185 }
186
187
188
189 @SuppressWarnings("override")
190 public String getApplicationProtocol() {
191 return JdkAlpnSslUtils.getApplicationProtocol(getWrappedEngine());
192 }
193
194 @SuppressWarnings("override")
195 public String getHandshakeApplicationProtocol() {
196 return JdkAlpnSslUtils.getHandshakeApplicationProtocol(getWrappedEngine());
197 }
198
199 @SuppressWarnings("override")
200 public void setHandshakeApplicationProtocolSelector(BiFunction<SSLEngine, List<String>, String> selector) {
201 JdkAlpnSslUtils.setHandshakeApplicationProtocolSelector(getWrappedEngine(), selector);
202 }
203
204 @SuppressWarnings("override")
205 public BiFunction<SSLEngine, List<String>, String> getHandshakeApplicationProtocolSelector() {
206 return JdkAlpnSslUtils.getHandshakeApplicationProtocolSelector(getWrappedEngine());
207 }
208 }