• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
17 
18 #include <forward_list>
19 
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/cleanup.h"
22 #include "tensorflow/core/util/env_var.h"
23 
24 namespace tensorflow {
25 namespace {
IsAsyncWaitForRemoteFunctionEnabled()26 bool IsAsyncWaitForRemoteFunctionEnabled() {
27   bool enabled = true;
28   TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_ASYNC_WAIT_FOR_REMOTE_FUNCTION",
29                                  true, &enabled));
30   return enabled;
31 }
32 }  // namespace
33 
EagerExecutor(bool async)34 EagerExecutor::EagerExecutor(bool async)
35     : next_node_id_(0),
36       ok_(true),
37       thread_(async ? tensorflow::Env::Default()->StartThread(
38                           tensorflow::ThreadOptions(), "eager_async_executor",
39                           std::bind(&EagerExecutor::Run, this))
40                     : nullptr),
41       last_eager_client_(nullptr),
42       enable_async_wait_for_remote_function_(
43           IsAsyncWaitForRemoteFunctionEnabled()) {}
44 
~EagerExecutor()45 EagerExecutor::~EagerExecutor() {
46   tensorflow::mutex_lock l(node_queue_mutex_);
47   state_ = ExecutorState::kShutDown;
48   nodes_pending_.notify_all();
49   for (const auto& cleanups_for_key : cleanups_) {
50     for (const std::function<void()>& cleanup : cleanups_for_key.second) {
51       cleanup();
52     }
53   }
54 }
55 
ShutDown()56 Status EagerExecutor::ShutDown() {
57   {
58     bool has_thread;
59     Status status;
60     {
61       tensorflow::mutex_lock l(node_queue_mutex_);
62       if (state_ != ExecutorState::kShutDown) {
63         // if the state is kShutDown, we don't return here because we want to
64         // make sure the executor thread has ended (if there is one).
65         // So, we fall through to
66         // thread_exited_notification_.WaitForNotification() below.
67         state_ = ExecutorState::kShuttingDown;
68       }
69       // It is OK to ignore the returned status here because it will be saved
70       // as the final status_.
71       WaitForAllPendingNodesLocked(&l).IgnoreError();
72       state_ = ExecutorState::kShutDown;
73       has_thread = thread_ != nullptr;
74       status = status_;
75       if (has_thread) {
76         nodes_pending_.notify_all();
77       }
78     }
79     if (!has_thread) {
80       return status;
81     }
82   }
83 
84   thread_exited_notification_.WaitForNotification();
85 
86   return status();
87 }
88 
StateStringLocked()89 const char* EagerExecutor::StateStringLocked() {
90   switch (state_) {
91     case ExecutorState::kActive:
92       return "Active";
93     case ExecutorState::kShuttingDown:
94       return "ShuttingDown";
95     case ExecutorState::kShutDown:
96       return "ShutDown";
97   }
98 }
99 
SyncExecute(EagerNode * node)100 Status EagerExecutor::SyncExecute(EagerNode* node) {
101   if (Async()) {
102     return errors::Internal("Executor does not support async execution");
103   }
104   if (node->AsAsync() != nullptr) {
105     return errors::Internal("Executor does not support executing async nodes");
106   }
107   // NOTE: SyncExecute runs every node regardless of error status in executor.
108 
109   uint64 id = next_node_id_++;
110 
111   Status s = node->Prepare();
112   if (!s.ok()) {
113     return s;
114   }
115 
116   // Inline execution in sync mode.
117   s = node->Run();
118   tensorflow::mutex_lock l(node_queue_mutex_);
119   if (!s.ok()) {
120     status_ = s;
121     ok_ = false;
122   }
123   NotifyWaiters(id);
124   return s;
125 }
126 
AddOrExecute(std::unique_ptr<EagerNode> node)127 Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
128   Status status;
129   core::RefCountPtr<NodeItem> item(new NodeItem);
130   item->id = next_node_id_++;
131   item->node = std::move(node);
132   item->state = NodeState::kPENDING;
133 
134   status = item->node->Prepare();
135   if (!status.ok()) {
136     item->node->Abort(status);
137     return status;
138   }
139 
140   // Inline execution in sync mode.
141   if (!Async()) {
142     // In sync mode, run the node item regardless of executor status.
143     return RunItem(std::move(item), /*from_queue=*/false);
144   } else {
145     tensorflow::mutex_lock l(node_queue_mutex_);
146     DVLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
147              << " with status: " << status_.ToString();
148     if (state_ != ExecutorState::kActive) {
149       status = errors::FailedPrecondition(
150           "EagerExecutor accepts new EagerNodes to run only in Active state. "
151           "Current state is '",
152           StateStringLocked(), "'");
153     } else {
154       status = status_;
155       if (status.ok()) {
156         node_queue_.push(std::move(item));
157         // If there were no previous nodes pending, wake the run thread to
158         // start processing requests again.
159         if (node_queue_.size() == 1) {
160           nodes_pending_.notify_all();
161         }
162 
163         return Status::OK();
164       }
165     }
166   }
167 
168   // If we are unable to add the node to the queue, we must call Abort. However,
169   // we want to do that outside of the scope of the lock since the Abort may
170   // try to call EagerExecutor::AddOrExecute()
171   item->node->Abort(status);
172 
173   return status;
174 }
175 
WaitForAllPendingNodes()176 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
177   tensorflow::mutex_lock l(node_queue_mutex_);
178   return WaitForAllPendingNodesLocked(&l);
179 }
180 
WaitForAllPendingNodesLocked(mutex_lock * lock)181 tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
182     mutex_lock* lock) {
183   tensorflow::condition_variable cond;
184   // Don't wait if an error is already set.
185   if (!status_.ok()) return status_;
186   if (node_queue_.empty() && unfinished_nodes_.empty())
187     return tensorflow::Status::OK();
188   // node_queue_ must be empty in sync mode.
189   DCHECK(Async() || node_queue_.empty());
190   auto last_id = next_node_id_ - 1;
191   DVLOG(3) << "Wait for Node: [id " << last_id << "] ";
192   node_done_notifications_.insert(std::make_pair(last_id, &cond));
193   cond.wait(*lock);
194   // Note that we could be woken up if an error occurs, even though the node has
195   // not actually executed.
196   return status_;
197 }
198 
ClearError()199 void EagerExecutor::ClearError() {
200   // TODO(iga): Check state_ and return an error if it is not kActive.
201   if (ok()) return;
202 
203   tensorflow::mutex_lock l(node_queue_mutex_);
204   // If an error was set, node_done_notifications_ and node_queue_ should have
205   // been cleared, and no new entries should have been added since.
206   DCHECK(node_done_notifications_.empty());
207   DCHECK(node_queue_.empty());
208   status_ = tensorflow::Status::OK();
209   ok_ = true;
210   last_eager_client_ = nullptr;
211   nodes_pending_.notify_all();
212 }
213 
NodeDone(const core::RefCountPtr<NodeItem> & item,const Status & status,bool from_queue)214 void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
215                              const Status& status, bool from_queue) {
216   DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
217            << " with status: " << status.ToString();
218   DCHECK(item->state != NodeState::kDONE);
219   item->state = NodeState::kDONE;
220 
221   bool async = item->node->AsAsync() != nullptr;
222   // If executing synchronously we don't need to notify if status is OK since
223   // the node  was never added to the unfinished_nodes_ list and nobody should
224   // ever be waiting for it.
225   if (status.ok() && !from_queue && !async) {
226     return;
227   }
228 
229   std::forward_list<core::RefCountPtr<NodeItem>> items_to_destroy;
230   {
231     mutex_lock l(node_queue_mutex_);
232     if (!status_.ok()) return;
233 
234     bool need_notification = from_queue;
235     if (from_queue) {
236       // Since this was from the async queue, pop it from the front of the queue
237       DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
238       node_queue_.pop();
239     } else if (async) {
240       // If it is an Async node then we will find the node in the unfinished
241       // nodes list. However we only notify if we are at the front of the list
242       // since we don't want to notify any waiters of earlier nodes.
243       need_notification = item->id == unfinished_nodes_.begin()->first;
244       // Remove item if it exists in unfinished_nodes_.
245       // With async execution, if two separate nodes failed and enter this
246       // callback, then the second node might not find itself in
247       // unfinished_nodes_ in the following senario:
248       //   1) Callback of the first failed node clears unfinished_nodes_
249       //   2) ClearError is called and executor status_ is set to OK
250       //   3) Callback of the second failed node is triggered
251       // In this case, do not taint the executor status or other note items
252       // because they are inserted after the ClearError.
253       auto result = unfinished_nodes_.erase(item->id);
254       if (result == 0) return;
255     }
256 
257     if (!status.ok() && item->node->Fatal()) {
258       // Since we received an error, broadcast to any waiters.
259       need_notification = true;
260       status_ = status;
261       ok_ = false;
262       if (Async()) {
263         // We remove any pending ops so that we don't try to execute them if
264         // ClearError is called.
265         errors::AppendToMessage(&status_,
266                                 "Encountered when executing an operation using "
267                                 "EagerExecutor. This error cancels all future "
268                                 "operations and poisons their output tensors.");
269       }
270       while (!node_queue_.empty()) {
271         items_to_destroy.push_front(std::move(node_queue_.front()));
272         node_queue_.pop();
273       }
274       for (auto& it : unfinished_nodes_) {
275         items_to_destroy.push_front(std::move(it.second));
276       }
277       unfinished_nodes_.clear();
278     }
279     if (need_notification) {
280       NotifyWaiters(item->id);
281     }
282   }
283 
284   for (auto& item : items_to_destroy) {
285     item->node->Abort(status);
286   }
287   // nodes_to_destroy will be destructed here, while not holding
288   // node_queue_mutex_. This is important because, unfortunately, some nodes'
289   // destructors can enqueue more operations onto this executor and cause
290   // a deadlock.
291 }
292 
NotifyWaiters(uint64 id)293 void EagerExecutor::NotifyWaiters(uint64 id) {
294   if (!node_done_notifications_.empty()) {
295     uint64 upperbound_id = 0;
296     if (!unfinished_nodes_.empty()) {
297       upperbound_id = unfinished_nodes_.begin()->first - 1;
298     } else if (!node_queue_.empty()) {
299       upperbound_id = node_queue_.front()->id - 1;
300     } else {
301       upperbound_id = next_node_id_ - 1;
302     }
303     if (upperbound_id < id) {
304       return;
305     }
306     DVLOG(3) << "Notify node done: [id " << id << " to " << upperbound_id
307              << "] ";
308     // Note that we notify all waiting threads in case an error has
309     // occurred. These calling threads are responsible for checking status_
310     // before proceeding.
311     const auto range =
312         status_.ok()
313             ? make_pair(node_done_notifications_.lower_bound(id),
314                         node_done_notifications_.upper_bound(upperbound_id))
315             : make_pair(node_done_notifications_.begin(),
316                         node_done_notifications_.end());
317     for (auto it = range.first; it != range.second; ++it) {
318       it->second->notify_all();
319     }
320     node_done_notifications_.erase(range.first, range.second);
321   }
322 }
323 
Run()324 void EagerExecutor::Run() {
325   auto thread_exited_notifier =
326       gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
327   while (true) {
328     core::RefCountPtr<NodeItem> curr_item;
329     {
330       tensorflow::mutex_lock l(node_queue_mutex_);
331       while (node_queue_.empty() || !status_.ok()) {
332         if (state_ == ExecutorState::kShutDown) return;
333         nodes_pending_.wait(l);
334       }
335       // Obtain raw pointer since we don't want to remove from the queue until
336       // the node has been run. Otherwise, WaitForAllPendingNodes can return
337       // too early.
338       // Note, we don't std::move from the here because the front of the queue
339       // will then contain a nullptr. This can be a problem in
340       // WaitForAllPendingNodes where we get the top EagerNode pointer
341       // and register a notification for its completion.
342       curr_item.reset(node_queue_.front().get());
343       curr_item->Ref();
344     }
345     Status status = RunItem(std::move(curr_item), /*from_queue=*/true);
346     if (!status.ok()) {
347       VLOG(1) << "Failed to run item: " << status;
348     }
349   }
350 }
351 
RunItem(core::RefCountPtr<NodeItem> item,bool from_queue)352 Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
353                               bool from_queue) {
354   DVLOG(3) << "Running Node: [id " << item->id << "] "
355            << item->node->DebugString();
356   AsyncRemoteExecuteNode* async_remote_node =
357       item->node->AsAsyncRemoteExecuteNode();
358   if (enable_async_wait_for_remote_function_) {
359     if (async_remote_node != nullptr) {
360       if (last_eager_client_ != nullptr &&
361           async_remote_node->eager_client() != nullptr &&
362           last_eager_client_ != async_remote_node->eager_client()) {
363         // Running a remote function, need to sync if the function is going to
364         // different device than last time we run remote distributed function.
365         DVLOG(3) << "Executing Sync Executor for node" << item->id;
366         tensorflow::Status status = async_remote_node->SyncExecutors();
367         if (!status.ok()) {
368           NodeDone(item, status, from_queue);
369           return status;
370         }
371         last_eager_client_ = nullptr;
372       }
373       if (async_remote_node->eager_client() != nullptr &&
374           async_remote_node->needs_remote_inputs() &&
375           async_remote_node->allow_multiple_pending_requests()) {
376         // We are running remote distributed function, update
377         // last_remote_device_name_.
378         last_eager_client_ = async_remote_node->eager_client();
379       }
380     }
381   }
382 
383   AsyncEagerNode* async_node = item->node->AsAsync();
384   if (async_node == nullptr) {
385     tensorflow::Status status = item->node->Run();
386     NodeDone(item, status, from_queue);
387     return status;
388   }
389 
390   item->state = NodeState::kSCHEDULED;
391   auto async_ref = item.get();
392   async_ref->Ref();
393 
394   TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue));
395 
396   async_node->RunAsync([this, async_ref](const Status& status) {
397     core::RefCountPtr<NodeItem> async_item(async_ref);
398     NodeDone(async_item, status, false);
399   });
400 
401   // Return the status of the executor in case we are in an error state.
402   return status();
403 }
404 
MoveToUnfinished(core::RefCountPtr<NodeItem> item,bool from_queue)405 Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
406                                        bool from_queue) {
407   tensorflow::mutex_lock l(node_queue_mutex_);
408   if (!status_.ok()) {
409     return status_;
410   }
411 
412   if (from_queue) {
413     DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
414     node_queue_.pop();
415   }
416 
417   DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
418   unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
419                                  std::move(item));
420 
421   return Status::OK();
422 }
423 
AddCleanup(intptr_t key,std::function<void ()> callback)424 void EagerExecutor::AddCleanup(intptr_t key, std::function<void()> callback) {
425   cleanups_[key].push_back(callback);
426 }
427 
RemoveCleanups(intptr_t key)428 void EagerExecutor::RemoveCleanups(intptr_t key) { cleanups_.erase(key); }
429 
430 }  // namespace tensorflow
431