1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
17
18 #include <forward_list>
19
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/cleanup.h"
22
23 namespace tensorflow {
24
EagerExecutor(bool async)25 EagerExecutor::EagerExecutor(bool async)
26 : next_node_id_(0),
27 ok_(true),
28 thread_(async ? tensorflow::Env::Default()->StartThread(
29 tensorflow::ThreadOptions(), "eager_async_executor",
30 std::bind(&EagerExecutor::Run, this))
31 : nullptr) {}
32
~EagerExecutor()33 EagerExecutor::~EagerExecutor() {
34 tensorflow::mutex_lock l(node_queue_mutex_);
35 state_ = ExecutorState::kShutDown;
36 nodes_pending_.notify_all();
37 }
38
ShutDown()39 Status EagerExecutor::ShutDown() {
40 {
41 std::vector<core::RefCountPtr<NodeItem>> items_to_destroy;
42 bool has_thread;
43 Status status;
44 {
45 tensorflow::mutex_lock l(node_queue_mutex_);
46 if (state_ != ExecutorState::kShutDown) {
47 // if the state is kShutDown, we don't return here because we want to
48 // make sure the executor thread has ended (if there is one).
49 // So, we fall through to
50 // thread_exited_notification_.WaitForNotification() below.
51 state_ = ExecutorState::kShuttingDown;
52 }
53 // It is OK to ignore the returned status here because it will be saved
54 // as the final status_.
55 WaitForAllPendingNodesLocked(&l).IgnoreError();
56 state_ = ExecutorState::kShutDown;
57 has_thread = thread_ != nullptr;
58 status = status_;
59 if (has_thread) {
60 nodes_pending_.notify_all();
61 }
62 }
63 for (auto& item : items_to_destroy) {
64 item->node->Abort(status);
65 }
66 if (!has_thread) {
67 return status;
68 }
69 }
70
71 thread_exited_notification_.WaitForNotification();
72
73 return status();
74 }
75
StateStringLocked()76 const char* EagerExecutor::StateStringLocked() {
77 switch (state_) {
78 case ExecutorState::kActive:
79 return "Active";
80 case ExecutorState::kShuttingDown:
81 return "ShuttingDown";
82 case ExecutorState::kShutDown:
83 return "ShutDown";
84 }
85 }
86
SyncExecute(EagerNode * node)87 Status EagerExecutor::SyncExecute(EagerNode* node) {
88 if (Async()) {
89 return errors::Internal("Executor does not support sync execution");
90 }
91 if (node->AsAsync() != nullptr) {
92 return errors::Internal("Executor does not support executing async nodes");
93 }
94 // NOTE: SyncExecute runs every node regardless of error status in executor.
95
96 uint64 id = next_node_id_++;
97
98 Status s = node->Prepare();
99 if (!s.ok()) {
100 return s;
101 }
102
103 // Inline execution in sync mode.
104 s = node->Run();
105 tensorflow::mutex_lock l(node_queue_mutex_);
106 if (!s.ok()) {
107 status_ = s;
108 ok_ = false;
109 }
110 NotifyWaiters(id);
111 return s;
112 }
113
AddOrExecute(std::unique_ptr<EagerNode> node)114 Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
115 Status status;
116 core::RefCountPtr<NodeItem> item(new NodeItem);
117 item->id = next_node_id_++;
118 item->node = std::move(node);
119 item->state = NodeState::kPENDING;
120
121 status = item->node->Prepare();
122 if (!status.ok()) {
123 item->node->Abort(status);
124 return status;
125 }
126
127 // Inline execution in sync mode.
128 if (!Async()) {
129 // In sync mode, run the node item regardless of executor status.
130 return RunItem(std::move(item), /*from_queue=*/false);
131 } else {
132 tensorflow::mutex_lock l(node_queue_mutex_);
133 DVLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
134 << " with status: " << status_.ToString();
135 if (state_ != ExecutorState::kActive) {
136 status = errors::FailedPrecondition(
137 "EagerExecutor accepts new EagerNodes to run only in Active state. "
138 "Current state is '",
139 StateStringLocked(), "'");
140 } else {
141 status = status_;
142 if (status.ok()) {
143 node_queue_.push(std::move(item));
144 // If there were no previous nodes pending, wake the run thread to
145 // start processing requests again.
146 if (node_queue_.size() == 1) {
147 nodes_pending_.notify_all();
148 }
149
150 return Status::OK();
151 }
152 }
153 }
154
155 // If we are unable to add the node to the queue, we must call Abort. However,
156 // we want to do that outside of the scope of the lock since the Abort may
157 // try to call EagerExecutor::AddOrExecute()
158 item->node->Abort(status);
159
160 return status;
161 }
162
WaitForAllPendingNodes()163 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
164 tensorflow::mutex_lock l(node_queue_mutex_);
165 return WaitForAllPendingNodesLocked(&l);
166 }
167
WaitForAllPendingNodesLocked(mutex_lock * lock)168 tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
169 mutex_lock* lock) {
170 tensorflow::condition_variable cond;
171 // Don't wait if an error is already set.
172 if (!status_.ok()) return status_;
173 if (node_queue_.empty() && unfinished_nodes_.empty())
174 return tensorflow::Status::OK();
175 // node_queue_ must be empty in sync mode.
176 DCHECK(Async() || node_queue_.empty());
177 auto last_id = next_node_id_ - 1;
178 DVLOG(3) << "Wait for Node: [id " << last_id << "] ";
179 node_done_notifications_.insert(std::make_pair(last_id, &cond));
180 cond.wait(*lock);
181 // Note that we could be woken up if an error occurs, even though the node has
182 // not actually executed.
183 return status_;
184 }
185
ClearError()186 void EagerExecutor::ClearError() {
187 // TODO(iga): Check state_ and return an error if it is not kActive.
188 if (ok()) return;
189
190 tensorflow::mutex_lock l(node_queue_mutex_);
191 // If an error was set, node_done_notifications_ and node_queue_ should have
192 // been cleared, and no new entries should have been added since.
193 DCHECK(node_done_notifications_.empty());
194 DCHECK(node_queue_.empty());
195 status_ = tensorflow::Status::OK();
196 ok_ = true;
197 nodes_pending_.notify_all();
198 }
199
NodeDone(const core::RefCountPtr<NodeItem> & item,const Status & status,bool from_queue)200 void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
201 const Status& status, bool from_queue) {
202 DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
203 << " with status: " << status.ToString();
204 DCHECK(item->state != NodeState::kDONE);
205 item->state = NodeState::kDONE;
206
207 bool async = item->node->AsAsync() != nullptr;
208 // If executing synchronously we don't need to notify if status is OK since
209 // the node was never added to the unfinished_nodes_ list and nobody should
210 // ever be waiting for it.
211 if (status.ok() && !from_queue && !async) {
212 return;
213 }
214
215 std::forward_list<core::RefCountPtr<NodeItem>> items_to_destroy;
216 {
217 mutex_lock l(node_queue_mutex_);
218 if (!status_.ok()) return;
219
220 bool need_notification = from_queue;
221 if (from_queue) {
222 // Since this was from the async queue, pop it from the front of the queue
223 DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
224 node_queue_.pop();
225 } else if (async) {
226 // If it is an Async node then we will find the node in the unfinished
227 // nodes list. However we only notify if we are at the front of the list
228 // since we don't want to notify any waiters of earlier nodes.
229 need_notification = item->id == unfinished_nodes_.begin()->first;
230 auto result = unfinished_nodes_.erase(item->id);
231 DCHECK_GT(result, 0);
232 }
233
234 if (!status.ok()) {
235 // Since we received an error, broadcast to any waiters.
236 need_notification = true;
237 status_ = status;
238 ok_ = false;
239 if (Async()) {
240 // We remove any pending ops so that we don't try to execute them if
241 // ClearError is called.
242 errors::AppendToMessage(&status_,
243 "Encountered when executing an operation using "
244 "EagerExecutor. This error cancels all future "
245 "operations and poisons their output tensors.");
246 }
247 while (!node_queue_.empty()) {
248 items_to_destroy.push_front(std::move(node_queue_.front()));
249 node_queue_.pop();
250 }
251 for (auto& it : unfinished_nodes_) {
252 items_to_destroy.push_front(std::move(it.second));
253 }
254 unfinished_nodes_.clear();
255 }
256 if (need_notification) {
257 NotifyWaiters(item->id);
258 }
259 }
260
261 for (auto& item : items_to_destroy) {
262 item->node->Abort(status);
263 }
264 // nodes_to_destroy will be destructed here, while not holding
265 // node_queue_mutex_. This is important because, unfortunately, some nodes'
266 // destructors can enqueue more operations onto this executor and cause
267 // a deadlock.
268 }
269
NotifyWaiters(uint64 id)270 void EagerExecutor::NotifyWaiters(uint64 id) {
271 if (!node_done_notifications_.empty()) {
272 uint64 upperbound_id = 0;
273 if (!unfinished_nodes_.empty()) {
274 upperbound_id = unfinished_nodes_.begin()->first - 1;
275 } else if (!node_queue_.empty()) {
276 upperbound_id = node_queue_.front()->id - 1;
277 } else {
278 upperbound_id = next_node_id_ - 1;
279 }
280 DVLOG(3) << "Notify node done: [id " << id << " to " << upperbound_id
281 << "] ";
282 // Note that we notify all waiting threads in case an error has
283 // occurred. These calling threads are responsible for checking status_
284 // before proceeding.
285 const auto range =
286 status_.ok()
287 ? make_pair(node_done_notifications_.lower_bound(id),
288 node_done_notifications_.upper_bound(upperbound_id))
289 : make_pair(node_done_notifications_.begin(),
290 node_done_notifications_.end());
291 for (auto it = range.first; it != range.second; ++it) {
292 it->second->notify_all();
293 }
294 node_done_notifications_.erase(range.first, range.second);
295 }
296 }
297
Run()298 void EagerExecutor::Run() {
299 auto thread_exited_notifier =
300 gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
301 while (true) {
302 core::RefCountPtr<NodeItem> curr_item;
303 {
304 tensorflow::mutex_lock l(node_queue_mutex_);
305 while (node_queue_.empty() || !status_.ok()) {
306 if (state_ == ExecutorState::kShutDown) return;
307 nodes_pending_.wait(l);
308 }
309 // Obtain raw pointer since we don't want to remove from the queue until
310 // the node has been run. Otherwise, WaitForAllPendingNodes can return
311 // too early.
312 // Note, we don't std::move from the here because the front of the queue
313 // will then contain a nullptr. This can be a problem in
314 // WaitForAllPendingNodes where we get the top EagerNode pointer
315 // and register a notification for its completion.
316 curr_item.reset(node_queue_.front().get());
317 curr_item->Ref();
318 }
319 Status status = RunItem(std::move(curr_item), /*from_queue=*/true);
320 if (!status.ok()) {
321 VLOG(1) << "Failed to run item: " << status;
322 }
323 }
324 }
325
RunItem(core::RefCountPtr<NodeItem> item,bool from_queue)326 Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
327 bool from_queue) {
328 DVLOG(3) << "Running Node: [id " << item->id << "] "
329 << item->node->DebugString();
330 AsyncEagerNode* async_node = item->node->AsAsync();
331 if (async_node == nullptr) {
332 tensorflow::Status status = item->node->Run();
333 NodeDone(item, status, from_queue);
334 return status;
335 }
336
337 item->state = NodeState::kSCHEDULED;
338 auto async_ref = item.get();
339 async_ref->Ref();
340
341 TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue));
342
343 async_node->RunAsync([this, async_ref](const Status& status) {
344 core::RefCountPtr<NodeItem> async_item(async_ref);
345 NodeDone(async_item, status, false);
346 });
347
348 // Return the status of the executor in case we are in an error state.
349 return status();
350 }
351
MoveToUnfinished(core::RefCountPtr<NodeItem> item,bool from_queue)352 Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
353 bool from_queue) {
354 tensorflow::mutex_lock l(node_queue_mutex_);
355 if (!status_.ok()) {
356 return status_;
357 }
358
359 if (from_queue) {
360 DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
361 node_queue_.pop();
362 }
363
364 DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
365 unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
366 std::move(item));
367
368 return Status::OK();
369 }
370
371 } // namespace tensorflow
372