1 /*
2 * Copyright 2012 The Netty Project
3 *
4 * The Netty Project licenses this file to you under the Apache License,
5 * version 2.0 (the "License"); you may not use this file except in compliance
6 * with the License. You may obtain a copy of the License at:
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 * License for the specific language governing permissions and limitations
14 * under the License.
15 */
16 package org.jboss.netty.handler.execution;
17
18 import org.jboss.netty.buffer.ChannelBuffer;
19 import org.jboss.netty.channel.Channel;
20 import org.jboss.netty.channel.ChannelEvent;
21 import org.jboss.netty.channel.ChannelFuture;
22 import org.jboss.netty.channel.ChannelHandlerContext;
23 import org.jboss.netty.channel.ChannelState;
24 import org.jboss.netty.channel.ChannelStateEvent;
25 import org.jboss.netty.channel.Channels;
26 import org.jboss.netty.channel.MessageEvent;
27 import org.jboss.netty.channel.WriteCompletionEvent;
28 import org.jboss.netty.logging.InternalLogger;
29 import org.jboss.netty.logging.InternalLoggerFactory;
30 import org.jboss.netty.util.DefaultObjectSizeEstimator;
31 import org.jboss.netty.util.ObjectSizeEstimator;
32 import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
33 import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
34
35 import java.io.IOException;
36 import java.lang.reflect.Method;
37 import java.util.HashSet;
38 import java.util.List;
39 import java.util.Set;
40 import java.util.concurrent.ConcurrentMap;
41 import java.util.concurrent.Executor;
42 import java.util.concurrent.Executors;
43 import java.util.concurrent.LinkedBlockingQueue;
44 import java.util.concurrent.RejectedExecutionException;
45 import java.util.concurrent.RejectedExecutionHandler;
46 import java.util.concurrent.ThreadFactory;
47 import java.util.concurrent.ThreadPoolExecutor;
48 import java.util.concurrent.TimeUnit;
49 import java.util.concurrent.atomic.AtomicLong;
50
51 /**
52 * A {@link ThreadPoolExecutor} which blocks the task submission when there's
53 * too many tasks in the queue. Both per-{@link Channel} and per-{@link Executor}
54 * limitation can be applied.
55 * <p>
56 * When a task (i.e. {@link Runnable}) is submitted,
57 * {@link MemoryAwareThreadPoolExecutor} calls {@link ObjectSizeEstimator#estimateSize(Object)}
58 * to get the estimated size of the task in bytes to calculate the amount of
59 * memory occupied by the unprocessed tasks.
60 * <p>
61 * If the total size of the unprocessed tasks exceeds either per-{@link Channel}
62 * or per-{@link Executor} threshold, any further {@link #execute(Runnable)}
63 * call will block until the tasks in the queue are processed so that the total
64 * size goes under the threshold.
65 *
66 * <h3>Using an alternative task size estimation strategy</h3>
67 *
68 * Although the default implementation does its best to guess the size of an
69 * object of unknown type, it is always good idea to to use an alternative
70 * {@link ObjectSizeEstimator} implementation instead of the
71 * {@link DefaultObjectSizeEstimator} to avoid incorrect task size calculation,
72 * especially when:
73 * <ul>
74 * <li>you are using {@link MemoryAwareThreadPoolExecutor} independently from
75 * {@link ExecutionHandler},</li>
76 * <li>you are submitting a task whose type is not {@link ChannelEventRunnable}, or</li>
77 * <li>the message type of the {@link MessageEvent} in the {@link ChannelEventRunnable}
78 * is not {@link ChannelBuffer}.</li>
79 * </ul>
80 * Here is an example that demonstrates how to implement an {@link ObjectSizeEstimator}
81 * which understands a user-defined object:
82 * <pre>
83 * public class MyRunnable implements {@link Runnable} {
84 *
85 * <b>private final byte[] data;</b>
86 *
87 * public MyRunnable(byte[] data) {
88 * this.data = data;
89 * }
90 *
91 * public void run() {
92 * // Process 'data' ..
93 * }
94 * }
95 *
96 * public class MyObjectSizeEstimator extends {@link DefaultObjectSizeEstimator} {
97 *
98 * {@literal @Override}
99 * public int estimateSize(Object o) {
100 * if (<b>o instanceof MyRunnable</b>) {
101 * <b>return ((MyRunnable) o).data.length + 8;</b>
102 * }
103 * return super.estimateSize(o);
104 * }
105 * }
106 *
107 * {@link ThreadPoolExecutor} pool = new {@link MemoryAwareThreadPoolExecutor}(
108 * 16, 65536, 1048576, 30, {@link TimeUnit}.SECONDS,
109 * <b>new MyObjectSizeEstimator()</b>,
110 * {@link Executors}.defaultThreadFactory());
111 *
112 * <b>pool.execute(new MyRunnable(data));</b>
113 * </pre>
114 *
115 * <h3>Event execution order</h3>
116 *
117 * Please note that this executor does not maintain the order of the
118 * {@link ChannelEvent}s for the same {@link Channel}. For example,
119 * you can even receive a {@code "channelClosed"} event before a
120 * {@code "messageReceived"} event, as depicted by the following diagram.
121 *
122 * For example, the events can be processed as depicted below:
123 *
124 * <pre>
125 * --------------------------------> Timeline -------------------------------->
126 *
127 * Thread X: --- Channel A (Event 1) --- Channel A (Event 2) --------------------------->
128 *
129 * Thread Y: --- Channel A (Event 3) --- Channel B (Event 2) --- Channel B (Event 3) --->
130 *
131 * Thread Z: --- Channel B (Event 1) --- Channel B (Event 4) --- Channel A (Event 4) --->
132 * </pre>
133 *
134 * To maintain the event order, you must use {@link OrderedMemoryAwareThreadPoolExecutor}.
135 *
136 * @apiviz.has org.jboss.netty.util.ObjectSizeEstimator oneway - -
137 * @apiviz.has org.jboss.netty.handler.execution.ChannelEventRunnable oneway - - executes
138 */
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140
141 private static final InternalLogger logger =
142 InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143
144 private static final SharedResourceMisuseDetector misuseDetector =
145 new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146
147 private volatile Settings settings;
148
149 private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150 new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151 private final Limiter totalLimiter;
152
153 private volatile boolean notifyOnShutdown;
154
155 /**
156 * Creates a new instance.
157 *
158 * @param corePoolSize the maximum number of active threads
159 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
160 * Specify {@code 0} to disable.
161 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
162 * Specify {@code 0} to disable.
163 */
164 public MemoryAwareThreadPoolExecutor(
165 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166
167 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168 }
169
170 /**
171 * Creates a new instance.
172 *
173 * @param corePoolSize the maximum number of active threads
174 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
175 * Specify {@code 0} to disable.
176 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
177 * Specify {@code 0} to disable.
178 * @param keepAliveTime the amount of time for an inactive thread to shut itself down
179 * @param unit the {@link TimeUnit} of {@code keepAliveTime}
180 */
181 public MemoryAwareThreadPoolExecutor(
182 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183 long keepAliveTime, TimeUnit unit) {
184
185 this(
186 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187 Executors.defaultThreadFactory());
188 }
189
190 /**
191 * Creates a new instance.
192 *
193 * @param corePoolSize the maximum number of active threads
194 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
195 * Specify {@code 0} to disable.
196 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
197 * Specify {@code 0} to disable.
198 * @param keepAliveTime the amount of time for an inactive thread to shut itself down
199 * @param unit the {@link TimeUnit} of {@code keepAliveTime}
200 * @param threadFactory the {@link ThreadFactory} of this pool
201 */
202 public MemoryAwareThreadPoolExecutor(
203 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204 long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205
206 this(
207 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208 new DefaultObjectSizeEstimator(), threadFactory);
209 }
210
211 /**
212 * Creates a new instance.
213 *
214 * @param corePoolSize the maximum number of active threads
215 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
216 * Specify {@code 0} to disable.
217 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
218 * Specify {@code 0} to disable.
219 * @param keepAliveTime the amount of time for an inactive thread to shut itself down
220 * @param unit the {@link TimeUnit} of {@code keepAliveTime}
221 * @param threadFactory the {@link ThreadFactory} of this pool
222 * @param objectSizeEstimator the {@link ObjectSizeEstimator} of this pool
223 */
224 public MemoryAwareThreadPoolExecutor(
225 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226 long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227 ThreadFactory threadFactory) {
228
229 super(corePoolSize, corePoolSize, keepAliveTime, unit,
230 new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231
232 if (objectSizeEstimator == null) {
233 throw new NullPointerException("objectSizeEstimator");
234 }
235 if (maxChannelMemorySize < 0) {
236 throw new IllegalArgumentException(
237 "maxChannelMemorySize: " + maxChannelMemorySize);
238 }
239 if (maxTotalMemorySize < 0) {
240 throw new IllegalArgumentException(
241 "maxTotalMemorySize: " + maxTotalMemorySize);
242 }
243
244 // Call allowCoreThreadTimeOut(true) using reflection
245 // because it is not supported in Java 5.
246 try {
247 Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248 m.invoke(this, Boolean.TRUE);
249 } catch (Throwable t) {
250 // Java 5
251 logger.debug(
252 "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253 "supported in this platform.");
254 }
255
256 settings = new Settings(
257 objectSizeEstimator, maxChannelMemorySize);
258
259 if (maxTotalMemorySize == 0) {
260 totalLimiter = null;
261 } else {
262 totalLimiter = new Limiter(maxTotalMemorySize);
263 }
264
265 // Misuse check
266 misuseDetector.increase();
267 }
268
269 @Override
270 protected void terminated() {
271 super.terminated();
272 misuseDetector.decrease();
273 }
274
275 /**
276 * This will call {@link #shutdownNow(boolean)} with the value of {@link #getNotifyChannelFuturesOnShutdown()}.
277 */
278 @Override
279 public List<Runnable> shutdownNow() {
280 return shutdownNow(notifyOnShutdown);
281 }
282
283 /**
284 * See {@link ThreadPoolExecutor#shutdownNow()} for how it handles the shutdown.
285 * If {@code true} is given to this method it also notifies all {@link ChannelFuture}'s
286 * of the not executed {@link ChannelEventRunnable}'s.
287 *
288 * <p>
289 * Be aware that if you call this with {@code false} you will need to handle the
290 * notification of the {@link ChannelFuture}'s by your self. So only use this if you
291 * really have a use-case for it.
292 * </p>
293 *
294 */
295 public List<Runnable> shutdownNow(boolean notify) {
296 if (!notify) {
297 return super.shutdownNow();
298 }
299 Throwable cause = null;
300 Set<Channel> channels = null;
301
302 List<Runnable> tasks = super.shutdownNow();
303
304 // loop over all tasks and cancel the ChannelFuture of the ChannelEventRunable's
305 for (Runnable task: tasks) {
306 if (task instanceof ChannelEventRunnable) {
307 if (cause == null) {
308 cause = new IOException("Unable to process queued event");
309 }
310 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311 event.getFuture().setFailure(cause);
312
313 if (channels == null) {
314 channels = new HashSet<Channel>();
315 }
316
317 // store the Channel of the event for later notification of the exceptionCaught event
318 channels.add(event.getChannel());
319 }
320 }
321
322 // loop over all channels and fire an exceptionCaught event
323 if (channels != null) {
324 for (Channel channel: channels) {
325 Channels.fireExceptionCaughtLater(channel, cause);
326 }
327 }
328 return tasks;
329 }
330
331 /**
332 * Returns the {@link ObjectSizeEstimator} of this pool.
333 */
334 public ObjectSizeEstimator getObjectSizeEstimator() {
335 return settings.objectSizeEstimator;
336 }
337
338 /**
339 * Sets the {@link ObjectSizeEstimator} of this pool.
340 */
341 public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
342 if (objectSizeEstimator == null) {
343 throw new NullPointerException("objectSizeEstimator");
344 }
345
346 settings = new Settings(
347 objectSizeEstimator,
348 settings.maxChannelMemorySize);
349 }
350
351 /**
352 * Returns the maximum total size of the queued events per channel.
353 */
354 public long getMaxChannelMemorySize() {
355 return settings.maxChannelMemorySize;
356 }
357
358 /**
359 * Sets the maximum total size of the queued events per channel.
360 * Specify {@code 0} to disable.
361 */
362 public void setMaxChannelMemorySize(long maxChannelMemorySize) {
363 if (maxChannelMemorySize < 0) {
364 throw new IllegalArgumentException(
365 "maxChannelMemorySize: " + maxChannelMemorySize);
366 }
367
368 if (getTaskCount() > 0) {
369 throw new IllegalStateException(
370 "can't be changed after a task is executed");
371 }
372
373 settings = new Settings(
374 settings.objectSizeEstimator,
375 maxChannelMemorySize);
376 }
377
378 /**
379 * Returns the maximum total size of the queued events for this pool.
380 */
381 public long getMaxTotalMemorySize() {
382 if (totalLimiter == null) {
383 return 0;
384 }
385 return totalLimiter.limit;
386 }
387
388 /**
389 * @deprecated <tt>maxTotalMemorySize</tt> is not modifiable anymore.
390 */
391 @Deprecated
392 public void setMaxTotalMemorySize(long maxTotalMemorySize) {
393 if (maxTotalMemorySize < 0) {
394 throw new IllegalArgumentException(
395 "maxTotalMemorySize: " + maxTotalMemorySize);
396 }
397
398 if (getTaskCount() > 0) {
399 throw new IllegalStateException(
400 "can't be changed after a task is executed");
401 }
402 }
403
404 /**
405 * If set to {@code false} no queued {@link ChannelEventRunnable}'s {@link ChannelFuture}
406 * will get notified once {@link #shutdownNow()} is called. If set to {@code true} every
407 * queued {@link ChannelEventRunnable} will get marked as failed via {@link ChannelFuture#setFailure(Throwable)}.
408 *
409 * <p>
410 * Please only set this to {@code false} if you want to handle the notification by yourself
411 * and know what you are doing. Default is {@code true}.
412 * </p>
413 */
414 public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
415 this.notifyOnShutdown = notifyOnShutdown;
416 }
417
418 /**
419 * Returns if the {@link ChannelFuture}'s of the {@link ChannelEventRunnable}'s should be
420 * notified about the shutdown of this {@link MemoryAwareThreadPoolExecutor}.
421 */
422 public boolean getNotifyChannelFuturesOnShutdown() {
423 return notifyOnShutdown;
424 }
425
426 @Override
427 public void execute(Runnable command) {
428 if (command instanceof ChannelDownstreamEventRunnable) {
429 throw new RejectedExecutionException("command must be enclosed with an upstream event.");
430 }
431 if (!(command instanceof ChannelEventRunnable)) {
432 command = new MemoryAwareRunnable(command);
433 }
434
435 increaseCounter(command);
436 doExecute(command);
437 }
438
439 /**
440 * Put the actual execution logic here. The default implementation simply
441 * calls {@link #doUnorderedExecute(Runnable)}.
442 */
443 protected void doExecute(Runnable task) {
444 doUnorderedExecute(task);
445 }
446
447 /**
448 * Executes the specified task without maintaining the event order.
449 */
450 protected final void doUnorderedExecute(Runnable task) {
451 super.execute(task);
452 }
453
454 @Override
455 public boolean remove(Runnable task) {
456 boolean removed = super.remove(task);
457 if (removed) {
458 decreaseCounter(task);
459 }
460 return removed;
461 }
462
463 @Override
464 protected void beforeExecute(Thread t, Runnable r) {
465 super.beforeExecute(t, r);
466 decreaseCounter(r);
467 }
468
469 protected void increaseCounter(Runnable task) {
470 if (!shouldCount(task)) {
471 return;
472 }
473
474 Settings settings = this.settings;
475 long maxChannelMemorySize = settings.maxChannelMemorySize;
476
477 int increment = settings.objectSizeEstimator.estimateSize(task);
478
479 if (task instanceof ChannelEventRunnable) {
480 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
481 eventTask.estimatedSize = increment;
482 Channel channel = eventTask.getEvent().getChannel();
483 long channelCounter = getChannelCounter(channel).addAndGet(increment);
484 //System.out.println("IC: " + channelCounter + ", " + increment);
485 if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
486 if (channel.isReadable()) {
487 //System.out.println("UNREADABLE");
488 ChannelHandlerContext ctx = eventTask.getContext();
489 if (ctx.getHandler() instanceof ExecutionHandler) {
490 // readSuspended = true;
491 ctx.setAttachment(Boolean.TRUE);
492 }
493 channel.setReadable(false);
494 }
495 }
496 } else {
497 ((MemoryAwareRunnable) task).estimatedSize = increment;
498 }
499
500 if (totalLimiter != null) {
501 totalLimiter.increase(increment);
502 }
503 }
504
505 protected void decreaseCounter(Runnable task) {
506 if (!shouldCount(task)) {
507 return;
508 }
509
510 Settings settings = this.settings;
511 long maxChannelMemorySize = settings.maxChannelMemorySize;
512
513 int increment;
514 if (task instanceof ChannelEventRunnable) {
515 increment = ((ChannelEventRunnable) task).estimatedSize;
516 } else {
517 increment = ((MemoryAwareRunnable) task).estimatedSize;
518 }
519
520 if (totalLimiter != null) {
521 totalLimiter.decrease(increment);
522 }
523
524 if (task instanceof ChannelEventRunnable) {
525 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
526 Channel channel = eventTask.getEvent().getChannel();
527 long channelCounter = getChannelCounter(channel).addAndGet(-increment);
528 //System.out.println("DC: " + channelCounter + ", " + increment);
529 if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
530 if (!channel.isReadable()) {
531 //System.out.println("READABLE");
532 ChannelHandlerContext ctx = eventTask.getContext();
533 if (ctx.getHandler() instanceof ExecutionHandler) {
534 // check if the attachment was set as this means that we suspend the channel
535 // from reads. This only works when this pool is used with ExecutionHandler
536 // but I guess thats good enough for us.
537 //
538 // See #215
539 if (ctx.getAttachment() != null) {
540 // readSuspended = false;
541 ctx.setAttachment(null);
542 channel.setReadable(true);
543 }
544 } else {
545 channel.setReadable(true);
546 }
547 }
548 }
549 }
550 }
551
552 private AtomicLong getChannelCounter(Channel channel) {
553 AtomicLong counter = channelCounters.get(channel);
554 if (counter == null) {
555 counter = new AtomicLong();
556 AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
557 if (oldCounter != null) {
558 counter = oldCounter;
559 }
560 }
561
562 // Remove the entry when the channel closes.
563 if (!channel.isOpen()) {
564 channelCounters.remove(channel);
565 }
566 return counter;
567 }
568
569 /**
570 * Returns {@code true} if and only if the specified {@code task} should
571 * be counted to limit the global and per-channel memory consumption.
572 * To override this method, you must call {@code super.shouldCount()} to
573 * make sure important tasks are not counted.
574 */
575 protected boolean shouldCount(Runnable task) {
576 if (task instanceof ChannelUpstreamEventRunnable) {
577 ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
578 ChannelEvent e = r.getEvent();
579 if (e instanceof WriteCompletionEvent) {
580 return false;
581 } else if (e instanceof ChannelStateEvent) {
582 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
583 return false;
584 }
585 }
586 }
587 return true;
588 }
589
590 private static final class Settings {
591 final ObjectSizeEstimator objectSizeEstimator;
592 final long maxChannelMemorySize;
593
594 Settings(ObjectSizeEstimator objectSizeEstimator,
595 long maxChannelMemorySize) {
596 this.objectSizeEstimator = objectSizeEstimator;
597 this.maxChannelMemorySize = maxChannelMemorySize;
598 }
599 }
600
601 private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
602 public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
603 try {
604 final Thread t = new Thread(r, "Temporary task executor");
605 t.start();
606 } catch (Throwable e) {
607 throw new RejectedExecutionException(
608 "Failed to start a new thread", e);
609 }
610 }
611 }
612
613 private static final class MemoryAwareRunnable implements Runnable {
614 final Runnable task;
615 int estimatedSize;
616
617 MemoryAwareRunnable(Runnable task) {
618 this.task = task;
619 }
620
621 public void run() {
622 task.run();
623 }
624 }
625
626 private static class Limiter {
627
628 final long limit;
629 private long counter;
630 private int waiters;
631
632 Limiter(long limit) {
633 this.limit = limit;
634 }
635
636 synchronized void increase(long amount) {
637 while (counter >= limit) {
638 waiters ++;
639 try {
640 wait();
641 } catch (InterruptedException e) {
642 Thread.currentThread().interrupt();
643 } finally {
644 waiters --;
645 }
646 }
647 counter += amount;
648 }
649
650 synchronized void decrease(long amount) {
651 counter -= amount;
652 if (counter < limit && waiters > 0) {
653 notifyAll();
654 }
655 }
656 }
657 }