1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.channel.socket.http;
17
18 import java.io.EOFException;
19 import java.io.IOException;
20 import java.io.PushbackInputStream;
21 import java.net.SocketAddress;
22
23 import javax.servlet.ServletConfig;
24 import javax.servlet.ServletException;
25 import javax.servlet.ServletOutputStream;
26 import javax.servlet.http.HttpServlet;
27 import javax.servlet.http.HttpServletRequest;
28 import javax.servlet.http.HttpServletResponse;
29
30 import org.jboss.netty.buffer.ChannelBuffer;
31 import org.jboss.netty.buffer.ChannelBuffers;
32 import org.jboss.netty.channel.Channel;
33 import org.jboss.netty.channel.ChannelFactory;
34 import org.jboss.netty.channel.ChannelFuture;
35 import org.jboss.netty.channel.ChannelFutureListener;
36 import org.jboss.netty.channel.ChannelHandlerContext;
37 import org.jboss.netty.channel.ChannelPipeline;
38 import org.jboss.netty.channel.Channels;
39 import org.jboss.netty.channel.ExceptionEvent;
40 import org.jboss.netty.channel.MessageEvent;
41 import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
42 import org.jboss.netty.channel.local.DefaultLocalClientChannelFactory;
43 import org.jboss.netty.channel.local.LocalAddress;
44 import org.jboss.netty.handler.codec.http.HttpHeaders;
45 import org.jboss.netty.logging.InternalLogger;
46 import org.jboss.netty.logging.InternalLoggerFactory;
47
48
49
50
51
52
53
54
55 public class HttpTunnelingServlet extends HttpServlet {
56
57 private static final long serialVersionUID = 4259910275899756070L;
58
59 private static final String ENDPOINT = "endpoint";
60
61 static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpTunnelingServlet.class);
62
63 private volatile SocketAddress remoteAddress;
64 private volatile ChannelFactory channelFactory;
65
66 @Override
67 public void init() throws ServletException {
68 ServletConfig config = getServletConfig();
69 String endpoint = config.getInitParameter(ENDPOINT);
70 if (endpoint == null) {
71 throw new ServletException("init-param '" + ENDPOINT + "' must be specified.");
72 }
73
74 try {
75 remoteAddress = parseEndpoint(endpoint.trim());
76 } catch (ServletException e) {
77 throw e;
78 } catch (Exception e) {
79 throw new ServletException("Failed to parse an endpoint.", e);
80 }
81
82 try {
83 channelFactory = createChannelFactory(remoteAddress);
84 } catch (ServletException e) {
85 throw e;
86 } catch (Exception e) {
87 throw new ServletException("Failed to create a channel factory.", e);
88 }
89
90
91
92
93
94
95 }
96
97 protected SocketAddress parseEndpoint(String endpoint) throws Exception {
98 if (endpoint.startsWith("local:")) {
99 return new LocalAddress(endpoint.substring(6).trim());
100 } else {
101 throw new ServletException(
102 "Invalid or unknown endpoint: " + endpoint);
103 }
104 }
105
106 protected ChannelFactory createChannelFactory(SocketAddress remoteAddress) throws Exception {
107 if (remoteAddress instanceof LocalAddress) {
108 return new DefaultLocalClientChannelFactory();
109 } else {
110 throw new ServletException(
111 "Unsupported remote address type: " +
112 remoteAddress.getClass().getName());
113 }
114 }
115
116 @Override
117 public void destroy() {
118 try {
119 destroyChannelFactory(channelFactory);
120 } catch (Exception e) {
121 if (logger.isWarnEnabled()) {
122 logger.warn("Failed to destroy a channel factory.", e);
123 }
124 }
125 }
126
127 protected void destroyChannelFactory(ChannelFactory factory) throws Exception {
128 factory.releaseExternalResources();
129 }
130
131 @Override
132 protected void service(HttpServletRequest req, HttpServletResponse res)
133 throws ServletException, IOException {
134 if (!"POST".equalsIgnoreCase(req.getMethod())) {
135 if (logger.isWarnEnabled()) {
136 logger.warn("Unallowed method: " + req.getMethod());
137 }
138 res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
139 return;
140 }
141
142 final ChannelPipeline pipeline = Channels.pipeline();
143 final ServletOutputStream out = res.getOutputStream();
144 final OutboundConnectionHandler handler = new OutboundConnectionHandler(out);
145 pipeline.addLast("handler", handler);
146
147 Channel channel = channelFactory.newChannel(pipeline);
148 ChannelFuture future = channel.connect(remoteAddress).awaitUninterruptibly();
149 if (!future.isSuccess()) {
150 if (logger.isWarnEnabled()) {
151 Throwable cause = future.getCause();
152 logger.warn("Endpoint unavailable: " + cause.getMessage(), cause);
153 }
154 res.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
155 return;
156 }
157
158 ChannelFuture lastWriteFuture = null;
159 try {
160 res.setStatus(HttpServletResponse.SC_OK);
161 res.setHeader(HttpHeaders.Names.CONTENT_TYPE, "application/octet-stream");
162 res.setHeader(HttpHeaders.Names.CONTENT_TRANSFER_ENCODING, HttpHeaders.Values.BINARY);
163
164
165 out.flush();
166
167 PushbackInputStream in =
168 new PushbackInputStream(req.getInputStream());
169 while (channel.isConnected()) {
170 ChannelBuffer buffer;
171 try {
172 buffer = read(in);
173 } catch (EOFException e) {
174 break;
175 }
176 if (buffer == null) {
177 break;
178 }
179 lastWriteFuture = channel.write(buffer);
180 }
181 } finally {
182 if (lastWriteFuture == null) {
183 channel.close();
184 } else {
185 lastWriteFuture.addListener(ChannelFutureListener.CLOSE);
186 }
187 }
188 }
189
190 private static ChannelBuffer read(PushbackInputStream in) throws IOException {
191 byte[] buf;
192 int readBytes;
193
194 int bytesToRead = in.available();
195 if (bytesToRead > 0) {
196 buf = new byte[bytesToRead];
197 readBytes = in.read(buf);
198 } else if (bytesToRead == 0) {
199 int b = in.read();
200 if (b < 0 || in.available() < 0) {
201 return null;
202 }
203 in.unread(b);
204 bytesToRead = in.available();
205 buf = new byte[bytesToRead];
206 readBytes = in.read(buf);
207 } else {
208 return null;
209 }
210
211 assert readBytes > 0;
212
213 ChannelBuffer buffer;
214 if (readBytes == buf.length) {
215 buffer = ChannelBuffers.wrappedBuffer(buf);
216 } else {
217
218 buffer = ChannelBuffers.wrappedBuffer(buf, 0, readBytes);
219 }
220 return buffer;
221 }
222
223 private static final class OutboundConnectionHandler extends SimpleChannelUpstreamHandler {
224
225 private final ServletOutputStream out;
226
227 public OutboundConnectionHandler(ServletOutputStream out) {
228 this.out = out;
229 }
230
231 @Override
232 public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
233 ChannelBuffer buffer = (ChannelBuffer) e.getMessage();
234 synchronized (this) {
235 buffer.readBytes(out, buffer.readableBytes());
236 out.flush();
237 }
238 }
239
240 @Override
241 public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
242 if (logger.isWarnEnabled()) {
243 logger.warn("Unexpected exception while HTTP tunneling", e.getCause());
244 }
245 e.getChannel().close();
246 }
247 }
248 }