1 /*
2 * Copyright 2012 The Netty Project
3 *
4 * The Netty Project licenses this file to you under the Apache License,
5 * version 2.0 (the "License"); you may not use this file except in compliance
6 * with the License. You may obtain a copy of the License at:
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 * License for the specific language governing permissions and limitations
14 * under the License.
15 */
16 package org.jboss.netty.handler.execution;
17
18 import org.jboss.netty.buffer.ChannelBuffer;
19 import org.jboss.netty.channel.Channel;
20 import org.jboss.netty.channel.ChannelEvent;
21 import org.jboss.netty.channel.ChannelFuture;
22 import org.jboss.netty.channel.ChannelHandlerContext;
23 import org.jboss.netty.channel.ChannelState;
24 import org.jboss.netty.channel.ChannelStateEvent;
25 import org.jboss.netty.channel.Channels;
26 import org.jboss.netty.channel.MessageEvent;
27 import org.jboss.netty.channel.WriteCompletionEvent;
28 import org.jboss.netty.logging.InternalLogger;
29 import org.jboss.netty.logging.InternalLoggerFactory;
30 import org.jboss.netty.util.DefaultObjectSizeEstimator;
31 import org.jboss.netty.util.ObjectSizeEstimator;
32 import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
33 import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
34
35 import java.io.IOException;
36 import java.lang.reflect.Method;
37 import java.util.HashSet;
38 import java.util.List;
39 import java.util.Set;
40 import java.util.concurrent.ConcurrentMap;
41 import java.util.concurrent.Executor;
42 import java.util.concurrent.Executors;
43 import java.util.concurrent.LinkedBlockingQueue;
44 import java.util.concurrent.RejectedExecutionException;
45 import java.util.concurrent.RejectedExecutionHandler;
46 import java.util.concurrent.ThreadFactory;
47 import java.util.concurrent.ThreadPoolExecutor;
48 import java.util.concurrent.TimeUnit;
49 import java.util.concurrent.atomic.AtomicLong;
50
51 /**
52 * A {@link ThreadPoolExecutor} which blocks the task submission when there's
53 * too many tasks in the queue. Both per-{@link Channel} and per-{@link Executor}
54 * limitation can be applied.
55 * <p>
56 * When a task (i.e. {@link Runnable}) is submitted,
57 * {@link MemoryAwareThreadPoolExecutor} calls {@link ObjectSizeEstimator#estimateSize(Object)}
58 * to get the estimated size of the task in bytes to calculate the amount of
59 * memory occupied by the unprocessed tasks.
60 * <p>
61 * If the total size of the unprocessed tasks exceeds either per-{@link Channel}
62 * or per-{@link Executor} threshold, any further {@link #execute(Runnable)}
63 * call will block until the tasks in the queue are processed so that the total
64 * size goes under the threshold.
65 *
66 * <h3>Using an alternative task size estimation strategy</h3>
67 *
68 * Although the default implementation does its best to guess the size of an
69 * object of unknown type, it is always good idea to to use an alternative
70 * {@link ObjectSizeEstimator} implementation instead of the
71 * {@link DefaultObjectSizeEstimator} to avoid incorrect task size calculation,
72 * especially when:
73 * <ul>
74 * <li>you are using {@link MemoryAwareThreadPoolExecutor} independently from
75 * {@link ExecutionHandler},</li>
76 * <li>you are submitting a task whose type is not {@link ChannelEventRunnable}, or</li>
77 * <li>the message type of the {@link MessageEvent} in the {@link ChannelEventRunnable}
78 * is not {@link ChannelBuffer}.</li>
79 * </ul>
80 * Here is an example that demonstrates how to implement an {@link ObjectSizeEstimator}
81 * which understands a user-defined object:
82 * <pre>
83 * public class MyRunnable implements {@link Runnable} {
84 *
85 * <b>private final byte[] data;</b>
86 *
87 * public MyRunnable(byte[] data) {
88 * this.data = data;
89 * }
90 *
91 * public void run() {
92 * // Process 'data' ..
93 * }
94 * }
95 *
96 * public class MyObjectSizeEstimator extends {@link DefaultObjectSizeEstimator} {
97 *
98 * {@literal @Override}
99 * public int estimateSize(Object o) {
100 * if (<b>o instanceof MyRunnable</b>) {
101 * <b>return ((MyRunnable) o).data.length + 8;</b>
102 * }
103 * return super.estimateSize(o);
104 * }
105 * }
106 *
107 * {@link ThreadPoolExecutor} pool = new {@link MemoryAwareThreadPoolExecutor}(
108 * 16, 65536, 1048576, 30, {@link TimeUnit}.SECONDS,
109 * <b>new MyObjectSizeEstimator()</b>,
110 * {@link Executors}.defaultThreadFactory());
111 *
112 * <b>pool.execute(new MyRunnable(data));</b>
113 * </pre>
114 *
115 * <h3>Event execution order</h3>
116 *
117 * Please note that this executor does not maintain the order of the
118 * {@link ChannelEvent}s for the same {@link Channel}. For example,
119 * you can even receive a {@code "channelClosed"} event before a
120 * {@code "messageReceived"} event, as depicted by the following diagram.
121 *
122 * For example, the events can be processed as depicted below:
123 *
124 * <pre>
125 * --------------------------------> Timeline -------------------------------->
126 *
127 * Thread X: --- Channel A (Event 1) --- Channel A (Event 2) --------------------------->
128 *
129 * Thread Y: --- Channel A (Event 3) --- Channel B (Event 2) --- Channel B (Event 3) --->
130 *
131 * Thread Z: --- Channel B (Event 1) --- Channel B (Event 4) --- Channel A (Event 4) --->
132 * </pre>
133 *
134 * To maintain the event order, you must use {@link OrderedMemoryAwareThreadPoolExecutor}.
135 *
136 * @apiviz.has org.jboss.netty.util.ObjectSizeEstimator oneway - -
137 * @apiviz.has org.jboss.netty.handler.execution.ChannelEventRunnable oneway - - executes
138 */
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140
141 private static final InternalLogger logger =
142 InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143
144 private static final SharedResourceMisuseDetector misuseDetector =
145 new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146
147 private volatile Settings settings;
148
149 private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150 new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151 private final Limiter totalLimiter;
152
153 private volatile boolean notifyOnShutdown;
154
155 /**
156 * Creates a new instance.
157 *
158 * @param corePoolSize the maximum number of active threads
159 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
160 * Specify {@code 0} to disable.
161 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
162 * Specify {@code 0} to disable.
163 */
164 public MemoryAwareThreadPoolExecutor(
165 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166
167 this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168 }
169
170 /**
171 * Creates a new instance.
172 *
173 * @param corePoolSize the maximum number of active threads
174 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
175 * Specify {@code 0} to disable.
176 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
177 * Specify {@code 0} to disable.
178 * @param keepAliveTime the amount of time for an inactive thread to shut itself down
179 * @param unit the {@link TimeUnit} of {@code keepAliveTime}
180 */
181 public MemoryAwareThreadPoolExecutor(
182 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183 long keepAliveTime, TimeUnit unit) {
184
185 this(
186 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187 Executors.defaultThreadFactory());
188 }
189
190 /**
191 * Creates a new instance.
192 *
193 * @param corePoolSize the maximum number of active threads
194 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
195 * Specify {@code 0} to disable.
196 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
197 * Specify {@code 0} to disable.
198 * @param keepAliveTime the amount of time for an inactive thread to shut itself down
199 * @param unit the {@link TimeUnit} of {@code keepAliveTime}
200 * @param threadFactory the {@link ThreadFactory} of this pool
201 */
202 public MemoryAwareThreadPoolExecutor(
203 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204 long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205
206 this(
207 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208 new DefaultObjectSizeEstimator(), threadFactory);
209 }
210
211 /**
212 * Creates a new instance.
213 *
214 * @param corePoolSize the maximum number of active threads
215 * @param maxChannelMemorySize the maximum total size of the queued events per channel.
216 * Specify {@code 0} to disable.
217 * @param maxTotalMemorySize the maximum total size of the queued events for this pool
218 * Specify {@code 0} to disable.
219 * @param keepAliveTime the amount of time for an inactive thread to shut itself down
220 * @param unit the {@link TimeUnit} of {@code keepAliveTime}
221 * @param threadFactory the {@link ThreadFactory} of this pool
222 * @param objectSizeEstimator the {@link ObjectSizeEstimator} of this pool
223 */
224 public MemoryAwareThreadPoolExecutor(
225 int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226 long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227 ThreadFactory threadFactory) {
228
229 super(corePoolSize, corePoolSize, keepAliveTime, unit,
230 new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231
232 if (objectSizeEstimator == null) {
233 throw new NullPointerException("objectSizeEstimator");
234 }
235 if (maxChannelMemorySize < 0) {
236 throw new IllegalArgumentException(
237 "maxChannelMemorySize: " + maxChannelMemorySize);
238 }
239 if (maxTotalMemorySize < 0) {
240 throw new IllegalArgumentException(
241 "maxTotalMemorySize: " + maxTotalMemorySize);
242 }
243
244 // Call allowCoreThreadTimeOut(true) using reflection
245 // because it is not supported in Java 5.
246 try {
247 Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248 m.invoke(this, Boolean.TRUE);
249 } catch (Throwable t) {
250 // Java 5
251 logger.debug(
252 "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253 "supported in this platform.");
254 }
255
256 settings = new Settings(
257 objectSizeEstimator, maxChannelMemorySize);
258
259 if (maxTotalMemorySize == 0) {
260 totalLimiter = null;
261 } else {
262 totalLimiter = new Limiter(maxTotalMemorySize);
263 }
264
265 // Misuse check
266 misuseDetector.increase();
267 }
268
269 @Override
270 protected void terminated() {
271 super.terminated();
272 misuseDetector.decrease();
273 }
274
275 /**
276 * This will call {@link #shutdownNow(boolean)} with the value of {@link #getNotifyChannelFuturesOnShutdown()}.
277 */
278 @Override
279 public List<Runnable> shutdownNow() {
280 return shutdownNow(notifyOnShutdown);
281 }
282
283 /**
284 * See {@link ThreadPoolExecutor#shutdownNow()} for how it handles the shutdown.
285 * If {@code true} is given to this method it also notifies all {@link ChannelFuture}'s
286 * of the not executed {@link ChannelEventRunnable}'s.
287 *
288 * <p>
289 * Be aware that if you call this with {@code false} you will need to handle the
290 * notification of the {@link ChannelFuture}'s by your self. So only use this if you
291 * really have a use-case for it.
292 * </p>
293 *
294 */
295 public List<Runnable> shutdownNow(boolean notify) {
296 if (!notify) {
297 return super.shutdownNow();
298 }
299 Throwable cause = null;
300 Set<Channel> channels = null;
301
302 List<Runnable> tasks = super.shutdownNow();
303
304 // loop over all tasks and cancel the ChannelFuture of the ChannelEventRunable's
305 for (Runnable task: tasks) {
306 if (task instanceof ChannelEventRunnable) {
307 if (cause == null) {
308 cause = new IOException("Unable to process queued event");
309 }
310 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311 event.getFuture().setFailure(cause);
312
313 if (channels == null) {
314 channels = new HashSet<Channel>();
315 }
316
317 // store the Channel of the event for later notification of the exceptionCaught event
318 channels.add(event.getChannel());
319 }
320 }
321
322 // loop over all channels and fire an exceptionCaught event
323 if (channels != null) {
324 for (Channel channel: channels) {
325 Channels.fireExceptionCaughtLater(channel, cause);
326 }
327 }
328 return tasks;
329 }
330
331 /**
332 * Returns the {@link ObjectSizeEstimator} of this pool.
333 */
334 public ObjectSizeEstimator getObjectSizeEstimator() {
335 return settings.objectSizeEstimator;
336 }
337
338 /**
339 * Sets the {@link ObjectSizeEstimator} of this pool.
340 */
341 public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
342 if (objectSizeEstimator == null) {
343 throw new NullPointerException("objectSizeEstimator");
344 }
345
346 settings = new Settings(
347 objectSizeEstimator,
348 settings.maxChannelMemorySize);
349 }
350
351 /**
352 * Returns the maximum total size of the queued events per channel.
353 */
354 public long getMaxChannelMemorySize() {
355 return settings.maxChannelMemorySize;
356 }
357
358 /**
359 * Sets the maximum total size of the queued events per channel.
360 * Specify {@code 0} to disable.
361 */
362 public void setMaxChannelMemorySize(long maxChannelMemorySize) {
363 if (maxChannelMemorySize < 0) {
364 throw new IllegalArgumentException(
365 "maxChannelMemorySize: " + maxChannelMemorySize);
366 }
367
368 if (getTaskCount() > 0) {
369 throw new IllegalStateException(
370 "can't be changed after a task is executed");
371 }
372
373 settings = new Settings(
374 settings.objectSizeEstimator,
375 maxChannelMemorySize);
376 }
377
378 /**
379 * Returns the maximum total size of the queued events for this pool.
380 */
381 public long getMaxTotalMemorySize() {
382 if (totalLimiter == null) {
383 return 0;
384 }
385 return totalLimiter.limit;
386 }
387
388 /**
389 * If set to {@code false} no queued {@link ChannelEventRunnable}'s {@link ChannelFuture}
390 * will get notified once {@link #shutdownNow()} is called. If set to {@code true} every
391 * queued {@link ChannelEventRunnable} will get marked as failed via {@link ChannelFuture#setFailure(Throwable)}.
392 *
393 * <p>
394 * Please only set this to {@code false} if you want to handle the notification by yourself
395 * and know what you are doing. Default is {@code true}.
396 * </p>
397 */
398 public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
399 this.notifyOnShutdown = notifyOnShutdown;
400 }
401
402 /**
403 * Returns if the {@link ChannelFuture}'s of the {@link ChannelEventRunnable}'s should be
404 * notified about the shutdown of this {@link MemoryAwareThreadPoolExecutor}.
405 */
406 public boolean getNotifyChannelFuturesOnShutdown() {
407 return notifyOnShutdown;
408 }
409
410 @Override
411 public void execute(Runnable command) {
412 if (command instanceof ChannelDownstreamEventRunnable) {
413 throw new RejectedExecutionException("command must be enclosed with an upstream event.");
414 }
415 if (!(command instanceof ChannelEventRunnable)) {
416 command = new MemoryAwareRunnable(command);
417 }
418
419 increaseCounter(command);
420 doExecute(command);
421 }
422
423 /**
424 * Put the actual execution logic here. The default implementation simply
425 * calls {@link #doUnorderedExecute(Runnable)}.
426 */
427 protected void doExecute(Runnable task) {
428 doUnorderedExecute(task);
429 }
430
431 /**
432 * Executes the specified task without maintaining the event order.
433 */
434 protected final void doUnorderedExecute(Runnable task) {
435 super.execute(task);
436 }
437
438 @Override
439 public boolean remove(Runnable task) {
440 boolean removed = super.remove(task);
441 if (removed) {
442 decreaseCounter(task);
443 }
444 return removed;
445 }
446
447 @Override
448 protected void beforeExecute(Thread t, Runnable r) {
449 super.beforeExecute(t, r);
450 decreaseCounter(r);
451 }
452
453 protected void increaseCounter(Runnable task) {
454 if (!shouldCount(task)) {
455 return;
456 }
457
458 Settings settings = this.settings;
459 long maxChannelMemorySize = settings.maxChannelMemorySize;
460
461 int increment = settings.objectSizeEstimator.estimateSize(task);
462
463 if (task instanceof ChannelEventRunnable) {
464 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
465 eventTask.estimatedSize = increment;
466 Channel channel = eventTask.getEvent().getChannel();
467 long channelCounter = getChannelCounter(channel).addAndGet(increment);
468 //System.out.println("IC: " + channelCounter + ", " + increment);
469 if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
470 if (channel.isReadable()) {
471 //System.out.println("UNREADABLE");
472 ChannelHandlerContext ctx = eventTask.getContext();
473 if (ctx.getHandler() instanceof ExecutionHandler) {
474 // readSuspended = true;
475 ctx.setAttachment(Boolean.TRUE);
476 }
477 channel.setReadable(false);
478 }
479 }
480 } else {
481 ((MemoryAwareRunnable) task).estimatedSize = increment;
482 }
483
484 if (totalLimiter != null) {
485 totalLimiter.increase(increment);
486 }
487 }
488
489 protected void decreaseCounter(Runnable task) {
490 if (!shouldCount(task)) {
491 return;
492 }
493
494 Settings settings = this.settings;
495 long maxChannelMemorySize = settings.maxChannelMemorySize;
496
497 int increment;
498 if (task instanceof ChannelEventRunnable) {
499 increment = ((ChannelEventRunnable) task).estimatedSize;
500 } else {
501 increment = ((MemoryAwareRunnable) task).estimatedSize;
502 }
503
504 if (totalLimiter != null) {
505 totalLimiter.decrease(increment);
506 }
507
508 if (task instanceof ChannelEventRunnable) {
509 ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
510 Channel channel = eventTask.getEvent().getChannel();
511 long channelCounter = getChannelCounter(channel).addAndGet(-increment);
512 //System.out.println("DC: " + channelCounter + ", " + increment);
513 if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
514 if (!channel.isReadable()) {
515 //System.out.println("READABLE");
516 ChannelHandlerContext ctx = eventTask.getContext();
517 if (ctx.getHandler() instanceof ExecutionHandler) {
518 // check if the attachment was set as this means that we suspend the channel
519 // from reads. This only works when this pool is used with ExecutionHandler
520 // but I guess thats good enough for us.
521 //
522 // See #215
523 if (ctx.getAttachment() != null) {
524 // readSuspended = false;
525 ctx.setAttachment(null);
526 channel.setReadable(true);
527 }
528 } else {
529 channel.setReadable(true);
530 }
531 }
532 }
533 }
534 }
535
536 private AtomicLong getChannelCounter(Channel channel) {
537 AtomicLong counter = channelCounters.get(channel);
538 if (counter == null) {
539 counter = new AtomicLong();
540 AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
541 if (oldCounter != null) {
542 counter = oldCounter;
543 }
544 }
545
546 // Remove the entry when the channel closes.
547 if (!channel.isOpen()) {
548 channelCounters.remove(channel);
549 }
550 return counter;
551 }
552
553 /**
554 * Returns {@code true} if and only if the specified {@code task} should
555 * be counted to limit the global and per-channel memory consumption.
556 * To override this method, you must call {@code super.shouldCount()} to
557 * make sure important tasks are not counted.
558 */
559 protected boolean shouldCount(Runnable task) {
560 if (task instanceof ChannelUpstreamEventRunnable) {
561 ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
562 ChannelEvent e = r.getEvent();
563 if (e instanceof WriteCompletionEvent) {
564 return false;
565 } else if (e instanceof ChannelStateEvent) {
566 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
567 return false;
568 }
569 }
570 }
571 return true;
572 }
573
574 private static final class Settings {
575 final ObjectSizeEstimator objectSizeEstimator;
576 final long maxChannelMemorySize;
577
578 Settings(ObjectSizeEstimator objectSizeEstimator,
579 long maxChannelMemorySize) {
580 this.objectSizeEstimator = objectSizeEstimator;
581 this.maxChannelMemorySize = maxChannelMemorySize;
582 }
583 }
584
585 private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
586 public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
587 try {
588 final Thread t = new Thread(r, "Temporary task executor");
589 t.start();
590 } catch (Throwable e) {
591 throw new RejectedExecutionException(
592 "Failed to start a new thread", e);
593 }
594 }
595 }
596
597 private static final class MemoryAwareRunnable implements Runnable {
598 final Runnable task;
599 int estimatedSize;
600
601 MemoryAwareRunnable(Runnable task) {
602 this.task = task;
603 }
604
605 public void run() {
606 task.run();
607 }
608 }
609
610 private static class Limiter {
611
612 final long limit;
613 private long counter;
614 private int waiters;
615
616 Limiter(long limit) {
617 this.limit = limit;
618 }
619
620 synchronized void increase(long amount) {
621 while (counter >= limit) {
622 waiters ++;
623 try {
624 wait();
625 } catch (InterruptedException e) {
626 Thread.currentThread().interrupt();
627 } finally {
628 waiters --;
629 }
630 }
631 counter += amount;
632 }
633
634 synchronized void decrease(long amount) {
635 counter -= amount;
636 if (counter < limit && waiters > 0) {
637 notifyAll();
638 }
639 }
640 }
641 }