1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package io.netty.handler.proxy;
18
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelDuplexHandler;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.channel.PendingWriteQueue;
25 import io.netty.util.ReferenceCountUtil;
26 import io.netty.util.concurrent.DefaultPromise;
27 import io.netty.util.concurrent.EventExecutor;
28 import io.netty.util.concurrent.Future;
29 import io.netty.util.internal.ObjectUtil;
30 import io.netty.util.internal.logging.InternalLogger;
31 import io.netty.util.internal.logging.InternalLoggerFactory;
32
33 import java.net.SocketAddress;
34 import java.nio.channels.ConnectionPendingException;
35 import java.util.concurrent.TimeUnit;
36
37
38
39
40 public abstract class ProxyHandler extends ChannelDuplexHandler {
41
42 private static final InternalLogger logger = InternalLoggerFactory.getInstance(ProxyHandler.class);
43
44
45
46
47 private static final long DEFAULT_CONNECT_TIMEOUT_MILLIS = 10000;
48
49
50
51
52 static final String AUTH_NONE = "none";
53
54 private final SocketAddress proxyAddress;
55 private volatile SocketAddress destinationAddress;
56 private volatile long connectTimeoutMillis = DEFAULT_CONNECT_TIMEOUT_MILLIS;
57
58 private volatile ChannelHandlerContext ctx;
59 private PendingWriteQueue pendingWrites;
60 private boolean finished;
61 private boolean suppressChannelReadComplete;
62 private boolean flushedPrematurely;
63 private final LazyChannelPromise connectPromise = new LazyChannelPromise();
64 private Future<?> connectTimeoutFuture;
65 private final ChannelFutureListener writeListener = future -> {
66 if (!future.isSuccess()) {
67 setConnectFailure(future.cause());
68 }
69 };
70
71 protected ProxyHandler(SocketAddress proxyAddress) {
72 this.proxyAddress = ObjectUtil.checkNotNull(proxyAddress, "proxyAddress");
73 }
74
75
76
77
78 public abstract String protocol();
79
80
81
82
83 public abstract String authScheme();
84
85
86
87
88 @SuppressWarnings("unchecked")
89 public final <T extends SocketAddress> T proxyAddress() {
90 return (T) proxyAddress;
91 }
92
93
94
95
96 @SuppressWarnings("unchecked")
97 public final <T extends SocketAddress> T destinationAddress() {
98 return (T) destinationAddress;
99 }
100
101
102
103
104 public final boolean isConnected() {
105 return connectPromise.isSuccess();
106 }
107
108
109
110
111
112 public final Future<Channel> connectFuture() {
113 return connectPromise;
114 }
115
116
117
118
119
120 public final long connectTimeoutMillis() {
121 return connectTimeoutMillis;
122 }
123
124
125
126
127
128 public final void setConnectTimeoutMillis(long connectTimeoutMillis) {
129 if (connectTimeoutMillis <= 0) {
130 connectTimeoutMillis = 0;
131 }
132
133 this.connectTimeoutMillis = connectTimeoutMillis;
134 }
135
136 @Override
137 public final void handlerAdded(ChannelHandlerContext ctx) throws Exception {
138 this.ctx = ctx;
139 addCodec(ctx);
140
141 if (ctx.channel().isActive()) {
142
143
144 sendInitialMessage(ctx);
145 } else {
146
147
148 }
149 }
150
151
152
153
154 protected abstract void addCodec(ChannelHandlerContext ctx) throws Exception;
155
156
157
158
159 protected abstract void removeEncoder(ChannelHandlerContext ctx) throws Exception;
160
161
162
163
164 protected abstract void removeDecoder(ChannelHandlerContext ctx) throws Exception;
165
166 @Override
167 public final void connect(
168 ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
169 ChannelPromise promise) throws Exception {
170
171 if (destinationAddress != null) {
172 promise.setFailure(new ConnectionPendingException());
173 return;
174 }
175
176 destinationAddress = remoteAddress;
177 ctx.connect(proxyAddress, localAddress, promise);
178 }
179
180 @Override
181 public final void channelActive(ChannelHandlerContext ctx) throws Exception {
182 sendInitialMessage(ctx);
183 ctx.fireChannelActive();
184 }
185
186
187
188
189
190 private void sendInitialMessage(final ChannelHandlerContext ctx) throws Exception {
191 final long connectTimeoutMillis = this.connectTimeoutMillis;
192 if (connectTimeoutMillis > 0) {
193 connectTimeoutFuture = ctx.executor().schedule(new Runnable() {
194 @Override
195 public void run() {
196 if (!connectPromise.isDone()) {
197 setConnectFailure(new ProxyConnectException(exceptionMessage("timeout")));
198 }
199 }
200 }, connectTimeoutMillis, TimeUnit.MILLISECONDS);
201 }
202
203 final Object initialMessage = newInitialMessage(ctx);
204 if (initialMessage != null) {
205 sendToProxyServer(initialMessage);
206 }
207
208 readIfNeeded(ctx);
209 }
210
211
212
213
214
215
216 protected abstract Object newInitialMessage(ChannelHandlerContext ctx) throws Exception;
217
218
219
220
221
222 protected final void sendToProxyServer(Object msg) {
223 ctx.writeAndFlush(msg).addListener(writeListener);
224 }
225
226 @Override
227 public final void channelInactive(ChannelHandlerContext ctx) throws Exception {
228 if (finished) {
229 ctx.fireChannelInactive();
230 } else {
231
232 setConnectFailure(new ProxyConnectException(exceptionMessage("disconnected")));
233 }
234 }
235
236 @Override
237 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
238 if (finished) {
239 ctx.fireExceptionCaught(cause);
240 } else {
241
242 setConnectFailure(cause);
243 }
244 }
245
246 @Override
247 public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
248 if (finished) {
249
250 suppressChannelReadComplete = false;
251 ctx.fireChannelRead(msg);
252 } else {
253 suppressChannelReadComplete = true;
254 Throwable cause = null;
255 try {
256 boolean done = handleResponse(ctx, msg);
257 if (done) {
258 setConnectSuccess();
259 }
260 } catch (Throwable t) {
261 cause = t;
262 } finally {
263 ReferenceCountUtil.release(msg);
264 if (cause != null) {
265 setConnectFailure(cause);
266 }
267 }
268 }
269 }
270
271
272
273
274
275
276
277
278 protected abstract boolean handleResponse(ChannelHandlerContext ctx, Object response) throws Exception;
279
280 private void setConnectSuccess() {
281 finished = true;
282 cancelConnectTimeoutFuture();
283
284 if (!connectPromise.isDone()) {
285 boolean removedCodec = true;
286
287 removedCodec &= safeRemoveEncoder();
288
289 ctx.fireUserEventTriggered(
290 new ProxyConnectionEvent(protocol(), authScheme(), proxyAddress, destinationAddress));
291
292 removedCodec &= safeRemoveDecoder();
293
294 if (removedCodec) {
295 writePendingWrites();
296
297 if (flushedPrematurely) {
298 ctx.flush();
299 }
300 connectPromise.trySuccess(ctx.channel());
301 } else {
302
303 Exception cause = new ProxyConnectException(
304 "failed to remove all codec handlers added by the proxy handler; bug?");
305 failPendingWritesAndClose(cause);
306 }
307 }
308 }
309
310 private boolean safeRemoveDecoder() {
311 try {
312 removeDecoder(ctx);
313 return true;
314 } catch (Exception e) {
315 logger.warn("Failed to remove proxy decoders:", e);
316 }
317
318 return false;
319 }
320
321 private boolean safeRemoveEncoder() {
322 try {
323 removeEncoder(ctx);
324 return true;
325 } catch (Exception e) {
326 logger.warn("Failed to remove proxy encoders:", e);
327 }
328
329 return false;
330 }
331
332 private void setConnectFailure(Throwable cause) {
333 finished = true;
334 cancelConnectTimeoutFuture();
335
336 if (!connectPromise.isDone()) {
337
338 if (!(cause instanceof ProxyConnectException)) {
339 cause = new ProxyConnectException(
340 exceptionMessage(cause.toString()), cause);
341 }
342
343 safeRemoveDecoder();
344 safeRemoveEncoder();
345 failPendingWritesAndClose(cause);
346 }
347 }
348
349 private void failPendingWritesAndClose(Throwable cause) {
350 failPendingWrites(cause);
351 connectPromise.tryFailure(cause);
352 ctx.fireExceptionCaught(cause);
353 ctx.close();
354 }
355
356 private void cancelConnectTimeoutFuture() {
357 if (connectTimeoutFuture != null) {
358 connectTimeoutFuture.cancel(false);
359 connectTimeoutFuture = null;
360 }
361 }
362
363
364
365
366
367 protected final String exceptionMessage(String msg) {
368 if (msg == null) {
369 msg = "";
370 }
371
372 StringBuilder buf = new StringBuilder(128 + msg.length())
373 .append(protocol())
374 .append(", ")
375 .append(authScheme())
376 .append(", ")
377 .append(proxyAddress)
378 .append(" => ")
379 .append(destinationAddress);
380 if (!msg.isEmpty()) {
381 buf.append(", ").append(msg);
382 }
383
384 return buf.toString();
385 }
386
387 @Override
388 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
389 if (suppressChannelReadComplete) {
390 suppressChannelReadComplete = false;
391
392 readIfNeeded(ctx);
393 } else {
394 ctx.fireChannelReadComplete();
395 }
396 }
397
398 @Override
399 public final void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
400 if (finished) {
401 writePendingWrites();
402 ctx.write(msg, promise);
403 } else {
404 addPendingWrite(ctx, msg, promise);
405 }
406 }
407
408 @Override
409 public final void flush(ChannelHandlerContext ctx) throws Exception {
410 if (finished) {
411 writePendingWrites();
412 ctx.flush();
413 } else {
414 flushedPrematurely = true;
415 }
416 }
417
418 private static void readIfNeeded(ChannelHandlerContext ctx) {
419 if (!ctx.channel().config().isAutoRead()) {
420 ctx.read();
421 }
422 }
423
424 private void writePendingWrites() {
425 if (pendingWrites != null) {
426 pendingWrites.removeAndWriteAll();
427 pendingWrites = null;
428 }
429 }
430
431 private void failPendingWrites(Throwable cause) {
432 if (pendingWrites != null) {
433 pendingWrites.removeAndFailAll(cause);
434 pendingWrites = null;
435 }
436 }
437
438 private void addPendingWrite(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
439 PendingWriteQueue pendingWrites = this.pendingWrites;
440 if (pendingWrites == null) {
441 this.pendingWrites = pendingWrites = new PendingWriteQueue(ctx);
442 }
443 pendingWrites.add(msg, promise);
444 }
445
446 private final class LazyChannelPromise extends DefaultPromise<Channel> {
447 @Override
448 protected EventExecutor executor() {
449 if (ctx == null) {
450 throw new IllegalStateException();
451 }
452 return ctx.executor();
453 }
454 }
455 }