View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  
17  package io.netty.util;
18  
19  import io.netty.util.concurrent.DefaultThreadFactory;
20  import io.netty.util.internal.MpscLinkedQueueNode;
21  import io.netty.util.internal.PlatformDependent;
22  import io.netty.util.internal.logging.InternalLogger;
23  import io.netty.util.internal.logging.InternalLoggerFactory;
24  
25  import java.util.ArrayList;
26  import java.util.List;
27  import java.util.Queue;
28  import java.util.concurrent.ThreadFactory;
29  import java.util.concurrent.TimeUnit;
30  import java.util.concurrent.atomic.AtomicBoolean;
31  
32  /**
33   * Checks if a thread is alive periodically and runs a task when a thread dies.
34   * <p>
35   * This thread starts a daemon thread to check the state of the threads being watched and to invoke their
36   * associated {@link Runnable}s.  When there is no thread to watch (i.e. all threads are dead), the daemon thread
37   * will terminate itself, and a new daemon thread will be started again when a new watch is added.
38   * </p>
39   */
40  public final class ThreadDeathWatcher {
41  
42      private static final InternalLogger logger = InternalLoggerFactory.getInstance(ThreadDeathWatcher.class);
43      private static final ThreadFactory threadFactory =
44              new DefaultThreadFactory(ThreadDeathWatcher.class, true, Thread.MIN_PRIORITY);
45  
46      private static final Queue<Entry> pendingEntries = PlatformDependent.newMpscQueue();
47      private static final Watcher watcher = new Watcher();
48      private static final AtomicBoolean started = new AtomicBoolean();
49      private static volatile Thread watcherThread;
50  
51      /**
52       * Schedules the specified {@code task} to run when the specified {@code thread} dies.
53       *
54       * @param thread the {@link Thread} to watch
55       * @param task the {@link Runnable} to run when the {@code thread} dies
56       *
57       * @throws IllegalArgumentException if the specified {@code thread} is not alive
58       */
59      public static void watch(Thread thread, Runnable task) {
60          if (thread == null) {
61              throw new NullPointerException("thread");
62          }
63          if (task == null) {
64              throw new NullPointerException("task");
65          }
66          if (!thread.isAlive()) {
67              throw new IllegalArgumentException("thread must be alive.");
68          }
69  
70          schedule(thread, task, true);
71      }
72  
73      /**
74       * Cancels the task scheduled via {@link #watch(Thread, Runnable)}.
75       */
76      public static void unwatch(Thread thread, Runnable task) {
77          if (thread == null) {
78              throw new NullPointerException("thread");
79          }
80          if (task == null) {
81              throw new NullPointerException("task");
82          }
83  
84          schedule(thread, task, false);
85      }
86  
87      private static void schedule(Thread thread, Runnable task, boolean isWatch) {
88          pendingEntries.add(new Entry(thread, task, isWatch));
89  
90          if (started.compareAndSet(false, true)) {
91              Thread watcherThread = threadFactory.newThread(watcher);
92              watcherThread.start();
93              ThreadDeathWatcher.watcherThread = watcherThread;
94          }
95      }
96  
97      /**
98       * Waits until the thread of this watcher has no threads to watch and terminates itself.
99       * Because a new watcher thread will be started again on {@link #watch(Thread, Runnable)},
100      * this operation is only useful when you want to ensure that the watcher thread is terminated
101      * <strong>after</strong> your application is shut down and there's no chance of calling
102      * {@link #watch(Thread, Runnable)} afterwards.
103      *
104      * @return {@code true} if and only if the watcher thread has been terminated
105      */
106     public static boolean awaitInactivity(long timeout, TimeUnit unit) throws InterruptedException {
107         if (unit == null) {
108             throw new NullPointerException("unit");
109         }
110 
111         Thread watcherThread = ThreadDeathWatcher.watcherThread;
112         if (watcherThread != null) {
113             watcherThread.join(unit.toMillis(timeout));
114             return !watcherThread.isAlive();
115         } else {
116             return true;
117         }
118     }
119 
120     private ThreadDeathWatcher() { }
121 
122     private static final class Watcher implements Runnable {
123 
124         private final List<Entry> watchees = new ArrayList<Entry>();
125 
126         @Override
127         public void run() {
128             for (;;) {
129                 fetchWatchees();
130                 notifyWatchees();
131 
132                 // Try once again just in case notifyWatchees() triggered watch() or unwatch().
133                 fetchWatchees();
134                 notifyWatchees();
135 
136                 try {
137                     Thread.sleep(1000);
138                 } catch (InterruptedException ignore) {
139                     // Ignore the interrupt; do not terminate until all tasks are run.
140                 }
141 
142                 if (watchees.isEmpty() && pendingEntries.isEmpty()) {
143 
144                     // Mark the current worker thread as stopped.
145                     // The following CAS must always success and must be uncontended,
146                     // because only one watcher thread should be running at the same time.
147                     boolean stopped = started.compareAndSet(true, false);
148                     assert stopped;
149 
150                     // Check if there are pending entries added by watch() while we do CAS above.
151                     if (pendingEntries.isEmpty()) {
152                         // A) watch() was not invoked and thus there's nothing to handle
153                         //    -> safe to terminate because there's nothing left to do
154                         // B) a new watcher thread started and handled them all
155                         //    -> safe to terminate the new watcher thread will take care the rest
156                         break;
157                     }
158 
159                     // There are pending entries again, added by watch()
160                     if (!started.compareAndSet(false, true)) {
161                         // watch() started a new watcher thread and set 'started' to true.
162                         // -> terminate this thread so that the new watcher reads from pendingEntries exclusively.
163                         break;
164                     }
165 
166                     // watch() added an entry, but this worker was faster to set 'started' to true.
167                     // i.e. a new watcher thread was not started
168                     // -> keep this thread alive to handle the newly added entries.
169                 }
170             }
171         }
172 
173         private void fetchWatchees() {
174             for (;;) {
175                 Entry e = pendingEntries.poll();
176                 if (e == null) {
177                     break;
178                 }
179 
180                 if (e.isWatch) {
181                     watchees.add(e);
182                 } else {
183                     watchees.remove(e);
184                 }
185             }
186         }
187 
188         private void notifyWatchees() {
189             List<Entry> watchees = this.watchees;
190             for (int i = 0; i < watchees.size();) {
191                 Entry e = watchees.get(i);
192                 if (!e.thread.isAlive()) {
193                     watchees.remove(i);
194                     try {
195                         e.task.run();
196                     } catch (Throwable t) {
197                         logger.warn("Thread death watcher task raised an exception:", t);
198                     }
199                 } else {
200                     i ++;
201                 }
202             }
203         }
204     }
205 
206     private static final class Entry extends MpscLinkedQueueNode<Entry> {
207         final Thread thread;
208         final Runnable task;
209         final boolean isWatch;
210 
211         Entry(Thread thread, Runnable task, boolean isWatch) {
212             this.thread = thread;
213             this.task = task;
214             this.isWatch = isWatch;
215         }
216 
217         @Override
218         public Entry value() {
219             return this;
220         }
221 
222         @Override
223         public int hashCode() {
224             return thread.hashCode() ^ task.hashCode();
225         }
226 
227         @Override
228         public boolean equals(Object obj) {
229             if (obj == this) {
230                 return true;
231             }
232 
233             if (!(obj instanceof Entry)) {
234                 return false;
235             }
236 
237             Entry that = (Entry) obj;
238             return thread == that.thread && task == that.task;
239         }
240     }
241 }