1 /*
2 * Copyright 2012 The Netty Project
3 *
4 * The Netty Project licenses this file to you under the Apache License,
5 * version 2.0 (the "License"); you may not use this file except in compliance
6 * with the License. You may obtain a copy of the License at:
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 * License for the specific language governing permissions and limitations
14 * under the License.
15 */
16 package org.jboss.netty.handler.execution;
17
18 import java.io.IOException;
19 import java.lang.reflect.Method;
20 import java.util.HashSet;
21 import java.util.List;
22 import java.util.Set;
23 import java.util.concurrent.ConcurrentMap;
24 import java.util.concurrent.Executor;
25 import java.util.concurrent.Executors;
26 import java.util.concurrent.LinkedBlockingQueue;
27 import java.util.concurrent.RejectedExecutionException;
28 import java.util.concurrent.RejectedExecutionHandler;
29 import java.util.concurrent.ThreadFactory;
30 import java.util.concurrent.ThreadPoolExecutor;
31 import java.util.concurrent.TimeUnit;
32 import java.util.concurrent.atomic.AtomicLong;
33
34 import org.jboss.netty.buffer.ChannelBuffer;
35 import org.jboss.netty.channel.Channel;
36 import org.jboss.netty.channel.ChannelEvent;
37 import org.jboss.netty.channel.ChannelFuture;
38 import org.jboss.netty.channel.ChannelHandlerContext;
39 import org.jboss.netty.channel.ChannelState;
40 import org.jboss.netty.channel.ChannelStateEvent;
41 import org.jboss.netty.channel.Channels;
42 import org.jboss.netty.channel.MessageEvent;
43 import org.jboss.netty.channel.WriteCompletionEvent;
44 import org.jboss.netty.logging.InternalLogger;
45 import org.jboss.netty.logging.InternalLoggerFactory;
46 import org.jboss.netty.util.DefaultObjectSizeEstimator;
47 import org.jboss.netty.util.ObjectSizeEstimator;
48 import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
49 import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
50
51 /**
52 * A {@link ThreadPoolExecutor} which blocks the task submission when there's
53 * too many tasks in the queue. Both per-{@link Channel} and per-{@link Executor}
54 * limitation can be applied.
55 * <p>
56 * When a task (i.e. {@link Runnable}) is submitted,
57 * {@link MemoryAwareThreadPoolExecutor} calls {@link ObjectSizeEstimator#estimateSize(Object)}
58 * to get the estimated size of the task in bytes to calculate the amount of
59 * memory occupied by the unprocessed tasks.
60 * <p>
61 * If the total size of the unprocessed tasks exceeds either per-{@link Channel}
62 * or per-{@link Executor} threshold, any further {@link #execute(Runnable)}
63 * call will block until the tasks in the queue are processed so that the total
64 * size goes under the threshold.
65 *
66 * <h3>Using an alternative task size estimation strategy</h3>
67 *
68 * Although the default implementation does its best to guess the size of an
69 * object of unknown type, it is always good idea to to use an alternative
70 * {@link ObjectSizeEstimator} implementation instead of the
71 * {@link DefaultObjectSizeEstimator} to avoid incorrect task size calculation,
72 * especially when:
73 * <ul>
74 * <li>you are using {@link MemoryAwareThreadPoolExecutor} independently from
75 * {@link ExecutionHandler},</li>
76 * <li>you are submitting a task whose type is not {@link ChannelEventRunnable}, or</li>
77 * <li>the message type of the {@link MessageEvent} in the {@link ChannelEventRunnable}
78 * is not {@link ChannelBuffer}.</li>
79 * </ul>
80 * Here is an example that demonstrates how to implement an {@link ObjectSizeEstimator}
81 * which understands a user-defined object:
82 * <pre>
83 * public class MyRunnable implements {@link Runnable} {
84 *
85 * <b>private final byte[] data;</b>
86 *
87 * public MyRunnable(byte[] data) {
88 * this.data = data;
89 * }
90 *
91 * public void run() {
92 * // Process 'data' ..
93 * }
94 * }
95 *
96 * public class MyObjectSizeEstimator extends {@link DefaultObjectSizeEstimator} {
97 *
98 * {@literal @Override}
99 * public int estimateSize(Object o) {
100 * if (<b>o instanceof MyRunnable</b>) {
101 * <b>return ((MyRunnable) o).data.length + 8;</b>
102 * }
103 * return super.estimateSize(o);
104 * }
105 * }
106 *
107 * {@link ThreadPoolExecutor} pool = new {@link MemoryAwareThreadPoolExecutor}(
108 * 16, 65536, 1048576, 30, {@link TimeUnit}.SECONDS,
109 * <b>new MyObjectSizeEstimator()</b>,
110 * {@link Executors}.defaultThreadFactory());
111 *
112 * <b>pool.execute(new MyRunnable(data));</b>
113 * </pre>
114 *
115 * <h3>Event execution order</h3>
116 *
117 * Please note that this executor does not maintain the order of the
118 * {@link ChannelEvent}s for the same {@link Channel}. For example,
119 * you can even receive a {@code "channelClosed"} event before a
120 * {@code "messageReceived"} event, as depicted by the following diagram.
121 *
122 * For example, the events can be processed as depicted below:
123 *
124 * <pre>
125 * --------------------------------> Timeline -------------------------------->
126 *
127 * Thread X: --- Channel A (Event 2) --- Channel A (Event 1) --------------------------->
128 *
129 * Thread Y: --- Channel A (Event 3) --- Channel B (Event 2) --- Channel B (Event 3) --->
130 *
131 * Thread Z: --- Channel B (Event 1) --- Channel B (Event 4) --- Channel A (Event 4) --->
132 * </pre>
133 *
134 * To maintain the event order, you must use {@link OrderedMemoryAwareThreadPoolExecutor}.
135 *
136 * @apiviz.has org.jboss.netty.util.ObjectSizeEstimator oneway - -
137 * @apiviz.has org.jboss.netty.handler.execution.ChannelEventRunnable oneway - - executes
138 */
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140
141 private static final InternalLogger logger =
142 InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143
144 private static final SharedResourceMisuseDetector misuseDetector =
145 new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146
147 private volatile Settings settings;
148
149 private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150 new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151 private final Limiter totalLimiter;
152
153 private volatile boolean notifyOnShutdown;
154
155 /**
156 * Creates a new instance.
157 *
158 * @param corePoolSize the maximum number of active threads
159 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
160 * Specify {@code 0} to disable.
161 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
162 * Specify {@code 0} to disable.
163 */
164 public MemoryAwareThreadPoolExecutor(
165 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166
167 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168 }
169
170 /**
171 * Creates a new instance.
172 *
173 * @param corePoolSize the maximum number of active threads
174 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
175 * Specify {@code 0} to disable.
176 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
177 * Specify {@code 0} to disable.
178 * @param keepAliveTime the amount of time for an inactive thread to shut itself down
179 * @param unit the {@link TimeUnit} of {@code keepAliveTime}
180 */
181 public MemoryAwareThreadPoolExecutor(
182 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183 long keepAliveTime, TimeUnit unit) {
184
185 this(
186 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187 Executors.defaultThreadFactory());
188 }
189
190 /**
191 * Creates a new instance.
192 *
193 * @param corePoolSize the maximum number of active threads
194 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
195 * Specify {@code 0} to disable.
196 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
197 * Specify {@code 0} to disable.
198 * @param keepAliveTime the amount of time for an inactive thread to shut itself down
199 * @param unit the {@link TimeUnit} of {@code keepAliveTime}
200 * @param threadFactory the {@link ThreadFactory} of this pool
201 */
202 public MemoryAwareThreadPoolExecutor(
203 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204 long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205
206 this(
207 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208 new DefaultObjectSizeEstimator(), threadFactory);
209 }
210
211 /**
212 * Creates a new instance.
213 *
214 * @param corePoolSize the maximum number of active threads
215 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
216 * Specify {@code 0} to disable.
217 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
218 * Specify {@code 0} to disable.
219 * @param keepAliveTime the amount of time for an inactive thread to shut itself down
220 * @param unit the {@link TimeUnit} of {@code keepAliveTime}
221 * @param threadFactory the {@link ThreadFactory} of this pool
222 * @param objectSizeEstimator the {@link ObjectSizeEstimator} of this pool
223 */
224 public MemoryAwareThreadPoolExecutor(
225 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226 long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227 ThreadFactory threadFactory) {
228
229 super(corePoolSize, corePoolSize, keepAliveTime, unit,
230 new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231
232 if (objectSizeEstimator == null) {
233 throw new NullPointerException("objectSizeEstimator");
234 }
235 if (maxChannelMemorySize < 0) {
236 throw new IllegalArgumentException(
237 "maxChannelMemorySize: " + maxChannelMemorySize);
238 }
239 if (maxTotalMemorySize < 0) {
240 throw new IllegalArgumentException(
241 "maxTotalMemorySize: " + maxTotalMemorySize);
242 }
243
244 // Call allowCoreThreadTimeOut(true) using reflection
245 // because it is not supported in Java 5.
246 try {
247 Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248 m.invoke(this, Boolean.TRUE);
249 } catch (Throwable t) {
250 // Java 5
251 logger.debug(
252 "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253 "supported in this platform.");
254 }
255
256 settings = new Settings(
257 objectSizeEstimator, maxChannelMemorySize);
258
259 if (maxTotalMemorySize == 0) {
260 totalLimiter = null;
261 } else {
262 totalLimiter = new Limiter(maxTotalMemorySize);
263 }
264
265 // Misuse check
266 misuseDetector.increase();
267 }
268
269 @Override
270 protected void terminated() {
271 super.terminated();
272 misuseDetector.decrease();
273 }
274
275 /**
276 * This will call {@link #shutdownNow(boolean)} with the value of {@link #getNotifyChannelFuturesOnShutdown()}.
277 */
278 @Override
279 public List<Runnable> shutdownNow() {
280 return shutdownNow(notifyOnShutdown);
281 }
282
283 /**
284 * See {@link ThreadPoolExecutor#shutdownNow()} for how it handles the shutdown.
285 * If <code>true</code> is given to this method it also notifies all {@link ChannelFuture}'s
286 * of the not executed {@link ChannelEventRunnable}'s.
287 *
288 * <p>
289 * Be aware that if you call this with <code>false</code> you will need to handle the
290 * notification of the {@link ChannelFuture}'s by your self. So only use this if you
291 * really have a use-case for it.
292 * </p>
293 *
294 */
295 public List<Runnable> shutdownNow(boolean notify) {
296 if (!notify) {
297 return super.shutdownNow();
298 }
299 Throwable cause = null;
300 Set<Channel> channels = null;
301
302 List<Runnable> tasks = super.shutdownNow();
303
304 // loop over all tasks and cancel the ChannelFuture of the ChannelEventRunable's
305 for (Runnable task: tasks) {
306 if (task instanceof ChannelEventRunnable) {
307 if (cause == null) {
308 cause = new IOException("Unable to process queued event");
309 }
310 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311 event.getFuture().setFailure(cause);
312
313 if (channels == null) {
314 channels = new HashSet<Channel>();
315 }
316
317
318 // store the Channel of the event for later notification of the exceptionCaught event
319 channels.add(event.getChannel());
320 }
321 }
322
323 // loop over all channels and fire an exceptionCaught event
324 if (channels != null) {
325 for (Channel channel: channels) {
326 Channels.fireExceptionCaughtLater(channel, cause);
327 }
328 }
329 return tasks;
330 }
331
332 /**
333 * Returns the {@link ObjectSizeEstimator} of this pool.
334 */
335 public ObjectSizeEstimator getObjectSizeEstimator() {
336 return settings.objectSizeEstimator;
337 }
338
339 /**
340 * Sets the {@link ObjectSizeEstimator} of this pool.
341 */
342 public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
343 if (objectSizeEstimator == null) {
344 throw new NullPointerException("objectSizeEstimator");
345 }
346
347 settings = new Settings(
348 objectSizeEstimator,
349 settings.maxChannelMemorySize);
350 }
351
352 /**
353 * Returns the maximum total size of the queued events per channel.
354 */
355 public long getMaxChannelMemorySize() {
356 return settings.maxChannelMemorySize;
357 }
358
359 /**
360 * Sets the maximum total size of the queued events per channel.
361 * Specify {@code 0} to disable.
362 */
363 public void setMaxChannelMemorySize(long maxChannelMemorySize) {
364 if (maxChannelMemorySize < 0) {
365 throw new IllegalArgumentException(
366 "maxChannelMemorySize: " + maxChannelMemorySize);
367 }
368
369 if (getTaskCount() > 0) {
370 throw new IllegalStateException(
371 "can't be changed after a task is executed");
372 }
373
374 settings = new Settings(
375 settings.objectSizeEstimator,
376 maxChannelMemorySize);
377 }
378
379 /**
380 * Returns the maximum total size of the queued events for this pool.
381 */
382 public long getMaxTotalMemorySize() {
383 if (totalLimiter == null) {
384 return 0;
385 }
386 return totalLimiter.limit;
387 }
388
389
390 /**
391 * @deprecated <tt>maxTotalMemorySize</tt> is not modifiable anymore.
392 */
393 @Deprecated
394 public void setMaxTotalMemorySize(long maxTotalMemorySize) {
395 if (maxTotalMemorySize < 0) {
396 throw new IllegalArgumentException(
397 "maxTotalMemorySize: " + maxTotalMemorySize);
398 }
399
400 if (getTaskCount() > 0) {
401 throw new IllegalStateException(
402 "can't be changed after a task is executed");
403 }
404 }
405
406 /**
407 * If set to <code>false</code> no queued {@link ChannelEventRunnable}'s {@link ChannelFuture}
408 * will get notified once {@link #shutdownNow()} is called. If set to <code>true</code> every
409 * queued {@link ChannelEventRunnable} will get marked as failed via {@link ChannelFuture#setFailure(Throwable)}.
410 *
411 * <p>
412 * Please only set this to <code>false</code> if you want to handle the notification by yourself
413 * and know what you are doing. Default is <code>true</code>.
414 * </p>
415 */
416 public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
417 this.notifyOnShutdown = notifyOnShutdown;
418 }
419
420 /**
421 * Returns if the {@link ChannelFuture}'s of the {@link ChannelEventRunnable}'s should be
422 * notified about the shutdown of this {@link MemoryAwareThreadPoolExecutor}.
423 */
424 public boolean getNotifyChannelFuturesOnShutdown() {
425 return notifyOnShutdown;
426 }
427
428
429
430 @Override
431 public void execute(Runnable command) {
432 if (command instanceof ChannelDownstreamEventRunnable) {
433 throw new RejectedExecutionException("command must be enclosed with an upstream event.");
434 }
435 if (!(command instanceof ChannelEventRunnable)) {
436 command = new MemoryAwareRunnable(command);
437 }
438
439 increaseCounter(command);
440 doExecute(command);
441 }
442
443 /**
444 * Put the actual execution logic here. The default implementation simply
445 * calls {@link #doUnorderedExecute(Runnable)}.
446 */
447 protected void doExecute(Runnable task) {
448 doUnorderedExecute(task);
449 }
450
451 /**
452 * Executes the specified task without maintaining the event order.
453 */
454 protected final void doUnorderedExecute(Runnable task) {
455 super.execute(task);
456 }
457
458 @Override
459 public boolean remove(Runnable task) {
460 boolean removed = super.remove(task);
461 if (removed) {
462 decreaseCounter(task);
463 }
464 return removed;
465 }
466
467 @Override
468 protected void beforeExecute(Thread t, Runnable r) {
469 super.beforeExecute(t, r);
470 decreaseCounter(r);
471 }
472
473 protected void increaseCounter(Runnable task) {
474 if (!shouldCount(task)) {
475 return;
476 }
477
478 Settings settings = this.settings;
479 long maxChannelMemorySize = settings.maxChannelMemorySize;
480
481 int increment = settings.objectSizeEstimator.estimateSize(task);
482
483 if (task instanceof ChannelEventRunnable) {
484 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
485 eventTask.estimatedSize = increment;
486 Channel channel = eventTask.getEvent().getChannel();
487 long channelCounter = getChannelCounter(channel).addAndGet(increment);
488 //System.out.println("IC: " + channelCounter + ", " + increment);
489 if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
490 if (channel.isReadable()) {
491 //System.out.println("UNREADABLE");
492 ChannelHandlerContext ctx = eventTask.getContext();
493 if (ctx.getHandler() instanceof ExecutionHandler) {
494 // readSuspended = true;
495 ctx.setAttachment(Boolean.TRUE);
496 }
497 channel.setReadable(false);
498 }
499 }
500 } else {
501 ((MemoryAwareRunnable) task).estimatedSize = increment;
502 }
503
504 if (totalLimiter != null) {
505 totalLimiter.increase(increment);
506 }
507 }
508
509 protected void decreaseCounter(Runnable task) {
510 if (!shouldCount(task)) {
511 return;
512 }
513
514 Settings settings = this.settings;
515 long maxChannelMemorySize = settings.maxChannelMemorySize;
516
517 int increment;
518 if (task instanceof ChannelEventRunnable) {
519 increment = ((ChannelEventRunnable) task).estimatedSize;
520 } else {
521 increment = ((MemoryAwareRunnable) task).estimatedSize;
522 }
523
524 if (totalLimiter != null) {
525 totalLimiter.decrease(increment);
526 }
527
528 if (task instanceof ChannelEventRunnable) {
529 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
530 Channel channel = eventTask.getEvent().getChannel();
531 long channelCounter = getChannelCounter(channel).addAndGet(-increment);
532 //System.out.println("DC: " + channelCounter + ", " + increment);
533 if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
534 if (!channel.isReadable()) {
535 //System.out.println("READABLE");
536 ChannelHandlerContext ctx = eventTask.getContext();
537 if (ctx.getHandler() instanceof ExecutionHandler) {
538 // check if the attachment was set as this means that we suspend the channel
539 // from reads. This only works when this pool is used with ExecutionHandler
540 // but I guess thats good enough for us.
541 //
542 // See #215
543 if (ctx.getAttachment() != null) {
544 // readSuspended = false;
545 ctx.setAttachment(null);
546 channel.setReadable(true);
547 }
548 } else {
549 channel.setReadable(true);
550 }
551 }
552 }
553 }
554 }
555
556 private AtomicLong getChannelCounter(Channel channel) {
557 AtomicLong counter = channelCounters.get(channel);
558 if (counter == null) {
559 counter = new AtomicLong();
560 AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
561 if (oldCounter != null) {
562 counter = oldCounter;
563 }
564 }
565
566 // Remove the entry when the channel closes.
567 if (!channel.isOpen()) {
568 channelCounters.remove(channel);
569 }
570 return counter;
571 }
572
573 /**
574 * Returns {@code true} if and only if the specified {@code task} should
575 * be counted to limit the global and per-channel memory consumption.
576 * To override this method, you must call {@code super.shouldCount()} to
577 * make sure important tasks are not counted.
578 */
579 protected boolean shouldCount(Runnable task) {
580 if (task instanceof ChannelUpstreamEventRunnable) {
581 ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
582 ChannelEvent e = r.getEvent();
583 if (e instanceof WriteCompletionEvent) {
584 return false;
585 } else if (e instanceof ChannelStateEvent) {
586 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
587 return false;
588 }
589 }
590 }
591 return true;
592 }
593
594 private static final class Settings {
595 final ObjectSizeEstimator objectSizeEstimator;
596 final long maxChannelMemorySize;
597
598 Settings(ObjectSizeEstimator objectSizeEstimator,
599 long maxChannelMemorySize) {
600 this.objectSizeEstimator = objectSizeEstimator;
601 this.maxChannelMemorySize = maxChannelMemorySize;
602 }
603 }
604
605 private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
606 public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
607 try {
608 final Thread t = new Thread(r, "Temporary task executor");
609 t.start();
610 } catch (Throwable e) {
611 throw new RejectedExecutionException(
612 "Failed to start a new thread", e);
613 }
614 }
615 }
616
617 private static final class MemoryAwareRunnable implements Runnable {
618 final Runnable task;
619 int estimatedSize;
620
621 MemoryAwareRunnable(Runnable task) {
622 this.task = task;
623 }
624
625 public void run() {
626 task.run();
627 }
628 }
629
630
631 private static class Limiter {
632
633 final long limit;
634 private long counter;
635 private int waiters;
636
637 Limiter(long limit) {
638 this.limit = limit;
639 }
640
641 synchronized void increase(long amount) {
642 while (counter >= limit) {
643 waiters ++;
644 try {
645 wait();
646 } catch (InterruptedException e) {
647 Thread.currentThread().interrupt();
648 } finally {
649 waiters --;
650 }
651 }
652 counter += amount;
653 }
654
655 synchronized void decrease(long amount) {
656 counter -= amount;
657 if (counter < limit && waiters > 0) {
658 notifyAll();
659 }
660 }
661 }
662 }