1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http3;
17
18 import io.netty.channel.ChannelHandler;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.channel.ChannelPromise;
21 import io.netty.channel.socket.ChannelInputShutdownReadComplete;
22
23 import java.util.function.BooleanSupplier;
24
25 import static io.netty.handler.codec.http.HttpMethod.HEAD;
26 import static io.netty.handler.codec.http3.Http3FrameValidationUtils.frameTypeUnexpected;
27 import static io.netty.handler.codec.http3.Http3RequestStreamValidationUtils.INVALID_FRAME_READ;
28 import static io.netty.handler.codec.http3.Http3RequestStreamValidationUtils.sendStreamAbandonedIfRequired;
29 import static io.netty.handler.codec.http3.Http3RequestStreamValidationUtils.validateClientWrite;
30 import static io.netty.handler.codec.http3.Http3RequestStreamValidationUtils.validateDataFrameRead;
31 import static io.netty.handler.codec.http3.Http3RequestStreamValidationUtils.validateHeaderFrameRead;
32 import static io.netty.handler.codec.http3.Http3RequestStreamValidationUtils.validateOnStreamClosure;
33
34 final class Http3RequestStreamValidationHandler extends Http3FrameTypeDuplexValidationHandler<Http3RequestStreamFrame> {
35 private final boolean server;
36 private final BooleanSupplier goAwayReceivedSupplier;
37 private final QpackAttributes qpackAttributes;
38 private final QpackDecoder qpackDecoder;
39 private final Http3RequestStreamCodecState decodeState;
40 private final Http3RequestStreamCodecState encodeState;
41
42 private boolean clientHeadRequest;
43 private long expectedLength = -1;
44 private long seenLength;
45
46 static ChannelHandler newServerValidator(QpackAttributes qpackAttributes, QpackDecoder decoder,
47 Http3RequestStreamCodecState encodeState,
48 Http3RequestStreamCodecState decodeState) {
49 return new Http3RequestStreamValidationHandler(true, () -> false, qpackAttributes, decoder,
50 encodeState, decodeState);
51 }
52
53 static ChannelHandler newClientValidator(BooleanSupplier goAwayReceivedSupplier, QpackAttributes qpackAttributes,
54 QpackDecoder decoder, Http3RequestStreamCodecState encodeState,
55 Http3RequestStreamCodecState decodeState) {
56 return new Http3RequestStreamValidationHandler(false, goAwayReceivedSupplier, qpackAttributes, decoder,
57 encodeState, decodeState);
58 }
59
60 private Http3RequestStreamValidationHandler(boolean server, BooleanSupplier goAwayReceivedSupplier,
61 QpackAttributes qpackAttributes, QpackDecoder qpackDecoder,
62 Http3RequestStreamCodecState encodeState,
63 Http3RequestStreamCodecState decodeState) {
64 super(Http3RequestStreamFrame.class);
65 this.server = server;
66 this.goAwayReceivedSupplier = goAwayReceivedSupplier;
67 this.qpackAttributes = qpackAttributes;
68 this.qpackDecoder = qpackDecoder;
69 this.decodeState = decodeState;
70 this.encodeState = encodeState;
71 }
72
73 @Override
74 void write(ChannelHandlerContext ctx, Http3RequestStreamFrame frame, ChannelPromise promise) {
75 if (!server) {
76 if (!validateClientWrite(frame, promise, ctx, goAwayReceivedSupplier, encodeState)) {
77 return;
78 }
79 if (frame instanceof Http3HeadersFrame) {
80 clientHeadRequest = HEAD.asciiName().equals(((Http3HeadersFrame) frame).headers().method());
81 }
82 }
83 ctx.write(frame, promise);
84 }
85
86 @Override
87 void channelRead(ChannelHandlerContext ctx, Http3RequestStreamFrame frame) {
88 if (frame instanceof Http3PushPromiseFrame) {
89 if (server) {
90
91
92 frameTypeUnexpected(ctx, frame);
93 } else {
94 ctx.fireChannelRead(frame);
95 }
96 return;
97 }
98
99 if (frame instanceof Http3HeadersFrame) {
100 Http3HeadersFrame headersFrame = (Http3HeadersFrame) frame;
101 long maybeContentLength = validateHeaderFrameRead(headersFrame, ctx, decodeState);
102 if (maybeContentLength >= 0) {
103 expectedLength = maybeContentLength;
104 } else if (maybeContentLength == INVALID_FRAME_READ) {
105 return;
106 }
107 }
108
109 if (frame instanceof Http3DataFrame) {
110 final Http3DataFrame dataFrame = (Http3DataFrame) frame;
111 long maybeContentLength = validateDataFrameRead(dataFrame, ctx, expectedLength, seenLength,
112 clientHeadRequest);
113 if (maybeContentLength >= 0) {
114 seenLength = maybeContentLength;
115 } else if (maybeContentLength == INVALID_FRAME_READ) {
116 return;
117 }
118 }
119
120 ctx.fireChannelRead(frame);
121 }
122
123 @Override
124 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
125 if (evt == ChannelInputShutdownReadComplete.INSTANCE) {
126 sendStreamAbandonedIfRequired(ctx, qpackAttributes, qpackDecoder, decodeState);
127 if (!validateOnStreamClosure(ctx, expectedLength, seenLength, clientHeadRequest)) {
128 return;
129 }
130 }
131 ctx.fireUserEventTriggered(evt);
132 }
133
134 @Override
135 public boolean isSharable() {
136
137 return false;
138 }
139 }