• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_
16 #define TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_
17 
18 #include <utility>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_key.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/framework/variant_tensor_data.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 
27 namespace tensorflow {
28 
29 // Variant compatible type for a map of tensors. This is mutable but instances
30 // should never be mutated after stored in a variant tensor.
31 //
32 // **NOTE**: TensorMap stores a refcounted container of tf::Tensor objects,
33 // which are accessible via TensorMap::tensors().  Because it is refcounted,
34 // straight copies of the form:
35 //
36 //    TensorMap b = a;
37 //    b.tensors().insert(k,v);  // WARNING: This modifies a.tensors().
38 //
39 // Do not create a true copy of the underlying container - but instead increment
40 // a reference count.  Modifying b.tensors() modifies a.tensors().  In this way,
41 // TensorMap should be considered similar to the tf::Tensor object.
42 //
43 // In order to get a copy of the underlying map, use the Copy method:
44 //
45 //    TensorMap b = a.Copy();
46 //    b.tensors().insert(k, v);  // This does not modify a.tensors().
47 //
48 // Note that this is not a deep copy: the memory locations of the underlying
49 // tensors will still point to the same locations of the corresponding tensors
50 // in the original.  To truly perform a deep copy, Device and Type-specific
51 // code needs to be applied to the underlying tensors as usual.
52 //
53 // The most important implication of RefCounted TensorMaps is that OpKernels
54 // wishing to reuse TensorMap inputs as outputs via context->forward_input()
55 // need to perform an additional check on the refcount of the TensorList,
56 // to ensure aliasing can be performed safely.  For example:
57 //
58 //     bool can_alias = false;
59 //     auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
60 //     if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
61 //       auto* tl = fw->scalar<Variant>()().get<TensorMap>();
62 //       if (tl && tl->RefCountIsOne()) {
63 //         can_alias = true;
64 //       }
65 //     }
66 //
67 class TensorMap {
68  public:
TensorMap()69   TensorMap() : tensors_(new Tensors) {}
70   ~TensorMap();
71 
TensorMap(const TensorMap & other)72   TensorMap(const TensorMap& other) : tensors_(other.tensors_) {
73     tensors_->Ref();
74   }
75 
TensorMap(TensorMap && rhs)76   TensorMap(TensorMap&& rhs) : tensors_(rhs.tensors_) {
77     rhs.tensors_ = nullptr;
78   }
79 
80   TensorMap& operator=(const TensorMap& rhs) {
81     if (this == &rhs) return *this;
82     tensors_->Unref();
83     tensors_ = rhs.tensors_;
84     tensors_->Ref();
85     return *this;
86   }
87 
88   TensorMap& operator=(TensorMap&& rhs) {
89     if (this == &rhs) return *this;
90     std::swap(tensors_, rhs.tensors_);
91     return *this;
92   }
93 
94   static const char kTypeName[];
95 
TypeName()96   string TypeName() const { return kTypeName; }
97 
98   void Encode(VariantTensorData* data) const;
99 
100   bool Decode(const VariantTensorData& data);
101 
102   // TODO(apassos) fill this out
DebugString()103   string DebugString() const { return "TensorMap"; }
104 
105   // Access to the underlying tensor container.
tensors()106   absl::flat_hash_map<TensorKey, Tensor>& tensors() {
107     return tensors_->values_;
108   }
109 
tensors()110   const absl::flat_hash_map<TensorKey, Tensor>& tensors() const {
111     return tensors_->values_;
112   }
113 
114   // Get a new TensorMap containing a copy of the underlying tensor container.
Copy()115   TensorMap Copy() const {
116     TensorMap out;
117     // This performs a copy of the absl::hashmap.
118     out.tensors_->values_ = tensors_->values_;
119     return out;
120   }
121 
122   // Insert key and value if the key does not already exist.
123   // Returns true if the insertion happens.
insert(const TensorKey & key,const Tensor & value)124   bool insert(const TensorKey& key, const Tensor& value) {
125     auto r = tensors_->values_.try_emplace(key, value);
126     return r.second;
127   }
128 
129   // Lookup given key. Returns iterator to found key or end.
find(TensorKey key)130   absl::flat_hash_map<TensorKey, Tensor>::iterator find(TensorKey key) {
131     return tensors_->values_.find(key);
132   }
133 
lookup(TensorKey key)134   Tensor& lookup(TensorKey key) { return tensors_->values_.find(key)->second; }
135 
136   Tensor& operator[](TensorKey& k) { return tensors_->values_[k]; }
137 
replace(const TensorKey & k,const Tensor & v)138   bool replace(const TensorKey& k, const Tensor& v) {
139     tensors_->values_[k] = v;
140     return true;
141   }
142 
143   // Removes element with given key. Return size of removed element.
erase(TensorKey key)144   size_t erase(TensorKey key) { return tensors_->values_.erase(key); }
145 
146   // Size returns the number of elements in the map
size()147   size_t size() const { return tensors_->values_.size(); }
148 
keys()149   std::vector<Tensor> keys() const {
150     std::vector<Tensor> keys;
151     keys.reserve(tensors_->values_.size());
152     absl::flat_hash_map<TensorKey, Tensor>::iterator it =
153         tensors_->values_.begin();
154     while (it != tensors_->values_.end()) {
155       keys.push_back(it->first);
156       it++;
157     }
158     return keys;
159   }
160 
161   // Is this TensorMap the only one with a reference to the underlying
162   // container?
RefCountIsOne()163   bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
164 
165  private:
166   class Tensors : public core::RefCounted {
167    public:
168     absl::flat_hash_map<TensorKey, Tensor> values_;
169   };
170   Tensors* tensors_;
171 };
172 
173 #if defined(PLATFORM_GOOGLE)
174 // TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
175 // For 32-bit devices, it's acceptable not to inline.
176 static_assert(Variant::CanInlineType<TensorMap>() || sizeof(void*) < 8,
177               "Must be able to inline TensorMap into a Variant");
178 #endif
179 }  // namespace tensorflow
180 
181 #endif  // TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_
182