1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
18
19 #include <functional>
20 #include <iostream>
21 #include <memory>
22 #include <type_traits>
23 #include <unordered_map>
24 #include <utility>
25
26 #include "absl/memory/memory.h"
27 #include "tensorflow/core/framework/type_index.h"
28 #include "tensorflow/core/framework/variant_tensor_data.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/strcat.h"
31
32 namespace tensorflow {
33
34 template <typename T>
35 std::string TypeNameVariant(const T& value);
36
37 template <typename T>
38 std::string DebugStringVariant(const T& value);
39
40 // Allows for specializations of Variant Decoding. `data` may be modified in
41 // the process of decoding to `value`.
42 template <typename T>
43 bool DecodeVariant(VariantTensorData* data, T* value);
44
45 template <typename T>
46 bool DecodeVariant(std::string* buf, T* value);
47
48 template <typename T>
49 void EncodeVariant(const T& value, VariantTensorData* data);
50
51 template <typename T>
52 void EncodeVariant(const T& value, std::string* buf);
53
54 // This is an implementation of a type-erased container that can store an
55 // object of any type. The implementation is very similar to std::any, but has
56 // restrictions on the types of objects that can be stored, and eschews some of
57 // the fancier constructors available for std::any. An object of
58 // tensorflow::Variant is intended to be used as the value that will be stored
59 // in a tensorflow::Tensor object when its type is DT_VARIANT.
60 //
61 // tensorflow::Variant can store an object of a class that satisfies the
62 // following constraints:
63 //
64 // * The class is CopyConstructible.
65 // * The class has a default constructor.
66 // * It's either a protocol buffer, a tensorflow::Tensor, or defines the
67 // following functions:
68 //
69 // string TypeName() const;
70 // void Encode(VariantTensorData* data) const;
71 // bool Decode(VariantTensorData data);
72 //
73 // Simple POD types can elide the Encode/Decode functions, they are provided by
74 // helper methods.
75 // Here are some typical usage patterns:
76 //
77 // Variant x = 10;
78 // EXPECT_EQ(*x.get<int>(), 10);
79 //
80 // Tensor t(DT_FLOAT, TensorShape({}));
81 // t.flat<float>()(0) = 42.0f;
82 // Variant x = t;
83 // EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f);
84 //
85 // Accessing the stored object:
86 //
87 // The get<T> function is the main mechanism to access the object
88 // stored in the container. It is type-safe, that is, calling
89 // get<T> when the stored object's type is not T, returns a
90 // nullptr. A raw pointer to the stored object can be obtained by calling
91 // get<void>().
92 //
93 // Serializing/deserializing Variant object:
94 //
95 // The Variant class delegates serializing and deserializing operations to the
96 // contained object. Helper functions to do these operations are provided for
97 // POD data types, tensorflow::Tensor, and protocol buffer objects. However,
98 // other classes have to provide Encode/Decode functions to handle
99 // serialization.
100 //
101 // Objects stored in a Variant object often contain references to other
102 // tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors).
103 // To efficiently support those use cases, a structure is imposed on the
104 // serialization format. Namely, classes should serialize their contents into a
105 // VariantTensorData object:
106 //
107 // struct VariantTensorData {
108 // string type_name;
109 // string metadata;
110 // std::vector<Tensor> tensors;
111 // };
112 //
113 // Objects with references to other Tensors can simply store those tensors in
114 // the `tensors` field, and serialize other metadata content in to the
115 // `metadata` field.
116 //
117 // Serialization example:
118 //
119 // Foo f = Foo {...};
120 // Variant x = f;
121 // string serialized_f;
122 // x.Encode(&serialized_f);
123 //
124 // Variant y = Foo(); // default constructed Foo.
125 // y.Decode(std::move(serialized_f));
126 // EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
127 //
128 //
129 // A Variant storing serialized Variant data (a value of type
130 // VariantTensorDataProto) has different behavior from a standard Variant.
131 // Namely, its TypeName matches the TypeName of the original Variant;
132 // and its non-const get method performs lazy deserialization.
133 //
134 // Decode and copy example:
135 //
136 // Foo f = Foo {...};
137 // Variant x = f;
138 //
139 // VariantTensorData serialized_data_f;
140 // VariantTensorDataProto serialized_proto_f;
141 // x.Encode(&serialized_data_f);
142 // serialized_data_f.ToProto(&serialized_proto_f);
143 //
144 // Variant y_type_unknown = serialized_proto_f; // Store serialized Variant.
145 //
146 // EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo.
147 // EXPECT_EQ(TypeIndex::Make<VariantTensorDataProto>(),
148 // y_type_unknown.TypeId());
149 //
150 class Variant {
151 public:
152 // Constructs a Variant holding no value (aka `is_empty()`).
153 //
154 // This is done by pointing at nullptr via the heap value.
Variant()155 Variant() noexcept : heap_value_(/*pointer=*/nullptr), is_inline_(false) {}
156
157 ~Variant();
158
159 Variant(const Variant& other);
160 Variant(Variant&& other) noexcept;
161
162 // Make sure that the type is CopyConstructible and not a
163 // tensorflow::Variant object itself. We want the copy constructor to be
164 // chosen for the tensorflow::Variant case.
165 template <typename T, typename VT = typename std::decay<T>::type,
166 typename std::enable_if<!std::is_same<Variant, VT>::value &&
167 std::is_move_constructible<VT>::value,
168 void>::type* = nullptr>
169 Variant(T&& value);
170
171 template <typename T, typename VT = typename std::decay<T>::type,
172 typename std::enable_if<!std::is_same<Variant, VT>::value &&
173 std::is_copy_constructible<VT>::value,
174 void>::type* = nullptr>
175 Variant(const T& value);
176
177 template <typename T, typename VT = typename std::decay<T>::type,
178 typename std::enable_if<!std::is_same<Variant, VT>::value &&
179 std::is_copy_constructible<VT>::value,
180 void>::type* = nullptr>
181 Variant& operator=(const T& value);
182
183 template <typename T, typename VT = typename std::decay<T>::type,
184 typename std::enable_if<!std::is_same<Variant, VT>::value &&
185 std::is_move_constructible<VT>::value,
186 void>::type* = nullptr>
187 Variant& operator=(T&& value);
188
189 Variant& operator=(const Variant& rhs) {
190 if (&rhs == this) return *this;
191 Variant(rhs).swap(*this);
192 return *this;
193 }
194
195 Variant& operator=(Variant&& rhs) noexcept {
196 if (&rhs == this) return *this;
197 Variant(std::move(rhs)).swap(*this);
198 return *this;
199 }
200
201 // Constructs a value of type T with the given args in-place in this Variant.
202 // Returns a reference to the newly constructed value.
203 // The signature is based on std::variant<Types...>::emplace() in C++17.
204 template <typename T, class... Args>
emplace(Args &&...args)205 T& emplace(Args&&... args) {
206 ResetMemory();
207 is_inline_ = CanInlineType<T>();
208 if (is_inline_) {
209 new (&inline_value_)
210 InlineValue(InlineValue::Tag<T>{}, std::forward<Args>(args)...);
211 return static_cast<Variant::Value<T>*>(inline_value_.AsValueInterface())
212 ->value;
213 } else {
214 new (&heap_value_) HeapValue(
215 absl::make_unique<Value<T>>(InPlace(), std::forward<Args>(args)...));
216 return static_cast<Variant::Value<T>*>(heap_value_.get())->value;
217 }
218 }
219
is_empty()220 bool is_empty() const { return GetValue() == nullptr; }
221
222 void clear() noexcept;
223
224 void swap(Variant& other) noexcept;
225
226 // Note, unlike TypeName(), TypeId() does not return the TypeIndex
227 // of the original type when a TensorValueDataProto is stored as the
228 // value. In this case, it returns the TypeIndex of TensorValueDataProto.
TypeId()229 TypeIndex TypeId() const {
230 const TypeIndex VoidTypeIndex = TypeIndex::Make<void>();
231 if (is_empty()) {
232 return VoidTypeIndex;
233 }
234 return GetValue()->TypeId();
235 }
236
DebugString()237 std::string DebugString() const {
238 return strings::StrCat(
239 "Variant<type: ", TypeName(),
240 " value: ", is_empty() ? "[empty]" : GetValue()->DebugString(), ">");
241 }
242
243 // Returns a pointer to the stored value if it is type T, or nullptr
244 // otherwise.
245 template <typename T>
get()246 T* get() {
247 const TypeIndex TTypeIndex = TypeIndex::Make<T>();
248 if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
249 return std::addressof(static_cast<Variant::Value<T>*>(GetValue())->value);
250 }
251
252 // Returns a pointer to the stored value if it is type T, or nullptr
253 // otherwise.
254 template <typename T>
get()255 const T* get() const {
256 const TypeIndex TTypeIndex = TypeIndex::Make<T>();
257 if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
258 return std::addressof(
259 static_cast<const Variant::Value<T>*>(GetValue())->value);
260 }
261
262 // Returns TypeNameVariant(value).
263 //
264 // In the special case that a serialized Variant is stored (value
265 // is a VariantTensorDataProto), returns value.TypeName(), the
266 // TypeName field stored in the VariantTensorDataProto buffer.
TypeName()267 std::string TypeName() const {
268 if (is_empty()) {
269 return "";
270 }
271 return GetValue()->TypeName();
272 }
273
274 // Serialize the contents of the stored object into `data`.
Encode(VariantTensorData * data)275 void Encode(VariantTensorData* data) const {
276 if (!is_empty()) {
277 GetValue()->Encode(data);
278 }
279 }
280
281 // Deserialize `data` and update the stored object.
282 bool Decode(VariantTensorData data);
283
284 // Helper methods to directly serialize/deserialize from strings.
Encode(std::string * buf)285 void Encode(std::string* buf) const {
286 if (!is_empty()) {
287 GetValue()->Encode(buf);
288 }
289 }
Decode(std::string buf)290 bool Decode(std::string buf) {
291 if (!is_empty()) {
292 return GetValue()->Decode(std::move(buf));
293 }
294 return true;
295 }
296
297 template <typename VT>
CanInlineType()298 static constexpr bool CanInlineType() {
299 return ((sizeof(Value<VT>) <= InlineValue::kMaxValueSize) &&
300 (alignof(Value<VT>) <= kMaxInlineValueAlignSize));
301 }
302
303 private:
304 struct in_place_t {};
InPlace()305 static constexpr in_place_t InPlace() { return in_place_t{}; }
306
307 struct ValueInterface {
308 virtual ~ValueInterface() = default;
309 virtual TypeIndex TypeId() const = 0;
310 virtual void* RawPtr() = 0;
311 virtual const void* RawPtr() const = 0;
312 virtual std::unique_ptr<ValueInterface> Clone() const = 0;
313 virtual void CloneInto(ValueInterface* memory) const = 0;
314 virtual void MoveAssign(ValueInterface* memory) = 0;
315 virtual void MoveInto(ValueInterface* memory) = 0;
316 virtual std::string TypeName() const = 0;
317 virtual std::string DebugString() const = 0;
318 virtual void Encode(VariantTensorData* data) const = 0;
319 virtual bool Decode(VariantTensorData data) = 0;
320 virtual void Encode(std::string* buf) const = 0;
321 virtual bool Decode(std::string data) = 0;
322 };
323
324 template <typename T>
325 struct Value final : ValueInterface {
326 template <class... Args>
Valuefinal327 explicit Value(in_place_t /*tag*/, Args&&... args)
328 : value(std::forward<Args>(args)...) {}
329
330 // NOTE(ebrevdo): Destructor must be explicitly defined for CUDA to happily
331 // build `alignof(Variant<void*>)`.
332 ~Value() final = default;
333
TypeIdfinal334 TypeIndex TypeId() const final {
335 const TypeIndex value_type_index =
336 TypeIndex::Make<typename std::decay<T>::type>();
337 return value_type_index;
338 }
339
RawPtrfinal340 void* RawPtr() final { return &value; }
341
RawPtrfinal342 const void* RawPtr() const final { return &value; }
343
Clonefinal344 std::unique_ptr<ValueInterface> Clone() const final {
345 return absl::make_unique<Value>(InPlace(), value);
346 }
347
MoveAssignfinal348 void MoveAssign(ValueInterface* memory) final {
349 CHECK(TypeId() == memory->TypeId())
350 << TypeId().name() << " vs. " << memory->TypeId().name();
351 static_cast<Value*>(memory)->value = std::move(value);
352 }
353
CloneIntofinal354 void CloneInto(ValueInterface* memory) const final {
355 new (memory) Value(InPlace(), value);
356 }
357
MoveIntofinal358 void MoveInto(ValueInterface* memory) final {
359 new (memory) Value(InPlace(), std::move(value));
360 }
361
TypeNamefinal362 std::string TypeName() const final { return TypeNameVariant(value); }
363
DebugStringfinal364 std::string DebugString() const final { return DebugStringVariant(value); }
365
Encodefinal366 void Encode(VariantTensorData* data) const final {
367 EncodeVariant(value, data);
368 }
369
Decodefinal370 bool Decode(VariantTensorData data) final {
371 return DecodeVariant(&data, &value);
372 }
373
Encodefinal374 void Encode(std::string* buf) const final { EncodeVariant(value, buf); }
375
Decodefinal376 bool Decode(std::string buf) final { return DecodeVariant(&buf, &value); }
377
378 T value;
379 };
380 static constexpr int kMaxInlineValueAlignSize = alignof(Value<void*>);
381
382 using HeapValue = std::unique_ptr<ValueInterface>;
383
384 struct InlineValue {
385 // We try to size InlineValue so that sizeof(Variant) <= 64 and it can fit
386 // into the aligned space of a TensorBuffer.
387 static constexpr int kMaxValueSize = (64 - /*some extra padding=*/8);
388
389 typedef char ValueDataArray[kMaxValueSize];
390 alignas(kMaxInlineValueAlignSize) ValueDataArray value_data;
391
392 // Tag is used for deducing the right type when constructing a Value in
393 // place.
394 template <typename VT>
395 struct Tag {};
396
397 template <typename VT, class... Args>
InlineValueInlineValue398 explicit InlineValue(Tag<VT> /*tag*/, Args&&... args) noexcept {
399 Value<VT>* inline_value_data = reinterpret_cast<Value<VT>*>(value_data);
400 new (inline_value_data) Value<VT>(InPlace(), std::forward<Args>(args)...);
401 }
402
InlineValueInlineValue403 InlineValue(const InlineValue& other) noexcept {
404 other.AsValueInterface()->CloneInto(AsValueInterface());
405 }
406
InlineValueInlineValue407 InlineValue(InlineValue&& other) noexcept {
408 other.AsValueInterface()->MoveInto(AsValueInterface());
409 }
410
ResetMemoryInlineValue411 void ResetMemory() { AsValueInterface()->~ValueInterface(); }
412
413 InlineValue& operator=(const InlineValue& other) {
414 if (&other == this) return *this;
415 ResetMemory();
416 other.AsValueInterface()->CloneInto(AsValueInterface());
417 return *this;
418 }
419
420 InlineValue& operator=(InlineValue&& other) {
421 if (&other == this) return *this;
422 if (AsValueInterface()->TypeId() == other.AsValueInterface()->TypeId()) {
423 other.AsValueInterface()->MoveAssign(AsValueInterface());
424 } else {
425 ResetMemory();
426 other.AsValueInterface()->MoveInto(AsValueInterface());
427 }
428 return *this;
429 }
430
AsValueInterfaceInlineValue431 ValueInterface* AsValueInterface() {
432 return reinterpret_cast<ValueInterface*>(value_data);
433 }
434
AsValueInterfaceInlineValue435 const ValueInterface* AsValueInterface() const {
436 return reinterpret_cast<const ValueInterface*>(value_data);
437 }
438
~InlineValueInlineValue439 ~InlineValue() { ResetMemory(); }
440 };
441
442 union {
443 HeapValue heap_value_;
444 InlineValue inline_value_;
445 };
446 // is_inline_ provides discrimination between which member of the prior union
447 // is currently within it's lifetime. To switch from one member to the other,
448 // the destructor must be called on the currently alive member before calling
449 // the constructor on the other member. In effect, a member is expected to be
450 // live at any given time and that member is tracked via this boolean.
451 bool is_inline_;
452
IsInlineValue()453 bool IsInlineValue() const { return is_inline_; }
454
455 // ResetMemory causes the destructor of the currently active member of the
456 // union to be run. This must be follwed with a placement new call on the
457 // member whose lifetime is to start. Additionally, is_inline_ needs to be set
458 // accordingly. ResetAndSetInline and ResetAndSetHeap are simple helper
459 // functions for performing the actions that are required to follow.
ResetMemory()460 void ResetMemory() {
461 if (IsInlineValue()) {
462 inline_value_.~InlineValue();
463 } else {
464 heap_value_.~HeapValue();
465 }
466 }
467
468 // ResetAndSetInline clears the current state and then constructs a new value
469 // inline with the provided arguments.
470 template <typename... Args>
ResetAndSetInline(Args &&...args)471 void ResetAndSetInline(Args&&... args) noexcept {
472 ResetMemory();
473 new (&inline_value_) InlineValue(std::forward<Args>(args)...);
474 is_inline_ = true;
475 }
476
477 // ResetAndSetHeap clears the current state then constructs a new value on the
478 // heap with the provided arguments.
479 template <typename... Args>
ResetAndSetHeap(Args &&...args)480 void ResetAndSetHeap(Args&&... args) noexcept {
481 ResetMemory();
482 new (&heap_value_) HeapValue(std::forward<Args>(args)...);
483 is_inline_ = false;
484 }
485
GetValue()486 ValueInterface* GetValue() {
487 if (IsInlineValue()) {
488 return inline_value_.AsValueInterface();
489 } else {
490 return heap_value_.get();
491 }
492 }
493
GetValue()494 const ValueInterface* GetValue() const {
495 if (IsInlineValue()) {
496 return inline_value_.AsValueInterface();
497 } else {
498 return heap_value_.get();
499 }
500 }
501
502 // PRECONDITION: Called on construction or ResetMemory() has been called
503 // before this method.
504 template <typename VT, typename T>
InsertValue(T && value)505 void InsertValue(T&& value) {
506 if (IsInlineValue()) {
507 new (&inline_value_)
508 InlineValue(InlineValue::Tag<VT>{}, std::forward<T>(value));
509 } else {
510 new (&heap_value_) HeapValue(
511 absl::make_unique<Value<VT>>(InPlace(), std::forward<T>(value)));
512 }
513 }
514 };
515
516 // Make sure that a Variant object can reside in a 64-byte aligned Tensor
517 // buffer.
518 static_assert(sizeof(Variant) <= 64,
519 "Expected internal representation to be 64 bytes.");
520
Variant(const Variant & other)521 inline Variant::Variant(const Variant& other)
522 : is_inline_(other.IsInlineValue()) {
523 if (IsInlineValue()) {
524 new (&inline_value_) InlineValue(other.inline_value_);
525 } else {
526 new (&heap_value_)
527 HeapValue(other.heap_value_ ? other.heap_value_->Clone() : nullptr);
528 }
529 }
530
Variant(Variant && other)531 inline Variant::Variant(Variant&& other) noexcept
532 : is_inline_(other.IsInlineValue()) {
533 if (IsInlineValue()) {
534 new (&inline_value_) InlineValue(std::move(other.inline_value_));
535 } else {
536 new (&heap_value_) HeapValue(std::move(other.heap_value_));
537 }
538 }
539
540 template <typename T, typename VT,
541 typename std::enable_if<!std::is_same<Variant, VT>::value &&
542 std::is_move_constructible<VT>::value,
543 void>::type*>
Variant(T && value)544 inline Variant::Variant(T&& value) : is_inline_(CanInlineType<VT>()) {
545 InsertValue<VT>(std::forward<T>(value));
546 }
547
548 template <typename T, typename VT,
549 typename std::enable_if<!std::is_same<Variant, VT>::value &&
550 std::is_copy_constructible<VT>::value,
551 void>::type*>
Variant(const T & value)552 inline Variant::Variant(const T& value) : is_inline_(CanInlineType<VT>()) {
553 InsertValue<VT>(value);
554 }
555
556 template <typename T, typename VT,
557 typename std::enable_if<!std::is_same<Variant, VT>::value &&
558 std::is_move_constructible<VT>::value,
559 void>::type*>
560 inline Variant& Variant::operator=(T&& value) {
561 ResetMemory();
562 is_inline_ = CanInlineType<VT>();
563 InsertValue<VT>(std::forward<T>(value));
564 return *this;
565 }
566
567 template <typename T, typename VT,
568 typename std::enable_if<!std::is_same<Variant, VT>::value &&
569 std::is_copy_constructible<VT>::value,
570 void>::type*>
571 inline Variant& Variant::operator=(const T& value) {
572 ResetMemory();
573 is_inline_ = CanInlineType<VT>();
574 InsertValue<VT>(value);
575 return *this;
576 }
577
clear()578 inline void Variant::clear() noexcept {
579 // We set the internal unique_ptr to nullptr so that we preserve the
580 // invariant that one of the two states must be set at all times. nullptr
581 // indicates that the variant is empty.
582 ResetAndSetHeap(/*pointer=*/nullptr);
583 }
584
swap(Variant & other)585 inline void Variant::swap(Variant& other) noexcept {
586 if (is_empty()) {
587 if (other.IsInlineValue()) {
588 ResetAndSetInline(std::move(other.inline_value_));
589 } else {
590 ResetAndSetHeap(std::move(other.heap_value_));
591 }
592 other.clear();
593 } else if (other.is_empty()) {
594 if (IsInlineValue()) {
595 other.ResetAndSetInline(std::move(inline_value_));
596 } else {
597 other.ResetAndSetHeap(std::move(heap_value_));
598 }
599 clear();
600 } else { // Both Variants have values.
601 if (other.IsInlineValue() && IsInlineValue()) {
602 std::swap(inline_value_, other.inline_value_);
603 } else if (!other.IsInlineValue() && !IsInlineValue()) {
604 std::swap(heap_value_, other.heap_value_);
605 } else if (other.IsInlineValue() && !IsInlineValue()) {
606 HeapValue v = std::move(heap_value_);
607 ResetAndSetInline(std::move(other.inline_value_));
608 other.ResetAndSetHeap(std::move(v));
609 } else { // !other.IsInlineValue() && IsInlineValue()
610 HeapValue v = std::move(other.heap_value_);
611 other.ResetAndSetInline(std::move(inline_value_));
612 ResetAndSetHeap(std::move(v));
613 }
614 }
615 }
616
617 template <>
618 void* Variant::get();
619
620 template <>
621 const void* Variant::get() const;
622
623 } // end namespace tensorflow
624
625 #endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
626