View Javadoc

1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package org.jboss.netty.channel.socket.http;
17  
18  import java.io.EOFException;
19  import java.io.IOException;
20  import java.io.PushbackInputStream;
21  import java.net.SocketAddress;
22  
23  import javax.servlet.ServletConfig;
24  import javax.servlet.ServletException;
25  import javax.servlet.ServletOutputStream;
26  import javax.servlet.http.HttpServlet;
27  import javax.servlet.http.HttpServletRequest;
28  import javax.servlet.http.HttpServletResponse;
29  
30  import org.jboss.netty.buffer.ChannelBuffer;
31  import org.jboss.netty.buffer.ChannelBuffers;
32  import org.jboss.netty.channel.Channel;
33  import org.jboss.netty.channel.ChannelFactory;
34  import org.jboss.netty.channel.ChannelFuture;
35  import org.jboss.netty.channel.ChannelFutureListener;
36  import org.jboss.netty.channel.ChannelHandlerContext;
37  import org.jboss.netty.channel.ChannelPipeline;
38  import org.jboss.netty.channel.Channels;
39  import org.jboss.netty.channel.ExceptionEvent;
40  import org.jboss.netty.channel.MessageEvent;
41  import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
42  import org.jboss.netty.channel.local.DefaultLocalClientChannelFactory;
43  import org.jboss.netty.channel.local.LocalAddress;
44  import org.jboss.netty.handler.codec.http.HttpHeaders;
45  import org.jboss.netty.logging.InternalLogger;
46  import org.jboss.netty.logging.InternalLoggerFactory;
47  
48  /**
49   * An {@link HttpServlet} that proxies an incoming data to the actual server
50   * and vice versa.  Please refer to the
51   * <a href="package-summary.html#package_description">package summary</a> for
52   * the detailed usage.
53   * @apiviz.landmark
54   */
55  public class HttpTunnelingServlet extends HttpServlet {
56  
57      private static final long serialVersionUID = 4259910275899756070L;
58  
59      private static final String ENDPOINT = "endpoint";
60      private static final String CONNECT_ATTEMPTS = "connectAttempts";
61      private static final String RETRY_DELAY = "retryDelay";
62  
63      static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpTunnelingServlet.class);
64  
65      private volatile SocketAddress remoteAddress;
66      private volatile ChannelFactory channelFactory;
67      private volatile long connectAttempts = 1;
68      private volatile long retryDelay;
69  
70      @Override
71      public void init() throws ServletException {
72          ServletConfig config = getServletConfig();
73          String endpoint = config.getInitParameter(ENDPOINT);
74          if (endpoint == null) {
75              throw new ServletException("init-param '" + ENDPOINT + "' must be specified.");
76          }
77  
78          try {
79              remoteAddress = parseEndpoint(endpoint.trim());
80          } catch (ServletException e) {
81              throw e;
82          } catch (Exception e) {
83              throw new ServletException("Failed to parse an endpoint.", e);
84          }
85  
86          try {
87              channelFactory = createChannelFactory(remoteAddress);
88          } catch (ServletException e) {
89              throw e;
90          } catch (Exception e) {
91              throw new ServletException("Failed to create a channel factory.", e);
92          }
93  
94          String temp = config.getInitParameter(CONNECT_ATTEMPTS);
95          if (temp != null) {
96              try {
97                  connectAttempts = Long.parseLong(temp);
98              } catch (NumberFormatException e) {
99                  throw new ServletException(
100                    "init-param '" + CONNECT_ATTEMPTS + "' is not a valid number. Actual value: " + temp);
101             }
102             if (connectAttempts < 1) {
103                 throw new ServletException(
104                    "init-param '" + CONNECT_ATTEMPTS + "' must be >= 1. Actual value: " + connectAttempts);
105             }
106         }
107 
108         temp = config.getInitParameter(RETRY_DELAY);
109         if (temp != null) {
110             try {
111                 retryDelay = Long.parseLong(temp);
112             } catch (NumberFormatException e) {
113                 throw new ServletException(
114                    "init-param '" + RETRY_DELAY + "' is not a valid number. Actual value: " + temp);
115             }
116             if (retryDelay < 0) {
117                 throw new ServletException(
118                    "init-param '" + RETRY_DELAY + "' must be >= 0. Actual value: " + retryDelay);
119             }
120         }
121 
122         // Stuff for testing purpose
123         //ServerBootstrap b = new ServerBootstrap(new DefaultLocalServerChannelFactory());
124         //b.getPipeline().addLast("logger", new LoggingHandler(getClass(), InternalLogLevel.INFO, true));
125         //b.getPipeline().addLast("handler", new EchoHandler());
126         //b.bind(remoteAddress);
127     }
128 
129     protected SocketAddress parseEndpoint(String endpoint) throws Exception {
130         if (endpoint.startsWith("local:")) {
131             return new LocalAddress(endpoint.substring(6).trim());
132         } else {
133             throw new ServletException(
134                     "Invalid or unknown endpoint: " + endpoint);
135         }
136     }
137 
138     protected ChannelFactory createChannelFactory(SocketAddress remoteAddress) throws Exception {
139         if (remoteAddress instanceof LocalAddress) {
140             return new DefaultLocalClientChannelFactory();
141         } else {
142             throw new ServletException(
143                     "Unsupported remote address type: " +
144                     remoteAddress.getClass().getName());
145         }
146     }
147 
148     @Override
149     public void destroy() {
150         try {
151             destroyChannelFactory(channelFactory);
152         } catch (Exception e) {
153             if (logger.isWarnEnabled()) {
154                 logger.warn("Failed to destroy a channel factory.", e);
155             }
156         }
157     }
158 
159     protected void destroyChannelFactory(ChannelFactory factory) throws Exception {
160         factory.releaseExternalResources();
161     }
162 
163     @Override
164     protected void service(HttpServletRequest req, HttpServletResponse res)
165             throws ServletException, IOException {
166         if (!"POST".equalsIgnoreCase(req.getMethod())) {
167             if (logger.isWarnEnabled()) {
168                 logger.warn("Unallowed method: " + req.getMethod());
169             }
170             res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
171             return;
172         }
173 
174         final ChannelPipeline pipeline = Channels.pipeline();
175         final ServletOutputStream out = res.getOutputStream();
176         final OutboundConnectionHandler handler = new OutboundConnectionHandler(out);
177         pipeline.addLast("handler", handler);
178 
179         Channel channel = channelFactory.newChannel(pipeline);
180         int tries = 0;
181         ChannelFuture future = null;
182 
183         while (tries < connectAttempts) {
184             future = channel.connect(remoteAddress).awaitUninterruptibly();
185             if (!future.isSuccess()) {
186                 tries++;
187                 try {
188                     Thread.sleep(retryDelay);
189                 } catch (InterruptedException e) {
190                     // ignore
191                 }
192             } else {
193                 break;
194             }
195         }
196 
197         if (!future.isSuccess()) {
198             if (logger.isWarnEnabled()) {
199                 Throwable cause = future.getCause();
200                 logger.warn("Endpoint unavailable: " + cause.getMessage(), cause);
201             }
202             res.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
203             return;
204         }
205 
206         ChannelFuture lastWriteFuture = null;
207         try {
208             res.setStatus(HttpServletResponse.SC_OK);
209             res.setHeader(HttpHeaders.Names.CONTENT_TYPE, "application/octet-stream");
210             res.setHeader(HttpHeaders.Names.CONTENT_TRANSFER_ENCODING, HttpHeaders.Values.BINARY);
211 
212             // Initiate chunked encoding by flushing the headers.
213             out.flush();
214 
215             PushbackInputStream in =
216                     new PushbackInputStream(req.getInputStream());
217             while (channel.isConnected()) {
218                 ChannelBuffer buffer;
219                 try {
220                     buffer = read(in);
221                 } catch (EOFException e) {
222                     break;
223                 }
224                 if (buffer == null) {
225                     break;
226                 }
227                 lastWriteFuture = channel.write(buffer);
228             }
229         } finally {
230             if (lastWriteFuture == null) {
231                 channel.close();
232             } else {
233                 lastWriteFuture.addListener(ChannelFutureListener.CLOSE);
234             }
235         }
236     }
237 
238     private static ChannelBuffer read(PushbackInputStream in) throws IOException {
239         byte[] buf;
240         int readBytes;
241 
242         int bytesToRead = in.available();
243         if (bytesToRead > 0) {
244             buf = new byte[bytesToRead];
245             readBytes = in.read(buf);
246         } else if (bytesToRead == 0) {
247             int b = in.read();
248             if (b < 0 || in.available() < 0) {
249                 return null;
250             }
251             in.unread(b);
252             bytesToRead = in.available();
253             buf = new byte[bytesToRead];
254             readBytes = in.read(buf);
255         } else {
256             return null;
257         }
258 
259         assert readBytes > 0;
260 
261         ChannelBuffer buffer;
262         if (readBytes == buf.length) {
263             buffer = ChannelBuffers.wrappedBuffer(buf);
264         } else {
265             // A rare case, but it sometimes happen.
266             buffer = ChannelBuffers.wrappedBuffer(buf, 0, readBytes);
267         }
268         return buffer;
269     }
270 
271     private static final class OutboundConnectionHandler extends SimpleChannelUpstreamHandler {
272 
273         private final ServletOutputStream out;
274 
275         public OutboundConnectionHandler(ServletOutputStream out) {
276             this.out = out;
277         }
278 
279         @Override
280         public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
281             ChannelBuffer buffer = (ChannelBuffer) e.getMessage();
282             synchronized (this) {
283                 buffer.readBytes(out, buffer.readableBytes());
284                 out.flush();
285             }
286         }
287 
288         @Override
289         public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
290             if (logger.isWarnEnabled()) {
291                 logger.warn("Unexpected exception while HTTP tunneling", e.getCause());
292             }
293             e.getChannel().close();
294         }
295     }
296 }