1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http3;
17
18 import io.netty.channel.ChannelHandler;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.channel.ChannelInboundHandler;
21 import io.netty.channel.ChannelInboundHandlerAdapter;
22 import io.netty.channel.socket.ChannelInputShutdownReadComplete;
23 import io.netty.handler.codec.quic.QuicChannel;
24 import io.netty.handler.codec.quic.QuicStreamChannel;
25 import io.netty.handler.codec.quic.QuicStreamChannelBootstrap;
26 import io.netty.handler.codec.quic.QuicStreamType;
27 import io.netty.util.ReferenceCountUtil;
28 import io.netty.util.concurrent.Future;
29 import io.netty.util.concurrent.Promise;
30 import org.jetbrains.annotations.Nullable;
31
32 import java.util.concurrent.ConcurrentMap;
33 import java.util.concurrent.atomic.AtomicLongFieldUpdater;
34 import java.util.function.UnaryOperator;
35
36 import static io.netty.handler.codec.http3.Http3.maxPushIdReceived;
37 import static io.netty.handler.codec.http3.Http3CodecUtils.connectionError;
38 import static io.netty.handler.codec.http3.Http3ErrorCode.H3_ID_ERROR;
39 import static io.netty.util.internal.PlatformDependent.newConcurrentHashMap;
40 import static java.util.Objects.requireNonNull;
41 import static java.util.concurrent.atomic.AtomicLongFieldUpdater.newUpdater;
42
43
44
45
46
47
48
49 public final class Http3ServerPushStreamManager {
50 private static final AtomicLongFieldUpdater<Http3ServerPushStreamManager> nextIdUpdater =
51 newUpdater(Http3ServerPushStreamManager.class, "nextId");
52 private static final Object CANCELLED_STREAM = new Object();
53 private static final Object PUSH_ID_GENERATED = new Object();
54 private static final Object AWAITING_STREAM_ESTABLISHMENT = new Object();
55
56 private final QuicChannel channel;
57 private final ConcurrentMap<Long, Object> pushStreams;
58 private final ChannelInboundHandler controlStreamListener;
59
60 private volatile long nextId;
61
62
63
64
65
66
67 public Http3ServerPushStreamManager(QuicChannel channel) {
68 this(channel, 8);
69 }
70
71
72
73
74
75
76
77 public Http3ServerPushStreamManager(QuicChannel channel, int initialPushStreamsCountHint) {
78 this.channel = requireNonNull(channel, "channel");
79 pushStreams = newConcurrentHashMap(initialPushStreamsCountHint);
80 controlStreamListener = new ChannelInboundHandlerAdapter() {
81 @Override
82 public void channelRead(ChannelHandlerContext ctx, Object msg) {
83 if (msg instanceof Http3CancelPushFrame) {
84 final long pushId = ((Http3CancelPushFrame) msg).id();
85 if (pushId >= nextId) {
86 connectionError(ctx, H3_ID_ERROR, "CANCEL_PUSH id greater than the last known id", true);
87 return;
88 }
89
90 pushStreams.computeIfPresent(pushId, (id, existing) -> {
91 if (existing == AWAITING_STREAM_ESTABLISHMENT) {
92 return CANCELLED_STREAM;
93 }
94 if (existing == PUSH_ID_GENERATED) {
95 throw new IllegalStateException("Unexpected push stream state " + existing +
96 " for pushId: " + id);
97 }
98 assert existing instanceof QuicStreamChannel;
99 ((QuicStreamChannel) existing).close();
100
101 return null;
102 });
103 }
104 ReferenceCountUtil.release(msg);
105 }
106 };
107 }
108
109
110
111
112
113
114 public boolean isPushAllowed() {
115 return isPushAllowed(maxPushIdReceived(channel));
116 }
117
118
119
120
121
122
123
124
125
126 public long reserveNextPushId() {
127 final long maxPushId = maxPushIdReceived(channel);
128 if (isPushAllowed(maxPushId)) {
129 return nextPushId();
130 }
131 throw new IllegalStateException("MAX allowed push ID: " + maxPushId + ", next push ID: " + nextId);
132 }
133
134
135
136
137
138
139
140
141
142
143 public Future<QuicStreamChannel> newPushStream(long pushId, @Nullable ChannelHandler handler) {
144 final Promise<QuicStreamChannel> promise = channel.eventLoop().newPromise();
145 newPushStream(pushId, handler, promise);
146 return promise;
147 }
148
149
150
151
152
153
154
155
156
157
158 public void newPushStream(long pushId, @Nullable ChannelHandler handler, Promise<QuicStreamChannel> promise) {
159 validatePushId(pushId);
160 channel.createStream(QuicStreamType.UNIDIRECTIONAL, pushStreamInitializer(pushId, handler), promise);
161 setupCancelPushIfStreamCreationFails(pushId, promise, channel);
162 }
163
164
165
166
167
168
169
170
171
172
173
174 public void newPushStream(long pushId, @Nullable ChannelHandler handler,
175 UnaryOperator<QuicStreamChannelBootstrap> bootstrapConfigurator,
176 Promise<QuicStreamChannel> promise) {
177 validatePushId(pushId);
178 QuicStreamChannelBootstrap bootstrap = bootstrapConfigurator.apply(channel.newStreamBootstrap());
179 bootstrap.type(QuicStreamType.UNIDIRECTIONAL)
180 .handler(pushStreamInitializer(pushId, handler))
181 .create(promise);
182 setupCancelPushIfStreamCreationFails(pushId, promise, channel);
183 }
184
185
186
187
188
189
190
191
192 public ChannelInboundHandler controlStreamListener() {
193 return controlStreamListener;
194 }
195
196 private boolean isPushAllowed(long maxPushId) {
197 return nextId <= maxPushId;
198 }
199
200 private long nextPushId() {
201 final long pushId = nextIdUpdater.getAndIncrement(this);
202 pushStreams.put(pushId, PUSH_ID_GENERATED);
203 return pushId;
204 }
205
206 private void validatePushId(long pushId) {
207 if (!pushStreams.replace(pushId, PUSH_ID_GENERATED, AWAITING_STREAM_ESTABLISHMENT)) {
208 throw new IllegalArgumentException("Unknown push ID: " + pushId);
209 }
210 }
211
212 private Http3PushStreamServerInitializer pushStreamInitializer(long pushId, @Nullable ChannelHandler handler) {
213 final Http3PushStreamServerInitializer initializer;
214 if (handler instanceof Http3PushStreamServerInitializer) {
215 initializer = (Http3PushStreamServerInitializer) handler;
216 } else {
217 initializer = null;
218 }
219 return new Http3PushStreamServerInitializer(pushId) {
220 @Override
221 protected void initPushStream(QuicStreamChannel ch) {
222 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
223 private boolean stateUpdated;
224
225 @Override
226 public void channelActive(ChannelHandlerContext ctx) {
227 if (!stateUpdated) {
228 updatePushStreamsMap();
229 }
230 }
231
232 @Override
233 public void handlerAdded(ChannelHandlerContext ctx) {
234 if (!stateUpdated && ctx.channel().isActive()) {
235 updatePushStreamsMap();
236 }
237 }
238
239 private void updatePushStreamsMap() {
240 assert !stateUpdated;
241 stateUpdated = true;
242 pushStreams.compute(pushId, (id, existing) -> {
243 if (existing == AWAITING_STREAM_ESTABLISHMENT) {
244 return ch;
245 }
246 if (existing == CANCELLED_STREAM) {
247 ch.close();
248 return null;
249 }
250 throw new IllegalStateException("Unexpected push stream state " +
251 existing + " for pushId: " + id);
252 });
253 }
254
255 @Override
256 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
257 if (evt == ChannelInputShutdownReadComplete.INSTANCE) {
258 pushStreams.remove(pushId);
259 }
260 ctx.fireUserEventTriggered(evt);
261 }
262 });
263 if (initializer != null) {
264 initializer.initPushStream(ch);
265 } else if (handler != null) {
266 ch.pipeline().addLast(handler);
267 }
268 }
269 };
270 }
271
272 private static void setupCancelPushIfStreamCreationFails(long pushId, Future<QuicStreamChannel> future,
273 QuicChannel channel) {
274 if (future.isDone()) {
275 sendCancelPushIfFailed(future, pushId, channel);
276 } else {
277 future.addListener(f -> sendCancelPushIfFailed(future, pushId, channel));
278 }
279 }
280
281 private static void sendCancelPushIfFailed(Future<QuicStreamChannel> future, long pushId, QuicChannel channel) {
282
283
284 if (!future.isSuccess()) {
285 final QuicStreamChannel localControlStream = Http3.getLocalControlStream(channel);
286 assert localControlStream != null;
287 localControlStream.writeAndFlush(new DefaultHttp3CancelPushFrame(pushId));
288 }
289 }
290 }