View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  
17  package io.netty.handler.proxy;
18  
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelFuture;
21  import io.netty.channel.ChannelFutureListener;
22  import io.netty.channel.ChannelHandlerAdapter;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelPromise;
25  import io.netty.channel.PendingWriteQueue;
26  import io.netty.util.ReferenceCountUtil;
27  import io.netty.util.concurrent.DefaultPromise;
28  import io.netty.util.concurrent.EventExecutor;
29  import io.netty.util.concurrent.Future;
30  import io.netty.util.concurrent.ScheduledFuture;
31  import io.netty.util.internal.OneTimeTask;
32  import io.netty.util.internal.logging.InternalLogger;
33  import io.netty.util.internal.logging.InternalLoggerFactory;
34  
35  import java.net.SocketAddress;
36  import java.nio.channels.ConnectionPendingException;
37  import java.util.concurrent.TimeUnit;
38  
39  public abstract class ProxyHandler extends ChannelHandlerAdapter {
40  
41      private static final InternalLogger logger = InternalLoggerFactory.getInstance(ProxyHandler.class);
42  
43      /**
44       * The default connect timeout: 10 seconds.
45       */
46      private static final long DEFAULT_CONNECT_TIMEOUT_MILLIS = 10000;
47  
48      /**
49       * A string that signifies 'no authentication' or 'anonymous'.
50       */
51      static final String AUTH_NONE = "none";
52  
53      private final SocketAddress proxyAddress;
54      private volatile SocketAddress destinationAddress;
55      private volatile long connectTimeoutMillis = DEFAULT_CONNECT_TIMEOUT_MILLIS;
56  
57      private volatile ChannelHandlerContext ctx;
58      private PendingWriteQueue pendingWrites;
59      private boolean finished;
60      private boolean suppressChannelReadComplete;
61      private boolean flushedPrematurely;
62      private final LazyChannelPromise connectPromise = new LazyChannelPromise();
63      private ScheduledFuture<?> connectTimeoutFuture;
64      private final ChannelFutureListener writeListener = new ChannelFutureListener() {
65          @Override
66          public void operationComplete(ChannelFuture future) throws Exception {
67              if (!future.isSuccess()) {
68                  setConnectFailure(future.cause());
69              }
70          }
71      };
72  
73      protected ProxyHandler(SocketAddress proxyAddress) {
74          if (proxyAddress == null) {
75              throw new NullPointerException("proxyAddress");
76          }
77          this.proxyAddress = proxyAddress;
78      }
79  
80      /**
81       * Returns the name of the proxy protocol in use.
82       */
83      public abstract String protocol();
84  
85      /**
86       * Returns the name of the authentication scheme in use.
87       */
88      public abstract String authScheme();
89  
90      /**
91       * Returns the address of the proxy server.
92       */
93      @SuppressWarnings("unchecked")
94      public final <T extends SocketAddress> T proxyAddress() {
95          return (T) proxyAddress;
96      }
97  
98      /**
99       * Returns the address of the destination to connect to via the proxy server.
100      */
101     @SuppressWarnings("unchecked")
102     public final <T extends SocketAddress> T destinationAddress() {
103         return (T) destinationAddress;
104     }
105 
106     /**
107      * Rerutns {@code true} if and only if the connection to the destination has been established successfully.
108      */
109     public final boolean isConnected() {
110         return connectPromise.isSuccess();
111     }
112 
113     /**
114      * Returns a {@link Future} that is notified when the connection to the destination has been established
115      * or the connection attempt has failed.
116      */
117     public final Future<Channel> connectFuture() {
118         return connectPromise;
119     }
120 
121     /**
122      * Returns the connect timeout in millis.  If the connection attempt to the destination does not finish within
123      * the timeout, the connection attempt will be failed.
124      */
125     public final long connectTimeoutMillis() {
126         return connectTimeoutMillis;
127     }
128 
129     /**
130      * Sets the connect timeout in millis.  If the connection attempt to the destination does not finish within
131      * the timeout, the connection attempt will be failed.
132      */
133     public final void setConnectTimeoutMillis(long connectTimeoutMillis) {
134         if (connectTimeoutMillis <= 0) {
135             connectTimeoutMillis = 0;
136         }
137 
138         this.connectTimeoutMillis = connectTimeoutMillis;
139     }
140 
141     @Override
142     public final void handlerAdded(ChannelHandlerContext ctx) throws Exception {
143         this.ctx = ctx;
144         addCodec(ctx);
145 
146         if (ctx.channel().isActive()) {
147             // channelActive() event has been fired already, which means this.channelActive() will
148             // not be invoked. We have to initialize here instead.
149             sendInitialMessage(ctx);
150         } else {
151             // channelActive() event has not been fired yet.  this.channelOpen() will be invoked
152             // and initialization will occur there.
153         }
154     }
155 
156     /**
157      * Adds the codec handlers required to communicate with the proxy server.
158      */
159     protected abstract void addCodec(ChannelHandlerContext ctx) throws Exception;
160 
161     /**
162      * Removes the encoders added in {@link #addCodec(ChannelHandlerContext)}.
163      */
164     protected abstract void removeEncoder(ChannelHandlerContext ctx) throws Exception;
165 
166     /**
167      * Removes the decoders added in {@link #addCodec(ChannelHandlerContext)}.
168      */
169     protected abstract void removeDecoder(ChannelHandlerContext ctx) throws Exception;
170 
171     @Override
172     public final void connect(
173             ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
174             ChannelPromise promise) throws Exception {
175 
176         if (destinationAddress != null) {
177             promise.setFailure(new ConnectionPendingException());
178             return;
179         }
180 
181         destinationAddress = remoteAddress;
182         ctx.connect(proxyAddress, localAddress, promise);
183     }
184 
185     @Override
186     public final void channelActive(ChannelHandlerContext ctx) throws Exception {
187         sendInitialMessage(ctx);
188         ctx.fireChannelActive();
189     }
190 
191     /**
192      * Sends the initial message to be sent to the proxy server. This method also starts a timeout task which marks
193      * the {@link #connectPromise} as failure if the connection attempt does not success within the timeout.
194      */
195     private void sendInitialMessage(final ChannelHandlerContext ctx) throws Exception {
196         final long connectTimeoutMillis = this.connectTimeoutMillis;
197         if (connectTimeoutMillis > 0) {
198             connectTimeoutFuture = ctx.executor().schedule(new OneTimeTask() {
199                 @Override
200                 public void run() {
201                     if (!connectPromise.isDone()) {
202                         setConnectFailure(new ProxyConnectException(exceptionMessage("timeout")));
203                     }
204                 }
205             }, connectTimeoutMillis, TimeUnit.MILLISECONDS);
206         }
207 
208         final Object initialMessage = newInitialMessage(ctx);
209         if (initialMessage != null) {
210             sendToProxyServer(initialMessage);
211         }
212     }
213 
214     /**
215      * Returns a new message that is sent at first time when the connection to the proxy server has been established.
216      *
217      * @return the initial message, or {@code null} if the proxy server is expected to send the first message instead
218      */
219     protected abstract Object newInitialMessage(ChannelHandlerContext ctx) throws Exception;
220 
221     /**
222      * Sends the specified message to the proxy server.  Use this method to send a response to the proxy server in
223      * {@link #handleResponse(ChannelHandlerContext, Object)}.
224      */
225     protected final void sendToProxyServer(Object msg) {
226         ctx.writeAndFlush(msg).addListener(writeListener);
227     }
228 
229     @Override
230     public final void channelInactive(ChannelHandlerContext ctx) throws Exception {
231         if (finished) {
232             ctx.fireChannelInactive();
233         } else {
234             // Disconnected before connected to the destination.
235             setConnectFailure(new ProxyConnectException(exceptionMessage("disconnected")));
236         }
237     }
238 
239     @Override
240     public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
241         if (finished) {
242             ctx.fireExceptionCaught(cause);
243         } else {
244             // Exception was raised before the connection attempt is finished.
245             setConnectFailure(cause);
246         }
247     }
248 
249     @Override
250     public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
251         if (finished) {
252             // Received a message after the connection has been established; pass through.
253             suppressChannelReadComplete = false;
254             ctx.fireChannelRead(msg);
255         } else {
256             suppressChannelReadComplete = true;
257             Throwable cause = null;
258             try {
259                 boolean done = handleResponse(ctx, msg);
260                 if (done) {
261                     setConnectSuccess();
262                 }
263             } catch (Throwable t) {
264                 cause = t;
265             } finally {
266                 ReferenceCountUtil.release(msg);
267                 if (cause != null) {
268                     setConnectFailure(cause);
269                 }
270             }
271         }
272     }
273 
274     /**
275      * Handles the message received from the proxy server.
276      *
277      * @return {@code true} if the connection to the destination has been established,
278      *         {@code false} if the connection to the destination has not been established and more messages are
279      *         expected from the proxy server
280      */
281     protected abstract boolean handleResponse(ChannelHandlerContext ctx, Object response) throws Exception;
282 
283     private void setConnectSuccess() {
284         finished = true;
285         if (connectTimeoutFuture != null) {
286             connectTimeoutFuture.cancel(false);
287         }
288 
289         if (connectPromise.trySuccess(ctx.channel())) {
290             boolean removedCodec = true;
291 
292             removedCodec &= safeRemoveEncoder();
293 
294             ctx.fireUserEventTriggered(
295                     new ProxyConnectionEvent(protocol(), authScheme(), proxyAddress, destinationAddress));
296 
297             removedCodec &= safeRemoveDecoder();
298 
299             if (removedCodec) {
300                 writePendingWrites();
301 
302                 if (flushedPrematurely) {
303                     ctx.flush();
304                 }
305             } else {
306                 // We are at inconsistent state because we failed to remove all codec handlers.
307                 Exception cause = new ProxyConnectException(
308                         "failed to remove all codec handlers added by the proxy handler; bug?");
309                 failPendingWrites(cause);
310                 ctx.fireExceptionCaught(cause);
311                 ctx.close();
312             }
313         }
314     }
315 
316     private boolean safeRemoveDecoder() {
317         try {
318             removeDecoder(ctx);
319             return true;
320         } catch (Exception e) {
321             logger.warn("Failed to remove proxy decoders:", e);
322         }
323 
324         return false;
325     }
326 
327     private boolean safeRemoveEncoder() {
328         try {
329             removeEncoder(ctx);
330             return true;
331         } catch (Exception e) {
332             logger.warn("Failed to remove proxy encoders:", e);
333         }
334 
335         return false;
336     }
337 
338     private void setConnectFailure(Throwable cause) {
339         finished = true;
340         if (connectTimeoutFuture != null) {
341             connectTimeoutFuture.cancel(false);
342         }
343 
344         if (!(cause instanceof ProxyConnectException)) {
345             cause = new ProxyConnectException(
346                     exceptionMessage(cause.toString()), cause);
347         }
348 
349         if (connectPromise.tryFailure(cause)) {
350             safeRemoveDecoder();
351             safeRemoveEncoder();
352 
353             failPendingWrites(cause);
354             ctx.fireExceptionCaught(cause);
355             ctx.close();
356         }
357     }
358 
359     /**
360      * Decorates the specified exception message with the common information such as the current protocol,
361      * authentication scheme, proxy address, and destination address.
362      */
363     protected final String exceptionMessage(String msg) {
364         if (msg == null) {
365             msg = "";
366         }
367 
368         StringBuilder buf = new StringBuilder(128 + msg.length())
369             .append(protocol())
370             .append(", ")
371             .append(authScheme())
372             .append(", ")
373             .append(proxyAddress)
374             .append(" => ")
375             .append(destinationAddress);
376         if (!msg.isEmpty()) {
377             buf.append(", ").append(msg);
378         }
379 
380         return buf.toString();
381     }
382 
383     @Override
384     public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
385         if (suppressChannelReadComplete) {
386             suppressChannelReadComplete = false;
387 
388             if (!ctx.channel().config().isAutoRead()) {
389                 ctx.read();
390             }
391         } else {
392             ctx.fireChannelReadComplete();
393         }
394     }
395 
396     @Override
397     public final void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
398         if (finished) {
399             writePendingWrites();
400             ctx.write(msg, promise);
401         } else {
402             addPendingWrite(ctx, msg, promise);
403         }
404     }
405 
406     @Override
407     public final void flush(ChannelHandlerContext ctx) throws Exception {
408         if (finished) {
409             writePendingWrites();
410             ctx.flush();
411         } else {
412             flushedPrematurely = true;
413         }
414     }
415 
416     private void writePendingWrites() {
417         if (pendingWrites != null) {
418             pendingWrites.removeAndWriteAll();
419             pendingWrites = null;
420         }
421     }
422 
423     private void failPendingWrites(Throwable cause) {
424         if (pendingWrites != null) {
425             pendingWrites.removeAndFailAll(cause);
426             pendingWrites = null;
427         }
428     }
429 
430     private void addPendingWrite(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
431         PendingWriteQueue pendingWrites = this.pendingWrites;
432         if (pendingWrites == null) {
433             this.pendingWrites = pendingWrites = new PendingWriteQueue(ctx);
434         }
435         pendingWrites.add(msg, promise);
436     }
437 
438     private final class LazyChannelPromise extends DefaultPromise<Channel> {
439         @Override
440         protected EventExecutor executor() {
441             if (ctx == null) {
442                 throw new IllegalStateException();
443             }
444             return ctx.executor();
445         }
446     }
447 }