1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.channel.socket.http;
17
18 import java.io.EOFException;
19 import java.io.IOException;
20 import java.io.PushbackInputStream;
21 import java.net.SocketAddress;
22
23 import javax.servlet.ServletConfig;
24 import javax.servlet.ServletException;
25 import javax.servlet.ServletOutputStream;
26 import javax.servlet.http.HttpServlet;
27 import javax.servlet.http.HttpServletRequest;
28 import javax.servlet.http.HttpServletResponse;
29
30 import org.jboss.netty.buffer.ChannelBuffer;
31 import org.jboss.netty.buffer.ChannelBuffers;
32 import org.jboss.netty.channel.Channel;
33 import org.jboss.netty.channel.ChannelFactory;
34 import org.jboss.netty.channel.ChannelFuture;
35 import org.jboss.netty.channel.ChannelFutureListener;
36 import org.jboss.netty.channel.ChannelHandlerContext;
37 import org.jboss.netty.channel.ChannelPipeline;
38 import org.jboss.netty.channel.Channels;
39 import org.jboss.netty.channel.ExceptionEvent;
40 import org.jboss.netty.channel.MessageEvent;
41 import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
42 import org.jboss.netty.channel.local.DefaultLocalClientChannelFactory;
43 import org.jboss.netty.channel.local.LocalAddress;
44 import org.jboss.netty.handler.codec.http.HttpHeaders;
45 import org.jboss.netty.logging.InternalLogger;
46 import org.jboss.netty.logging.InternalLoggerFactory;
47
48
49
50
51
52
53
54
55 public class HttpTunnelingServlet extends HttpServlet {
56
57 private static final long serialVersionUID = 4259910275899756070L;
58
59 private static final String ENDPOINT = "endpoint";
60 private static final String CONNECT_ATTEMPTS = "connectAttempts";
61 private static final String RETRY_DELAY = "retryDelay";
62
63 static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpTunnelingServlet.class);
64
65 private volatile SocketAddress remoteAddress;
66 private volatile ChannelFactory channelFactory;
67 private volatile long connectAttempts = 1;
68 private volatile long retryDelay;
69
70 @Override
71 public void init() throws ServletException {
72 ServletConfig config = getServletConfig();
73 String endpoint = config.getInitParameter(ENDPOINT);
74 if (endpoint == null) {
75 throw new ServletException("init-param '" + ENDPOINT + "' must be specified.");
76 }
77
78 try {
79 remoteAddress = parseEndpoint(endpoint.trim());
80 } catch (ServletException e) {
81 throw e;
82 } catch (Exception e) {
83 throw new ServletException("Failed to parse an endpoint.", e);
84 }
85
86 try {
87 channelFactory = createChannelFactory(remoteAddress);
88 } catch (ServletException e) {
89 throw e;
90 } catch (Exception e) {
91 throw new ServletException("Failed to create a channel factory.", e);
92 }
93
94 String temp = config.getInitParameter(CONNECT_ATTEMPTS);
95 if (temp != null) {
96 try {
97 connectAttempts = Long.parseLong(temp);
98 } catch (NumberFormatException e) {
99 throw new ServletException(
100 "init-param '" + CONNECT_ATTEMPTS + "' is not a valid number. Actual value: " + temp);
101 }
102 if (connectAttempts < 1) {
103 throw new ServletException(
104 "init-param '" + CONNECT_ATTEMPTS + "' must be >= 1. Actual value: " + connectAttempts);
105 }
106 }
107
108 temp = config.getInitParameter(RETRY_DELAY);
109 if (temp != null) {
110 try {
111 retryDelay = Long.parseLong(temp);
112 } catch (NumberFormatException e) {
113 throw new ServletException(
114 "init-param '" + RETRY_DELAY + "' is not a valid number. Actual value: " + temp);
115 }
116 if (retryDelay < 0) {
117 throw new ServletException(
118 "init-param '" + RETRY_DELAY + "' must be >= 0. Actual value: " + retryDelay);
119 }
120 }
121
122
123
124
125
126
127 }
128
129 protected SocketAddress parseEndpoint(String endpoint) throws Exception {
130 if (endpoint.startsWith("local:")) {
131 return new LocalAddress(endpoint.substring(6).trim());
132 } else {
133 throw new ServletException(
134 "Invalid or unknown endpoint: " + endpoint);
135 }
136 }
137
138 protected ChannelFactory createChannelFactory(SocketAddress remoteAddress) throws Exception {
139 if (remoteAddress instanceof LocalAddress) {
140 return new DefaultLocalClientChannelFactory();
141 } else {
142 throw new ServletException(
143 "Unsupported remote address type: " +
144 remoteAddress.getClass().getName());
145 }
146 }
147
148 @Override
149 public void destroy() {
150 try {
151 destroyChannelFactory(channelFactory);
152 } catch (Exception e) {
153 if (logger.isWarnEnabled()) {
154 logger.warn("Failed to destroy a channel factory.", e);
155 }
156 }
157 }
158
159 protected void destroyChannelFactory(ChannelFactory factory) throws Exception {
160 factory.releaseExternalResources();
161 }
162
163 @Override
164 protected void service(HttpServletRequest req, HttpServletResponse res)
165 throws ServletException, IOException {
166 if (!"POST".equalsIgnoreCase(req.getMethod())) {
167 if (logger.isWarnEnabled()) {
168 logger.warn("Unallowed method: " + req.getMethod());
169 }
170 res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
171 return;
172 }
173
174 final ChannelPipeline pipeline = Channels.pipeline();
175 final ServletOutputStream out = res.getOutputStream();
176 final OutboundConnectionHandler handler = new OutboundConnectionHandler(out);
177 pipeline.addLast("handler", handler);
178
179 Channel channel = channelFactory.newChannel(pipeline);
180 int tries = 0;
181 ChannelFuture future = null;
182
183 while (tries < connectAttempts) {
184 future = channel.connect(remoteAddress).awaitUninterruptibly();
185 if (!future.isSuccess()) {
186 tries++;
187 try {
188 Thread.sleep(retryDelay);
189 } catch (InterruptedException e) {
190
191 }
192 } else {
193 break;
194 }
195 }
196
197 if (!future.isSuccess()) {
198 if (logger.isWarnEnabled()) {
199 Throwable cause = future.getCause();
200 logger.warn("Endpoint unavailable: " + cause.getMessage(), cause);
201 }
202 res.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
203 return;
204 }
205
206 ChannelFuture lastWriteFuture = null;
207 try {
208 res.setStatus(HttpServletResponse.SC_OK);
209 res.setHeader(HttpHeaders.Names.CONTENT_TYPE, "application/octet-stream");
210 res.setHeader(HttpHeaders.Names.CONTENT_TRANSFER_ENCODING, HttpHeaders.Values.BINARY);
211
212
213 out.flush();
214
215 PushbackInputStream in =
216 new PushbackInputStream(req.getInputStream());
217 while (channel.isConnected()) {
218 ChannelBuffer buffer;
219 try {
220 buffer = read(in);
221 } catch (EOFException e) {
222 break;
223 }
224 if (buffer == null) {
225 break;
226 }
227 lastWriteFuture = channel.write(buffer);
228 }
229 } finally {
230 if (lastWriteFuture == null) {
231 channel.close();
232 } else {
233 lastWriteFuture.addListener(ChannelFutureListener.CLOSE);
234 }
235 }
236 }
237
238 private static ChannelBuffer read(PushbackInputStream in) throws IOException {
239 byte[] buf;
240 int readBytes;
241
242 int bytesToRead = in.available();
243 if (bytesToRead > 0) {
244 buf = new byte[bytesToRead];
245 readBytes = in.read(buf);
246 } else if (bytesToRead == 0) {
247 int b = in.read();
248 if (b < 0 || in.available() < 0) {
249 return null;
250 }
251 in.unread(b);
252 bytesToRead = in.available();
253 buf = new byte[bytesToRead];
254 readBytes = in.read(buf);
255 } else {
256 return null;
257 }
258
259 assert readBytes > 0;
260
261 ChannelBuffer buffer;
262 if (readBytes == buf.length) {
263 buffer = ChannelBuffers.wrappedBuffer(buf);
264 } else {
265
266 buffer = ChannelBuffers.wrappedBuffer(buf, 0, readBytes);
267 }
268 return buffer;
269 }
270
271 private static final class OutboundConnectionHandler extends SimpleChannelUpstreamHandler {
272
273 private final ServletOutputStream out;
274
275 public OutboundConnectionHandler(ServletOutputStream out) {
276 this.out = out;
277 }
278
279 @Override
280 public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
281 ChannelBuffer buffer = (ChannelBuffer) e.getMessage();
282 synchronized (this) {
283 buffer.readBytes(out, buffer.readableBytes());
284 out.flush();
285 }
286 }
287
288 @Override
289 public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
290 if (logger.isWarnEnabled()) {
291 logger.warn("Unexpected exception while HTTP tunneling", e.getCause());
292 }
293 e.getChannel().close();
294 }
295 }
296 }