View Javadoc

1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package org.jboss.netty.handler.execution;
17  
18  import org.jboss.netty.buffer.ChannelBuffer;
19  import org.jboss.netty.channel.Channel;
20  import org.jboss.netty.channel.ChannelEvent;
21  import org.jboss.netty.channel.ChannelFuture;
22  import org.jboss.netty.channel.ChannelHandlerContext;
23  import org.jboss.netty.channel.ChannelState;
24  import org.jboss.netty.channel.ChannelStateEvent;
25  import org.jboss.netty.channel.Channels;
26  import org.jboss.netty.channel.MessageEvent;
27  import org.jboss.netty.channel.WriteCompletionEvent;
28  import org.jboss.netty.logging.InternalLogger;
29  import org.jboss.netty.logging.InternalLoggerFactory;
30  import org.jboss.netty.util.DefaultObjectSizeEstimator;
31  import org.jboss.netty.util.ObjectSizeEstimator;
32  import org.jboss.netty.util.internal.ConcurrentIdentityHashMap;
33  import org.jboss.netty.util.internal.SharedResourceMisuseDetector;
34  
35  import java.io.IOException;
36  import java.lang.reflect.Method;
37  import java.util.HashSet;
38  import java.util.List;
39  import java.util.Set;
40  import java.util.concurrent.ConcurrentMap;
41  import java.util.concurrent.Executor;
42  import java.util.concurrent.Executors;
43  import java.util.concurrent.LinkedBlockingQueue;
44  import java.util.concurrent.RejectedExecutionException;
45  import java.util.concurrent.RejectedExecutionHandler;
46  import java.util.concurrent.ThreadFactory;
47  import java.util.concurrent.ThreadPoolExecutor;
48  import java.util.concurrent.TimeUnit;
49  import java.util.concurrent.atomic.AtomicLong;
50  
51  /**
52   * A {@link ThreadPoolExecutor} which blocks the task submission when there's
53   * too many tasks in the queue.  Both per-{@link Channel} and per-{@link Executor}
54   * limitation can be applied.
55   * <p>
56   * When a task (i.e. {@link Runnable}) is submitted,
57   * {@link MemoryAwareThreadPoolExecutor} calls {@link ObjectSizeEstimator#estimateSize(Object)}
58   * to get the estimated size of the task in bytes to calculate the amount of
59   * memory occupied by the unprocessed tasks.
60   * <p>
61   * If the total size of the unprocessed tasks exceeds either per-{@link Channel}
62   * or per-{@link Executor} threshold, any further {@link #execute(Runnable)}
63   * call will block until the tasks in the queue are processed so that the total
64   * size goes under the threshold.
65   *
66   * <h3>Using an alternative task size estimation strategy</h3>
67   *
68   * Although the default implementation does its best to guess the size of an
69   * object of unknown type, it is always good idea to to use an alternative
70   * {@link ObjectSizeEstimator} implementation instead of the
71   * {@link DefaultObjectSizeEstimator} to avoid incorrect task size calculation,
72   * especially when:
73   * <ul>
74   *   <li>you are using {@link MemoryAwareThreadPoolExecutor} independently from
75   *       {@link ExecutionHandler},</li>
76   *   <li>you are submitting a task whose type is not {@link ChannelEventRunnable}, or</li>
77   *   <li>the message type of the {@link MessageEvent} in the {@link ChannelEventRunnable}
78   *       is not {@link ChannelBuffer}.</li>
79   * </ul>
80   * Here is an example that demonstrates how to implement an {@link ObjectSizeEstimator}
81   * which understands a user-defined object:
82   * <pre>
83   * public class MyRunnable implements {@link Runnable} {
84   *
85   *     <b>private final byte[] data;</b>
86   *
87   *     public MyRunnable(byte[] data) {
88   *         this.data = data;
89   *     }
90   *
91   *     public void run() {
92   *         // Process 'data' ..
93   *     }
94   * }
95   *
96   * public class MyObjectSizeEstimator extends {@link DefaultObjectSizeEstimator} {
97   *
98   *     {@literal @Override}
99   *     public int estimateSize(Object o) {
100  *         if (<b>o instanceof MyRunnable</b>) {
101  *             <b>return ((MyRunnable) o).data.length + 8;</b>
102  *         }
103  *         return super.estimateSize(o);
104  *     }
105  * }
106  *
107  * {@link ThreadPoolExecutor} pool = new {@link MemoryAwareThreadPoolExecutor}(
108  *         16, 65536, 1048576, 30, {@link TimeUnit}.SECONDS,
109  *         <b>new MyObjectSizeEstimator()</b>,
110  *         {@link Executors}.defaultThreadFactory());
111  *
112  * <b>pool.execute(new MyRunnable(data));</b>
113  * </pre>
114  *
115  * <h3>Event execution order</h3>
116  *
117  * Please note that this executor does not maintain the order of the
118  * {@link ChannelEvent}s for the same {@link Channel}.  For example,
119  * you can even receive a {@code "channelClosed"} event before a
120  * {@code "messageReceived"} event, as depicted by the following diagram.
121  *
122  * For example, the events can be processed as depicted below:
123  *
124  * <pre>
125  *           --------------------------------&gt; Timeline --------------------------------&gt;
126  *
127  * Thread X: --- Channel A (Event 1) --- Channel A (Event 2) ---------------------------&gt;
128  *
129  * Thread Y: --- Channel A (Event 3) --- Channel B (Event 2) --- Channel B (Event 3) ---&gt;
130  *
131  * Thread Z: --- Channel B (Event 1) --- Channel B (Event 4) --- Channel A (Event 4) ---&gt;
132  * </pre>
133  *
134  * To maintain the event order, you must use {@link OrderedMemoryAwareThreadPoolExecutor}.
135  *
136  * @apiviz.has org.jboss.netty.util.ObjectSizeEstimator oneway - -
137  * @apiviz.has org.jboss.netty.handler.execution.ChannelEventRunnable oneway - - executes
138  */
139 public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
140 
141     private static final InternalLogger logger =
142         InternalLoggerFactory.getInstance(MemoryAwareThreadPoolExecutor.class);
143 
144     private static final SharedResourceMisuseDetector misuseDetector =
145         new SharedResourceMisuseDetector(MemoryAwareThreadPoolExecutor.class);
146 
147     private volatile Settings settings;
148 
149     private final ConcurrentMap<Channel, AtomicLong> channelCounters =
150         new ConcurrentIdentityHashMap<Channel, AtomicLong>();
151     private final Limiter totalLimiter;
152 
153     private volatile boolean notifyOnShutdown;
154 
155     /**
156      * Creates a new instance.
157      *
158      * @param corePoolSize          the maximum number of active threads
159      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
160      *                              Specify {@code 0} to disable.
161      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
162      *                              Specify {@code 0} to disable.
163      */
164     public MemoryAwareThreadPoolExecutor(
165             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize) {
166 
167         this(corePoolSize, maxChannelMemorySize, maxTotalMemorySize, 30, TimeUnit.SECONDS);
168     }
169 
170     /**
171      * Creates a new instance.
172      *
173      * @param corePoolSize          the maximum number of active threads
174      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
175      *                              Specify {@code 0} to disable.
176      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
177      *                              Specify {@code 0} to disable.
178      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
179      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
180      */
181     public MemoryAwareThreadPoolExecutor(
182             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
183             long keepAliveTime, TimeUnit unit) {
184 
185         this(
186                 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
187                 Executors.defaultThreadFactory());
188     }
189 
190     /**
191      * Creates a new instance.
192      *
193      * @param corePoolSize          the maximum number of active threads
194      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
195      *                              Specify {@code 0} to disable.
196      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
197      *                              Specify {@code 0} to disable.
198      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
199      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
200      * @param threadFactory         the {@link ThreadFactory} of this pool
201      */
202     public MemoryAwareThreadPoolExecutor(
203             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
204             long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
205 
206         this(
207                 corePoolSize, maxChannelMemorySize, maxTotalMemorySize, keepAliveTime, unit,
208                 new DefaultObjectSizeEstimator(), threadFactory);
209     }
210 
211     /**
212      * Creates a new instance.
213      *
214      * @param corePoolSize          the maximum number of active threads
215      * @param maxChannelMemorySize  the maximum total size of the queued events per channel.
216      *                              Specify {@code 0} to disable.
217      * @param maxTotalMemorySize    the maximum total size of the queued events for this pool
218      *                              Specify {@code 0} to disable.
219      * @param keepAliveTime         the amount of time for an inactive thread to shut itself down
220      * @param unit                  the {@link TimeUnit} of {@code keepAliveTime}
221      * @param threadFactory         the {@link ThreadFactory} of this pool
222      * @param objectSizeEstimator   the {@link ObjectSizeEstimator} of this pool
223      */
224     public MemoryAwareThreadPoolExecutor(
225             int corePoolSize, long maxChannelMemorySize, long maxTotalMemorySize,
226             long keepAliveTime, TimeUnit unit, ObjectSizeEstimator objectSizeEstimator,
227             ThreadFactory threadFactory) {
228 
229         super(corePoolSize, corePoolSize, keepAliveTime, unit,
230               new LinkedBlockingQueue<Runnable>(), threadFactory, new NewThreadRunsPolicy());
231 
232         if (objectSizeEstimator == null) {
233             throw new NullPointerException("objectSizeEstimator");
234         }
235         if (maxChannelMemorySize < 0) {
236             throw new IllegalArgumentException(
237                     "maxChannelMemorySize: " + maxChannelMemorySize);
238         }
239         if (maxTotalMemorySize < 0) {
240             throw new IllegalArgumentException(
241                     "maxTotalMemorySize: " + maxTotalMemorySize);
242         }
243 
244         // Call allowCoreThreadTimeOut(true) using reflection
245         // because it is not supported in Java 5.
246         try {
247             Method m = getClass().getMethod("allowCoreThreadTimeOut", new Class[] { boolean.class });
248             m.invoke(this, Boolean.TRUE);
249         } catch (Throwable t) {
250             // Java 5
251             logger.debug(
252                     "ThreadPoolExecutor.allowCoreThreadTimeOut() is not " +
253                     "supported in this platform.");
254         }
255 
256         settings = new Settings(
257                 objectSizeEstimator, maxChannelMemorySize);
258 
259         if (maxTotalMemorySize == 0) {
260             totalLimiter = null;
261         } else {
262             totalLimiter = new Limiter(maxTotalMemorySize);
263         }
264 
265         // Misuse check
266         misuseDetector.increase();
267     }
268 
269     @Override
270     protected void terminated() {
271         super.terminated();
272         misuseDetector.decrease();
273     }
274 
275     /**
276      * This will call {@link #shutdownNow(boolean)} with the value of {@link #getNotifyChannelFuturesOnShutdown()}.
277      */
278     @Override
279     public List<Runnable> shutdownNow() {
280         return shutdownNow(notifyOnShutdown);
281     }
282 
283     /**
284      * See {@link ThreadPoolExecutor#shutdownNow()} for how it handles the shutdown.
285      * If {@code true} is given to this method it also notifies all {@link ChannelFuture}'s
286      * of the not executed {@link ChannelEventRunnable}'s.
287      *
288      * <p>
289      * Be aware that if you call this with {@code false} you will need to handle the
290      * notification of the {@link ChannelFuture}'s by your self. So only use this if you
291      * really have a use-case for it.
292      * </p>
293      *
294      */
295     public List<Runnable> shutdownNow(boolean notify) {
296         if (!notify) {
297             return super.shutdownNow();
298         }
299         Throwable cause = null;
300         Set<Channel> channels = null;
301 
302         List<Runnable> tasks = super.shutdownNow();
303 
304         // loop over all tasks and cancel the ChannelFuture of the ChannelEventRunable's
305         for (Runnable task: tasks) {
306             if (task instanceof ChannelEventRunnable) {
307                 if (cause == null) {
308                     cause = new IOException("Unable to process queued event");
309                 }
310                 ChannelEvent event = ((ChannelEventRunnable) task).getEvent();
311                 event.getFuture().setFailure(cause);
312 
313                 if (channels == null) {
314                     channels = new HashSet<Channel>();
315                 }
316 
317                 // store the Channel of the event for later notification of the exceptionCaught event
318                 channels.add(event.getChannel());
319             }
320         }
321 
322         // loop over all channels and fire an exceptionCaught event
323         if (channels != null) {
324             for (Channel channel: channels) {
325                 Channels.fireExceptionCaughtLater(channel, cause);
326             }
327         }
328         return tasks;
329     }
330 
331     /**
332      * Returns the {@link ObjectSizeEstimator} of this pool.
333      */
334     public ObjectSizeEstimator getObjectSizeEstimator() {
335         return settings.objectSizeEstimator;
336     }
337 
338     /**
339      * Sets the {@link ObjectSizeEstimator} of this pool.
340      */
341     public void setObjectSizeEstimator(ObjectSizeEstimator objectSizeEstimator) {
342         if (objectSizeEstimator == null) {
343             throw new NullPointerException("objectSizeEstimator");
344         }
345 
346         settings = new Settings(
347                 objectSizeEstimator,
348                 settings.maxChannelMemorySize);
349     }
350 
351     /**
352      * Returns the maximum total size of the queued events per channel.
353      */
354     public long getMaxChannelMemorySize() {
355         return settings.maxChannelMemorySize;
356     }
357 
358     /**
359      * Sets the maximum total size of the queued events per channel.
360      * Specify {@code 0} to disable.
361      */
362     public void setMaxChannelMemorySize(long maxChannelMemorySize) {
363         if (maxChannelMemorySize < 0) {
364             throw new IllegalArgumentException(
365                     "maxChannelMemorySize: " + maxChannelMemorySize);
366         }
367 
368         if (getTaskCount() > 0) {
369             throw new IllegalStateException(
370                     "can't be changed after a task is executed");
371         }
372 
373         settings = new Settings(
374                 settings.objectSizeEstimator,
375                 maxChannelMemorySize);
376     }
377 
378     /**
379      * Returns the maximum total size of the queued events for this pool.
380      */
381     public long getMaxTotalMemorySize() {
382         if (totalLimiter == null) {
383             return 0;
384         }
385         return totalLimiter.limit;
386     }
387 
388     /**
389      * If set to {@code false} no queued {@link ChannelEventRunnable}'s {@link ChannelFuture}
390      * will get notified once {@link #shutdownNow()} is called.  If set to {@code true} every
391      * queued {@link ChannelEventRunnable} will get marked as failed via {@link ChannelFuture#setFailure(Throwable)}.
392      *
393      * <p>
394      * Please only set this to {@code false} if you want to handle the notification by yourself
395      * and know what you are doing. Default is {@code true}.
396      * </p>
397      */
398     public void setNotifyChannelFuturesOnShutdown(boolean notifyOnShutdown) {
399         this.notifyOnShutdown = notifyOnShutdown;
400     }
401 
402     /**
403      * Returns if the {@link ChannelFuture}'s of the {@link ChannelEventRunnable}'s should be
404      * notified about the shutdown of this {@link MemoryAwareThreadPoolExecutor}.
405      */
406     public boolean getNotifyChannelFuturesOnShutdown() {
407         return notifyOnShutdown;
408     }
409 
410     @Override
411     public void execute(Runnable command) {
412         if (command instanceof ChannelDownstreamEventRunnable) {
413             throw new RejectedExecutionException("command must be enclosed with an upstream event.");
414         }
415         if (!(command instanceof ChannelEventRunnable)) {
416             command = new MemoryAwareRunnable(command);
417         }
418 
419         increaseCounter(command);
420         doExecute(command);
421     }
422 
423     /**
424      * Put the actual execution logic here.  The default implementation simply
425      * calls {@link #doUnorderedExecute(Runnable)}.
426      */
427     protected void doExecute(Runnable task) {
428         doUnorderedExecute(task);
429     }
430 
431     /**
432      * Executes the specified task without maintaining the event order.
433      */
434     protected final void doUnorderedExecute(Runnable task) {
435         super.execute(task);
436     }
437 
438     @Override
439     public boolean remove(Runnable task) {
440         boolean removed = super.remove(task);
441         if (removed) {
442             decreaseCounter(task);
443         }
444         return removed;
445     }
446 
447     @Override
448     protected void beforeExecute(Thread t, Runnable r) {
449         super.beforeExecute(t, r);
450         decreaseCounter(r);
451     }
452 
453     protected void increaseCounter(Runnable task) {
454         if (!shouldCount(task)) {
455             return;
456         }
457 
458         Settings settings = this.settings;
459         long maxChannelMemorySize = settings.maxChannelMemorySize;
460 
461         int increment = settings.objectSizeEstimator.estimateSize(task);
462 
463         if (task instanceof ChannelEventRunnable) {
464             ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
465             eventTask.estimatedSize = increment;
466             Channel channel = eventTask.getEvent().getChannel();
467             long channelCounter = getChannelCounter(channel).addAndGet(increment);
468             //System.out.println("IC: " + channelCounter + ", " + increment);
469             if (maxChannelMemorySize != 0 && channelCounter >= maxChannelMemorySize && channel.isOpen()) {
470                 if (channel.isReadable()) {
471                     //System.out.println("UNREADABLE");
472                     ChannelHandlerContext ctx = eventTask.getContext();
473                     if (ctx.getHandler() instanceof ExecutionHandler) {
474                         // readSuspended = true;
475                         ctx.setAttachment(Boolean.TRUE);
476                     }
477                     channel.setReadable(false);
478                 }
479             }
480         } else {
481             ((MemoryAwareRunnable) task).estimatedSize = increment;
482         }
483 
484         if (totalLimiter != null) {
485             totalLimiter.increase(increment);
486         }
487     }
488 
489     protected void decreaseCounter(Runnable task) {
490         if (!shouldCount(task)) {
491             return;
492         }
493 
494         Settings settings = this.settings;
495         long maxChannelMemorySize = settings.maxChannelMemorySize;
496 
497         int increment;
498         if (task instanceof ChannelEventRunnable) {
499             increment = ((ChannelEventRunnable) task).estimatedSize;
500         } else {
501             increment = ((MemoryAwareRunnable) task).estimatedSize;
502         }
503 
504         if (totalLimiter != null) {
505             totalLimiter.decrease(increment);
506         }
507 
508         if (task instanceof ChannelEventRunnable) {
509             ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
510             Channel channel = eventTask.getEvent().getChannel();
511             long channelCounter = getChannelCounter(channel).addAndGet(-increment);
512             //System.out.println("DC: " + channelCounter + ", " + increment);
513             if (maxChannelMemorySize != 0 && channelCounter < maxChannelMemorySize && channel.isOpen()) {
514                 if (!channel.isReadable()) {
515                     //System.out.println("READABLE");
516                     ChannelHandlerContext ctx = eventTask.getContext();
517                     if (ctx.getHandler() instanceof ExecutionHandler) {
518                         // check if the attachment was set as this means that we suspend the channel
519                         // from reads. This only works when this pool is used with ExecutionHandler
520                         // but I guess thats good enough for us.
521                         //
522                         // See #215
523                         if (ctx.getAttachment() != null) {
524                             // readSuspended = false;
525                             ctx.setAttachment(null);
526                             channel.setReadable(true);
527                         }
528                     } else {
529                         channel.setReadable(true);
530                     }
531                 }
532             }
533         }
534     }
535 
536     private AtomicLong getChannelCounter(Channel channel) {
537         AtomicLong counter = channelCounters.get(channel);
538         if (counter == null) {
539             counter = new AtomicLong();
540             AtomicLong oldCounter = channelCounters.putIfAbsent(channel, counter);
541             if (oldCounter != null) {
542                 counter = oldCounter;
543             }
544         }
545 
546         // Remove the entry when the channel closes.
547         if (!channel.isOpen()) {
548             channelCounters.remove(channel);
549         }
550         return counter;
551     }
552 
553     /**
554      * Returns {@code true} if and only if the specified {@code task} should
555      * be counted to limit the global and per-channel memory consumption.
556      * To override this method, you must call {@code super.shouldCount()} to
557      * make sure important tasks are not counted.
558      */
559     protected boolean shouldCount(Runnable task) {
560         if (task instanceof ChannelUpstreamEventRunnable) {
561             ChannelUpstreamEventRunnable r = (ChannelUpstreamEventRunnable) task;
562             ChannelEvent e = r.getEvent();
563             if (e instanceof WriteCompletionEvent) {
564                 return false;
565             } else if (e instanceof ChannelStateEvent) {
566                 if (((ChannelStateEvent) e).getState() == ChannelState.INTEREST_OPS) {
567                     return false;
568                 }
569             }
570         }
571         return true;
572     }
573 
574     private static final class Settings {
575         final ObjectSizeEstimator objectSizeEstimator;
576         final long maxChannelMemorySize;
577 
578         Settings(ObjectSizeEstimator objectSizeEstimator,
579                  long maxChannelMemorySize) {
580             this.objectSizeEstimator = objectSizeEstimator;
581             this.maxChannelMemorySize = maxChannelMemorySize;
582         }
583     }
584 
585     private static final class NewThreadRunsPolicy implements RejectedExecutionHandler {
586         public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
587             try {
588                 final Thread t = new Thread(r, "Temporary task executor");
589                 t.start();
590             } catch (Throwable e) {
591                 throw new RejectedExecutionException(
592                         "Failed to start a new thread", e);
593             }
594         }
595     }
596 
597     private static final class MemoryAwareRunnable implements Runnable {
598         final Runnable task;
599         int estimatedSize;
600 
601         MemoryAwareRunnable(Runnable task) {
602             this.task = task;
603         }
604 
605         public void run() {
606             task.run();
607         }
608     }
609 
610     private static class Limiter {
611 
612         final long limit;
613         private long counter;
614         private int waiters;
615 
616         Limiter(long limit) {
617             this.limit = limit;
618         }
619 
620         synchronized void increase(long amount) {
621             while (counter >= limit) {
622                 waiters ++;
623                 try {
624                     wait();
625                 } catch (InterruptedException e) {
626                     Thread.currentThread().interrupt();
627                 } finally {
628                     waiters --;
629                 }
630             }
631             counter += amount;
632         }
633 
634         synchronized void decrease(long amount) {
635             counter -= amount;
636             if (counter < limit && waiters > 0) {
637                 notifyAll();
638             }
639         }
640     }
641 }