1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // See docs in ../ops/data_flow_ops.cc.
16
17 #include <deque>
18 #include <queue>
19 #include <vector>
20
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/kernels/priority_queue.h"
26 #include "tensorflow/core/kernels/queue_base.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/priority_queue_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/batch_util.h"
33
34 namespace tensorflow {
35
PriorityQueue(int32 capacity,const DataTypeVector & component_dtypes,const std::vector<TensorShape> & component_shapes,const string & name)36 PriorityQueue::PriorityQueue(int32 capacity,
37 const DataTypeVector& component_dtypes,
38 const std::vector<TensorShape>& component_shapes,
39 const string& name)
40 : TypedQueue(capacity, component_dtypes, component_shapes, name) {}
41
Initialize()42 Status PriorityQueue::Initialize() {
43 Status s = TypedQueue::Initialize();
44 if (!s.ok()) return s;
45
46 mutex_lock lock(mu_);
47 if (component_dtypes_[0] != DT_INT64) {
48 return errors::InvalidArgument(
49 "PriorityQueue priority index component must be type int64, but "
50 "dtype is: ",
51 DataTypeString(component_dtypes_[0]));
52 }
53 if (specified_shapes() && !TensorShapeUtils::IsScalar(component_shapes_[0])) {
54 return errors::InvalidArgument(
55 "PriorityQueue priority index component must be a scalar, but shape "
56 "is: ",
57 component_shapes_[0].DebugString());
58 }
59 return Status::OK();
60 }
61
DequeueLocked(OpKernelContext * ctx,Tuple * tuple)62 void PriorityQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
63 DCHECK_GT(queues_[0].size(), 0);
64 (*tuple).reserve(num_components());
65 for (int i = 0; i < num_components(); ++i) {
66 PersistentTensor persistent_tensor = gtl::ConsumeTop(&queues_[i]).second;
67 (*tuple).push_back(*persistent_tensor.AccessTensor(ctx));
68 }
69 }
70
TryEnqueue(const Tuple & tuple,OpKernelContext * ctx,DoneCallback callback)71 void PriorityQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
72 DoneCallback callback) {
73 CancellationManager* cm = ctx->cancellation_manager();
74 CancellationToken token = cm->get_cancellation_token();
75 bool already_cancelled;
76 {
77 mutex_lock l(mu_);
78 already_cancelled = !cm->RegisterCallback(
79 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
80 if (!already_cancelled) {
81 enqueue_attempts_.emplace_back(
82 1, callback, ctx, cm, token,
83 [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
84 if (closed_) {
85 attempt->context->SetStatus(
86 errors::Cancelled("PriorityQueue '", name_, "' is closed."));
87 return kComplete;
88 }
89 if (queues_[0].size() < static_cast<size_t>(capacity_)) {
90 if (!TensorShapeUtils::IsScalar(tuple[0].shape())) {
91 attempt->context->SetStatus(errors::InvalidArgument(
92 "Expected the priority element to be a scalar, but "
93 "received shape: ",
94 tuple[0].shape().DebugString()));
95 return kComplete;
96 }
97 const int64 priority = tuple[0].scalar<int64>()();
98 for (int i = 0; i < num_components(); ++i) {
99 queues_[i].emplace(priority, PersistentTensor(tuple[i]));
100 }
101 return kComplete;
102 } else {
103 return kNoProgress;
104 }
105 });
106 }
107 }
108 if (!already_cancelled) {
109 FlushUnlocked();
110 } else {
111 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
112 callback();
113 }
114 }
115
116 /* static */
GetElementComponentFromBatch(const PriorityQueue::Tuple & tuple,int index,int component,OpKernelContext * ctx,PersistentTensor * out_tensor)117 Status PriorityQueue::GetElementComponentFromBatch(
118 const PriorityQueue::Tuple& tuple, int index, int component,
119 OpKernelContext* ctx, PersistentTensor* out_tensor) {
120 TensorShape element_shape(tuple[component].shape());
121 element_shape.RemoveDim(0);
122 Tensor* element_access = nullptr;
123 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
124 tuple[component].dtype(), element_shape, out_tensor, &element_access));
125 TF_RETURN_IF_ERROR(
126 batch_util::CopySliceToElement(tuple[component], element_access, index));
127 return Status::OK();
128 }
129
TryEnqueueMany(const Tuple & tuple,OpKernelContext * ctx,DoneCallback callback)130 void PriorityQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
131 DoneCallback callback) {
132 const int64 batch_size = tuple[0].dim_size(0);
133 if (batch_size == 0) {
134 callback();
135 return;
136 }
137
138 CancellationManager* cm = ctx->cancellation_manager();
139 CancellationToken token = cm->get_cancellation_token();
140 bool already_cancelled;
141 {
142 mutex_lock l(mu_);
143 already_cancelled = !cm->RegisterCallback(
144 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
145 if (!already_cancelled) {
146 enqueue_attempts_.emplace_back(
147 batch_size, callback, ctx, cm, token,
148 [tuple, this,
149 ctx](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
150 if (closed_) {
151 attempt->context->SetStatus(
152 errors::Cancelled("PriorityQueue '", name_, "' is closed."));
153 return kComplete;
154 }
155 RunResult result = kNoProgress;
156 while (queues_[0].size() < static_cast<size_t>(capacity_)) {
157 result = kProgress;
158 const int index =
159 tuple[0].dim_size(0) - attempt->elements_requested;
160
161 PersistentTensor priority_element;
162 attempt->context->SetStatus(GetElementComponentFromBatch(
163 tuple, index, 0, attempt->context, &priority_element));
164 if (!attempt->context->status().ok()) return kComplete;
165 Tensor* priority_tensor = priority_element.AccessTensor(ctx);
166 if (!TensorShapeUtils::IsScalar(priority_tensor->shape())) {
167 attempt->context->SetStatus(errors::InvalidArgument(
168 "Expected the priority element to be a scalar, but "
169 "received shape: ",
170 priority_tensor->shape().DebugString()));
171 return kComplete;
172 }
173 const int64 priority = priority_tensor->scalar<int64>()();
174 for (int i = 0; i < num_components(); ++i) {
175 PersistentTensor element;
176 attempt->context->SetStatus(GetElementComponentFromBatch(
177 tuple, index, i, attempt->context, &element));
178 if (!attempt->context->status().ok()) return kComplete;
179 queues_[i].emplace(priority, element);
180 }
181 --attempt->elements_requested;
182 if (attempt->elements_requested == 0) {
183 return kComplete;
184 }
185 }
186 return result;
187 });
188 }
189 }
190 if (!already_cancelled) {
191 FlushUnlocked();
192 } else {
193 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
194 callback();
195 }
196 }
197
TryDequeue(OpKernelContext * ctx,CallbackWithTuple callback)198 void PriorityQueue::TryDequeue(OpKernelContext* ctx,
199 CallbackWithTuple callback) {
200 CancellationManager* cm = ctx->cancellation_manager();
201 CancellationToken token = cm->get_cancellation_token();
202 bool already_cancelled;
203 {
204 mutex_lock l(mu_);
205 already_cancelled = !cm->RegisterCallback(
206 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
207 if (!already_cancelled) {
208 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
209 dequeue_attempts_.emplace_back(
210 1, [callback]() { callback(Tuple()); }, ctx, cm, token,
211 [callback, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
212 const int32 s = queues_[0].size();
213 if (closed_ && s == 0) {
214 attempt->context->SetStatus(errors::OutOfRange(
215 "PriorityQueue '", name_, "' is closed and has ",
216 "insufficient elements (requested ", 1, ", current size ", s,
217 ")"));
218 return kComplete;
219 }
220 if (s > 0) {
221 Tuple tuple;
222 DequeueLocked(attempt->context, &tuple);
223 attempt->done_callback = [callback, tuple]() { callback(tuple); };
224 return kComplete;
225 } else {
226 return kNoProgress;
227 }
228 });
229 }
230 }
231 if (!already_cancelled) {
232 FlushUnlocked();
233 } else {
234 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
235 callback(Tuple());
236 }
237 }
238
TryDequeueMany(int num_elements,OpKernelContext * ctx,bool allow_small_batch,CallbackWithTuple callback)239 void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
240 bool allow_small_batch,
241 CallbackWithTuple callback) {
242 if (!specified_shapes()) {
243 ctx->SetStatus(
244 errors::InvalidArgument("PriorityQueue's DequeueMany requires the "
245 "components to have specified shapes."));
246 callback(Tuple());
247 return;
248 }
249 if (num_elements == 0) {
250 Tuple tuple;
251 tuple.reserve(num_components());
252 for (int i = 0; i < num_components(); ++i) {
253 // TODO(josh11b,misard): Switch to allocate_output(). Problem is
254 // this breaks the abstraction boundary since we don't *really*
255 // know if and how the Tensors in the tuple we pass to callback
256 // correspond to the outputs of *ctx. For example, the
257 // ReaderRead Op uses TryDequeue() to get a filename out of a
258 // queue that is used internally by the reader and is not
259 // associated with any output of the ReaderRead.
260 // mrry@ adds:
261 // Maybe we need to pass a std::function<Tensor*(...)> (or
262 // better signature) that calls the appropriate allocator
263 // function in addition to ctx? (Or support a shim Allocator
264 // that has an internal OpKernelContext*, and dispatches to the
265 // appropriate method?)
266 // misard@ adds:
267 // I don't see that a std::function would help. The problem is
268 // that at this point (allocation time) the system doesn't know
269 // what is going to happen to the element read out of the
270 // queue. As long as we keep the generality that TensorFlow Ops
271 // do their own dynamic allocation in arbitrary C++ code, we
272 // need to preserve robustness to allocating output Tensors with
273 // the 'wrong' attributes, and fixing up with a copy. The only
274 // improvement I can see here in the future would be to support
275 // an optimized case where the queue 'knows' what attributes to
276 // use, and plumbs them through here.
277 Tensor element;
278 Status status = ctx->allocate_temp(component_dtypes_[i],
279 ManyOutShape(i, 0), &element);
280 if (!status.ok()) {
281 ctx->SetStatus(status);
282 callback(Tuple());
283 return;
284 }
285 tuple.emplace_back(element);
286 }
287 callback(tuple);
288 return;
289 }
290
291 CancellationManager* cm = ctx->cancellation_manager();
292 CancellationToken token = cm->get_cancellation_token();
293 bool already_cancelled;
294 {
295 mutex_lock l(mu_);
296 already_cancelled = !cm->RegisterCallback(
297 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
298 if (!already_cancelled) {
299 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
300 dequeue_attempts_.emplace_back(
301 num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
302 [callback, this, allow_small_batch](
303 Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
304 int32 s = queues_[0].size();
305 // Return OutOfRange if closed and there are fewer elements
306 // available than requested. *Unless* allow_small_batch
307 // is true, in which case we return as many elements as
308 // possible.
309 if (closed_) {
310 if (s == 0 ||
311 (!allow_small_batch && s < attempt->elements_requested)) {
312 attempt->context->SetStatus(errors::OutOfRange(
313 "PriorityQueue '", name_, "' is closed and has ",
314 "insufficient elements (requested ",
315 attempt->elements_requested, ", current size ", s, ")"));
316 return kComplete;
317 }
318 }
319
320 // The PriorityQueue is expected to always return a
321 // sorted set of entries. In order to do this, the underlying
322 // queue must have at least this many entries already.
323 // Doing the dynamic thing and pulling out a portion at a
324 // time leads to unordered output in calls to DequeueMany.
325 //
326 // An alternative solution is to store the attempt tuple
327 // entries in an identical priority_queue and push onto
328 // this queue dynamically, then when it is full, do all
329 // the Tensor concatenation at the very end.
330 // TODO(ebrevdo): Change approach if this leads to locking issues.
331 if (s < attempt->elements_requested) {
332 // If we have no elements at all, then wait.
333 // Otherwise proceed if closed and allow small batch is true.
334 // Otherwise wait until we have more enqueued elements.
335 if (s == 0 || !(closed_ && allow_small_batch)) {
336 return kNoProgress;
337 }
338 }
339
340 RunResult result = kNoProgress;
341 for (; s > 0; --s) {
342 if (attempt->tuple.empty()) {
343 // Only allocate tuple when we have something to dequeue
344 // so we don't use excessive memory when there are many
345 // blocked dequeue attempts waiting.
346 attempt->tuple.reserve(num_components());
347 for (int i = 0; i < num_components(); ++i) {
348 const TensorShape shape =
349 ManyOutShape(i, attempt->elements_requested);
350 Tensor element;
351 attempt->context->SetStatus(attempt->context->allocate_temp(
352 component_dtypes_[i], shape, &element));
353 if (!attempt->context->status().ok()) return kComplete;
354 attempt->tuple.emplace_back(element);
355 }
356 }
357 result = kProgress;
358 Tuple tuple;
359 DequeueLocked(attempt->context, &tuple);
360 const int index =
361 attempt->tuple[0].dim_size(0) - attempt->elements_requested;
362 for (int i = 0; i < num_components(); ++i) {
363 attempt->context->SetStatus(batch_util::CopyElementToSlice(
364 std::move(tuple[i]), &attempt->tuple[i], index));
365 if (!attempt->context->status().ok()) return kComplete;
366 }
367 tuple.clear();
368 --attempt->elements_requested;
369 if (attempt->elements_requested == 0) {
370 tuple = attempt->tuple;
371 attempt->done_callback = [callback, tuple]() {
372 callback(tuple);
373 };
374 return kComplete;
375 }
376 }
377 return result;
378 });
379 }
380 }
381 if (!already_cancelled) {
382 FlushUnlocked();
383 } else {
384 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
385 callback(Tuple());
386 }
387 }
388
MatchesNodeDef(const NodeDef & node_def)389 Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) {
390 if (!MatchesNodeDefOp(node_def, "PriorityQueue").ok() &&
391 !MatchesNodeDefOp(node_def, "PriorityQueueV2").ok()) {
392 return errors::InvalidArgument("Expected PriorityQueue, found ",
393 node_def.op());
394 }
395 TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
396 TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def));
397 TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def));
398 return Status::OK();
399 }
400
MatchesPriorityNodeDefTypes(const NodeDef & node_def) const401 Status PriorityQueue::MatchesPriorityNodeDefTypes(
402 const NodeDef& node_def) const {
403 DataTypeVector requested_dtypes;
404 TF_RETURN_IF_ERROR(
405 GetNodeAttr(node_def, "component_types", &requested_dtypes));
406 requested_dtypes.insert(requested_dtypes.begin(), DT_INT64);
407 if (requested_dtypes != component_dtypes_) {
408 return errors::InvalidArgument("Shared queue '", name_,
409 "' has component types ",
410 DataTypeSliceString(component_dtypes_),
411 " but requested component types were ",
412 DataTypeSliceString(requested_dtypes));
413 }
414 return Status::OK();
415 }
416
MatchesPriorityNodeDefShapes(const NodeDef & node_def) const417 Status PriorityQueue::MatchesPriorityNodeDefShapes(
418 const NodeDef& node_def) const {
419 std::vector<TensorShape> requested_shapes;
420 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
421 requested_shapes.insert(requested_shapes.begin(), TensorShape({}));
422 if (requested_shapes != component_shapes_) {
423 return errors::InvalidArgument("Shared queue '", name_,
424 "' has component shapes ",
425 ShapeListString(component_shapes_),
426 " but requested component shapes were ",
427 ShapeListString(requested_shapes));
428 }
429 return Status::OK();
430 }
431
432 } // namespace tensorflow
433