1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_ 16 #define TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_ 17 18 #include <utility> 19 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/variant.h" 22 #include "tensorflow/core/framework/variant_tensor_data.h" 23 #include "tensorflow/core/lib/core/refcount.h" 24 25 namespace tensorflow { 26 27 // Variant compatible type for a list of tensors. This is mutable but instances 28 // should never be mutated after stored in a variant tensor. 29 // 30 // **NOTE**: TensorList stores a refcounted container of tf::Tensor objects, 31 // which are accessible via TensorList::tensors(). Because it is refcounted, 32 // straight copies of the form: 33 // 34 // TensorList b = a; 35 // b.tensors().push_back(t); // WARNING: This modifies a.tensors(). 36 // 37 // Do not create a true copy of the underlying container - but instead increment 38 // a reference count. Modifying b.tensors() modifies a.tensors(). In this way, 39 // TensorList should be considered similar to the tf::Tensor object. 40 // 41 // In order to get a copy of the underlying list, use the Copy method: 42 // 43 // TensorList b = a.Copy(); 44 // b.tensors().push_back(t); // This does not modify a.tensors(). 45 // 46 // Note that this is not a deep copy: the memory locations of the underlying 47 // tensors will still point to the same locations of the corresponding tensors 48 // in the original. To truly perform a deep copy, Device and Type-specific 49 // code needs to be applied to the underlying tensors as usual. 50 // 51 // The most important implication of RefCounted TLs is that OpKernels 52 // wishing to reuse TensorList inputs as outputs via context->forward_input() 53 // need to perform an additional check on the refcount of the TensorList, 54 // to ensure aliasing can be performed safely. For example: 55 // 56 // bool can_alias = false; 57 // auto fw = c->forward_input(..., DT_VARIANT, {}, ...); 58 // if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) { 59 // auto* tl = fw->scalar<Variant>()().get<TensorList>(); 60 // if (tl && tl->RefCountIsOne()) { 61 // can_alias = true; 62 // } 63 // } 64 // 65 class TensorList { 66 public: TensorList()67 TensorList() : tensors_(new Tensors) {} 68 ~TensorList(); 69 TensorList(const TensorList & other)70 TensorList(const TensorList& other) 71 : element_shape(other.element_shape), 72 element_dtype(other.element_dtype), 73 max_num_elements(other.max_num_elements), 74 tensors_(other.tensors_) { 75 tensors_->Ref(); 76 } 77 TensorList(TensorList && rhs)78 TensorList(TensorList&& rhs) 79 : element_shape(std::move(rhs.element_shape)), 80 element_dtype(rhs.element_dtype), 81 max_num_elements(rhs.max_num_elements), 82 tensors_(rhs.tensors_) { 83 rhs.tensors_ = nullptr; 84 } 85 86 TensorList& operator=(const TensorList& rhs) { 87 if (this == &rhs) return *this; 88 element_shape = rhs.element_shape; 89 element_dtype = rhs.element_dtype; 90 max_num_elements = rhs.max_num_elements; 91 tensors_->Unref(); 92 tensors_ = rhs.tensors_; 93 tensors_->Ref(); 94 return *this; 95 } 96 97 TensorList& operator=(TensorList&& rhs) { 98 if (this == &rhs) return *this; 99 element_shape = rhs.element_shape; 100 element_dtype = rhs.element_dtype; 101 max_num_elements = rhs.max_num_elements; 102 std::swap(tensors_, rhs.tensors_); 103 return *this; 104 } 105 106 static const char kTypeName[]; 107 TypeName()108 string TypeName() const { return kTypeName; } 109 110 void Encode(VariantTensorData* data) const; 111 112 bool Decode(const VariantTensorData& data); 113 114 // TODO(apassos) fill this out DebugString()115 string DebugString() const { return "TensorList"; } 116 117 PartialTensorShape element_shape; 118 119 DataType element_dtype; 120 121 // The maximum allowed size of `tensors`. Defaults to -1 meaning that the size 122 // of `tensors` is unbounded. 123 int max_num_elements = -1; 124 125 // Access to the underlying tensor container. tensors()126 std::vector<Tensor>& tensors() { return tensors_->values_; } tensors()127 const std::vector<Tensor>& tensors() const { return tensors_->values_; } 128 129 // Get a new TensorList containing a copy of the underlying tensor container. Copy()130 TensorList Copy() const { 131 TensorList out; 132 out.element_shape = element_shape; 133 out.element_dtype = element_dtype; 134 out.max_num_elements = max_num_elements; 135 // This performs a copy of the std::vector. 136 out.tensors_->values_ = tensors_->values_; 137 return out; 138 } 139 140 // Is this TensorList the only one with a reference to the underlying 141 // container? RefCountIsOne()142 bool RefCountIsOne() const { return tensors_->RefCountIsOne(); } 143 144 private: 145 class Tensors : public core::RefCounted { 146 public: 147 std::vector<Tensor> values_; 148 }; 149 Tensors* tensors_; 150 }; 151 152 #if defined(PLATFORM_GOOGLE) 153 // TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices. 154 // For 32-bit devices, it's acceptable not to inline. 155 static_assert(Variant::CanInlineType<TensorList>() || sizeof(void*) < 8, 156 "Must be able to inline TensorList into a Variant"); 157 #endif 158 } // namespace tensorflow 159 160 #endif // TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_ 161