1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Operations for ExtensionTypes (aka Composite Tensors).""" 16 17from tensorflow.core.protobuf import composite_tensor_variant_pb2 18from tensorflow.python.framework import composite_tensor 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import gen_composite_tensor_ops 22from tensorflow.python.saved_model import nested_structure_coder 23from tensorflow.python.util import nest 24 25 26def composite_tensor_to_variants(value, type_spec=None, name=None): 27 """Encodes `value` as a scalar variant tensor. 28 29 Args: 30 value: The `ExtensionType` value to encode. 31 type_spec: Information about the value's type that should be included in the 32 encoding. 33 name: Optional name for the operation. 34 35 Returns: 36 A Tensor with shape=`()` and dtype=`tf.variant`. 37 38 Raises: 39 ValueError: If `type_spec` is not compatible with `value`. 40 """ 41 if not isinstance(value, composite_tensor.CompositeTensor): 42 raise TypeError("Expected `value` to be a CompositeTensor.") 43 44 if type_spec is None: 45 type_spec = value._type_spec # pylint: disable=protected-access 46 if not type_spec.is_compatible_with(value): 47 raise ValueError("TypeSpec %r is not compatible with value %r" % 48 (type_spec, value)) 49 coder = nested_structure_coder.StructureCoder() 50 metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() 51 metadata.type_spec_proto.CopyFrom( 52 coder.encode_structure(type_spec).type_spec_value) 53 54 return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( 55 components=nest.flatten(value, expand_composites=True), 56 metadata=metadata.SerializeToString(), 57 name=name) 58 59 60def composite_tensor_from_variant(encoded, type_spec, name=None): 61 """Returns the `ExtensionType` value encoded by a variant scalar tensor. 62 63 Args: 64 encoded: A Tensor returned by `composite_tensor_to_variants`. 65 type_spec: The `TypeSpec` of the original value. This is used to determine 66 the number and types of the component tensors that comprise the decoded 67 value. Must be compatible with the `TypeSpec` serilized in `encoded`. 68 name: Optional name for the operation. 69 70 Returns: 71 An `ExtensionType` value that is compatible with `TypeSpec`. 72 73 Raises: 74 TypeError: If `encoded` is not a Tensor with dtype=variant. 75 InvalidArgumentError: If `encoded` is not compatible with `type_spec`. 76 """ 77 if not isinstance(encoded, ops.Tensor): 78 raise TypeError("Expected `encoded` to be a Tensor, got %r." % encoded) 79 if encoded.dtype != dtypes.variant: 80 raise TypeError("Expected `encoded` to have dtype=variant, got %r." % 81 encoded) 82 encoded.shape.assert_is_compatible_with(()) 83 84 coder = nested_structure_coder.StructureCoder() 85 metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() 86 metadata.type_spec_proto.CopyFrom( 87 coder.encode_structure(type_spec).type_spec_value) 88 89 component_dtypes = [ 90 t.dtype for t in nest.flatten(type_spec, expand_composites=True) 91 ] 92 93 components = gen_composite_tensor_ops.CompositeTensorVariantToComponents( 94 encoded=encoded, 95 metadata=metadata.SerializeToString(), 96 Tcomponents=component_dtypes, 97 name=name) 98 return nest.pack_sequence_as(type_spec, components, expand_composites=True) 99 100 101@ops.RegisterGradient("CompositeTensorVariantFromComponents") 102def _composite_tensor_to_variants_grad(op, grad): 103 return gen_composite_tensor_ops.CompositeTensorVariantToComponents( 104 encoded=grad, 105 metadata=op.get_attr("metadata"), 106 Tcomponents=op.get_attr("Tcomponents")) 107 108 109@ops.RegisterGradient("CompositeTensorVariantToComponents") 110def _composite_tensor_from_variant_grad(op, *grad): 111 assert len(grad) == len(op.outputs) 112 # `components` is `op.outputs`, but with any tensors for which we're 113 # taking the gradient replaced by the corresponding value from `grad`. 114 components = [ 115 op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad)) 116 ] 117 return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( 118 components=components, metadata=op.get_attr("metadata")) 119