• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/platform/env.h"
20 #include "tensorflow/core/platform/mutex.h"
21 
22 namespace tensorflow {
23 namespace data {
24 
25 // A lightweight wrapper for creating logical threads in a `UnboundedThreadPool`
26 // that can be shared (e.g.) in an `IteratorContext`.
27 class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory {
28  public:
LogicalThreadFactory(UnboundedThreadPool * pool)29   explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {}
30 
StartThread(const string & name,std::function<void ()> fn)31   std::unique_ptr<Thread> StartThread(const string& name,
32                                       std::function<void()> fn) override {
33     return pool_->RunOnPooledThread(std::move(fn));
34   }
35 
36  private:
37   UnboundedThreadPool* const pool_;  // Not owned.
38 };
39 
40 // A logical implementation of the `tensorflow::Thread` interface that uses
41 // physical threads in an `UnboundedThreadPool` to perform the work.
42 //
43 // NOTE: This object represents a logical thread of control that may be mapped
44 // onto the same physical thread as other work items that are submitted to the
45 // same `UnboundedThreadPool`.
46 class UnboundedThreadPool::LogicalThreadWrapper : public Thread {
47  public:
LogicalThreadWrapper(std::shared_ptr<Notification> join_notification)48   explicit LogicalThreadWrapper(std::shared_ptr<Notification> join_notification)
49       : join_notification_(std::move(join_notification)) {}
50 
~LogicalThreadWrapper()51   ~LogicalThreadWrapper() override {
52     // NOTE: The `Thread` destructor is expected to "join" the created thread,
53     // but the physical thread may continue to execute after the work for this
54     // thread is complete. We simulate this by waiting on a notification that
55     // the `CachedThreadFunc` will notify when the thread's work function is
56     // complete.
57     join_notification_->WaitForNotification();
58   }
59 
60  private:
61   std::shared_ptr<Notification> join_notification_;
62 };
63 
~UnboundedThreadPool()64 UnboundedThreadPool::~UnboundedThreadPool() {
65   {
66     mutex_lock l(work_queue_mu_);
67     // Wake up all `CachedThreadFunc` threads and cause them to terminate before
68     // joining them when `threads_` is cleared.
69     cancelled_ = true;
70     work_queue_cv_.notify_all();
71     if (!work_queue_.empty()) {
72       LOG(ERROR) << "UnboundedThreadPool named \"" << thread_name_ << "\" was "
73                  << "deleted with pending work in its queue. This may indicate "
74                  << "a potential use-after-free bug.";
75     }
76   }
77 
78   {
79     mutex_lock l(thread_pool_mu_);
80     // Clear the list of pooled threads, which will eventually terminate due to
81     // the previous notification.
82     //
83     // NOTE: It is safe to do this while holding `pooled_threads_mu_`, because
84     // no subsequent calls to `this->StartThread()` should be issued after the
85     // destructor starts.
86     thread_pool_.clear();
87   }
88 }
89 
get_thread_factory()90 std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
91   return std::make_shared<LogicalThreadFactory>(this);
92 }
93 
size()94 size_t UnboundedThreadPool::size() {
95   tf_shared_lock l(thread_pool_mu_);
96   return thread_pool_.size();
97 }
98 
RunOnPooledThread(std::function<void ()> fn)99 std::unique_ptr<Thread> UnboundedThreadPool::RunOnPooledThread(
100     std::function<void()> fn) {
101   auto join_notification = std::make_shared<Notification>();
102   bool all_threads_busy;
103   {
104     // Enqueue a work item for the new thread's function, and wake up a
105     // cached thread to process it.
106     mutex_lock l(work_queue_mu_);
107     work_queue_.push_back({std::move(fn), join_notification});
108     work_queue_cv_.notify_one();
109     // NOTE: The queue may be non-empty, so we must account for queued work when
110     // considering how many threads are free.
111     all_threads_busy = work_queue_.size() > num_idle_threads_;
112   }
113 
114   if (all_threads_busy) {
115     // Spawn a new physical thread to process the given function.
116     // NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_`
117     // at the beginning of its work loop.
118     Thread* new_thread = env_->StartThread(
119         {}, thread_name_,
120         std::bind(&UnboundedThreadPool::PooledThreadFunc, this));
121 
122     mutex_lock l(thread_pool_mu_);
123     thread_pool_.emplace_back(new_thread);
124   }
125 
126   return absl::make_unique<LogicalThreadWrapper>(std::move(join_notification));
127 }
128 
PooledThreadFunc()129 void UnboundedThreadPool::PooledThreadFunc() {
130   while (true) {
131     WorkItem work_item;
132     {
133       mutex_lock l(work_queue_mu_);
134       ++num_idle_threads_;
135       while (!cancelled_ && work_queue_.empty()) {
136         // Wait for a new work function to be submitted, or the cache to be
137         // destroyed.
138         work_queue_cv_.wait(l);
139       }
140       if (cancelled_) {
141         return;
142       }
143       work_item = std::move(work_queue_.front());
144       work_queue_.pop_front();
145       --num_idle_threads_;
146     }
147 
148     work_item.work_function();
149 
150     // Notify any thread that has "joined" the cached thread for this work item.
151     work_item.done_notification->Notify();
152   }
153 }
154 
155 }  // namespace data
156 }  // namespace tensorflow
157