1 /*
2 * Copyright 2021 The Netty Project
3 *
4 * The Netty Project licenses this file to you under the Apache License, version 2.0 (the
5 * "License"); you may not use this file except in compliance with the License. You may obtain a
6 * copy of the License at:
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software distributed under the License
11 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12 * or implied. See the License for the specific language governing permissions and limitations under
13 * the License.
14 */
15 package io.netty5.handler.codec;
16
17 import io.netty5.buffer.api.BufferAllocator;
18 import io.netty5.channel.ChannelFutureListeners;
19 import io.netty5.channel.ChannelHandler;
20 import io.netty5.channel.ChannelHandlerContext;
21 import io.netty5.channel.ChannelOption;
22 import io.netty5.channel.ChannelPipeline;
23 import io.netty5.util.concurrent.Future;
24 import io.netty5.util.concurrent.FutureContextListener;
25
26 import static io.netty5.util.internal.ObjectUtil.checkPositiveOrZero;
27
28 /**
29 * An abstract {@link ChannelHandler} that aggregates a series of message objects into a single aggregated message.
30 * <p>
31 * 'A series of messages' is composed of the following:
32 * <ul>
33 * <li>a single start message which optionally contains the first part of the content, and</li>
34 * <li>1 or more content messages.</li>
35 * </ul>
36 * The content of the aggregated message will be the merged content of the start message and its following content
37 * messages. If this aggregator encounters a content message where {@link #isLastContentMessage(AutoCloseable)}
38 * return {@code true} for, the aggregator will finish the aggregation and produce the aggregated message and expect
39 * another start message.
40 * </p>
41 *
42 * @param <I> the type that covers both start message and content message
43 * @param <S> the type of the start message
44 * @param <C> the type of the content message
45 * @param <A> the type of the aggregated message
46 */
47 public abstract class MessageAggregator<I, S, C extends AutoCloseable, A extends AutoCloseable>
48 extends MessageToMessageDecoder<I> {
49 private final int maxContentLength;
50 private A currentMessage;
51 private boolean handlingOversizedMessage;
52
53 private ChannelHandlerContext ctx;
54 private FutureContextListener<ChannelHandlerContext, Void> continueResponseWriteListener;
55
56 private boolean aggregating;
57
58 /**
59 * Creates a new instance.
60 *
61 * @param maxContentLength
62 * the maximum length of the aggregated content.
63 * If the length of the aggregated content exceeds this value,
64 * {@link #handleOversizedMessage(ChannelHandlerContext, Object)} will be called.
65 */
66 protected MessageAggregator(int maxContentLength) {
67 validateMaxContentLength(maxContentLength);
68 this.maxContentLength = maxContentLength;
69 }
70
71 protected MessageAggregator(int maxContentLength, Class<? extends I> inboundMessageType) {
72 super(inboundMessageType);
73 validateMaxContentLength(maxContentLength);
74 this.maxContentLength = maxContentLength;
75 }
76
77 private static void validateMaxContentLength(int maxContentLength) {
78 checkPositiveOrZero(maxContentLength, "maxContentLength");
79 }
80
81 @Override
82 public boolean acceptInboundMessage(Object msg) throws Exception {
83 // No need to match last and full types because they are subset of first and middle types.
84 if (!super.acceptInboundMessage(msg)) {
85 return false;
86 }
87
88 if (isAggregated(msg)) {
89 return false;
90 }
91
92 // NOTE: It's tempting to make this check only if aggregating is false. There are however
93 // side conditions in decode(...) in respect to large messages.
94 if (tryStartMessage(msg) != null) {
95 aggregating = true;
96 return true;
97 }
98 return aggregating && tryContentMessage(msg) != null;
99 }
100
101 /**
102 * If the passed {@code msg} is a {@linkplain S start message} then cast and return the same, else return
103 * {@code null}.
104 */
105 protected abstract S tryStartMessage(Object msg);
106
107 /**
108 * If the passed {@code msg} is a {@linkplain C content message} then cast and return the same, else return
109 * {@code null}.
110 */
111 protected abstract C tryContentMessage(Object msg);
112
113 /**
114 * Returns {@code true} if and only if the specified message is the last content message. Typically, this method is
115 * implemented as a single {@code return} statement with {@code instanceof}:
116 * <pre>
117 * return msg instanceof MyLastContentMessage;
118 * </pre>
119 * or with {@code instanceof} and boolean field check:
120 * <pre>
121 * return msg instanceof MyContentMessage && msg.isLastFragment();
122 * </pre>
123 */
124 protected abstract boolean isLastContentMessage(C msg) throws Exception;
125
126 /**
127 * Returns {@code true} if and only if the specified message is already aggregated. If this method returns
128 * {@code true}, this handler will simply forward the message to the next handler as-is.
129 */
130 protected abstract boolean isAggregated(Object msg) throws Exception;
131
132 /**
133 * Returns the length in bytes of the passed message.
134 *
135 * @param msg to calculate length.
136 * @return Length in bytes of the passed message.
137 */
138 protected abstract int lengthForContent(C msg);
139
140 /**
141 * Returns the length in bytes of the passed message.
142 *
143 * @param msg to calculate length.
144 * @return Length in bytes of the passed message.
145 */
146 protected abstract int lengthForAggregation(A msg);
147
148 /**
149 * Returns the maximum allowed length of the aggregated message in bytes.
150 */
151 public final int maxContentLength() {
152 return maxContentLength;
153 }
154
155 protected final ChannelHandlerContext ctx() {
156 if (ctx == null) {
157 throw new IllegalStateException("not added to a pipeline yet");
158 }
159 return ctx;
160 }
161
162 @Override
163 protected void decode(final ChannelHandlerContext ctx, I msg) throws Exception {
164 assert aggregating;
165 final S startMsg = tryStartMessage(msg);
166 if (startMsg != null) {
167 handlingOversizedMessage = false;
168 if (currentMessage != null) {
169 currentMessage.close();
170 currentMessage = null;
171 throw new MessageAggregationException();
172 }
173
174 // Send the continue response if necessary (e.g. 'Expect: 100-continue' header)
175 // Check before content length. Failing an expectation may result in a different response being sent.
176 Object continueResponse = newContinueResponse(startMsg, maxContentLength, ctx.pipeline());
177 if (continueResponse != null) {
178 // Cache the write listener for reuse.
179 FutureContextListener<ChannelHandlerContext, Void> listener = continueResponseWriteListener;
180 if (listener == null) {
181 continueResponseWriteListener = listener = (context, future) -> {
182 if (future.isFailed()) {
183 context.fireChannelExceptionCaught(future.cause());
184 }
185 };
186 }
187
188 // Make sure to call this before writing, otherwise reference counts may be invalid.
189 boolean closeAfterWrite = closeAfterContinueResponse(continueResponse);
190 handlingOversizedMessage = ignoreContentAfterContinueResponse(continueResponse);
191
192 Future<Void> future = ctx.writeAndFlush(continueResponse).addListener(ctx, listener);
193
194 if (closeAfterWrite) {
195 future.addListener(ctx, ChannelFutureListeners.CLOSE);
196 return;
197 }
198 if (handlingOversizedMessage) {
199 return;
200 }
201 } else if (isContentLengthInvalid(startMsg, maxContentLength)) {
202 // if content length is set, preemptively close if it's too large
203 invokeHandleOversizedMessage(ctx, startMsg);
204 return;
205 }
206
207 if (startMsg instanceof DecoderResultProvider &&
208 !((DecoderResultProvider) startMsg).decoderResult().isSuccess()) {
209 final A aggregated = beginAggregation(ctx.bufferAllocator(), startMsg);
210 finishAggregation(ctx.bufferAllocator(), aggregated);
211 ctx.fireChannelRead(aggregated);
212 return;
213 }
214
215 currentMessage = beginAggregation(ctx.bufferAllocator(), startMsg);
216 return;
217 }
218
219 final C contentMsg = tryContentMessage(msg);
220 if (contentMsg != null) {
221 if (currentMessage == null) {
222 // it is possible that a TooLongFrameException was already thrown but we can still discard data
223 // until the beginning of the next request/response.
224 return;
225 }
226
227 // Handle oversized message.
228 if (lengthForAggregation(currentMessage) > maxContentLength - lengthForContent(contentMsg)) {
229 invokeHandleOversizedMessage(ctx, currentMessage);
230 return;
231 }
232
233 aggregate(ctx.bufferAllocator(), currentMessage, contentMsg);
234
235 final boolean last;
236 if (contentMsg instanceof DecoderResultProvider) {
237 DecoderResult decoderResult = ((DecoderResultProvider) contentMsg).decoderResult();
238 if (!decoderResult.isSuccess()) {
239 if (currentMessage instanceof DecoderResultProvider) {
240 ((DecoderResultProvider) currentMessage).setDecoderResult(
241 DecoderResult.failure(decoderResult.cause()));
242 }
243 last = true;
244 } else {
245 last = isLastContentMessage(contentMsg);
246 }
247 } else {
248 last = isLastContentMessage(contentMsg);
249 }
250
251 if (last) {
252 finishAggregation0(ctx.bufferAllocator(), currentMessage);
253
254 // All done
255 A message = currentMessage;
256 currentMessage = null;
257 ctx.fireChannelRead(message);
258 }
259 } else {
260 throw new MessageAggregationException();
261 }
262 }
263
264 /**
265 * Determine if the message {@code start}'s content length is known, and if it greater than
266 * {@code maxContentLength}.
267 * @param start The message which may indicate the content length.
268 * @param maxContentLength The maximum allowed content length.
269 * @return {@code true} if the message {@code start}'s content length is known, and if it greater than
270 * {@code maxContentLength}. {@code false} otherwise.
271 */
272 protected abstract boolean isContentLengthInvalid(S start, int maxContentLength) throws Exception;
273
274 /**
275 * Returns the 'continue response' for the specified start message if necessary. For example, this method is
276 * useful to handle an HTTP 100-continue header.
277 *
278 * @return the 'continue response', or {@code null} if there's no message to send
279 */
280 protected abstract Object newContinueResponse(S start, int maxContentLength, ChannelPipeline pipeline)
281 throws Exception;
282
283 /**
284 * Determine if the channel should be closed after the result of
285 * {@link #newContinueResponse(Object, int, ChannelPipeline)} is written.
286 * @param msg The return value from {@link #newContinueResponse(Object, int, ChannelPipeline)}.
287 * @return {@code true} if the channel should be closed after the result of
288 * {@link #newContinueResponse(Object, int, ChannelPipeline)} is written. {@code false} otherwise.
289 */
290 protected abstract boolean closeAfterContinueResponse(Object msg) throws Exception;
291
292 /**
293 * Determine if all objects for the current request/response should be ignored or not.
294 * Messages will stop being ignored the next time {@link #tryContentMessage(Object)} returns a {@code non null}
295 * value.
296 *
297 * @param msg The return value from {@link #newContinueResponse(Object, int, ChannelPipeline)}.
298 * @return {@code true} if all objects for the current request/response should be ignored or not.
299 * {@code false} otherwise.
300 */
301 protected abstract boolean ignoreContentAfterContinueResponse(Object msg) throws Exception;
302
303 /**
304 * Creates a new aggregated message from the specified start message.
305 */
306 protected abstract A beginAggregation(BufferAllocator allocator, S start) throws Exception;
307
308 /**
309 * Aggregated the passed {@code content} in the passed {@code aggregate}.
310 */
311 protected abstract void aggregate(BufferAllocator allocator, A aggregated, C content) throws Exception;
312
313 private void finishAggregation0(BufferAllocator allocator, A aggregated) throws Exception {
314 aggregating = false;
315 finishAggregation(allocator, aggregated);
316 }
317
318 /**
319 * Invoked when the specified {@code aggregated} message is about to be passed to the next handler in the pipeline.
320 */
321 protected void finishAggregation(BufferAllocator allocator, A aggregated) throws Exception { }
322
323 private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, Object oversized) throws Exception {
324 handlingOversizedMessage = true;
325 currentMessage = null;
326 try {
327 handleOversizedMessage(ctx, oversized);
328 } finally {
329 if (oversized instanceof AutoCloseable) {
330 ((AutoCloseable) oversized).close();
331 }
332 }
333 }
334
335 /**
336 * Invoked when an incoming request exceeds the maximum content length. The default behavior is to trigger an
337 * {@code exceptionCaught()} event with a {@link TooLongFrameException}.
338 *
339 * @param ctx the {@link ChannelHandlerContext}
340 * @param oversized the accumulated message up to this point.
341 */
342 protected void handleOversizedMessage(ChannelHandlerContext ctx, @SuppressWarnings("unused") Object oversized)
343 throws Exception {
344 ctx.fireChannelExceptionCaught(
345 new TooLongFrameException("content length exceeded " + maxContentLength() + " bytes."));
346 }
347
348 @Override
349 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
350 // We might need keep reading the channel until the full message is aggregated.
351 //
352 // See https://github.com/netty/netty/issues/6583
353 if (currentMessage != null && !ctx.channel().getOption(ChannelOption.AUTO_READ)) {
354 ctx.read();
355 }
356 ctx.fireChannelReadComplete();
357 }
358
359 @Override
360 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
361 try {
362 // release current message if it is not null as it may be a left-over
363 super.channelInactive(ctx);
364 } finally {
365 releaseCurrentMessage();
366 }
367 }
368
369 @Override
370 public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
371 this.ctx = ctx;
372 }
373
374 @Override
375 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
376 try {
377 super.handlerRemoved(ctx);
378 } finally {
379 // release current message if it is not null as it may be a left-over as there is not much more we can do in
380 // this case
381 releaseCurrentMessage();
382 }
383 }
384
385 private void releaseCurrentMessage() throws Exception {
386 if (currentMessage != null) {
387 currentMessage.close();
388 currentMessage = null;
389 handlingOversizedMessage = false;
390 aggregating = false;
391 }
392 }
393 }