• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/data_flow_ops.cc.
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/padding_fifo_queue.h"
27 #include "tensorflow/core/kernels/queue_base.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/batch_util.h"
33 
34 namespace tensorflow {
35 
PaddingFIFOQueue(int capacity,const DataTypeVector & component_dtypes,const std::vector<PartialTensorShape> & partial_shapes,const string & name)36 PaddingFIFOQueue::PaddingFIFOQueue(
37     int capacity, const DataTypeVector& component_dtypes,
38     const std::vector<PartialTensorShape>& partial_shapes, const string& name)
39     : FIFOQueue(capacity, component_dtypes,
40                 ConvertShapesPartialDimensionsToZero(partial_shapes), name),
41       partial_shapes_(partial_shapes) {}
42 
Initialize()43 Status PaddingFIFOQueue::Initialize() {
44   Status s = FIFOQueue::Initialize();
45   if (!s.ok()) return s;
46 
47   if (component_dtypes_.size() != partial_shapes_.size()) {
48     return errors::InvalidArgument(
49         "Shapes must be provided for all components, but received ",
50         component_dtypes_.size(), " dtypes and ", partial_shapes_.size(),
51         " shapes.");
52   }
53 
54   return Status::OK();
55 }
56 
57 /* static */
GetElementComponent(const PaddingFIFOQueue::Tuple & tuple,int component,OpKernelContext * ctx,PersistentTensor * out_tensor)58 Status PaddingFIFOQueue::GetElementComponent(
59     const PaddingFIFOQueue::Tuple& tuple, int component, OpKernelContext* ctx,
60     PersistentTensor* out_tensor) {
61   TensorShape element_shape(tuple[component].shape());
62   Tensor* element_access = nullptr;
63   TF_RETURN_IF_ERROR(ctx->allocate_persistent(
64       tuple[component].dtype(), element_shape, out_tensor, &element_access));
65   *element_access = tuple[component];
66   return Status::OK();
67 }
68 
TryDequeueMany(int num_elements,OpKernelContext * ctx,bool allow_small_batch,CallbackWithTuple callback)69 void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
70                                       bool allow_small_batch,
71                                       CallbackWithTuple callback) {
72   if (num_elements == 0) {
73     Tuple tuple;
74     tuple.reserve(num_components());
75     for (int i = 0; i < num_components(); ++i) {
76       // TODO(josh11b,misard): Switch to allocate_output().
77       // See similar comment in fifo_queue.cc
78       Tensor element;
79       // Here, ManyOutShape returns zeros for undetermined shapes,
80       // which is exactly what we want to use.
81       OP_REQUIRES_OK(ctx, ctx->allocate_temp(component_dtypes_[i],
82                                              ManyOutShape(i, 0), &element));
83       tuple.emplace_back(element);
84     }
85     callback(tuple);
86     return;
87   }
88 
89   CancellationManager* cm = ctx->cancellation_manager();
90   CancellationToken token = cm->get_cancellation_token();
91   bool already_cancelled;
92   {
93     mutex_lock l(mu_);
94     already_cancelled = !cm->RegisterCallback(
95         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
96     if (!already_cancelled) {
97       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
98       dequeue_attempts_.emplace_back(
99           num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
100           [callback, allow_small_batch,
101            this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
102             int32 queue_size = queues_[0].size();
103             if (closed_ && queue_size < attempt->elements_requested) {
104               // If we don't have enough for a full dequeue, we have
105               // to reset the attempt tuple.
106               if (!attempt->tuples.empty()) {
107                 // Restore already-dequeued elements to the front of the queue.
108                 for (int64 i = attempt->tuples.size() - 1; i >= 0; --i) {
109                   for (int j = 0; j < num_components(); ++j) {
110                     PersistentTensor element;
111                     Status s = GetElementComponent(attempt->tuples[i], j,
112                                                    attempt->context, &element);
113                     if (!s.ok()) {
114                       attempt->context->SetStatus(
115                           errors::DataLoss("Failed to restore element from "
116                                            "partially-dequeued batch "
117                                            "to PaddingFIFOQueue: ",
118                                            s.error_message()));
119                     }
120                     queues_[j].push_front(element);
121                   }
122                 }
123               }
124               if (allow_small_batch && !queues_[0].empty()) {
125                 // Request all remaining elements in the queue.
126                 queue_size = queues_[0].size();
127                 attempt->tuples.clear();
128                 attempt->elements_requested = queue_size;
129               } else {
130                 if (allow_small_batch) {
131                   // There may be some enqueue attempts containing
132                   // values.  If so, we'll yield and wait for them
133                   // to add elements to the queue.
134                   if (!enqueue_attempts_.empty()) return kProgress;
135                 }
136                 if (attempt->context->status().ok()) {
137                   attempt->context->SetStatus(errors::OutOfRange(
138                       "PaddingFIFOQueue '", name_, "' is closed and has ",
139                       "insufficient elements (requested ",
140                       attempt->elements_requested, ", current size ",
141                       queue_size, ")"));
142                 }
143                 return kComplete;
144               }
145             }
146 
147             RunResult result = kNoProgress;
148             for (; queue_size > 0; --queue_size) {
149               result = kProgress;
150               Tuple tuple;
151               DequeueLocked(attempt->context, &tuple);
152               attempt->tuples.push_back(tuple);
153               tuple.clear();
154               --attempt->elements_requested;
155 
156               if (attempt->elements_requested == 0) {
157                 // Finished.  Allocate attempt->tuple and
158                 // copy from attempt->tuples to attempt->tuple.
159                 attempt->tuple.reserve(num_components());
160                 std::vector<Tuple>& tuples = attempt->tuples;
161 
162                 std::vector<bool> dynamic_shape;
163                 const int64 batch_size = tuples.size();
164 
165                 for (int i = 0; i < num_components(); ++i) {
166                   const PartialTensorShape partial_shape =
167                       PartialTensorShape({batch_size})
168                           .Concatenate(partial_shapes_[i]);
169                   TensorShape shape({batch_size});
170 
171                   for (int j = 0; j < partial_shape.dims() - 1; ++j) {
172                     if (partial_shape.dim_size(j + 1) > -1) {
173                       shape.AddDim(partial_shape.dim_size(j + 1));
174                     } else {
175                       // Expand sizes to match.
176                       int64 max_val = 0;
177                       for (const Tuple& t : tuples) {
178                         max_val = std::max(max_val, t[i].shape().dim_size(j));
179                       }
180                       shape.AddDim(max_val);
181                     }
182                   }
183 
184                   Tensor element;
185                   attempt->context->SetStatus(attempt->context->allocate_temp(
186                       component_dtypes_[i], shape, &element));
187                   if (!attempt->context->status().ok()) return kComplete;
188 
189                   bool has_dynamic_shape = !partial_shape.IsFullyDefined();
190                   if (has_dynamic_shape) {
191                     // Set all values to zero because not all values
192                     // will get written over.
193                     attempt->context->SetStatus(SetElementZero(&element));
194                     if (!attempt->context->status().ok()) return kComplete;
195                   }
196 
197                   dynamic_shape.push_back(has_dynamic_shape);
198 
199                   // TODO(ebrevdo): should this be a persistent tensor?
200                   attempt->tuple.emplace_back(element);
201                 }
202 
203                 for (size_t index = 0; index < tuples.size(); ++index) {
204                   for (int i = 0; i < num_components(); ++i) {
205                     if (dynamic_shape[i]) {
206                       // Slightly slower copy operation
207                       attempt->context->SetStatus(CopyElementToLargerSlice(
208                           tuples[index][i], &attempt->tuple[i], index));
209                     } else {
210                       attempt->context->SetStatus(
211                           batch_util::CopyElementToSlice(
212                               std::move(tuples[index][i]), &attempt->tuple[i],
213                               index));
214                     }
215                     if (!attempt->context->status().ok()) return kComplete;
216                   }
217                 }
218                 tuple = attempt->tuple;
219                 attempt->tuples.clear();
220                 attempt->done_callback = [callback, tuple]() {
221                   callback(tuple);
222                 };
223                 return kComplete;
224               }
225             }
226             return result;
227           });
228     }
229   }
230   if (!already_cancelled) {
231     FlushUnlocked();
232   } else {
233     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
234     callback(Tuple());
235   }
236 }
237 
ValidateTuple(const Tuple & tuple)238 Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) {
239   TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
240   for (size_t i = 0; i < tuple.size(); ++i) {
241     if (!partial_shapes_[i].IsCompatibleWith(tuple[i].shape())) {
242       return errors::InvalidArgument("Shape mismatch in tuple component ", i,
243                                      ". Expected ",
244                                      partial_shapes_[i].DebugString(), ", got ",
245                                      tuple[i].shape().DebugString());
246     }
247   }
248   return Status::OK();
249 }
250 
ValidateManyTuple(const Tuple & tuple)251 Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) {
252   TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
253   const int64 batch_size = tuple[0].dim_size(0);
254   for (size_t i = 0; i < tuple.size(); ++i) {
255     // Expected shape is [batch_size] + partial_shapes_[i]
256     const PartialTensorShape expected_shape =
257         PartialTensorShape({batch_size}).Concatenate(partial_shapes_[i]);
258     if (!expected_shape.IsCompatibleWith(tuple[i].shape())) {
259       return errors::InvalidArgument("Shape mismatch in tuple component ", i,
260                                      ". Expected ",
261                                      expected_shape.DebugString(), ", got ",
262                                      tuple[i].shape().DebugString());
263     }
264   }
265   return Status::OK();
266 }
267 
CompatibleNodeDefShapes(const NodeDef & node_def) const268 Status PaddingFIFOQueue::CompatibleNodeDefShapes(
269     const NodeDef& node_def) const {
270   std::vector<PartialTensorShape> requested_shapes;
271   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
272   if (!PartialTensorShapeUtils::AreCompatible(requested_shapes,
273                                               partial_shapes_)) {
274     return errors::InvalidArgument(
275         "Shared queue '", name_, "' has component shapes ",
276         PartialTensorShapeUtils::PartialShapeListString(partial_shapes_),
277         " but requested component shapes were ",
278         PartialTensorShapeUtils::PartialShapeListString(requested_shapes));
279   } else {
280     return Status::OK();
281   }
282 }
283 
MatchesNodeDef(const NodeDef & node_def)284 Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
285   if (!MatchesNodeDefOp(node_def, "PaddingFIFOQueue").ok() &&
286       !MatchesNodeDefOp(node_def, "PaddingFIFOQueueV2").ok()) {
287     return errors::InvalidArgument("Expected PaddingFIFOQueue, found ",
288                                    node_def.op());
289   }
290   TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
291   TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
292   TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def));
293   return Status::OK();
294 }
295 
ValidateElementToLargerSlice(const Tensor & element,Tensor * parent)296 static Status ValidateElementToLargerSlice(const Tensor& element,
297                                            Tensor* parent) {
298   DCHECK_NE(parent->dim_size(0), 0);
299   if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
300     TensorShape chip_shape = parent->shape();
301     chip_shape.RemoveDim(0);
302     return errors::Internal(
303         "HandleElementToLargerSlice Cannot copy slice: number of entries in "
304         "element is greater than number of elements in parent slice.  ",
305         "Shapes are: [element]: ", element.shape().DebugString(),
306         ", [parent slice]: ", chip_shape.DebugString());
307   }
308   return Status::OK();
309 }
310 
311 template <typename T, int NDIMS>
HandleElementToLargerSlice(const Tensor & element,Tensor * parent,int index)312 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
313                                   int index) {
314   Status s = ValidateElementToLargerSlice(element, parent);
315   if (!s.ok()) {
316     return s;
317   }
318   if (element.NumElements() == 0) {
319     return Status::OK();
320   }
321   auto element_t = element.tensor<T, NDIMS>();
322   auto parent_t = parent->tensor<T, NDIMS + 1>();
323   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
324   slice_indices[0] = index;
325   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
326   slice_size[0] = 1;
327   for (size_t i = 1; i < slice_size.size(); ++i) {
328     slice_size[i] = element_t.dimension(i - 1);
329   }
330   parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
331   return Status::OK();
332 }
333 
334 namespace {
335 
336 template <int NDIMS>
HandleElementToLargerSliceWithRank(const Tensor & element,Tensor * parent,int index)337 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
338                                           int index) {
339 #define HANDLE_TYPE(T)                                                   \
340   case DataTypeToEnum<T>::value: {                                       \
341     return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
342   }
343 
344   switch (element.dtype()) {
345     TF_CALL_ALL_TYPES(HANDLE_TYPE);
346 #undef HANDLE_TYPE
347     default:
348       return errors::Unimplemented(
349           "HandleElementToLargerSliceWithRank Unhandled data type: ",
350           DataTypeString(element.dtype()));
351   }
352 }
353 
354 }  // namespace
355 
CopyElementToLargerSlice(const Tensor & element,Tensor * parent,int index)356 Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element,
357                                                   Tensor* parent, int index) {
358   if (parent->dims() != element.dims() + 1) {
359     return errors::Internal(
360         "Mismatched ranks.  Element's rank is: ", element.dims(),
361         " but element is meant to be a slice in output Tensor having rank: ",
362         parent->dims(), " (should be: ", element.dims() + 1, ")");
363   }
364 
365 #define HANDLE_DIMS(NDIMS)                                                  \
366   case NDIMS: {                                                             \
367     TF_RETURN_IF_ERROR(                                                     \
368         HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
369     return Status::OK();                                                    \
370   }
371 
372   switch (element.dims()) {
373     HANDLE_DIMS(0);
374     HANDLE_DIMS(1);
375     HANDLE_DIMS(2);
376     HANDLE_DIMS(3);
377     HANDLE_DIMS(4);
378 #undef HANDLE_DIMS
379     default:
380       return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
381                                    element.dims());
382   }
383 }
384 
385 // Static method
SetElementZero(Tensor * element)386 Status PaddingFIFOQueue::SetElementZero(Tensor* element) {
387 #define HANDLE_TYPE(T)                                \
388   if (element->dtype() == DataTypeToEnum<T>::value) { \
389     element->flat<T>().setConstant(T());              \
390     return Status::OK();                              \
391   }
392   TF_CALL_ALL_TYPES(HANDLE_TYPE);
393 #undef HANDLE_TYPE
394   return errors::Unimplemented("SetElementZero Unhandled data type: ",
395                                DataTypeString(element->dtype()));
396 }
397 
ConvertShapesPartialDimensionsToZero(const gtl::ArraySlice<PartialTensorShape> & partial_shapes)398 std::vector<TensorShape> PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero(
399     const gtl::ArraySlice<PartialTensorShape>& partial_shapes) {
400   std::vector<TensorShape> shapes(partial_shapes.size());
401   for (size_t i = 0; i < shapes.size(); ++i) {
402     const PartialTensorShape& partial = partial_shapes[i];
403     TensorShape& shape = shapes[i];
404     for (int64 s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s);
405   }
406   return shapes;
407 }
408 
409 }  // namespace tensorflow
410