View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  
17  package io.netty.handler.proxy;
18  
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelDuplexHandler;
21  import io.netty.channel.ChannelFutureListener;
22  import io.netty.channel.ChannelHandlerContext;
23  import io.netty.channel.ChannelPromise;
24  import io.netty.channel.PendingWriteQueue;
25  import io.netty.util.ReferenceCountUtil;
26  import io.netty.util.concurrent.DefaultPromise;
27  import io.netty.util.concurrent.EventExecutor;
28  import io.netty.util.concurrent.Future;
29  import io.netty.util.internal.ObjectUtil;
30  import io.netty.util.internal.logging.InternalLogger;
31  import io.netty.util.internal.logging.InternalLoggerFactory;
32  
33  import java.net.SocketAddress;
34  import java.nio.channels.ConnectionPendingException;
35  import java.util.concurrent.TimeUnit;
36  
37  /**
38   * A common abstraction for protocols that establish blind forwarding proxy tunnels.
39   */
40  public abstract class ProxyHandler extends ChannelDuplexHandler {
41  
42      private static final InternalLogger logger = InternalLoggerFactory.getInstance(ProxyHandler.class);
43  
44      /**
45       * The default connect timeout: 10 seconds.
46       */
47      private static final long DEFAULT_CONNECT_TIMEOUT_MILLIS = 10000;
48  
49      /**
50       * A string that signifies 'no authentication' or 'anonymous'.
51       */
52      static final String AUTH_NONE = "none";
53  
54      private final SocketAddress proxyAddress;
55      private volatile SocketAddress destinationAddress;
56      private volatile long connectTimeoutMillis = DEFAULT_CONNECT_TIMEOUT_MILLIS;
57  
58      private volatile ChannelHandlerContext ctx;
59      private PendingWriteQueue pendingWrites;
60      private boolean finished;
61      private boolean suppressChannelReadComplete;
62      private boolean flushedPrematurely;
63      private final LazyChannelPromise connectPromise = new LazyChannelPromise();
64      private Future<?> connectTimeoutFuture;
65      private final ChannelFutureListener writeListener = future -> {
66          if (!future.isSuccess()) {
67              setConnectFailure(future.cause());
68          }
69      };
70  
71      protected ProxyHandler(SocketAddress proxyAddress) {
72          this.proxyAddress = ObjectUtil.checkNotNull(proxyAddress, "proxyAddress");
73      }
74  
75      /**
76       * Returns the name of the proxy protocol in use.
77       */
78      public abstract String protocol();
79  
80      /**
81       * Returns the name of the authentication scheme in use.
82       */
83      public abstract String authScheme();
84  
85      /**
86       * Returns the address of the proxy server.
87       */
88      @SuppressWarnings("unchecked")
89      public final <T extends SocketAddress> T proxyAddress() {
90          return (T) proxyAddress;
91      }
92  
93      /**
94       * Returns the address of the destination to connect to via the proxy server.
95       */
96      @SuppressWarnings("unchecked")
97      public final <T extends SocketAddress> T destinationAddress() {
98          return (T) destinationAddress;
99      }
100 
101     /**
102      * Returns {@code true} if and only if the connection to the destination has been established successfully.
103      */
104     public final boolean isConnected() {
105         return connectPromise.isSuccess();
106     }
107 
108     /**
109      * Returns a {@link Future} that is notified when the connection to the destination has been established
110      * or the connection attempt has failed.
111      */
112     public final Future<Channel> connectFuture() {
113         return connectPromise;
114     }
115 
116     /**
117      * Returns the connect timeout in millis.  If the connection attempt to the destination does not finish within
118      * the timeout, the connection attempt will be failed.
119      */
120     public final long connectTimeoutMillis() {
121         return connectTimeoutMillis;
122     }
123 
124     /**
125      * Sets the connect timeout in millis.  If the connection attempt to the destination does not finish within
126      * the timeout, the connection attempt will be failed.
127      */
128     public final void setConnectTimeoutMillis(long connectTimeoutMillis) {
129         if (connectTimeoutMillis <= 0) {
130             connectTimeoutMillis = 0;
131         }
132 
133         this.connectTimeoutMillis = connectTimeoutMillis;
134     }
135 
136     @Override
137     public final void handlerAdded(ChannelHandlerContext ctx) throws Exception {
138         this.ctx = ctx;
139         addCodec(ctx);
140 
141         if (ctx.channel().isActive()) {
142             // channelActive() event has been fired already, which means this.channelActive() will
143             // not be invoked. We have to initialize here instead.
144             sendInitialMessage(ctx);
145         } else {
146             // channelActive() event has not been fired yet.  this.channelOpen() will be invoked
147             // and initialization will occur there.
148         }
149     }
150 
151     /**
152      * Adds the codec handlers required to communicate with the proxy server.
153      */
154     protected abstract void addCodec(ChannelHandlerContext ctx) throws Exception;
155 
156     /**
157      * Removes the encoders added in {@link #addCodec(ChannelHandlerContext)}.
158      */
159     protected abstract void removeEncoder(ChannelHandlerContext ctx) throws Exception;
160 
161     /**
162      * Removes the decoders added in {@link #addCodec(ChannelHandlerContext)}.
163      */
164     protected abstract void removeDecoder(ChannelHandlerContext ctx) throws Exception;
165 
166     @Override
167     public final void connect(
168             ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
169             ChannelPromise promise) throws Exception {
170 
171         if (destinationAddress != null) {
172             promise.setFailure(new ConnectionPendingException());
173             return;
174         }
175 
176         destinationAddress = remoteAddress;
177         ctx.connect(proxyAddress, localAddress, promise);
178     }
179 
180     @Override
181     public final void channelActive(ChannelHandlerContext ctx) throws Exception {
182         sendInitialMessage(ctx);
183         ctx.fireChannelActive();
184     }
185 
186     /**
187      * Sends the initial message to be sent to the proxy server. This method also starts a timeout task which marks
188      * the {@link #connectPromise} as failure if the connection attempt does not success within the timeout.
189      */
190     private void sendInitialMessage(final ChannelHandlerContext ctx) throws Exception {
191         final long connectTimeoutMillis = this.connectTimeoutMillis;
192         if (connectTimeoutMillis > 0) {
193             connectTimeoutFuture = ctx.executor().schedule(new Runnable() {
194                 @Override
195                 public void run() {
196                     if (!connectPromise.isDone()) {
197                         setConnectFailure(new ProxyConnectException(exceptionMessage("timeout")));
198                     }
199                 }
200             }, connectTimeoutMillis, TimeUnit.MILLISECONDS);
201         }
202 
203         final Object initialMessage = newInitialMessage(ctx);
204         if (initialMessage != null) {
205             sendToProxyServer(initialMessage);
206         }
207 
208         readIfNeeded(ctx);
209     }
210 
211     /**
212      * Returns a new message that is sent at first time when the connection to the proxy server has been established.
213      *
214      * @return the initial message, or {@code null} if the proxy server is expected to send the first message instead
215      */
216     protected abstract Object newInitialMessage(ChannelHandlerContext ctx) throws Exception;
217 
218     /**
219      * Sends the specified message to the proxy server.  Use this method to send a response to the proxy server in
220      * {@link #handleResponse(ChannelHandlerContext, Object)}.
221      */
222     protected final void sendToProxyServer(Object msg) {
223         ctx.writeAndFlush(msg).addListener(writeListener);
224     }
225 
226     @Override
227     public final void channelInactive(ChannelHandlerContext ctx) throws Exception {
228         if (finished) {
229             ctx.fireChannelInactive();
230         } else {
231             // Disconnected before connected to the destination.
232             setConnectFailure(new ProxyConnectException(exceptionMessage("disconnected")));
233         }
234     }
235 
236     @Override
237     public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
238         if (finished) {
239             ctx.fireExceptionCaught(cause);
240         } else {
241             // Exception was raised before the connection attempt is finished.
242             setConnectFailure(cause);
243         }
244     }
245 
246     @Override
247     public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
248         if (finished) {
249             // Received a message after the connection has been established; pass through.
250             suppressChannelReadComplete = false;
251             ctx.fireChannelRead(msg);
252         } else {
253             suppressChannelReadComplete = true;
254             Throwable cause = null;
255             try {
256                 boolean done = handleResponse(ctx, msg);
257                 if (done) {
258                     setConnectSuccess();
259                 }
260             } catch (Throwable t) {
261                 cause = t;
262             } finally {
263                 ReferenceCountUtil.release(msg);
264                 if (cause != null) {
265                     setConnectFailure(cause);
266                 }
267             }
268         }
269     }
270 
271     /**
272      * Handles the message received from the proxy server.
273      *
274      * @return {@code true} if the connection to the destination has been established,
275      *         {@code false} if the connection to the destination has not been established and more messages are
276      *         expected from the proxy server
277      */
278     protected abstract boolean handleResponse(ChannelHandlerContext ctx, Object response) throws Exception;
279 
280     private void setConnectSuccess() {
281         finished = true;
282         cancelConnectTimeoutFuture();
283 
284         if (!connectPromise.isDone()) {
285             boolean removedCodec = true;
286 
287             removedCodec &= safeRemoveEncoder();
288 
289             ctx.fireUserEventTriggered(
290                     new ProxyConnectionEvent(protocol(), authScheme(), proxyAddress, destinationAddress));
291 
292             removedCodec &= safeRemoveDecoder();
293 
294             if (removedCodec) {
295                 writePendingWrites();
296 
297                 if (flushedPrematurely) {
298                     ctx.flush();
299                 }
300                 connectPromise.trySuccess(ctx.channel());
301             } else {
302                 // We are at inconsistent state because we failed to remove all codec handlers.
303                 Exception cause = new ProxyConnectException(
304                         "failed to remove all codec handlers added by the proxy handler; bug?");
305                 failPendingWritesAndClose(cause);
306             }
307         }
308     }
309 
310     private boolean safeRemoveDecoder() {
311         try {
312             removeDecoder(ctx);
313             return true;
314         } catch (Exception e) {
315             logger.warn("Failed to remove proxy decoders:", e);
316         }
317 
318         return false;
319     }
320 
321     private boolean safeRemoveEncoder() {
322         try {
323             removeEncoder(ctx);
324             return true;
325         } catch (Exception e) {
326             logger.warn("Failed to remove proxy encoders:", e);
327         }
328 
329         return false;
330     }
331 
332     private void setConnectFailure(Throwable cause) {
333         finished = true;
334         cancelConnectTimeoutFuture();
335 
336         if (!connectPromise.isDone()) {
337 
338             if (!(cause instanceof ProxyConnectException)) {
339                 cause = new ProxyConnectException(
340                         exceptionMessage(cause.toString()), cause);
341             }
342 
343             safeRemoveDecoder();
344             safeRemoveEncoder();
345             failPendingWritesAndClose(cause);
346         }
347     }
348 
349     private void failPendingWritesAndClose(Throwable cause) {
350         failPendingWrites(cause);
351         connectPromise.tryFailure(cause);
352         ctx.fireExceptionCaught(cause);
353         ctx.close();
354     }
355 
356     private void cancelConnectTimeoutFuture() {
357         if (connectTimeoutFuture != null) {
358             connectTimeoutFuture.cancel(false);
359             connectTimeoutFuture = null;
360         }
361     }
362 
363     /**
364      * Decorates the specified exception message with the common information such as the current protocol,
365      * authentication scheme, proxy address, and destination address.
366      */
367     protected final String exceptionMessage(String msg) {
368         if (msg == null) {
369             msg = "";
370         }
371 
372         StringBuilder buf = new StringBuilder(128 + msg.length())
373             .append(protocol())
374             .append(", ")
375             .append(authScheme())
376             .append(", ")
377             .append(proxyAddress)
378             .append(" => ")
379             .append(destinationAddress);
380         if (!msg.isEmpty()) {
381             buf.append(", ").append(msg);
382         }
383 
384         return buf.toString();
385     }
386 
387     @Override
388     public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
389         if (suppressChannelReadComplete) {
390             suppressChannelReadComplete = false;
391 
392             readIfNeeded(ctx);
393         } else {
394             ctx.fireChannelReadComplete();
395         }
396     }
397 
398     @Override
399     public final void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
400         if (finished) {
401             writePendingWrites();
402             ctx.write(msg, promise);
403         } else {
404             addPendingWrite(ctx, msg, promise);
405         }
406     }
407 
408     @Override
409     public final void flush(ChannelHandlerContext ctx) throws Exception {
410         if (finished) {
411             writePendingWrites();
412             ctx.flush();
413         } else {
414             flushedPrematurely = true;
415         }
416     }
417 
418     private static void readIfNeeded(ChannelHandlerContext ctx) {
419         if (!ctx.channel().config().isAutoRead()) {
420             ctx.read();
421         }
422     }
423 
424     private void writePendingWrites() {
425         if (pendingWrites != null) {
426             pendingWrites.removeAndWriteAll();
427             pendingWrites = null;
428         }
429     }
430 
431     private void failPendingWrites(Throwable cause) {
432         if (pendingWrites != null) {
433             pendingWrites.removeAndFailAll(cause);
434             pendingWrites = null;
435         }
436     }
437 
438     private void addPendingWrite(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
439         PendingWriteQueue pendingWrites = this.pendingWrites;
440         if (pendingWrites == null) {
441             this.pendingWrites = pendingWrites = new PendingWriteQueue(ctx);
442         }
443         pendingWrites.add(msg, promise);
444     }
445 
446     private final class LazyChannelPromise extends DefaultPromise<Channel> {
447         @Override
448         protected EventExecutor executor() {
449             if (ctx == null) {
450                 throw new IllegalStateException();
451             }
452             return ctx.executor();
453         }
454     }
455 }