View Javadoc
1   /*
2    * Copyright 2024 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.quic;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.channel.ChannelInboundHandlerAdapter;
22  import io.netty.channel.socket.DatagramPacket;
23  import io.netty.util.internal.ObjectUtil;
24  import org.jetbrains.annotations.Nullable;
25  
26  import java.nio.ByteBuffer;
27  import java.util.List;
28  import java.util.concurrent.CopyOnWriteArrayList;
29  import java.util.concurrent.atomic.AtomicBoolean;
30  
31  
32  /**
33   * Special {@link io.netty.channel.ChannelHandler} that should be used to init {@link Channel}s that will be used
34   * for QUIC while <a href="https://man7.org/linux/man-pages/man7/socket.7.html">SO_REUSEPORT</a> is used to
35   * bind to same {@link java.net.InetSocketAddress} multiple times. This is necessary to ensure QUIC packets are always
36   * dispatched to the correct codec that keeps the mapping for the connection id.
37   * This implementation use a very simple mapping strategy by encoding the index of the internal datastructure that
38   * keeps track of the different {@link ChannelHandlerContext}s into the destination connection id. This way once a
39   * {@code QUIC} packet is received its possible to forward it to the right codec.
40   * Subclasses might change how encoding / decoding of the index is done by overriding {@link #decodeIndex(ByteBuf)}
41   * and {@link #newIdGenerator(int)}.
42   * <p>
43   * It is important that the same {@link QuicCodecDispatcher} instance is shared between all the {@link Channel}s that
44   * are bound to the same {@link java.net.InetSocketAddress} and use {@code SO_REUSEPORT}.
45   * <p>
46   * An alternative way to handle this would be to do the "routing" to the correct socket in an {@code epbf} program
47   * by implementing your own {@link QuicConnectionIdGenerator} that issue ids that can be understood and handled by the
48   * {@code epbf} program to route the packet to the correct socket.
49   *
50   */
51  public abstract class QuicCodecDispatcher extends ChannelInboundHandlerAdapter {
52      // 20 is the max as per RFC.
53      // See https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
54      private static final int MAX_LOCAL_CONNECTION_ID_LENGTH = 20;
55  
56      // Use a CopyOnWriteArrayList as modifications to the List should only happen during bootstrapping and teardown
57      // of the channels.
58      private final List<ChannelHandlerContextDispatcher> contextList = new CopyOnWriteArrayList<>();
59      private final int localConnectionIdLength;
60  
61      /**
62       * Create a new instance using the default connection id length.
63       */
64      protected QuicCodecDispatcher() {
65          this(MAX_LOCAL_CONNECTION_ID_LENGTH);
66      }
67  
68      /**
69       * Create a new instance
70       *
71       * @param localConnectionIdLength   the local connection id length. This must be between 10 and 20.
72       */
73      protected QuicCodecDispatcher(int localConnectionIdLength) {
74          // Let's use 10 as a minimum to ensure we still have some bytes left for randomness as we already use
75          // 2 of the bytes to encode the index.
76          this.localConnectionIdLength = ObjectUtil.checkInRange(localConnectionIdLength,
77                  10, MAX_LOCAL_CONNECTION_ID_LENGTH, "localConnectionIdLength");
78      }
79  
80      @Override
81      public final boolean isSharable() {
82          return true;
83      }
84  
85      @Override
86      public final void handlerAdded(ChannelHandlerContext ctx) throws Exception {
87          super.handlerAdded(ctx);
88  
89          ChannelHandlerContextDispatcher ctxDispatcher = new ChannelHandlerContextDispatcher(ctx);
90          contextList.add(ctxDispatcher);
91          int idx = contextList.indexOf(ctxDispatcher);
92          try {
93              QuicConnectionIdGenerator idGenerator = newIdGenerator((short) idx);
94              initChannel(ctx.channel(), localConnectionIdLength, idGenerator);
95          } catch (Exception e) {
96              // Null out on exception and rethrow. We not remove the element as the indices need to be
97              // stable.
98              contextList.set(idx, null);
99              throw e;
100         }
101     }
102 
103     @Override
104     public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
105         super.handlerRemoved(ctx);
106 
107         for (int idx = 0; idx < contextList.size(); idx++) {
108             ChannelHandlerContextDispatcher ctxDispatcher = contextList.get(idx);
109             if (ctxDispatcher != null && ctxDispatcher.ctx.equals(ctx)) {
110                 // null out, so we can collect the ChannelHandlerContext that was stored in the List.
111                 contextList.set(idx, null);
112                 break;
113             }
114         }
115     }
116 
117     @Override
118     public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
119         DatagramPacket packet = (DatagramPacket) msg;
120         ByteBuf connectionId = getDestinationConnectionId(packet.content(), localConnectionIdLength);
121         if (connectionId != null) {
122             int idx = decodeIndex(connectionId);
123             if (contextList.size() > idx) {
124                 ChannelHandlerContextDispatcher selectedCtx = contextList.get(idx);
125                 if (selectedCtx != null) {
126                     selectedCtx.fireChannelRead(msg);
127                     return;
128                 }
129             }
130         }
131         // We were not be-able to dispatch to a specific ChannelHandlerContext, just forward and let the
132         // Quic*Codec handle it directly.
133         ctx.fireChannelRead(msg);
134     }
135 
136     @Override
137     public final void channelReadComplete(ChannelHandlerContext ctx) {
138         // Loop over all ChannelHandlerContextDispatchers and ensure fireChannelReadComplete() is called if required.
139         // We use and old style for loop as CopyOnWriteArrayList implements RandomAccess and so we can
140         // reduce the object creations.
141         boolean dispatchForOwnContextAlready = false;
142         for (int i = 0; i < contextList.size(); i++) {
143             ChannelHandlerContextDispatcher ctxDispatcher = contextList.get(i);
144             if (ctxDispatcher != null) {
145                 boolean fired = ctxDispatcher.fireChannelReadCompleteIfNeeded();
146                 if (fired && !dispatchForOwnContextAlready) {
147                     // Check if we dispatched to ctx so if we didnt at the end we can do it manually.
148                     dispatchForOwnContextAlready = ctx.equals(ctxDispatcher.ctx);
149                 }
150             }
151         }
152         if (!dispatchForOwnContextAlready) {
153             ctx.fireChannelReadComplete();
154         }
155     }
156 
157     /**
158      * Init the {@link Channel} and add all the needed {@link io.netty.channel.ChannelHandler} to the pipeline.
159      * This also included building the {@code QUIC} codec via {@link QuicCodecBuilder} sub-type using the given local
160      * connection id length and {@link QuicConnectionIdGenerator}.
161      *
162      * @param channel                   the {@link Channel} to init.
163      * @param localConnectionIdLength   the local connection id length that must be used with the
164      *                                  {@link QuicCodecBuilder}.
165      * @param idGenerator               the {@link QuicConnectionIdGenerator} that must be used with the
166      *                                  {@link QuicCodecBuilder}.
167      * @throws Exception                thrown on error.
168      */
169     protected abstract void initChannel(Channel channel, int localConnectionIdLength,
170                                         QuicConnectionIdGenerator idGenerator) throws Exception;
171 
172     /**
173      * Return the idx that was encoded into the connectionId via the {@link QuicConnectionIdGenerator} before,
174      * or {@code -1} if decoding was not successful.
175      * <p>
176      * Subclasses may override this. In this case {@link #newIdGenerator(int)} should be overridden as well
177      * to implement the encoding scheme for the encoding side.
178      *
179      *
180      * @param connectionId  the destination connection id of the {@code QUIC} connection.
181      * @return              the index or -1.
182      */
183     protected int decodeIndex(ByteBuf connectionId) {
184         return decodeIdx(connectionId);
185     }
186 
187     /**
188      * Return the destination connection id or {@code null} if decoding was not possible.
189      *
190      * @param buffer    the buffer
191      * @return          the id or {@code null}.
192      */
193     // Package-private for testing
194     @Nullable
195     static ByteBuf getDestinationConnectionId(ByteBuf buffer, int localConnectionIdLength) throws QuicException {
196         if (buffer.readableBytes() > Byte.BYTES) {
197             int offset = buffer.readerIndex();
198             boolean shortHeader = hasShortHeader(buffer);
199             offset += Byte.BYTES;
200             // We are only interested in packets with short header as these the packets that
201             // are exchanged after the server did provide the connection id that the client should use.
202             if (shortHeader) {
203                 // See https://www.rfc-editor.org/rfc/rfc9000.html#section-17.3
204                 // 1-RTT Packet {
205                 //  Header Form (1) = 0,
206                 //  Fixed Bit (1) = 1,
207                 //  Spin Bit (1),
208                 //  Reserved Bits (2),
209                 //  Key Phase (1),
210                 //  Packet Number Length (2),
211                 //  Destination Connection ID (0..160),
212                 //  Packet Number (8..32),
213                 //  Packet Payload (8..),
214                 //}
215                 return QuicHeaderParser.sliceCid(buffer, offset, localConnectionIdLength);
216             }
217         }
218         return null;
219     }
220 
221     // Package-private for testing
222     static boolean hasShortHeader(ByteBuf buffer) {
223         return QuicHeaderParser.hasShortHeader(buffer.getByte(buffer.readerIndex()));
224     }
225 
226     // Package-private for testing
227     static int decodeIdx(ByteBuf connectionId) {
228         if (connectionId.readableBytes() >= 2) {
229             return connectionId.getUnsignedShort(connectionId.readerIndex());
230         }
231         return -1;
232     }
233 
234     // Package-private for testing
235     static ByteBuffer encodeIdx(ByteBuffer buffer, int idx) {
236         // Allocate a new buffer and prepend it with the index.
237         ByteBuffer b = ByteBuffer.allocate(buffer.capacity() + Short.BYTES);
238         // We encode it as unsigned short.
239         b.putShort((short) idx).put(buffer).flip();
240         return b;
241     }
242 
243     /**
244      * Returns a {@link QuicConnectionIdGenerator} that will encode the given index into all the
245      * ids that it produces.
246      * <p>
247      * Subclasses may override this. In this case {@link #decodeIndex(ByteBuf)} should be overridden as well
248      * to implement the encoding scheme for the decoding side.
249      *
250      * @param idx       the index to encode into each id.
251      * @return          the {@link QuicConnectionIdGenerator}.
252      */
253     protected QuicConnectionIdGenerator newIdGenerator(int idx) {
254         return new IndexAwareQuicConnectionIdGenerator(idx, SecureRandomQuicConnectionIdGenerator.INSTANCE);
255     }
256 
257     private static final class IndexAwareQuicConnectionIdGenerator implements QuicConnectionIdGenerator {
258         private final int idx;
259         private final QuicConnectionIdGenerator idGenerator;
260 
261         IndexAwareQuicConnectionIdGenerator(int idx, QuicConnectionIdGenerator idGenerator) {
262             this.idx = idx;
263             this.idGenerator = idGenerator;
264         }
265 
266         @Override
267         public ByteBuffer newId(int length) {
268             if (length > Short.BYTES) {
269                 return encodeIdx(idGenerator.newId(length - Short.BYTES), idx);
270             }
271             return idGenerator.newId(length);
272         }
273 
274         @Override
275         public ByteBuffer newId(ByteBuffer input, int length) {
276             if (length > Short.BYTES) {
277                 return encodeIdx(idGenerator.newId(input, length - Short.BYTES), idx);
278             }
279             return idGenerator.newId(input, length);
280         }
281 
282         @Override
283         public ByteBuffer newId(ByteBuffer scid, ByteBuffer dcid, int length) {
284             if (length > Short.BYTES) {
285                 return encodeIdx(idGenerator.newId(scid, dcid, length - Short.BYTES), idx);
286             }
287             return idGenerator.newId(scid, dcid, length);
288         }
289 
290         @Override
291         public int maxConnectionIdLength() {
292             return idGenerator.maxConnectionIdLength();
293         }
294 
295         @Override
296         public boolean isIdempotent() {
297             // Return false as the id might be different because of the idx that is encoded into it.
298             return false;
299         }
300     }
301 
302     private static final class ChannelHandlerContextDispatcher extends AtomicBoolean {
303 
304         private final ChannelHandlerContext ctx;
305 
306         ChannelHandlerContextDispatcher(ChannelHandlerContext ctx) {
307             this.ctx = ctx;
308         }
309 
310         void fireChannelRead(Object msg) {
311             ctx.fireChannelRead(msg);
312             set(true);
313         }
314 
315         boolean fireChannelReadCompleteIfNeeded() {
316             if (getAndSet(false)) {
317                 // There was a fireChannelRead() before, let's call fireChannelReadComplete()
318                 // so the user is aware that we might be done with the reading loop.
319                 ctx.fireChannelReadComplete();
320                 return true;
321             }
322             return false;
323         }
324     }
325 }