1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.handler.execution;
17
18 import org.jboss.netty.buffer.ChannelBuffer;
19 import org.jboss.netty.channel.Channel;
20 import org.jboss.netty.channel.ChannelEvent;
21 import org.jboss.netty.channel.ChannelFuture;
22 import org.jboss.netty.channel.ChannelHandlerContext;
23 import org.jboss.netty.channel.ChannelState;
24 import org.jboss.netty.channel.ChannelStateEvent;
25 import org.jboss.netty.channel.Channels;
26 import org.jboss.netty.channel.MessageEvent;
27 import org.jboss.netty.channel.WriteCompletionEvent;
28 import org.jboss.netty.logging.InternalLogger;
29 import org.jboss.netty.logging.InternalLoggerFactory;
30 import org.jboss.netty.util.DefaultObjectSizeEstimator;
31 import org.jboss.netty.util.ObjectSizeEstimator;
32 import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
33 import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
34
35 import java.io.IOException;
36 import java.lang.reflect.Method;
37 import java.util.HashSet;
38 import java.util.List;
39 import java.util.Set;
40 import java.util.concurrent.ConcurrentMap;
41 import java.util.concurrent.Executor;
42 import java.util.concurrent.Executors;
43 import java.util.concurrent.LinkedBlockingQueue;
44 import java.util.concurrent.RejectedExecutionException;
45 import java.util.concurrent.RejectedExecutionHandler;
46 import java.util.concurrent.ThreadFactory;
47 import java.util.concurrent.ThreadPoolExecutor;
48 import java.util.concurrent.TimeUnit;
49 import java.util.concurrent.atomic.AtomicLong;
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140
141 private static final InternalLogger logger =
142 InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143
144 private static final SharedResourceMisuseDetector misuseDetector =
145 new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146
147 private volatile Settings settings;
148
149 private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150 new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151 private final Limiter totalLimiter;
152
153 private volatile boolean notifyOnShutdown;
154
155
156
157
158
159
160
161
162
163
164 public MemoryAwareThreadPoolExecutor(
165 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166
167 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168 }
169
170
171
172
173
174
175
176
177
178
179
180
181 public MemoryAwareThreadPoolExecutor(
182 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183 long keepAliveTime, TimeUnit unit) {
184
185 this(
186 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187 Executors.defaultThreadFactory());
188 }
189
190
191
192
193
194
195
196
197
198
199
200
201
202 public MemoryAwareThreadPoolExecutor(
203 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204 long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205
206 this(
207 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208 new DefaultObjectSizeEstimator(), threadFactory);
209 }
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224 public MemoryAwareThreadPoolExecutor(
225 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226 long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227 ThreadFactory threadFactory) {
228
229 super(corePoolSize, corePoolSize, keepAliveTime, unit,
230 new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231
232 if (objectSizeEstimator == null) {
233 throw new NullPointerException("objectSizeEstimator");
234 }
235 if (maxChannelMemorySize < 0) {
236 throw new IllegalArgumentException(
237 "maxChannelMemorySize: " + maxChannelMemorySize);
238 }
239 if (maxTotalMemorySize < 0) {
240 throw new IllegalArgumentException(
241 "maxTotalMemorySize: " + maxTotalMemorySize);
242 }
243
244
245
246 try {
247 Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248 m.invoke(this, Boolean.TRUE);
249 } catch (Throwable t) {
250
251 logger.debug(
252 "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253 "supported in this platform.");
254 }
255
256 settings = new Settings(
257 objectSizeEstimator, maxChannelMemorySize);
258
259 if (maxTotalMemorySize == 0) {
260 totalLimiter = null;
261 } else {
262 totalLimiter = new Limiter(maxTotalMemorySize);
263 }
264
265
266 misuseDetector.increase();
267 }
268
269 @Override
270 protected void terminated() {
271 super.terminated();
272 misuseDetector.decrease();
273 }
274
275
276
277
278 @Override
279 public List<Runnable> shutdownNow() {
280 return shutdownNow(notifyOnShutdown);
281 }
282
283
284
285
286
287
288
289
290
291
292
293
294
295 public List<Runnable> shutdownNow(boolean notify) {
296 if (!notify) {
297 return super.shutdownNow();
298 }
299 Throwable cause = null;
300 Set<Channel> channels = null;
301
302 List<Runnable> tasks = super.shutdownNow();
303
304
305 for (Runnable task: tasks) {
306 if (task instanceof ChannelEventRunnable) {
307 if (cause == null) {
308 cause = new IOException("Unable to process queued event");
309 }
310 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311 event.getFuture().setFailure(cause);
312
313 if (channels == null) {
314 channels = new HashSet<Channel>();
315 }
316
317
318 channels.add(event.getChannel());
319 }
320 }
321
322
323 if (channels != null) {
324 for (Channel channel: channels) {
325 Channels.fireExceptionCaughtLater(channel, cause);
326 }
327 }
328 return tasks;
329 }
330
331
332
333
334 public ObjectSizeEstimator getObjectSizeEstimator() {
335 return settings.objectSizeEstimator;
336 }
337
338
339
340
341 public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
342 if (objectSizeEstimator == null) {
343 throw new NullPointerException("objectSizeEstimator");
344 }
345
346 settings = new Settings(
347 objectSizeEstimator,
348 settings.maxChannelMemorySize);
349 }
350
351
352
353
354 public long getMaxChannelMemorySize() {
355 return settings.maxChannelMemorySize;
356 }
357
358
359
360
361
362 public void setMaxChannelMemorySize(long maxChannelMemorySize) {
363 if (maxChannelMemorySize < 0) {
364 throw new IllegalArgumentException(
365 "maxChannelMemorySize: " + maxChannelMemorySize);
366 }
367
368 if (getTaskCount() > 0) {
369 throw new IllegalStateException(
370 "can't be changed after a task is executed");
371 }
372
373 settings = new Settings(
374 settings.objectSizeEstimator,
375 maxChannelMemorySize);
376 }
377
378
379
380
381 public long getMaxTotalMemorySize() {
382 if (totalLimiter == null) {
383 return 0;
384 }
385 return totalLimiter.limit;
386 }
387
388
389
390
391 @Deprecated
392 public void setMaxTotalMemorySize(long maxTotalMemorySize) {
393 if (maxTotalMemorySize < 0) {
394 throw new IllegalArgumentException(
395 "maxTotalMemorySize: " + maxTotalMemorySize);
396 }
397
398 if (getTaskCount() > 0) {
399 throw new IllegalStateException(
400 "can't be changed after a task is executed");
401 }
402 }
403
404
405
406
407
408
409
410
411
412
413
414 public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
415 this.notifyOnShutdown = notifyOnShutdown;
416 }
417
418
419
420
421
422 public boolean getNotifyChannelFuturesOnShutdown() {
423 return notifyOnShutdown;
424 }
425
426 @Override
427 public void execute(Runnable command) {
428 if (command instanceof ChannelDownstreamEventRunnable) {
429 throw new RejectedExecutionException("command must be enclosed with an upstream event.");
430 }
431 if (!(command instanceof ChannelEventRunnable)) {
432 command = new MemoryAwareRunnable(command);
433 }
434
435 increaseCounter(command);
436 doExecute(command);
437 }
438
439
440
441
442
443 protected void doExecute(Runnable task) {
444 doUnorderedExecute(task);
445 }
446
447
448
449
450 protected final void doUnorderedExecute(Runnable task) {
451 super.execute(task);
452 }
453
454 @Override
455 public boolean remove(Runnable task) {
456 boolean removed = super.remove(task);
457 if (removed) {
458 decreaseCounter(task);
459 }
460 return removed;
461 }
462
463 @Override
464 protected void beforeExecute(Thread t, Runnable r) {
465 super.beforeExecute(t, r);
466 decreaseCounter(r);
467 }
468
469 protected void increaseCounter(Runnable task) {
470 if (!shouldCount(task)) {
471 return;
472 }
473
474 Settings settings = this.settings;
475 long maxChannelMemorySize = settings.maxChannelMemorySize;
476
477 int increment = settings.objectSizeEstimator.estimateSize(task);
478
479 if (task instanceof ChannelEventRunnable) {
480 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
481 eventTask.estimatedSize = increment;
482 Channel channel = eventTask.getEvent().getChannel();
483 long channelCounter = getChannelCounter(channel).addAndGet(increment);
484
485 if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
486 if (channel.isReadable()) {
487
488 ChannelHandlerContext ctx = eventTask.getContext();
489 if (ctx.getHandler() instanceof ExecutionHandler) {
490
491 ctx.setAttachment(Boolean.TRUE);
492 }
493 channel.setReadable(false);
494 }
495 }
496 } else {
497 ((MemoryAwareRunnable) task).estimatedSize = increment;
498 }
499
500 if (totalLimiter != null) {
501 totalLimiter.increase(increment);
502 }
503 }
504
505 protected void decreaseCounter(Runnable task) {
506 if (!shouldCount(task)) {
507 return;
508 }
509
510 Settings settings = this.settings;
511 long maxChannelMemorySize = settings.maxChannelMemorySize;
512
513 int increment;
514 if (task instanceof ChannelEventRunnable) {
515 increment = ((ChannelEventRunnable) task).estimatedSize;
516 } else {
517 increment = ((MemoryAwareRunnable) task).estimatedSize;
518 }
519
520 if (totalLimiter != null) {
521 totalLimiter.decrease(increment);
522 }
523
524 if (task instanceof ChannelEventRunnable) {
525 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
526 Channel channel = eventTask.getEvent().getChannel();
527 long channelCounter = getChannelCounter(channel).addAndGet(-increment);
528
529 if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
530 if (!channel.isReadable()) {
531
532 ChannelHandlerContext ctx = eventTask.getContext();
533 if (ctx.getHandler() instanceof ExecutionHandler) {
534
535
536
537
538
539 if (ctx.getAttachment() != null) {
540
541 ctx.setAttachment(null);
542 channel.setReadable(true);
543 }
544 } else {
545 channel.setReadable(true);
546 }
547 }
548 }
549 }
550 }
551
552 private AtomicLong getChannelCounter(Channel channel) {
553 AtomicLong counter = channelCounters.get(channel);
554 if (counter == null) {
555 counter = new AtomicLong();
556 AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
557 if (oldCounter != null) {
558 counter = oldCounter;
559 }
560 }
561
562
563 if (!channel.isOpen()) {
564 channelCounters.remove(channel);
565 }
566 return counter;
567 }
568
569
570
571
572
573
574
575 protected boolean shouldCount(Runnable task) {
576 if (task instanceof ChannelUpstreamEventRunnable) {
577 ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
578 ChannelEvent e = r.getEvent();
579 if (e instanceof WriteCompletionEvent) {
580 return false;
581 } else if (e instanceof ChannelStateEvent) {
582 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
583 return false;
584 }
585 }
586 }
587 return true;
588 }
589
590 private static final class Settings {
591 final ObjectSizeEstimator objectSizeEstimator;
592 final long maxChannelMemorySize;
593
594 Settings(ObjectSizeEstimator objectSizeEstimator,
595 long maxChannelMemorySize) {
596 this.objectSizeEstimator = objectSizeEstimator;
597 this.maxChannelMemorySize = maxChannelMemorySize;
598 }
599 }
600
601 private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
602 public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
603 try {
604 final Thread t = new Thread(r, "Temporary task executor");
605 t.start();
606 } catch (Throwable e) {
607 throw new RejectedExecutionException(
608 "Failed to start a new thread", e);
609 }
610 }
611 }
612
613 private static final class MemoryAwareRunnable implements Runnable {
614 final Runnable task;
615 int estimatedSize;
616
617 MemoryAwareRunnable(Runnable task) {
618 this.task = task;
619 }
620
621 public void run() {
622 task.run();
623 }
624 }
625
626 private static class Limiter {
627
628 final long limit;
629 private long counter;
630 private int waiters;
631
632 Limiter(long limit) {
633 this.limit = limit;
634 }
635
636 synchronized void increase(long amount) {
637 while (counter >= limit) {
638 waiters ++;
639 try {
640 wait();
641 } catch (InterruptedException e) {
642 Thread.currentThread().interrupt();
643 } finally {
644 waiters --;
645 }
646 }
647 counter += amount;
648 }
649
650 synchronized void decrease(long amount) {
651 counter -= amount;
652 if (counter < limit && waiters > 0) {
653 notifyAll();
654 }
655 }
656 }
657 }