1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufHolder;
20 import io.netty.buffer.CompositeByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.ChannelFuture;
23 import io.netty.channel.ChannelFutureListener;
24 import io.netty.channel.ChannelHandler;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.util.ReferenceCountUtil;
27
28 import java.util.List;
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49 public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends ByteBufHolder>
50 extends MessageToMessageDecoder<I> {
51
52 private static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024;
53
54 private final int maxContentLength;
55 private O currentMessage;
56 private boolean handlingOversizedMessage;
57
58 private int maxCumulationBufferComponents = DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS;
59 private ChannelHandlerContext ctx;
60 private ChannelFutureListener continueResponseWriteListener;
61
62
63
64
65
66
67
68
69
70 protected MessageAggregator(int maxContentLength) {
71 validateMaxContentLength(maxContentLength);
72 this.maxContentLength = maxContentLength;
73 }
74
75 protected MessageAggregator(int maxContentLength, Class<? extends I> inboundMessageType) {
76 super(inboundMessageType);
77 validateMaxContentLength(maxContentLength);
78 this.maxContentLength = maxContentLength;
79 }
80
81 private static void validateMaxContentLength(int maxContentLength) {
82 if (maxContentLength <= 0) {
83 throw new IllegalArgumentException("maxContentLength must be a positive integer: " + maxContentLength);
84 }
85 }
86
87 @Override
88 public boolean acceptInboundMessage(Object msg) throws Exception {
89
90 if (!super.acceptInboundMessage(msg)) {
91 return false;
92 }
93
94 @SuppressWarnings("unchecked")
95 I in = (I) msg;
96
97 return (isContentMessage(in) || isStartMessage(in)) && !isAggregated(in);
98 }
99
100
101
102
103
104
105
106
107 protected abstract boolean isStartMessage(I msg) throws Exception;
108
109
110
111
112
113
114
115
116 protected abstract boolean isContentMessage(I msg) throws Exception;
117
118
119
120
121
122
123
124
125
126
127
128
129 protected abstract boolean isLastContentMessage(C msg) throws Exception;
130
131
132
133
134
135 protected abstract boolean isAggregated(I msg) throws Exception;
136
137
138
139
140 public final int maxContentLength() {
141 return maxContentLength;
142 }
143
144
145
146
147
148
149
150 public final int maxCumulationBufferComponents() {
151 return maxCumulationBufferComponents;
152 }
153
154
155
156
157
158
159
160
161 public final void setMaxCumulationBufferComponents(int maxCumulationBufferComponents) {
162 if (maxCumulationBufferComponents < 2) {
163 throw new IllegalArgumentException(
164 "maxCumulationBufferComponents: " + maxCumulationBufferComponents +
165 " (expected: >= 2)");
166 }
167
168 if (ctx == null) {
169 this.maxCumulationBufferComponents = maxCumulationBufferComponents;
170 } else {
171 throw new IllegalStateException(
172 "decoder properties cannot be changed once the decoder is added to a pipeline.");
173 }
174 }
175
176 public final boolean isHandlingOversizedMessage() {
177 return handlingOversizedMessage;
178 }
179
180 protected final ChannelHandlerContext ctx() {
181 if (ctx == null) {
182 throw new IllegalStateException("not added to a pipeline yet");
183 }
184 return ctx;
185 }
186
187 @Override
188 protected void decode(final ChannelHandlerContext ctx, I msg, List<Object> out) throws Exception {
189 O currentMessage = this.currentMessage;
190
191 if (isStartMessage(msg)) {
192 handlingOversizedMessage = false;
193 if (currentMessage != null) {
194 throw new MessageAggregationException();
195 }
196
197 @SuppressWarnings("unchecked")
198 S m = (S) msg;
199
200
201 if (hasContentLength(m)) {
202 if (contentLength(m) > maxContentLength) {
203
204 invokeHandleOversizedMessage(ctx, m);
205 return;
206 }
207 }
208
209
210 Object continueResponse = newContinueResponse(m);
211 if (continueResponse != null) {
212
213 ChannelFutureListener listener = continueResponseWriteListener;
214 if (listener == null) {
215 continueResponseWriteListener = listener = new ChannelFutureListener() {
216 @Override
217 public void operationComplete(ChannelFuture future) throws Exception {
218 if (!future.isSuccess()) {
219 ctx.fireExceptionCaught(future.cause());
220 }
221 }
222 };
223 }
224 ctx.writeAndFlush(continueResponse).addListener(listener);
225 }
226
227 if (m instanceof DecoderResultProvider && !((DecoderResultProvider) m).decoderResult().isSuccess()) {
228 O aggregated;
229 if (m instanceof ByteBufHolder && ((ByteBufHolder) m).content().isReadable()) {
230 aggregated = beginAggregation(m, ((ByteBufHolder) m).content().retain());
231 } else {
232 aggregated = beginAggregation(m, Unpooled.EMPTY_BUFFER);
233 }
234 finishAggregation(aggregated);
235 out.add(aggregated);
236 this.currentMessage = null;
237 return;
238 }
239
240
241 CompositeByteBuf content = ctx.alloc().compositeBuffer(maxCumulationBufferComponents);
242 if (m instanceof ByteBufHolder) {
243 appendPartialContent(content, ((ByteBufHolder) m).content());
244 }
245 this.currentMessage = beginAggregation(m, content);
246
247 } else if (isContentMessage(msg)) {
248 @SuppressWarnings("unchecked")
249 final C m = (C) msg;
250 final ByteBuf partialContent = ((ByteBufHolder) msg).content();
251 final boolean isLastContentMessage = isLastContentMessage(m);
252 if (handlingOversizedMessage) {
253 if (isLastContentMessage) {
254 this.currentMessage = null;
255 }
256
257 return;
258 }
259
260 if (currentMessage == null) {
261 throw new MessageAggregationException();
262 }
263
264
265 CompositeByteBuf content = (CompositeByteBuf) currentMessage.content();
266
267
268 if (content.readableBytes() > maxContentLength - partialContent.readableBytes()) {
269
270 @SuppressWarnings("unchecked")
271 S s = (S) currentMessage;
272 invokeHandleOversizedMessage(ctx, s);
273 return;
274 }
275
276
277 appendPartialContent(content, partialContent);
278
279
280 aggregate(currentMessage, m);
281
282 final boolean last;
283 if (m instanceof DecoderResultProvider) {
284 DecoderResult decoderResult = ((DecoderResultProvider) m).decoderResult();
285 if (!decoderResult.isSuccess()) {
286 if (currentMessage instanceof DecoderResultProvider) {
287 ((DecoderResultProvider) currentMessage).setDecoderResult(
288 DecoderResult.failure(decoderResult.cause()));
289 }
290 last = true;
291 } else {
292 last = isLastContentMessage;
293 }
294 } else {
295 last = isLastContentMessage;
296 }
297
298 if (last) {
299 finishAggregation(currentMessage);
300
301
302 out.add(currentMessage);
303 this.currentMessage = null;
304 }
305 } else {
306 throw new MessageAggregationException();
307 }
308 }
309
310 private static void appendPartialContent(CompositeByteBuf content, ByteBuf partialContent) {
311 if (partialContent.isReadable()) {
312 partialContent.retain();
313 content.addComponent(partialContent);
314 content.writerIndex(content.writerIndex() + partialContent.readableBytes());
315 }
316 }
317
318
319
320
321
322 protected abstract boolean hasContentLength(S start) throws Exception;
323
324
325
326
327
328 protected abstract long contentLength(S start) throws Exception;
329
330
331
332
333
334
335
336 protected abstract Object newContinueResponse(S start) throws Exception;
337
338
339
340
341
342
343 protected abstract O beginAggregation(S start, ByteBuf content) throws Exception;
344
345
346
347
348
349
350
351 protected void aggregate(O aggregated, C content) throws Exception { }
352
353
354
355
356 protected void finishAggregation(O aggregated) throws Exception { }
357
358 private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception {
359 handlingOversizedMessage = true;
360 currentMessage = null;
361 try {
362 handleOversizedMessage(ctx, oversized);
363 } finally {
364
365 ReferenceCountUtil.release(oversized);
366 }
367 }
368
369
370
371
372
373
374
375
376 protected void handleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception {
377 ctx.fireExceptionCaught(
378 new TooLongFrameException("content length exceeded " + maxContentLength() + " bytes."));
379 }
380
381 @Override
382 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
383
384 if (currentMessage != null) {
385 currentMessage.release();
386 currentMessage = null;
387 }
388
389 super.channelInactive(ctx);
390 }
391
392 @Override
393 public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
394 this.ctx = ctx;
395 }
396
397 @Override
398 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
399 super.handlerRemoved(ctx);
400
401
402
403 if (currentMessage != null) {
404 currentMessage.release();
405 currentMessage = null;
406 }
407 }
408 }