1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/data_flow_ops.cc.
17
18 #include <deque>
19 #include <vector>
20
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/padding_fifo_queue.h"
27 #include "tensorflow/core/kernels/queue_base.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/batch_util.h"
33
34 namespace tensorflow {
35
PaddingFIFOQueue(int capacity,const DataTypeVector & component_dtypes,const std::vector<PartialTensorShape> & partial_shapes,const string & name)36 PaddingFIFOQueue::PaddingFIFOQueue(
37 int capacity, const DataTypeVector& component_dtypes,
38 const std::vector<PartialTensorShape>& partial_shapes, const string& name)
39 : FIFOQueue(capacity, component_dtypes,
40 ConvertShapesPartialDimensionsToZero(partial_shapes), name),
41 partial_shapes_(partial_shapes) {}
42
Initialize()43 Status PaddingFIFOQueue::Initialize() {
44 Status s = FIFOQueue::Initialize();
45 if (!s.ok()) return s;
46
47 if (component_dtypes_.size() != partial_shapes_.size()) {
48 return errors::InvalidArgument(
49 "Shapes must be provided for all components, but received ",
50 component_dtypes_.size(), " dtypes and ", partial_shapes_.size(),
51 " shapes.");
52 }
53
54 return Status::OK();
55 }
56
57 /* static */
GetElementComponent(const PaddingFIFOQueue::Tuple & tuple,int component,OpKernelContext * ctx,PersistentTensor * out_tensor)58 Status PaddingFIFOQueue::GetElementComponent(
59 const PaddingFIFOQueue::Tuple& tuple, int component, OpKernelContext* ctx,
60 PersistentTensor* out_tensor) {
61 TensorShape element_shape(tuple[component].shape());
62 Tensor* element_access = nullptr;
63 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
64 tuple[component].dtype(), element_shape, out_tensor, &element_access));
65 *element_access = tuple[component];
66 return Status::OK();
67 }
68
TryDequeueMany(int num_elements,OpKernelContext * ctx,bool allow_small_batch,CallbackWithTuple callback)69 void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
70 bool allow_small_batch,
71 CallbackWithTuple callback) {
72 if (num_elements == 0) {
73 Tuple tuple;
74 tuple.reserve(num_components());
75 for (int i = 0; i < num_components(); ++i) {
76 // TODO(josh11b,misard): Switch to allocate_output().
77 // See similar comment in fifo_queue.cc
78 Tensor element;
79 // Here, ManyOutShape returns zeros for undetermined shapes,
80 // which is exactly what we want to use.
81 OP_REQUIRES_OK(ctx, ctx->allocate_temp(component_dtypes_[i],
82 ManyOutShape(i, 0), &element));
83 tuple.emplace_back(element);
84 }
85 callback(tuple);
86 return;
87 }
88
89 CancellationManager* cm = ctx->cancellation_manager();
90 CancellationToken token = cm->get_cancellation_token();
91 bool already_cancelled;
92 {
93 mutex_lock l(mu_);
94 already_cancelled = !cm->RegisterCallback(
95 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
96 if (!already_cancelled) {
97 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
98 dequeue_attempts_.emplace_back(
99 num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
100 [callback, allow_small_batch,
101 this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
102 int32 queue_size = queues_[0].size();
103 if (closed_ && queue_size < attempt->elements_requested) {
104 // If we don't have enough for a full dequeue, we have
105 // to reset the attempt tuple.
106 if (!attempt->tuples.empty()) {
107 // Restore already-dequeued elements to the front of the queue.
108 for (int64 i = attempt->tuples.size() - 1; i >= 0; --i) {
109 for (int j = 0; j < num_components(); ++j) {
110 PersistentTensor element;
111 Status s = GetElementComponent(attempt->tuples[i], j,
112 attempt->context, &element);
113 if (!s.ok()) {
114 attempt->context->SetStatus(
115 errors::DataLoss("Failed to restore element from "
116 "partially-dequeued batch "
117 "to PaddingFIFOQueue: ",
118 s.error_message()));
119 }
120 queues_[j].push_front(element);
121 }
122 }
123 }
124 if (allow_small_batch && !queues_[0].empty()) {
125 // Request all remaining elements in the queue.
126 queue_size = queues_[0].size();
127 attempt->tuples.clear();
128 attempt->elements_requested = queue_size;
129 } else {
130 if (allow_small_batch) {
131 // There may be some enqueue attempts containing
132 // values. If so, we'll yield and wait for them
133 // to add elements to the queue.
134 if (!enqueue_attempts_.empty()) return kProgress;
135 }
136 if (attempt->context->status().ok()) {
137 attempt->context->SetStatus(errors::OutOfRange(
138 "PaddingFIFOQueue '", name_, "' is closed and has ",
139 "insufficient elements (requested ",
140 attempt->elements_requested, ", current size ",
141 queue_size, ")"));
142 }
143 return kComplete;
144 }
145 }
146
147 RunResult result = kNoProgress;
148 for (; queue_size > 0; --queue_size) {
149 result = kProgress;
150 Tuple tuple;
151 DequeueLocked(attempt->context, &tuple);
152 attempt->tuples.push_back(tuple);
153 tuple.clear();
154 --attempt->elements_requested;
155
156 if (attempt->elements_requested == 0) {
157 // Finished. Allocate attempt->tuple and
158 // copy from attempt->tuples to attempt->tuple.
159 attempt->tuple.reserve(num_components());
160 std::vector<Tuple>& tuples = attempt->tuples;
161
162 std::vector<bool> dynamic_shape;
163 const int64 batch_size = tuples.size();
164
165 for (int i = 0; i < num_components(); ++i) {
166 const PartialTensorShape partial_shape =
167 PartialTensorShape({batch_size})
168 .Concatenate(partial_shapes_[i]);
169 TensorShape shape({batch_size});
170
171 for (int j = 0; j < partial_shape.dims() - 1; ++j) {
172 if (partial_shape.dim_size(j + 1) > -1) {
173 shape.AddDim(partial_shape.dim_size(j + 1));
174 } else {
175 // Expand sizes to match.
176 int64 max_val = 0;
177 for (const Tuple& t : tuples) {
178 max_val = std::max(max_val, t[i].shape().dim_size(j));
179 }
180 shape.AddDim(max_val);
181 }
182 }
183
184 Tensor element;
185 attempt->context->SetStatus(attempt->context->allocate_temp(
186 component_dtypes_[i], shape, &element));
187 if (!attempt->context->status().ok()) return kComplete;
188
189 bool has_dynamic_shape = !partial_shape.IsFullyDefined();
190 if (has_dynamic_shape) {
191 // Set all values to zero because not all values
192 // will get written over.
193 attempt->context->SetStatus(SetElementZero(&element));
194 if (!attempt->context->status().ok()) return kComplete;
195 }
196
197 dynamic_shape.push_back(has_dynamic_shape);
198
199 // TODO(ebrevdo): should this be a persistent tensor?
200 attempt->tuple.emplace_back(element);
201 }
202
203 for (size_t index = 0; index < tuples.size(); ++index) {
204 for (int i = 0; i < num_components(); ++i) {
205 if (dynamic_shape[i]) {
206 // Slightly slower copy operation
207 attempt->context->SetStatus(CopyElementToLargerSlice(
208 tuples[index][i], &attempt->tuple[i], index));
209 } else {
210 attempt->context->SetStatus(
211 batch_util::CopyElementToSlice(
212 std::move(tuples[index][i]), &attempt->tuple[i],
213 index));
214 }
215 if (!attempt->context->status().ok()) return kComplete;
216 }
217 }
218 tuple = attempt->tuple;
219 attempt->tuples.clear();
220 attempt->done_callback = [callback, tuple]() {
221 callback(tuple);
222 };
223 return kComplete;
224 }
225 }
226 return result;
227 });
228 }
229 }
230 if (!already_cancelled) {
231 FlushUnlocked();
232 } else {
233 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
234 callback(Tuple());
235 }
236 }
237
ValidateTuple(const Tuple & tuple)238 Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) {
239 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
240 for (size_t i = 0; i < tuple.size(); ++i) {
241 if (!partial_shapes_[i].IsCompatibleWith(tuple[i].shape())) {
242 return errors::InvalidArgument("Shape mismatch in tuple component ", i,
243 ". Expected ",
244 partial_shapes_[i].DebugString(), ", got ",
245 tuple[i].shape().DebugString());
246 }
247 }
248 return Status::OK();
249 }
250
ValidateManyTuple(const Tuple & tuple)251 Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) {
252 TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
253 const int64 batch_size = tuple[0].dim_size(0);
254 for (size_t i = 0; i < tuple.size(); ++i) {
255 // Expected shape is [batch_size] + partial_shapes_[i]
256 const PartialTensorShape expected_shape =
257 PartialTensorShape({batch_size}).Concatenate(partial_shapes_[i]);
258 if (!expected_shape.IsCompatibleWith(tuple[i].shape())) {
259 return errors::InvalidArgument("Shape mismatch in tuple component ", i,
260 ". Expected ",
261 expected_shape.DebugString(), ", got ",
262 tuple[i].shape().DebugString());
263 }
264 }
265 return Status::OK();
266 }
267
CompatibleNodeDefShapes(const NodeDef & node_def) const268 Status PaddingFIFOQueue::CompatibleNodeDefShapes(
269 const NodeDef& node_def) const {
270 std::vector<PartialTensorShape> requested_shapes;
271 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
272 if (!PartialTensorShapeUtils::AreCompatible(requested_shapes,
273 partial_shapes_)) {
274 return errors::InvalidArgument(
275 "Shared queue '", name_, "' has component shapes ",
276 PartialTensorShapeUtils::PartialShapeListString(partial_shapes_),
277 " but requested component shapes were ",
278 PartialTensorShapeUtils::PartialShapeListString(requested_shapes));
279 } else {
280 return Status::OK();
281 }
282 }
283
MatchesNodeDef(const NodeDef & node_def)284 Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
285 if (!MatchesNodeDefOp(node_def, "PaddingFIFOQueue").ok() &&
286 !MatchesNodeDefOp(node_def, "PaddingFIFOQueueV2").ok()) {
287 return errors::InvalidArgument("Expected PaddingFIFOQueue, found ",
288 node_def.op());
289 }
290 TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
291 TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
292 TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def));
293 return Status::OK();
294 }
295
ValidateElementToLargerSlice(const Tensor & element,Tensor * parent)296 static Status ValidateElementToLargerSlice(const Tensor& element,
297 Tensor* parent) {
298 DCHECK_NE(parent->dim_size(0), 0);
299 if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
300 TensorShape chip_shape = parent->shape();
301 chip_shape.RemoveDim(0);
302 return errors::Internal(
303 "HandleElementToLargerSlice Cannot copy slice: number of entries in "
304 "element is greater than number of elements in parent slice. ",
305 "Shapes are: [element]: ", element.shape().DebugString(),
306 ", [parent slice]: ", chip_shape.DebugString());
307 }
308 return Status::OK();
309 }
310
311 template <typename T, int NDIMS>
HandleElementToLargerSlice(const Tensor & element,Tensor * parent,int index)312 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
313 int index) {
314 Status s = ValidateElementToLargerSlice(element, parent);
315 if (!s.ok()) {
316 return s;
317 }
318 if (element.NumElements() == 0) {
319 return Status::OK();
320 }
321 auto element_t = element.tensor<T, NDIMS>();
322 auto parent_t = parent->tensor<T, NDIMS + 1>();
323 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
324 slice_indices[0] = index;
325 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
326 slice_size[0] = 1;
327 for (size_t i = 1; i < slice_size.size(); ++i) {
328 slice_size[i] = element_t.dimension(i - 1);
329 }
330 parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
331 return Status::OK();
332 }
333
334 namespace {
335
336 template <int NDIMS>
HandleElementToLargerSliceWithRank(const Tensor & element,Tensor * parent,int index)337 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
338 int index) {
339 #define HANDLE_TYPE(T) \
340 case DataTypeToEnum<T>::value: { \
341 return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
342 }
343
344 switch (element.dtype()) {
345 TF_CALL_ALL_TYPES(HANDLE_TYPE);
346 #undef HANDLE_TYPE
347 default:
348 return errors::Unimplemented(
349 "HandleElementToLargerSliceWithRank Unhandled data type: ",
350 DataTypeString(element.dtype()));
351 }
352 }
353
354 } // namespace
355
CopyElementToLargerSlice(const Tensor & element,Tensor * parent,int index)356 Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element,
357 Tensor* parent, int index) {
358 if (parent->dims() != element.dims() + 1) {
359 return errors::Internal(
360 "Mismatched ranks. Element's rank is: ", element.dims(),
361 " but element is meant to be a slice in output Tensor having rank: ",
362 parent->dims(), " (should be: ", element.dims() + 1, ")");
363 }
364
365 #define HANDLE_DIMS(NDIMS) \
366 case NDIMS: { \
367 TF_RETURN_IF_ERROR( \
368 HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
369 return Status::OK(); \
370 }
371
372 switch (element.dims()) {
373 HANDLE_DIMS(0);
374 HANDLE_DIMS(1);
375 HANDLE_DIMS(2);
376 HANDLE_DIMS(3);
377 HANDLE_DIMS(4);
378 #undef HANDLE_DIMS
379 default:
380 return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
381 element.dims());
382 }
383 }
384
385 // Static method
SetElementZero(Tensor * element)386 Status PaddingFIFOQueue::SetElementZero(Tensor* element) {
387 #define HANDLE_TYPE(T) \
388 if (element->dtype() == DataTypeToEnum<T>::value) { \
389 element->flat<T>().setConstant(T()); \
390 return Status::OK(); \
391 }
392 TF_CALL_ALL_TYPES(HANDLE_TYPE);
393 #undef HANDLE_TYPE
394 return errors::Unimplemented("SetElementZero Unhandled data type: ",
395 DataTypeString(element->dtype()));
396 }
397
ConvertShapesPartialDimensionsToZero(const gtl::ArraySlice<PartialTensorShape> & partial_shapes)398 std::vector<TensorShape> PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero(
399 const gtl::ArraySlice<PartialTensorShape>& partial_shapes) {
400 std::vector<TensorShape> shapes(partial_shapes.size());
401 for (size_t i = 0; i < shapes.size(); ++i) {
402 const PartialTensorShape& partial = partial_shapes[i];
403 TensorShape& shape = shapes[i];
404 for (int64 s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s);
405 }
406 return shapes;
407 }
408
409 } // namespace tensorflow
410