1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/queue_base.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/platform/mutex.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow/core/util/batch_util.h"
25
26 namespace tensorflow {
27
28 namespace {
29
30 template <DataType DT>
HandleSliceToElement(const Tensor & parent,Tensor * element,int64 index)31 Status HandleSliceToElement(const Tensor& parent, Tensor* element,
32 int64 index) {
33 typedef typename EnumToDataType<DT>::Type T;
34 DCHECK_NE(parent.dim_size(0), 0);
35 DCHECK_GE(index, 0);
36 if (element->NumElements() != (parent.NumElements() / parent.dim_size(0))) {
37 TensorShape chip_shape = parent.shape();
38 chip_shape.RemoveDim(0);
39 return errors::Internal(
40 "HandleSliceToElement Cannot copy slice: number of elements does not "
41 "match. Shapes are: [element]: ",
42 element->shape().DebugString(),
43 ", [parent slice]: ", chip_shape.DebugString());
44 }
45 auto parent_as_matrix = parent.flat_outer_dims<T>();
46 element->flat<T>() = parent_as_matrix.chip(index, 0);
47 return Status::OK();
48 }
49
50 } // namespace
51
QueueBase(int32 capacity,const DataTypeVector & component_dtypes,const std::vector<TensorShape> & component_shapes,const string & name)52 QueueBase::QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
53 const std::vector<TensorShape>& component_shapes,
54 const string& name)
55 : capacity_(capacity),
56 component_dtypes_(component_dtypes),
57 component_shapes_(component_shapes),
58 name_(name),
59 closed_(false) {}
60
~QueueBase()61 QueueBase::~QueueBase() {}
62
ValidateTupleCommon(const Tuple & tuple) const63 Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const {
64 if (tuple.size() != static_cast<size_t>(num_components())) {
65 return errors::InvalidArgument(
66 "Wrong number of components in tuple. Expected ", num_components(),
67 ", got ", tuple.size());
68 }
69 for (size_t i = 0; i < tuple.size(); ++i) {
70 if (tuple[i].dtype() != component_dtypes_[i]) {
71 return errors::InvalidArgument(
72 "Type mismatch in tuple component ", i, ". Expected ",
73 DataTypeString(component_dtypes_[i]), ", got ",
74 DataTypeString(tuple[i].dtype()));
75 }
76 }
77 return Status::OK();
78 }
79
80 // static
ShapeListString(const gtl::ArraySlice<TensorShape> & shapes)81 string QueueBase::ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) {
82 string result = "[";
83 bool first = true;
84 for (const TensorShape& shape : shapes) {
85 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
86 first = false;
87 }
88 strings::StrAppend(&result, "]");
89 return result;
90 }
91
MatchesNodeDefOp(const NodeDef & node_def,const string & op) const92 Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def,
93 const string& op) const {
94 if (node_def.op() != op) {
95 return errors::InvalidArgument("Shared queue '", name_, "' has type '", op,
96 "' that does not match type of Node '",
97 node_def.name(), "': ", node_def.op());
98 }
99 return Status::OK();
100 }
101
MatchesNodeDefCapacity(const NodeDef & node_def,int32 capacity) const102 Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def,
103 int32 capacity) const {
104 int32 requested_capacity = -1;
105 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "capacity", &requested_capacity));
106 if (requested_capacity < 0) requested_capacity = kUnbounded;
107 if (requested_capacity != capacity) {
108 return errors::InvalidArgument("Shared queue '", name_, "' has capacity ",
109 capacity, " but requested capacity was ",
110 requested_capacity);
111 }
112 return Status::OK();
113 }
114
MatchesNodeDefTypes(const NodeDef & node_def) const115 Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const {
116 DataTypeVector requested_dtypes;
117 TF_RETURN_IF_ERROR(
118 GetNodeAttr(node_def, "component_types", &requested_dtypes));
119 if (requested_dtypes != component_dtypes_) {
120 return errors::InvalidArgument("Shared queue '", name_,
121 "' has component types ",
122 DataTypeSliceString(component_dtypes_),
123 " but requested component types were ",
124 DataTypeSliceString(requested_dtypes));
125 }
126 return Status::OK();
127 }
128
MatchesNodeDefShapes(const NodeDef & node_def) const129 Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const {
130 std::vector<TensorShape> requested_shapes;
131 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
132 if (requested_shapes != component_shapes_) {
133 return errors::InvalidArgument("Shared queue '", name_,
134 "' has component shapes ",
135 ShapeListString(component_shapes_),
136 " but requested component shapes were ",
137 ShapeListString(requested_shapes));
138 }
139 return Status::OK();
140 }
141
142 // TODO(mrry): If these checks become a bottleneck, find a way to
143 // reduce the number of times that they are called.
ValidateTuple(const Tuple & tuple)144 Status QueueBase::ValidateTuple(const Tuple& tuple) {
145 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
146 if (specified_shapes()) {
147 for (size_t i = 0; i < tuple.size(); ++i) {
148 if (!component_shapes_[i].IsSameSize(tuple[i].shape())) {
149 return errors::InvalidArgument(
150 "Shape mismatch in tuple component ", i, ". Expected ",
151 component_shapes_[i].DebugString(), ", got ",
152 tuple[i].shape().DebugString());
153 }
154 }
155 }
156 return Status::OK();
157 }
158
159 // TODO(mrry): If these checks become a bottleneck, find a way to
160 // reduce the number of times that they are called.
ValidateManyTuple(const Tuple & tuple)161 Status QueueBase::ValidateManyTuple(const Tuple& tuple) {
162 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
163 const int64 batch_size = tuple[0].dim_size(0);
164 if (specified_shapes()) {
165 for (size_t i = 0; i < tuple.size(); ++i) {
166 // Expected shape is [batch_size] + component_shapes_[i]
167 const TensorShape expected_shape = ManyOutShape(i, batch_size);
168 if (!expected_shape.IsSameSize(tuple[i].shape())) {
169 return errors::InvalidArgument("Shape mismatch in tuple component ", i,
170 ". Expected ",
171 expected_shape.DebugString(), ", got ",
172 tuple[i].shape().DebugString());
173 }
174 }
175 } else {
176 for (size_t i = 1; i < tuple.size(); ++i) {
177 if (tuple[i].dim_size(0) != batch_size) {
178 return errors::InvalidArgument(
179 "All input tensors must have the same size in the 0th ",
180 "dimension. Component ", i, " has ", tuple[i].dim_size(0),
181 ", and should have ", batch_size);
182 }
183 }
184 }
185 return Status::OK();
186 }
187
Cancel(Action action,CancellationManager * cancellation_manager,CancellationToken token)188 void QueueBase::Cancel(Action action, CancellationManager* cancellation_manager,
189 CancellationToken token) {
190 DoneCallback callback = nullptr;
191 {
192 mutex_lock lock(mu_);
193 std::deque<Attempt>* attempts =
194 action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
195
196 for (Attempt& attempt : *attempts) {
197 if (attempt.cancellation_manager == cancellation_manager &&
198 attempt.cancellation_token == token) {
199 if (!attempt.is_cancelled) {
200 attempt.is_cancelled = true;
201 if (action == kEnqueue) {
202 attempt.context->SetStatus(
203 errors::Cancelled("Enqueue operation was cancelled"));
204 } else {
205 attempt.context->SetStatus(
206 errors::Cancelled("Dequeue operation was cancelled"));
207 }
208 std::swap(callback, attempt.done_callback);
209 }
210 break;
211 }
212 }
213 }
214 if (callback) {
215 callback();
216 FlushUnlocked();
217 }
218 }
219
CloseAndCancel()220 void QueueBase::CloseAndCancel() {
221 std::vector<DoneCallback> callbacks;
222 {
223 mutex_lock lock(mu_);
224 closed_ = true;
225 for (Attempt& attempt : enqueue_attempts_) {
226 if (!attempt.is_cancelled) {
227 attempt.is_cancelled = true;
228 attempt.context->SetStatus(
229 errors::Cancelled("Enqueue operation was cancelled"));
230 callbacks.emplace_back(std::move(attempt.done_callback));
231 }
232 }
233 }
234 for (const DoneCallback& callback : callbacks) {
235 callback();
236 }
237 FlushUnlocked();
238 }
239
Close(OpKernelContext * ctx,bool cancel_pending_enqueues,DoneCallback callback)240 void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
241 DoneCallback callback) {
242 if (cancel_pending_enqueues) {
243 CloseAndCancel();
244 callback();
245 } else {
246 {
247 mutex_lock lock(mu_);
248 enqueue_attempts_.emplace_back(
249 0, callback, ctx, nullptr, CancellationManager::kInvalidToken,
250 [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
251 if (closed_) {
252 attempt->context->SetStatus(
253 errors::Cancelled("Queue '", name_, "' is already closed."));
254 } else {
255 closed_ = true;
256 }
257 return kComplete;
258 });
259 }
260 FlushUnlocked();
261 }
262 }
263
TryAttemptLocked(Action action,std::vector<CleanUp> * clean_up)264 bool QueueBase::TryAttemptLocked(Action action,
265 std::vector<CleanUp>* clean_up) {
266 std::deque<Attempt>* attempts =
267 action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
268
269 bool progress = false;
270 bool done = false;
271 while (!done && !attempts->empty()) {
272 if (attempts->front().is_cancelled) {
273 if (action == kEnqueue) {
274 if (closed_) {
275 VLOG(1) << "Skipping cancelled enqueue attempt";
276 } else {
277 LOG(WARNING)
278 << name_
279 << ": Skipping cancelled enqueue attempt with queue not closed";
280 }
281 } else {
282 if (closed_) {
283 VLOG(1) << "Skipping cancelled dequeue attempt";
284 } else {
285 LOG(WARNING)
286 << name_
287 << ": Skipping cancelled dequeue attempt with queue not closed";
288 }
289 }
290 attempts->pop_front();
291 } else {
292 Attempt* cur_attempt = &attempts->front();
293 switch (cur_attempt->run_callback(cur_attempt)) {
294 case kNoProgress:
295 done = true;
296 break;
297 case kProgress:
298 done = true;
299 progress = true;
300 break;
301 case kComplete:
302 progress = true;
303 clean_up->emplace_back(std::move(cur_attempt->done_callback),
304 cur_attempt->cancellation_token,
305 cur_attempt->context->cancellation_manager());
306 attempts->pop_front();
307 break;
308 }
309 }
310 }
311 return progress;
312 }
313
FlushUnlocked()314 void QueueBase::FlushUnlocked() {
315 std::vector<CleanUp> clean_up;
316 Ref();
317 {
318 mutex_lock lock(mu_);
319 bool changed;
320 do {
321 changed = TryAttemptLocked(kEnqueue, &clean_up);
322 changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
323 } while (changed);
324 }
325 Unref();
326 for (const auto& to_clean : clean_up) {
327 if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
328 // NOTE(mrry): We can safely ignore the return value of
329 // DeregisterCallback because the mutex mu_ ensures that the
330 // cleanup action only executes once.
331 to_clean.cm->DeregisterCallback(to_clean.to_deregister);
332 }
333 to_clean.finished();
334 }
335 }
336
CopySliceToElement(const Tensor & parent,Tensor * element,int64 index)337 Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
338 int64 index) {
339 return batch_util::CopySliceToElement(parent, element, index);
340 }
341
342 /* static */
CopyElementToSlice(const Tensor & element,Tensor * parent,int64 index)343 Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
344 int64 index) {
345 return batch_util::CopyElementToSlice(element, parent, index);
346 }
347
348 } // namespace tensorflow
349