1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17 #include <functional>
18 #include <map>
19 #include <mutex>
20 #include <numeric>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/gtl/optional.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33
34 namespace tensorflow {
35 namespace {
36
37 // Partial Ordering Comparator for Tensor keys containing scalar int64's
38 struct KeyTensorLess {
operator ()tensorflow::__anond7e31e9a0111::KeyTensorLess39 bool operator()(const Tensor& lhs, const Tensor& rhs) const {
40 return std::less<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
41 }
42 };
43
44 // Key Equality operator for Tensor keys containing scalar int64's
45 struct KeyTensorEqual {
operator ()tensorflow::__anond7e31e9a0111::KeyTensorEqual46 bool operator()(const Tensor& lhs, const Tensor& rhs) const {
47 return std::equal_to<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
48 }
49 };
50
51 // Hash for Tensor keys containing scalar int64's
52 struct KeyTensorHash {
operator ()tensorflow::__anond7e31e9a0111::KeyTensorHash53 std::size_t operator()(const Tensor& key) const {
54 return std::hash<int64>{}(key.scalar<int64>()());
55 }
56 };
57
58 // Primary template.
59 template <bool Ordered, typename Data>
60 struct MapTraits;
61
62 // Partial specialization for ordered.
63 template <typename Data>
64 struct MapTraits<true, Data> {
65 using KeyType = Tensor;
66 using DataType = Data;
67 using MapType = std::map<KeyType, Data, KeyTensorLess>;
68 };
69
70 // Partial specialization for unordered.
71 template <typename Data>
72 struct MapTraits<false, Data> {
73 using KeyType = Tensor;
74 using DataType = Data;
75 using MapType =
76 std::unordered_map<KeyType, Data, KeyTensorHash, KeyTensorEqual>;
77 };
78
79 // Wrapper around map/unordered_map.
80 template <bool Ordered>
81 class StagingMap : public ResourceBase {
82 public:
83 // Public typedefs
84 using Tuple = std::vector<Tensor>;
85 using OptionalTensor = gtl::optional<Tensor>;
86 using OptionalTuple = std::vector<OptionalTensor>;
87
88 using MapType = typename MapTraits<Ordered, OptionalTuple>::MapType;
89 using KeyType = typename MapTraits<Ordered, OptionalTuple>::KeyType;
90
91 using IncompleteType = typename MapTraits<false, OptionalTuple>::MapType;
92
93 private:
94 // Private variables
95 DataTypeVector dtypes_ TF_GUARDED_BY(mu_);
96 std::size_t capacity_ TF_GUARDED_BY(mu_);
97 std::size_t memory_limit_ TF_GUARDED_BY(mu_);
98 std::size_t current_bytes_ TF_GUARDED_BY(mu_);
99 tensorflow::mutex mu_;
100 tensorflow::condition_variable not_empty_;
101 tensorflow::condition_variable full_;
102 IncompleteType incomplete_ TF_GUARDED_BY(mu_);
103 MapType map_ TF_GUARDED_BY(mu_);
104
105 private:
106 // private methods
107
108 // If map is configured for bounded capacity, notify
109 // waiting inserters that space is now available
notify_inserters_if_bounded()110 void notify_inserters_if_bounded() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
111 if (has_capacity() || has_memory_limit()) {
112 // Notify all inserters. The removal of an element
113 // may make memory available for many inserters
114 // to insert new elements
115 full_.notify_all();
116 }
117 }
118
119 // Notify all removers waiting to extract values
120 // that data is now available
notify_removers()121 void notify_removers() {
122 // Notify all removers. This is because they are
123 // waiting for specific keys to appear in the map
124 // so we don't know which one to wake up.
125 not_empty_.notify_all();
126 }
127
has_capacity() const128 bool has_capacity() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
129 return capacity_ > 0;
130 }
131
has_memory_limit() const132 bool has_memory_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
133 return memory_limit_ > 0;
134 }
135
would_exceed_memory_limit(std::size_t bytes) const136 bool would_exceed_memory_limit(std::size_t bytes) const
137 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
138 return has_memory_limit() && bytes + current_bytes_ > memory_limit_;
139 }
140
is_capacity_full() const141 bool is_capacity_full() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
142 return has_capacity() && map_.size() >= capacity_;
143 }
144
145 // Get number of bytes in the tuple
get_tuple_bytes(const Tuple & tuple)146 std::size_t get_tuple_bytes(const Tuple& tuple) {
147 return std::accumulate(tuple.begin(), tuple.end(),
148 static_cast<std::size_t>(0),
149 [](const std::size_t& lhs, const Tensor& rhs) {
150 return lhs + rhs.TotalBytes();
151 });
152 }
153
154 // Get number of bytes in the incomplete tuple
get_tuple_bytes(const OptionalTuple & tuple)155 std::size_t get_tuple_bytes(const OptionalTuple& tuple) {
156 return std::accumulate(
157 tuple.begin(), tuple.end(), static_cast<std::size_t>(0),
158 [](const std::size_t& lhs, const OptionalTensor& rhs) {
159 return (lhs + rhs.has_value()) ? rhs.value().TotalBytes() : 0;
160 });
161 }
162
163 // Check that the index is within bounds
check_index(const Tensor & key,std::size_t index)164 Status check_index(const Tensor& key, std::size_t index)
165 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
166 if (index >= dtypes_.size()) {
167 return Status(errors::InvalidArgument(
168 "Index '", index, "' for key '", key.scalar<int64>()(),
169 "' was out of bounds '", dtypes_.size(), "'."));
170 }
171
172 return Status::OK();
173 }
174
copy_or_move_tensors(OptionalTuple * map_tuple,const Tensor & key,const Tensor & indices,Tuple * output,bool copy=false)175 Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key,
176 const Tensor& indices, Tuple* output,
177 bool copy = false)
178 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
179 auto findices = indices.flat<int>();
180
181 // Return values at specified indices
182 for (std::size_t i = 0; i < findices.dimension(0); ++i) {
183 std::size_t index = findices(i);
184
185 TF_RETURN_IF_ERROR(check_index(key, index));
186
187 // Insist on a value present at the specified index
188 if (!(*map_tuple)[index].has_value()) {
189 return Status(errors::InvalidArgument(
190 "Tensor at index '", index, "' for key '", key.scalar<int64>()(),
191 "' has already been removed."));
192 }
193
194 // Copy the contained tensor and
195 // remove from the OptionalTuple
196 output->push_back((*map_tuple)[index].value());
197
198 // Clear out the entry if we're not copying (moving)
199 if (!copy) {
200 (*map_tuple)[index].reset();
201 }
202 }
203
204 return Status::OK();
205 }
206
207 // Check that the optional value at the specified index
208 // is uninitialized
check_index_uninitialized(const Tensor & key,std::size_t index,const OptionalTuple & tuple)209 Status check_index_uninitialized(const Tensor& key, std::size_t index,
210 const OptionalTuple& tuple)
211 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
212 if (tuple[index].has_value()) {
213 return Status(errors::InvalidArgument(
214 "The tensor for index '", index, "' for key '", key.scalar<int64>()(),
215 "' was already initialized '", dtypes_.size(), "'."));
216 }
217
218 return Status::OK();
219 }
220
221 // Check that the indices are strictly ordered
check_index_ordering(const Tensor & indices)222 Status check_index_ordering(const Tensor& indices) {
223 auto findices = indices.flat<int>();
224
225 for (std::size_t i = 0; i < findices.dimension(0) - 1; ++i) {
226 if (findices(i) < findices(i + 1)) {
227 continue;
228 }
229
230 return Status(
231 errors::InvalidArgument("Indices are not strictly ordered"));
232 }
233
234 return Status::OK();
235 }
236
237 // Check bytes are within memory limits memory limits
check_memory_limit(std::size_t bytes)238 Status check_memory_limit(std::size_t bytes)
239 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
240 if (has_memory_limit() && bytes > memory_limit_) {
241 return Status(errors::ResourceExhausted(
242 "Attempted to insert tensors with combined size of '", bytes,
243 "' bytes into Staging Area with a memory limit of '", memory_limit_,
244 "'."));
245 }
246
247 return Status::OK();
248 }
249
250 // Insert incomplete data into the Barrier
put_incomplete(const KeyType & key,const Tensor & indices,OptionalTuple * tuple,tensorflow::mutex_lock * lock)251 Status put_incomplete(const KeyType& key, const Tensor& indices,
252 OptionalTuple* tuple, tensorflow::mutex_lock* lock)
253 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
254 auto findices = indices.flat<int>();
255
256 // Search for the key in our incomplete set
257 auto it = incomplete_.find(key);
258
259 // Check that the tuple fits within the memory limit
260 std::size_t tuple_bytes = get_tuple_bytes(*tuple);
261 TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
262
263 // Wait until we don't exceed the memory limit
264 while (would_exceed_memory_limit(tuple_bytes)) {
265 full_.wait(*lock);
266 }
267
268 // This key isn't present in the incomplete set
269 // Create OptionalTuple and insert
270 if (it == incomplete_.end()) {
271 OptionalTuple empty(dtypes_.size());
272
273 // Initialize empty tuple with given dta
274 for (std::size_t i = 0; i < findices.dimension(0); ++i) {
275 std::size_t index = findices(i);
276 TF_RETURN_IF_ERROR(check_index(key, index));
277
278 // Assign tuple at this index
279 empty[index] = std::move((*tuple)[i]);
280 }
281
282 // Insert into incomplete map
283 incomplete_.insert({key, std::move(empty)});
284
285 // Increment size
286 current_bytes_ += tuple_bytes;
287 }
288 // Found an entry in the incomplete index
289 // Update with given data and insert complete entries
290 // into the main map
291 else {
292 // Reference existing incomplete tuple
293 OptionalTuple& present = it->second;
294
295 // Assign given data
296 for (std::size_t i = 0; i < findices.dimension(0); ++i) {
297 std::size_t index = findices(i);
298 TF_RETURN_IF_ERROR(check_index(key, index));
299 TF_RETURN_IF_ERROR(check_index_uninitialized(key, index, present));
300
301 // Assign tuple at this index
302 present[index] = std::move((*tuple)[i]);
303 }
304
305 // Increment size
306 current_bytes_ += tuple_bytes;
307
308 // Do we have values at all tuple elements?
309 bool complete =
310 std::all_of(present.begin(), present.end(),
311 [](const OptionalTensor& v) { return v.has_value(); });
312
313 // If so, put the tuple in the actual map
314 if (complete) {
315 OptionalTuple insert_tuple = std::move(it->second);
316
317 // Remove from incomplete
318 incomplete_.erase(it);
319
320 TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple));
321 }
322 }
323
324 return Status::OK();
325 }
326
327 // Does the insertion into the actual staging area
put_complete(const KeyType & key,OptionalTuple * tuple)328 Status put_complete(const KeyType& key, OptionalTuple* tuple)
329 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
330 // Insert key and tuples into the map
331 map_.insert({key, std::move(*tuple)});
332
333 notify_removers();
334
335 return Status::OK();
336 }
337
338 public:
339 // public methods
StagingMap(const DataTypeVector & dtypes,std::size_t capacity,std::size_t memory_limit)340 explicit StagingMap(const DataTypeVector& dtypes, std::size_t capacity,
341 std::size_t memory_limit)
342 : dtypes_(dtypes),
343 capacity_(capacity),
344 memory_limit_(memory_limit),
345 current_bytes_(0) {}
346
put(KeyType * key,const Tensor * indices,OptionalTuple * tuple)347 Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) {
348 tensorflow::mutex_lock lock(mu_);
349
350 // Sanity check the indices
351 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
352
353 // Handle incomplete inserts
354 if (indices->NumElements() != dtypes_.size()) {
355 return put_incomplete(*key, *indices, tuple, &lock);
356 }
357
358 std::size_t tuple_bytes = get_tuple_bytes(*tuple);
359 // Check that tuple_bytes fits within the memory limit
360 TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
361
362 // Wait until there's space for insertion.
363 while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) {
364 full_.wait(lock);
365 }
366
367 // Do the put operation
368 TF_RETURN_IF_ERROR(put_complete(*key, tuple));
369
370 // Update the current size
371 current_bytes_ += tuple_bytes;
372
373 return Status::OK();
374 }
375
get(const KeyType * key,const Tensor * indices,Tuple * tuple)376 Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) {
377 tensorflow::mutex_lock lock(mu_);
378
379 // Sanity check the indices
380 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
381
382 typename MapType::iterator it;
383
384 // Wait until the element with the requested key is present
385 while ((it = map_.find(*key)) == map_.end()) {
386 not_empty_.wait(lock);
387 }
388
389 TF_RETURN_IF_ERROR(
390 copy_or_move_tensors(&it->second, *key, *indices, tuple, true));
391
392 // Update bytes in the Staging Area
393 current_bytes_ -= get_tuple_bytes(*tuple);
394
395 return Status::OK();
396 }
397
pop(const KeyType * key,const Tensor * indices,Tuple * tuple)398 Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) {
399 tensorflow::mutex_lock lock(mu_);
400
401 // Sanity check the indices
402 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
403
404 typename MapType::iterator it;
405
406 // Wait until the element with the requested key is present
407 while ((it = map_.find(*key)) == map_.end()) {
408 not_empty_.wait(lock);
409 }
410
411 TF_RETURN_IF_ERROR(
412 copy_or_move_tensors(&it->second, *key, *indices, tuple));
413
414 // Remove entry if all the values have been consumed
415 if (!std::any_of(
416 it->second.begin(), it->second.end(),
417 [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
418 map_.erase(it);
419 }
420
421 // Update bytes in the Staging Area
422 current_bytes_ -= get_tuple_bytes(*tuple);
423
424 notify_inserters_if_bounded();
425
426 return Status::OK();
427 }
428
popitem(KeyType * key,const Tensor * indices,Tuple * tuple)429 Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) {
430 tensorflow::mutex_lock lock(mu_);
431
432 // Sanity check the indices
433 TF_RETURN_IF_ERROR(check_index_ordering(*indices));
434
435 // Wait until map is not empty
436 while (this->map_.empty()) {
437 not_empty_.wait(lock);
438 }
439
440 // Move from the first element and erase it
441
442 auto it = map_.begin();
443
444 TF_RETURN_IF_ERROR(
445 copy_or_move_tensors(&it->second, *key, *indices, tuple));
446
447 *key = it->first;
448
449 // Remove entry if all the values have been consumed
450 if (!std::any_of(
451 it->second.begin(), it->second.end(),
452 [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
453 map_.erase(it);
454 }
455
456 // Update bytes in the Staging Area
457 current_bytes_ -= get_tuple_bytes(*tuple);
458
459 notify_inserters_if_bounded();
460
461 return Status::OK();
462 }
463
clear()464 Status clear() {
465 tensorflow::mutex_lock lock(mu_);
466 map_.clear();
467 incomplete_.clear();
468 current_bytes_ = 0;
469
470 notify_inserters_if_bounded();
471
472 return Status::OK();
473 }
474
incomplete_size()475 std::size_t incomplete_size() {
476 tensorflow::mutex_lock lock(mu_);
477 return incomplete_.size();
478 }
479
size()480 std::size_t size() {
481 tensorflow::mutex_lock lock(mu_);
482 return map_.size();
483 }
484
DebugString() const485 string DebugString() const override { return "StagingMap"; }
486 };
487
488 template <bool Ordered>
GetStagingMap(OpKernelContext * ctx,const NodeDef & ndef,StagingMap<Ordered> ** map)489 Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef,
490 StagingMap<Ordered>** map) {
491 auto rm = ctx->resource_manager();
492 ContainerInfo cinfo;
493
494 // Lambda for creating the Staging Area
495 auto create_fn = [&ndef](StagingMap<Ordered>** ret) -> Status {
496 DataTypeVector dtypes;
497 int64 capacity;
498 int64 memory_limit;
499 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "dtypes", &dtypes));
500 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
501 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
502 *ret = new StagingMap<Ordered>(dtypes, capacity, memory_limit);
503 return Status::OK();
504 };
505
506 TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
507 TF_RETURN_IF_ERROR(rm->LookupOrCreate<StagingMap<Ordered>>(
508 cinfo.container(), cinfo.name(), map, create_fn));
509 return Status::OK();
510 }
511
512 template <bool Ordered>
513 class MapStageOp : public OpKernel {
514 public:
MapStageOp(OpKernelConstruction * ctx)515 explicit MapStageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
516
Compute(OpKernelContext * ctx)517 void Compute(OpKernelContext* ctx) override {
518 StagingMap<Ordered>* map = nullptr;
519 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
520 core::ScopedUnref scope(map);
521 typename StagingMap<Ordered>::OptionalTuple tuple;
522
523 const Tensor* key_tensor;
524 const Tensor* indices_tensor;
525 OpInputList values_tensor;
526
527 OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
528 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
529 OP_REQUIRES_OK(ctx, ctx->input_list("values", &values_tensor));
530
531 // Create copy for insertion into Staging Area
532 Tensor key(*key_tensor);
533
534 // Create the tuple to store
535 for (std::size_t i = 0; i < values_tensor.size(); ++i) {
536 tuple.push_back(values_tensor[i]);
537 }
538
539 // Store the tuple in the map
540 OP_REQUIRES_OK(ctx, map->put(&key, indices_tensor, &tuple));
541 }
542 };
543
544 REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), MapStageOp<false>);
545 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage").Device(DEVICE_CPU),
546 MapStageOp<true>);
547
548 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
549 REGISTER_KERNEL_BUILDER(
550 Name("MapStage").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
551 MapStageOp<false>);
552 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
553 .HostMemory("key")
554 .HostMemory("indices")
555 .Device(DEVICE_GPU),
556 MapStageOp<true>);
557 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
558
559
560 template <bool Ordered>
561 class MapUnstageOp : public OpKernel {
562 public:
MapUnstageOp(OpKernelConstruction * ctx)563 explicit MapUnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
564
565 // Using this op in such a way that it blocks forever
566 // is an error. As such cancellation is not handled.
Compute(OpKernelContext * ctx)567 void Compute(OpKernelContext* ctx) override {
568 StagingMap<Ordered>* map = nullptr;
569 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
570 core::ScopedUnref scope(map);
571 typename StagingMap<Ordered>::Tuple tuple;
572
573 const Tensor* key_tensor;
574 const Tensor* indices_tensor;
575
576 OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
577 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
578 OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple));
579
580 OP_REQUIRES(
581 ctx, tuple.size() == indices_tensor->NumElements(),
582 errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
583 " vs. ", indices_tensor->NumElements()));
584
585 for (std::size_t i = 0; i < tuple.size(); ++i) {
586 ctx->set_output(i, tuple[i]);
587 }
588 }
589 };
590
591 REGISTER_KERNEL_BUILDER(Name("MapUnstage").Device(DEVICE_CPU),
592 MapUnstageOp<false>);
593 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU),
594 MapUnstageOp<true>);
595
596 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
597 REGISTER_KERNEL_BUILDER(Name("MapUnstage")
598 .HostMemory("key")
599 .HostMemory("indices")
600 .Device(DEVICE_GPU),
601 MapUnstageOp<false>);
602 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
603 .HostMemory("key")
604 .HostMemory("indices")
605 .Device(DEVICE_GPU),
606 MapUnstageOp<true>);
607 #endif
608
609 template <bool Ordered>
610 class MapPeekOp : public OpKernel {
611 public:
MapPeekOp(OpKernelConstruction * ctx)612 explicit MapPeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
613
614 // Using this op in such a way that it blocks forever
615 // is an error. As such cancellation is not handled.
Compute(OpKernelContext * ctx)616 void Compute(OpKernelContext* ctx) override {
617 StagingMap<Ordered>* map = nullptr;
618 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
619 core::ScopedUnref scope(map);
620 typename StagingMap<Ordered>::Tuple tuple;
621
622 const Tensor* key_tensor;
623 const Tensor* indices_tensor;
624
625 OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
626 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
627 OP_REQUIRES_OK(ctx, map->get(key_tensor, indices_tensor, &tuple));
628
629 OP_REQUIRES(
630 ctx, tuple.size() == indices_tensor->NumElements(),
631 errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
632 " vs. ", indices_tensor->NumElements()));
633
634 for (std::size_t i = 0; i < tuple.size(); ++i) {
635 ctx->set_output(i, tuple[i]);
636 }
637 }
638 };
639
640 REGISTER_KERNEL_BUILDER(Name("MapPeek").Device(DEVICE_CPU), MapPeekOp<false>);
641 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").Device(DEVICE_CPU),
642 MapPeekOp<true>);
643
644 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
645 REGISTER_KERNEL_BUILDER(
646 Name("MapPeek").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
647 MapPeekOp<false>);
648 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
649 .HostMemory("key")
650 .HostMemory("indices")
651 .Device(DEVICE_GPU),
652 MapPeekOp<true>);
653 #endif
654
655
656 template <bool Ordered>
657 class MapUnstageNoKeyOp : public OpKernel {
658 public:
MapUnstageNoKeyOp(OpKernelConstruction * ctx)659 explicit MapUnstageNoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
660
661 // Using this op in such a way that it blocks forever
662 // is an error. As such cancellation is not handled.
Compute(OpKernelContext * ctx)663 void Compute(OpKernelContext* ctx) override {
664 StagingMap<Ordered>* map = nullptr;
665 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
666 core::ScopedUnref scope(map);
667
668 // Pop a random (key, value) off the map
669 typename StagingMap<Ordered>::KeyType key;
670 typename StagingMap<Ordered>::Tuple tuple;
671
672 const Tensor* indices_tensor;
673
674 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
675 OP_REQUIRES_OK(ctx, map->popitem(&key, indices_tensor, &tuple));
676
677 // Allocate a key tensor and assign the key as the first output
678 ctx->set_output(0, key);
679
680 // Set the rest of the outputs to the tuple Tensors
681 OP_REQUIRES(
682 ctx, tuple.size() == indices_tensor->NumElements(),
683 errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
684 " vs. ", indices_tensor->NumElements()));
685
686 for (std::size_t i = 0; i < tuple.size(); ++i) {
687 ctx->set_output(i + 1, tuple[i]);
688 }
689 }
690 };
691
692 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey").Device(DEVICE_CPU),
693 MapUnstageNoKeyOp<false>);
694 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").Device(DEVICE_CPU),
695 MapUnstageNoKeyOp<true>);
696
697 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
698 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
699 .HostMemory("key")
700 .HostMemory("indices")
701 .Device(DEVICE_GPU),
702 MapUnstageNoKeyOp<false>);
703 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
704 .HostMemory("key")
705 .HostMemory("indices")
706 .Device(DEVICE_GPU),
707 MapUnstageNoKeyOp<true>);
708 #endif
709
710
711 template <bool Ordered>
712 class MapSizeOp : public OpKernel {
713 public:
MapSizeOp(OpKernelConstruction * ctx)714 explicit MapSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
715
Compute(OpKernelContext * ctx)716 void Compute(OpKernelContext* ctx) override {
717 StagingMap<Ordered>* map = nullptr;
718 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
719 core::ScopedUnref scope(map);
720
721 // Allocate size output tensor
722 Tensor* size = nullptr;
723 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
724
725 // Set it to the actual size
726 size->scalar<int32>().setConstant(map->size());
727 }
728 };
729
730 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), MapSizeOp<false>);
731 REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_CPU),
732 MapSizeOp<true>);
733
734 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
735 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_GPU).HostMemory("size"),
736 MapSizeOp<false>);
737 REGISTER_KERNEL_BUILDER(
738 Name("OrderedMapSize").Device(DEVICE_GPU).HostMemory("size"),
739 MapSizeOp<true>);
740 #endif
741
742 template <bool Ordered>
743 class MapIncompleteSizeOp : public OpKernel {
744 public:
MapIncompleteSizeOp(OpKernelConstruction * ctx)745 explicit MapIncompleteSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
746
Compute(OpKernelContext * ctx)747 void Compute(OpKernelContext* ctx) override {
748 StagingMap<Ordered>* map = nullptr;
749 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
750 core::ScopedUnref scope(map);
751
752 // Allocate size output tensor
753 Tensor* size = nullptr;
754 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
755
756 // Set it to the actual size
757 size->scalar<int32>().setConstant(map->incomplete_size());
758 }
759 };
760
761 REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_CPU),
762 MapIncompleteSizeOp<false>);
763 REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_CPU),
764 MapIncompleteSizeOp<true>);
765
766 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
767 REGISTER_KERNEL_BUILDER(
768 Name("MapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
769 MapIncompleteSizeOp<false>);
770 REGISTER_KERNEL_BUILDER(
771 Name("OrderedMapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
772 MapIncompleteSizeOp<true>);
773 #endif
774
775 template <bool Ordered>
776 class MapClearOp : public OpKernel {
777 public:
MapClearOp(OpKernelConstruction * ctx)778 explicit MapClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
779
Compute(OpKernelContext * ctx)780 void Compute(OpKernelContext* ctx) override {
781 StagingMap<Ordered>* map = nullptr;
782 OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
783 core::ScopedUnref scope(map);
784
785 OP_REQUIRES_OK(ctx, map->clear());
786 }
787 };
788
789 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), MapClearOp<false>);
790 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_CPU),
791 MapClearOp<true>);
792
793 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
794 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_GPU), MapClearOp<false>);
795 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_GPU),
796 MapClearOp<true>);
797 #endif
798
799 } // namespace
800 } // namespace tensorflow
801