1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
17
18 #include <forward_list>
19
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/cleanup.h"
22 #include "tensorflow/core/util/env_var.h"
23
24 namespace tensorflow {
25 namespace {
IsAsyncWaitForRemoteFunctionEnabled()26 bool IsAsyncWaitForRemoteFunctionEnabled() {
27 bool enabled = true;
28 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_ASYNC_WAIT_FOR_REMOTE_FUNCTION",
29 true, &enabled));
30 return enabled;
31 }
32 } // namespace
33
EagerExecutor(bool async,bool enable_streaming_enqueue)34 EagerExecutor::EagerExecutor(bool async, bool enable_streaming_enqueue)
35 : next_node_id_(0),
36 ok_(true),
37 thread_(async ? tensorflow::Env::Default()->StartThread(
38 tensorflow::ThreadOptions(), "eager_async_executor",
39 std::bind(&EagerExecutor::Run, this))
40 : nullptr),
41 last_eager_client_(nullptr),
42 enable_async_wait_for_remote_function_(
43 IsAsyncWaitForRemoteFunctionEnabled()),
44 enable_streaming_enqueue_(enable_streaming_enqueue) {}
45
~EagerExecutor()46 EagerExecutor::~EagerExecutor() {
47 tensorflow::mutex_lock l(node_queue_mutex_);
48 state_ = ExecutorState::kShutDown;
49 nodes_pending_.notify_all();
50 for (const auto& cleanups_for_key : cleanups_) {
51 for (const std::function<void()>& cleanup : cleanups_for_key.second) {
52 cleanup();
53 }
54 }
55 }
56
ShutDown()57 Status EagerExecutor::ShutDown() {
58 {
59 bool has_thread;
60 Status status;
61 {
62 tensorflow::mutex_lock l(node_queue_mutex_);
63 if (state_ != ExecutorState::kShutDown) {
64 // if the state is kShutDown, we don't return here because we want to
65 // make sure the executor thread has ended (if there is one).
66 // So, we fall through to
67 // thread_exited_notification_.WaitForNotification() below.
68 state_ = ExecutorState::kShuttingDown;
69 }
70 // It is OK to ignore the returned status here because it will be saved
71 // as the final status_.
72 WaitForAllPendingNodesLocked(&l).IgnoreError();
73 state_ = ExecutorState::kShutDown;
74 has_thread = thread_ != nullptr;
75 status = status_;
76 if (has_thread) {
77 nodes_pending_.notify_all();
78 }
79 }
80 if (!has_thread) {
81 return status;
82 }
83 }
84
85 thread_exited_notification_.WaitForNotification();
86
87 return status();
88 }
89
StateStringLocked()90 const char* EagerExecutor::StateStringLocked() {
91 switch (state_) {
92 case ExecutorState::kActive:
93 return "Active";
94 case ExecutorState::kShuttingDown:
95 return "ShuttingDown";
96 case ExecutorState::kShutDown:
97 return "ShutDown";
98 }
99 }
100
SyncExecute(EagerNode * node)101 Status EagerExecutor::SyncExecute(EagerNode* node) {
102 if (Async()) {
103 return errors::Internal("Executor does not support async execution");
104 }
105 if (node->AsAsync() != nullptr) {
106 return errors::Internal("Executor does not support executing async nodes");
107 }
108 // NOTE: SyncExecute runs every node regardless of error status in executor.
109
110 uint64 id = next_node_id_++;
111
112 Status s = node->Prepare();
113 if (!s.ok()) {
114 return s;
115 }
116
117 // Inline execution in sync mode.
118 s = node->Run();
119 tensorflow::mutex_lock l(node_queue_mutex_);
120 NotifyWaiters(id);
121 return s;
122 }
123
AddOrExecute(std::unique_ptr<EagerNode> node)124 Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
125 Status status;
126 core::RefCountPtr<NodeItem> item(new NodeItem);
127 item->id = next_node_id_++;
128 item->node = std::move(node);
129 item->state = NodeState::kPENDING;
130
131 status = item->node->Prepare();
132 if (!status.ok()) {
133 item->node->Abort(status);
134 return status;
135 }
136
137 // Inline execution in sync mode.
138 if (!Async()) {
139 // In sync mode, run the node item regardless of executor status.
140 return RunItem(std::move(item), /*from_queue=*/false);
141 } else {
142 tensorflow::mutex_lock l(node_queue_mutex_);
143 DVLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
144 << " with status: " << status_.ToString();
145 if (state_ != ExecutorState::kActive) {
146 status = errors::FailedPrecondition(
147 "EagerExecutor accepts new EagerNodes to run only in Active state. "
148 "Current state is '",
149 StateStringLocked(), "'");
150 } else {
151 status = status_;
152 if (status.ok()) {
153 node_queue_.push(std::move(item));
154 // If there were no previous nodes pending, wake the run thread to
155 // start processing requests again.
156 if (node_queue_.size() == 1) {
157 nodes_pending_.notify_all();
158 }
159
160 return OkStatus();
161 }
162 }
163 }
164
165 // If we are unable to add the node to the queue, we must call Abort. However,
166 // we want to do that outside of the scope of the lock since the Abort may
167 // try to call EagerExecutor::AddOrExecute()
168 item->node->Abort(status);
169
170 return status;
171 }
172
WaitForAllPendingNodes()173 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
174 tensorflow::mutex_lock l(node_queue_mutex_);
175 return WaitForAllPendingNodesLocked(&l);
176 }
177
WaitForAllPendingNodesLocked(mutex_lock * lock)178 tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
179 mutex_lock* lock) {
180 tensorflow::condition_variable cond;
181 // Don't wait if an error is already set.
182 if (!status_.ok()) return status_;
183 if (node_queue_.empty() && unfinished_nodes_.empty()) return OkStatus();
184 // node_queue_ must be empty in sync mode.
185 DCHECK(Async() || node_queue_.empty());
186 auto last_id = next_node_id_ - 1;
187 DVLOG(3) << "Wait for Node: [id " << last_id << "] ";
188 node_done_notifications_.insert(std::make_pair(last_id, &cond));
189 cond.wait(*lock);
190 // Note that we could be woken up if an error occurs, even though the node has
191 // not actually executed.
192 return status_;
193 }
194
ClearError()195 void EagerExecutor::ClearError() {
196 // TODO(iga): Check state_ and return an error if it is not kActive.
197 if (ok()) return;
198
199 tensorflow::mutex_lock l(node_queue_mutex_);
200 // If an error was set, node_done_notifications_ and node_queue_ should have
201 // been cleared, and no new entries should have been added since.
202 DCHECK(node_done_notifications_.empty());
203 DCHECK(node_queue_.empty());
204 status_ = OkStatus();
205 ok_ = true;
206 last_eager_client_ = nullptr;
207 nodes_pending_.notify_all();
208 }
209
NodeDone(const core::RefCountPtr<NodeItem> & item,const Status & status,bool from_queue)210 void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
211 const Status& status, bool from_queue) {
212 DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
213 << " with status: " << status.ToString();
214 DCHECK(item->state != NodeState::kDONE);
215 item->state = NodeState::kDONE;
216
217 bool async = item->node->AsAsync() != nullptr;
218 // If executing synchronously we don't need to notify if status is OK since
219 // the node was never added to the unfinished_nodes_ list and nobody should
220 // ever be waiting for it.
221 if (status.ok() && !from_queue && !async) {
222 return;
223 }
224
225 std::forward_list<core::RefCountPtr<NodeItem>> items_to_destroy;
226 {
227 mutex_lock l(node_queue_mutex_);
228 if (!status_.ok()) return;
229
230 bool need_notification = from_queue;
231 if (from_queue) {
232 // Since this was from the async queue, pop it from the front of the queue
233 DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
234 node_queue_.pop();
235 } else if (async) {
236 // If it is an Async node then we will find the node in the unfinished
237 // nodes list. However we only notify if we are at the front of the list
238 // since we don't want to notify any waiters of earlier nodes.
239 need_notification = item->id == unfinished_nodes_.begin()->first;
240 // Remove item if it exists in unfinished_nodes_.
241 // With async execution, if two separate nodes failed and enter this
242 // callback, then the second node might not find itself in
243 // unfinished_nodes_ in the following senario:
244 // 1) Callback of the first failed node clears unfinished_nodes_
245 // 2) ClearError is called and executor status_ is set to OK
246 // 3) Callback of the second failed node is triggered
247 // In this case, do not taint the executor status or other note items
248 // because they are inserted after the ClearError.
249 auto result = unfinished_nodes_.erase(item->id);
250 if (result == 0) return;
251 }
252
253 if (!status.ok() && item->node->Fatal()) {
254 // Since we received an error, broadcast to any waiters.
255 need_notification = true;
256 status_ = status;
257 ok_ = false;
258 if (Async()) {
259 // We remove any pending ops so that we don't try to execute them if
260 // ClearError is called.
261 errors::AppendToMessage(&status_,
262 "Encountered when executing an operation using "
263 "EagerExecutor. This error cancels all future "
264 "operations and poisons their output tensors.");
265 }
266 while (!node_queue_.empty()) {
267 items_to_destroy.push_front(std::move(node_queue_.front()));
268 node_queue_.pop();
269 }
270 for (auto& it : unfinished_nodes_) {
271 items_to_destroy.push_front(std::move(it.second));
272 }
273 unfinished_nodes_.clear();
274 }
275 if (need_notification) {
276 NotifyWaiters(item->id);
277 }
278 }
279
280 for (auto& item : items_to_destroy) {
281 item->node->Abort(status);
282 }
283 // nodes_to_destroy will be destructed here, while not holding
284 // node_queue_mutex_. This is important because, unfortunately, some nodes'
285 // destructors can enqueue more operations onto this executor and cause
286 // a deadlock.
287 }
288
NotifyWaiters(uint64 id)289 void EagerExecutor::NotifyWaiters(uint64 id) {
290 if (!node_done_notifications_.empty()) {
291 uint64 upperbound_id = 0;
292 if (!unfinished_nodes_.empty()) {
293 upperbound_id = unfinished_nodes_.begin()->first - 1;
294 } else if (!node_queue_.empty()) {
295 upperbound_id = node_queue_.front()->id - 1;
296 } else {
297 upperbound_id = next_node_id_ - 1;
298 }
299 if (upperbound_id < id) {
300 return;
301 }
302 DVLOG(3) << "Notify node done: [id " << id << " to " << upperbound_id
303 << "] ";
304 // Note that we notify all waiting threads in case an error has
305 // occurred. These calling threads are responsible for checking status_
306 // before proceeding.
307 const auto range =
308 status_.ok()
309 ? make_pair(node_done_notifications_.lower_bound(id),
310 node_done_notifications_.upper_bound(upperbound_id))
311 : make_pair(node_done_notifications_.begin(),
312 node_done_notifications_.end());
313 for (auto it = range.first; it != range.second; ++it) {
314 it->second->notify_all();
315 }
316 node_done_notifications_.erase(range.first, range.second);
317 }
318 }
319
Run()320 void EagerExecutor::Run() {
321 auto thread_exited_notifier =
322 gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
323 while (true) {
324 core::RefCountPtr<NodeItem> curr_item;
325 {
326 tensorflow::mutex_lock l(node_queue_mutex_);
327 while (node_queue_.empty() || !status_.ok()) {
328 if (state_ == ExecutorState::kShutDown) return;
329 nodes_pending_.wait(l);
330 }
331 // Obtain raw pointer since we don't want to remove from the queue until
332 // the node has been run. Otherwise, WaitForAllPendingNodes can return
333 // too early.
334 // Note, we don't std::move from the here because the front of the queue
335 // will then contain a nullptr. This can be a problem in
336 // WaitForAllPendingNodes where we get the top EagerNode pointer
337 // and register a notification for its completion.
338 curr_item.reset(node_queue_.front().get());
339 curr_item->Ref();
340 }
341 Status status = RunItem(std::move(curr_item), /*from_queue=*/true);
342 if (!status.ok()) {
343 VLOG(1) << "Failed to run item: " << status;
344 }
345 }
346 }
347
RunItem(core::RefCountPtr<NodeItem> item,bool from_queue)348 Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
349 bool from_queue) {
350 DVLOG(3) << "Running Node: [id " << item->id << "] "
351 << item->node->DebugString();
352 AsyncRemoteExecuteNode* async_remote_node =
353 item->node->AsAsyncRemoteExecuteNode();
354 if (enable_async_wait_for_remote_function_) {
355 if (async_remote_node != nullptr) {
356 if (last_eager_client_ != nullptr &&
357 async_remote_node->eager_client() != nullptr &&
358 last_eager_client_ != async_remote_node->eager_client()) {
359 // Running a remote function, need to sync if the function is going to
360 // different device than last time we run remote distributed function.
361 DVLOG(3) << "Executing Sync Executor for node" << item->id;
362 tensorflow::Status status = async_remote_node->SyncExecutors();
363 if (!status.ok()) {
364 NodeDone(item, status, from_queue);
365 return status;
366 }
367 last_eager_client_ = nullptr;
368 }
369 if (async_remote_node->eager_client() != nullptr &&
370 async_remote_node->needs_remote_inputs() &&
371 async_remote_node->allow_multiple_pending_requests()) {
372 // We are running remote distributed function, update
373 // last_remote_device_name_.
374 last_eager_client_ = async_remote_node->eager_client();
375 }
376 }
377 }
378
379 AsyncEagerNode* async_node = item->node->AsAsync();
380 if (async_node == nullptr) {
381 tensorflow::Status status = item->node->Run();
382 NodeDone(item, status, from_queue);
383 return status;
384 }
385
386 item->state = NodeState::kSCHEDULED;
387 auto async_ref = item.get();
388 async_ref->Ref();
389
390 TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue));
391
392 async_node->RunAsync([this, async_ref](const Status& status) {
393 core::RefCountPtr<NodeItem> async_item(async_ref);
394 NodeDone(async_item, status, false);
395 });
396
397 // Return the status of the executor in case we are in an error state.
398 return status();
399 }
400
MoveToUnfinished(core::RefCountPtr<NodeItem> item,bool from_queue)401 Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
402 bool from_queue) {
403 tensorflow::mutex_lock l(node_queue_mutex_);
404 if (!status_.ok()) {
405 return status_;
406 }
407
408 if (from_queue) {
409 DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
410 node_queue_.pop();
411 }
412
413 DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
414 unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
415 std::move(item));
416
417 return OkStatus();
418 }
419
AddCleanup(intptr_t key,std::function<void ()> callback)420 void EagerExecutor::AddCleanup(intptr_t key, std::function<void()> callback) {
421 cleanups_[key].push_back(callback);
422 }
423
RemoveCleanups(intptr_t key)424 void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); }
425
426 } // namespace tensorflow
427