1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.example.http.websocketx.benchmarkserver;
17
18 import io.netty5.buffer.api.Buffer;
19 import io.netty5.channel.ChannelFutureListeners;
20 import io.netty5.channel.ChannelHandlerContext;
21 import io.netty5.channel.SimpleChannelInboundHandler;
22 import io.netty5.handler.codec.http.DefaultFullHttpResponse;
23 import io.netty5.handler.codec.http.FullHttpRequest;
24 import io.netty5.handler.codec.http.FullHttpResponse;
25 import io.netty5.handler.codec.http.HttpHeaderNames;
26 import io.netty5.handler.codec.http.HttpUtil;
27 import io.netty5.handler.codec.http.websocketx.BinaryWebSocketFrame;
28 import io.netty5.handler.codec.http.websocketx.CloseWebSocketFrame;
29 import io.netty5.handler.codec.http.websocketx.PingWebSocketFrame;
30 import io.netty5.handler.codec.http.websocketx.PongWebSocketFrame;
31 import io.netty5.handler.codec.http.websocketx.TextWebSocketFrame;
32 import io.netty5.handler.codec.http.websocketx.WebSocketFrame;
33 import io.netty5.handler.codec.http.websocketx.WebSocketServerHandshaker;
34 import io.netty5.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
35 import io.netty5.util.CharsetUtil;
36
37 import static io.netty5.handler.codec.http.HttpHeaderNames.CONNECTION;
38 import static io.netty5.handler.codec.http.HttpHeaderValues.CLOSE;
39 import static io.netty5.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
40 import static io.netty5.handler.codec.http.HttpMethod.GET;
41 import static io.netty5.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
42 import static io.netty5.handler.codec.http.HttpResponseStatus.FORBIDDEN;
43 import static io.netty5.handler.codec.http.HttpResponseStatus.NOT_FOUND;
44 import static io.netty5.handler.codec.http.HttpResponseStatus.OK;
45 import static io.netty5.handler.codec.http.HttpVersion.HTTP_1_0;
46 import static io.netty5.handler.codec.http.HttpVersion.HTTP_1_1;
47
48
49
50
51 public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> {
52
53 private static final String WEBSOCKET_PATH = "/websocket";
54
55 private WebSocketServerHandshaker handshaker;
56
57 @Override
58 public void messageReceived(ChannelHandlerContext ctx, Object msg) {
59 if (msg instanceof FullHttpRequest) {
60 handleHttpRequest(ctx, (FullHttpRequest) msg);
61 } else if (msg instanceof WebSocketFrame) {
62 handleWebSocketFrame(ctx, (WebSocketFrame) msg);
63 }
64 }
65
66 @Override
67 public void channelReadComplete(ChannelHandlerContext ctx) {
68 ctx.flush();
69 }
70
71 private void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) {
72
73 if (!req.decoderResult().isSuccess()) {
74 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, BAD_REQUEST,
75 ctx.bufferAllocator().allocate(0)));
76 return;
77 }
78
79
80 if (!GET.equals(req.method())) {
81 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN,
82 ctx.bufferAllocator().allocate(0)));
83 return;
84 }
85
86
87 if ("/".equals(req.uri())) {
88 Buffer content = WebSocketServerBenchmarkPage.getContent(ctx.bufferAllocator(), getWebSocketLocation(req));
89 FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, OK, content);
90
91 res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html; charset=UTF-8");
92 HttpUtil.setContentLength(res, content.readableBytes());
93
94 sendHttpResponse(ctx, req, res);
95 return;
96 }
97 if ("/favicon.ico".equals(req.uri())) {
98 FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, NOT_FOUND, ctx.bufferAllocator().allocate(0));
99 sendHttpResponse(ctx, req, res);
100 return;
101 }
102
103
104 WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
105 getWebSocketLocation(req), null, true, 5 * 1024 * 1024);
106 handshaker = wsFactory.newHandshaker(req);
107 if (handshaker == null) {
108 WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
109 } else {
110 handshaker.handshake(ctx.channel(), req);
111 }
112 }
113
114 private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
115
116 try (frame) {
117 if (frame instanceof CloseWebSocketFrame) {
118 handshaker.close(ctx, (CloseWebSocketFrame) frame.send().receive());
119 return;
120 }
121 if (frame instanceof PingWebSocketFrame) {
122 ctx.write(new PongWebSocketFrame(frame.binaryData().send()));
123 return;
124 }
125 if (frame instanceof TextWebSocketFrame) {
126
127 ctx.write(frame.send().receive());
128 return;
129 }
130 if (frame instanceof BinaryWebSocketFrame) {
131
132 ctx.write(frame.send().receive());
133 }
134 }
135 }
136
137 private static void sendHttpResponse(
138 ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse res) {
139
140 if (res.status().code() != 200) {
141 res.payload().writeCharSequence(res.status().toString(), CharsetUtil.UTF_8);
142 HttpUtil.setContentLength(res, res.payload().readableBytes());
143 }
144
145
146 if (!HttpUtil.isKeepAlive(req) || res.status().code() != 200) {
147
148 res.headers().set(CONNECTION, CLOSE);
149 ctx.writeAndFlush(res).addListener(ctx, ChannelFutureListeners.CLOSE);
150 } else {
151 if (req.protocolVersion().equals(HTTP_1_0)) {
152 res.headers().set(CONNECTION, KEEP_ALIVE);
153 }
154 ctx.writeAndFlush(res);
155 }
156 }
157
158 @Override
159 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
160 cause.printStackTrace();
161 ctx.close();
162 }
163
164 private static String getWebSocketLocation(FullHttpRequest req) {
165 String location = req.headers().get(HttpHeaderNames.HOST) + WEBSOCKET_PATH;
166 if (WebSocketServer.SSL) {
167 return "wss://" + location;
168 } else {
169 return "ws://" + location;
170 }
171 }
172 }