1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
17
18 #include <forward_list>
19
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/cleanup.h"
22 #include "tensorflow/core/util/env_var.h"
23
24 namespace tensorflow {
25 namespace {
IsAsyncWaitForRemoteFunctionEnabled()26 bool IsAsyncWaitForRemoteFunctionEnabled() {
27 bool enabled = true;
28 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_ASYNC_WAIT_FOR_REMOTE_FUNCTION",
29 true, &enabled));
30 return enabled;
31 }
32 } // namespace
33
EagerExecutor(bool async)34 EagerExecutor::EagerExecutor(bool async)
35 : next_node_id_(0),
36 ok_(true),
37 thread_(async ? tensorflow::Env::Default()->StartThread(
38 tensorflow::ThreadOptions(), "eager_async_executor",
39 std::bind(&EagerExecutor::Run, this))
40 : nullptr),
41 last_eager_client_(nullptr),
42 enable_async_wait_for_remote_function_(
43 IsAsyncWaitForRemoteFunctionEnabled()) {}
44
~EagerExecutor()45 EagerExecutor::~EagerExecutor() {
46 tensorflow::mutex_lock l(node_queue_mutex_);
47 state_ = ExecutorState::kShutDown;
48 nodes_pending_.notify_all();
49 for (const auto& cleanups_for_key : cleanups_) {
50 for (const std::function<void()>& cleanup : cleanups_for_key.second) {
51 cleanup();
52 }
53 }
54 }
55
ShutDown()56 Status EagerExecutor::ShutDown() {
57 {
58 bool has_thread;
59 Status status;
60 {
61 tensorflow::mutex_lock l(node_queue_mutex_);
62 if (state_ != ExecutorState::kShutDown) {
63 // if the state is kShutDown, we don't return here because we want to
64 // make sure the executor thread has ended (if there is one).
65 // So, we fall through to
66 // thread_exited_notification_.WaitForNotification() below.
67 state_ = ExecutorState::kShuttingDown;
68 }
69 // It is OK to ignore the returned status here because it will be saved
70 // as the final status_.
71 WaitForAllPendingNodesLocked(&l).IgnoreError();
72 state_ = ExecutorState::kShutDown;
73 has_thread = thread_ != nullptr;
74 status = status_;
75 if (has_thread) {
76 nodes_pending_.notify_all();
77 }
78 }
79 if (!has_thread) {
80 return status;
81 }
82 }
83
84 thread_exited_notification_.WaitForNotification();
85
86 return status();
87 }
88
StateStringLocked()89 const char* EagerExecutor::StateStringLocked() {
90 switch (state_) {
91 case ExecutorState::kActive:
92 return "Active";
93 case ExecutorState::kShuttingDown:
94 return "ShuttingDown";
95 case ExecutorState::kShutDown:
96 return "ShutDown";
97 }
98 }
99
SyncExecute(EagerNode * node)100 Status EagerExecutor::SyncExecute(EagerNode* node) {
101 if (Async()) {
102 return errors::Internal("Executor does not support async execution");
103 }
104 if (node->AsAsync() != nullptr) {
105 return errors::Internal("Executor does not support executing async nodes");
106 }
107 // NOTE: SyncExecute runs every node regardless of error status in executor.
108
109 uint64 id = next_node_id_++;
110
111 Status s = node->Prepare();
112 if (!s.ok()) {
113 return s;
114 }
115
116 // Inline execution in sync mode.
117 s = node->Run();
118 tensorflow::mutex_lock l(node_queue_mutex_);
119 NotifyWaiters(id);
120 return s;
121 }
122
AddOrExecute(std::unique_ptr<EagerNode> node)123 Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
124 Status status;
125 core::RefCountPtr<NodeItem> item(new NodeItem);
126 item->id = next_node_id_++;
127 item->node = std::move(node);
128 item->state = NodeState::kPENDING;
129
130 status = item->node->Prepare();
131 if (!status.ok()) {
132 item->node->Abort(status);
133 return status;
134 }
135
136 // Inline execution in sync mode.
137 if (!Async()) {
138 // In sync mode, run the node item regardless of executor status.
139 return RunItem(std::move(item), /*from_queue=*/false);
140 } else {
141 tensorflow::mutex_lock l(node_queue_mutex_);
142 DVLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
143 << " with status: " << status_.ToString();
144 if (state_ != ExecutorState::kActive) {
145 status = errors::FailedPrecondition(
146 "EagerExecutor accepts new EagerNodes to run only in Active state. "
147 "Current state is '",
148 StateStringLocked(), "'");
149 } else {
150 status = status_;
151 if (status.ok()) {
152 node_queue_.push(std::move(item));
153 // If there were no previous nodes pending, wake the run thread to
154 // start processing requests again.
155 if (node_queue_.size() == 1) {
156 nodes_pending_.notify_all();
157 }
158
159 return Status::OK();
160 }
161 }
162 }
163
164 // If we are unable to add the node to the queue, we must call Abort. However,
165 // we want to do that outside of the scope of the lock since the Abort may
166 // try to call EagerExecutor::AddOrExecute()
167 item->node->Abort(status);
168
169 return status;
170 }
171
WaitForAllPendingNodes()172 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
173 tensorflow::mutex_lock l(node_queue_mutex_);
174 return WaitForAllPendingNodesLocked(&l);
175 }
176
WaitForAllPendingNodesLocked(mutex_lock * lock)177 tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
178 mutex_lock* lock) {
179 tensorflow::condition_variable cond;
180 // Don't wait if an error is already set.
181 if (!status_.ok()) return status_;
182 if (node_queue_.empty() && unfinished_nodes_.empty())
183 return tensorflow::Status::OK();
184 // node_queue_ must be empty in sync mode.
185 DCHECK(Async() || node_queue_.empty());
186 auto last_id = next_node_id_ - 1;
187 DVLOG(3) << "Wait for Node: [id " << last_id << "] ";
188 node_done_notifications_.insert(std::make_pair(last_id, &cond));
189 cond.wait(*lock);
190 // Note that we could be woken up if an error occurs, even though the node has
191 // not actually executed.
192 return status_;
193 }
194
ClearError()195 void EagerExecutor::ClearError() {
196 // TODO(iga): Check state_ and return an error if it is not kActive.
197 if (ok()) return;
198
199 tensorflow::mutex_lock l(node_queue_mutex_);
200 // If an error was set, node_done_notifications_ and node_queue_ should have
201 // been cleared, and no new entries should have been added since.
202 DCHECK(node_done_notifications_.empty());
203 DCHECK(node_queue_.empty());
204 status_ = tensorflow::Status::OK();
205 ok_ = true;
206 last_eager_client_ = nullptr;
207 nodes_pending_.notify_all();
208 }
209
NodeDone(const core::RefCountPtr<NodeItem> & item,const Status & status,bool from_queue)210 void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
211 const Status& status, bool from_queue) {
212 DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
213 << " with status: " << status.ToString();
214 DCHECK(item->state != NodeState::kDONE);
215 item->state = NodeState::kDONE;
216
217 bool async = item->node->AsAsync() != nullptr;
218 // If executing synchronously we don't need to notify if status is OK since
219 // the node was never added to the unfinished_nodes_ list and nobody should
220 // ever be waiting for it.
221 if (status.ok() && !from_queue && !async) {
222 return;
223 }
224
225 std::forward_list<core::RefCountPtr<NodeItem>> items_to_destroy;
226 {
227 mutex_lock l(node_queue_mutex_);
228 if (!status_.ok()) return;
229
230 bool need_notification = from_queue;
231 if (from_queue) {
232 // Since this was from the async queue, pop it from the front of the queue
233 DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
234 node_queue_.pop();
235 } else if (async) {
236 // If it is an Async node then we will find the node in the unfinished
237 // nodes list. However we only notify if we are at the front of the list
238 // since we don't want to notify any waiters of earlier nodes.
239 need_notification = item->id == unfinished_nodes_.begin()->first;
240 // Remove item if it exists in unfinished_nodes_.
241 // With async execution, if two separate nodes failed and enter this
242 // callback, then the second node might not find itself in
243 // unfinished_nodes_ in the following senario:
244 // 1) Callback of the first failed node clears unfinished_nodes_
245 // 2) ClearError is called and executor status_ is set to OK
246 // 3) Callback of the second failed node is triggered
247 // In this case, do not taint the executor status or other note items
248 // because they are inserted after the ClearError.
249 auto result = unfinished_nodes_.erase(item->id);
250 if (result == 0) return;
251 }
252
253 if (!status.ok() && item->node->Fatal()) {
254 // Since we received an error, broadcast to any waiters.
255 need_notification = true;
256 status_ = status;
257 ok_ = false;
258 if (Async()) {
259 // We remove any pending ops so that we don't try to execute them if
260 // ClearError is called.
261 errors::AppendToMessage(&status_,
262 "Encountered when executing an operation using "
263 "EagerExecutor. This error cancels all future "
264 "operations and poisons their output tensors.");
265 }
266 while (!node_queue_.empty()) {
267 items_to_destroy.push_front(std::move(node_queue_.front()));
268 node_queue_.pop();
269 }
270 for (auto& it : unfinished_nodes_) {
271 items_to_destroy.push_front(std::move(it.second));
272 }
273 unfinished_nodes_.clear();
274 }
275 if (need_notification) {
276 NotifyWaiters(item->id);
277 }
278 }
279
280 for (auto& item : items_to_destroy) {
281 item->node->Abort(status);
282 }
283 // nodes_to_destroy will be destructed here, while not holding
284 // node_queue_mutex_. This is important because, unfortunately, some nodes'
285 // destructors can enqueue more operations onto this executor and cause
286 // a deadlock.
287 }
288
NotifyWaiters(uint64 id)289 void EagerExecutor::NotifyWaiters(uint64 id) {
290 if (!node_done_notifications_.empty()) {
291 uint64 upperbound_id = 0;
292 if (!unfinished_nodes_.empty()) {
293 upperbound_id = unfinished_nodes_.begin()->first - 1;
294 } else if (!node_queue_.empty()) {
295 upperbound_id = node_queue_.front()->id - 1;
296 } else {
297 upperbound_id = next_node_id_ - 1;
298 }
299 if (upperbound_id < id) {
300 return;
301 }
302 DVLOG(3) << "Notify node done: [id " << id << " to " << upperbound_id
303 << "] ";
304 // Note that we notify all waiting threads in case an error has
305 // occurred. These calling threads are responsible for checking status_
306 // before proceeding.
307 const auto range =
308 status_.ok()
309 ? make_pair(node_done_notifications_.lower_bound(id),
310 node_done_notifications_.upper_bound(upperbound_id))
311 : make_pair(node_done_notifications_.begin(),
312 node_done_notifications_.end());
313 for (auto it = range.first; it != range.second; ++it) {
314 it->second->notify_all();
315 }
316 node_done_notifications_.erase(range.first, range.second);
317 }
318 }
319
Run()320 void EagerExecutor::Run() {
321 auto thread_exited_notifier =
322 gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
323 while (true) {
324 core::RefCountPtr<NodeItem> curr_item;
325 {
326 tensorflow::mutex_lock l(node_queue_mutex_);
327 while (node_queue_.empty() || !status_.ok()) {
328 if (state_ == ExecutorState::kShutDown) return;
329 nodes_pending_.wait(l);
330 }
331 // Obtain raw pointer since we don't want to remove from the queue until
332 // the node has been run. Otherwise, WaitForAllPendingNodes can return
333 // too early.
334 // Note, we don't std::move from the here because the front of the queue
335 // will then contain a nullptr. This can be a problem in
336 // WaitForAllPendingNodes where we get the top EagerNode pointer
337 // and register a notification for its completion.
338 curr_item.reset(node_queue_.front().get());
339 curr_item->Ref();
340 }
341 Status status = RunItem(std::move(curr_item), /*from_queue=*/true);
342 if (!status.ok()) {
343 VLOG(1) << "Failed to run item: " << status;
344 }
345 }
346 }
347
RunItem(core::RefCountPtr<NodeItem> item,bool from_queue)348 Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
349 bool from_queue) {
350 DVLOG(3) << "Running Node: [id " << item->id << "] "
351 << item->node->DebugString();
352 AsyncRemoteExecuteNode* async_remote_node =
353 item->node->AsAsyncRemoteExecuteNode();
354 if (enable_async_wait_for_remote_function_) {
355 if (async_remote_node != nullptr) {
356 if (last_eager_client_ != nullptr &&
357 async_remote_node->eager_client() != nullptr &&
358 last_eager_client_ != async_remote_node->eager_client()) {
359 // Running a remote function, need to sync if the function is going to
360 // different device than last time we run remote distributed function.
361 DVLOG(3) << "Executing Sync Executor for node" << item->id;
362 tensorflow::Status status = async_remote_node->SyncExecutors();
363 if (!status.ok()) {
364 NodeDone(item, status, from_queue);
365 return status;
366 }
367 last_eager_client_ = nullptr;
368 }
369 if (async_remote_node->eager_client() != nullptr &&
370 async_remote_node->needs_remote_inputs() &&
371 async_remote_node->allow_multiple_pending_requests()) {
372 // We are running remote distributed function, update
373 // last_remote_device_name_.
374 last_eager_client_ = async_remote_node->eager_client();
375 }
376 }
377 }
378
379 AsyncEagerNode* async_node = item->node->AsAsync();
380 if (async_node == nullptr) {
381 tensorflow::Status status = item->node->Run();
382 NodeDone(item, status, from_queue);
383 return status;
384 }
385
386 item->state = NodeState::kSCHEDULED;
387 auto async_ref = item.get();
388 async_ref->Ref();
389
390 TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue));
391
392 async_node->RunAsync([this, async_ref](const Status& status) {
393 core::RefCountPtr<NodeItem> async_item(async_ref);
394 NodeDone(async_item, status, false);
395 });
396
397 // Return the status of the executor in case we are in an error state.
398 return status();
399 }
400
MoveToUnfinished(core::RefCountPtr<NodeItem> item,bool from_queue)401 Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
402 bool from_queue) {
403 tensorflow::mutex_lock l(node_queue_mutex_);
404 if (!status_.ok()) {
405 return status_;
406 }
407
408 if (from_queue) {
409 DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
410 node_queue_.pop();
411 }
412
413 DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
414 unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
415 std::move(item));
416
417 return Status::OK();
418 }
419
AddCleanup(intptr_t key,std::function<void ()> callback)420 void EagerExecutor::AddCleanup(intptr_t key, std::function<void()> callback) {
421 cleanups_[key].push_back(callback);
422 }
423
RemoveCleanups(intptr_t key)424 void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); }
425
426 } // namespace tensorflow
427