1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package io.netty5.handler.codec;
16
17 import io.netty5.buffer.api.BufferAllocator;
18 import io.netty5.channel.ChannelFutureListeners;
19 import io.netty5.channel.ChannelHandler;
20 import io.netty5.channel.ChannelHandlerContext;
21 import io.netty5.channel.ChannelOption;
22 import io.netty5.channel.ChannelPipeline;
23 import io.netty5.util.concurrent.Future;
24 import io.netty5.util.concurrent.FutureContextListener;
25
26 import static io.netty5.util.internal.ObjectUtil.checkPositiveOrZero;
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47 public abstract class MessageAggregator<I, S, C extends AutoCloseable, A extends AutoCloseable>
48 extends MessageToMessageDecoder<I> {
49 private final int maxContentLength;
50 private A currentMessage;
51 private boolean handlingOversizedMessage;
52
53 private ChannelHandlerContext ctx;
54 private FutureContextListener<ChannelHandlerContext, Void> continueResponseWriteListener;
55
56 private boolean aggregating;
57
58
59
60
61
62
63
64
65
66 protected MessageAggregator(int maxContentLength) {
67 validateMaxContentLength(maxContentLength);
68 this.maxContentLength = maxContentLength;
69 }
70
71 protected MessageAggregator(int maxContentLength, Class<? extends I> inboundMessageType) {
72 super(inboundMessageType);
73 validateMaxContentLength(maxContentLength);
74 this.maxContentLength = maxContentLength;
75 }
76
77 private static void validateMaxContentLength(int maxContentLength) {
78 checkPositiveOrZero(maxContentLength, "maxContentLength");
79 }
80
81 @Override
82 public boolean acceptInboundMessage(Object msg) throws Exception {
83
84 if (!super.acceptInboundMessage(msg)) {
85 return false;
86 }
87
88 if (isAggregated(msg)) {
89 return false;
90 }
91
92
93
94 if (tryStartMessage(msg) != null) {
95 aggregating = true;
96 return true;
97 }
98 return aggregating && tryContentMessage(msg) != null;
99 }
100
101
102
103
104
105 protected abstract S tryStartMessage(Object msg);
106
107
108
109
110
111 protected abstract C tryContentMessage(Object msg);
112
113
114
115
116
117
118
119
120
121
122
123
124 protected abstract boolean isLastContentMessage(C msg) throws Exception;
125
126
127
128
129
130 protected abstract boolean isAggregated(Object msg) throws Exception;
131
132
133
134
135
136
137
138 protected abstract int lengthForContent(C msg);
139
140
141
142
143
144
145
146 protected abstract int lengthForAggregation(A msg);
147
148
149
150
151 public final int maxContentLength() {
152 return maxContentLength;
153 }
154
155 protected final ChannelHandlerContext ctx() {
156 if (ctx == null) {
157 throw new IllegalStateException("not added to a pipeline yet");
158 }
159 return ctx;
160 }
161
162 @Override
163 protected void decode(final ChannelHandlerContext ctx, I msg) throws Exception {
164 assert aggregating;
165 final S startMsg = tryStartMessage(msg);
166 if (startMsg != null) {
167 handlingOversizedMessage = false;
168 if (currentMessage != null) {
169 currentMessage.close();
170 currentMessage = null;
171 throw new MessageAggregationException();
172 }
173
174
175
176 Object continueResponse = newContinueResponse(startMsg, maxContentLength, ctx.pipeline());
177 if (continueResponse != null) {
178
179 FutureContextListener<ChannelHandlerContext, Void> listener = continueResponseWriteListener;
180 if (listener == null) {
181 continueResponseWriteListener = listener = (context, future) -> {
182 if (future.isFailed()) {
183 context.fireChannelExceptionCaught(future.cause());
184 }
185 };
186 }
187
188
189 boolean closeAfterWrite = closeAfterContinueResponse(continueResponse);
190 handlingOversizedMessage = ignoreContentAfterContinueResponse(continueResponse);
191
192 Future<Void> future = ctx.writeAndFlush(continueResponse).addListener(ctx, listener);
193
194 if (closeAfterWrite) {
195 future.addListener(ctx, ChannelFutureListeners.CLOSE);
196 return;
197 }
198 if (handlingOversizedMessage) {
199 return;
200 }
201 } else if (isContentLengthInvalid(startMsg, maxContentLength)) {
202
203 invokeHandleOversizedMessage(ctx, startMsg);
204 return;
205 }
206
207 if (startMsg instanceof DecoderResultProvider &&
208 !((DecoderResultProvider) startMsg).decoderResult().isSuccess()) {
209 final A aggregated = beginAggregation(ctx.bufferAllocator(), startMsg);
210 finishAggregation(ctx.bufferAllocator(), aggregated);
211 ctx.fireChannelRead(aggregated);
212 return;
213 }
214
215 currentMessage = beginAggregation(ctx.bufferAllocator(), startMsg);
216 return;
217 }
218
219 final C contentMsg = tryContentMessage(msg);
220 if (contentMsg != null) {
221 if (currentMessage == null) {
222
223
224 return;
225 }
226
227
228 if (lengthForAggregation(currentMessage) > maxContentLength - lengthForContent(contentMsg)) {
229 invokeHandleOversizedMessage(ctx, currentMessage);
230 return;
231 }
232
233 aggregate(ctx.bufferAllocator(), currentMessage, contentMsg);
234
235 final boolean last;
236 if (contentMsg instanceof DecoderResultProvider) {
237 DecoderResult decoderResult = ((DecoderResultProvider) contentMsg).decoderResult();
238 if (!decoderResult.isSuccess()) {
239 if (currentMessage instanceof DecoderResultProvider) {
240 ((DecoderResultProvider) currentMessage).setDecoderResult(
241 DecoderResult.failure(decoderResult.cause()));
242 }
243 last = true;
244 } else {
245 last = isLastContentMessage(contentMsg);
246 }
247 } else {
248 last = isLastContentMessage(contentMsg);
249 }
250
251 if (last) {
252 finishAggregation0(ctx.bufferAllocator(), currentMessage);
253
254
255 A message = currentMessage;
256 currentMessage = null;
257 ctx.fireChannelRead(message);
258 }
259 } else {
260 throw new MessageAggregationException();
261 }
262 }
263
264
265
266
267
268
269
270
271
272 protected abstract boolean isContentLengthInvalid(S start, int maxContentLength) throws Exception;
273
274
275
276
277
278
279
280 protected abstract Object newContinueResponse(S start, int maxContentLength, ChannelPipeline pipeline)
281 throws Exception;
282
283
284
285
286
287
288
289
290 protected abstract boolean closeAfterContinueResponse(Object msg) throws Exception;
291
292
293
294
295
296
297
298
299
300
301 protected abstract boolean ignoreContentAfterContinueResponse(Object msg) throws Exception;
302
303
304
305
306 protected abstract A beginAggregation(BufferAllocator allocator, S start) throws Exception;
307
308
309
310
311 protected abstract void aggregate(BufferAllocator allocator, A aggregated, C content) throws Exception;
312
313 private void finishAggregation0(BufferAllocator allocator, A aggregated) throws Exception {
314 aggregating = false;
315 finishAggregation(allocator, aggregated);
316 }
317
318
319
320
321 protected void finishAggregation(BufferAllocator allocator, A aggregated) throws Exception { }
322
323 private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, Object oversized) throws Exception {
324 handlingOversizedMessage = true;
325 currentMessage = null;
326 try {
327 handleOversizedMessage(ctx, oversized);
328 } finally {
329 if (oversized instanceof AutoCloseable) {
330 ((AutoCloseable) oversized).close();
331 }
332 }
333 }
334
335
336
337
338
339
340
341
342 protected void handleOversizedMessage(ChannelHandlerContext ctx, @SuppressWarnings("unused") Object oversized)
343 throws Exception {
344 ctx.fireChannelExceptionCaught(
345 new TooLongFrameException("content length exceeded " + maxContentLength() + " bytes."));
346 }
347
348 @Override
349 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
350
351
352
353 if (currentMessage != null && !ctx.channel().getOption(ChannelOption.AUTO_READ)) {
354 ctx.read();
355 }
356 ctx.fireChannelReadComplete();
357 }
358
359 @Override
360 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
361 try {
362
363 super.channelInactive(ctx);
364 } finally {
365 releaseCurrentMessage();
366 }
367 }
368
369 @Override
370 public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
371 this.ctx = ctx;
372 }
373
374 @Override
375 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
376 try {
377 super.handlerRemoved(ctx);
378 } finally {
379
380
381 releaseCurrentMessage();
382 }
383 }
384
385 private void releaseCurrentMessage() throws Exception {
386 if (currentMessage != null) {
387 currentMessage.close();
388 currentMessage = null;
389 handlingOversizedMessage = false;
390 aggregating = false;
391 }
392 }
393 }