View Javadoc
1   /*
2    * Copyright 2020 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.example.stomp.websocket;
17  
18  import io.netty.channel.ChannelFuture;
19  import io.netty.channel.ChannelFutureListener;
20  import io.netty.channel.ChannelHandler.Sharable;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.SimpleChannelInboundHandler;
23  import io.netty.handler.codec.DecoderResult;
24  import io.netty.handler.codec.stomp.DefaultStompFrame;
25  import io.netty.handler.codec.stomp.StompCommand;
26  import io.netty.handler.codec.stomp.StompFrame;
27  import io.netty.util.CharsetUtil;
28  
29  import java.util.HashSet;
30  import java.util.Iterator;
31  import java.util.Map.Entry;
32  import java.util.Set;
33  import java.util.UUID;
34  import java.util.concurrent.ConcurrentHashMap;
35  import java.util.concurrent.ConcurrentMap;
36  
37  import static io.netty.handler.codec.stomp.StompHeaders.*;
38  
39  @Sharable
40  public class StompChatHandler extends SimpleChannelInboundHandler<StompFrame> {
41  
42      private final ConcurrentMap<String, Set<StompSubscription>> chatDestinations =
43              new ConcurrentHashMap<String, Set<StompSubscription>>();
44  
45      @Override
46      protected void channelRead0(ChannelHandlerContext ctx, StompFrame inboundFrame) throws Exception {
47          DecoderResult decoderResult = inboundFrame.decoderResult();
48          if (decoderResult.isFailure()) {
49              sendErrorFrame("rejected frame", decoderResult.toString(), ctx);
50              return;
51          }
52  
53          switch (inboundFrame.command()) {
54          case STOMP:
55          case CONNECT:
56              onConnect(ctx, inboundFrame);
57              break;
58          case SUBSCRIBE:
59              onSubscribe(ctx, inboundFrame);
60              break;
61          case SEND:
62              onSend(ctx, inboundFrame);
63              break;
64          case UNSUBSCRIBE:
65              onUnsubscribe(ctx, inboundFrame);
66              break;
67          case DISCONNECT:
68              onDisconnect(ctx, inboundFrame);
69              break;
70          default:
71              sendErrorFrame("unsupported command",
72                             "Received unsupported command " + inboundFrame.command(), ctx);
73          }
74      }
75  
76      private void onSubscribe(ChannelHandlerContext ctx, StompFrame inboundFrame) {
77          String destination = inboundFrame.headers().getAsString(DESTINATION);
78          String subscriptionId = inboundFrame.headers().getAsString(ID);
79  
80          if (destination == null || subscriptionId == null) {
81              sendErrorFrame("missed header", "Required 'destination' or 'id' header missed", ctx);
82              return;
83          }
84  
85          Set<StompSubscription> subscriptions = chatDestinations.get(destination);
86          if (subscriptions == null) {
87              subscriptions = new HashSet<StompSubscription>();
88              Set<StompSubscription> previousSubscriptions = chatDestinations.putIfAbsent(destination, subscriptions);
89              if (previousSubscriptions != null) {
90                  subscriptions = previousSubscriptions;
91              }
92          }
93  
94          final StompSubscription subscription = new StompSubscription(subscriptionId, destination, ctx.channel());
95          if (subscriptions.contains(subscription)) {
96              sendErrorFrame("duplicate subscription",
97                             "Received duplicate subscription id=" + subscriptionId, ctx);
98              return;
99          }
100 
101         subscriptions.add(subscription);
102         ctx.channel().closeFuture().addListener(new ChannelFutureListener() {
103             @Override
104             public void operationComplete(ChannelFuture future) {
105                 chatDestinations.get(subscription.destination()).remove(subscription);
106             }
107         });
108 
109         String receiptId = inboundFrame.headers().getAsString(RECEIPT);
110         if (receiptId != null) {
111             StompFrame receiptFrame = new DefaultStompFrame(StompCommand.RECEIPT);
112             receiptFrame.headers().set(RECEIPT_ID, receiptId);
113             ctx.writeAndFlush(receiptFrame);
114         }
115     }
116 
117     private void onSend(ChannelHandlerContext ctx, StompFrame inboundFrame) {
118         String destination = inboundFrame.headers().getAsString(DESTINATION);
119         if (destination == null) {
120             sendErrorFrame("missed header", "required 'destination' header missed", ctx);
121             return;
122         }
123 
124         Set<StompSubscription> subscriptions = chatDestinations.get(destination);
125         for (StompSubscription subscription : subscriptions) {
126             subscription.channel().writeAndFlush(transformToMessage(inboundFrame, subscription));
127         }
128     }
129 
130     private void onUnsubscribe(ChannelHandlerContext ctx, StompFrame inboundFrame) {
131         String subscriptionId = inboundFrame.headers().getAsString(SUBSCRIPTION);
132         for (Entry<String, Set<StompSubscription>> entry : chatDestinations.entrySet()) {
133             Iterator<StompSubscription> iterator = entry.getValue().iterator();
134             while (iterator.hasNext()) {
135                 StompSubscription subscription = iterator.next();
136                 if (subscription.id().equals(subscriptionId) && subscription.channel().equals(ctx.channel())) {
137                     iterator.remove();
138                     return;
139                 }
140             }
141         }
142     }
143 
144     private static void onConnect(ChannelHandlerContext ctx, StompFrame inboundFrame) {
145         String acceptVersions = inboundFrame.headers().getAsString(ACCEPT_VERSION);
146         StompVersion handshakeAcceptVersion = ctx.channel().attr(StompVersion.CHANNEL_ATTRIBUTE_KEY).get();
147         if (acceptVersions == null || !acceptVersions.contains(handshakeAcceptVersion.version())) {
148             sendErrorFrame("invalid version",
149                            "Received invalid version, expected " + handshakeAcceptVersion.version(), ctx);
150             return;
151         }
152 
153         StompFrame connectedFrame = new DefaultStompFrame(StompCommand.CONNECTED);
154         connectedFrame.headers()
155                       .set(VERSION, handshakeAcceptVersion.version())
156                       .set(SERVER, "Netty-Server")
157                       .set(HEART_BEAT, "0,0");
158         ctx.writeAndFlush(connectedFrame);
159     }
160 
161     private static void onDisconnect(ChannelHandlerContext ctx, StompFrame inboundFrame) {
162         String receiptId = inboundFrame.headers().getAsString(RECEIPT);
163         if (receiptId == null) {
164             ctx.close();
165             return;
166         }
167 
168         StompFrame receiptFrame = new DefaultStompFrame(StompCommand.RECEIPT);
169         receiptFrame.headers().set(RECEIPT_ID, receiptId);
170         ctx.writeAndFlush(receiptFrame).addListener(ChannelFutureListener.CLOSE);
171     }
172 
173     private static void sendErrorFrame(String message, String description, ChannelHandlerContext ctx) {
174         StompFrame errorFrame = new DefaultStompFrame(StompCommand.ERROR);
175         errorFrame.headers().set(MESSAGE, message);
176 
177         if (description != null) {
178             errorFrame.content().writeCharSequence(description, CharsetUtil.UTF_8);
179         }
180 
181         ctx.writeAndFlush(errorFrame).addListener(ChannelFutureListener.CLOSE);
182     }
183 
184     private static StompFrame transformToMessage(StompFrame sendFrame, StompSubscription subscription) {
185         StompFrame messageFrame = new DefaultStompFrame(StompCommand.MESSAGE, sendFrame.content().retainedDuplicate());
186         String id = UUID.randomUUID().toString();
187         messageFrame.headers()
188                     .set(MESSAGE_ID, id)
189                     .set(SUBSCRIPTION, subscription.id())
190                     .set(CONTENT_LENGTH, Integer.toString(messageFrame.content().readableBytes()));
191 
192         CharSequence contentType = sendFrame.headers().get(CONTENT_TYPE);
193         if (contentType != null) {
194             messageFrame.headers().set(CONTENT_TYPE, contentType);
195         }
196 
197         return messageFrame;
198     }
199 }