1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EXTENSION_TYPE_VARIANT_H_ 17 #define TENSORFLOW_CORE_KERNELS_EXTENSION_TYPE_VARIANT_H_ 18 19 #include <vector> 20 21 #include "absl/types/span.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/variant_tensor_data.h" 24 25 namespace tensorflow { 26 27 class CompositeTensorVariantMetadata; 28 29 // Encoding for a `tf.ExtensionType` value, that can be saved as a Variant. 30 // 31 // `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class 32 // used to Python types that are supported by TensorFlow APIs. Example 33 // ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`. 34 // 35 // `CompositeTensorVariant` decomposes the `ExtensionType` value into two 36 // parts: 37 // 38 // * `components`: A list of Tensors, which encodes the value's dynamic 39 // data -- i.e., data that may change for different executions of a graph. 40 // * `metadata`: A serialized TypeSpec, which encodes the value's 41 // static data -- i.e., data that is the same for all executions of a graph. 42 // 43 // CompositeTensorVariant can be stored in a Tensor with dtype=DT_VARIANT. 44 // Typically, extension type values are encoded with a scalar tensor containing 45 // a single CompositeTensorVariant value. 46 class CompositeTensorVariant { 47 public: 48 CompositeTensorVariant(const CompositeTensorVariantMetadata& metadata, 49 absl::Span<Tensor> flat_components); 50 51 CompositeTensorVariant(); 52 CompositeTensorVariant(const CompositeTensorVariant& other); 53 CompositeTensorVariant& operator=(CompositeTensorVariant&& other) = default; 54 CompositeTensorVariant& operator=(const CompositeTensorVariant& other) = 55 delete; 56 57 // Returns the list of Tensor components that encode this value's dynamic 58 // data. flat_components()59 absl::Span<const Tensor> flat_components() const { 60 return absl::MakeConstSpan(flat_components_); 61 } 62 63 // Returns the serialized TypeSpec that encodes the value's static data. metadata()64 const CompositeTensorVariantMetadata& metadata() const { return *metadata_; } 65 66 // Variant methods. TypeName()67 string TypeName() const { return kTypeName; } 68 69 // Updates `VariantTensorData` with an encoding for this value. 70 void Encode(VariantTensorData* data) const; 71 72 // Updates this value to match the encoding in a given `VariantTensorData`. 73 bool Decode(const VariantTensorData& data); 74 75 // Returns a string summary for this value. 76 string DebugString() const; 77 78 // Name of this type (used for variant serialization). 79 static constexpr const char kTypeName[] = "CompositeTensorVariant"; 80 81 private: 82 // Tensor components for this value. 83 std::vector<Tensor> flat_components_; 84 85 // TypeSpec for this value. CompositeTensorVariantMetadata is a thin wrapper 86 // around a TypeSpecProto, which is used to retain flexibility to change the 87 // variant encoding. 88 // 89 // Note: we use a unique_ptr, because header files in the kernels/ directory 90 // are not allowed to import .pb.h files. 91 std::unique_ptr<CompositeTensorVariantMetadata> metadata_; 92 }; 93 94 } // namespace tensorflow 95 96 #endif // TENSORFLOW_CORE_KERNELS_EXTENSION_TYPE_VARIANT_H_ 97