• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
17 
18 #include <forward_list>
19 
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/cleanup.h"
22 
23 namespace tensorflow {
24 
EagerExecutor(bool async)25 EagerExecutor::EagerExecutor(bool async)
26     : next_node_id_(0),
27       ok_(true),
28       thread_(async ? tensorflow::Env::Default()->StartThread(
29                           tensorflow::ThreadOptions(), "eager_async_executor",
30                           std::bind(&EagerExecutor::Run, this))
31                     : nullptr) {}
32 
~EagerExecutor()33 EagerExecutor::~EagerExecutor() {
34   tensorflow::mutex_lock l(node_queue_mutex_);
35   state_ = ExecutorState::kShutDown;
36   nodes_pending_.notify_all();
37 }
38 
ShutDown()39 Status EagerExecutor::ShutDown() {
40   {
41     std::vector<core::RefCountPtr<NodeItem>> items_to_destroy;
42     bool has_thread;
43     Status status;
44     {
45       tensorflow::mutex_lock l(node_queue_mutex_);
46       if (state_ != ExecutorState::kShutDown) {
47         // if the state is kShutDown, we don't return here because we want to
48         // make sure the executor thread has ended (if there is one).
49         // So, we fall through to
50         // thread_exited_notification_.WaitForNotification() below.
51         state_ = ExecutorState::kShuttingDown;
52       }
53       // It is OK to ignore the returned status here because it will be saved
54       // as the final status_.
55       WaitForAllPendingNodesLocked(&l).IgnoreError();
56       state_ = ExecutorState::kShutDown;
57       has_thread = thread_ != nullptr;
58       status = status_;
59       if (has_thread) {
60         nodes_pending_.notify_all();
61       }
62     }
63     for (auto& item : items_to_destroy) {
64       item->node->Abort(status);
65     }
66     if (!has_thread) {
67       return status;
68     }
69   }
70 
71   thread_exited_notification_.WaitForNotification();
72 
73   return status();
74 }
75 
StateStringLocked()76 const char* EagerExecutor::StateStringLocked() {
77   switch (state_) {
78     case ExecutorState::kActive:
79       return "Active";
80     case ExecutorState::kShuttingDown:
81       return "ShuttingDown";
82     case ExecutorState::kShutDown:
83       return "ShutDown";
84   }
85 }
86 
SyncExecute(EagerNode * node)87 Status EagerExecutor::SyncExecute(EagerNode* node) {
88   if (Async()) {
89     return errors::Internal("Executor does not support sync execution");
90   }
91   if (node->AsAsync() != nullptr) {
92     return errors::Internal("Executor does not support executing async nodes");
93   }
94   // NOTE: SyncExecute runs every node regardless of error status in executor.
95 
96   uint64 id = next_node_id_++;
97 
98   Status s = node->Prepare();
99   if (!s.ok()) {
100     return s;
101   }
102 
103   // Inline execution in sync mode.
104   s = node->Run();
105   tensorflow::mutex_lock l(node_queue_mutex_);
106   if (!s.ok()) {
107     status_ = s;
108     ok_ = false;
109   }
110   NotifyWaiters(id);
111   return s;
112 }
113 
AddOrExecute(std::unique_ptr<EagerNode> node)114 Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
115   Status status;
116   core::RefCountPtr<NodeItem> item(new NodeItem);
117   item->id = next_node_id_++;
118   item->node = std::move(node);
119   item->state = NodeState::kPENDING;
120 
121   status = item->node->Prepare();
122   if (!status.ok()) {
123     item->node->Abort(status);
124     return status;
125   }
126 
127   // Inline execution in sync mode.
128   if (!Async()) {
129     // In sync mode, run the node item regardless of executor status.
130     return RunItem(std::move(item), /*from_queue=*/false);
131   } else {
132     tensorflow::mutex_lock l(node_queue_mutex_);
133     DVLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
134              << " with status: " << status_.ToString();
135     if (state_ != ExecutorState::kActive) {
136       status = errors::FailedPrecondition(
137           "EagerExecutor accepts new EagerNodes to run only in Active state. "
138           "Current state is '",
139           StateStringLocked(), "'");
140     } else {
141       status = status_;
142       if (status.ok()) {
143         node_queue_.push(std::move(item));
144         // If there were no previous nodes pending, wake the run thread to
145         // start processing requests again.
146         if (node_queue_.size() == 1) {
147           nodes_pending_.notify_all();
148         }
149 
150         return Status::OK();
151       }
152     }
153   }
154 
155   // If we are unable to add the node to the queue, we must call Abort. However,
156   // we want to do that outside of the scope of the lock since the Abort may
157   // try to call EagerExecutor::AddOrExecute()
158   item->node->Abort(status);
159 
160   return status;
161 }
162 
WaitForAllPendingNodes()163 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
164   tensorflow::mutex_lock l(node_queue_mutex_);
165   return WaitForAllPendingNodesLocked(&l);
166 }
167 
WaitForAllPendingNodesLocked(mutex_lock * lock)168 tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
169     mutex_lock* lock) {
170   tensorflow::condition_variable cond;
171   // Don't wait if an error is already set.
172   if (!status_.ok()) return status_;
173   if (node_queue_.empty() && unfinished_nodes_.empty())
174     return tensorflow::Status::OK();
175   // node_queue_ must be empty in sync mode.
176   DCHECK(Async() || node_queue_.empty());
177   auto last_id = next_node_id_ - 1;
178   DVLOG(3) << "Wait for Node: [id " << last_id << "] ";
179   node_done_notifications_.insert(std::make_pair(last_id, &cond));
180   cond.wait(*lock);
181   // Note that we could be woken up if an error occurs, even though the node has
182   // not actually executed.
183   return status_;
184 }
185 
ClearError()186 void EagerExecutor::ClearError() {
187   // TODO(iga): Check state_ and return an error if it is not kActive.
188   if (ok()) return;
189 
190   tensorflow::mutex_lock l(node_queue_mutex_);
191   // If an error was set, node_done_notifications_ and node_queue_ should have
192   // been cleared, and no new entries should have been added since.
193   DCHECK(node_done_notifications_.empty());
194   DCHECK(node_queue_.empty());
195   status_ = tensorflow::Status::OK();
196   ok_ = true;
197   nodes_pending_.notify_all();
198 }
199 
NodeDone(const core::RefCountPtr<NodeItem> & item,const Status & status,bool from_queue)200 void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
201                              const Status& status, bool from_queue) {
202   DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
203            << " with status: " << status.ToString();
204   DCHECK(item->state != NodeState::kDONE);
205   item->state = NodeState::kDONE;
206 
207   bool async = item->node->AsAsync() != nullptr;
208   // If executing synchronously we don't need to notify if status is OK since
209   // the node  was never added to the unfinished_nodes_ list and nobody should
210   // ever be waiting for it.
211   if (status.ok() && !from_queue && !async) {
212     return;
213   }
214 
215   std::forward_list<core::RefCountPtr<NodeItem>> items_to_destroy;
216   {
217     mutex_lock l(node_queue_mutex_);
218     if (!status_.ok()) return;
219 
220     bool need_notification = from_queue;
221     if (from_queue) {
222       // Since this was from the async queue, pop it from the front of the queue
223       DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
224       node_queue_.pop();
225     } else if (async) {
226       // If it is an Async node then we will find the node in the unfinished
227       // nodes list. However we only notify if we are at the front of the list
228       // since we don't want to notify any waiters of earlier nodes.
229       need_notification = item->id == unfinished_nodes_.begin()->first;
230       auto result = unfinished_nodes_.erase(item->id);
231       DCHECK_GT(result, 0);
232     }
233 
234     if (!status.ok()) {
235       // Since we received an error, broadcast to any waiters.
236       need_notification = true;
237       status_ = status;
238       ok_ = false;
239       if (Async()) {
240         // We remove any pending ops so that we don't try to execute them if
241         // ClearError is called.
242         errors::AppendToMessage(&status_,
243                                 "Encountered when executing an operation using "
244                                 "EagerExecutor. This error cancels all future "
245                                 "operations and poisons their output tensors.");
246       }
247       while (!node_queue_.empty()) {
248         items_to_destroy.push_front(std::move(node_queue_.front()));
249         node_queue_.pop();
250       }
251       for (auto& it : unfinished_nodes_) {
252         items_to_destroy.push_front(std::move(it.second));
253       }
254       unfinished_nodes_.clear();
255     }
256     if (need_notification) {
257       NotifyWaiters(item->id);
258     }
259   }
260 
261   for (auto& item : items_to_destroy) {
262     item->node->Abort(status);
263   }
264   // nodes_to_destroy will be destructed here, while not holding
265   // node_queue_mutex_. This is important because, unfortunately, some nodes'
266   // destructors can enqueue more operations onto this executor and cause
267   // a deadlock.
268 }
269 
NotifyWaiters(uint64 id)270 void EagerExecutor::NotifyWaiters(uint64 id) {
271   if (!node_done_notifications_.empty()) {
272     uint64 upperbound_id = 0;
273     if (!unfinished_nodes_.empty()) {
274       upperbound_id = unfinished_nodes_.begin()->first - 1;
275     } else if (!node_queue_.empty()) {
276       upperbound_id = node_queue_.front()->id - 1;
277     } else {
278       upperbound_id = next_node_id_ - 1;
279     }
280     DVLOG(3) << "Notify node done: [id " << id << " to " << upperbound_id
281              << "] ";
282     // Note that we notify all waiting threads in case an error has
283     // occurred. These calling threads are responsible for checking status_
284     // before proceeding.
285     const auto range =
286         status_.ok()
287             ? make_pair(node_done_notifications_.lower_bound(id),
288                         node_done_notifications_.upper_bound(upperbound_id))
289             : make_pair(node_done_notifications_.begin(),
290                         node_done_notifications_.end());
291     for (auto it = range.first; it != range.second; ++it) {
292       it->second->notify_all();
293     }
294     node_done_notifications_.erase(range.first, range.second);
295   }
296 }
297 
Run()298 void EagerExecutor::Run() {
299   auto thread_exited_notifier =
300       gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
301   while (true) {
302     core::RefCountPtr<NodeItem> curr_item;
303     {
304       tensorflow::mutex_lock l(node_queue_mutex_);
305       while (node_queue_.empty() || !status_.ok()) {
306         if (state_ == ExecutorState::kShutDown) return;
307         nodes_pending_.wait(l);
308       }
309       // Obtain raw pointer since we don't want to remove from the queue until
310       // the node has been run. Otherwise, WaitForAllPendingNodes can return
311       // too early.
312       // Note, we don't std::move from the here because the front of the queue
313       // will then contain a nullptr. This can be a problem in
314       // WaitForAllPendingNodes where we get the top EagerNode pointer
315       // and register a notification for its completion.
316       curr_item.reset(node_queue_.front().get());
317       curr_item->Ref();
318     }
319     Status status = RunItem(std::move(curr_item), /*from_queue=*/true);
320     if (!status.ok()) {
321       VLOG(1) << "Failed to run item: " << status;
322     }
323   }
324 }
325 
RunItem(core::RefCountPtr<NodeItem> item,bool from_queue)326 Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
327                               bool from_queue) {
328   DVLOG(3) << "Running Node: [id " << item->id << "] "
329            << item->node->DebugString();
330   AsyncEagerNode* async_node = item->node->AsAsync();
331   if (async_node == nullptr) {
332     tensorflow::Status status = item->node->Run();
333     NodeDone(item, status, from_queue);
334     return status;
335   }
336 
337   item->state = NodeState::kSCHEDULED;
338   auto async_ref = item.get();
339   async_ref->Ref();
340 
341   TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue));
342 
343   async_node->RunAsync([this, async_ref](const Status& status) {
344     core::RefCountPtr<NodeItem> async_item(async_ref);
345     NodeDone(async_item, status, false);
346   });
347 
348   // Return the status of the executor in case we are in an error state.
349   return status();
350 }
351 
MoveToUnfinished(core::RefCountPtr<NodeItem> item,bool from_queue)352 Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
353                                        bool from_queue) {
354   tensorflow::mutex_lock l(node_queue_mutex_);
355   if (!status_.ok()) {
356     return status_;
357   }
358 
359   if (from_queue) {
360     DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
361     node_queue_.pop();
362   }
363 
364   DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
365   unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
366                                  std::move(item));
367 
368   return Status::OK();
369 }
370 
371 }  // namespace tensorflow
372