1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.handler.execution;
17
18 import java.io.IOException;
19 import java.lang.reflect.Method;
20 import java.util.HashSet;
21 import java.util.List;
22 import java.util.Set;
23 import java.util.concurrent.ConcurrentMap;
24 import java.util.concurrent.Executor;
25 import java.util.concurrent.Executors;
26 import java.util.concurrent.LinkedBlockingQueue;
27 import java.util.concurrent.RejectedExecutionException;
28 import java.util.concurrent.RejectedExecutionHandler;
29 import java.util.concurrent.ThreadFactory;
30 import java.util.concurrent.ThreadPoolExecutor;
31 import java.util.concurrent.TimeUnit;
32 import java.util.concurrent.atomic.AtomicLong;
33
34 import org.jboss.netty.buffer.ChannelBuffer;
35 import org.jboss.netty.channel.Channel;
36 import org.jboss.netty.channel.ChannelEvent;
37 import org.jboss.netty.channel.ChannelFuture;
38 import org.jboss.netty.channel.ChannelHandlerContext;
39 import org.jboss.netty.channel.ChannelState;
40 import org.jboss.netty.channel.ChannelStateEvent;
41 import org.jboss.netty.channel.Channels;
42 import org.jboss.netty.channel.MessageEvent;
43 import org.jboss.netty.channel.WriteCompletionEvent;
44 import org.jboss.netty.logging.InternalLogger;
45 import org.jboss.netty.logging.InternalLoggerFactory;
46 import org.jboss.netty.util.DefaultObjectSizeEstimator;
47 import org.jboss.netty.util.ObjectSizeEstimator;
48 import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
49 import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140
141 private static final InternalLogger logger =
142 InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143
144 private static final SharedResourceMisuseDetector misuseDetector =
145 new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146
147 private volatile Settings settings;
148
149 private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150 new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151 private final Limiter totalLimiter;
152
153 private volatile boolean notifyOnShutdown;
154
155
156
157
158
159
160
161
162
163
164 public MemoryAwareThreadPoolExecutor(
165 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166
167 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168 }
169
170
171
172
173
174
175
176
177
178
179
180
181 public MemoryAwareThreadPoolExecutor(
182 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183 long keepAliveTime, TimeUnit unit) {
184
185 this(
186 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187 Executors.defaultThreadFactory());
188 }
189
190
191
192
193
194
195
196
197
198
199
200
201
202 public MemoryAwareThreadPoolExecutor(
203 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204 long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205
206 this(
207 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208 new DefaultObjectSizeEstimator(), threadFactory);
209 }
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224 public MemoryAwareThreadPoolExecutor(
225 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226 long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227 ThreadFactory threadFactory) {
228
229 super(corePoolSize, corePoolSize, keepAliveTime, unit,
230 new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231
232 if (objectSizeEstimator == null) {
233 throw new NullPointerException("objectSizeEstimator");
234 }
235 if (maxChannelMemorySize < 0) {
236 throw new IllegalArgumentException(
237 "maxChannelMemorySize: " + maxChannelMemorySize);
238 }
239 if (maxTotalMemorySize < 0) {
240 throw new IllegalArgumentException(
241 "maxTotalMemorySize: " + maxTotalMemorySize);
242 }
243
244
245
246 try {
247 Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248 m.invoke(this, Boolean.TRUE);
249 } catch (Throwable t) {
250
251 logger.debug(
252 "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253 "supported in this platform.");
254 }
255
256 settings = new Settings(
257 objectSizeEstimator, maxChannelMemorySize);
258
259 if (maxTotalMemorySize == 0) {
260 totalLimiter = null;
261 } else {
262 totalLimiter = new Limiter(maxTotalMemorySize);
263 }
264
265
266 misuseDetector.increase();
267 }
268
269 @Override
270 protected void terminated() {
271 super.terminated();
272 misuseDetector.decrease();
273 }
274
275
276
277
278 @Override
279 public List<Runnable> shutdownNow() {
280 return shutdownNow(notifyOnShutdown);
281 }
282
283
284
285
286
287
288
289
290
291
292
293
294
295 public List<Runnable> shutdownNow(boolean notify) {
296 if (!notify) {
297 return super.shutdownNow();
298 }
299 Throwable cause = null;
300 Set<Channel> channels = null;
301
302 List<Runnable> tasks = super.shutdownNow();
303
304
305 for (Runnable task: tasks) {
306 if (task instanceof ChannelEventRunnable) {
307 if (cause == null) {
308 cause = new IOException("Unable to process queued event");
309 }
310 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311 event.getFuture().setFailure(cause);
312
313 if (channels == null) {
314 channels = new HashSet<Channel>();
315 }
316
317
318
319 channels.add(event.getChannel());
320 }
321 }
322
323
324 if (channels != null) {
325 for (Channel channel: channels) {
326 Channels.fireExceptionCaughtLater(channel, cause);
327 }
328 }
329 return tasks;
330 }
331
332
333
334
335 public ObjectSizeEstimator getObjectSizeEstimator() {
336 return settings.objectSizeEstimator;
337 }
338
339
340
341
342 public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
343 if (objectSizeEstimator == null) {
344 throw new NullPointerException("objectSizeEstimator");
345 }
346
347 settings = new Settings(
348 objectSizeEstimator,
349 settings.maxChannelMemorySize);
350 }
351
352
353
354
355 public long getMaxChannelMemorySize() {
356 return settings.maxChannelMemorySize;
357 }
358
359
360
361
362
363 public void setMaxChannelMemorySize(long maxChannelMemorySize) {
364 if (maxChannelMemorySize < 0) {
365 throw new IllegalArgumentException(
366 "maxChannelMemorySize: " + maxChannelMemorySize);
367 }
368
369 if (getTaskCount() > 0) {
370 throw new IllegalStateException(
371 "can't be changed after a task is executed");
372 }
373
374 settings = new Settings(
375 settings.objectSizeEstimator,
376 maxChannelMemorySize);
377 }
378
379
380
381
382 public long getMaxTotalMemorySize() {
383 if (totalLimiter == null) {
384 return 0;
385 }
386 return totalLimiter.limit;
387 }
388
389
390
391
392
393 @Deprecated
394 public void setMaxTotalMemorySize(long maxTotalMemorySize) {
395 if (maxTotalMemorySize < 0) {
396 throw new IllegalArgumentException(
397 "maxTotalMemorySize: " + maxTotalMemorySize);
398 }
399
400 if (getTaskCount() > 0) {
401 throw new IllegalStateException(
402 "can't be changed after a task is executed");
403 }
404 }
405
406
407
408
409
410
411
412
413
414
415
416 public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
417 this.notifyOnShutdown = notifyOnShutdown;
418 }
419
420
421
422
423
424 public boolean getNotifyChannelFuturesOnShutdown() {
425 return notifyOnShutdown;
426 }
427
428
429
430 @Override
431 public void execute(Runnable command) {
432 if (command instanceof ChannelDownstreamEventRunnable) {
433 throw new RejectedExecutionException("command must be enclosed with an upstream event.");
434 }
435 if (!(command instanceof ChannelEventRunnable)) {
436 command = new MemoryAwareRunnable(command);
437 }
438
439 increaseCounter(command);
440 doExecute(command);
441 }
442
443
444
445
446
447 protected void doExecute(Runnable task) {
448 doUnorderedExecute(task);
449 }
450
451
452
453
454 protected final void doUnorderedExecute(Runnable task) {
455 super.execute(task);
456 }
457
458 @Override
459 public boolean remove(Runnable task) {
460 boolean removed = super.remove(task);
461 if (removed) {
462 decreaseCounter(task);
463 }
464 return removed;
465 }
466
467 @Override
468 protected void beforeExecute(Thread t, Runnable r) {
469 super.beforeExecute(t, r);
470 decreaseCounter(r);
471 }
472
473 protected void increaseCounter(Runnable task) {
474 if (!shouldCount(task)) {
475 return;
476 }
477
478 Settings settings = this.settings;
479 long maxChannelMemorySize = settings.maxChannelMemorySize;
480
481 int increment = settings.objectSizeEstimator.estimateSize(task);
482
483 if (task instanceof ChannelEventRunnable) {
484 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
485 eventTask.estimatedSize = increment;
486 Channel channel = eventTask.getEvent().getChannel();
487 long channelCounter = getChannelCounter(channel).addAndGet(increment);
488
489 if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
490 if (channel.isReadable()) {
491
492 ChannelHandlerContext ctx = eventTask.getContext();
493 if (ctx.getHandler() instanceof ExecutionHandler) {
494
495 ctx.setAttachment(Boolean.TRUE);
496 }
497 channel.setReadable(false);
498 }
499 }
500 } else {
501 ((MemoryAwareRunnable) task).estimatedSize = increment;
502 }
503
504 if (totalLimiter != null) {
505 totalLimiter.increase(increment);
506 }
507 }
508
509 protected void decreaseCounter(Runnable task) {
510 if (!shouldCount(task)) {
511 return;
512 }
513
514 Settings settings = this.settings;
515 long maxChannelMemorySize = settings.maxChannelMemorySize;
516
517 int increment;
518 if (task instanceof ChannelEventRunnable) {
519 increment = ((ChannelEventRunnable) task).estimatedSize;
520 } else {
521 increment = ((MemoryAwareRunnable) task).estimatedSize;
522 }
523
524 if (totalLimiter != null) {
525 totalLimiter.decrease(increment);
526 }
527
528 if (task instanceof ChannelEventRunnable) {
529 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
530 Channel channel = eventTask.getEvent().getChannel();
531 long channelCounter = getChannelCounter(channel).addAndGet(-increment);
532
533 if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
534 if (!channel.isReadable()) {
535
536 ChannelHandlerContext ctx = eventTask.getContext();
537 if (ctx.getHandler() instanceof ExecutionHandler) {
538
539
540
541
542
543 if (ctx.getAttachment() != null) {
544
545 ctx.setAttachment(null);
546 channel.setReadable(true);
547 }
548 } else {
549 channel.setReadable(true);
550 }
551 }
552 }
553 }
554 }
555
556 private AtomicLong getChannelCounter(Channel channel) {
557 AtomicLong counter = channelCounters.get(channel);
558 if (counter == null) {
559 counter = new AtomicLong();
560 AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
561 if (oldCounter != null) {
562 counter = oldCounter;
563 }
564 }
565
566
567 if (!channel.isOpen()) {
568 channelCounters.remove(channel);
569 }
570 return counter;
571 }
572
573
574
575
576
577
578
579 protected boolean shouldCount(Runnable task) {
580 if (task instanceof ChannelUpstreamEventRunnable) {
581 ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
582 ChannelEvent e = r.getEvent();
583 if (e instanceof WriteCompletionEvent) {
584 return false;
585 } else if (e instanceof ChannelStateEvent) {
586 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
587 return false;
588 }
589 }
590 }
591 return true;
592 }
593
594 private static final class Settings {
595 final ObjectSizeEstimator objectSizeEstimator;
596 final long maxChannelMemorySize;
597
598 Settings(ObjectSizeEstimator objectSizeEstimator,
599 long maxChannelMemorySize) {
600 this.objectSizeEstimator = objectSizeEstimator;
601 this.maxChannelMemorySize = maxChannelMemorySize;
602 }
603 }
604
605 private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
606 public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
607 try {
608 final Thread t = new Thread(r, "Temporary task executor");
609 t.start();
610 } catch (Throwable e) {
611 throw new RejectedExecutionException(
612 "Failed to start a new thread", e);
613 }
614 }
615 }
616
617 private static final class MemoryAwareRunnable implements Runnable {
618 final Runnable task;
619 int estimatedSize;
620
621 MemoryAwareRunnable(Runnable task) {
622 this.task = task;
623 }
624
625 public void run() {
626 task.run();
627 }
628 }
629
630
631 private static class Limiter {
632
633 final long limit;
634 private long counter;
635 private int waiters;
636
637 Limiter(long limit) {
638 this.limit = limit;
639 }
640
641 synchronized void increase(long amount) {
642 while (counter >= limit) {
643 waiters ++;
644 try {
645 wait();
646 } catch (InterruptedException e) {
647 Thread.currentThread().interrupt();
648 } finally {
649 waiters --;
650 }
651 }
652 counter += amount;
653 }
654
655 synchronized void decrease(long amount) {
656 counter -= amount;
657 if (counter < limit && waiters > 0) {
658 notifyAll();
659 }
660 }
661 }
662 }