• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
16 #define TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
17 
18 #include <utility>
19 
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant.h"
22 #include "tensorflow/core/framework/variant_tensor_data.h"
23 #include "tensorflow/core/lib/core/refcount.h"
24 
25 namespace tensorflow {
26 
27 // Variant compatible type for a list of tensors. This is mutable but instances
28 // should never be mutated after stored in a variant tensor.
29 //
30 // **NOTE**: TensorList stores a refcounted container of tf::Tensor objects,
31 // which are accessible via TensorList::tensors().  Because it is refcounted,
32 // straight copies of the form:
33 //
34 //    TensorList b = a;
35 //    b.tensors().push_back(t);  // WARNING: This modifies a.tensors().
36 //
37 // Do not create a true copy of the underlying container - but instead increment
38 // a reference count.  Modifying b.tensors() modifies a.tensors().  In this way,
39 // TensorList should be considered similar to the tf::Tensor object.
40 //
41 // In order to get a copy of the underlying list, use the Copy method:
42 //
43 //    TensorList b = a.Copy();
44 //    b.tensors().push_back(t);  // This does not modify a.tensors().
45 //
46 // Note that this is not a deep copy: the memory locations of the underlying
47 // tensors will still point to the same locations of the corresponding tensors
48 // in the original.  To truly perform a deep copy, Device and Type-specific
49 // code needs to be applied to the underlying tensors as usual.
50 //
51 // The most important implication of RefCounted TLs is that OpKernels
52 // wishing to reuse TensorList inputs as outputs via context->forward_input()
53 // need to perform an additional check on the refcount of the TensorList,
54 // to ensure aliasing can be performed safely.  For example:
55 //
56 //     bool can_alias = false;
57 //     auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
58 //     if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
59 //       auto* tl = fw->scalar<Variant>()().get<TensorList>();
60 //       if (tl && tl->RefCountIsOne()) {
61 //         can_alias = true;
62 //       }
63 //     }
64 //
65 class TensorList {
66  public:
TensorList()67   TensorList() : tensors_(new Tensors) {}
68   ~TensorList();
69 
TensorList(const TensorList & other)70   TensorList(const TensorList& other)
71       : element_shape(other.element_shape),
72         element_dtype(other.element_dtype),
73         max_num_elements(other.max_num_elements),
74         tensors_(other.tensors_) {
75     tensors_->Ref();
76   }
77 
TensorList(TensorList && rhs)78   TensorList(TensorList&& rhs)
79       : element_shape(std::move(rhs.element_shape)),
80         element_dtype(rhs.element_dtype),
81         max_num_elements(rhs.max_num_elements),
82         tensors_(rhs.tensors_) {
83     rhs.tensors_ = nullptr;
84   }
85 
86   TensorList& operator=(const TensorList& rhs) {
87     if (this == &rhs) return *this;
88     element_shape = rhs.element_shape;
89     element_dtype = rhs.element_dtype;
90     max_num_elements = rhs.max_num_elements;
91     tensors_->Unref();
92     tensors_ = rhs.tensors_;
93     tensors_->Ref();
94     return *this;
95   }
96 
97   TensorList& operator=(TensorList&& rhs) {
98     if (this == &rhs) return *this;
99     element_shape = rhs.element_shape;
100     element_dtype = rhs.element_dtype;
101     max_num_elements = rhs.max_num_elements;
102     std::swap(tensors_, rhs.tensors_);
103     return *this;
104   }
105 
106   static const char kTypeName[];
107 
TypeName()108   string TypeName() const { return kTypeName; }
109 
110   void Encode(VariantTensorData* data) const;
111 
112   bool Decode(const VariantTensorData& data);
113 
114   // TODO(apassos) fill this out
DebugString()115   string DebugString() const { return "TensorList"; }
116 
117   PartialTensorShape element_shape;
118 
119   DataType element_dtype;
120 
121   // The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
122   // of `tensors` is unbounded.
123   int max_num_elements = -1;
124 
125   // Access to the underlying tensor container.
tensors()126   std::vector<Tensor>& tensors() { return tensors_->values_; }
tensors()127   const std::vector<Tensor>& tensors() const { return tensors_->values_; }
128 
129   // Get a new TensorList containing a copy of the underlying tensor container.
Copy()130   TensorList Copy() const {
131     TensorList out;
132     out.element_shape = element_shape;
133     out.element_dtype = element_dtype;
134     out.max_num_elements = max_num_elements;
135     // This performs a copy of the std::vector.
136     out.tensors_->values_ = tensors_->values_;
137     return out;
138   }
139 
140   // Is this TensorList the only one with a reference to the underlying
141   // container?
RefCountIsOne()142   bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
143 
144  private:
145   class Tensors : public core::RefCounted {
146    public:
147     std::vector<Tensor> values_;
148   };
149   Tensors* tensors_;
150 };
151 
152 #if defined(PLATFORM_GOOGLE)
153 // TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
154 // For 32-bit devices, it's acceptable not to inline.
155 static_assert(Variant::CanInlineType<TensorList>() || sizeof(void*) < 8,
156               "Must be able to inline TensorList into a Variant");
157 #endif
158 }  // namespace tensorflow
159 
160 #endif  // TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
161