View Javadoc
1   /*
2    * Copyright 2020 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.quic;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.ChannelDuplexHandler;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.channel.ChannelPromise;
22  import io.netty.channel.MessageSizeEstimator;
23  import io.netty.channel.socket.DatagramPacket;
24  import io.netty.util.internal.logging.InternalLogger;
25  import io.netty.util.internal.logging.InternalLoggerFactory;
26  import org.jetbrains.annotations.Nullable;
27  
28  import java.net.InetSocketAddress;
29  import java.net.SocketAddress;
30  import java.nio.ByteBuffer;
31  import java.util.ArrayDeque;
32  import java.util.HashSet;
33  import java.util.Queue;
34  import java.util.Set;
35  import java.util.function.Consumer;
36  
37  import static io.netty.handler.codec.quic.Quiche.allocateNativeOrder;
38  
39  /**
40   * Abstract base class for QUIC codecs.
41   */
42  abstract class QuicheQuicCodec extends ChannelDuplexHandler {
43      private static final InternalLogger LOGGER = InternalLoggerFactory.getInstance(QuicheQuicCodec.class);
44      private final ConnectionIdChannelMap connectionIdToChannel = new ConnectionIdChannelMap();
45      private final Set<QuicheQuicChannel> channels = new HashSet<>();
46      private final Queue<QuicheQuicChannel> needsFireChannelReadComplete = new ArrayDeque<>();
47      private final Queue<QuicheQuicChannel> delayedRemoval = new ArrayDeque<>();
48  
49      private final Consumer<QuicheQuicChannel> freeTask = this::removeChannel;
50      private final FlushStrategy flushStrategy;
51      private final int localConnIdLength;
52      private final QuicheConfig config;
53  
54      private MessageSizeEstimator.Handle estimatorHandle;
55      private QuicHeaderParser headerParser;
56      private QuicHeaderParser.QuicHeaderProcessor parserCallback;
57      private int pendingBytes;
58      private int pendingPackets;
59      private boolean inChannelReadComplete;
60      private boolean delayRemoval;
61  
62      // This buffer is used to copy InetSocketAddress to sockaddr_storage and so pass it down the JNI layer.
63      private ByteBuf senderSockaddrMemory;
64      private ByteBuf recipientSockaddrMemory;
65  
66      QuicheQuicCodec(QuicheConfig config, int localConnIdLength, FlushStrategy flushStrategy) {
67          this.config = config;
68          this.localConnIdLength = localConnIdLength;
69          this.flushStrategy = flushStrategy;
70      }
71  
72      @Override
73      public final boolean isSharable() {
74          return false;
75      }
76  
77      @Nullable
78      protected final QuicheQuicChannel getChannel(ByteBuffer key) {
79          return connectionIdToChannel.get(key);
80      }
81  
82      private void addMapping(QuicheQuicChannel channel, ByteBuffer id) {
83          QuicheQuicChannel ch = connectionIdToChannel.put(id, channel);
84          assert ch == null;
85      }
86  
87      private void removeMapping(QuicheQuicChannel channel, ByteBuffer id) {
88          QuicheQuicChannel ch = connectionIdToChannel.remove(id);
89          assert ch == channel;
90      }
91  
92      private void processDelayedRemoval() {
93          for (;;) {
94              // Now remove all channels that we marked for removal.
95              QuicheQuicChannel toBeRemoved = delayedRemoval.poll();
96              if (toBeRemoved == null) {
97                  break;
98              }
99              removeChannel(toBeRemoved);
100         }
101     }
102 
103     private void removeChannel(QuicheQuicChannel channel) {
104         if (delayRemoval) {
105             boolean added = delayedRemoval.offer(channel);
106             assert added;
107         } else {
108             boolean removed = channels.remove(channel);
109             if (removed) {
110                 for (ByteBuffer id : channel.sourceConnectionIds()) {
111                     QuicheQuicChannel ch = connectionIdToChannel.remove(id);
112                     assert ch == channel;
113                 }
114             }
115         }
116     }
117 
118     protected final void addChannel(QuicheQuicChannel channel) {
119         boolean added = channels.add(channel);
120         assert added;
121         for (ByteBuffer id : channel.sourceConnectionIds()) {
122             QuicheQuicChannel ch = connectionIdToChannel.put(id.duplicate(), channel);
123             assert ch == null;
124         }
125     }
126 
127     @Override
128     public final void handlerAdded(ChannelHandlerContext ctx) {
129         senderSockaddrMemory = allocateNativeOrder(Quiche.SIZEOF_SOCKADDR_STORAGE);
130         recipientSockaddrMemory = allocateNativeOrder(Quiche.SIZEOF_SOCKADDR_STORAGE);
131         headerParser = new QuicHeaderParser(localConnIdLength);
132         parserCallback = new QuicCodecHeaderProcessor(ctx);
133         estimatorHandle = ctx.channel().config().getMessageSizeEstimator().newHandle();
134         handlerAdded(ctx, localConnIdLength);
135     }
136 
137     /**
138      * See {@link io.netty.channel.ChannelHandler#handlerAdded(ChannelHandlerContext)}.
139      */
140     protected void handlerAdded(ChannelHandlerContext ctx, int localConnIdLength) {
141         // NOOP.
142     }
143 
144     @Override
145     public void handlerRemoved(ChannelHandlerContext ctx) {
146         try {
147             // Use a copy of the array as closing the channel may cause an unwritable event that could also
148             // remove channels.
149             for (QuicheQuicChannel ch : channels.toArray(new QuicheQuicChannel[0])) {
150                 ch.forceClose();
151             }
152             if (pendingPackets > 0) {
153                 flushNow(ctx);
154             }
155         } finally {
156             channels.clear();
157             connectionIdToChannel.clear();
158             needsFireChannelReadComplete.clear();
159             delayedRemoval.clear();
160 
161             config.free();
162             if (senderSockaddrMemory != null) {
163                 senderSockaddrMemory.release();
164             }
165             if (recipientSockaddrMemory != null) {
166                 recipientSockaddrMemory.release();
167             }
168             if (headerParser != null) {
169                 headerParser.close();
170                 headerParser = null;
171             }
172         }
173     }
174 
175     @Override
176     public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
177         DatagramPacket packet = (DatagramPacket) msg;
178         try {
179             ByteBuf buffer = ((DatagramPacket) msg).content();
180             if (!buffer.isDirect()) {
181                 // We need a direct buffer as otherwise we can not access the memoryAddress.
182                 // Let's do a copy to direct memory.
183                 ByteBuf direct = ctx.alloc().directBuffer(buffer.readableBytes());
184                 try {
185                     direct.writeBytes(buffer, buffer.readerIndex(), buffer.readableBytes());
186                     handleQuicPacket(packet.sender(), packet.recipient(), direct);
187                 } finally {
188                     direct.release();
189                 }
190             } else {
191                 handleQuicPacket(packet.sender(), packet.recipient(), buffer);
192             }
193         } finally {
194             packet.release();
195         }
196     }
197 
198     private void handleQuicPacket(InetSocketAddress sender, InetSocketAddress recipient, ByteBuf buffer) {
199         try {
200             headerParser.parse(sender, recipient, buffer, parserCallback);
201         } catch (Exception e) {
202             LOGGER.debug("Error while processing QUIC packet", e);
203         }
204     }
205 
206     /**
207      * Handle a QUIC packet and return {@link QuicheQuicChannel} that is mapped to the id.
208      *
209      * @param ctx the {@link ChannelHandlerContext}.
210      * @param sender the {@link InetSocketAddress} of the sender of the QUIC packet
211      * @param recipient the {@link InetSocketAddress} of the recipient of the QUIC packet
212      * @param type the type of the packet.
213      * @param version the QUIC version
214      * @param scid the source connection id.
215      * @param dcid the destination connection id
216      * @param token the token
217      * @param senderSockaddrMemory the {@link ByteBuf} that can be used for the sender {@code struct sockaddr).
218      * @param recipientSockaddrMemory the {@link ByteBuf} that can be used for the recipient {@code struct sockaddr).
219      * @param freeTask the {@link Consumer} that will be called once native memory of the {@link QuicheQuicChannel} is
220      *                  freed and so the mappings should be deleted to the ids.
221      * @param localConnIdLength the length of the local connection ids.
222      * @param config the {@link QuicheConfig} that is used.
223      * @return the {@link QuicheQuicChannel} that is mapped to the id.
224      * @throws Exception  thrown if there is an error during processing.
225      */
226     @Nullable
227     protected abstract QuicheQuicChannel quicPacketRead(ChannelHandlerContext ctx, InetSocketAddress sender,
228                                                         InetSocketAddress recipient, QuicPacketType type, long version,
229                                                         ByteBuf scid, ByteBuf dcid, ByteBuf token,
230                                                         ByteBuf senderSockaddrMemory, ByteBuf recipientSockaddrMemory,
231                                                         Consumer<QuicheQuicChannel> freeTask,
232                                                         int localConnIdLength, QuicheConfig config) throws Exception;
233 
234     @Override
235     public final void channelReadComplete(ChannelHandlerContext ctx) {
236         inChannelReadComplete = true;
237         try {
238             for (;;) {
239                 QuicheQuicChannel channel = needsFireChannelReadComplete.poll();
240                 if (channel == null) {
241                     break;
242                 }
243                 channel.recvComplete();
244             }
245         } finally {
246             inChannelReadComplete = false;
247             if (pendingPackets > 0) {
248                 flushNow(ctx);
249             }
250         }
251     }
252 
253     @Override
254     public final void channelWritabilityChanged(ChannelHandlerContext ctx) {
255         if (ctx.channel().isWritable()) {
256             // Ensure we delay removal from the channels Set as otherwise we will might see an exception
257             // due modifications while iteration.
258             delayRemoval = true;
259             try {
260                 for (QuicheQuicChannel channel : channels) {
261                     // TODO: Be a bit smarter about this.
262                     channel.writable();
263                 }
264             } finally {
265                 // We are done with the loop, reset the flag and process the removals from the channels Set.
266                 delayRemoval = false;
267                 processDelayedRemoval();
268             }
269         } else {
270             // As we batch flushes we need to ensure we at least try to flush a batch once the channel becomes
271             // unwritable. Otherwise we may end up with buffering too much writes and so waste memory.
272             ctx.flush();
273         }
274 
275         ctx.fireChannelWritabilityChanged();
276     }
277 
278     @Override
279     public final void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)  {
280         pendingPackets ++;
281         int size = estimatorHandle.size(msg);
282         if (size > 0) {
283             pendingBytes += size;
284         }
285         try {
286             ctx.write(msg, promise);
287         } finally {
288             flushIfNeeded(ctx);
289         }
290     }
291 
292     @Override
293     public final void flush(ChannelHandlerContext ctx) {
294         // If we are in the channelReadComplete(...) method we might be able to delay the flush(...) until we finish
295         // processing all channels.
296         if (inChannelReadComplete) {
297             flushIfNeeded(ctx);
298         } else if (pendingPackets > 0) {
299             flushNow(ctx);
300         }
301     }
302 
303     @Override
304     public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
305                         ChannelPromise promise) throws Exception {
306         if (remoteAddress instanceof QuicheQuicChannelAddress) {
307             QuicheQuicChannelAddress addr = (QuicheQuicChannelAddress) remoteAddress;
308             QuicheQuicChannel channel = addr.channel;
309             connectQuicChannel(channel, remoteAddress, localAddress,
310                     senderSockaddrMemory, recipientSockaddrMemory, freeTask, localConnIdLength, config, promise);
311         } else {
312             ctx.connect(remoteAddress, localAddress, promise);
313         }
314     }
315 
316     /**
317      * Connects the given {@link QuicheQuicChannel}.
318      *
319      * @param channel                   the {@link QuicheQuicChannel} to connect.
320      * @param remoteAddress             the remote {@link SocketAddress}.
321      * @param localAddress              the local  {@link SocketAddress}
322      * @param senderSockaddrMemory      the {@link ByteBuf} that can be used for the sender {@code struct sockaddr).
323      * @param recipientSockaddrMemory   the {@link ByteBuf} that can be used for the recipient {@code struct sockaddr).
324      * @param freeTask                  the {@link Consumer} that will be called once native memory of the
325      *                                  {@link QuicheQuicChannel} is freed and so the mappings should be deleted to
326      *                                  the ids.
327      * @param localConnIdLength         the length of the local connection ids.
328      * @param config                    the {@link QuicheConfig} that is used.
329      * @param promise                   the {@link ChannelPromise} to notify once the connect is done.
330      */
331     protected abstract void connectQuicChannel(QuicheQuicChannel channel, SocketAddress remoteAddress,
332                                                SocketAddress localAddress, ByteBuf senderSockaddrMemory,
333                                                ByteBuf recipientSockaddrMemory, Consumer<QuicheQuicChannel> freeTask,
334                                                int localConnIdLength, QuicheConfig config, ChannelPromise promise);
335 
336     private void flushIfNeeded(ChannelHandlerContext ctx) {
337         // Check if we should force a flush() and so ensure the packets are delivered in a timely
338         // manner and also make room in the outboundbuffer again that belongs to the underlying channel.
339         if (flushStrategy.shouldFlushNow(pendingPackets, pendingBytes)) {
340             flushNow(ctx);
341         }
342     }
343 
344     private void flushNow(ChannelHandlerContext ctx) {
345         pendingBytes = 0;
346         pendingPackets = 0;
347         ctx.flush();
348     }
349 
350     private final class QuicCodecHeaderProcessor implements QuicHeaderParser.QuicHeaderProcessor {
351 
352         private final ChannelHandlerContext ctx;
353 
354         QuicCodecHeaderProcessor(ChannelHandlerContext ctx) {
355             this.ctx = ctx;
356         }
357 
358         @Override
359         public void process(InetSocketAddress sender, InetSocketAddress recipient, ByteBuf buffer, QuicPacketType type,
360                             long version, ByteBuf scid, ByteBuf dcid, ByteBuf token) throws Exception {
361             QuicheQuicChannel channel = quicPacketRead(ctx, sender, recipient,
362                     type, version, scid,
363                     dcid, token, senderSockaddrMemory, recipientSockaddrMemory, freeTask, localConnIdLength, config);
364             if (channel != null) {
365                 // Add to queue first, we might be able to safe some flushes and consolidate them
366                 // in channelReadComplete(...) this way.
367                 if (channel.markInFireChannelReadCompleteQueue()) {
368                     needsFireChannelReadComplete.add(channel);
369                 }
370                 channel.recv(sender, recipient, buffer);
371                 for (ByteBuffer retiredSourceConnectionId : channel.retiredSourceConnectionId()) {
372                     removeMapping(channel, retiredSourceConnectionId);
373                 }
374                 for (ByteBuffer newSourceConnectionId : channel.newSourceConnectionIds()) {
375                     addMapping(channel, newSourceConnectionId);
376                 }
377             }
378         }
379     }
380 }