• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <functional>
18 #include <map>
19 #include <mutex>
20 #include <numeric>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/gtl/optional.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 
34 namespace tensorflow {
35 namespace {
36 
37 // Partial Ordering Comparator for Tensor keys containing scalar int64's
38 struct KeyTensorLess {
operator ()tensorflow::__anond7e31e9a0111::KeyTensorLess39   bool operator()(const Tensor& lhs, const Tensor& rhs) const {
40     return std::less<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
41   }
42 };
43 
44 // Key Equality operator for Tensor keys containing scalar int64's
45 struct KeyTensorEqual {
operator ()tensorflow::__anond7e31e9a0111::KeyTensorEqual46   bool operator()(const Tensor& lhs, const Tensor& rhs) const {
47     return std::equal_to<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
48   }
49 };
50 
51 // Hash for Tensor keys containing scalar int64's
52 struct KeyTensorHash {
operator ()tensorflow::__anond7e31e9a0111::KeyTensorHash53   std::size_t operator()(const Tensor& key) const {
54     return std::hash<int64>{}(key.scalar<int64>()());
55   }
56 };
57 
58 // Primary template.
59 template <bool Ordered, typename Data>
60 struct MapTraits;
61 
62 // Partial specialization for ordered.
63 template <typename Data>
64 struct MapTraits<true, Data> {
65   using KeyType = Tensor;
66   using DataType = Data;
67   using MapType = std::map<KeyType, Data, KeyTensorLess>;
68 };
69 
70 // Partial specialization for unordered.
71 template <typename Data>
72 struct MapTraits<false, Data> {
73   using KeyType = Tensor;
74   using DataType = Data;
75   using MapType =
76       std::unordered_map<KeyType, Data, KeyTensorHash, KeyTensorEqual>;
77 };
78 
79 // Wrapper around map/unordered_map.
80 template <bool Ordered>
81 class StagingMap : public ResourceBase {
82  public:
83   // Public typedefs
84   using Tuple = std::vector<Tensor>;
85   using OptionalTensor = gtl::optional<Tensor>;
86   using OptionalTuple = std::vector<OptionalTensor>;
87 
88   using MapType = typename MapTraits<Ordered, OptionalTuple>::MapType;
89   using KeyType = typename MapTraits<Ordered, OptionalTuple>::KeyType;
90 
91   using IncompleteType = typename MapTraits<false, OptionalTuple>::MapType;
92 
93  private:
94   // Private variables
95   DataTypeVector dtypes_ TF_GUARDED_BY(mu_);
96   std::size_t capacity_ TF_GUARDED_BY(mu_);
97   std::size_t memory_limit_ TF_GUARDED_BY(mu_);
98   std::size_t current_bytes_ TF_GUARDED_BY(mu_);
99   tensorflow::mutex mu_;
100   tensorflow::condition_variable not_empty_;
101   tensorflow::condition_variable full_;
102   IncompleteType incomplete_ TF_GUARDED_BY(mu_);
103   MapType map_ TF_GUARDED_BY(mu_);
104 
105  private:
106   // private methods
107 
108   // If map is configured for bounded capacity, notify
109   // waiting inserters that space is now available
notify_inserters_if_bounded()110   void notify_inserters_if_bounded() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
111     if (has_capacity() || has_memory_limit()) {
112       // Notify all inserters. The removal of an element
113       // may make memory available for many inserters
114       // to insert new elements
115       full_.notify_all();
116     }
117   }
118 
119   // Notify all removers waiting to extract values
120   // that data is now available
notify_removers()121   void notify_removers() {
122     // Notify all removers. This is because they are
123     // waiting for specific keys to appear in the map
124     // so we don't know which one to wake up.
125     not_empty_.notify_all();
126   }
127 
has_capacity() const128   bool has_capacity() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
129     return capacity_ > 0;
130   }
131 
has_memory_limit() const132   bool has_memory_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
133     return memory_limit_ > 0;
134   }
135 
would_exceed_memory_limit(std::size_t bytes) const136   bool would_exceed_memory_limit(std::size_t bytes) const
137       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
138     return has_memory_limit() && bytes + current_bytes_ > memory_limit_;
139   }
140 
is_capacity_full() const141   bool is_capacity_full() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
142     return has_capacity() && map_.size() >= capacity_;
143   }
144 
145   // Get number of bytes in the tuple
get_tuple_bytes(const Tuple & tuple)146   std::size_t get_tuple_bytes(const Tuple& tuple) {
147     return std::accumulate(tuple.begin(), tuple.end(),
148                            static_cast<std::size_t>(0),
149                            [](const std::size_t& lhs, const Tensor& rhs) {
150                              return lhs + rhs.TotalBytes();
151                            });
152   }
153 
154   // Get number of bytes in the incomplete tuple
get_tuple_bytes(const OptionalTuple & tuple)155   std::size_t get_tuple_bytes(const OptionalTuple& tuple) {
156     return std::accumulate(
157         tuple.begin(), tuple.end(), static_cast<std::size_t>(0),
158         [](const std::size_t& lhs, const OptionalTensor& rhs) {
159           return (lhs + rhs.has_value()) ? rhs.value().TotalBytes() : 0;
160         });
161   }
162 
163   // Check that the index is within bounds
check_index(const Tensor & key,std::size_t index)164   Status check_index(const Tensor& key, std::size_t index)
165       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
166     if (index >= dtypes_.size()) {
167       return Status(errors::InvalidArgument(
168           "Index '", index, "' for key '", key.scalar<int64>()(),
169           "' was out of bounds '", dtypes_.size(), "'."));
170     }
171 
172     return Status::OK();
173   }
174 
copy_or_move_tensors(OptionalTuple * map_tuple,const Tensor & key,const Tensor & indices,Tuple * output,bool copy=false)175   Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key,
176                               const Tensor& indices, Tuple* output,
177                               bool copy = false)
178       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
179     auto findices = indices.flat<int>();
180 
181     // Return values at specified indices
182     for (std::size_t i = 0; i < findices.dimension(0); ++i) {
183       std::size_t index = findices(i);
184 
185       TF_RETURN_IF_ERROR(check_index(key, index));
186 
187       // Insist on a value present at the specified index
188       if (!(*map_tuple)[index].has_value()) {
189         return Status(errors::InvalidArgument(
190             "Tensor at index '", index, "' for key '", key.scalar<int64>()(),
191             "' has already been removed."));
192       }
193 
194       // Copy the contained tensor and
195       // remove from the OptionalTuple
196       output->push_back((*map_tuple)[index].value());
197 
198       // Clear out the entry if we're not copying (moving)
199       if (!copy) {
200         (*map_tuple)[index].reset();
201       }
202     }
203 
204     return Status::OK();
205   }
206 
207   // Check that the optional value at the specified index
208   // is uninitialized
check_index_uninitialized(const Tensor & key,std::size_t index,const OptionalTuple & tuple)209   Status check_index_uninitialized(const Tensor& key, std::size_t index,
210                                    const OptionalTuple& tuple)
211       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
212     if (tuple[index].has_value()) {
213       return Status(errors::InvalidArgument(
214           "The tensor for index '", index, "' for key '", key.scalar<int64>()(),
215           "' was already initialized '", dtypes_.size(), "'."));
216     }
217 
218     return Status::OK();
219   }
220 
221   // Check that the indices are strictly ordered
check_index_ordering(const Tensor & indices)222   Status check_index_ordering(const Tensor& indices) {
223     auto findices = indices.flat<int>();
224 
225     for (std::size_t i = 0; i < findices.dimension(0) - 1; ++i) {
226       if (findices(i) < findices(i + 1)) {
227         continue;
228       }
229 
230       return Status(
231           errors::InvalidArgument("Indices are not strictly ordered"));
232     }
233 
234     return Status::OK();
235   }
236 
237   // Check bytes are within memory limits memory limits
check_memory_limit(std::size_t bytes)238   Status check_memory_limit(std::size_t bytes)
239       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
240     if (has_memory_limit() && bytes > memory_limit_) {
241       return Status(errors::ResourceExhausted(
242           "Attempted to insert tensors with combined size of '", bytes,
243           "' bytes into Staging Area with a memory limit of '", memory_limit_,
244           "'."));
245     }
246 
247     return Status::OK();
248   }
249 
250   // Insert incomplete data into the Barrier
put_incomplete(const KeyType & key,const Tensor & indices,OptionalTuple * tuple,tensorflow::mutex_lock * lock)251   Status put_incomplete(const KeyType& key, const Tensor& indices,
252                         OptionalTuple* tuple, tensorflow::mutex_lock* lock)
253       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
254     auto findices = indices.flat<int>();
255 
256     // Search for the key in our incomplete set
257     auto it = incomplete_.find(key);
258 
259     // Check that the tuple fits within the memory limit
260     std::size_t tuple_bytes = get_tuple_bytes(*tuple);
261     TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
262 
263     // Wait until we don't exceed the memory limit
264     while (would_exceed_memory_limit(tuple_bytes)) {
265       full_.wait(*lock);
266     }
267 
268     // This key isn't present in the incomplete set
269     // Create OptionalTuple and insert
270     if (it == incomplete_.end()) {
271       OptionalTuple empty(dtypes_.size());
272 
273       // Initialize empty tuple with given dta
274       for (std::size_t i = 0; i < findices.dimension(0); ++i) {
275         std::size_t index = findices(i);
276         TF_RETURN_IF_ERROR(check_index(key, index));
277 
278         // Assign tuple at this index
279         empty[index] = std::move((*tuple)[i]);
280       }
281 
282       // Insert into incomplete map
283       incomplete_.insert({key, std::move(empty)});
284 
285       // Increment size
286       current_bytes_ += tuple_bytes;
287     }
288     // Found an entry in the incomplete index
289     // Update with given data and insert complete entries
290     // into the main map
291     else {
292       // Reference existing incomplete tuple
293       OptionalTuple& present = it->second;
294 
295       // Assign given data
296       for (std::size_t i = 0; i < findices.dimension(0); ++i) {
297         std::size_t index = findices(i);
298         TF_RETURN_IF_ERROR(check_index(key, index));
299         TF_RETURN_IF_ERROR(check_index_uninitialized(key, index, present));
300 
301         // Assign tuple at this index
302         present[index] = std::move((*tuple)[i]);
303       }
304 
305       // Increment size
306       current_bytes_ += tuple_bytes;
307 
308       // Do we have values at all tuple elements?
309       bool complete =
310           std::all_of(present.begin(), present.end(),
311                       [](const OptionalTensor& v) { return v.has_value(); });
312 
313       // If so, put the tuple in the actual map
314       if (complete) {
315         OptionalTuple insert_tuple = std::move(it->second);
316 
317         // Remove from incomplete
318         incomplete_.erase(it);
319 
320         TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple));
321       }
322     }
323 
324     return Status::OK();
325   }
326 
327   // Does the insertion into the actual staging area
put_complete(const KeyType & key,OptionalTuple * tuple)328   Status put_complete(const KeyType& key, OptionalTuple* tuple)
329       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
330     // Insert key and tuples into the map
331     map_.insert({key, std::move(*tuple)});
332 
333     notify_removers();
334 
335     return Status::OK();
336   }
337 
338  public:
339   // public methods
StagingMap(const DataTypeVector & dtypes,std::size_t capacity,std::size_t memory_limit)340   explicit StagingMap(const DataTypeVector& dtypes, std::size_t capacity,
341                       std::size_t memory_limit)
342       : dtypes_(dtypes),
343         capacity_(capacity),
344         memory_limit_(memory_limit),
345         current_bytes_(0) {}
346 
put(KeyType * key,const Tensor * indices,OptionalTuple * tuple)347   Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) {
348     tensorflow::mutex_lock lock(mu_);
349 
350     // Sanity check the indices
351     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
352 
353     // Handle incomplete inserts
354     if (indices->NumElements() != dtypes_.size()) {
355       return put_incomplete(*key, *indices, tuple, &lock);
356     }
357 
358     std::size_t tuple_bytes = get_tuple_bytes(*tuple);
359     // Check that tuple_bytes fits within the memory limit
360     TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
361 
362     // Wait until there's space for insertion.
363     while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) {
364       full_.wait(lock);
365     }
366 
367     // Do the put operation
368     TF_RETURN_IF_ERROR(put_complete(*key, tuple));
369 
370     // Update the current size
371     current_bytes_ += tuple_bytes;
372 
373     return Status::OK();
374   }
375 
get(const KeyType * key,const Tensor * indices,Tuple * tuple)376   Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) {
377     tensorflow::mutex_lock lock(mu_);
378 
379     // Sanity check the indices
380     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
381 
382     typename MapType::iterator it;
383 
384     // Wait until the element with the requested key is present
385     while ((it = map_.find(*key)) == map_.end()) {
386       not_empty_.wait(lock);
387     }
388 
389     TF_RETURN_IF_ERROR(
390         copy_or_move_tensors(&it->second, *key, *indices, tuple, true));
391 
392     // Update bytes in the Staging Area
393     current_bytes_ -= get_tuple_bytes(*tuple);
394 
395     return Status::OK();
396   }
397 
pop(const KeyType * key,const Tensor * indices,Tuple * tuple)398   Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) {
399     tensorflow::mutex_lock lock(mu_);
400 
401     // Sanity check the indices
402     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
403 
404     typename MapType::iterator it;
405 
406     // Wait until the element with the requested key is present
407     while ((it = map_.find(*key)) == map_.end()) {
408       not_empty_.wait(lock);
409     }
410 
411     TF_RETURN_IF_ERROR(
412         copy_or_move_tensors(&it->second, *key, *indices, tuple));
413 
414     // Remove entry if all the values have been consumed
415     if (!std::any_of(
416             it->second.begin(), it->second.end(),
417             [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
418       map_.erase(it);
419     }
420 
421     // Update bytes in the Staging Area
422     current_bytes_ -= get_tuple_bytes(*tuple);
423 
424     notify_inserters_if_bounded();
425 
426     return Status::OK();
427   }
428 
popitem(KeyType * key,const Tensor * indices,Tuple * tuple)429   Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) {
430     tensorflow::mutex_lock lock(mu_);
431 
432     // Sanity check the indices
433     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
434 
435     // Wait until map is not empty
436     while (this->map_.empty()) {
437       not_empty_.wait(lock);
438     }
439 
440     // Move from the first element and erase it
441 
442     auto it = map_.begin();
443 
444     TF_RETURN_IF_ERROR(
445         copy_or_move_tensors(&it->second, *key, *indices, tuple));
446 
447     *key = it->first;
448 
449     // Remove entry if all the values have been consumed
450     if (!std::any_of(
451             it->second.begin(), it->second.end(),
452             [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
453       map_.erase(it);
454     }
455 
456     // Update bytes in the Staging Area
457     current_bytes_ -= get_tuple_bytes(*tuple);
458 
459     notify_inserters_if_bounded();
460 
461     return Status::OK();
462   }
463 
clear()464   Status clear() {
465     tensorflow::mutex_lock lock(mu_);
466     map_.clear();
467     incomplete_.clear();
468     current_bytes_ = 0;
469 
470     notify_inserters_if_bounded();
471 
472     return Status::OK();
473   }
474 
incomplete_size()475   std::size_t incomplete_size() {
476     tensorflow::mutex_lock lock(mu_);
477     return incomplete_.size();
478   }
479 
size()480   std::size_t size() {
481     tensorflow::mutex_lock lock(mu_);
482     return map_.size();
483   }
484 
DebugString() const485   string DebugString() const override { return "StagingMap"; }
486 };
487 
488 template <bool Ordered>
GetStagingMap(OpKernelContext * ctx,const NodeDef & ndef,StagingMap<Ordered> ** map)489 Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef,
490                      StagingMap<Ordered>** map) {
491   auto rm = ctx->resource_manager();
492   ContainerInfo cinfo;
493 
494   // Lambda for creating the Staging Area
495   auto create_fn = [&ndef](StagingMap<Ordered>** ret) -> Status {
496     DataTypeVector dtypes;
497     int64 capacity;
498     int64 memory_limit;
499     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "dtypes", &dtypes));
500     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
501     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
502     *ret = new StagingMap<Ordered>(dtypes, capacity, memory_limit);
503     return Status::OK();
504   };
505 
506   TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
507   TF_RETURN_IF_ERROR(rm->LookupOrCreate<StagingMap<Ordered>>(
508       cinfo.container(), cinfo.name(), map, create_fn));
509   return Status::OK();
510 }
511 
512 template <bool Ordered>
513 class MapStageOp : public OpKernel {
514  public:
MapStageOp(OpKernelConstruction * ctx)515   explicit MapStageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
516 
Compute(OpKernelContext * ctx)517   void Compute(OpKernelContext* ctx) override {
518     StagingMap<Ordered>* map = nullptr;
519     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
520     core::ScopedUnref scope(map);
521     typename StagingMap<Ordered>::OptionalTuple tuple;
522 
523     const Tensor* key_tensor;
524     const Tensor* indices_tensor;
525     OpInputList values_tensor;
526 
527     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
528     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
529     OP_REQUIRES_OK(ctx, ctx->input_list("values", &values_tensor));
530 
531     // Create copy for insertion into Staging Area
532     Tensor key(*key_tensor);
533 
534     // Create the tuple to store
535     for (std::size_t i = 0; i < values_tensor.size(); ++i) {
536       tuple.push_back(values_tensor[i]);
537     }
538 
539     // Store the tuple in the map
540     OP_REQUIRES_OK(ctx, map->put(&key, indices_tensor, &tuple));
541   }
542 };
543 
544 REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), MapStageOp<false>);
545 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage").Device(DEVICE_CPU),
546                         MapStageOp<true>);
547 
548 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
549 REGISTER_KERNEL_BUILDER(
550     Name("MapStage").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
551     MapStageOp<false>);
552 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
553                             .HostMemory("key")
554                             .HostMemory("indices")
555                             .Device(DEVICE_GPU),
556                         MapStageOp<true>);
557 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
558 
559 
560 template <bool Ordered>
561 class MapUnstageOp : public OpKernel {
562  public:
MapUnstageOp(OpKernelConstruction * ctx)563   explicit MapUnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
564 
565   // Using this op in such a way that it blocks forever
566   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)567   void Compute(OpKernelContext* ctx) override {
568     StagingMap<Ordered>* map = nullptr;
569     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
570     core::ScopedUnref scope(map);
571     typename StagingMap<Ordered>::Tuple tuple;
572 
573     const Tensor* key_tensor;
574     const Tensor* indices_tensor;
575 
576     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
577     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
578     OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple));
579 
580     OP_REQUIRES(
581         ctx, tuple.size() == indices_tensor->NumElements(),
582         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
583                                 " vs. ", indices_tensor->NumElements()));
584 
585     for (std::size_t i = 0; i < tuple.size(); ++i) {
586       ctx->set_output(i, tuple[i]);
587     }
588   }
589 };
590 
591 REGISTER_KERNEL_BUILDER(Name("MapUnstage").Device(DEVICE_CPU),
592                         MapUnstageOp<false>);
593 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU),
594                         MapUnstageOp<true>);
595 
596 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
597 REGISTER_KERNEL_BUILDER(Name("MapUnstage")
598                             .HostMemory("key")
599                             .HostMemory("indices")
600                             .Device(DEVICE_GPU),
601                         MapUnstageOp<false>);
602 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
603                             .HostMemory("key")
604                             .HostMemory("indices")
605                             .Device(DEVICE_GPU),
606                         MapUnstageOp<true>);
607 #endif
608 
609 template <bool Ordered>
610 class MapPeekOp : public OpKernel {
611  public:
MapPeekOp(OpKernelConstruction * ctx)612   explicit MapPeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
613 
614   // Using this op in such a way that it blocks forever
615   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)616   void Compute(OpKernelContext* ctx) override {
617     StagingMap<Ordered>* map = nullptr;
618     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
619     core::ScopedUnref scope(map);
620     typename StagingMap<Ordered>::Tuple tuple;
621 
622     const Tensor* key_tensor;
623     const Tensor* indices_tensor;
624 
625     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
626     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
627     OP_REQUIRES_OK(ctx, map->get(key_tensor, indices_tensor, &tuple));
628 
629     OP_REQUIRES(
630         ctx, tuple.size() == indices_tensor->NumElements(),
631         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
632                                 " vs. ", indices_tensor->NumElements()));
633 
634     for (std::size_t i = 0; i < tuple.size(); ++i) {
635       ctx->set_output(i, tuple[i]);
636     }
637   }
638 };
639 
640 REGISTER_KERNEL_BUILDER(Name("MapPeek").Device(DEVICE_CPU), MapPeekOp<false>);
641 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").Device(DEVICE_CPU),
642                         MapPeekOp<true>);
643 
644 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
645 REGISTER_KERNEL_BUILDER(
646     Name("MapPeek").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
647     MapPeekOp<false>);
648 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
649                             .HostMemory("key")
650                             .HostMemory("indices")
651                             .Device(DEVICE_GPU),
652                         MapPeekOp<true>);
653 #endif
654 
655 
656 template <bool Ordered>
657 class MapUnstageNoKeyOp : public OpKernel {
658  public:
MapUnstageNoKeyOp(OpKernelConstruction * ctx)659   explicit MapUnstageNoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
660 
661   // Using this op in such a way that it blocks forever
662   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)663   void Compute(OpKernelContext* ctx) override {
664     StagingMap<Ordered>* map = nullptr;
665     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
666     core::ScopedUnref scope(map);
667 
668     // Pop a random (key, value) off the map
669     typename StagingMap<Ordered>::KeyType key;
670     typename StagingMap<Ordered>::Tuple tuple;
671 
672     const Tensor* indices_tensor;
673 
674     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
675     OP_REQUIRES_OK(ctx, map->popitem(&key, indices_tensor, &tuple));
676 
677     // Allocate a key tensor and assign the key as the first output
678     ctx->set_output(0, key);
679 
680     // Set the rest of the outputs to the tuple Tensors
681     OP_REQUIRES(
682         ctx, tuple.size() == indices_tensor->NumElements(),
683         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
684                                 " vs. ", indices_tensor->NumElements()));
685 
686     for (std::size_t i = 0; i < tuple.size(); ++i) {
687       ctx->set_output(i + 1, tuple[i]);
688     }
689   }
690 };
691 
692 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey").Device(DEVICE_CPU),
693                         MapUnstageNoKeyOp<false>);
694 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").Device(DEVICE_CPU),
695                         MapUnstageNoKeyOp<true>);
696 
697 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
698 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
699                             .HostMemory("key")
700                             .HostMemory("indices")
701                             .Device(DEVICE_GPU),
702                         MapUnstageNoKeyOp<false>);
703 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
704                             .HostMemory("key")
705                             .HostMemory("indices")
706                             .Device(DEVICE_GPU),
707                         MapUnstageNoKeyOp<true>);
708 #endif
709 
710 
711 template <bool Ordered>
712 class MapSizeOp : public OpKernel {
713  public:
MapSizeOp(OpKernelConstruction * ctx)714   explicit MapSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
715 
Compute(OpKernelContext * ctx)716   void Compute(OpKernelContext* ctx) override {
717     StagingMap<Ordered>* map = nullptr;
718     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
719     core::ScopedUnref scope(map);
720 
721     // Allocate size output tensor
722     Tensor* size = nullptr;
723     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
724 
725     // Set it to the actual size
726     size->scalar<int32>().setConstant(map->size());
727   }
728 };
729 
730 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), MapSizeOp<false>);
731 REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_CPU),
732                         MapSizeOp<true>);
733 
734 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
735 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_GPU).HostMemory("size"),
736                         MapSizeOp<false>);
737 REGISTER_KERNEL_BUILDER(
738     Name("OrderedMapSize").Device(DEVICE_GPU).HostMemory("size"),
739     MapSizeOp<true>);
740 #endif
741 
742 template <bool Ordered>
743 class MapIncompleteSizeOp : public OpKernel {
744  public:
MapIncompleteSizeOp(OpKernelConstruction * ctx)745   explicit MapIncompleteSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
746 
Compute(OpKernelContext * ctx)747   void Compute(OpKernelContext* ctx) override {
748     StagingMap<Ordered>* map = nullptr;
749     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
750     core::ScopedUnref scope(map);
751 
752     // Allocate size output tensor
753     Tensor* size = nullptr;
754     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
755 
756     // Set it to the actual size
757     size->scalar<int32>().setConstant(map->incomplete_size());
758   }
759 };
760 
761 REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_CPU),
762                         MapIncompleteSizeOp<false>);
763 REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_CPU),
764                         MapIncompleteSizeOp<true>);
765 
766 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
767 REGISTER_KERNEL_BUILDER(
768     Name("MapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
769     MapIncompleteSizeOp<false>);
770 REGISTER_KERNEL_BUILDER(
771     Name("OrderedMapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
772     MapIncompleteSizeOp<true>);
773 #endif
774 
775 template <bool Ordered>
776 class MapClearOp : public OpKernel {
777  public:
MapClearOp(OpKernelConstruction * ctx)778   explicit MapClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
779 
Compute(OpKernelContext * ctx)780   void Compute(OpKernelContext* ctx) override {
781     StagingMap<Ordered>* map = nullptr;
782     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
783     core::ScopedUnref scope(map);
784 
785     OP_REQUIRES_OK(ctx, map->clear());
786   }
787 };
788 
789 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), MapClearOp<false>);
790 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_CPU),
791                         MapClearOp<true>);
792 
793 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
794 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_GPU), MapClearOp<false>);
795 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_GPU),
796                         MapClearOp<true>);
797 #endif
798 
799 }  // namespace
800 }  // namespace tensorflow
801