View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  
17  package io.netty.handler.proxy;
18  
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelDuplexHandler;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelFutureListener;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelPromise;
25  import io.netty.channel.PendingWriteQueue;
26  import io.netty.util.ReferenceCountUtil;
27  import io.netty.util.concurrent.DefaultPromise;
28  import io.netty.util.concurrent.EventExecutor;
29  import io.netty.util.concurrent.Future;
30  import io.netty.util.concurrent.ScheduledFuture;
31  import io.netty.util.internal.logging.InternalLogger;
32  import io.netty.util.internal.logging.InternalLoggerFactory;
33  
34  import java.net.SocketAddress;
35  import java.nio.channels.ConnectionPendingException;
36  import java.util.concurrent.TimeUnit;
37  
38  public abstract class ProxyHandler extends ChannelDuplexHandler {
39  
40      private static final InternalLogger logger = InternalLoggerFactory.getInstance(ProxyHandler.class);
41  
42      /**
43       * The default connect timeout: 10 seconds.
44       */
45      private static final long DEFAULT_CONNECT_TIMEOUT_MILLIS = 10000;
46  
47      /**
48       * A string that signifies 'no authentication' or 'anonymous'.
49       */
50      static final String AUTH_NONE = "none";
51  
52      private final SocketAddress proxyAddress;
53      private volatile SocketAddress destinationAddress;
54      private volatile long connectTimeoutMillis = DEFAULT_CONNECT_TIMEOUT_MILLIS;
55  
56      private volatile ChannelHandlerContext ctx;
57      private PendingWriteQueue pendingWrites;
58      private boolean finished;
59      private boolean suppressChannelReadComplete;
60      private boolean flushedPrematurely;
61      private final LazyChannelPromise connectPromise = new LazyChannelPromise();
62      private ScheduledFuture<?> connectTimeoutFuture;
63      private final ChannelFutureListener writeListener = new ChannelFutureListener() {
64          @Override
65          public void operationComplete(ChannelFuture future) throws Exception {
66              if (!future.isSuccess()) {
67                  setConnectFailure(future.cause());
68              }
69          }
70      };
71  
72      protected ProxyHandler(SocketAddress proxyAddress) {
73          if (proxyAddress == null) {
74              throw new NullPointerException("proxyAddress");
75          }
76          this.proxyAddress = proxyAddress;
77      }
78  
79      /**
80       * Returns the name of the proxy protocol in use.
81       */
82      public abstract String protocol();
83  
84      /**
85       * Returns the name of the authentication scheme in use.
86       */
87      public abstract String authScheme();
88  
89      /**
90       * Returns the address of the proxy server.
91       */
92      @SuppressWarnings("unchecked")
93      public final <T extends SocketAddress> T proxyAddress() {
94          return (T) proxyAddress;
95      }
96  
97      /**
98       * Returns the address of the destination to connect to via the proxy server.
99       */
100     @SuppressWarnings("unchecked")
101     public final <T extends SocketAddress> T destinationAddress() {
102         return (T) destinationAddress;
103     }
104 
105     /**
106      * Returns {@code true} if and only if the connection to the destination has been established successfully.
107      */
108     public final boolean isConnected() {
109         return connectPromise.isSuccess();
110     }
111 
112     /**
113      * Returns a {@link Future} that is notified when the connection to the destination has been established
114      * or the connection attempt has failed.
115      */
116     public final Future<Channel> connectFuture() {
117         return connectPromise;
118     }
119 
120     /**
121      * Returns the connect timeout in millis.  If the connection attempt to the destination does not finish within
122      * the timeout, the connection attempt will be failed.
123      */
124     public final long connectTimeoutMillis() {
125         return connectTimeoutMillis;
126     }
127 
128     /**
129      * Sets the connect timeout in millis.  If the connection attempt to the destination does not finish within
130      * the timeout, the connection attempt will be failed.
131      */
132     public final void setConnectTimeoutMillis(long connectTimeoutMillis) {
133         if (connectTimeoutMillis <= 0) {
134             connectTimeoutMillis = 0;
135         }
136 
137         this.connectTimeoutMillis = connectTimeoutMillis;
138     }
139 
140     @Override
141     public final void handlerAdded(ChannelHandlerContext ctx) throws Exception {
142         this.ctx = ctx;
143         addCodec(ctx);
144 
145         if (ctx.channel().isActive()) {
146             // channelActive() event has been fired already, which means this.channelActive() will
147             // not be invoked. We have to initialize here instead.
148             sendInitialMessage(ctx);
149         } else {
150             // channelActive() event has not been fired yet.  this.channelOpen() will be invoked
151             // and initialization will occur there.
152         }
153     }
154 
155     /**
156      * Adds the codec handlers required to communicate with the proxy server.
157      */
158     protected abstract void addCodec(ChannelHandlerContext ctx) throws Exception;
159 
160     /**
161      * Removes the encoders added in {@link #addCodec(ChannelHandlerContext)}.
162      */
163     protected abstract void removeEncoder(ChannelHandlerContext ctx) throws Exception;
164 
165     /**
166      * Removes the decoders added in {@link #addCodec(ChannelHandlerContext)}.
167      */
168     protected abstract void removeDecoder(ChannelHandlerContext ctx) throws Exception;
169 
170     @Override
171     public final void connect(
172             ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
173             ChannelPromise promise) throws Exception {
174 
175         if (destinationAddress != null) {
176             promise.setFailure(new ConnectionPendingException());
177             return;
178         }
179 
180         destinationAddress = remoteAddress;
181         ctx.connect(proxyAddress, localAddress, promise);
182     }
183 
184     @Override
185     public final void channelActive(ChannelHandlerContext ctx) throws Exception {
186         sendInitialMessage(ctx);
187         ctx.fireChannelActive();
188     }
189 
190     /**
191      * Sends the initial message to be sent to the proxy server. This method also starts a timeout task which marks
192      * the {@link #connectPromise} as failure if the connection attempt does not success within the timeout.
193      */
194     private void sendInitialMessage(final ChannelHandlerContext ctx) throws Exception {
195         final long connectTimeoutMillis = this.connectTimeoutMillis;
196         if (connectTimeoutMillis > 0) {
197             connectTimeoutFuture = ctx.executor().schedule(new Runnable() {
198                 @Override
199                 public void run() {
200                     if (!connectPromise.isDone()) {
201                         setConnectFailure(new ProxyConnectException(exceptionMessage("timeout")));
202                     }
203                 }
204             }, connectTimeoutMillis, TimeUnit.MILLISECONDS);
205         }
206 
207         final Object initialMessage = newInitialMessage(ctx);
208         if (initialMessage != null) {
209             sendToProxyServer(initialMessage);
210         }
211 
212         readIfNeeded(ctx);
213     }
214 
215     /**
216      * Returns a new message that is sent at first time when the connection to the proxy server has been established.
217      *
218      * @return the initial message, or {@code null} if the proxy server is expected to send the first message instead
219      */
220     protected abstract Object newInitialMessage(ChannelHandlerContext ctx) throws Exception;
221 
222     /**
223      * Sends the specified message to the proxy server.  Use this method to send a response to the proxy server in
224      * {@link #handleResponse(ChannelHandlerContext, Object)}.
225      */
226     protected final void sendToProxyServer(Object msg) {
227         ctx.writeAndFlush(msg).addListener(writeListener);
228     }
229 
230     @Override
231     public final void channelInactive(ChannelHandlerContext ctx) throws Exception {
232         if (finished) {
233             ctx.fireChannelInactive();
234         } else {
235             // Disconnected before connected to the destination.
236             setConnectFailure(new ProxyConnectException(exceptionMessage("disconnected")));
237         }
238     }
239 
240     @Override
241     public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
242         if (finished) {
243             ctx.fireExceptionCaught(cause);
244         } else {
245             // Exception was raised before the connection attempt is finished.
246             setConnectFailure(cause);
247         }
248     }
249 
250     @Override
251     public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
252         if (finished) {
253             // Received a message after the connection has been established; pass through.
254             suppressChannelReadComplete = false;
255             ctx.fireChannelRead(msg);
256         } else {
257             suppressChannelReadComplete = true;
258             Throwable cause = null;
259             try {
260                 boolean done = handleResponse(ctx, msg);
261                 if (done) {
262                     setConnectSuccess();
263                 }
264             } catch (Throwable t) {
265                 cause = t;
266             } finally {
267                 ReferenceCountUtil.release(msg);
268                 if (cause != null) {
269                     setConnectFailure(cause);
270                 }
271             }
272         }
273     }
274 
275     /**
276      * Handles the message received from the proxy server.
277      *
278      * @return {@code true} if the connection to the destination has been established,
279      *         {@code false} if the connection to the destination has not been established and more messages are
280      *         expected from the proxy server
281      */
282     protected abstract boolean handleResponse(ChannelHandlerContext ctx, Object response) throws Exception;
283 
284     private void setConnectSuccess() {
285         finished = true;
286         cancelConnectTimeoutFuture();
287 
288         if (!connectPromise.isDone()) {
289             boolean removedCodec = true;
290 
291             removedCodec &= safeRemoveEncoder();
292 
293             ctx.fireUserEventTriggered(
294                     new ProxyConnectionEvent(protocol(), authScheme(), proxyAddress, destinationAddress));
295 
296             removedCodec &= safeRemoveDecoder();
297 
298             if (removedCodec) {
299                 writePendingWrites();
300 
301                 if (flushedPrematurely) {
302                     ctx.flush();
303                 }
304                 connectPromise.trySuccess(ctx.channel());
305             } else {
306                 // We are at inconsistent state because we failed to remove all codec handlers.
307                 Exception cause = new ProxyConnectException(
308                         "failed to remove all codec handlers added by the proxy handler; bug?");
309                 failPendingWritesAndClose(cause);
310             }
311         }
312     }
313 
314     private boolean safeRemoveDecoder() {
315         try {
316             removeDecoder(ctx);
317             return true;
318         } catch (Exception e) {
319             logger.warn("Failed to remove proxy decoders:", e);
320         }
321 
322         return false;
323     }
324 
325     private boolean safeRemoveEncoder() {
326         try {
327             removeEncoder(ctx);
328             return true;
329         } catch (Exception e) {
330             logger.warn("Failed to remove proxy encoders:", e);
331         }
332 
333         return false;
334     }
335 
336     private void setConnectFailure(Throwable cause) {
337         finished = true;
338         cancelConnectTimeoutFuture();
339 
340         if (!connectPromise.isDone()) {
341 
342             if (!(cause instanceof ProxyConnectException)) {
343                 cause = new ProxyConnectException(
344                         exceptionMessage(cause.toString()), cause);
345             }
346 
347             safeRemoveDecoder();
348             safeRemoveEncoder();
349             failPendingWritesAndClose(cause);
350         }
351     }
352 
353     private void failPendingWritesAndClose(Throwable cause) {
354         failPendingWrites(cause);
355         connectPromise.tryFailure(cause);
356         ctx.fireExceptionCaught(cause);
357         ctx.close();
358     }
359 
360     private void cancelConnectTimeoutFuture() {
361         if (connectTimeoutFuture != null) {
362             connectTimeoutFuture.cancel(false);
363             connectTimeoutFuture = null;
364         }
365     }
366 
367     /**
368      * Decorates the specified exception message with the common information such as the current protocol,
369      * authentication scheme, proxy address, and destination address.
370      */
371     protected final String exceptionMessage(String msg) {
372         if (msg == null) {
373             msg = "";
374         }
375 
376         StringBuilder buf = new StringBuilder(128 + msg.length())
377             .append(protocol())
378             .append(", ")
379             .append(authScheme())
380             .append(", ")
381             .append(proxyAddress)
382             .append(" => ")
383             .append(destinationAddress);
384         if (!msg.isEmpty()) {
385             buf.append(", ").append(msg);
386         }
387 
388         return buf.toString();
389     }
390 
391     @Override
392     public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
393         if (suppressChannelReadComplete) {
394             suppressChannelReadComplete = false;
395 
396             readIfNeeded(ctx);
397         } else {
398             ctx.fireChannelReadComplete();
399         }
400     }
401 
402     @Override
403     public final void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
404         if (finished) {
405             writePendingWrites();
406             ctx.write(msg, promise);
407         } else {
408             addPendingWrite(ctx, msg, promise);
409         }
410     }
411 
412     @Override
413     public final void flush(ChannelHandlerContext ctx) throws Exception {
414         if (finished) {
415             writePendingWrites();
416             ctx.flush();
417         } else {
418             flushedPrematurely = true;
419         }
420     }
421 
422     private static void readIfNeeded(ChannelHandlerContext ctx) {
423         if (!ctx.channel().config().isAutoRead()) {
424             ctx.read();
425         }
426     }
427 
428     private void writePendingWrites() {
429         if (pendingWrites != null) {
430             pendingWrites.removeAndWriteAll();
431             pendingWrites = null;
432         }
433     }
434 
435     private void failPendingWrites(Throwable cause) {
436         if (pendingWrites != null) {
437             pendingWrites.removeAndFailAll(cause);
438             pendingWrites = null;
439         }
440     }
441 
442     private void addPendingWrite(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
443         PendingWriteQueue pendingWrites = this.pendingWrites;
444         if (pendingWrites == null) {
445             this.pendingWrites = pendingWrites = new PendingWriteQueue(ctx);
446         }
447         pendingWrites.add(msg, promise);
448     }
449 
450     private final class LazyChannelPromise extends DefaultPromise<Channel> {
451         @Override
452         protected EventExecutor executor() {
453             if (ctx == null) {
454                 throw new IllegalStateException();
455             }
456             return ctx.executor();
457         }
458     }
459 }