• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
17 
18 #include <forward_list>
19 
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/cleanup.h"
22 #include "tensorflow/core/util/env_var.h"
23 
24 namespace tensorflow {
25 namespace {
IsAsyncWaitForRemoteFunctionEnabled()26 bool IsAsyncWaitForRemoteFunctionEnabled() {
27   bool enabled = true;
28   TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_ASYNC_WAIT_FOR_REMOTE_FUNCTION",
29                                  true, &enabled));
30   return enabled;
31 }
32 }  // namespace
33 
EagerExecutor(bool async,bool enable_streaming_enqueue)34 EagerExecutor::EagerExecutor(bool async, bool enable_streaming_enqueue)
35     : next_node_id_(0),
36       ok_(true),
37       thread_(async ? tensorflow::Env::Default()->StartThread(
38                           tensorflow::ThreadOptions(), "eager_async_executor",
39                           std::bind(&EagerExecutor::Run, this))
40                     : nullptr),
41       last_eager_client_(nullptr),
42       enable_async_wait_for_remote_function_(
43           IsAsyncWaitForRemoteFunctionEnabled()),
44       enable_streaming_enqueue_(enable_streaming_enqueue) {}
45 
~EagerExecutor()46 EagerExecutor::~EagerExecutor() {
47   tensorflow::mutex_lock l(node_queue_mutex_);
48   state_ = ExecutorState::kShutDown;
49   nodes_pending_.notify_all();
50   for (const auto& cleanups_for_key : cleanups_) {
51     for (const std::function<void()>& cleanup : cleanups_for_key.second) {
52       cleanup();
53     }
54   }
55 }
56 
ShutDown()57 Status EagerExecutor::ShutDown() {
58   {
59     bool has_thread;
60     Status status;
61     {
62       tensorflow::mutex_lock l(node_queue_mutex_);
63       if (state_ != ExecutorState::kShutDown) {
64         // if the state is kShutDown, we don't return here because we want to
65         // make sure the executor thread has ended (if there is one).
66         // So, we fall through to
67         // thread_exited_notification_.WaitForNotification() below.
68         state_ = ExecutorState::kShuttingDown;
69       }
70       // It is OK to ignore the returned status here because it will be saved
71       // as the final status_.
72       WaitForAllPendingNodesLocked(&l).IgnoreError();
73       state_ = ExecutorState::kShutDown;
74       has_thread = thread_ != nullptr;
75       status = status_;
76       if (has_thread) {
77         nodes_pending_.notify_all();
78       }
79     }
80     if (!has_thread) {
81       return status;
82     }
83   }
84 
85   thread_exited_notification_.WaitForNotification();
86 
87   return status();
88 }
89 
StateStringLocked()90 const char* EagerExecutor::StateStringLocked() {
91   switch (state_) {
92     case ExecutorState::kActive:
93       return "Active";
94     case ExecutorState::kShuttingDown:
95       return "ShuttingDown";
96     case ExecutorState::kShutDown:
97       return "ShutDown";
98   }
99 }
100 
SyncExecute(EagerNode * node)101 Status EagerExecutor::SyncExecute(EagerNode* node) {
102   if (Async()) {
103     return errors::Internal("Executor does not support async execution");
104   }
105   if (node->AsAsync() != nullptr) {
106     return errors::Internal("Executor does not support executing async nodes");
107   }
108   // NOTE: SyncExecute runs every node regardless of error status in executor.
109 
110   uint64 id = next_node_id_++;
111 
112   Status s = node->Prepare();
113   if (!s.ok()) {
114     return s;
115   }
116 
117   // Inline execution in sync mode.
118   s = node->Run();
119   tensorflow::mutex_lock l(node_queue_mutex_);
120   NotifyWaiters(id);
121   return s;
122 }
123 
AddOrExecute(std::unique_ptr<EagerNode> node)124 Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
125   Status status;
126   core::RefCountPtr<NodeItem> item(new NodeItem);
127   item->id = next_node_id_++;
128   item->node = std::move(node);
129   item->state = NodeState::kPENDING;
130 
131   status = item->node->Prepare();
132   if (!status.ok()) {
133     item->node->Abort(status);
134     return status;
135   }
136 
137   // Inline execution in sync mode.
138   if (!Async()) {
139     // In sync mode, run the node item regardless of executor status.
140     return RunItem(std::move(item), /*from_queue=*/false);
141   } else {
142     tensorflow::mutex_lock l(node_queue_mutex_);
143     DVLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
144              << " with status: " << status_.ToString();
145     if (state_ != ExecutorState::kActive) {
146       status = errors::FailedPrecondition(
147           "EagerExecutor accepts new EagerNodes to run only in Active state. "
148           "Current state is '",
149           StateStringLocked(), "'");
150     } else {
151       status = status_;
152       if (status.ok()) {
153         node_queue_.push(std::move(item));
154         // If there were no previous nodes pending, wake the run thread to
155         // start processing requests again.
156         if (node_queue_.size() == 1) {
157           nodes_pending_.notify_all();
158         }
159 
160         return OkStatus();
161       }
162     }
163   }
164 
165   // If we are unable to add the node to the queue, we must call Abort. However,
166   // we want to do that outside of the scope of the lock since the Abort may
167   // try to call EagerExecutor::AddOrExecute()
168   item->node->Abort(status);
169 
170   return status;
171 }
172 
WaitForAllPendingNodes()173 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
174   tensorflow::mutex_lock l(node_queue_mutex_);
175   return WaitForAllPendingNodesLocked(&l);
176 }
177 
WaitForAllPendingNodesLocked(mutex_lock * lock)178 tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
179     mutex_lock* lock) {
180   tensorflow::condition_variable cond;
181   // Don't wait if an error is already set.
182   if (!status_.ok()) return status_;
183   if (node_queue_.empty() && unfinished_nodes_.empty()) return OkStatus();
184   // node_queue_ must be empty in sync mode.
185   DCHECK(Async() || node_queue_.empty());
186   auto last_id = next_node_id_ - 1;
187   DVLOG(3) << "Wait for Node: [id " << last_id << "] ";
188   node_done_notifications_.insert(std::make_pair(last_id, &cond));
189   cond.wait(*lock);
190   // Note that we could be woken up if an error occurs, even though the node has
191   // not actually executed.
192   return status_;
193 }
194 
ClearError()195 void EagerExecutor::ClearError() {
196   // TODO(iga): Check state_ and return an error if it is not kActive.
197   if (ok()) return;
198 
199   tensorflow::mutex_lock l(node_queue_mutex_);
200   // If an error was set, node_done_notifications_ and node_queue_ should have
201   // been cleared, and no new entries should have been added since.
202   DCHECK(node_done_notifications_.empty());
203   DCHECK(node_queue_.empty());
204   status_ = OkStatus();
205   ok_ = true;
206   last_eager_client_ = nullptr;
207   nodes_pending_.notify_all();
208 }
209 
NodeDone(const core::RefCountPtr<NodeItem> & item,const Status & status,bool from_queue)210 void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
211                              const Status& status, bool from_queue) {
212   DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
213            << " with status: " << status.ToString();
214   DCHECK(item->state != NodeState::kDONE);
215   item->state = NodeState::kDONE;
216 
217   bool async = item->node->AsAsync() != nullptr;
218   // If executing synchronously we don't need to notify if status is OK since
219   // the node  was never added to the unfinished_nodes_ list and nobody should
220   // ever be waiting for it.
221   if (status.ok() && !from_queue && !async) {
222     return;
223   }
224 
225   std::forward_list<core::RefCountPtr<NodeItem>> items_to_destroy;
226   {
227     mutex_lock l(node_queue_mutex_);
228     if (!status_.ok()) return;
229 
230     bool need_notification = from_queue;
231     if (from_queue) {
232       // Since this was from the async queue, pop it from the front of the queue
233       DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
234       node_queue_.pop();
235     } else if (async) {
236       // If it is an Async node then we will find the node in the unfinished
237       // nodes list. However we only notify if we are at the front of the list
238       // since we don't want to notify any waiters of earlier nodes.
239       need_notification = item->id == unfinished_nodes_.begin()->first;
240       // Remove item if it exists in unfinished_nodes_.
241       // With async execution, if two separate nodes failed and enter this
242       // callback, then the second node might not find itself in
243       // unfinished_nodes_ in the following senario:
244       //   1) Callback of the first failed node clears unfinished_nodes_
245       //   2) ClearError is called and executor status_ is set to OK
246       //   3) Callback of the second failed node is triggered
247       // In this case, do not taint the executor status or other note items
248       // because they are inserted after the ClearError.
249       auto result = unfinished_nodes_.erase(item->id);
250       if (result == 0) return;
251     }
252 
253     if (!status.ok() && item->node->Fatal()) {
254       // Since we received an error, broadcast to any waiters.
255       need_notification = true;
256       status_ = status;
257       ok_ = false;
258       if (Async()) {
259         // We remove any pending ops so that we don't try to execute them if
260         // ClearError is called.
261         errors::AppendToMessage(&status_,
262                                 "Encountered when executing an operation using "
263                                 "EagerExecutor. This error cancels all future "
264                                 "operations and poisons their output tensors.");
265       }
266       while (!node_queue_.empty()) {
267         items_to_destroy.push_front(std::move(node_queue_.front()));
268         node_queue_.pop();
269       }
270       for (auto& it : unfinished_nodes_) {
271         items_to_destroy.push_front(std::move(it.second));
272       }
273       unfinished_nodes_.clear();
274     }
275     if (need_notification) {
276       NotifyWaiters(item->id);
277     }
278   }
279 
280   for (auto& item : items_to_destroy) {
281     item->node->Abort(status);
282   }
283   // nodes_to_destroy will be destructed here, while not holding
284   // node_queue_mutex_. This is important because, unfortunately, some nodes'
285   // destructors can enqueue more operations onto this executor and cause
286   // a deadlock.
287 }
288 
NotifyWaiters(uint64 id)289 void EagerExecutor::NotifyWaiters(uint64 id) {
290   if (!node_done_notifications_.empty()) {
291     uint64 upperbound_id = 0;
292     if (!unfinished_nodes_.empty()) {
293       upperbound_id = unfinished_nodes_.begin()->first - 1;
294     } else if (!node_queue_.empty()) {
295       upperbound_id = node_queue_.front()->id - 1;
296     } else {
297       upperbound_id = next_node_id_ - 1;
298     }
299     if (upperbound_id < id) {
300       return;
301     }
302     DVLOG(3) << "Notify node done: [id " << id << " to " << upperbound_id
303              << "] ";
304     // Note that we notify all waiting threads in case an error has
305     // occurred. These calling threads are responsible for checking status_
306     // before proceeding.
307     const auto range =
308         status_.ok()
309             ? make_pair(node_done_notifications_.lower_bound(id),
310                         node_done_notifications_.upper_bound(upperbound_id))
311             : make_pair(node_done_notifications_.begin(),
312                         node_done_notifications_.end());
313     for (auto it = range.first; it != range.second; ++it) {
314       it->second->notify_all();
315     }
316     node_done_notifications_.erase(range.first, range.second);
317   }
318 }
319 
Run()320 void EagerExecutor::Run() {
321   auto thread_exited_notifier =
322       gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
323   while (true) {
324     core::RefCountPtr<NodeItem> curr_item;
325     {
326       tensorflow::mutex_lock l(node_queue_mutex_);
327       while (node_queue_.empty() || !status_.ok()) {
328         if (state_ == ExecutorState::kShutDown) return;
329         nodes_pending_.wait(l);
330       }
331       // Obtain raw pointer since we don't want to remove from the queue until
332       // the node has been run. Otherwise, WaitForAllPendingNodes can return
333       // too early.
334       // Note, we don't std::move from the here because the front of the queue
335       // will then contain a nullptr. This can be a problem in
336       // WaitForAllPendingNodes where we get the top EagerNode pointer
337       // and register a notification for its completion.
338       curr_item.reset(node_queue_.front().get());
339       curr_item->Ref();
340     }
341     Status status = RunItem(std::move(curr_item), /*from_queue=*/true);
342     if (!status.ok()) {
343       VLOG(1) << "Failed to run item: " << status;
344     }
345   }
346 }
347 
RunItem(core::RefCountPtr<NodeItem> item,bool from_queue)348 Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
349                               bool from_queue) {
350   DVLOG(3) << "Running Node: [id " << item->id << "] "
351            << item->node->DebugString();
352   AsyncRemoteExecuteNode* async_remote_node =
353       item->node->AsAsyncRemoteExecuteNode();
354   if (enable_async_wait_for_remote_function_) {
355     if (async_remote_node != nullptr) {
356       if (last_eager_client_ != nullptr &&
357           async_remote_node->eager_client() != nullptr &&
358           last_eager_client_ != async_remote_node->eager_client()) {
359         // Running a remote function, need to sync if the function is going to
360         // different device than last time we run remote distributed function.
361         DVLOG(3) << "Executing Sync Executor for node" << item->id;
362         tensorflow::Status status = async_remote_node->SyncExecutors();
363         if (!status.ok()) {
364           NodeDone(item, status, from_queue);
365           return status;
366         }
367         last_eager_client_ = nullptr;
368       }
369       if (async_remote_node->eager_client() != nullptr &&
370           async_remote_node->needs_remote_inputs() &&
371           async_remote_node->allow_multiple_pending_requests()) {
372         // We are running remote distributed function, update
373         // last_remote_device_name_.
374         last_eager_client_ = async_remote_node->eager_client();
375       }
376     }
377   }
378 
379   AsyncEagerNode* async_node = item->node->AsAsync();
380   if (async_node == nullptr) {
381     tensorflow::Status status = item->node->Run();
382     NodeDone(item, status, from_queue);
383     return status;
384   }
385 
386   item->state = NodeState::kSCHEDULED;
387   auto async_ref = item.get();
388   async_ref->Ref();
389 
390   TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue));
391 
392   async_node->RunAsync([this, async_ref](const Status& status) {
393     core::RefCountPtr<NodeItem> async_item(async_ref);
394     NodeDone(async_item, status, false);
395   });
396 
397   // Return the status of the executor in case we are in an error state.
398   return status();
399 }
400 
MoveToUnfinished(core::RefCountPtr<NodeItem> item,bool from_queue)401 Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
402                                        bool from_queue) {
403   tensorflow::mutex_lock l(node_queue_mutex_);
404   if (!status_.ok()) {
405     return status_;
406   }
407 
408   if (from_queue) {
409     DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
410     node_queue_.pop();
411   }
412 
413   DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
414   unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
415                                  std::move(item));
416 
417   return OkStatus();
418 }
419 
AddCleanup(intptr_t key,std::function<void ()> callback)420 void EagerExecutor::AddCleanup(intptr_t key, std::function<void()> callback) {
421   cleanups_[key].push_back(callback);
422 }
423 
RemoveCleanups(intptr_t key)424 void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); }
425 
426 }  // namespace tensorflow
427