View Javadoc

1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package org.jboss.netty.handler.execution;
17  
18  import java.io.IOException;
19  import java.lang.reflect.Method;
20  import java.util.HashSet;
21  import java.util.List;
22  import java.util.Set;
23  import java.util.concurrent.ConcurrentMap;
24  import java.util.concurrent.Executor;
25  import java.util.concurrent.Executors;
26  import java.util.concurrent.LinkedBlockingQueue;
27  import java.util.concurrent.RejectedExecutionException;
28  import java.util.concurrent.RejectedExecutionHandler;
29  import java.util.concurrent.ThreadFactory;
30  import java.util.concurrent.ThreadPoolExecutor;
31  import java.util.concurrent.TimeUnit;
32  import java.util.concurrent.atomic.AtomicLong;
33  
34  import org.jboss.netty.buffer.ChannelBuffer;
35  import org.jboss.netty.channel.Channel;
36  import org.jboss.netty.channel.ChannelEvent;
37  import org.jboss.netty.channel.ChannelFuture;
38  import org.jboss.netty.channel.ChannelHandlerContext;
39  import org.jboss.netty.channel.ChannelState;
40  import org.jboss.netty.channel.ChannelStateEvent;
41  import org.jboss.netty.channel.Channels;
42  import org.jboss.netty.channel.MessageEvent;
43  import org.jboss.netty.channel.WriteCompletionEvent;
44  import org.jboss.netty.logging.InternalLogger;
45  import org.jboss.netty.logging.InternalLoggerFactory;
46  import org.jboss.netty.util.DefaultObjectSizeEstimator;
47  import org.jboss.netty.util.ObjectSizeEstimator;
48  import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
49  import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
50  
51  /**
52   * A {@link ThreadPoolExecutor} which blocks the task submission when there's
53   * too many tasks in the queue.  Both per-{@link Channel} and per-{@link Executor}
54   * limitation can be applied.
55   * <p>
56   * When a task (i.e. {@link Runnable}) is submitted,
57   * {@link MemoryAwareThreadPoolExecutor} calls {@link ObjectSizeEstimator#estimateSize(Object)}
58   * to get the estimated size of the task in bytes to calculate the amount of
59   * memory occupied by the unprocessed tasks.
60   * <p>
61   * If the total size of the unprocessed tasks exceeds either per-{@link Channel}
62   * or per-{@link Executor} threshold, any further {@link #execute(Runnable)}
63   * call will block until the tasks in the queue are processed so that the total
64   * size goes under the threshold.
65   *
66   * <h3>Using an alternative task size estimation strategy</h3>
67   *
68   * Although the default implementation does its best to guess the size of an
69   * object of unknown type, it is always good idea to to use an alternative
70   * {@link ObjectSizeEstimator} implementation instead of the
71   * {@link DefaultObjectSizeEstimator} to avoid incorrect task size calculation,
72   * especially when:
73   * <ul>
74   *   <li>you are using {@link MemoryAwareThreadPoolExecutor} independently from
75   *       {@link ExecutionHandler},</li>
76   *   <li>you are submitting a task whose type is not {@link ChannelEventRunnable}, or</li>
77   *   <li>the message type of the {@link MessageEvent} in the {@link ChannelEventRunnable}
78   *       is not {@link ChannelBuffer}.</li>
79   * </ul>
80   * Here is an example that demonstrates how to implement an {@link ObjectSizeEstimator}
81   * which understands a user-defined object:
82   * <pre>
83   * public class MyRunnable implements {@link Runnable} {
84   *
85   *     <b>private final byte[] data;</b>
86   *
87   *     public MyRunnable(byte[] data) {
88   *         this.data = data;
89   *     }
90   *
91   *     public void run() {
92   *         // Process 'data' ..
93   *     }
94   * }
95   *
96   * public class MyObjectSizeEstimator extends {@link DefaultObjectSizeEstimator} {
97   *
98   *     {@literal @Override}
99   *     public int estimateSize(Object o) {
100  *         if (<b>o instanceof MyRunnable</b>) {
101  *             <b>return ((MyRunnable) o).data.length + 8;</b>
102  *         }
103  *         return super.estimateSize(o);
104  *     }
105  * }
106  *
107  * {@link ThreadPoolExecutor} pool = new {@link MemoryAwareThreadPoolExecutor}(
108  *         16, 65536, 1048576, 30, {@link TimeUnit}.SECONDS,
109  *         <b>new MyObjectSizeEstimator()</b>,
110  *         {@link Executors}.defaultThreadFactory());
111  *
112  * <b>pool.execute(new MyRunnable(data));</b>
113  * </pre>
114  *
115  * <h3>Event execution order</h3>
116  *
117  * Please note that this executor does not maintain the order of the
118  * {@link ChannelEvent}s for the same {@link Channel}.  For example,
119  * you can even receive a {@code "channelClosed"} event before a
120  * {@code "messageReceived"} event, as depicted by the following diagram.
121  *
122  * For example, the events can be processed as depicted below:
123  *
124  * <pre>
125  *           --------------------------------&gt; Timeline --------------------------------&gt;
126  *
127  * Thread X: --- Channel A (Event 2) --- Channel A (Event 1) ---------------------------&gt;
128  *
129  * Thread Y: --- Channel A (Event 3) --- Channel B (Event 2) --- Channel B (Event 3) ---&gt;
130  *
131  * Thread Z: --- Channel B (Event 1) --- Channel B (Event 4) --- Channel A (Event 4) ---&gt;
132  * </pre>
133  *
134  * To maintain the event order, you must use {@link OrderedMemoryAwareThreadPoolExecutor}.
135  *
136  * @apiviz.has org.jboss.netty.util.ObjectSizeEstimator oneway - -
137  * @apiviz.has org.jboss.netty.handler.execution.ChannelEventRunnable oneway - - executes
138  */
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140 
141     private static final InternalLogger logger =
142         InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143 
144     private static final SharedResourceMisuseDetector misuseDetector =
145         new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146 
147     private volatile Settings settings;
148 
149     private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150         new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151     private final Limiter totalLimiter;
152 
153     private volatile boolean notifyOnShutdown;
154 
155     /**
156      * Creates a new instance.
157      *
158      * @param corePoolSize          the maximum number of active threads
159      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
160      *                              Specify {@code 0} to disable.
161      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
162      *                              Specify {@code 0} to disable.
163      */
164     public MemoryAwareThreadPoolExecutor(
165             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166 
167         this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168     }
169 
170     /**
171      * Creates a new instance.
172      *
173      * @param corePoolSize          the maximum number of active threads
174      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
175      *                              Specify {@code 0} to disable.
176      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
177      *                              Specify {@code 0} to disable.
178      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
179      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
180      */
181     public MemoryAwareThreadPoolExecutor(
182             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183             long keepAliveTime, TimeUnit unit) {
184 
185         this(
186                 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187                 Executors.defaultThreadFactory());
188     }
189 
190     /**
191      * Creates a new instance.
192      *
193      * @param corePoolSize          the maximum number of active threads
194      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
195      *                              Specify {@code 0} to disable.
196      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
197      *                              Specify {@code 0} to disable.
198      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
199      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
200      * @param threadFactory         the {@link ThreadFactory} of this pool
201      */
202     public MemoryAwareThreadPoolExecutor(
203             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204             long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205 
206         this(
207                 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208                 new DefaultObjectSizeEstimator(), threadFactory);
209     }
210 
211     /**
212      * Creates a new instance.
213      *
214      * @param corePoolSize          the maximum number of active threads
215      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
216      *                              Specify {@code 0} to disable.
217      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
218      *                              Specify {@code 0} to disable.
219      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
220      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
221      * @param threadFactory         the {@link ThreadFactory} of this pool
222      * @param objectSizeEstimator   the {@link ObjectSizeEstimator} of this pool
223      */
224     public MemoryAwareThreadPoolExecutor(
225             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226             long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227             ThreadFactory threadFactory) {
228 
229         super(corePoolSize, corePoolSize, keepAliveTime, unit,
230               new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231 
232         if (objectSizeEstimator == null) {
233             throw new NullPointerException("objectSizeEstimator");
234         }
235         if (maxChannelMemorySize < 0) {
236             throw new IllegalArgumentException(
237                     "maxChannelMemorySize: " + maxChannelMemorySize);
238         }
239         if (maxTotalMemorySize < 0) {
240             throw new IllegalArgumentException(
241                     "maxTotalMemorySize: " + maxTotalMemorySize);
242         }
243 
244         // Call allowCoreThreadTimeOut(true) using reflection
245         // because it is not supported in Java 5.
246         try {
247             Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248             m.invoke(this, Boolean.TRUE);
249         } catch (Throwable t) {
250             // Java 5
251             logger.debug(
252                     "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253                     "supported in this platform.");
254         }
255 
256         settings = new Settings(
257                 objectSizeEstimator, maxChannelMemorySize);
258 
259         if (maxTotalMemorySize == 0) {
260             totalLimiter = null;
261         } else {
262             totalLimiter = new Limiter(maxTotalMemorySize);
263         }
264 
265         // Misuse check
266         misuseDetector.increase();
267     }
268 
269     @Override
270     protected void terminated() {
271         super.terminated();
272         misuseDetector.decrease();
273     }
274 
275     /**
276      * This will call {@link #shutdownNow(boolean)} with the value of {@link #getNotifyChannelFuturesOnShutdown()}.
277      */
278     @Override
279     public List<Runnable> shutdownNow() {
280         return shutdownNow(notifyOnShutdown);
281     }
282 
283     /**
284      * See {@link ThreadPoolExecutor#shutdownNow()} for how it handles the shutdown.
285      * If <code>true</code> is given to this method it also notifies all {@link ChannelFuture}'s
286      * of the not executed {@link ChannelEventRunnable}'s.
287      *
288      * <p>
289      * Be aware that if you call this with <code>false</code> you will need to handle the
290      * notification of the {@link ChannelFuture}'s by your self. So only use this if you
291      * really have a use-case for it.
292      * </p>
293      *
294      */
295     public List<Runnable> shutdownNow(boolean notify) {
296         if (!notify) {
297             return super.shutdownNow();
298         }
299         Throwable cause = null;
300         Set<Channel> channels = null;
301 
302         List<Runnable> tasks = super.shutdownNow();
303 
304         // loop over all tasks and cancel the ChannelFuture of the ChannelEventRunable's
305         for (Runnable task: tasks) {
306             if (task instanceof ChannelEventRunnable) {
307                 if (cause == null) {
308                     cause = new IOException("Unable to process queued event");
309                 }
310                 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311                 event.getFuture().setFailure(cause);
312 
313                 if (channels == null) {
314                     channels = new HashSet<Channel>();
315                 }
316 
317 
318                 // store the Channel of the event for later notification of the exceptionCaught event
319                 channels.add(event.getChannel());
320             }
321         }
322 
323         // loop over all channels and fire an exceptionCaught event
324         if (channels != null) {
325             for (Channel channel: channels) {
326                 Channels.fireExceptionCaughtLater(channel, cause);
327             }
328         }
329         return tasks;
330     }
331 
332     /**
333      * Returns the {@link ObjectSizeEstimator} of this pool.
334      */
335     public ObjectSizeEstimator getObjectSizeEstimator() {
336         return settings.objectSizeEstimator;
337     }
338 
339     /**
340      * Sets the {@link ObjectSizeEstimator} of this pool.
341      */
342     public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
343         if (objectSizeEstimator == null) {
344             throw new NullPointerException("objectSizeEstimator");
345         }
346 
347         settings = new Settings(
348                 objectSizeEstimator,
349                 settings.maxChannelMemorySize);
350     }
351 
352     /**
353      * Returns the maximum total size of the queued events per channel.
354      */
355     public long getMaxChannelMemorySize() {
356         return settings.maxChannelMemorySize;
357     }
358 
359     /**
360      * Sets the maximum total size of the queued events per channel.
361      * Specify {@code 0} to disable.
362      */
363     public void setMaxChannelMemorySize(long maxChannelMemorySize) {
364         if (maxChannelMemorySize < 0) {
365             throw new IllegalArgumentException(
366                     "maxChannelMemorySize: " + maxChannelMemorySize);
367         }
368 
369         if (getTaskCount() > 0) {
370             throw new IllegalStateException(
371                     "can't be changed after a task is executed");
372         }
373 
374         settings = new Settings(
375                 settings.objectSizeEstimator,
376                 maxChannelMemorySize);
377     }
378 
379     /**
380      * Returns the maximum total size of the queued events for this pool.
381      */
382     public long getMaxTotalMemorySize() {
383         if (totalLimiter == null) {
384             return 0;
385         }
386         return totalLimiter.limit;
387     }
388 
389 
390     /**
391      * @deprecated <tt>maxTotalMemorySize</tt> is not modifiable anymore.
392      */
393     @Deprecated
394     public void setMaxTotalMemorySize(long maxTotalMemorySize) {
395         if (maxTotalMemorySize < 0) {
396             throw new IllegalArgumentException(
397                     "maxTotalMemorySize: " + maxTotalMemorySize);
398         }
399 
400         if (getTaskCount() > 0) {
401             throw new IllegalStateException(
402                     "can't be changed after a task is executed");
403         }
404     }
405 
406     /**
407      * If set to <code>false</code> no queued {@link ChannelEventRunnable}'s {@link ChannelFuture}
408      * will get notified once {@link #shutdownNow()} is called.  If set to <code>true</code> every
409      * queued {@link ChannelEventRunnable} will get marked as failed via {@link ChannelFuture#setFailure(Throwable)}.
410      *
411      * <p>
412      * Please only set this to <code>false</code> if you want to handle the notification by yourself
413      * and know what you are doing. Default is <code>true</code>.
414      * </p>
415      */
416     public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
417         this.notifyOnShutdown = notifyOnShutdown;
418     }
419 
420     /**
421      * Returns if the {@link ChannelFuture}'s of the {@link ChannelEventRunnable}'s should be
422      * notified about the shutdown of this {@link MemoryAwareThreadPoolExecutor}.
423      */
424     public boolean getNotifyChannelFuturesOnShutdown() {
425         return notifyOnShutdown;
426     }
427 
428 
429 
430     @Override
431     public void execute(Runnable command) {
432         if (command instanceof ChannelDownstreamEventRunnable) {
433             throw new RejectedExecutionException("command must be enclosed with an upstream event.");
434         }
435         if (!(command instanceof ChannelEventRunnable)) {
436             command = new MemoryAwareRunnable(command);
437         }
438 
439         increaseCounter(command);
440         doExecute(command);
441     }
442 
443     /**
444      * Put the actual execution logic here.  The default implementation simply
445      * calls {@link #doUnorderedExecute(Runnable)}.
446      */
447     protected void doExecute(Runnable task) {
448         doUnorderedExecute(task);
449     }
450 
451     /**
452      * Executes the specified task without maintaining the event order.
453      */
454     protected final void doUnorderedExecute(Runnable task) {
455         super.execute(task);
456     }
457 
458     @Override
459     public boolean remove(Runnable task) {
460         boolean removed = super.remove(task);
461         if (removed) {
462             decreaseCounter(task);
463         }
464         return removed;
465     }
466 
467     @Override
468     protected void beforeExecute(Thread t, Runnable r) {
469         super.beforeExecute(t, r);
470         decreaseCounter(r);
471     }
472 
473     protected void increaseCounter(Runnable task) {
474         if (!shouldCount(task)) {
475             return;
476         }
477 
478         Settings settings = this.settings;
479         long maxChannelMemorySize = settings.maxChannelMemorySize;
480 
481         int increment = settings.objectSizeEstimator.estimateSize(task);
482 
483         if (task instanceof ChannelEventRunnable) {
484             ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
485             eventTask.estimatedSize = increment;
486             Channel channel = eventTask.getEvent().getChannel();
487             long channelCounter = getChannelCounter(channel).addAndGet(increment);
488             //System.out.println("IC: " + channelCounter + ", " + increment);
489             if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
490                 if (channel.isReadable()) {
491                     //System.out.println("UNREADABLE");
492                     ChannelHandlerContext ctx = eventTask.getContext();
493                     if (ctx.getHandler() instanceof ExecutionHandler) {
494                         // readSuspended = true;
495                         ctx.setAttachment(Boolean.TRUE);
496                     }
497                     channel.setReadable(false);
498                 }
499             }
500         } else {
501             ((MemoryAwareRunnable) task).estimatedSize = increment;
502         }
503 
504         if (totalLimiter != null) {
505             totalLimiter.increase(increment);
506         }
507     }
508 
509     protected void decreaseCounter(Runnable task) {
510         if (!shouldCount(task)) {
511             return;
512         }
513 
514         Settings settings = this.settings;
515         long maxChannelMemorySize = settings.maxChannelMemorySize;
516 
517         int increment;
518         if (task instanceof ChannelEventRunnable) {
519             increment = ((ChannelEventRunnable) task).estimatedSize;
520         } else {
521             increment = ((MemoryAwareRunnable) task).estimatedSize;
522         }
523 
524         if (totalLimiter != null) {
525             totalLimiter.decrease(increment);
526         }
527 
528         if (task instanceof ChannelEventRunnable) {
529             ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
530             Channel channel = eventTask.getEvent().getChannel();
531             long channelCounter = getChannelCounter(channel).addAndGet(-increment);
532             //System.out.println("DC: " + channelCounter + ", " + increment);
533             if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
534                 if (!channel.isReadable()) {
535                     //System.out.println("READABLE");
536                     ChannelHandlerContext ctx = eventTask.getContext();
537                     if (ctx.getHandler() instanceof ExecutionHandler) {
538                         // check if the attachment was set as this means that we suspend the channel
539                         // from reads. This only works when this pool is used with ExecutionHandler
540                         // but I guess thats good enough for us.
541                         //
542                         // See #215
543                         if (ctx.getAttachment() != null) {
544                             // readSuspended = false;
545                             ctx.setAttachment(null);
546                             channel.setReadable(true);
547                         }
548                     } else {
549                         channel.setReadable(true);
550                     }
551                 }
552             }
553         }
554     }
555 
556     private AtomicLong getChannelCounter(Channel channel) {
557         AtomicLong counter = channelCounters.get(channel);
558         if (counter == null) {
559             counter = new AtomicLong();
560             AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
561             if (oldCounter != null) {
562                 counter = oldCounter;
563             }
564         }
565 
566         // Remove the entry when the channel closes.
567         if (!channel.isOpen()) {
568             channelCounters.remove(channel);
569         }
570         return counter;
571     }
572 
573     /**
574      * Returns {@code true} if and only if the specified {@code task} should
575      * be counted to limit the global and per-channel memory consumption.
576      * To override this method, you must call {@code super.shouldCount()} to
577      * make sure important tasks are not counted.
578      */
579     protected boolean shouldCount(Runnable task) {
580         if (task instanceof ChannelUpstreamEventRunnable) {
581             ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
582             ChannelEvent e = r.getEvent();
583             if (e instanceof WriteCompletionEvent) {
584                 return false;
585             } else if (e instanceof ChannelStateEvent) {
586                 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
587                     return false;
588                 }
589             }
590         }
591         return true;
592     }
593 
594     private static final class Settings {
595         final ObjectSizeEstimator objectSizeEstimator;
596         final long maxChannelMemorySize;
597 
598         Settings(ObjectSizeEstimator objectSizeEstimator,
599                  long maxChannelMemorySize) {
600             this.objectSizeEstimator = objectSizeEstimator;
601             this.maxChannelMemorySize = maxChannelMemorySize;
602         }
603     }
604 
605     private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
606         public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
607             try {
608                 final Thread t = new Thread(r, "Temporary task executor");
609                 t.start();
610             } catch (Throwable e) {
611                 throw new RejectedExecutionException(
612                         "Failed to start a new thread", e);
613             }
614         }
615     }
616 
617     private static final class MemoryAwareRunnable implements Runnable {
618         final Runnable task;
619         int estimatedSize;
620 
621         MemoryAwareRunnable(Runnable task) {
622             this.task = task;
623         }
624 
625         public void run() {
626             task.run();
627         }
628     }
629 
630 
631     private static class Limiter {
632 
633         final long limit;
634         private long counter;
635         private int waiters;
636 
637         Limiter(long limit) {
638             this.limit = limit;
639         }
640 
641         synchronized void increase(long amount) {
642             while (counter >= limit) {
643                 waiters ++;
644                 try {
645                     wait();
646                 } catch (InterruptedException e) {
647                     Thread.currentThread().interrupt();
648                 } finally {
649                     waiters --;
650                 }
651             }
652             counter += amount;
653         }
654 
655         synchronized void decrease(long amount) {
656             counter -= amount;
657             if (counter < limit && waiters > 0) {
658                 notifyAll();
659             }
660         }
661     }
662 }