1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_ 16 #define TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_ 17 18 #include <utility> 19 20 #include "absl/container/flat_hash_map.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_key.h" 23 #include "tensorflow/core/framework/variant.h" 24 #include "tensorflow/core/framework/variant_tensor_data.h" 25 #include "tensorflow/core/lib/core/refcount.h" 26 27 namespace tensorflow { 28 29 // Variant compatible type for a map of tensors. This is mutable but instances 30 // should never be mutated after stored in a variant tensor. 31 // 32 // **NOTE**: TensorMap stores a refcounted container of tf::Tensor objects, 33 // which are accessible via TensorMap::tensors(). Because it is refcounted, 34 // straight copies of the form: 35 // 36 // TensorMap b = a; 37 // b.tensors().insert(k,v); // WARNING: This modifies a.tensors(). 38 // 39 // Do not create a true copy of the underlying container - but instead increment 40 // a reference count. Modifying b.tensors() modifies a.tensors(). In this way, 41 // TensorMap should be considered similar to the tf::Tensor object. 42 // 43 // In order to get a copy of the underlying map, use the Copy method: 44 // 45 // TensorMap b = a.Copy(); 46 // b.tensors().insert(k, v); // This does not modify a.tensors(). 47 // 48 // Note that this is not a deep copy: the memory locations of the underlying 49 // tensors will still point to the same locations of the corresponding tensors 50 // in the original. To truly perform a deep copy, Device and Type-specific 51 // code needs to be applied to the underlying tensors as usual. 52 // 53 // The most important implication of RefCounted TensorMaps is that OpKernels 54 // wishing to reuse TensorMap inputs as outputs via context->forward_input() 55 // need to perform an additional check on the refcount of the TensorList, 56 // to ensure aliasing can be performed safely. For example: 57 // 58 // bool can_alias = false; 59 // auto fw = c->forward_input(..., DT_VARIANT, {}, ...); 60 // if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) { 61 // auto* tl = fw->scalar<Variant>()().get<TensorMap>(); 62 // if (tl && tl->RefCountIsOne()) { 63 // can_alias = true; 64 // } 65 // } 66 // 67 class TensorMap { 68 public: TensorMap()69 TensorMap() : tensors_(new Tensors) {} 70 ~TensorMap(); 71 TensorMap(const TensorMap & other)72 TensorMap(const TensorMap& other) : tensors_(other.tensors_) { 73 tensors_->Ref(); 74 } 75 TensorMap(TensorMap && rhs)76 TensorMap(TensorMap&& rhs) : tensors_(rhs.tensors_) { 77 rhs.tensors_ = nullptr; 78 } 79 80 TensorMap& operator=(const TensorMap& rhs) { 81 if (this == &rhs) return *this; 82 tensors_->Unref(); 83 tensors_ = rhs.tensors_; 84 tensors_->Ref(); 85 return *this; 86 } 87 88 TensorMap& operator=(TensorMap&& rhs) { 89 if (this == &rhs) return *this; 90 std::swap(tensors_, rhs.tensors_); 91 return *this; 92 } 93 94 static const char kTypeName[]; 95 TypeName()96 string TypeName() const { return kTypeName; } 97 98 void Encode(VariantTensorData* data) const; 99 100 bool Decode(const VariantTensorData& data); 101 102 // TODO(apassos) fill this out DebugString()103 string DebugString() const { return "TensorMap"; } 104 105 // Access to the underlying tensor container. tensors()106 absl::flat_hash_map<TensorKey, Tensor>& tensors() { 107 return tensors_->values_; 108 } 109 tensors()110 const absl::flat_hash_map<TensorKey, Tensor>& tensors() const { 111 return tensors_->values_; 112 } 113 114 // Get a new TensorMap containing a copy of the underlying tensor container. Copy()115 TensorMap Copy() const { 116 TensorMap out; 117 // This performs a copy of the absl::hashmap. 118 out.tensors_->values_ = tensors_->values_; 119 return out; 120 } 121 122 // Insert key and value if the key does not already exist. 123 // Returns true if the insertion happens. insert(const TensorKey & key,const Tensor & value)124 bool insert(const TensorKey& key, const Tensor& value) { 125 auto r = tensors_->values_.try_emplace(key, value); 126 return r.second; 127 } 128 129 // Lookup given key. Returns iterator to found key or end. find(TensorKey key)130 absl::flat_hash_map<TensorKey, Tensor>::iterator find(TensorKey key) { 131 return tensors_->values_.find(key); 132 } 133 lookup(TensorKey key)134 Tensor& lookup(TensorKey key) { return tensors_->values_.find(key)->second; } 135 136 Tensor& operator[](TensorKey& k) { return tensors_->values_[k]; } 137 replace(const TensorKey & k,const Tensor & v)138 bool replace(const TensorKey& k, const Tensor& v) { 139 tensors_->values_[k] = v; 140 return true; 141 } 142 143 // Removes element with given key. Return size of removed element. erase(TensorKey key)144 size_t erase(TensorKey key) { return tensors_->values_.erase(key); } 145 146 // Size returns the number of elements in the map size()147 size_t size() const { return tensors_->values_.size(); } 148 keys()149 std::vector<Tensor> keys() const { 150 std::vector<Tensor> keys; 151 keys.reserve(tensors_->values_.size()); 152 absl::flat_hash_map<TensorKey, Tensor>::iterator it = 153 tensors_->values_.begin(); 154 while (it != tensors_->values_.end()) { 155 keys.push_back(it->first); 156 it++; 157 } 158 return keys; 159 } 160 161 // Is this TensorMap the only one with a reference to the underlying 162 // container? RefCountIsOne()163 bool RefCountIsOne() const { return tensors_->RefCountIsOne(); } 164 165 private: 166 class Tensors : public core::RefCounted { 167 public: 168 absl::flat_hash_map<TensorKey, Tensor> values_; 169 }; 170 Tensors* tensors_; 171 }; 172 173 #if defined(PLATFORM_GOOGLE) 174 // TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices. 175 // For 32-bit devices, it's acceptable not to inline. 176 static_assert(Variant::CanInlineType<TensorMap>() || sizeof(void*) < 8, 177 "Must be able to inline TensorMap into a Variant"); 178 #endif 179 } // namespace tensorflow 180 181 #endif // TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_ 182