1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
17
18 namespace tensorflow {
19
EagerNode(tensorflow::uint64 id)20 EagerNode::EagerNode(tensorflow::uint64 id) : id(id) {}
21
~EagerExecutor()22 EagerExecutor::~EagerExecutor() {
23 tensorflow::mutex_lock l(node_queue_mutex_);
24 thread_done_ = true;
25 nodes_pending_.notify_all();
26 }
27
NextId()28 tensorflow::uint64 EagerExecutor::NextId() {
29 tensorflow::mutex_lock l(next_id_mutex_);
30 return next_id_++;
31 }
32
EnableAsync()33 void EagerExecutor::EnableAsync() {
34 tensorflow::mutex_lock l(node_queue_mutex_);
35 if (thread_ == nullptr) {
36 thread_.reset(tensorflow::Env::Default()->StartThread(
37 tensorflow::ThreadOptions(), "eager_async_executor",
38 std::bind(&EagerExecutor::Run, this)));
39 }
40 }
41
Add(EagerNode * node)42 void EagerExecutor::Add(EagerNode* node) {
43 tensorflow::mutex_lock l(node_queue_mutex_);
44 DCHECK(thread_) << "EnableAsync should have been called before Add";
45 if (!status_.ok()) {
46 delete node;
47 return;
48 }
49 int64 qlen = node_queue_.size();
50 if (qlen > 0) {
51 if (node_queue_.back()->id >= node->id) {
52 status_ = tensorflow::errors::InvalidArgument(
53 "Inserting EagerNode with non-increasing ids:",
54 node_queue_.back()->id, " vs ", node->id);
55 delete node;
56 return;
57 }
58 node_queue_.push(node);
59 } else {
60 node_queue_.push(node);
61 nodes_pending_.notify_all();
62 }
63 }
64
WaitFor(tensorflow::uint64 node_id)65 tensorflow::Status EagerExecutor::WaitFor(tensorflow::uint64 node_id) {
66 return WaitImpl(false, node_id);
67 }
68
WaitForAllPendingNodes()69 tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
70 return WaitImpl(true, 0);
71 }
72
WaitImpl(bool wait_all,tensorflow::uint64 node_id)73 tensorflow::Status EagerExecutor::WaitImpl(bool wait_all,
74 tensorflow::uint64 node_id) {
75 tensorflow::condition_variable cond;
76 tensorflow::mutex_lock l(node_queue_mutex_);
77 // Don't wait if an error is already set.
78 if (!status_.ok()) return status_;
79 if (node_queue_.empty()) return tensorflow::Status::OK();
80 if (wait_all) {
81 node_id = node_queue_.back()->id;
82 } else if (node_id < node_queue_.front()->id) {
83 // Note that we are relying on the ops being dispatched sequentially from
84 // the queue.
85 return tensorflow::Status::OK();
86 }
87 node_done_notifications_.insert(std::make_pair(node_id, &cond));
88 cond.wait(l);
89 // Note that we could be woken up if an error occurs, even though the node has
90 // not actually executed.
91 return status_;
92 }
93
ClearError()94 void EagerExecutor::ClearError() {
95 tensorflow::mutex_lock l(node_queue_mutex_);
96 if (status_.ok()) return;
97 // If an error was set, node_done_notifications_ and node_queue_ should have
98 // been cleared, and no new entries should have been added since.
99 DCHECK(node_done_notifications_.empty());
100 DCHECK(node_queue_.empty());
101 status_ = tensorflow::Status::OK();
102 nodes_pending_.notify_all();
103 }
104
status()105 tensorflow::Status EagerExecutor::status() {
106 tensorflow::mutex_lock l(node_queue_mutex_);
107 return status_;
108 }
109
Run()110 void EagerExecutor::Run() {
111 while (true) {
112 std::unique_ptr<EagerNode> curr_node;
113 {
114 tensorflow::mutex_lock l(node_queue_mutex_);
115 while (node_queue_.empty() || !status_.ok()) {
116 if (thread_done_) return;
117 nodes_pending_.wait(l);
118 }
119 curr_node.reset(node_queue_.front());
120 }
121 tensorflow::Status status = curr_node->Run();
122 const bool ok = status.ok();
123 tensorflow::mutex_lock l(node_queue_mutex_);
124 node_queue_.pop();
125 if (!ok) {
126 status_ = status;
127 // TODO(agarwal): mark all affected handles as corrupted before clearing
128 // this queue.
129 // We remove any pending ops so that we don't try to execute them if
130 // ClearError is called.
131 for (int i = 0; i < node_queue_.size(); ++i) {
132 delete node_queue_.front();
133 node_queue_.pop();
134 }
135 }
136 if (!node_done_notifications_.empty()) {
137 tensorflow::uint64 node_id = curr_node->id;
138 // Note that we notify all waiting threads in case an error has occurred.
139 // These calling threads are responsible for checking status_ before
140 // proceeding.
141 const auto range = ok ? node_done_notifications_.equal_range(node_id)
142 : make_pair(node_done_notifications_.begin(),
143 node_done_notifications_.end());
144 for (auto it = range.first; it != range.second; ++it) {
145 it->second->notify_all();
146 }
147 node_done_notifications_.erase(range.first, range.second);
148 }
149 }
150 }
151
152 } // namespace tensorflow
153