• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Operations for ExtensionTypes (aka Composite Tensors)."""
16
17from tensorflow.core.protobuf import composite_tensor_variant_pb2
18from tensorflow.python.framework import composite_tensor
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import gen_composite_tensor_ops
22from tensorflow.python.saved_model import nested_structure_coder
23from tensorflow.python.util import nest
24
25
26def composite_tensor_to_variants(value, type_spec=None, name=None):
27  """Encodes `value` as a scalar variant tensor.
28
29  Args:
30    value: The `ExtensionType` value to encode.
31    type_spec: Information about the value's type that should be included in the
32      encoding.
33    name: Optional name for the operation.
34
35  Returns:
36    A Tensor with shape=`()` and dtype=`tf.variant`.
37
38  Raises:
39    ValueError: If `type_spec` is not compatible with `value`.
40  """
41  if not isinstance(value, composite_tensor.CompositeTensor):
42    raise TypeError("Expected `value` to be a CompositeTensor.")
43
44  if type_spec is None:
45    type_spec = value._type_spec  # pylint: disable=protected-access
46  if not type_spec.is_compatible_with(value):
47    raise ValueError("TypeSpec %r is not compatible with value %r" %
48                     (type_spec, value))
49  coder = nested_structure_coder.StructureCoder()
50  metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
51  metadata.type_spec_proto.CopyFrom(
52      coder.encode_structure(type_spec).type_spec_value)
53
54  return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
55      components=nest.flatten(value, expand_composites=True),
56      metadata=metadata.SerializeToString(),
57      name=name)
58
59
60def composite_tensor_from_variant(encoded, type_spec, name=None):
61  """Returns the `ExtensionType` value encoded by a variant scalar tensor.
62
63  Args:
64    encoded: A Tensor returned by `composite_tensor_to_variants`.
65    type_spec: The `TypeSpec` of the original value.  This is used to determine
66      the number and types of the component tensors that comprise the decoded
67      value.  Must be compatible with the `TypeSpec` serilized in `encoded`.
68    name: Optional name for the operation.
69
70  Returns:
71    An `ExtensionType` value that is compatible with `TypeSpec`.
72
73  Raises:
74    TypeError: If `encoded` is not a Tensor with dtype=variant.
75    InvalidArgumentError: If `encoded` is not compatible with `type_spec`.
76  """
77  if not isinstance(encoded, ops.Tensor):
78    raise TypeError("Expected `encoded` to be a Tensor, got %r." % encoded)
79  if encoded.dtype != dtypes.variant:
80    raise TypeError("Expected `encoded` to have dtype=variant, got %r." %
81                    encoded)
82  encoded.shape.assert_is_compatible_with(())
83
84  coder = nested_structure_coder.StructureCoder()
85  metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
86  metadata.type_spec_proto.CopyFrom(
87      coder.encode_structure(type_spec).type_spec_value)
88
89  component_dtypes = [
90      t.dtype for t in nest.flatten(type_spec, expand_composites=True)
91  ]
92
93  components = gen_composite_tensor_ops.CompositeTensorVariantToComponents(
94      encoded=encoded,
95      metadata=metadata.SerializeToString(),
96      Tcomponents=component_dtypes,
97      name=name)
98  return nest.pack_sequence_as(type_spec, components, expand_composites=True)
99
100
101@ops.RegisterGradient("CompositeTensorVariantFromComponents")
102def _composite_tensor_to_variants_grad(op, grad):
103  return gen_composite_tensor_ops.CompositeTensorVariantToComponents(
104      encoded=grad,
105      metadata=op.get_attr("metadata"),
106      Tcomponents=op.get_attr("Tcomponents"))
107
108
109@ops.RegisterGradient("CompositeTensorVariantToComponents")
110def _composite_tensor_from_variant_grad(op, *grad):
111  assert len(grad) == len(op.outputs)
112  # `components` is `op.outputs`, but with any tensors for which we're
113  # taking the gradient replaced by the corresponding value from `grad`.
114  components = [
115      op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad))
116  ]
117  return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
118      components=components, metadata=op.get_attr("metadata"))
119