1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.example.http.websocketx.server;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelPipeline;
24 import io.netty.channel.SimpleChannelInboundHandler;
25 import io.netty.handler.codec.http.DefaultFullHttpResponse;
26 import io.netty.handler.codec.http.FullHttpRequest;
27 import io.netty.handler.codec.http.FullHttpResponse;
28 import io.netty.handler.codec.http.HttpHeaderNames;
29 import io.netty.handler.codec.http.HttpRequest;
30 import io.netty.handler.codec.http.HttpResponseStatus;
31 import io.netty.handler.codec.http.HttpUtil;
32 import io.netty.handler.ssl.SslHandler;
33
34 import static io.netty.handler.codec.http.HttpHeaderNames.*;
35 import static io.netty.handler.codec.http.HttpMethod.*;
36 import static io.netty.handler.codec.http.HttpResponseStatus.*;
37
38
39
40
41 public class WebSocketIndexPageHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
42
43 private final String websocketPath;
44
45 public WebSocketIndexPageHandler(String websocketPath) {
46 this.websocketPath = websocketPath;
47 }
48
49 @Override
50 protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) throws Exception {
51
52 if (!req.decoderResult().isSuccess()) {
53 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), BAD_REQUEST,
54 ctx.alloc().buffer(0)));
55 return;
56 }
57
58
59 if (!GET.equals(req.method())) {
60 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), FORBIDDEN,
61 ctx.alloc().buffer(0)));
62 return;
63 }
64
65
66 if ("/".equals(req.uri()) || "/index.html".equals(req.uri())) {
67 String webSocketLocation = getWebSocketLocation(ctx.pipeline(), req, websocketPath);
68 ByteBuf content = WebSocketServerIndexPage.getContent(webSocketLocation);
69 FullHttpResponse res = new DefaultFullHttpResponse(req.protocolVersion(), OK, content);
70
71 res.headers().set(CONTENT_TYPE, "text/html; charset=UTF-8");
72 HttpUtil.setContentLength(res, content.readableBytes());
73
74 sendHttpResponse(ctx, req, res);
75 } else {
76 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(req.protocolVersion(), NOT_FOUND,
77 ctx.alloc().buffer(0)));
78 }
79 }
80
81 @Override
82 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
83 cause.printStackTrace();
84 ctx.close();
85 }
86
87 private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
88
89 HttpResponseStatus responseStatus = res.status();
90 if (responseStatus.code() != 200) {
91 ByteBufUtil.writeUtf8(res.content(), responseStatus.toString());
92 HttpUtil.setContentLength(res, res.content().readableBytes());
93 }
94
95 boolean keepAlive = HttpUtil.isKeepAlive(req) && responseStatus.code() == 200;
96 HttpUtil.setKeepAlive(res, keepAlive);
97 ChannelFuture future = ctx.writeAndFlush(res);
98 if (!keepAlive) {
99 future.addListener(ChannelFutureListener.CLOSE);
100 }
101 }
102
103 private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
104 String protocol = "ws";
105 if (cp.get(SslHandler.class) != null) {
106
107 protocol = "wss";
108 }
109 return protocol + "://" + req.headers().get(HttpHeaderNames.HOST) + path;
110 }
111 }