1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/platform/env.h"
20 #include "tensorflow/core/platform/mutex.h"
21
22 namespace tensorflow {
23 namespace data {
24
25 // A lightweight wrapper for creating logical threads in a `UnboundedThreadPool`
26 // that can be shared (e.g.) in an `IteratorContext`.
27 class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory {
28 public:
LogicalThreadFactory(UnboundedThreadPool * pool)29 explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {}
30
StartThread(const string & name,std::function<void ()> fn)31 std::unique_ptr<Thread> StartThread(const string& name,
32 std::function<void()> fn) override {
33 return pool_->RunOnPooledThread(std::move(fn));
34 }
35
36 private:
37 UnboundedThreadPool* const pool_; // Not owned.
38 };
39
40 // A logical implementation of the `tensorflow::Thread` interface that uses
41 // physical threads in an `UnboundedThreadPool` to perform the work.
42 //
43 // NOTE: This object represents a logical thread of control that may be mapped
44 // onto the same physical thread as other work items that are submitted to the
45 // same `UnboundedThreadPool`.
46 class UnboundedThreadPool::LogicalThreadWrapper : public Thread {
47 public:
LogicalThreadWrapper(std::shared_ptr<Notification> join_notification)48 explicit LogicalThreadWrapper(std::shared_ptr<Notification> join_notification)
49 : join_notification_(std::move(join_notification)) {}
50
~LogicalThreadWrapper()51 ~LogicalThreadWrapper() override {
52 // NOTE: The `Thread` destructor is expected to "join" the created thread,
53 // but the physical thread may continue to execute after the work for this
54 // thread is complete. We simulate this by waiting on a notification that
55 // the `CachedThreadFunc` will notify when the thread's work function is
56 // complete.
57 join_notification_->WaitForNotification();
58 }
59
60 private:
61 std::shared_ptr<Notification> join_notification_;
62 };
63
~UnboundedThreadPool()64 UnboundedThreadPool::~UnboundedThreadPool() {
65 {
66 mutex_lock l(work_queue_mu_);
67 // Wake up all `CachedThreadFunc` threads and cause them to terminate before
68 // joining them when `threads_` is cleared.
69 cancelled_ = true;
70 work_queue_cv_.notify_all();
71 if (!work_queue_.empty()) {
72 LOG(ERROR) << "UnboundedThreadPool named \"" << thread_name_ << "\" was "
73 << "deleted with pending work in its queue. This may indicate "
74 << "a potential use-after-free bug.";
75 }
76 }
77
78 {
79 mutex_lock l(thread_pool_mu_);
80 // Clear the list of pooled threads, which will eventually terminate due to
81 // the previous notification.
82 //
83 // NOTE: It is safe to do this while holding `pooled_threads_mu_`, because
84 // no subsequent calls to `this->StartThread()` should be issued after the
85 // destructor starts.
86 thread_pool_.clear();
87 }
88 }
89
get_thread_factory()90 std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
91 return std::make_shared<LogicalThreadFactory>(this);
92 }
93
size()94 size_t UnboundedThreadPool::size() {
95 tf_shared_lock l(thread_pool_mu_);
96 return thread_pool_.size();
97 }
98
RunOnPooledThread(std::function<void ()> fn)99 std::unique_ptr<Thread> UnboundedThreadPool::RunOnPooledThread(
100 std::function<void()> fn) {
101 auto join_notification = std::make_shared<Notification>();
102 bool all_threads_busy;
103 {
104 // Enqueue a work item for the new thread's function, and wake up a
105 // cached thread to process it.
106 mutex_lock l(work_queue_mu_);
107 work_queue_.push_back({std::move(fn), join_notification});
108 work_queue_cv_.notify_one();
109 // NOTE: The queue may be non-empty, so we must account for queued work when
110 // considering how many threads are free.
111 all_threads_busy = work_queue_.size() > num_idle_threads_;
112 }
113
114 if (all_threads_busy) {
115 // Spawn a new physical thread to process the given function.
116 // NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_`
117 // at the beginning of its work loop.
118 Thread* new_thread = env_->StartThread(
119 {}, thread_name_,
120 std::bind(&UnboundedThreadPool::PooledThreadFunc, this));
121
122 mutex_lock l(thread_pool_mu_);
123 thread_pool_.emplace_back(new_thread);
124 }
125
126 return absl::make_unique<LogicalThreadWrapper>(std::move(join_notification));
127 }
128
PooledThreadFunc()129 void UnboundedThreadPool::PooledThreadFunc() {
130 while (true) {
131 WorkItem work_item;
132 {
133 mutex_lock l(work_queue_mu_);
134 ++num_idle_threads_;
135 while (!cancelled_ && work_queue_.empty()) {
136 // Wait for a new work function to be submitted, or the cache to be
137 // destroyed.
138 work_queue_cv_.wait(l);
139 }
140 if (cancelled_) {
141 return;
142 }
143 work_item = std::move(work_queue_.front());
144 work_queue_.pop_front();
145 --num_idle_threads_;
146 }
147
148 work_item.work_function();
149
150 // Notify any thread that has "joined" the cached thread for this work item.
151 work_item.done_notification->Notify();
152 }
153 }
154
155 } // namespace data
156 } // namespace tensorflow
157