View Javadoc
1   /*
2    * Copyright 2021 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License, version 2.0 (the
5    * "License"); you may not use this file except in compliance with the License. You may obtain a
6    * copy of the License at:
7    *
8    * https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software distributed under the License
11   * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12   * or implied. See the License for the specific language governing permissions and limitations under
13   * the License.
14   */
15  package io.netty5.handler.codec;
16  
17  import io.netty5.buffer.api.BufferAllocator;
18  import io.netty5.channel.ChannelFutureListeners;
19  import io.netty5.channel.ChannelHandler;
20  import io.netty5.channel.ChannelHandlerContext;
21  import io.netty5.channel.ChannelOption;
22  import io.netty5.channel.ChannelPipeline;
23  import io.netty5.util.concurrent.Future;
24  import io.netty5.util.concurrent.FutureContextListener;
25  
26  import static io.netty5.util.internal.ObjectUtil.checkPositiveOrZero;
27  
28  /**
29   * An abstract {@link ChannelHandler} that aggregates a series of message objects into a single aggregated message.
30   * <p>
31   * 'A series of messages' is composed of the following:
32   * <ul>
33   * <li>a single start message which optionally contains the first part of the content, and</li>
34   * <li>1 or more content messages.</li>
35   * </ul>
36   * The content of the aggregated message will be the merged content of the start message and its following content
37   * messages. If this aggregator encounters a content message where {@link #isLastContentMessage(AutoCloseable)}
38   * return {@code true} for, the aggregator will finish the aggregation and produce the aggregated message and expect
39   * another start message.
40   * </p>
41   *
42   * @param <I> the type that covers both start message and content message
43   * @param <S> the type of the start message
44   * @param <C> the type of the content message
45   * @param <A> the type of the aggregated message
46   */
47  public abstract class MessageAggregator<I, S, C extends AutoCloseable, A extends AutoCloseable>
48          extends MessageToMessageDecoder<I> {
49      private final int maxContentLength;
50      private A currentMessage;
51      private boolean handlingOversizedMessage;
52  
53      private ChannelHandlerContext ctx;
54      private FutureContextListener<ChannelHandlerContext, Void> continueResponseWriteListener;
55  
56      private boolean aggregating;
57  
58      /**
59       * Creates a new instance.
60       *
61       * @param maxContentLength
62       *        the maximum length of the aggregated content.
63       *        If the length of the aggregated content exceeds this value,
64       *        {@link #handleOversizedMessage(ChannelHandlerContext, Object)} will be called.
65       */
66      protected MessageAggregator(int maxContentLength) {
67          validateMaxContentLength(maxContentLength);
68          this.maxContentLength = maxContentLength;
69      }
70  
71      protected MessageAggregator(int maxContentLength, Class<? extends I> inboundMessageType) {
72          super(inboundMessageType);
73          validateMaxContentLength(maxContentLength);
74          this.maxContentLength = maxContentLength;
75      }
76  
77      private static void validateMaxContentLength(int maxContentLength) {
78          checkPositiveOrZero(maxContentLength, "maxContentLength");
79      }
80  
81      @Override
82      public boolean acceptInboundMessage(Object msg) throws Exception {
83          // No need to match last and full types because they are subset of first and middle types.
84          if (!super.acceptInboundMessage(msg)) {
85              return false;
86          }
87  
88          if (isAggregated(msg)) {
89              return false;
90          }
91  
92          // NOTE: It's tempting to make this check only if aggregating is false. There are however
93          // side conditions in decode(...) in respect to large messages.
94          if (tryStartMessage(msg) != null) {
95              aggregating = true;
96              return true;
97          }
98          return aggregating && tryContentMessage(msg) != null;
99      }
100 
101     /**
102      * If the passed {@code msg} is a {@linkplain S start message} then cast and return the same, else return
103      * {@code null}.
104      */
105     protected abstract S tryStartMessage(Object msg);
106 
107     /**
108      * If the passed {@code msg} is a {@linkplain C content message} then cast and return the same, else return
109      * {@code null}.
110      */
111     protected abstract C tryContentMessage(Object msg);
112 
113     /**
114      * Returns {@code true} if and only if the specified message is the last content message. Typically, this method is
115      * implemented as a single {@code return} statement with {@code instanceof}:
116      * <pre>
117      * return msg instanceof MyLastContentMessage;
118      * </pre>
119      * or with {@code instanceof} and boolean field check:
120      * <pre>
121      * return msg instanceof MyContentMessage && msg.isLastFragment();
122      * </pre>
123      */
124     protected abstract boolean isLastContentMessage(C msg) throws Exception;
125 
126     /**
127      * Returns {@code true} if and only if the specified message is already aggregated.  If this method returns
128      * {@code true}, this handler will simply forward the message to the next handler as-is.
129      */
130     protected abstract boolean isAggregated(Object msg) throws Exception;
131 
132     /**
133      * Returns the length in bytes of the passed message.
134      *
135      * @param msg to calculate length.
136      * @return Length in bytes of the passed message.
137      */
138     protected abstract int lengthForContent(C msg);
139 
140     /**
141      * Returns the length in bytes of the passed message.
142      *
143      * @param msg to calculate length.
144      * @return Length in bytes of the passed message.
145      */
146     protected abstract int lengthForAggregation(A msg);
147 
148     /**
149      * Returns the maximum allowed length of the aggregated message in bytes.
150      */
151     public final int maxContentLength() {
152         return maxContentLength;
153     }
154 
155     protected final ChannelHandlerContext ctx() {
156         if (ctx == null) {
157             throw new IllegalStateException("not added to a pipeline yet");
158         }
159         return ctx;
160     }
161 
162     @Override
163     protected void decode(final ChannelHandlerContext ctx, I msg) throws Exception {
164         assert aggregating;
165         final S startMsg = tryStartMessage(msg);
166         if (startMsg != null) {
167             handlingOversizedMessage = false;
168             if (currentMessage != null) {
169                 currentMessage.close();
170                 currentMessage = null;
171                 throw new MessageAggregationException();
172             }
173 
174             // Send the continue response if necessary (e.g. 'Expect: 100-continue' header)
175             // Check before content length. Failing an expectation may result in a different response being sent.
176             Object continueResponse = newContinueResponse(startMsg, maxContentLength, ctx.pipeline());
177             if (continueResponse != null) {
178                 // Cache the write listener for reuse.
179                 FutureContextListener<ChannelHandlerContext, Void> listener = continueResponseWriteListener;
180                 if (listener == null) {
181                     continueResponseWriteListener = listener = (context, future) -> {
182                         if (future.isFailed()) {
183                             context.fireChannelExceptionCaught(future.cause());
184                         }
185                     };
186                 }
187 
188                 // Make sure to call this before writing, otherwise reference counts may be invalid.
189                 boolean closeAfterWrite = closeAfterContinueResponse(continueResponse);
190                 handlingOversizedMessage = ignoreContentAfterContinueResponse(continueResponse);
191 
192                 Future<Void> future = ctx.writeAndFlush(continueResponse).addListener(ctx, listener);
193 
194                 if (closeAfterWrite) {
195                     future.addListener(ctx, ChannelFutureListeners.CLOSE);
196                     return;
197                 }
198                 if (handlingOversizedMessage) {
199                     return;
200                 }
201             } else if (isContentLengthInvalid(startMsg, maxContentLength)) {
202                 // if content length is set, preemptively close if it's too large
203                 invokeHandleOversizedMessage(ctx, startMsg);
204                 return;
205             }
206 
207             if (startMsg instanceof DecoderResultProvider &&
208                     !((DecoderResultProvider) startMsg).decoderResult().isSuccess()) {
209                 final A aggregated = beginAggregation(ctx.bufferAllocator(), startMsg);
210                 finishAggregation(ctx.bufferAllocator(), aggregated);
211                 ctx.fireChannelRead(aggregated);
212                 return;
213             }
214 
215             currentMessage = beginAggregation(ctx.bufferAllocator(), startMsg);
216             return;
217         }
218 
219         final C contentMsg = tryContentMessage(msg);
220         if (contentMsg != null) {
221             if (currentMessage == null) {
222                 // it is possible that a TooLongFrameException was already thrown but we can still discard data
223                 // until the beginning of the next request/response.
224                 return;
225             }
226 
227             // Handle oversized message.
228             if (lengthForAggregation(currentMessage) > maxContentLength - lengthForContent(contentMsg)) {
229                 invokeHandleOversizedMessage(ctx, currentMessage);
230                 return;
231             }
232 
233             aggregate(ctx.bufferAllocator(), currentMessage, contentMsg);
234 
235             final boolean last;
236             if (contentMsg instanceof DecoderResultProvider) {
237                 DecoderResult decoderResult = ((DecoderResultProvider) contentMsg).decoderResult();
238                 if (!decoderResult.isSuccess()) {
239                     if (currentMessage instanceof DecoderResultProvider) {
240                         ((DecoderResultProvider) currentMessage).setDecoderResult(
241                                 DecoderResult.failure(decoderResult.cause()));
242                     }
243                     last = true;
244                 } else {
245                     last = isLastContentMessage(contentMsg);
246                 }
247             } else {
248                 last = isLastContentMessage(contentMsg);
249             }
250 
251             if (last) {
252                 finishAggregation0(ctx.bufferAllocator(), currentMessage);
253 
254                 // All done
255                 A message = currentMessage;
256                 currentMessage = null;
257                 ctx.fireChannelRead(message);
258             }
259         } else {
260             throw new MessageAggregationException();
261         }
262     }
263 
264     /**
265      * Determine if the message {@code start}'s content length is known, and if it greater than
266      * {@code maxContentLength}.
267      * @param start The message which may indicate the content length.
268      * @param maxContentLength The maximum allowed content length.
269      * @return {@code true} if the message {@code start}'s content length is known, and if it greater than
270      * {@code maxContentLength}. {@code false} otherwise.
271      */
272     protected abstract boolean isContentLengthInvalid(S start, int maxContentLength) throws Exception;
273 
274     /**
275      * Returns the 'continue response' for the specified start message if necessary. For example, this method is
276      * useful to handle an HTTP 100-continue header.
277      *
278      * @return the 'continue response', or {@code null} if there's no message to send
279      */
280     protected abstract Object newContinueResponse(S start, int maxContentLength, ChannelPipeline pipeline)
281             throws Exception;
282 
283     /**
284      * Determine if the channel should be closed after the result of
285      * {@link #newContinueResponse(Object, int, ChannelPipeline)} is written.
286      * @param msg The return value from {@link #newContinueResponse(Object, int, ChannelPipeline)}.
287      * @return {@code true} if the channel should be closed after the result of
288      * {@link #newContinueResponse(Object, int, ChannelPipeline)} is written. {@code false} otherwise.
289      */
290     protected abstract boolean closeAfterContinueResponse(Object msg) throws Exception;
291 
292     /**
293      * Determine if all objects for the current request/response should be ignored or not.
294      * Messages will stop being ignored the next time {@link #tryContentMessage(Object)} returns a {@code non null}
295      * value.
296      *
297      * @param msg The return value from {@link #newContinueResponse(Object, int, ChannelPipeline)}.
298      * @return {@code true} if all objects for the current request/response should be ignored or not.
299      * {@code false} otherwise.
300      */
301     protected abstract boolean ignoreContentAfterContinueResponse(Object msg) throws Exception;
302 
303     /**
304      * Creates a new aggregated message from the specified start message.
305      */
306     protected abstract A beginAggregation(BufferAllocator allocator, S start) throws Exception;
307 
308     /**
309      * Aggregated the passed {@code content} in the passed {@code aggregate}.
310      */
311     protected abstract void aggregate(BufferAllocator allocator, A aggregated, C content) throws Exception;
312 
313     private void finishAggregation0(BufferAllocator allocator, A aggregated) throws Exception {
314         aggregating = false;
315         finishAggregation(allocator, aggregated);
316     }
317 
318     /**
319      * Invoked when the specified {@code aggregated} message is about to be passed to the next handler in the pipeline.
320      */
321     protected void finishAggregation(BufferAllocator allocator, A aggregated) throws Exception { }
322 
323     private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, Object oversized) throws Exception {
324         handlingOversizedMessage = true;
325         currentMessage = null;
326         try {
327             handleOversizedMessage(ctx, oversized);
328         } finally {
329             if (oversized instanceof AutoCloseable) {
330                 ((AutoCloseable) oversized).close();
331             }
332         }
333     }
334 
335     /**
336      * Invoked when an incoming request exceeds the maximum content length.  The default behavior is to trigger an
337      * {@code exceptionCaught()} event with a {@link TooLongFrameException}.
338      *
339      * @param ctx the {@link ChannelHandlerContext}
340      * @param oversized the accumulated message up to this point.
341      */
342     protected void handleOversizedMessage(ChannelHandlerContext ctx, @SuppressWarnings("unused") Object oversized)
343             throws Exception {
344         ctx.fireChannelExceptionCaught(
345                 new TooLongFrameException("content length exceeded " + maxContentLength() + " bytes."));
346     }
347 
348     @Override
349     public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
350         // We might need keep reading the channel until the full message is aggregated.
351         //
352         // See https://github.com/netty/netty/issues/6583
353         if (currentMessage != null && !ctx.channel().getOption(ChannelOption.AUTO_READ)) {
354             ctx.read();
355         }
356         ctx.fireChannelReadComplete();
357     }
358 
359     @Override
360     public void channelInactive(ChannelHandlerContext ctx) throws Exception {
361         try {
362             // release current message if it is not null as it may be a left-over
363             super.channelInactive(ctx);
364         } finally {
365             releaseCurrentMessage();
366         }
367     }
368 
369     @Override
370     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
371         this.ctx = ctx;
372     }
373 
374     @Override
375     public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
376         try {
377             super.handlerRemoved(ctx);
378         } finally {
379             // release current message if it is not null as it may be a left-over as there is not much more we can do in
380             // this case
381             releaseCurrentMessage();
382         }
383     }
384 
385     private void releaseCurrentMessage() throws Exception {
386         if (currentMessage != null) {
387             currentMessage.close();
388             currentMessage = null;
389             handlingOversizedMessage = false;
390             aggregating = false;
391         }
392     }
393 }