• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EXTENSION_TYPE_VARIANT_H_
17 #define TENSORFLOW_CORE_KERNELS_EXTENSION_TYPE_VARIANT_H_
18 
19 #include <vector>
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/variant_tensor_data.h"
24 
25 namespace tensorflow {
26 
27 class CompositeTensorVariantMetadata;
28 
29 // Encoding for a `tf.ExtensionType` value, that can be saved as a Variant.
30 //
31 // `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class
32 // used to Python types that are supported by TensorFlow APIs.  Example
33 // ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`.
34 //
35 // `CompositeTensorVariant` decomposes the `ExtensionType` value into two
36 // parts:
37 //
38 //   * `components`: A list of Tensors, which encodes the value's dynamic
39 //     data -- i.e., data that may change for different executions of a graph.
40 //   * `metadata`: A serialized TypeSpec, which encodes the value's
41 //     static data -- i.e., data that is the same for all executions of a graph.
42 //
43 // CompositeTensorVariant can be stored in a Tensor with dtype=DT_VARIANT.
44 // Typically, extension type values are encoded with a scalar tensor containing
45 // a single CompositeTensorVariant value.
46 class CompositeTensorVariant {
47  public:
48   CompositeTensorVariant(const CompositeTensorVariantMetadata& metadata,
49                          absl::Span<Tensor> flat_components);
50 
51   CompositeTensorVariant();
52   CompositeTensorVariant(const CompositeTensorVariant& other);
53   CompositeTensorVariant& operator=(CompositeTensorVariant&& other) = default;
54   CompositeTensorVariant& operator=(const CompositeTensorVariant& other) =
55       delete;
56 
57   // Returns the list of Tensor components that encode this value's dynamic
58   // data.
flat_components()59   absl::Span<const Tensor> flat_components() const {
60     return absl::MakeConstSpan(flat_components_);
61   }
62 
63   // Returns the serialized TypeSpec that encodes the value's static data.
metadata()64   const CompositeTensorVariantMetadata& metadata() const { return *metadata_; }
65 
66   // Variant methods.
TypeName()67   string TypeName() const { return kTypeName; }
68 
69   // Updates `VariantTensorData` with an encoding for this value.
70   void Encode(VariantTensorData* data) const;
71 
72   // Updates this value to match the encoding in a given `VariantTensorData`.
73   bool Decode(const VariantTensorData& data);
74 
75   // Returns a string summary for this value.
76   string DebugString() const;
77 
78   // Name of this type (used for variant serialization).
79   static constexpr const char kTypeName[] = "CompositeTensorVariant";
80 
81  private:
82   // Tensor components for this value.
83   std::vector<Tensor> flat_components_;
84 
85   // TypeSpec for this value.  CompositeTensorVariantMetadata is a thin wrapper
86   // around a TypeSpecProto, which is used to retain flexibility to change the
87   // variant encoding.
88   //
89   // Note: we use a unique_ptr, because header files in the kernels/ directory
90   // are not allowed to import .pb.h files.
91   std::unique_ptr<CompositeTensorVariantMetadata> metadata_;
92 };
93 
94 }  // namespace tensorflow
95 
96 #endif  // TENSORFLOW_CORE_KERNELS_EXTENSION_TYPE_VARIANT_H_
97