• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // See docs in ../ops/data_flow_ops.cc.
16 
17 #include <deque>
18 #include <queue>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/kernels/priority_queue.h"
26 #include "tensorflow/core/kernels/queue_base.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/priority_queue_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/batch_util.h"
33 
34 namespace tensorflow {
35 
PriorityQueue(int32 capacity,const DataTypeVector & component_dtypes,const std::vector<TensorShape> & component_shapes,const string & name)36 PriorityQueue::PriorityQueue(int32 capacity,
37                              const DataTypeVector& component_dtypes,
38                              const std::vector<TensorShape>& component_shapes,
39                              const string& name)
40     : TypedQueue(capacity, component_dtypes, component_shapes, name) {}
41 
Initialize()42 Status PriorityQueue::Initialize() {
43   Status s = TypedQueue::Initialize();
44   if (!s.ok()) return s;
45 
46   mutex_lock lock(mu_);
47   if (component_dtypes_[0] != DT_INT64) {
48     return errors::InvalidArgument(
49         "PriorityQueue priority index component must be type int64, but "
50         "dtype is: ",
51         DataTypeString(component_dtypes_[0]));
52   }
53   if (specified_shapes() && !TensorShapeUtils::IsScalar(component_shapes_[0])) {
54     return errors::InvalidArgument(
55         "PriorityQueue priority index component must be a scalar, but shape "
56         "is: ",
57         component_shapes_[0].DebugString());
58   }
59   return Status::OK();
60 }
61 
DequeueLocked(OpKernelContext * ctx,Tuple * tuple)62 void PriorityQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
63   DCHECK_GT(queues_[0].size(), 0);
64   (*tuple).reserve(num_components());
65   for (int i = 0; i < num_components(); ++i) {
66     PersistentTensor persistent_tensor = gtl::ConsumeTop(&queues_[i]).second;
67     (*tuple).push_back(*persistent_tensor.AccessTensor(ctx));
68   }
69 }
70 
TryEnqueue(const Tuple & tuple,OpKernelContext * ctx,DoneCallback callback)71 void PriorityQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
72                                DoneCallback callback) {
73   CancellationManager* cm = ctx->cancellation_manager();
74   CancellationToken token = cm->get_cancellation_token();
75   bool already_cancelled;
76   {
77     mutex_lock l(mu_);
78     already_cancelled = !cm->RegisterCallback(
79         token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
80     if (!already_cancelled) {
81       enqueue_attempts_.emplace_back(
82           1, callback, ctx, cm, token,
83           [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
84             if (closed_) {
85               attempt->context->SetStatus(
86                   errors::Cancelled("PriorityQueue '", name_, "' is closed."));
87               return kComplete;
88             }
89             if (queues_[0].size() < static_cast<size_t>(capacity_)) {
90               if (!TensorShapeUtils::IsScalar(tuple[0].shape())) {
91                 attempt->context->SetStatus(errors::InvalidArgument(
92                     "Expected the priority element to be a scalar, but "
93                     "received shape: ",
94                     tuple[0].shape().DebugString()));
95                 return kComplete;
96               }
97               const int64 priority = tuple[0].scalar<int64>()();
98               for (int i = 0; i < num_components(); ++i) {
99                 queues_[i].emplace(priority, PersistentTensor(tuple[i]));
100               }
101               return kComplete;
102             } else {
103               return kNoProgress;
104             }
105           });
106     }
107   }
108   if (!already_cancelled) {
109     FlushUnlocked();
110   } else {
111     ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
112     callback();
113   }
114 }
115 
116 /* static */
GetElementComponentFromBatch(const PriorityQueue::Tuple & tuple,int index,int component,OpKernelContext * ctx,PersistentTensor * out_tensor)117 Status PriorityQueue::GetElementComponentFromBatch(
118     const PriorityQueue::Tuple& tuple, int index, int component,
119     OpKernelContext* ctx, PersistentTensor* out_tensor) {
120   TensorShape element_shape(tuple[component].shape());
121   element_shape.RemoveDim(0);
122   Tensor* element_access = nullptr;
123   TF_RETURN_IF_ERROR(ctx->allocate_persistent(
124       tuple[component].dtype(), element_shape, out_tensor, &element_access));
125   TF_RETURN_IF_ERROR(
126       batch_util::CopySliceToElement(tuple[component], element_access, index));
127   return Status::OK();
128 }
129 
TryEnqueueMany(const Tuple & tuple,OpKernelContext * ctx,DoneCallback callback)130 void PriorityQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
131                                    DoneCallback callback) {
132   const int64 batch_size = tuple[0].dim_size(0);
133   if (batch_size == 0) {
134     callback();
135     return;
136   }
137 
138   CancellationManager* cm = ctx->cancellation_manager();
139   CancellationToken token = cm->get_cancellation_token();
140   bool already_cancelled;
141   {
142     mutex_lock l(mu_);
143     already_cancelled = !cm->RegisterCallback(
144         token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
145     if (!already_cancelled) {
146       enqueue_attempts_.emplace_back(
147           batch_size, callback, ctx, cm, token,
148           [tuple, this,
149            ctx](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
150             if (closed_) {
151               attempt->context->SetStatus(
152                   errors::Cancelled("PriorityQueue '", name_, "' is closed."));
153               return kComplete;
154             }
155             RunResult result = kNoProgress;
156             while (queues_[0].size() < static_cast<size_t>(capacity_)) {
157               result = kProgress;
158               const int index =
159                   tuple[0].dim_size(0) - attempt->elements_requested;
160 
161               PersistentTensor priority_element;
162               attempt->context->SetStatus(GetElementComponentFromBatch(
163                   tuple, index, 0, attempt->context, &priority_element));
164               if (!attempt->context->status().ok()) return kComplete;
165               Tensor* priority_tensor = priority_element.AccessTensor(ctx);
166               if (!TensorShapeUtils::IsScalar(priority_tensor->shape())) {
167                 attempt->context->SetStatus(errors::InvalidArgument(
168                     "Expected the priority element to be a scalar, but "
169                     "received shape: ",
170                     priority_tensor->shape().DebugString()));
171                 return kComplete;
172               }
173               const int64 priority = priority_tensor->scalar<int64>()();
174               for (int i = 0; i < num_components(); ++i) {
175                 PersistentTensor element;
176                 attempt->context->SetStatus(GetElementComponentFromBatch(
177                     tuple, index, i, attempt->context, &element));
178                 if (!attempt->context->status().ok()) return kComplete;
179                 queues_[i].emplace(priority, element);
180               }
181               --attempt->elements_requested;
182               if (attempt->elements_requested == 0) {
183                 return kComplete;
184               }
185             }
186             return result;
187           });
188     }
189   }
190   if (!already_cancelled) {
191     FlushUnlocked();
192   } else {
193     ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
194     callback();
195   }
196 }
197 
TryDequeue(OpKernelContext * ctx,CallbackWithTuple callback)198 void PriorityQueue::TryDequeue(OpKernelContext* ctx,
199                                CallbackWithTuple callback) {
200   CancellationManager* cm = ctx->cancellation_manager();
201   CancellationToken token = cm->get_cancellation_token();
202   bool already_cancelled;
203   {
204     mutex_lock l(mu_);
205     already_cancelled = !cm->RegisterCallback(
206         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
207     if (!already_cancelled) {
208       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
209       dequeue_attempts_.emplace_back(
210           1, [callback]() { callback(Tuple()); }, ctx, cm, token,
211           [callback, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
212             const int32 s = queues_[0].size();
213             if (closed_ && s == 0) {
214               attempt->context->SetStatus(errors::OutOfRange(
215                   "PriorityQueue '", name_, "' is closed and has ",
216                   "insufficient elements (requested ", 1, ", current size ", s,
217                   ")"));
218               return kComplete;
219             }
220             if (s > 0) {
221               Tuple tuple;
222               DequeueLocked(attempt->context, &tuple);
223               attempt->done_callback = [callback, tuple]() { callback(tuple); };
224               return kComplete;
225             } else {
226               return kNoProgress;
227             }
228           });
229     }
230   }
231   if (!already_cancelled) {
232     FlushUnlocked();
233   } else {
234     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
235     callback(Tuple());
236   }
237 }
238 
TryDequeueMany(int num_elements,OpKernelContext * ctx,bool allow_small_batch,CallbackWithTuple callback)239 void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
240                                    bool allow_small_batch,
241                                    CallbackWithTuple callback) {
242   if (!specified_shapes()) {
243     ctx->SetStatus(
244         errors::InvalidArgument("PriorityQueue's DequeueMany requires the "
245                                 "components to have specified shapes."));
246     callback(Tuple());
247     return;
248   }
249   if (num_elements == 0) {
250     Tuple tuple;
251     tuple.reserve(num_components());
252     for (int i = 0; i < num_components(); ++i) {
253       // TODO(josh11b,misard): Switch to allocate_output().  Problem is
254       // this breaks the abstraction boundary since we don't *really*
255       // know if and how the Tensors in the tuple we pass to callback
256       // correspond to the outputs of *ctx.  For example, the
257       // ReaderRead Op uses TryDequeue() to get a filename out of a
258       // queue that is used internally by the reader and is not
259       // associated with any output of the ReaderRead.
260       // mrry@ adds:
261       // Maybe we need to pass a std::function<Tensor*(...)> (or
262       // better signature) that calls the appropriate allocator
263       // function in addition to ctx?  (Or support a shim Allocator
264       // that has an internal OpKernelContext*, and dispatches to the
265       // appropriate method?)
266       // misard@ adds:
267       // I don't see that a std::function would help. The problem is
268       // that at this point (allocation time) the system doesn't know
269       // what is going to happen to the element read out of the
270       // queue. As long as we keep the generality that TensorFlow Ops
271       // do their own dynamic allocation in arbitrary C++ code, we
272       // need to preserve robustness to allocating output Tensors with
273       // the 'wrong' attributes, and fixing up with a copy. The only
274       // improvement I can see here in the future would be to support
275       // an optimized case where the queue 'knows' what attributes to
276       // use, and plumbs them through here.
277       Tensor element;
278       Status status = ctx->allocate_temp(component_dtypes_[i],
279                                          ManyOutShape(i, 0), &element);
280       if (!status.ok()) {
281         ctx->SetStatus(status);
282         callback(Tuple());
283         return;
284       }
285       tuple.emplace_back(element);
286     }
287     callback(tuple);
288     return;
289   }
290 
291   CancellationManager* cm = ctx->cancellation_manager();
292   CancellationToken token = cm->get_cancellation_token();
293   bool already_cancelled;
294   {
295     mutex_lock l(mu_);
296     already_cancelled = !cm->RegisterCallback(
297         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
298     if (!already_cancelled) {
299       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
300       dequeue_attempts_.emplace_back(
301           num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
302           [callback, this, allow_small_batch](
303               Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
304             int32 s = queues_[0].size();
305             // Return OutOfRange if closed and there are fewer elements
306             // available than requested.  *Unless* allow_small_batch
307             // is true, in which case we return as many elements as
308             // possible.
309             if (closed_) {
310               if (s == 0 ||
311                   (!allow_small_batch && s < attempt->elements_requested)) {
312                 attempt->context->SetStatus(errors::OutOfRange(
313                     "PriorityQueue '", name_, "' is closed and has ",
314                     "insufficient elements (requested ",
315                     attempt->elements_requested, ", current size ", s, ")"));
316                 return kComplete;
317               }
318             }
319 
320             // The PriorityQueue is expected to always return a
321             // sorted set of entries.  In order to do this, the underlying
322             // queue must have at least this many entries already.
323             // Doing the dynamic thing and pulling out a portion at a
324             // time leads to unordered output in calls to DequeueMany.
325             //
326             // An alternative solution is to store the attempt tuple
327             // entries in an identical priority_queue and push onto
328             // this queue dynamically, then when it is full, do all
329             // the Tensor concatenation at the very end.
330             // TODO(ebrevdo): Change approach if this leads to locking issues.
331             if (s < attempt->elements_requested) {
332               // If we have no elements at all, then wait.
333               // Otherwise proceed if closed and allow small batch is true.
334               // Otherwise wait until we have more enqueued elements.
335               if (s == 0 || !(closed_ && allow_small_batch)) {
336                 return kNoProgress;
337               }
338             }
339 
340             RunResult result = kNoProgress;
341             for (; s > 0; --s) {
342               if (attempt->tuple.empty()) {
343                 // Only allocate tuple when we have something to dequeue
344                 // so we don't use excessive memory when there are many
345                 // blocked dequeue attempts waiting.
346                 attempt->tuple.reserve(num_components());
347                 for (int i = 0; i < num_components(); ++i) {
348                   const TensorShape shape =
349                       ManyOutShape(i, attempt->elements_requested);
350                   Tensor element;
351                   attempt->context->SetStatus(attempt->context->allocate_temp(
352                       component_dtypes_[i], shape, &element));
353                   if (!attempt->context->status().ok()) return kComplete;
354                   attempt->tuple.emplace_back(element);
355                 }
356               }
357               result = kProgress;
358               Tuple tuple;
359               DequeueLocked(attempt->context, &tuple);
360               const int index =
361                   attempt->tuple[0].dim_size(0) - attempt->elements_requested;
362               for (int i = 0; i < num_components(); ++i) {
363                 attempt->context->SetStatus(batch_util::CopyElementToSlice(
364                     std::move(tuple[i]), &attempt->tuple[i], index));
365                 if (!attempt->context->status().ok()) return kComplete;
366               }
367               tuple.clear();
368               --attempt->elements_requested;
369               if (attempt->elements_requested == 0) {
370                 tuple = attempt->tuple;
371                 attempt->done_callback = [callback, tuple]() {
372                   callback(tuple);
373                 };
374                 return kComplete;
375               }
376             }
377             return result;
378           });
379     }
380   }
381   if (!already_cancelled) {
382     FlushUnlocked();
383   } else {
384     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
385     callback(Tuple());
386   }
387 }
388 
MatchesNodeDef(const NodeDef & node_def)389 Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) {
390   if (!MatchesNodeDefOp(node_def, "PriorityQueue").ok() &&
391       !MatchesNodeDefOp(node_def, "PriorityQueueV2").ok()) {
392     return errors::InvalidArgument("Expected PriorityQueue, found ",
393                                    node_def.op());
394   }
395   TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
396   TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def));
397   TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def));
398   return Status::OK();
399 }
400 
MatchesPriorityNodeDefTypes(const NodeDef & node_def) const401 Status PriorityQueue::MatchesPriorityNodeDefTypes(
402     const NodeDef& node_def) const {
403   DataTypeVector requested_dtypes;
404   TF_RETURN_IF_ERROR(
405       GetNodeAttr(node_def, "component_types", &requested_dtypes));
406   requested_dtypes.insert(requested_dtypes.begin(), DT_INT64);
407   if (requested_dtypes != component_dtypes_) {
408     return errors::InvalidArgument("Shared queue '", name_,
409                                    "' has component types ",
410                                    DataTypeSliceString(component_dtypes_),
411                                    " but requested component types were ",
412                                    DataTypeSliceString(requested_dtypes));
413   }
414   return Status::OK();
415 }
416 
MatchesPriorityNodeDefShapes(const NodeDef & node_def) const417 Status PriorityQueue::MatchesPriorityNodeDefShapes(
418     const NodeDef& node_def) const {
419   std::vector<TensorShape> requested_shapes;
420   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
421   requested_shapes.insert(requested_shapes.begin(), TensorShape({}));
422   if (requested_shapes != component_shapes_) {
423     return errors::InvalidArgument("Shared queue '", name_,
424                                    "' has component shapes ",
425                                    ShapeListString(component_shapes_),
426                                    " but requested component shapes were ",
427                                    ShapeListString(requested_shapes));
428   }
429   return Status::OK();
430 }
431 
432 }  // namespace tensorflow
433