1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.example.http.websocketx.server;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.Unpooled;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelPipeline;
24 import io.netty.channel.SimpleChannelInboundHandler;
25 import io.netty.handler.codec.http.DefaultFullHttpResponse;
26 import io.netty.handler.codec.http.FullHttpRequest;
27 import io.netty.handler.codec.http.FullHttpResponse;
28 import io.netty.handler.codec.http.HttpHeaders;
29 import io.netty.handler.codec.http.HttpRequest;
30 import io.netty.handler.ssl.SslHandler;
31 import io.netty.util.CharsetUtil;
32
33 import static io.netty.handler.codec.http.HttpMethod.GET;
34 import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
35 import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
36 import static io.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
37 import static io.netty.handler.codec.http.HttpResponseStatus.OK;
38 import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
39
40
41
42
43 public class WebSocketIndexPageHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
44
45 private final String websocketPath;
46
47 public WebSocketIndexPageHandler(String websocketPath) {
48 this.websocketPath = websocketPath;
49 }
50
51 @Override
52 protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) throws Exception {
53
54 if (!req.getDecoderResult().isSuccess()) {
55 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
56 return;
57 }
58
59
60 if (req.getMethod() != GET) {
61 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
62 return;
63 }
64
65
66 if ("/".equals(req.getUri()) || "/index.html".equals(req.getUri())) {
67 String webSocketLocation = getWebSocketLocation(ctx.pipeline(), req, websocketPath);
68 ByteBuf content = WebSocketServerIndexPage.getContent(webSocketLocation);
69 FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, OK, content);
70
71 res.headers().set(HttpHeaders.Names.CONTENT_TYPE, "text/html; charset=UTF-8");
72 HttpHeaders.setContentLength(res, content.readableBytes());
73
74 sendHttpResponse(ctx, req, res);
75 } else {
76 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, NOT_FOUND));
77 }
78 }
79
80 @Override
81 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
82 cause.printStackTrace();
83 ctx.close();
84 }
85
86 private static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
87
88 if (res.getStatus().code() != 200) {
89 ByteBuf buf = Unpooled.copiedBuffer(res.getStatus().toString(), CharsetUtil.UTF_8);
90 res.content().writeBytes(buf);
91 buf.release();
92 HttpHeaders.setContentLength(res, res.content().readableBytes());
93 }
94
95
96 ChannelFuture f = ctx.channel().writeAndFlush(res);
97 if (!HttpHeaders.isKeepAlive(req) || res.getStatus().code() != 200) {
98 f.addListener(ChannelFutureListener.CLOSE);
99 }
100 }
101
102 private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
103 String protocol = "ws";
104 if (cp.get(SslHandler.class) != null) {
105
106 protocol = "wss";
107 }
108 return protocol + "://" + req.headers().get(HttpHeaders.Names.HOST) + path;
109 }
110 }