1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/training/queue_runner.h"
17 #include "tensorflow/core/kernels/ops_util.h"
18 #include "tensorflow/core/platform/env.h"
19
20 namespace tensorflow {
21
New(const QueueRunnerDef & queue_runner_def,std::unique_ptr<QueueRunner> * result)22 Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
23 std::unique_ptr<QueueRunner>* result) {
24 result->reset(new QueueRunner());
25 return (*result)->Init(queue_runner_def);
26 }
27
New(const QueueRunnerDef & queue_runner_def,Coordinator * coord,std::unique_ptr<QueueRunner> * result)28 Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
29 Coordinator* coord,
30 std::unique_ptr<QueueRunner>* result) {
31 result->reset(new QueueRunner());
32 (*result)->coord_ = coord;
33 return (*result)->Init(queue_runner_def);
34 }
35
AddErrorCallback(const std::function<void (Status)> & cb)36 void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) {
37 mutex_lock l(cb_mu_);
38 callbacks_.push_back(cb);
39 }
40
ClearErrorCallbacks()41 void QueueRunner::ClearErrorCallbacks() {
42 mutex_lock l(cb_mu_);
43 callbacks_.clear();
44 }
45
Init(const QueueRunnerDef & queue_runner_def)46 Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
47 queue_name_ = queue_runner_def.queue_name();
48 enqueue_op_names_.clear();
49 enqueue_op_names_.insert(enqueue_op_names_.end(),
50 queue_runner_def.enqueue_op_name().begin(),
51 queue_runner_def.enqueue_op_name().end());
52 size_t op_names_size = enqueue_op_names_.size();
53 if (op_names_size > kint32max) {
54 return Status(error::INVALID_ARGUMENT,
55 "Enqueue ops to run cannot exceed kint32max");
56 }
57 runs_ = static_cast<int>(op_names_size);
58 if (runs_ == 0) {
59 return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run.");
60 }
61 close_op_name_ = queue_runner_def.close_op_name();
62 cancel_op_name_ = queue_runner_def.cancel_op_name();
63 if (queue_runner_def.queue_closed_exception_types_size() == 0) {
64 queue_closed_exception_types_.insert(error::OUT_OF_RANGE);
65 } else {
66 for (const auto& code : queue_runner_def.queue_closed_exception_types()) {
67 queue_closed_exception_types_.insert(static_cast<int>(code));
68 }
69 }
70
71 int nthreads = runs_;
72 if (coord_) {
73 // One more thread to call Stop()
74 nthreads++;
75 }
76 thread_pool_.reset(new thread::ThreadPool(
77 Env::Default(), SanitizeThreadSuffix(queue_name_), nthreads));
78
79 return Status::OK();
80 }
81
~QueueRunner()82 QueueRunner::~QueueRunner() {
83 // Cannot run Stop() here because the session might already be closed or
84 // destroyed.
85 Join().IgnoreError();
86 }
87
Start(Session * sess)88 Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
89
StartAndCollectCostGraph(Session * sess,const RunOptions & run_options)90 Status QueueRunner::StartAndCollectCostGraph(Session* sess,
91 const RunOptions& run_options) {
92 SetRunArgumentsAndCostGraph(run_options);
93 return Start(sess, 0);
94 }
95
Start(Session * sess,int wait_for)96 Status QueueRunner::Start(Session* sess, int wait_for) {
97 counter_.reset(new BlockingCounter(runs_));
98 for (const string& enqueue_op : enqueue_op_names_) {
99 thread_pool_->Schedule(
100 std::bind(&QueueRunner::Run, this, sess, enqueue_op));
101 }
102 if (coord_) {
103 thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess));
104 }
105 // Wait for up to 'wait_for' milliseconds.
106 if (wait_for > 0) {
107 if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) {
108 return Status(error::DEADLINE_EXCEEDED,
109 "Queues not fed before the timeout");
110 }
111 // Check the status of the queue runner as well as the result of the enqueue
112 // operations.
113 mutex_lock l(mu_);
114 if (!enqueue_status_.ok()) {
115 return enqueue_status_;
116 } else {
117 return status_;
118 }
119 }
120 return Status::OK();
121 }
122
StartAndCollectCostGraph(Session * session,int wait_for_ms,const RunOptions & run_options)123 Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms,
124 const RunOptions& run_options) {
125 SetRunArgumentsAndCostGraph(run_options);
126 return Start(session, wait_for_ms);
127 }
128
Stop(Session * sess)129 void QueueRunner::Stop(Session* sess) {
130 if (coord_ != nullptr) {
131 coord_->WaitForStop();
132 }
133 if (!cancel_op_name_.empty()) {
134 UpdateStatus(RealRun(sess, cancel_op_name_, false));
135 }
136 stopped_ = true;
137 }
138
Join()139 Status QueueRunner::Join() {
140 thread_pool_.reset();
141 mutex_lock l(mu_);
142 return status_;
143 }
144
UpdateStatus(const Status & status)145 void QueueRunner::UpdateStatus(const Status& status) {
146 {
147 mutex_lock l(mu_);
148 if (!status_.ok() || status.ok() || IsQueueClosed(status)) {
149 return;
150 }
151 status_ = status;
152 }
153 if (coord_) {
154 coord_->ReportStatus(status);
155 }
156 mutex_lock l(cb_mu_);
157 for (auto& cb : callbacks_) {
158 cb(status);
159 }
160 }
161
Run(Session * sess,const string & enqueue_op)162 void QueueRunner::Run(Session* sess, const string& enqueue_op) {
163 bool first_iteration = true;
164 Status status;
165 while (status.ok()) {
166 if (coord_ && coord_->ShouldStop()) {
167 break;
168 }
169 status = RealRun(sess, enqueue_op, true);
170 if (first_iteration) {
171 if (!status.ok()) {
172 mutex_lock l(mu_);
173 enqueue_status_ = status;
174 }
175 counter_->DecrementCount();
176 first_iteration = false;
177 }
178 }
179 bool last_run = false;
180 {
181 mutex_lock l(mu_);
182 runs_--;
183 last_run = (runs_ == 0);
184 }
185
186 // Close the queue unless the coordinator is shutting down since the cancel op
187 // will be run anway in this case.
188 if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) {
189 if (last_run && !close_op_name_.empty()) {
190 UpdateStatus(RealRun(sess, close_op_name_, false));
191 }
192 } else if (!status.ok()) {
193 LOG(ERROR) << "Queue runner thread got a failure status: "
194 << status.ToString();
195 UpdateStatus(status);
196 if (coord_) {
197 coord_->RequestStop().IgnoreError();
198 }
199 }
200 }
201
GetStatus()202 Status QueueRunner::GetStatus() {
203 mutex_lock l(mu_);
204 return status_;
205 }
206
ExportCostGraph(CostGraphDef * cost_graph) const207 Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const {
208 if (!cg_mu_) {
209 return Status(error::FAILED_PRECONDITION,
210 "This QueueRunner doesn't collect a cost graph.");
211 }
212 mutex_lock l(*cg_mu_);
213 cost_graph->MergeFrom(*cost_graph_);
214 return Status::OK();
215 }
216
SetRunArgumentsAndCostGraph(const RunOptions & run_options)217 void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions& run_options) {
218 cg_mu_.reset(new mutex());
219 {
220 mutex_lock l(*cg_mu_);
221 cost_graph_.reset(new CostGraphDef());
222 }
223 run_options_ = run_options;
224 }
225
RealRun(Session * sess,const string & op,bool update_costs)226 Status QueueRunner::RealRun(Session* sess, const string& op,
227 bool update_costs) {
228 Status s;
229 if (update_costs && cg_mu_) {
230 RunMetadata metadata;
231 s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata);
232 mutex_lock l(*cg_mu_);
233 cost_graph_->Swap(metadata.mutable_cost_graph());
234 } else {
235 s = sess->Run({}, {}, {op}, nullptr);
236 }
237 return s;
238 }
239
240 } // namespace tensorflow
241