• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
18 
19 #include <functional>
20 #include <iostream>
21 #include <memory>
22 #include <type_traits>
23 #include <unordered_map>
24 #include <utility>
25 
26 #include "absl/memory/memory.h"
27 #include "tensorflow/core/framework/type_index.h"
28 #include "tensorflow/core/framework/variant_tensor_data.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/strcat.h"
31 
32 namespace tensorflow {
33 
34 template <typename T>
35 std::string TypeNameVariant(const T& value);
36 
37 template <typename T>
38 std::string DebugStringVariant(const T& value);
39 
40 // Allows for specializations of Variant Decoding.  `data` may be modified in
41 // the process of decoding to `value`.
42 template <typename T>
43 bool DecodeVariant(VariantTensorData* data, T* value);
44 
45 template <typename T>
46 bool DecodeVariant(std::string* buf, T* value);
47 
48 template <typename T>
49 void EncodeVariant(const T& value, VariantTensorData* data);
50 
51 template <typename T>
52 void EncodeVariant(const T& value, std::string* buf);
53 
54 // This is an implementation of a type-erased container that can store an
55 // object of any type. The implementation is very similar to std::any, but has
56 // restrictions on the types of objects that can be stored, and eschews some of
57 // the fancier constructors available for std::any. An object of
58 // tensorflow::Variant is intended to be used as the value that will be stored
59 // in a tensorflow::Tensor object when its type is DT_VARIANT.
60 //
61 // tensorflow::Variant can store an object of a class that satisfies the
62 // following constraints:
63 //
64 // * The class is CopyConstructible.
65 // * The class has a default constructor.
66 // * It's either a protocol buffer, a tensorflow::Tensor, or defines the
67 // following functions:
68 //
69 //   string TypeName() const;
70 //   void Encode(VariantTensorData* data) const;
71 //   bool Decode(VariantTensorData data);
72 //
73 // Simple POD types can elide the Encode/Decode functions, they are provided by
74 // helper methods.
75 // Here are some typical usage patterns:
76 //
77 //   Variant x = 10;
78 //   EXPECT_EQ(*x.get<int>(), 10);
79 //
80 //   Tensor t(DT_FLOAT, TensorShape({}));
81 //   t.flat<float>()(0) = 42.0f;
82 //   Variant x = t;
83 //   EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f);
84 //
85 // Accessing the stored object:
86 //
87 // The get<T> function is the main mechanism to access the object
88 // stored in the container. It is type-safe, that is, calling
89 // get<T> when the stored object's type is not T, returns a
90 // nullptr. A raw pointer to the stored object can be obtained by calling
91 // get<void>().
92 //
93 // Serializing/deserializing Variant object:
94 //
95 // The Variant class delegates serializing and deserializing operations to the
96 // contained object. Helper functions to do these operations are provided for
97 // POD data types, tensorflow::Tensor, and protocol buffer objects. However,
98 // other classes have to provide Encode/Decode functions to handle
99 // serialization.
100 //
101 // Objects stored in a Variant object often contain references to other
102 // tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors).
103 // To efficiently support those use cases, a structure is imposed on the
104 // serialization format. Namely, classes should serialize their contents into a
105 // VariantTensorData object:
106 //
107 //   struct VariantTensorData {
108 //     string type_name;
109 //     string metadata;
110 //     std::vector<Tensor> tensors;
111 //   };
112 //
113 // Objects with references to other Tensors can simply store those tensors in
114 // the `tensors` field, and serialize other metadata content in to the
115 // `metadata` field.
116 //
117 // Serialization example:
118 //
119 //   Foo f = Foo {...};
120 //   Variant x = f;
121 //   string serialized_f;
122 //   x.Encode(&serialized_f);
123 //
124 //   Variant y = Foo(); // default constructed Foo.
125 //   y.Decode(std::move(serialized_f));
126 //   EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
127 //
128 //
129 // A Variant storing serialized Variant data (a value of type
130 // VariantTensorDataProto) has different behavior from a standard Variant.
131 // Namely, its TypeName matches the TypeName of the original Variant;
132 // and its non-const get method performs lazy deserialization.
133 //
134 // Decode and copy example:
135 //
136 //   Foo f = Foo {...};
137 //   Variant x = f;
138 //
139 //   VariantTensorData serialized_data_f;
140 //   VariantTensorDataProto serialized_proto_f;
141 //   x.Encode(&serialized_data_f);
142 //   serialized_data_f.ToProto(&serialized_proto_f);
143 //
144 //   Variant y_type_unknown = serialized_proto_f;  // Store serialized Variant.
145 //
146 //   EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName());  // Looks like Foo.
147 //   EXPECT_EQ(TypeIndex::Make<VariantTensorDataProto>(),
148 //             y_type_unknown.TypeId());
149 //
150 class Variant {
151  public:
152   // Constructs a Variant holding no value (aka `is_empty()`).
153   //
154   // This is done by pointing at nullptr via the heap value.
Variant()155   Variant() noexcept : heap_value_(/*pointer=*/nullptr), is_inline_(false) {}
156 
157   ~Variant();
158 
159   Variant(const Variant& other);
160   Variant(Variant&& other) noexcept;
161 
162   // Make sure that the type is CopyConstructible and not a
163   // tensorflow::Variant object itself. We want the copy constructor to be
164   // chosen for the tensorflow::Variant case.
165   template <typename T, typename VT = typename std::decay<T>::type,
166             typename std::enable_if<!std::is_same<Variant, VT>::value &&
167                                         std::is_move_constructible<VT>::value,
168                                     void>::type* = nullptr>
169   Variant(T&& value);
170 
171   template <typename T, typename VT = typename std::decay<T>::type,
172             typename std::enable_if<!std::is_same<Variant, VT>::value &&
173                                         std::is_copy_constructible<VT>::value,
174                                     void>::type* = nullptr>
175   Variant(const T& value);
176 
177   template <typename T, typename VT = typename std::decay<T>::type,
178             typename std::enable_if<!std::is_same<Variant, VT>::value &&
179                                         std::is_copy_constructible<VT>::value,
180                                     void>::type* = nullptr>
181   Variant& operator=(const T& value);
182 
183   template <typename T, typename VT = typename std::decay<T>::type,
184             typename std::enable_if<!std::is_same<Variant, VT>::value &&
185                                         std::is_move_constructible<VT>::value,
186                                     void>::type* = nullptr>
187   Variant& operator=(T&& value);
188 
189   Variant& operator=(const Variant& rhs) {
190     if (&rhs == this) return *this;
191     Variant(rhs).swap(*this);
192     return *this;
193   }
194 
195   Variant& operator=(Variant&& rhs) noexcept {
196     if (&rhs == this) return *this;
197     Variant(std::move(rhs)).swap(*this);
198     return *this;
199   }
200 
201   // Constructs a value of type T with the given args in-place in this Variant.
202   // Returns a reference to the newly constructed value.
203   // The signature is based on std::variant<Types...>::emplace() in C++17.
204   template <typename T, class... Args>
emplace(Args &&...args)205   T& emplace(Args&&... args) {
206     ResetMemory();
207     is_inline_ = CanInlineType<T>();
208     if (is_inline_) {
209       new (&inline_value_)
210           InlineValue(InlineValue::Tag<T>{}, std::forward<Args>(args)...);
211       return static_cast<Variant::Value<T>*>(inline_value_.AsValueInterface())
212           ->value;
213     } else {
214       new (&heap_value_) HeapValue(
215           absl::make_unique<Value<T>>(InPlace(), std::forward<Args>(args)...));
216       return static_cast<Variant::Value<T>*>(heap_value_.get())->value;
217     }
218   }
219 
is_empty()220   bool is_empty() const { return GetValue() == nullptr; }
221 
222   void clear() noexcept;
223 
224   void swap(Variant& other) noexcept;
225 
226   // Note, unlike TypeName(), TypeId() does not return the TypeIndex
227   // of the original type when a TensorValueDataProto is stored as the
228   // value.  In this case, it returns the TypeIndex of TensorValueDataProto.
TypeId()229   TypeIndex TypeId() const {
230     const TypeIndex VoidTypeIndex = TypeIndex::Make<void>();
231     if (is_empty()) {
232       return VoidTypeIndex;
233     }
234     return GetValue()->TypeId();
235   }
236 
DebugString()237   std::string DebugString() const {
238     return strings::StrCat(
239         "Variant<type: ", TypeName(),
240         " value: ", is_empty() ? "[empty]" : GetValue()->DebugString(), ">");
241   }
242 
243   // Returns a pointer to the stored value if it is type T, or nullptr
244   // otherwise.
245   template <typename T>
get()246   T* get() {
247     const TypeIndex TTypeIndex = TypeIndex::Make<T>();
248     if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
249     return std::addressof(static_cast<Variant::Value<T>*>(GetValue())->value);
250   }
251 
252   // Returns a pointer to the stored value if it is type T, or nullptr
253   // otherwise.
254   template <typename T>
get()255   const T* get() const {
256     const TypeIndex TTypeIndex = TypeIndex::Make<T>();
257     if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
258     return std::addressof(
259         static_cast<const Variant::Value<T>*>(GetValue())->value);
260   }
261 
262   // Returns TypeNameVariant(value).
263   //
264   // In the special case that a serialized Variant is stored (value
265   // is a VariantTensorDataProto), returns value.TypeName(), the
266   // TypeName field stored in the VariantTensorDataProto buffer.
TypeName()267   std::string TypeName() const {
268     if (is_empty()) {
269       return "";
270     }
271     return GetValue()->TypeName();
272   }
273 
274   // Serialize the contents of the stored object into `data`.
Encode(VariantTensorData * data)275   void Encode(VariantTensorData* data) const {
276     if (!is_empty()) {
277       GetValue()->Encode(data);
278     }
279   }
280 
281   // Deserialize `data` and update the stored object.
282   bool Decode(VariantTensorData data);
283 
284   // Helper methods to directly serialize/deserialize from strings.
Encode(std::string * buf)285   void Encode(std::string* buf) const {
286     if (!is_empty()) {
287       GetValue()->Encode(buf);
288     }
289   }
Decode(std::string buf)290   bool Decode(std::string buf) {
291     if (!is_empty()) {
292       return GetValue()->Decode(std::move(buf));
293     }
294     return true;
295   }
296 
297   template <typename VT>
CanInlineType()298   static constexpr bool CanInlineType() {
299     return ((sizeof(Value<VT>) <= InlineValue::kMaxValueSize) &&
300             (alignof(Value<VT>) <= kMaxInlineValueAlignSize));
301   }
302 
303  private:
304   struct in_place_t {};
InPlace()305   static constexpr in_place_t InPlace() { return in_place_t{}; }
306 
307   struct ValueInterface {
308     virtual ~ValueInterface() = default;
309     virtual TypeIndex TypeId() const = 0;
310     virtual void* RawPtr() = 0;
311     virtual const void* RawPtr() const = 0;
312     virtual std::unique_ptr<ValueInterface> Clone() const = 0;
313     virtual void CloneInto(ValueInterface* memory) const = 0;
314     virtual void MoveAssign(ValueInterface* memory) = 0;
315     virtual void MoveInto(ValueInterface* memory) = 0;
316     virtual std::string TypeName() const = 0;
317     virtual std::string DebugString() const = 0;
318     virtual void Encode(VariantTensorData* data) const = 0;
319     virtual bool Decode(VariantTensorData data) = 0;
320     virtual void Encode(std::string* buf) const = 0;
321     virtual bool Decode(std::string data) = 0;
322   };
323 
324   template <typename T>
325   struct Value final : ValueInterface {
326     template <class... Args>
Valuefinal327     explicit Value(in_place_t /*tag*/, Args&&... args)
328         : value(std::forward<Args>(args)...) {}
329 
330     // NOTE(ebrevdo): Destructor must be explicitly defined for CUDA to happily
331     // build `alignof(Variant<void*>)`.
332     ~Value() final = default;
333 
TypeIdfinal334     TypeIndex TypeId() const final {
335       const TypeIndex value_type_index =
336           TypeIndex::Make<typename std::decay<T>::type>();
337       return value_type_index;
338     }
339 
RawPtrfinal340     void* RawPtr() final { return &value; }
341 
RawPtrfinal342     const void* RawPtr() const final { return &value; }
343 
Clonefinal344     std::unique_ptr<ValueInterface> Clone() const final {
345       return absl::make_unique<Value>(InPlace(), value);
346     }
347 
MoveAssignfinal348     void MoveAssign(ValueInterface* memory) final {
349       CHECK(TypeId() == memory->TypeId())
350           << TypeId().name() << " vs. " << memory->TypeId().name();
351       static_cast<Value*>(memory)->value = std::move(value);
352     }
353 
CloneIntofinal354     void CloneInto(ValueInterface* memory) const final {
355       new (memory) Value(InPlace(), value);
356     }
357 
MoveIntofinal358     void MoveInto(ValueInterface* memory) final {
359       new (memory) Value(InPlace(), std::move(value));
360     }
361 
TypeNamefinal362     std::string TypeName() const final { return TypeNameVariant(value); }
363 
DebugStringfinal364     std::string DebugString() const final { return DebugStringVariant(value); }
365 
Encodefinal366     void Encode(VariantTensorData* data) const final {
367       EncodeVariant(value, data);
368     }
369 
Decodefinal370     bool Decode(VariantTensorData data) final {
371       return DecodeVariant(&data, &value);
372     }
373 
Encodefinal374     void Encode(std::string* buf) const final { EncodeVariant(value, buf); }
375 
Decodefinal376     bool Decode(std::string buf) final { return DecodeVariant(&buf, &value); }
377 
378     T value;
379   };
380   static constexpr int kMaxInlineValueAlignSize = alignof(Value<void*>);
381 
382   using HeapValue = std::unique_ptr<ValueInterface>;
383 
384   struct InlineValue {
385     // We try to size InlineValue so that sizeof(Variant) <= 64 and it can fit
386     // into the aligned space of a TensorBuffer.
387     static constexpr int kMaxValueSize = (64 - /*some extra padding=*/8);
388 
389     typedef char ValueDataArray[kMaxValueSize];
390     alignas(kMaxInlineValueAlignSize) ValueDataArray value_data;
391 
392     // Tag is used for deducing the right type when constructing a Value in
393     // place.
394     template <typename VT>
395     struct Tag {};
396 
397     template <typename VT, class... Args>
InlineValueInlineValue398     explicit InlineValue(Tag<VT> /*tag*/, Args&&... args) noexcept {
399       Value<VT>* inline_value_data = reinterpret_cast<Value<VT>*>(value_data);
400       new (inline_value_data) Value<VT>(InPlace(), std::forward<Args>(args)...);
401     }
402 
InlineValueInlineValue403     InlineValue(const InlineValue& other) noexcept {
404       other.AsValueInterface()->CloneInto(AsValueInterface());
405     }
406 
InlineValueInlineValue407     InlineValue(InlineValue&& other) noexcept {
408       other.AsValueInterface()->MoveInto(AsValueInterface());
409     }
410 
ResetMemoryInlineValue411     void ResetMemory() { AsValueInterface()->~ValueInterface(); }
412 
413     InlineValue& operator=(const InlineValue& other) {
414       if (&other == this) return *this;
415       ResetMemory();
416       other.AsValueInterface()->CloneInto(AsValueInterface());
417       return *this;
418     }
419 
420     InlineValue& operator=(InlineValue&& other) {
421       if (&other == this) return *this;
422       if (AsValueInterface()->TypeId() == other.AsValueInterface()->TypeId()) {
423         other.AsValueInterface()->MoveAssign(AsValueInterface());
424       } else {
425         ResetMemory();
426         other.AsValueInterface()->MoveInto(AsValueInterface());
427       }
428       return *this;
429     }
430 
AsValueInterfaceInlineValue431     ValueInterface* AsValueInterface() {
432       return reinterpret_cast<ValueInterface*>(value_data);
433     }
434 
AsValueInterfaceInlineValue435     const ValueInterface* AsValueInterface() const {
436       return reinterpret_cast<const ValueInterface*>(value_data);
437     }
438 
~InlineValueInlineValue439     ~InlineValue() { ResetMemory(); }
440   };
441 
442   union {
443     HeapValue heap_value_;
444     InlineValue inline_value_;
445   };
446   // is_inline_ provides discrimination between which member of the prior union
447   // is currently within it's lifetime. To switch from one member to the other,
448   // the destructor must be called on the currently alive member before calling
449   // the constructor on the other member. In effect, a member is expected to be
450   // live at any given time and that member is tracked via this boolean.
451   bool is_inline_;
452 
IsInlineValue()453   bool IsInlineValue() const { return is_inline_; }
454 
455   // ResetMemory causes the destructor of the currently active member of the
456   // union to be run. This must be follwed with a placement new call on the
457   // member whose lifetime is to start. Additionally, is_inline_ needs to be set
458   // accordingly. ResetAndSetInline and ResetAndSetHeap are simple helper
459   // functions for performing the actions that are required to follow.
ResetMemory()460   void ResetMemory() {
461     if (IsInlineValue()) {
462       inline_value_.~InlineValue();
463     } else {
464       heap_value_.~HeapValue();
465     }
466   }
467 
468   // ResetAndSetInline clears the current state and then constructs a new value
469   // inline with the provided arguments.
470   template <typename... Args>
ResetAndSetInline(Args &&...args)471   void ResetAndSetInline(Args&&... args) noexcept {
472     ResetMemory();
473     new (&inline_value_) InlineValue(std::forward<Args>(args)...);
474     is_inline_ = true;
475   }
476 
477   // ResetAndSetHeap clears the current state then constructs a new value on the
478   // heap with the provided arguments.
479   template <typename... Args>
ResetAndSetHeap(Args &&...args)480   void ResetAndSetHeap(Args&&... args) noexcept {
481     ResetMemory();
482     new (&heap_value_) HeapValue(std::forward<Args>(args)...);
483     is_inline_ = false;
484   }
485 
GetValue()486   ValueInterface* GetValue() {
487     if (IsInlineValue()) {
488       return inline_value_.AsValueInterface();
489     } else {
490       return heap_value_.get();
491     }
492   }
493 
GetValue()494   const ValueInterface* GetValue() const {
495     if (IsInlineValue()) {
496       return inline_value_.AsValueInterface();
497     } else {
498       return heap_value_.get();
499     }
500   }
501 
502   // PRECONDITION: Called on construction or ResetMemory() has been called
503   // before this method.
504   template <typename VT, typename T>
InsertValue(T && value)505   void InsertValue(T&& value) {
506     if (IsInlineValue()) {
507       new (&inline_value_)
508           InlineValue(InlineValue::Tag<VT>{}, std::forward<T>(value));
509     } else {
510       new (&heap_value_) HeapValue(
511           absl::make_unique<Value<VT>>(InPlace(), std::forward<T>(value)));
512     }
513   }
514 };
515 
516 // Make sure that a Variant object can reside in a 64-byte aligned Tensor
517 // buffer.
518 static_assert(sizeof(Variant) <= 64,
519               "Expected internal representation to be 64 bytes.");
520 
Variant(const Variant & other)521 inline Variant::Variant(const Variant& other)
522     : is_inline_(other.IsInlineValue()) {
523   if (IsInlineValue()) {
524     new (&inline_value_) InlineValue(other.inline_value_);
525   } else {
526     new (&heap_value_)
527         HeapValue(other.heap_value_ ? other.heap_value_->Clone() : nullptr);
528   }
529 }
530 
Variant(Variant && other)531 inline Variant::Variant(Variant&& other) noexcept
532     : is_inline_(other.IsInlineValue()) {
533   if (IsInlineValue()) {
534     new (&inline_value_) InlineValue(std::move(other.inline_value_));
535   } else {
536     new (&heap_value_) HeapValue(std::move(other.heap_value_));
537   }
538 }
539 
540 template <typename T, typename VT,
541           typename std::enable_if<!std::is_same<Variant, VT>::value &&
542                                       std::is_move_constructible<VT>::value,
543                                   void>::type*>
Variant(T && value)544 inline Variant::Variant(T&& value) : is_inline_(CanInlineType<VT>()) {
545   InsertValue<VT>(std::forward<T>(value));
546 }
547 
548 template <typename T, typename VT,
549           typename std::enable_if<!std::is_same<Variant, VT>::value &&
550                                       std::is_copy_constructible<VT>::value,
551                                   void>::type*>
Variant(const T & value)552 inline Variant::Variant(const T& value) : is_inline_(CanInlineType<VT>()) {
553   InsertValue<VT>(value);
554 }
555 
556 template <typename T, typename VT,
557           typename std::enable_if<!std::is_same<Variant, VT>::value &&
558                                       std::is_move_constructible<VT>::value,
559                                   void>::type*>
560 inline Variant& Variant::operator=(T&& value) {
561   ResetMemory();
562   is_inline_ = CanInlineType<VT>();
563   InsertValue<VT>(std::forward<T>(value));
564   return *this;
565 }
566 
567 template <typename T, typename VT,
568           typename std::enable_if<!std::is_same<Variant, VT>::value &&
569                                       std::is_copy_constructible<VT>::value,
570                                   void>::type*>
571 inline Variant& Variant::operator=(const T& value) {
572   ResetMemory();
573   is_inline_ = CanInlineType<VT>();
574   InsertValue<VT>(value);
575   return *this;
576 }
577 
clear()578 inline void Variant::clear() noexcept {
579   // We set the internal unique_ptr to nullptr so that we preserve the
580   // invariant that one of the two states must be set at all times. nullptr
581   // indicates that the variant is empty.
582   ResetAndSetHeap(/*pointer=*/nullptr);
583 }
584 
swap(Variant & other)585 inline void Variant::swap(Variant& other) noexcept {
586   if (is_empty()) {
587     if (other.IsInlineValue()) {
588       ResetAndSetInline(std::move(other.inline_value_));
589     } else {
590       ResetAndSetHeap(std::move(other.heap_value_));
591     }
592     other.clear();
593   } else if (other.is_empty()) {
594     if (IsInlineValue()) {
595       other.ResetAndSetInline(std::move(inline_value_));
596     } else {
597       other.ResetAndSetHeap(std::move(heap_value_));
598     }
599     clear();
600   } else {  // Both Variants have values.
601     if (other.IsInlineValue() && IsInlineValue()) {
602       std::swap(inline_value_, other.inline_value_);
603     } else if (!other.IsInlineValue() && !IsInlineValue()) {
604       std::swap(heap_value_, other.heap_value_);
605     } else if (other.IsInlineValue() && !IsInlineValue()) {
606       HeapValue v = std::move(heap_value_);
607       ResetAndSetInline(std::move(other.inline_value_));
608       other.ResetAndSetHeap(std::move(v));
609     } else {  // !other.IsInlineValue() && IsInlineValue()
610       HeapValue v = std::move(other.heap_value_);
611       other.ResetAndSetInline(std::move(inline_value_));
612       ResetAndSetHeap(std::move(v));
613     }
614   }
615 }
616 
617 template <>
618 void* Variant::get();
619 
620 template <>
621 const void* Variant::get() const;
622 
623 }  // end namespace tensorflow
624 
625 #endif  // TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
626