1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.handler.codec.http.websocketx;
17
18 import static org.jboss.netty.handler.codec.http.HttpHeaders.isKeepAlive;
19 import static org.jboss.netty.handler.codec.http.HttpMethod.GET;
20 import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
21 import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1;
22
23 import org.jboss.netty.channel.ChannelFuture;
24 import org.jboss.netty.channel.ChannelFutureListener;
25 import org.jboss.netty.channel.ChannelHandlerContext;
26 import org.jboss.netty.channel.ChannelPipeline;
27 import org.jboss.netty.channel.Channels;
28 import org.jboss.netty.channel.MessageEvent;
29 import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
30 import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
31 import org.jboss.netty.handler.codec.http.HttpHeaders;
32 import org.jboss.netty.handler.codec.http.HttpRequest;
33 import org.jboss.netty.handler.codec.http.HttpResponse;
34 import org.jboss.netty.handler.ssl.SslHandler;
35 import org.jboss.netty.logging.InternalLogger;
36 import org.jboss.netty.logging.InternalLoggerFactory;
37
38
39
40
41 public class WebSocketServerProtocolHandshakeHandler extends SimpleChannelUpstreamHandler {
42
43 private static final InternalLogger logger =
44 InternalLoggerFactory.getInstance(WebSocketServerProtocolHandshakeHandler.class);
45 private final String websocketPath;
46 private final String subprotocols;
47 private final boolean allowExtensions;
48
49 public WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
50 boolean allowExtensions) {
51 this.websocketPath = websocketPath;
52 this.subprotocols = subprotocols;
53 this.allowExtensions = allowExtensions;
54 }
55
56 @Override
57 public void messageReceived(final ChannelHandlerContext ctx, MessageEvent e) throws Exception {
58 if (e.getMessage() instanceof HttpRequest) {
59 HttpRequest req = (HttpRequest) e.getMessage();
60 if (req.getMethod() != GET) {
61 sendHttpResponse(ctx, req, new DefaultHttpResponse(HTTP_1_1, FORBIDDEN));
62 return;
63 }
64
65 final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
66 getWebSocketLocation(ctx.getPipeline(), req, websocketPath), subprotocols, allowExtensions);
67 final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
68 if (handshaker == null) {
69 wsFactory.sendUnsupportedWebSocketVersionResponse(ctx.getChannel());
70 } else {
71 final ChannelFuture handshakeFuture = handshaker.handshake(ctx.getChannel(), req);
72 handshakeFuture.addListener(new ChannelFutureListener() {
73 public void operationComplete(ChannelFuture future) throws Exception {
74 if (!future.isSuccess()) {
75 Channels.fireExceptionCaught(ctx, future.getCause());
76 }
77 }
78 });
79 WebSocketServerProtocolHandler.setHandshaker(ctx, handshaker);
80 ctx.getPipeline().replace(this, "WS403Responder",
81 WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
82 }
83 }
84 }
85
86 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
87 logger.error("Exception Caught", cause);
88 ctx.getChannel().close();
89 }
90
91 private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
92 ChannelFuture f = ctx.getChannel().write(res);
93 if (!isKeepAlive(req) || res.getStatus().getCode() != 200) {
94 f.addListener(ChannelFutureListener.CLOSE);
95 }
96 }
97
98 private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
99 String protocol = "ws";
100 if (cp.get(SslHandler.class) != null) {
101
102 protocol = "wss";
103 }
104 return protocol + "://" + req.headers().get(HttpHeaders.Names.HOST) + path;
105 }
106
107 }