1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.example.http.websocketx.benchmarkserver;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.Unpooled;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.SimpleChannelInboundHandler;
24 import io.netty.handler.codec.http.DefaultFullHttpResponse;
25 import io.netty.handler.codec.http.FullHttpRequest;
26 import io.netty.handler.codec.http.FullHttpResponse;
27 import io.netty.handler.codec.http.HttpHeaders;
28 import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
29 import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
30 import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
31 import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
32 import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
33 import io.netty.handler.codec.http.websocketx.WebSocketFrame;
34 import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
35 import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
36 import io.netty.util.CharsetUtil;
37
38 import static io.netty.handler.codec.http.HttpHeaders.Names.*;
39 import static io.netty.handler.codec.http.HttpMethod.*;
40 import static io.netty.handler.codec.http.HttpResponseStatus.*;
41 import static io.netty.handler.codec.http.HttpVersion.*;
42
43
44
45
46 public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> {
47
48 private static final String WEBSOCKET_PATH = "/websocket";
49
50 private WebSocketServerHandshaker handshaker;
51
52 @Override
53 public void channelRead0(ChannelHandlerContext ctx, Object msg) {
54 if (msg instanceof FullHttpRequest) {
55 handleHttpRequest(ctx, (FullHttpRequest) msg);
56 } else if (msg instanceof WebSocketFrame) {
57 handleWebSocketFrame(ctx, (WebSocketFrame) msg);
58 }
59 }
60
61 @Override
62 public void channelReadComplete(ChannelHandlerContext ctx) {
63 ctx.flush();
64 }
65
66 private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) {
67
68 if (!req.getDecoderResult().isSuccess()) {
69 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST));
70 return;
71 }
72
73
74 if (req.getMethod() != GET) {
75 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
76 return;
77 }
78
79
80 if ("/".equals(req.getUri())) {
81 ByteBuf content = WebSocketServerBenchmarkPage.getContent(getWebSocketLocation(req));
82 FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, OK, content);
83
84 res.headers().set(CONTENT_TYPE, "text/html; charset=UTF-8");
85 HttpHeaders.setContentLength(res, content.readableBytes());
86
87 sendHttpResponse(ctx, req, res);
88 return;
89 }
90 if ("/favicon.ico".equals(req.getUri())) {
91 FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, NOT_FOUND);
92 sendHttpResponse(ctx, req, res);
93 return;
94 }
95
96
97 WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
98 getWebSocketLocation(req), null, true, 5 * 1024 * 1024);
99 handshaker = wsFactory.newHandshaker(req);
100 if (handshaker == null) {
101 WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
102 } else {
103 handshaker.handshake(ctx.channel(), req);
104 }
105 }
106
107 private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
108
109
110 if (frame instanceof CloseWebSocketFrame) {
111 handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
112 return;
113 }
114 if (frame instanceof PingWebSocketFrame) {
115 ctx.write(new PongWebSocketFrame(frame.content().retain()));
116 return;
117 }
118 if (frame instanceof TextWebSocketFrame) {
119
120 ctx.write(frame.retain());
121 return;
122 }
123 if (frame instanceof BinaryWebSocketFrame) {
124
125 ctx.write(frame.retain());
126 }
127 }
128
129 private static void sendHttpResponse(
130 ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
131
132 if (res.getStatus().code() != 200) {
133 ByteBuf buf = Unpooled.copiedBuffer(res.getStatus().toString(), CharsetUtil.UTF_8);
134 res.content().writeBytes(buf);
135 buf.release();
136 HttpHeaders.setContentLength(res, res.content().readableBytes());
137 }
138
139
140 ChannelFuture f = ctx.channel().writeAndFlush(res);
141 if (!HttpHeaders.isKeepAlive(req) || res.getStatus().code() != 200) {
142 f.addListener(ChannelFutureListener.CLOSE);
143 }
144 }
145
146 @Override
147 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
148 cause.printStackTrace();
149 ctx.close();
150 }
151
152 private static String getWebSocketLocation(FullHttpRequest req) {
153 String location = req.headers().get(HOST) + WEBSOCKET_PATH;
154 if (WebSocketServer.SSL) {
155 return "wss://" + location;
156 } else {
157 return "ws://" + location;
158 }
159 }
160 }