• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/training/queue_runner.h"
17 #include "tensorflow/core/kernels/ops_util.h"
18 #include "tensorflow/core/platform/env.h"
19 
20 namespace tensorflow {
21 
New(const QueueRunnerDef & queue_runner_def,std::unique_ptr<QueueRunner> * result)22 Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
23                         std::unique_ptr<QueueRunner>* result) {
24   result->reset(new QueueRunner());
25   return (*result)->Init(queue_runner_def);
26 }
27 
New(const QueueRunnerDef & queue_runner_def,Coordinator * coord,std::unique_ptr<QueueRunner> * result)28 Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
29                         Coordinator* coord,
30                         std::unique_ptr<QueueRunner>* result) {
31   result->reset(new QueueRunner());
32   (*result)->coord_ = coord;
33   return (*result)->Init(queue_runner_def);
34 }
35 
AddErrorCallback(const std::function<void (Status)> & cb)36 void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) {
37   mutex_lock l(cb_mu_);
38   callbacks_.push_back(cb);
39 }
40 
ClearErrorCallbacks()41 void QueueRunner::ClearErrorCallbacks() {
42   mutex_lock l(cb_mu_);
43   callbacks_.clear();
44 }
45 
Init(const QueueRunnerDef & queue_runner_def)46 Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
47   queue_name_ = queue_runner_def.queue_name();
48   enqueue_op_names_.clear();
49   enqueue_op_names_.insert(enqueue_op_names_.end(),
50                            queue_runner_def.enqueue_op_name().begin(),
51                            queue_runner_def.enqueue_op_name().end());
52   size_t op_names_size = enqueue_op_names_.size();
53   if (op_names_size > kint32max) {
54     return Status(error::INVALID_ARGUMENT,
55                   "Enqueue ops to run cannot exceed kint32max");
56   }
57   runs_ = static_cast<int>(op_names_size);
58   if (runs_ == 0) {
59     return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run.");
60   }
61   close_op_name_ = queue_runner_def.close_op_name();
62   cancel_op_name_ = queue_runner_def.cancel_op_name();
63   if (queue_runner_def.queue_closed_exception_types_size() == 0) {
64     queue_closed_exception_types_.insert(error::OUT_OF_RANGE);
65   } else {
66     for (const auto& code : queue_runner_def.queue_closed_exception_types()) {
67       queue_closed_exception_types_.insert(static_cast<int>(code));
68     }
69   }
70 
71   int nthreads = runs_;
72   if (coord_) {
73     // One more thread to call Stop()
74     nthreads++;
75   }
76   thread_pool_.reset(new thread::ThreadPool(
77       Env::Default(), SanitizeThreadSuffix(queue_name_), nthreads));
78 
79   return Status::OK();
80 }
81 
~QueueRunner()82 QueueRunner::~QueueRunner() {
83   // Cannot run Stop() here because the session might already be closed or
84   // destroyed.
85   Join().IgnoreError();
86 }
87 
Start(Session * sess)88 Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
89 
StartAndCollectCostGraph(Session * sess,const RunOptions & run_options)90 Status QueueRunner::StartAndCollectCostGraph(Session* sess,
91                                              const RunOptions& run_options) {
92   SetRunArgumentsAndCostGraph(run_options);
93   return Start(sess, 0);
94 }
95 
Start(Session * sess,int wait_for)96 Status QueueRunner::Start(Session* sess, int wait_for) {
97   counter_.reset(new BlockingCounter(runs_));
98   for (const string& enqueue_op : enqueue_op_names_) {
99     thread_pool_->Schedule(
100         std::bind(&QueueRunner::Run, this, sess, enqueue_op));
101   }
102   if (coord_) {
103     thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess));
104   }
105   // Wait for up to 'wait_for' milliseconds.
106   if (wait_for > 0) {
107     if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) {
108       return Status(error::DEADLINE_EXCEEDED,
109                     "Queues not fed before the timeout");
110     }
111     // Check the status of the queue runner as well as the result of the enqueue
112     // operations.
113     mutex_lock l(mu_);
114     if (!enqueue_status_.ok()) {
115       return enqueue_status_;
116     } else {
117       return status_;
118     }
119   }
120   return Status::OK();
121 }
122 
StartAndCollectCostGraph(Session * session,int wait_for_ms,const RunOptions & run_options)123 Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms,
124                                              const RunOptions& run_options) {
125   SetRunArgumentsAndCostGraph(run_options);
126   return Start(session, wait_for_ms);
127 }
128 
Stop(Session * sess)129 void QueueRunner::Stop(Session* sess) {
130   if (coord_ != nullptr) {
131     coord_->WaitForStop();
132   }
133   if (!cancel_op_name_.empty()) {
134     UpdateStatus(RealRun(sess, cancel_op_name_, false));
135   }
136   stopped_ = true;
137 }
138 
Join()139 Status QueueRunner::Join() {
140   thread_pool_.reset();
141   mutex_lock l(mu_);
142   return status_;
143 }
144 
UpdateStatus(const Status & status)145 void QueueRunner::UpdateStatus(const Status& status) {
146   {
147     mutex_lock l(mu_);
148     if (!status_.ok() || status.ok() || IsQueueClosed(status)) {
149       return;
150     }
151     status_ = status;
152   }
153   if (coord_) {
154     coord_->ReportStatus(status);
155   }
156   mutex_lock l(cb_mu_);
157   for (auto& cb : callbacks_) {
158     cb(status);
159   }
160 }
161 
Run(Session * sess,const string & enqueue_op)162 void QueueRunner::Run(Session* sess, const string& enqueue_op) {
163   bool first_iteration = true;
164   Status status;
165   while (status.ok()) {
166     if (coord_ && coord_->ShouldStop()) {
167       break;
168     }
169     status = RealRun(sess, enqueue_op, true);
170     if (first_iteration) {
171       if (!status.ok()) {
172         mutex_lock l(mu_);
173         enqueue_status_ = status;
174       }
175       counter_->DecrementCount();
176       first_iteration = false;
177     }
178   }
179   bool last_run = false;
180   {
181     mutex_lock l(mu_);
182     runs_--;
183     last_run = (runs_ == 0);
184   }
185 
186   // Close the queue unless the coordinator is shutting down since the cancel op
187   // will be run anway in this case.
188   if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) {
189     if (last_run && !close_op_name_.empty()) {
190       UpdateStatus(RealRun(sess, close_op_name_, false));
191     }
192   } else if (!status.ok()) {
193     LOG(ERROR) << "Queue runner thread got a failure status: "
194                << status.ToString();
195     UpdateStatus(status);
196     if (coord_) {
197       coord_->RequestStop().IgnoreError();
198     }
199   }
200 }
201 
GetStatus()202 Status QueueRunner::GetStatus() {
203   mutex_lock l(mu_);
204   return status_;
205 }
206 
ExportCostGraph(CostGraphDef * cost_graph) const207 Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const {
208   if (!cg_mu_) {
209     return Status(error::FAILED_PRECONDITION,
210                   "This QueueRunner doesn't collect a cost graph.");
211   }
212   mutex_lock l(*cg_mu_);
213   cost_graph->MergeFrom(*cost_graph_);
214   return Status::OK();
215 }
216 
SetRunArgumentsAndCostGraph(const RunOptions & run_options)217 void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions& run_options) {
218   cg_mu_.reset(new mutex());
219   {
220     mutex_lock l(*cg_mu_);
221     cost_graph_.reset(new CostGraphDef());
222   }
223   run_options_ = run_options;
224 }
225 
RealRun(Session * sess,const string & op,bool update_costs)226 Status QueueRunner::RealRun(Session* sess, const string& op,
227                             bool update_costs) {
228   Status s;
229   if (update_costs && cg_mu_) {
230     RunMetadata metadata;
231     s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata);
232     mutex_lock l(*cg_mu_);
233     cost_graph_->Swap(metadata.mutable_cost_graph());
234   } else {
235     s = sess->Run({}, {}, {op}, nullptr);
236   }
237   return s;
238 }
239 
240 }  // namespace tensorflow
241