1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // See docs in ../ops/data_flow_ops.cc.
16
17 #include <deque>
18 #include <queue>
19 #include <vector>
20
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/kernels/priority_queue.h"
26 #include "tensorflow/core/kernels/queue_base.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/priority_queue_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/batch_util.h"
33
34 namespace tensorflow {
35
PriorityQueue(int32 capacity,const DataTypeVector & component_dtypes,const std::vector<TensorShape> & component_shapes,const string & name)36 PriorityQueue::PriorityQueue(int32 capacity,
37 const DataTypeVector& component_dtypes,
38 const std::vector<TensorShape>& component_shapes,
39 const string& name)
40 : TypedQueue(capacity, component_dtypes, component_shapes, name) {}
41
Initialize()42 Status PriorityQueue::Initialize() {
43 Status s = TypedQueue::Initialize();
44 if (!s.ok()) return s;
45
46 mutex_lock lock(mu_);
47 if (component_dtypes_[0] != DT_INT64) {
48 return errors::InvalidArgument(
49 "PriorityQueue priority index component must be type int64, but "
50 "dtype is: ",
51 DataTypeString(component_dtypes_[0]));
52 }
53 if (specified_shapes() && !TensorShapeUtils::IsScalar(component_shapes_[0])) {
54 return errors::InvalidArgument(
55 "PriorityQueue priority index component must be a scalar, but shape "
56 "is: ",
57 component_shapes_[0].DebugString());
58 }
59 return Status::OK();
60 }
61
DequeueLocked(OpKernelContext * ctx,Tuple * tuple)62 void PriorityQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
63 DCHECK_GT(queues_[0].size(), 0);
64 (*tuple).reserve(num_components());
65 for (int i = 0; i < num_components(); ++i) {
66 PersistentTensor persistent_tensor = gtl::ConsumeTop(&queues_[i]).second;
67 (*tuple).push_back(*persistent_tensor.AccessTensor(ctx));
68 }
69 }
70
TryEnqueue(const Tuple & tuple,OpKernelContext * ctx,DoneCallback callback)71 void PriorityQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
72 DoneCallback callback) {
73 CancellationManager* cm = ctx->cancellation_manager();
74 CancellationToken token = cm->get_cancellation_token();
75 bool already_cancelled;
76 {
77 mutex_lock l(mu_);
78 already_cancelled = !cm->RegisterCallback(
79 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
80 if (!already_cancelled) {
81 enqueue_attempts_.emplace_back(
82 1, callback, ctx, cm, token,
83 [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
84 if (closed_) {
85 attempt->context->SetStatus(
86 errors::Cancelled("PriorityQueue '", name_, "' is closed."));
87 return kComplete;
88 }
89 if (queues_[0].size() < static_cast<size_t>(capacity_)) {
90 if (!TensorShapeUtils::IsScalar(tuple[0].shape())) {
91 attempt->context->SetStatus(errors::InvalidArgument(
92 "Expected the priority element to be a scalar, but "
93 "received shape: ",
94 tuple[0].shape().DebugString()));
95 return kComplete;
96 }
97 const int64 priority = tuple[0].scalar<int64>()();
98 for (int i = 0; i < num_components(); ++i) {
99 queues_[i].emplace(priority, PersistentTensor(tuple[i]));
100 }
101 return kComplete;
102 } else {
103 return kNoProgress;
104 }
105 });
106 }
107 }
108 if (!already_cancelled) {
109 FlushUnlocked();
110 } else {
111 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
112 callback();
113 }
114 }
115
116 /* static */
GetElementComponentFromBatch(const PriorityQueue::Tuple & tuple,int index,int component,OpKernelContext * ctx,PersistentTensor * out_tensor)117 Status PriorityQueue::GetElementComponentFromBatch(
118 const PriorityQueue::Tuple& tuple, int index, int component,
119 OpKernelContext* ctx, PersistentTensor* out_tensor) {
120 TensorShape element_shape(tuple[component].shape());
121 element_shape.RemoveDim(0);
122 Tensor* element_access = nullptr;
123 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
124 tuple[component].dtype(), element_shape, out_tensor, &element_access));
125 TF_RETURN_IF_ERROR(
126 batch_util::CopySliceToElement(tuple[component], element_access, index));
127 return Status::OK();
128 }
129
TryEnqueueMany(const Tuple & tuple,OpKernelContext * ctx,DoneCallback callback)130 void PriorityQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
131 DoneCallback callback) {
132 const int64 batch_size = tuple[0].dim_size(0);
133 if (batch_size == 0) {
134 callback();
135 return;
136 }
137
138 CancellationManager* cm = ctx->cancellation_manager();
139 CancellationToken token = cm->get_cancellation_token();
140 bool already_cancelled;
141 {
142 mutex_lock l(mu_);
143 already_cancelled = !cm->RegisterCallback(
144 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
145 if (!already_cancelled) {
146 enqueue_attempts_.emplace_back(
147 batch_size, callback, ctx, cm, token,
148 [tuple, this, ctx](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
149 if (closed_) {
150 attempt->context->SetStatus(
151 errors::Cancelled("PriorityQueue '", name_, "' is closed."));
152 return kComplete;
153 }
154 RunResult result = kNoProgress;
155 while (queues_[0].size() < static_cast<size_t>(capacity_)) {
156 result = kProgress;
157 const int index =
158 tuple[0].dim_size(0) - attempt->elements_requested;
159
160 PersistentTensor priority_element;
161 attempt->context->SetStatus(GetElementComponentFromBatch(
162 tuple, index, 0, attempt->context, &priority_element));
163 if (!attempt->context->status().ok()) return kComplete;
164 Tensor* priority_tensor = priority_element.AccessTensor(ctx);
165 if (!TensorShapeUtils::IsScalar(priority_tensor->shape())) {
166 attempt->context->SetStatus(errors::InvalidArgument(
167 "Expected the priority element to be a scalar, but "
168 "received shape: ",
169 priority_tensor->shape().DebugString()));
170 return kComplete;
171 }
172 const int64 priority = priority_tensor->scalar<int64>()();
173 for (int i = 0; i < num_components(); ++i) {
174 PersistentTensor element;
175 attempt->context->SetStatus(GetElementComponentFromBatch(
176 tuple, index, i, attempt->context, &element));
177 if (!attempt->context->status().ok()) return kComplete;
178 queues_[i].emplace(priority, element);
179 }
180 --attempt->elements_requested;
181 if (attempt->elements_requested == 0) {
182 return kComplete;
183 }
184 }
185 return result;
186 });
187 }
188 }
189 if (!already_cancelled) {
190 FlushUnlocked();
191 } else {
192 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
193 callback();
194 }
195 }
196
TryDequeue(OpKernelContext * ctx,CallbackWithTuple callback)197 void PriorityQueue::TryDequeue(OpKernelContext* ctx,
198 CallbackWithTuple callback) {
199 CancellationManager* cm = ctx->cancellation_manager();
200 CancellationToken token = cm->get_cancellation_token();
201 bool already_cancelled;
202 {
203 mutex_lock l(mu_);
204 already_cancelled = !cm->RegisterCallback(
205 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
206 if (!already_cancelled) {
207 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
208 dequeue_attempts_.emplace_back(
209 1, [callback]() { callback(Tuple()); }, ctx, cm, token,
210 [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
211 const int32 s = queues_[0].size();
212 if (closed_ && s == 0) {
213 attempt->context->SetStatus(errors::OutOfRange(
214 "PriorityQueue '", name_, "' is closed and has ",
215 "insufficient elements (requested ", 1, ", current size ", s,
216 ")"));
217 return kComplete;
218 }
219 if (s > 0) {
220 Tuple tuple;
221 DequeueLocked(attempt->context, &tuple);
222 attempt->done_callback = [callback, tuple]() { callback(tuple); };
223 return kComplete;
224 } else {
225 return kNoProgress;
226 }
227 });
228 }
229 }
230 if (!already_cancelled) {
231 FlushUnlocked();
232 } else {
233 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
234 callback(Tuple());
235 }
236 }
237
TryDequeueMany(int num_elements,OpKernelContext * ctx,bool allow_small_batch,CallbackWithTuple callback)238 void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
239 bool allow_small_batch,
240 CallbackWithTuple callback) {
241 if (!specified_shapes()) {
242 ctx->SetStatus(
243 errors::InvalidArgument("PriorityQueue's DequeueMany requires the "
244 "components to have specified shapes."));
245 callback(Tuple());
246 return;
247 }
248 if (num_elements == 0) {
249 Tuple tuple;
250 tuple.reserve(num_components());
251 for (int i = 0; i < num_components(); ++i) {
252 // TODO(josh11b,misard): Switch to allocate_output(). Problem is
253 // this breaks the abstraction boundary since we don't *really*
254 // know if and how the Tensors in the tuple we pass to callback
255 // correspond to the outputs of *ctx. For example, the
256 // ReaderRead Op uses TryDequeue() to get a filename out of a
257 // queue that is used internally by the reader and is not
258 // associated with any output of the ReaderRead.
259 // mrry@ adds:
260 // Maybe we need to pass a std::function<Tensor*(...)> (or
261 // better signature) that calls the appropriate allocator
262 // function in addition to ctx? (Or support a shim Allocator
263 // that has an internal OpKernelContext*, and dispatches to the
264 // appropriate method?)
265 // misard@ adds:
266 // I don't see that a std::function would help. The problem is
267 // that at this point (allocation time) the system doesn't know
268 // what is going to happen to the element read out of the
269 // queue. As long as we keep the generality that TensorFlow Ops
270 // do their own dynamic allocation in arbitrary C++ code, we
271 // need to preserve robustness to allocating output Tensors with
272 // the 'wrong' attributes, and fixing up with a copy. The only
273 // improvement I can see here in the future would be to support
274 // an optimized case where the queue 'knows' what attributes to
275 // use, and plumbs them through here.
276 Tensor element;
277 Status status = ctx->allocate_temp(component_dtypes_[i],
278 ManyOutShape(i, 0), &element);
279 if (!status.ok()) {
280 ctx->SetStatus(status);
281 callback(Tuple());
282 return;
283 }
284 tuple.emplace_back(element);
285 }
286 callback(tuple);
287 return;
288 }
289
290 CancellationManager* cm = ctx->cancellation_manager();
291 CancellationToken token = cm->get_cancellation_token();
292 bool already_cancelled;
293 {
294 mutex_lock l(mu_);
295 already_cancelled = !cm->RegisterCallback(
296 token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
297 if (!already_cancelled) {
298 // TODO(josh11b): This makes two copies of callback, avoid this if possible.
299 dequeue_attempts_.emplace_back(
300 num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
301 [callback, this,
302 allow_small_batch](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
303 int32 s = queues_[0].size();
304 // Return OutOfRange if closed and there are fewer elements
305 // available than requested. *Unless* allow_small_batch
306 // is true, in which case we return as many elements as
307 // possible.
308 if (closed_) {
309 if (s == 0 ||
310 (!allow_small_batch && s < attempt->elements_requested)) {
311 attempt->context->SetStatus(errors::OutOfRange(
312 "PriorityQueue '", name_, "' is closed and has ",
313 "insufficient elements (requested ",
314 attempt->elements_requested, ", current size ", s, ")"));
315 return kComplete;
316 }
317 }
318
319 // The PriorityQueue is expected to always return a
320 // sorted set of entries. In order to do this, the underlying
321 // queue must have at least this many entries already.
322 // Doing the dynamic thing and pulling out a portion at a
323 // time leads to unordered output in calls to DequeueMany.
324 //
325 // An alternative solution is to store the attempt tuple
326 // entries in an identical priority_queue and push onto
327 // this queue dynamically, then when it is full, do all
328 // the Tensor concatenation at the very end.
329 // TODO(ebrevdo): Change approach if this leads to locking issues.
330 if (s < attempt->elements_requested) {
331 // If we have no elements at all, then wait.
332 // Otherwise proceed if closed and allow small batch is true.
333 // Otherwise wait until we have more enqueued elements.
334 if (s == 0 || !(closed_ && allow_small_batch)) {
335 return kNoProgress;
336 }
337 }
338
339 RunResult result = kNoProgress;
340 for (; s > 0; --s) {
341 if (attempt->tuple.empty()) {
342 // Only allocate tuple when we have something to dequeue
343 // so we don't use excessive memory when there are many
344 // blocked dequeue attempts waiting.
345 attempt->tuple.reserve(num_components());
346 for (int i = 0; i < num_components(); ++i) {
347 const TensorShape shape =
348 ManyOutShape(i, attempt->elements_requested);
349 Tensor element;
350 attempt->context->SetStatus(attempt->context->allocate_temp(
351 component_dtypes_[i], shape, &element));
352 if (!attempt->context->status().ok()) return kComplete;
353 attempt->tuple.emplace_back(element);
354 }
355 }
356 result = kProgress;
357 Tuple tuple;
358 DequeueLocked(attempt->context, &tuple);
359 const int index =
360 attempt->tuple[0].dim_size(0) - attempt->elements_requested;
361 for (int i = 0; i < num_components(); ++i) {
362 attempt->context->SetStatus(batch_util::CopyElementToSlice(
363 std::move(tuple[i]), &attempt->tuple[i], index));
364 if (!attempt->context->status().ok()) return kComplete;
365 }
366 tuple.clear();
367 --attempt->elements_requested;
368 if (attempt->elements_requested == 0) {
369 tuple = attempt->tuple;
370 attempt->done_callback = [callback, tuple]() {
371 callback(tuple);
372 };
373 return kComplete;
374 }
375 }
376 return result;
377 });
378 }
379 }
380 if (!already_cancelled) {
381 FlushUnlocked();
382 } else {
383 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
384 callback(Tuple());
385 }
386 }
387
MatchesNodeDef(const NodeDef & node_def)388 Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) {
389 if (!MatchesNodeDefOp(node_def, "PriorityQueue").ok() &&
390 !MatchesNodeDefOp(node_def, "PriorityQueueV2").ok()) {
391 return errors::InvalidArgument("Expected PriorityQueue, found ",
392 node_def.op());
393 }
394 TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
395 TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def));
396 TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def));
397 return Status::OK();
398 }
399
MatchesPriorityNodeDefTypes(const NodeDef & node_def) const400 Status PriorityQueue::MatchesPriorityNodeDefTypes(
401 const NodeDef& node_def) const {
402 DataTypeVector requested_dtypes;
403 TF_RETURN_IF_ERROR(
404 GetNodeAttr(node_def, "component_types", &requested_dtypes));
405 requested_dtypes.insert(requested_dtypes.begin(), DT_INT64);
406 if (requested_dtypes != component_dtypes_) {
407 return errors::InvalidArgument("Shared queue '", name_,
408 "' has component types ",
409 DataTypeSliceString(component_dtypes_),
410 " but requested component types were ",
411 DataTypeSliceString(requested_dtypes));
412 }
413 return Status::OK();
414 }
415
MatchesPriorityNodeDefShapes(const NodeDef & node_def) const416 Status PriorityQueue::MatchesPriorityNodeDefShapes(
417 const NodeDef& node_def) const {
418 std::vector<TensorShape> requested_shapes;
419 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
420 requested_shapes.insert(requested_shapes.begin(), TensorShape({}));
421 if (requested_shapes != component_shapes_) {
422 return errors::InvalidArgument("Shared queue '", name_,
423 "' has component shapes ",
424 ShapeListString(component_shapes_),
425 " but requested component shapes were ",
426 ShapeListString(requested_shapes));
427 }
428 return Status::OK();
429 }
430
431 } // namespace tensorflow
432