1 /*
2 * Copyright 2024 The Netty Project
3 *
4 * The Netty Project licenses this file to you under the Apache License,
5 * version 2.0 (the "License"); you may not use this file except in compliance
6 * with the License. You may obtain a copy of the License at:
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 * License for the specific language governing permissions and limitations
14 * under the License.
15 */
16 package io.netty.handler.codec.quic;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelInboundHandlerAdapter;
22 import io.netty.channel.socket.DatagramPacket;
23 import io.netty.util.internal.ObjectUtil;
24 import org.jetbrains.annotations.Nullable;
25
26 import java.nio.ByteBuffer;
27 import java.util.List;
28 import java.util.concurrent.CopyOnWriteArrayList;
29 import java.util.concurrent.atomic.AtomicBoolean;
30
31
32 /**
33 * Special {@link io.netty.channel.ChannelHandler} that should be used to init {@link Channel}s that will be used
34 * for QUIC while <a href="https://man7.org/linux/man-pages/man7/socket.7.html">SO_REUSEPORT</a> is used to
35 * bind to same {@link java.net.InetSocketAddress} multiple times. This is necessary to ensure QUIC packets are always
36 * dispatched to the correct codec that keeps the mapping for the connection id.
37 * This implementation use a very simple mapping strategy by encoding the index of the internal datastructure that
38 * keeps track of the different {@link ChannelHandlerContext}s into the destination connection id. This way once a
39 * {@code QUIC} packet is received its possible to forward it to the right codec.
40 * Subclasses might change how encoding / decoding of the index is done by overriding {@link #decodeIndex(ByteBuf)}
41 * and {@link #newIdGenerator(int)}.
42 * <p>
43 * It is important that the same {@link QuicCodecDispatcher} instance is shared between all the {@link Channel}s that
44 * are bound to the same {@link java.net.InetSocketAddress} and use {@code SO_REUSEPORT}.
45 * <p>
46 * An alternative way to handle this would be to do the "routing" to the correct socket in an {@code epbf} program
47 * by implementing your own {@link QuicConnectionIdGenerator} that issue ids that can be understood and handled by the
48 * {@code epbf} program to route the packet to the correct socket.
49 *
50 */
51 public abstract class QuicCodecDispatcher extends ChannelInboundHandlerAdapter {
52 // 20 is the max as per RFC.
53 // See https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
54 private static final int MAX_LOCAL_CONNECTION_ID_LENGTH = 20;
55
56 // Use a CopyOnWriteArrayList as modifications to the List should only happen during bootstrapping and teardown
57 // of the channels.
58 private final List<ChannelHandlerContextDispatcher> contextList = new CopyOnWriteArrayList<>();
59 private final int localConnectionIdLength;
60
61 /**
62 * Create a new instance using the default connection id length.
63 */
64 protected QuicCodecDispatcher() {
65 this(MAX_LOCAL_CONNECTION_ID_LENGTH);
66 }
67
68 /**
69 * Create a new instance
70 *
71 * @param localConnectionIdLength the local connection id length. This must be between 10 and 20.
72 */
73 protected QuicCodecDispatcher(int localConnectionIdLength) {
74 // Let's use 10 as a minimum to ensure we still have some bytes left for randomness as we already use
75 // 2 of the bytes to encode the index.
76 this.localConnectionIdLength = ObjectUtil.checkInRange(localConnectionIdLength,
77 10, MAX_LOCAL_CONNECTION_ID_LENGTH, "localConnectionIdLength");
78 }
79
80 @Override
81 public final boolean isSharable() {
82 return true;
83 }
84
85 @Override
86 public final void handlerAdded(ChannelHandlerContext ctx) throws Exception {
87 super.handlerAdded(ctx);
88
89 ChannelHandlerContextDispatcher ctxDispatcher = new ChannelHandlerContextDispatcher(ctx);
90 contextList.add(ctxDispatcher);
91 int idx = contextList.indexOf(ctxDispatcher);
92 try {
93 QuicConnectionIdGenerator idGenerator = newIdGenerator((short) idx);
94 initChannel(ctx.channel(), localConnectionIdLength, idGenerator);
95 } catch (Exception e) {
96 // Null out on exception and rethrow. We not remove the element as the indices need to be
97 // stable.
98 contextList.set(idx, null);
99 throw e;
100 }
101 }
102
103 @Override
104 public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
105 super.handlerRemoved(ctx);
106
107 for (int idx = 0; idx < contextList.size(); idx++) {
108 ChannelHandlerContextDispatcher ctxDispatcher = contextList.get(idx);
109 if (ctxDispatcher != null && ctxDispatcher.ctx.equals(ctx)) {
110 // null out, so we can collect the ChannelHandlerContext that was stored in the List.
111 contextList.set(idx, null);
112 break;
113 }
114 }
115 }
116
117 @Override
118 public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
119 DatagramPacket packet = (DatagramPacket) msg;
120 ByteBuf connectionId = getDestinationConnectionId(packet.content(), localConnectionIdLength);
121 if (connectionId != null) {
122 int idx = decodeIndex(connectionId);
123 if (contextList.size() > idx) {
124 ChannelHandlerContextDispatcher selectedCtx = contextList.get(idx);
125 if (selectedCtx != null) {
126 selectedCtx.fireChannelRead(msg);
127 return;
128 }
129 }
130 }
131 // We were not be-able to dispatch to a specific ChannelHandlerContext, just forward and let the
132 // Quic*Codec handle it directly.
133 ctx.fireChannelRead(msg);
134 }
135
136 @Override
137 public final void channelReadComplete(ChannelHandlerContext ctx) {
138 // Loop over all ChannelHandlerContextDispatchers and ensure fireChannelReadComplete() is called if required.
139 // We use and old style for loop as CopyOnWriteArrayList implements RandomAccess and so we can
140 // reduce the object creations.
141 boolean dispatchForOwnContextAlready = false;
142 for (int i = 0; i < contextList.size(); i++) {
143 ChannelHandlerContextDispatcher ctxDispatcher = contextList.get(i);
144 if (ctxDispatcher != null) {
145 boolean fired = ctxDispatcher.fireChannelReadCompleteIfNeeded();
146 if (fired && !dispatchForOwnContextAlready) {
147 // Check if we dispatched to ctx so if we didnt at the end we can do it manually.
148 dispatchForOwnContextAlready = ctx.equals(ctxDispatcher.ctx);
149 }
150 }
151 }
152 if (!dispatchForOwnContextAlready) {
153 ctx.fireChannelReadComplete();
154 }
155 }
156
157 /**
158 * Init the {@link Channel} and add all the needed {@link io.netty.channel.ChannelHandler} to the pipeline.
159 * This also included building the {@code QUIC} codec via {@link QuicCodecBuilder} sub-type using the given local
160 * connection id length and {@link QuicConnectionIdGenerator}.
161 *
162 * @param channel the {@link Channel} to init.
163 * @param localConnectionIdLength the local connection id length that must be used with the
164 * {@link QuicCodecBuilder}.
165 * @param idGenerator the {@link QuicConnectionIdGenerator} that must be used with the
166 * {@link QuicCodecBuilder}.
167 * @throws Exception thrown on error.
168 */
169 protected abstract void initChannel(Channel channel, int localConnectionIdLength,
170 QuicConnectionIdGenerator idGenerator) throws Exception;
171
172 /**
173 * Return the idx that was encoded into the connectionId via the {@link QuicConnectionIdGenerator} before,
174 * or {@code -1} if decoding was not successful.
175 * <p>
176 * Subclasses may override this. In this case {@link #newIdGenerator(int)} should be overridden as well
177 * to implement the encoding scheme for the encoding side.
178 *
179 *
180 * @param connectionId the destination connection id of the {@code QUIC} connection.
181 * @return the index or -1.
182 */
183 protected int decodeIndex(ByteBuf connectionId) {
184 return decodeIdx(connectionId);
185 }
186
187 /**
188 * Return the destination connection id or {@code null} if decoding was not possible.
189 *
190 * @param buffer the buffer
191 * @return the id or {@code null}.
192 */
193 // Package-private for testing
194 @Nullable
195 static ByteBuf getDestinationConnectionId(ByteBuf buffer, int localConnectionIdLength) throws QuicException {
196 if (buffer.readableBytes() > Byte.BYTES) {
197 int offset = buffer.readerIndex();
198 boolean shortHeader = hasShortHeader(buffer);
199 offset += Byte.BYTES;
200 // We are only interested in packets with short header as these the packets that
201 // are exchanged after the server did provide the connection id that the client should use.
202 if (shortHeader) {
203 // See https://www.rfc-editor.org/rfc/rfc9000.html#section-17.3
204 // 1-RTT Packet {
205 // Header Form (1) = 0,
206 // Fixed Bit (1) = 1,
207 // Spin Bit (1),
208 // Reserved Bits (2),
209 // Key Phase (1),
210 // Packet Number Length (2),
211 // Destination Connection ID (0..160),
212 // Packet Number (8..32),
213 // Packet Payload (8..),
214 //}
215 return QuicHeaderParser.sliceCid(buffer, offset, localConnectionIdLength);
216 }
217 }
218 return null;
219 }
220
221 // Package-private for testing
222 static boolean hasShortHeader(ByteBuf buffer) {
223 return QuicHeaderParser.hasShortHeader(buffer.getByte(buffer.readerIndex()));
224 }
225
226 // Package-private for testing
227 static int decodeIdx(ByteBuf connectionId) {
228 if (connectionId.readableBytes() >= 2) {
229 return connectionId.getUnsignedShort(connectionId.readerIndex());
230 }
231 return -1;
232 }
233
234 // Package-private for testing
235 static ByteBuffer encodeIdx(ByteBuffer buffer, int idx) {
236 // Allocate a new buffer and prepend it with the index.
237 ByteBuffer b = ByteBuffer.allocate(buffer.capacity() + Short.BYTES);
238 // We encode it as unsigned short.
239 b.putShort((short) idx).put(buffer).flip();
240 return b;
241 }
242
243 /**
244 * Returns a {@link QuicConnectionIdGenerator} that will encode the given index into all the
245 * ids that it produces.
246 * <p>
247 * Subclasses may override this. In this case {@link #decodeIndex(ByteBuf)} should be overridden as well
248 * to implement the encoding scheme for the decoding side.
249 *
250 * @param idx the index to encode into each id.
251 * @return the {@link QuicConnectionIdGenerator}.
252 */
253 protected QuicConnectionIdGenerator newIdGenerator(int idx) {
254 return new IndexAwareQuicConnectionIdGenerator(idx, SecureRandomQuicConnectionIdGenerator.INSTANCE);
255 }
256
257 private static final class IndexAwareQuicConnectionIdGenerator implements QuicConnectionIdGenerator {
258 private final int idx;
259 private final QuicConnectionIdGenerator idGenerator;
260
261 IndexAwareQuicConnectionIdGenerator(int idx, QuicConnectionIdGenerator idGenerator) {
262 this.idx = idx;
263 this.idGenerator = idGenerator;
264 }
265
266 @Override
267 public ByteBuffer newId(int length) {
268 if (length > Short.BYTES) {
269 return encodeIdx(idGenerator.newId(length - Short.BYTES), idx);
270 }
271 return idGenerator.newId(length);
272 }
273
274 @Override
275 public ByteBuffer newId(ByteBuffer input, int length) {
276 if (length > Short.BYTES) {
277 return encodeIdx(idGenerator.newId(input, length - Short.BYTES), idx);
278 }
279 return idGenerator.newId(input, length);
280 }
281
282 @Override
283 public ByteBuffer newId(ByteBuffer scid, ByteBuffer dcid, int length) {
284 if (length > Short.BYTES) {
285 return encodeIdx(idGenerator.newId(scid, dcid, length - Short.BYTES), idx);
286 }
287 return idGenerator.newId(scid, dcid, length);
288 }
289
290 @Override
291 public int maxConnectionIdLength() {
292 return idGenerator.maxConnectionIdLength();
293 }
294
295 @Override
296 public boolean isIdempotent() {
297 // Return false as the id might be different because of the idx that is encoded into it.
298 return false;
299 }
300 }
301
302 private static final class ChannelHandlerContextDispatcher extends AtomicBoolean {
303
304 private final ChannelHandlerContext ctx;
305
306 ChannelHandlerContextDispatcher(ChannelHandlerContext ctx) {
307 this.ctx = ctx;
308 }
309
310 void fireChannelRead(Object msg) {
311 ctx.fireChannelRead(msg);
312 set(true);
313 }
314
315 boolean fireChannelReadCompleteIfNeeded() {
316 if (getAndSet(false)) {
317 // There was a fireChannelRead() before, let's call fireChannelReadComplete()
318 // so the user is aware that we might be done with the reading loop.
319 ctx.fireChannelReadComplete();
320 return true;
321 }
322 return false;
323 }
324 }
325 }