View Javadoc
1   /*
2    * Copyright 2020 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.example.stomp.websocket;
17  
18  import io.netty.channel.ChannelFuture;
19  import io.netty.channel.ChannelFutureListener;
20  import io.netty.channel.ChannelHandler.Sharable;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.SimpleChannelInboundHandler;
23  import io.netty.handler.codec.DecoderResult;
24  import io.netty.handler.codec.stomp.DefaultStompFrame;
25  import io.netty.handler.codec.stomp.StompCommand;
26  import io.netty.handler.codec.stomp.StompFrame;
27  import io.netty.util.CharsetUtil;
28  
29  import java.util.HashSet;
30  import java.util.Iterator;
31  import java.util.Map.Entry;
32  import java.util.Set;
33  import java.util.UUID;
34  import java.util.concurrent.ConcurrentHashMap;
35  import java.util.concurrent.ConcurrentMap;
36  
37  import static io.netty.handler.codec.stomp.StompHeaders.*;
38  
39  @Sharable
40  public class StompChatHandler extends SimpleChannelInboundHandler<StompFrame> {
41  
42      private final ConcurrentMap<String, Set<StompSubscription>> chatDestinations =
43              new ConcurrentHashMap<String, Set<StompSubscription>>();
44  
45      @Override
46      protected void channelRead0(ChannelHandlerContext ctx, StompFrame inboundFrame) throws Exception {
47          DecoderResult decoderResult = inboundFrame.decoderResult();
48          if (decoderResult.isFailure()) {
49              sendErrorFrame("rejected frame", decoderResult.toString(), ctx);
50              return;
51          }
52  
53          switch (inboundFrame.command()) {
54          case STOMP:
55          case CONNECT:
56              onConnect(ctx, inboundFrame);
57              break;
58          case SUBSCRIBE:
59              onSubscribe(ctx, inboundFrame);
60              break;
61          case SEND:
62              onSend(ctx, inboundFrame);
63              break;
64          case UNSUBSCRIBE:
65              onUnsubscribe(ctx, inboundFrame);
66              break;
67          case DISCONNECT:
68              onDisconnect(ctx, inboundFrame);
69              break;
70          default:
71              sendErrorFrame("unsupported command",
72                             "Received unsupported command " + inboundFrame.command(), ctx);
73          }
74      }
75  
76      private void onSubscribe(ChannelHandlerContext ctx, StompFrame inboundFrame) {
77          String destination = inboundFrame.headers().getAsString(DESTINATION);
78          String subscriptionId = inboundFrame.headers().getAsString(ID);
79  
80          if (destination == null || subscriptionId == null) {
81              sendErrorFrame("missed header", "Required 'destination' or 'id' header missed", ctx);
82              return;
83          }
84  
85          Set<StompSubscription> subscriptions = chatDestinations.get(destination);
86          if (subscriptions == null) {
87              subscriptions = new HashSet<StompSubscription>();
88              Set<StompSubscription> previousSubscriptions = chatDestinations.putIfAbsent(destination, subscriptions);
89              if (previousSubscriptions != null) {
90                  subscriptions = previousSubscriptions;
91              }
92          }
93  
94          final StompSubscription subscription = new StompSubscription(subscriptionId, destination, ctx.channel());
95          if (subscriptions.contains(subscription)) {
96              sendErrorFrame("duplicate subscription",
97                             "Received duplicate subscription id=" + subscriptionId, ctx);
98              return;
99          }
100 
101         subscriptions.add(subscription);
102         ctx.channel().closeFuture().addListener(f ->
103                 chatDestinations.get(subscription.destination()).remove(subscription));
104 
105         String receiptId = inboundFrame.headers().getAsString(RECEIPT);
106         if (receiptId != null) {
107             StompFrame receiptFrame = new DefaultStompFrame(StompCommand.RECEIPT);
108             receiptFrame.headers().set(RECEIPT_ID, receiptId);
109             ctx.writeAndFlush(receiptFrame);
110         }
111     }
112 
113     private void onSend(ChannelHandlerContext ctx, StompFrame inboundFrame) {
114         String destination = inboundFrame.headers().getAsString(DESTINATION);
115         if (destination == null) {
116             sendErrorFrame("missed header", "required 'destination' header missed", ctx);
117             return;
118         }
119 
120         Set<StompSubscription> subscriptions = chatDestinations.get(destination);
121         for (StompSubscription subscription : subscriptions) {
122             subscription.channel().writeAndFlush(transformToMessage(inboundFrame, subscription));
123         }
124     }
125 
126     private void onUnsubscribe(ChannelHandlerContext ctx, StompFrame inboundFrame) {
127         String subscriptionId = inboundFrame.headers().getAsString(SUBSCRIPTION);
128         for (Entry<String, Set<StompSubscription>> entry : chatDestinations.entrySet()) {
129             Iterator<StompSubscription> iterator = entry.getValue().iterator();
130             while (iterator.hasNext()) {
131                 StompSubscription subscription = iterator.next();
132                 if (subscription.id().equals(subscriptionId) && subscription.channel().equals(ctx.channel())) {
133                     iterator.remove();
134                     return;
135                 }
136             }
137         }
138     }
139 
140     private static void onConnect(ChannelHandlerContext ctx, StompFrame inboundFrame) {
141         String acceptVersions = inboundFrame.headers().getAsString(ACCEPT_VERSION);
142         StompVersion handshakeAcceptVersion = ctx.channel().attr(StompVersion.CHANNEL_ATTRIBUTE_KEY).get();
143         if (acceptVersions == null || !acceptVersions.contains(handshakeAcceptVersion.version())) {
144             sendErrorFrame("invalid version",
145                            "Received invalid version, expected " + handshakeAcceptVersion.version(), ctx);
146             return;
147         }
148 
149         StompFrame connectedFrame = new DefaultStompFrame(StompCommand.CONNECTED);
150         connectedFrame.headers()
151                       .set(VERSION, handshakeAcceptVersion.version())
152                       .set(SERVER, "Netty-Server")
153                       .set(HEART_BEAT, "0,0");
154         ctx.writeAndFlush(connectedFrame);
155     }
156 
157     private static void onDisconnect(ChannelHandlerContext ctx, StompFrame inboundFrame) {
158         String receiptId = inboundFrame.headers().getAsString(RECEIPT);
159         if (receiptId == null) {
160             ctx.close();
161             return;
162         }
163 
164         StompFrame receiptFrame = new DefaultStompFrame(StompCommand.RECEIPT);
165         receiptFrame.headers().set(RECEIPT_ID, receiptId);
166         ctx.writeAndFlush(receiptFrame).addListener(ChannelFutureListener.CLOSE);
167     }
168 
169     private static void sendErrorFrame(String message, String description, ChannelHandlerContext ctx) {
170         StompFrame errorFrame = new DefaultStompFrame(StompCommand.ERROR);
171         errorFrame.headers().set(MESSAGE, message);
172 
173         if (description != null) {
174             errorFrame.content().writeCharSequence(description, CharsetUtil.UTF_8);
175         }
176 
177         ctx.writeAndFlush(errorFrame).addListener(ChannelFutureListener.CLOSE);
178     }
179 
180     private static StompFrame transformToMessage(StompFrame sendFrame, StompSubscription subscription) {
181         StompFrame messageFrame = new DefaultStompFrame(StompCommand.MESSAGE, sendFrame.content().retainedDuplicate());
182         String id = UUID.randomUUID().toString();
183         messageFrame.headers()
184                     .set(MESSAGE_ID, id)
185                     .set(SUBSCRIPTION, subscription.id())
186                     .set(CONTENT_LENGTH, Integer.toString(messageFrame.content().readableBytes()));
187 
188         CharSequence contentType = sendFrame.headers().get(CONTENT_TYPE);
189         if (contentType != null) {
190             messageFrame.headers().set(CONTENT_TYPE, contentType);
191         }
192 
193         return messageFrame;
194     }
195 }