1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.jboss.netty.handler.execution;
17
18 import org.jboss.netty.buffer.ChannelBuffer;
19 import org.jboss.netty.channel.Channel;
20 import org.jboss.netty.channel.ChannelEvent;
21 import org.jboss.netty.channel.ChannelFuture;
22 import org.jboss.netty.channel.ChannelHandlerContext;
23 import org.jboss.netty.channel.ChannelState;
24 import org.jboss.netty.channel.ChannelStateEvent;
25 import org.jboss.netty.channel.Channels;
26 import org.jboss.netty.channel.MessageEvent;
27 import org.jboss.netty.channel.WriteCompletionEvent;
28 import org.jboss.netty.logging.InternalLogger;
29 import org.jboss.netty.logging.InternalLoggerFactory;
30 import org.jboss.netty.util.DefaultObjectSizeEstimator;
31 import org.jboss.netty.util.ObjectSizeEstimator;
32 import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
33 import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
34
35 import java.io.IOException;
36 import java.lang.reflect.Method;
37 import java.util.HashSet;
38 import java.util.List;
39 import java.util.Set;
40 import java.util.concurrent.ConcurrentMap;
41 import java.util.concurrent.Executor;
42 import java.util.concurrent.Executors;
43 import java.util.concurrent.LinkedBlockingQueue;
44 import java.util.concurrent.RejectedExecutionException;
45 import java.util.concurrent.RejectedExecutionHandler;
46 import java.util.concurrent.ThreadFactory;
47 import java.util.concurrent.ThreadPoolExecutor;
48 import java.util.concurrent.TimeUnit;
49 import java.util.concurrent.atomic.AtomicLong;
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140
141 private static final InternalLogger logger =
142 InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143
144 private static final SharedResourceMisuseDetector misuseDetector =
145 new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146
147 private volatile Settings settings;
148
149 private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150 new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151 private final Limiter totalLimiter;
152
153 private volatile boolean notifyOnShutdown;
154
155
156
157
158
159
160
161
162
163
164 public MemoryAwareThreadPoolExecutor(
165 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166
167 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168 }
169
170
171
172
173
174
175
176
177
178
179
180
181 public MemoryAwareThreadPoolExecutor(
182 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183 long keepAliveTime, TimeUnit unit) {
184
185 this(
186 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187 Executors.defaultThreadFactory());
188 }
189
190
191
192
193
194
195
196
197
198
199
200
201
202 public MemoryAwareThreadPoolExecutor(
203 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204 long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205
206 this(
207 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208 new DefaultObjectSizeEstimator(), threadFactory);
209 }
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224 public MemoryAwareThreadPoolExecutor(
225 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226 long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227 ThreadFactory threadFactory) {
228
229 super(corePoolSize, corePoolSize, keepAliveTime, unit,
230 new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231
232 if (objectSizeEstimator == null) {
233 throw new NullPointerException("objectSizeEstimator");
234 }
235 if (maxChannelMemorySize < 0) {
236 throw new IllegalArgumentException(
237 "maxChannelMemorySize: " + maxChannelMemorySize);
238 }
239 if (maxTotalMemorySize < 0) {
240 throw new IllegalArgumentException(
241 "maxTotalMemorySize: " + maxTotalMemorySize);
242 }
243
244
245
246 try {
247 Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248 m.invoke(this, Boolean.TRUE);
249 } catch (Throwable t) {
250
251 logger.debug(
252 "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253 "supported in this platform.");
254 }
255
256 settings = new Settings(
257 objectSizeEstimator, maxChannelMemorySize);
258
259 if (maxTotalMemorySize == 0) {
260 totalLimiter = null;
261 } else {
262 totalLimiter = new Limiter(maxTotalMemorySize);
263 }
264
265
266 misuseDetector.increase();
267 }
268
269 @Override
270 protected void terminated() {
271 super.terminated();
272 misuseDetector.decrease();
273 }
274
275
276
277
278 @Override
279 public List<Runnable> shutdownNow() {
280 return shutdownNow(notifyOnShutdown);
281 }
282
283
284
285
286
287
288
289
290
291
292
293
294
295 public List<Runnable> shutdownNow(boolean notify) {
296 if (!notify) {
297 return super.shutdownNow();
298 }
299 Throwable cause = null;
300 Set<Channel> channels = null;
301
302 List<Runnable> tasks = super.shutdownNow();
303
304
305 for (Runnable task: tasks) {
306 if (task instanceof ChannelEventRunnable) {
307 if (cause == null) {
308 cause = new IOException("Unable to process queued event");
309 }
310 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311 event.getFuture().setFailure(cause);
312
313 if (channels == null) {
314 channels = new HashSet<Channel>();
315 }
316
317
318 channels.add(event.getChannel());
319 }
320 }
321
322
323 if (channels != null) {
324 for (Channel channel: channels) {
325 Channels.fireExceptionCaughtLater(channel, cause);
326 }
327 }
328 return tasks;
329 }
330
331
332
333
334 public ObjectSizeEstimator getObjectSizeEstimator() {
335 return settings.objectSizeEstimator;
336 }
337
338
339
340
341 public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
342 if (objectSizeEstimator == null) {
343 throw new NullPointerException("objectSizeEstimator");
344 }
345
346 settings = new Settings(
347 objectSizeEstimator,
348 settings.maxChannelMemorySize);
349 }
350
351
352
353
354 public long getMaxChannelMemorySize() {
355 return settings.maxChannelMemorySize;
356 }
357
358
359
360
361
362 public void setMaxChannelMemorySize(long maxChannelMemorySize) {
363 if (maxChannelMemorySize < 0) {
364 throw new IllegalArgumentException(
365 "maxChannelMemorySize: " + maxChannelMemorySize);
366 }
367
368 if (getTaskCount() > 0) {
369 throw new IllegalStateException(
370 "can't be changed after a task is executed");
371 }
372
373 settings = new Settings(
374 settings.objectSizeEstimator,
375 maxChannelMemorySize);
376 }
377
378
379
380
381 public long getMaxTotalMemorySize() {
382 if (totalLimiter == null) {
383 return 0;
384 }
385 return totalLimiter.limit;
386 }
387
388
389
390
391
392
393
394
395
396
397
398 public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
399 this.notifyOnShutdown = notifyOnShutdown;
400 }
401
402
403
404
405
406 public boolean getNotifyChannelFuturesOnShutdown() {
407 return notifyOnShutdown;
408 }
409
410 @Override
411 public void execute(Runnable command) {
412 if (command instanceof ChannelDownstreamEventRunnable) {
413 throw new RejectedExecutionException("command must be enclosed with an upstream event.");
414 }
415 if (!(command instanceof ChannelEventRunnable)) {
416 command = new MemoryAwareRunnable(command);
417 }
418
419 increaseCounter(command);
420 doExecute(command);
421 }
422
423
424
425
426
427 protected void doExecute(Runnable task) {
428 doUnorderedExecute(task);
429 }
430
431
432
433
434 protected final void doUnorderedExecute(Runnable task) {
435 super.execute(task);
436 }
437
438 @Override
439 public boolean remove(Runnable task) {
440 boolean removed = super.remove(task);
441 if (removed) {
442 decreaseCounter(task);
443 }
444 return removed;
445 }
446
447 @Override
448 protected void beforeExecute(Thread t, Runnable r) {
449 super.beforeExecute(t, r);
450 decreaseCounter(r);
451 }
452
453 protected void increaseCounter(Runnable task) {
454 if (!shouldCount(task)) {
455 return;
456 }
457
458 Settings settings = this.settings;
459 long maxChannelMemorySize = settings.maxChannelMemorySize;
460
461 int increment = settings.objectSizeEstimator.estimateSize(task);
462
463 if (task instanceof ChannelEventRunnable) {
464 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
465 eventTask.estimatedSize = increment;
466 Channel channel = eventTask.getEvent().getChannel();
467 long channelCounter = getChannelCounter(channel).addAndGet(increment);
468
469 if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
470 if (channel.isReadable()) {
471
472 ChannelHandlerContext ctx = eventTask.getContext();
473 if (ctx.getHandler() instanceof ExecutionHandler) {
474
475 ctx.setAttachment(Boolean.TRUE);
476 }
477 channel.setReadable(false);
478 }
479 }
480 } else {
481 ((MemoryAwareRunnable) task).estimatedSize = increment;
482 }
483
484 if (totalLimiter != null) {
485 totalLimiter.increase(increment);
486 }
487 }
488
489 protected void decreaseCounter(Runnable task) {
490 if (!shouldCount(task)) {
491 return;
492 }
493
494 Settings settings = this.settings;
495 long maxChannelMemorySize = settings.maxChannelMemorySize;
496
497 int increment;
498 if (task instanceof ChannelEventRunnable) {
499 increment = ((ChannelEventRunnable) task).estimatedSize;
500 } else {
501 increment = ((MemoryAwareRunnable) task).estimatedSize;
502 }
503
504 if (totalLimiter != null) {
505 totalLimiter.decrease(increment);
506 }
507
508 if (task instanceof ChannelEventRunnable) {
509 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
510 Channel channel = eventTask.getEvent().getChannel();
511 long channelCounter = getChannelCounter(channel).addAndGet(-increment);
512
513 if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
514 if (!channel.isReadable()) {
515
516 ChannelHandlerContext ctx = eventTask.getContext();
517 if (ctx.getHandler() instanceof ExecutionHandler) {
518
519
520
521
522
523 if (ctx.getAttachment() != null) {
524
525 ctx.setAttachment(null);
526 channel.setReadable(true);
527 }
528 } else {
529 channel.setReadable(true);
530 }
531 }
532 }
533 }
534 }
535
536 private AtomicLong getChannelCounter(Channel channel) {
537 AtomicLong counter = channelCounters.get(channel);
538 if (counter == null) {
539 counter = new AtomicLong();
540 AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
541 if (oldCounter != null) {
542 counter = oldCounter;
543 }
544 }
545
546
547 if (!channel.isOpen()) {
548 channelCounters.remove(channel);
549 }
550 return counter;
551 }
552
553
554
555
556
557
558
559 protected boolean shouldCount(Runnable task) {
560 if (task instanceof ChannelUpstreamEventRunnable) {
561 ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
562 ChannelEvent e = r.getEvent();
563 if (e instanceof WriteCompletionEvent) {
564 return false;
565 } else if (e instanceof ChannelStateEvent) {
566 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
567 return false;
568 }
569 }
570 }
571 return true;
572 }
573
574 private static final class Settings {
575 final ObjectSizeEstimator objectSizeEstimator;
576 final long maxChannelMemorySize;
577
578 Settings(ObjectSizeEstimator objectSizeEstimator,
579 long maxChannelMemorySize) {
580 this.objectSizeEstimator = objectSizeEstimator;
581 this.maxChannelMemorySize = maxChannelMemorySize;
582 }
583 }
584
585 private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
586 public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
587 try {
588 final Thread t = new Thread(r, "Temporary task executor");
589 t.start();
590 } catch (Throwable e) {
591 throw new RejectedExecutionException(
592 "Failed to start a new thread", e);
593 }
594 }
595 }
596
597 private static final class MemoryAwareRunnable implements Runnable {
598 final Runnable task;
599 int estimatedSize;
600
601 MemoryAwareRunnable(Runnable task) {
602 this.task = task;
603 }
604
605 public void run() {
606 task.run();
607 }
608 }
609
610 private static class Limiter {
611
612 final long limit;
613 private long counter;
614 private int waiters;
615
616 Limiter(long limit) {
617 this.limit = limit;
618 }
619
620 synchronized void increase(long amount) {
621 while (counter >= limit) {
622 waiters ++;
623 try {
624 wait();
625 } catch (InterruptedException e) {
626 Thread.currentThread().interrupt();
627 } finally {
628 waiters --;
629 }
630 }
631 counter += amount;
632 }
633
634 synchronized void decrease(long amount) {
635 counter -= amount;
636 if (counter < limit && waiters > 0) {
637 notifyAll();
638 }
639 }
640 }
641 }