View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufHolder;
20  import io.netty.buffer.CompositeByteBuf;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelFutureListener;
23  import io.netty.channel.ChannelHandler;
24  import io.netty.channel.ChannelHandlerContext;
25  import io.netty.channel.ChannelPipeline;
26  import io.netty.util.ReferenceCountUtil;
27  
28  import java.util.List;
29  
30  import static io.netty.buffer.Unpooled.EMPTY_BUFFER;
31  import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
32  
33  /**
34   * An abstract {@link ChannelHandler} that aggregates a series of message objects into a single aggregated message.
35   * <p>
36   * 'A series of messages' is composed of the following:
37   * <ul>
38   * <li>a single start message which optionally contains the first part of the content, and</li>
39   * <li>1 or more content messages.</li>
40   * </ul>
41   * The content of the aggregated message will be the merged content of the start message and its following content
42   * messages. If this aggregator encounters a content message where {@link #isLastContentMessage(ByteBufHolder)}
43   * return {@code true} for, the aggregator will finish the aggregation and produce the aggregated message and expect
44   * another start message.
45   * </p>
46   *
47   * @param <I> the type that covers both start message and content message
48   * @param <S> the type of the start message
49   * @param <C> the type of the content message (must be a subtype of {@link ByteBufHolder})
50   * @param <O> the type of the aggregated message (must be a subtype of {@code S} and {@link ByteBufHolder})
51   */
52  public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends ByteBufHolder>
53          extends MessageToMessageDecoder<I> {
54  
55      private static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024;
56  
57      private final int maxContentLength;
58      private O currentMessage;
59      private boolean handlingOversizedMessage;
60  
61      private int maxCumulationBufferComponents = DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS;
62      private ChannelHandlerContext ctx;
63      private ChannelFutureListener continueResponseWriteListener;
64  
65      private boolean aggregating;
66      private boolean handleIncompleteAggregateDuringClose = true;
67  
68      /**
69       * Creates a new instance.
70       *
71       * @param maxContentLength
72       *        the maximum length of the aggregated content.
73       *        If the length of the aggregated content exceeds this value,
74       *        {@link #handleOversizedMessage(ChannelHandlerContext, Object)} will be called.
75       */
76      protected MessageAggregator(int maxContentLength) {
77          validateMaxContentLength(maxContentLength);
78          this.maxContentLength = maxContentLength;
79      }
80  
81      protected MessageAggregator(int maxContentLength, Class<? extends I> inboundMessageType) {
82          super(inboundMessageType);
83          validateMaxContentLength(maxContentLength);
84          this.maxContentLength = maxContentLength;
85      }
86  
87      private static void validateMaxContentLength(int maxContentLength) {
88          checkPositiveOrZero(maxContentLength, "maxContentLength");
89      }
90  
91      @Override
92      public boolean acceptInboundMessage(Object msg) throws Exception {
93          // No need to match last and full types because they are subset of first and middle types.
94          if (!super.acceptInboundMessage(msg)) {
95              return false;
96          }
97  
98          @SuppressWarnings("unchecked")
99          I in = (I) msg;
100 
101         if (isAggregated(in)) {
102             return false;
103         }
104 
105         // NOTE: It's tempting to make this check only if aggregating is false. There are however
106         // side conditions in decode(...) in respect to large messages.
107         if (isStartMessage(in)) {
108             return true;
109         } else {
110             return aggregating && isContentMessage(in);
111         }
112     }
113 
114     /**
115      * Returns {@code true} if and only if the specified message is a start message. Typically, this method is
116      * implemented as a single {@code return} statement with {@code instanceof}:
117      * <pre>
118      * return msg instanceof MyStartMessage;
119      * </pre>
120      */
121     protected abstract boolean isStartMessage(I msg) throws Exception;
122 
123     /**
124      * Returns {@code true} if and only if the specified message is a content message. Typically, this method is
125      * implemented as a single {@code return} statement with {@code instanceof}:
126      * <pre>
127      * return msg instanceof MyContentMessage;
128      * </pre>
129      */
130     protected abstract boolean isContentMessage(I msg) throws Exception;
131 
132     /**
133      * Returns {@code true} if and only if the specified message is the last content message. Typically, this method is
134      * implemented as a single {@code return} statement with {@code instanceof}:
135      * <pre>
136      * return msg instanceof MyLastContentMessage;
137      * </pre>
138      * or with {@code instanceof} and boolean field check:
139      * <pre>
140      * return msg instanceof MyContentMessage && msg.isLastFragment();
141      * </pre>
142      */
143     protected abstract boolean isLastContentMessage(C msg) throws Exception;
144 
145     /**
146      * Returns {@code true} if and only if the specified message is already aggregated.  If this method returns
147      * {@code true}, this handler will simply forward the message to the next handler as-is.
148      */
149     protected abstract boolean isAggregated(I msg) throws Exception;
150 
151     /**
152      * Returns the maximum allowed length of the aggregated message in bytes.
153      */
154     public final int maxContentLength() {
155         return maxContentLength;
156     }
157 
158     /**
159      * Returns the maximum number of components in the cumulation buffer.  If the number of
160      * the components in the cumulation buffer exceeds this value, the components of the
161      * cumulation buffer are consolidated into a single component, involving memory copies.
162      * The default value of this property is {@value #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS}.
163      */
164     public final int maxCumulationBufferComponents() {
165         return maxCumulationBufferComponents;
166     }
167 
168     /**
169      * Sets the maximum number of components in the cumulation buffer.  If the number of
170      * the components in the cumulation buffer exceeds this value, the components of the
171      * cumulation buffer are consolidated into a single component, involving memory copies.
172      * The default value of this property is {@value #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS}
173      * and its minimum allowed value is {@code 2}.
174      */
175     public final void setMaxCumulationBufferComponents(int maxCumulationBufferComponents) {
176         if (maxCumulationBufferComponents < 2) {
177             throw new IllegalArgumentException(
178                     "maxCumulationBufferComponents: " + maxCumulationBufferComponents +
179                     " (expected: >= 2)");
180         }
181 
182         if (ctx == null) {
183             this.maxCumulationBufferComponents = maxCumulationBufferComponents;
184         } else {
185             throw new IllegalStateException(
186                     "decoder properties cannot be changed once the decoder is added to a pipeline.");
187         }
188     }
189 
190     /**
191      * @deprecated This method will be removed in future releases.
192      */
193     @Deprecated
194     public final boolean isHandlingOversizedMessage() {
195         return handlingOversizedMessage;
196     }
197 
198     protected final ChannelHandlerContext ctx() {
199         if (ctx == null) {
200             throw new IllegalStateException("not added to a pipeline yet");
201         }
202         return ctx;
203     }
204 
205     @Override
206     protected void decode(final ChannelHandlerContext ctx, I msg, List<Object> out) throws Exception {
207         if (isStartMessage(msg)) {
208             aggregating = true;
209             handlingOversizedMessage = false;
210             if (currentMessage != null) {
211                 currentMessage.release();
212                 currentMessage = null;
213                 throw new MessageAggregationException();
214             }
215 
216             @SuppressWarnings("unchecked")
217             S m = (S) msg;
218 
219             // Send the continue response if necessary (e.g. 'Expect: 100-continue' header)
220             // Check before content length. Failing an expectation may result in a different response being sent.
221             Object continueResponse = newContinueResponse(m, maxContentLength, ctx.pipeline());
222             if (continueResponse != null) {
223                 // Cache the write listener for reuse.
224                 ChannelFutureListener listener = continueResponseWriteListener;
225                 if (listener == null) {
226                     continueResponseWriteListener = listener = new ChannelFutureListener() {
227                         @Override
228                         public void operationComplete(ChannelFuture future) throws Exception {
229                             if (!future.isSuccess()) {
230                                 ctx.fireExceptionCaught(future.cause());
231                             }
232                         }
233                     };
234                 }
235 
236                 // Make sure to call this before writing, otherwise reference counts may be invalid.
237                 boolean closeAfterWrite = closeAfterContinueResponse(continueResponse);
238                 handlingOversizedMessage = ignoreContentAfterContinueResponse(continueResponse);
239 
240                 final ChannelFuture future = ctx.writeAndFlush(continueResponse).addListener(listener);
241 
242                 if (closeAfterWrite) {
243                     handleIncompleteAggregateDuringClose = false;
244                     future.addListener(ChannelFutureListener.CLOSE);
245                     return;
246                 }
247                 if (handlingOversizedMessage) {
248                     return;
249                 }
250             } else if (isContentLengthInvalid(m, maxContentLength)) {
251                 // if content length is set, preemptively close if it's too large
252                 invokeHandleOversizedMessage(ctx, m);
253                 return;
254             }
255 
256             if (m instanceof DecoderResultProvider && !((DecoderResultProvider) m).decoderResult().isSuccess()) {
257                 O aggregated;
258                 if (m instanceof ByteBufHolder) {
259                     aggregated = beginAggregation(m, ((ByteBufHolder) m).content().retain());
260                 } else {
261                     aggregated = beginAggregation(m, EMPTY_BUFFER);
262                 }
263                 finishAggregation0(aggregated);
264                 out.add(aggregated);
265                 return;
266             }
267 
268             // A streamed message - initialize the cumulative buffer, and wait for incoming chunks.
269             CompositeByteBuf content = ctx.alloc().compositeBuffer(maxCumulationBufferComponents);
270             if (m instanceof ByteBufHolder) {
271                 appendPartialContent(content, ((ByteBufHolder) m).content());
272             }
273             currentMessage = beginAggregation(m, content);
274         } else if (isContentMessage(msg)) {
275             if (currentMessage == null) {
276                 // it is possible that a TooLongFrameException was already thrown but we can still discard data
277                 // until the begging of the next request/response.
278                 return;
279             }
280 
281             // Merge the received chunk into the content of the current message.
282             CompositeByteBuf content = (CompositeByteBuf) currentMessage.content();
283 
284             @SuppressWarnings("unchecked")
285             final C m = (C) msg;
286             // Handle oversized message.
287             if (content.readableBytes() > maxContentLength - m.content().readableBytes()) {
288                 // By convention, full message type extends first message type.
289                 @SuppressWarnings("unchecked")
290                 S s = (S) currentMessage;
291                 invokeHandleOversizedMessage(ctx, s);
292                 return;
293             }
294 
295             // Append the content of the chunk.
296             appendPartialContent(content, m.content());
297 
298             // Give the subtypes a chance to merge additional information such as trailing headers.
299             aggregate(currentMessage, m);
300 
301             final boolean last;
302             if (m instanceof DecoderResultProvider) {
303                 DecoderResult decoderResult = ((DecoderResultProvider) m).decoderResult();
304                 if (!decoderResult.isSuccess()) {
305                     if (currentMessage instanceof DecoderResultProvider) {
306                         ((DecoderResultProvider) currentMessage).setDecoderResult(
307                                 DecoderResult.failure(decoderResult.cause()));
308                     }
309                     last = true;
310                 } else {
311                     last = isLastContentMessage(m);
312                 }
313             } else {
314                 last = isLastContentMessage(m);
315             }
316 
317             if (last) {
318                 finishAggregation0(currentMessage);
319 
320                 // All done
321                 out.add(currentMessage);
322                 currentMessage = null;
323             }
324         } else {
325             throw new MessageAggregationException();
326         }
327     }
328 
329     private static void appendPartialContent(CompositeByteBuf content, ByteBuf partialContent) {
330         if (partialContent.isReadable()) {
331             content.addComponent(true, partialContent.retain());
332         }
333     }
334 
335     /**
336      * Determine if the message {@code start}'s content length is known, and if it greater than
337      * {@code maxContentLength}.
338      * @param start The message which may indicate the content length.
339      * @param maxContentLength The maximum allowed content length.
340      * @return {@code true} if the message {@code start}'s content length is known, and if it greater than
341      * {@code maxContentLength}. {@code false} otherwise.
342      */
343     protected abstract boolean isContentLengthInvalid(S start, int maxContentLength) throws Exception;
344 
345     /**
346      * Returns the 'continue response' for the specified start message if necessary. For example, this method is
347      * useful to handle an HTTP 100-continue header.
348      *
349      * @return the 'continue response', or {@code null} if there's no message to send
350      */
351     protected abstract Object newContinueResponse(S start, int maxContentLength, ChannelPipeline pipeline)
352             throws Exception;
353 
354     /**
355      * Determine if the channel should be closed after the result of
356      * {@link #newContinueResponse(Object, int, ChannelPipeline)} is written.
357      * @param msg The return value from {@link #newContinueResponse(Object, int, ChannelPipeline)}.
358      * @return {@code true} if the channel should be closed after the result of
359      * {@link #newContinueResponse(Object, int, ChannelPipeline)} is written. {@code false} otherwise.
360      */
361     protected abstract boolean closeAfterContinueResponse(Object msg) throws Exception;
362 
363     /**
364      * Determine if all objects for the current request/response should be ignored or not.
365      * Messages will stop being ignored the next time {@link #isContentMessage(Object)} returns {@code true}.
366      *
367      * @param msg The return value from {@link #newContinueResponse(Object, int, ChannelPipeline)}.
368      * @return {@code true} if all objects for the current request/response should be ignored or not.
369      * {@code false} otherwise.
370      */
371     protected abstract boolean ignoreContentAfterContinueResponse(Object msg) throws Exception;
372 
373     /**
374      * Creates a new aggregated message from the specified start message and the specified content.  If the start
375      * message implements {@link ByteBufHolder}, its content is appended to the specified {@code content}.
376      * This aggregator will continue to append the received content to the specified {@code content}.
377      */
378     protected abstract O beginAggregation(S start, ByteBuf content) throws Exception;
379 
380     /**
381      * Transfers the information provided by the specified content message to the specified aggregated message.
382      * Note that the content of the specified content message has been appended to the content of the specified
383      * aggregated message already, so that you don't need to.  Use this method to transfer the additional information
384      * that the content message provides to {@code aggregated}.
385      */
386     protected void aggregate(O aggregated, C content) throws Exception { }
387 
388     private void finishAggregation0(O aggregated) throws Exception {
389         aggregating = false;
390         finishAggregation(aggregated);
391     }
392 
393     /**
394      * Invoked when the specified {@code aggregated} message is about to be passed to the next handler in the pipeline.
395      */
396     protected void finishAggregation(O aggregated) throws Exception { }
397 
398     private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception {
399         handlingOversizedMessage = true;
400         currentMessage = null;
401         handleIncompleteAggregateDuringClose = false;
402         try {
403             handleOversizedMessage(ctx, oversized);
404         } finally {
405             // Release the message in case it is a full one.
406             ReferenceCountUtil.release(oversized);
407         }
408     }
409 
410     /**
411      * Invoked when an incoming request exceeds the maximum content length.  The default behvaior is to trigger an
412      * {@code exceptionCaught()} event with a {@link TooLongFrameException}.
413      *
414      * @param ctx the {@link ChannelHandlerContext}
415      * @param oversized the accumulated message up to this point, whose type is {@code S} or {@code O}
416      */
417     protected void handleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception {
418         ctx.fireExceptionCaught(
419                 new TooLongFrameException("content length exceeded " + maxContentLength() + " bytes."));
420     }
421 
422     @Override
423     public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
424         // We might need keep reading the channel until the full message is aggregated.
425         //
426         // See https://github.com/netty/netty/issues/6583
427         if (currentMessage != null && !ctx.channel().config().isAutoRead()) {
428             ctx.read();
429         }
430         ctx.fireChannelReadComplete();
431     }
432 
433     @Override
434     public void channelInactive(ChannelHandlerContext ctx) throws Exception {
435         if (aggregating && handleIncompleteAggregateDuringClose) {
436             ctx.fireExceptionCaught(
437                     new PrematureChannelClosureException("Channel closed while still aggregating message"));
438         }
439         try {
440             // release current message if it is not null as it may be a left-over
441             super.channelInactive(ctx);
442         } finally {
443             releaseCurrentMessage();
444         }
445     }
446 
447     @Override
448     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
449         this.ctx = ctx;
450     }
451 
452     @Override
453     public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
454         try {
455             super.handlerRemoved(ctx);
456         } finally {
457             // release current message if it is not null as it may be a left-over as there is not much more we can do in
458             // this case
459             releaseCurrentMessage();
460         }
461     }
462 
463     private void releaseCurrentMessage() {
464         if (currentMessage != null) {
465             currentMessage.release();
466             currentMessage = null;
467             handlingOversizedMessage = false;
468             aggregating = false;
469         }
470     }
471 }