1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
17
18 #include <forward_list>
19
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/cleanup.h"
22 #include "tensorflow/core/util/env_var.h"
23
24 namespace tensorflow {
25 namespace {
IsAsyncWaitForRemoteFunctionEnabled()26 bool IsAsyncWaitForRemoteFunctionEnabled() {
27 bool enabled = true;
28 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_ASYNC_WAIT_FOR_REMOTE_FUNCTION",
29 true, &enabled));
30 return enabled;
31 }
32 } // namespace
33
EagerExecutor(bool async)34 EagerExecutor::EagerExecutor(bool async)
35 : next_node_id_(0),
36 ok_(true),
37 thread_(async ? tensorflow::Env::Default()->StartThread(
38 tensorflow::ThreadOptions(), "eager_async_executor",
39 std::bind(&EagerExecutor::Run, this))
40 : nullptr),
41 last_eager_client_(nullptr),
42 enable_async_wait_for_remote_function_(
43 IsAsyncWaitForRemoteFunctionEnabled()) {}
44
~EagerExecutor()45 EagerExecutor::~EagerExecutor() {
46 tensorflow::mutex_lock l(node_queue_mutex_);
47 state_ = ExecutorState::kShutDown;
48 nodes_pending_.notify_all();
49 for (const auto& cleanups_for_key : cleanups_) {
50 for (const std::function<void()>& cleanup : cleanups_for_key.second) {
51 cleanup();
52 }
53 }
54 }
55
ShutDown()56 Status EagerExecutor::ShutDown() {
57 {
58 bool has_thread;
59 Status status;
60 {
61 tensorflow::mutex_lock l(node_queue_mutex_);
62 if (state_ != ExecutorState::kShutDown) {
63 // if the state is kShutDown, we don't return here because we want to
64 // make sure the executor thread has ended (if there is one).
65 // So, we fall through to
66 // thread_exited_notification_.WaitForNotification() below.
67 state_ = ExecutorState::kShuttingDown;
68 }
69 // It is OK to ignore the returned status here because it will be saved
70 // as the final status_.
71 WaitForAllPendingNodesLocked(&l).IgnoreError();
72 state_ = ExecutorState::kShutDown;
73 has_thread = thread_ != nullptr;
74 status = status_;
75 if (has_thread) {
76 nodes_pending_.notify_all();
77 }
78 }
79 if (!has_thread) {
80 return status;
81 }
82 }
83
84 thread_exited_notification_.WaitForNotification();
85
86 return status();
87 }
88
StateStringLocked()89 const char* EagerExecutor::StateStringLocked() {
90 switch (state_) {
91 case ExecutorState::kActive:
92 return "Active";
93 case ExecutorState::kShuttingDown:
94 return "ShuttingDown";
95 case ExecutorState::kShutDown:
96 return "ShutDown";
97 }
98 }
99
SyncExecute(EagerNode * node)100 Status EagerExecutor::SyncExecute(EagerNode* node) {
101 if (Async()) {
102 return errors::Internal("Executor does not support async execution");
103 }
104 if (node->AsAsync() != nullptr) {
105 return errors::Internal("Executor does not support executing async nodes");
106 }
107 // NOTE: SyncExecute runs every node regardless of error status in executor.
108
109 uint64 id = next_node_id_++;
110
111 Status s = node->Prepare();
112 if (!s.ok()) {
113 return s;
114 }
115
116 // Inline execution in sync mode.
117 s = node->Run();
118 tensorflow::mutex_lock l(node_queue_mutex_);
119 if (!s.ok()) {
120 status_ = s;
121 ok_ = false;
122 }
123 NotifyWaiters(id);
124 return s;
125 }
126
AddOrExecute(std::unique_ptr<EagerNode> node)127 Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
128 Status status;
129 core::RefCountPtr<NodeItem> item(new NodeItem);
130 item->id = next_node_id_++;
131 item->node = std::move(node);
132 item->state = NodeState::kPENDING;
133
134 status = item->node->Prepare();
135 if (!status.ok()) {
136 item->node->Abort(status);
137 return status;
138 }
139
140 // Inline execution in sync mode.
141 if (!Async()) {
142 // In sync mode, run the node item regardless of executor status.
143 return RunItem(std::move(item), /*from_queue=*/false);
144 } else {
145 tensorflow::mutex_lock l(node_queue_mutex_);
146 DVLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
147 << " with status: " << status_.ToString();
148 if (state_ != ExecutorState::kActive) {
149 status = errors::FailedPrecondition(
150 "EagerExecutor accepts new EagerNodes to run only in Active state. "
151 "Current state is '",
152 StateStringLocked(), "'");
153 } else {
154 status = status_;
155 if (status.ok()) {
156 node_queue_.push(std::move(item));
157 // If there were no previous nodes pending, wake the run thread to
158 // start processing requests again.
159 if (node_queue_.size() == 1) {
160 nodes_pending_.notify_all();
161 }
162
163 return Status::OK();
164 }
165 }
166 }
167
168 // If we are unable to add the node to the queue, we must call Abort. However,
169 // we want to do that outside of the scope of the lock since the Abort may
170 // try to call EagerExecutor::AddOrExecute()
171 item->node->Abort(status);
172
173 return status;
174 }
175
WaitForAllPendingNodes()176 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
177 tensorflow::mutex_lock l(node_queue_mutex_);
178 return WaitForAllPendingNodesLocked(&l);
179 }
180
WaitForAllPendingNodesLocked(mutex_lock * lock)181 tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
182 mutex_lock* lock) {
183 tensorflow::condition_variable cond;
184 // Don't wait if an error is already set.
185 if (!status_.ok()) return status_;
186 if (node_queue_.empty() && unfinished_nodes_.empty())
187 return tensorflow::Status::OK();
188 // node_queue_ must be empty in sync mode.
189 DCHECK(Async() || node_queue_.empty());
190 auto last_id = next_node_id_ - 1;
191 DVLOG(3) << "Wait for Node: [id " << last_id << "] ";
192 node_done_notifications_.insert(std::make_pair(last_id, &cond));
193 cond.wait(*lock);
194 // Note that we could be woken up if an error occurs, even though the node has
195 // not actually executed.
196 return status_;
197 }
198
ClearError()199 void EagerExecutor::ClearError() {
200 // TODO(iga): Check state_ and return an error if it is not kActive.
201 if (ok()) return;
202
203 tensorflow::mutex_lock l(node_queue_mutex_);
204 // If an error was set, node_done_notifications_ and node_queue_ should have
205 // been cleared, and no new entries should have been added since.
206 DCHECK(node_done_notifications_.empty());
207 DCHECK(node_queue_.empty());
208 status_ = tensorflow::Status::OK();
209 ok_ = true;
210 last_eager_client_ = nullptr;
211 nodes_pending_.notify_all();
212 }
213
NodeDone(const core::RefCountPtr<NodeItem> & item,const Status & status,bool from_queue)214 void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
215 const Status& status, bool from_queue) {
216 DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
217 << " with status: " << status.ToString();
218 DCHECK(item->state != NodeState::kDONE);
219 item->state = NodeState::kDONE;
220
221 bool async = item->node->AsAsync() != nullptr;
222 // If executing synchronously we don't need to notify if status is OK since
223 // the node was never added to the unfinished_nodes_ list and nobody should
224 // ever be waiting for it.
225 if (status.ok() && !from_queue && !async) {
226 return;
227 }
228
229 std::forward_list<core::RefCountPtr<NodeItem>> items_to_destroy;
230 {
231 mutex_lock l(node_queue_mutex_);
232 if (!status_.ok()) return;
233
234 bool need_notification = from_queue;
235 if (from_queue) {
236 // Since this was from the async queue, pop it from the front of the queue
237 DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
238 node_queue_.pop();
239 } else if (async) {
240 // If it is an Async node then we will find the node in the unfinished
241 // nodes list. However we only notify if we are at the front of the list
242 // since we don't want to notify any waiters of earlier nodes.
243 need_notification = item->id == unfinished_nodes_.begin()->first;
244 // Remove item if it exists in unfinished_nodes_.
245 // With async execution, if two separate nodes failed and enter this
246 // callback, then the second node might not find itself in
247 // unfinished_nodes_ in the following senario:
248 // 1) Callback of the first failed node clears unfinished_nodes_
249 // 2) ClearError is called and executor status_ is set to OK
250 // 3) Callback of the second failed node is triggered
251 // In this case, do not taint the executor status or other note items
252 // because they are inserted after the ClearError.
253 auto result = unfinished_nodes_.erase(item->id);
254 if (result == 0) return;
255 }
256
257 if (!status.ok() && item->node->Fatal()) {
258 // Since we received an error, broadcast to any waiters.
259 need_notification = true;
260 status_ = status;
261 ok_ = false;
262 if (Async()) {
263 // We remove any pending ops so that we don't try to execute them if
264 // ClearError is called.
265 errors::AppendToMessage(&status_,
266 "Encountered when executing an operation using "
267 "EagerExecutor. This error cancels all future "
268 "operations and poisons their output tensors.");
269 }
270 while (!node_queue_.empty()) {
271 items_to_destroy.push_front(std::move(node_queue_.front()));
272 node_queue_.pop();
273 }
274 for (auto& it : unfinished_nodes_) {
275 items_to_destroy.push_front(std::move(it.second));
276 }
277 unfinished_nodes_.clear();
278 }
279 if (need_notification) {
280 NotifyWaiters(item->id);
281 }
282 }
283
284 for (auto& item : items_to_destroy) {
285 item->node->Abort(status);
286 }
287 // nodes_to_destroy will be destructed here, while not holding
288 // node_queue_mutex_. This is important because, unfortunately, some nodes'
289 // destructors can enqueue more operations onto this executor and cause
290 // a deadlock.
291 }
292
NotifyWaiters(uint64 id)293 void EagerExecutor::NotifyWaiters(uint64 id) {
294 if (!node_done_notifications_.empty()) {
295 uint64 upperbound_id = 0;
296 if (!unfinished_nodes_.empty()) {
297 upperbound_id = unfinished_nodes_.begin()->first - 1;
298 } else if (!node_queue_.empty()) {
299 upperbound_id = node_queue_.front()->id - 1;
300 } else {
301 upperbound_id = next_node_id_ - 1;
302 }
303 if (upperbound_id < id) {
304 return;
305 }
306 DVLOG(3) << "Notify node done: [id " << id << " to " << upperbound_id
307 << "] ";
308 // Note that we notify all waiting threads in case an error has
309 // occurred. These calling threads are responsible for checking status_
310 // before proceeding.
311 const auto range =
312 status_.ok()
313 ? make_pair(node_done_notifications_.lower_bound(id),
314 node_done_notifications_.upper_bound(upperbound_id))
315 : make_pair(node_done_notifications_.begin(),
316 node_done_notifications_.end());
317 for (auto it = range.first; it != range.second; ++it) {
318 it->second->notify_all();
319 }
320 node_done_notifications_.erase(range.first, range.second);
321 }
322 }
323
Run()324 void EagerExecutor::Run() {
325 auto thread_exited_notifier =
326 gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
327 while (true) {
328 core::RefCountPtr<NodeItem> curr_item;
329 {
330 tensorflow::mutex_lock l(node_queue_mutex_);
331 while (node_queue_.empty() || !status_.ok()) {
332 if (state_ == ExecutorState::kShutDown) return;
333 nodes_pending_.wait(l);
334 }
335 // Obtain raw pointer since we don't want to remove from the queue until
336 // the node has been run. Otherwise, WaitForAllPendingNodes can return
337 // too early.
338 // Note, we don't std::move from the here because the front of the queue
339 // will then contain a nullptr. This can be a problem in
340 // WaitForAllPendingNodes where we get the top EagerNode pointer
341 // and register a notification for its completion.
342 curr_item.reset(node_queue_.front().get());
343 curr_item->Ref();
344 }
345 Status status = RunItem(std::move(curr_item), /*from_queue=*/true);
346 if (!status.ok()) {
347 VLOG(1) << "Failed to run item: " << status;
348 }
349 }
350 }
351
RunItem(core::RefCountPtr<NodeItem> item,bool from_queue)352 Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
353 bool from_queue) {
354 DVLOG(3) << "Running Node: [id " << item->id << "] "
355 << item->node->DebugString();
356 AsyncRemoteExecuteNode* async_remote_node =
357 item->node->AsAsyncRemoteExecuteNode();
358 if (enable_async_wait_for_remote_function_) {
359 if (async_remote_node != nullptr) {
360 if (last_eager_client_ != nullptr &&
361 async_remote_node->eager_client() != nullptr &&
362 last_eager_client_ != async_remote_node->eager_client()) {
363 // Running a remote function, need to sync if the function is going to
364 // different device than last time we run remote distributed function.
365 DVLOG(3) << "Executing Sync Executor for node" << item->id;
366 tensorflow::Status status = async_remote_node->SyncExecutors();
367 if (!status.ok()) {
368 NodeDone(item, status, from_queue);
369 return status;
370 }
371 last_eager_client_ = nullptr;
372 }
373 if (async_remote_node->eager_client() != nullptr &&
374 async_remote_node->needs_remote_inputs() &&
375 async_remote_node->allow_multiple_pending_requests()) {
376 // We are running remote distributed function, update
377 // last_remote_device_name_.
378 last_eager_client_ = async_remote_node->eager_client();
379 }
380 }
381 }
382
383 AsyncEagerNode* async_node = item->node->AsAsync();
384 if (async_node == nullptr) {
385 tensorflow::Status status = item->node->Run();
386 NodeDone(item, status, from_queue);
387 return status;
388 }
389
390 item->state = NodeState::kSCHEDULED;
391 auto async_ref = item.get();
392 async_ref->Ref();
393
394 TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue));
395
396 async_node->RunAsync([this, async_ref](const Status& status) {
397 core::RefCountPtr<NodeItem> async_item(async_ref);
398 NodeDone(async_item, status, false);
399 });
400
401 // Return the status of the executor in case we are in an error state.
402 return status();
403 }
404
MoveToUnfinished(core::RefCountPtr<NodeItem> item,bool from_queue)405 Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
406 bool from_queue) {
407 tensorflow::mutex_lock l(node_queue_mutex_);
408 if (!status_.ok()) {
409 return status_;
410 }
411
412 if (from_queue) {
413 DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
414 node_queue_.pop();
415 }
416
417 DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
418 unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
419 std::move(item));
420
421 return Status::OK();
422 }
423
AddCleanup(intptr_t key,std::function<void ()> callback)424 void EagerExecutor::AddCleanup(intptr_t key, std::function<void()> callback) {
425 cleanups_[key].push_back(callback);
426 }
427
RemoveCleanups(intptr_t key)428 void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); }
429
430 } // namespace tensorflow
431