• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for types information, incuding full type information."""
16
17from typing import List
18
19from tensorflow.core.framework import full_type_pb2
20from tensorflow.core.framework import types_pb2
21from tensorflow.python.framework import type_spec
22from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
23from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.util import nest
26
27# TODO(b/226455884) A python binding for DT_TO_FT or map_dtype_to_tensor() from
28# tensorflow/core/framework/types.cc to avoid duplication here
29_DT_TO_FT = {
30    types_pb2.DT_FLOAT: full_type_pb2.TFT_FLOAT,
31    types_pb2.DT_DOUBLE: full_type_pb2.TFT_DOUBLE,
32    types_pb2.DT_INT32: full_type_pb2.TFT_INT32,
33    types_pb2.DT_UINT8: full_type_pb2.TFT_UINT8,
34    types_pb2.DT_INT16: full_type_pb2.TFT_INT16,
35    types_pb2.DT_INT8: full_type_pb2.TFT_INT8,
36    types_pb2.DT_STRING: full_type_pb2.TFT_STRING,
37    types_pb2.DT_COMPLEX64: full_type_pb2.TFT_COMPLEX64,
38    types_pb2.DT_INT64: full_type_pb2.TFT_INT64,
39    types_pb2.DT_BOOL: full_type_pb2.TFT_BOOL,
40    types_pb2.DT_UINT16: full_type_pb2.TFT_UINT16,
41    types_pb2.DT_COMPLEX128: full_type_pb2.TFT_COMPLEX128,
42    types_pb2.DT_HALF: full_type_pb2.TFT_HALF,
43    types_pb2.DT_UINT32: full_type_pb2.TFT_UINT32,
44    types_pb2.DT_UINT64: full_type_pb2.TFT_UINT64,
45    types_pb2.DT_VARIANT: full_type_pb2.TFT_LEGACY_VARIANT,
46}
47
48
49def _translate_to_fulltype_for_flat_tensors(
50    spec: type_spec.TypeSpec) -> List[full_type_pb2.FullTypeDef]:
51  """Convert a TypeSec to a list of FullTypeDef.
52
53  The FullTypeDef created corresponds to the encoding used with datasets
54  (and map_fn) that uses variants (and not FullTypeDef corresponding to the
55  default "component" encoding).
56
57  Currently, the only use of this is for information about the contents of
58  ragged tensors, so only ragged tensors return useful full type information
59  and other types return TFT_UNSET. While this could be improved in the future,
60  this function is intended for temporary use and expected to be removed
61  when type inference support is sufficient.
62
63  Args:
64    spec: A TypeSpec for one element of a dataset or map_fn.
65
66  Returns:
67    A list of FullTypeDef corresponding to SPEC. The length of this list
68    is always the same as the length of spec._flat_tensor_specs.
69  """
70  if isinstance(spec, RaggedTensorSpec):
71    dt = spec.dtype
72    elem_t = _DT_TO_FT.get(dt)
73    if elem_t is None:
74      logging.vlog(1, "dtype %s that has no conversion to fulltype.", dt)
75    elif elem_t == full_type_pb2.TFT_LEGACY_VARIANT:
76      logging.vlog(1, "Ragged tensors containing variants are not supported.",
77                   dt)
78    else:
79      assert len(spec._flat_tensor_specs) == 1  # pylint: disable=protected-access
80      return [
81          full_type_pb2.FullTypeDef(
82              type_id=full_type_pb2.TFT_RAGGED,
83              args=[full_type_pb2.FullTypeDef(type_id=elem_t)])
84      ]
85  return [
86      full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET)
87      for t in spec._flat_tensor_specs  # pylint: disable=protected-access
88  ]
89
90
91# LINT.IfChange(_specs_for_flat_tensors)
92def _specs_for_flat_tensors(element_spec):
93  """Return a flat list of type specs for element_spec.
94
95  Note that "flat" in this function and in `_flat_tensor_specs` is a nickname
96  for the "batchable tensor list" encoding used by datasets and map_fn
97  internally (in C++/graphs). The ability to batch, unbatch and change
98  batch size is one important characteristic of this encoding. A second
99  important characteristic is that it represets a ragged tensor or sparse
100  tensor as a single tensor of type variant (and this encoding uses special
101  ops to encode/decode to/from variants).
102
103  (In constrast, the more typical encoding, e.g. the C++/graph
104  representation when calling a tf.function, is "component encoding" which
105  represents sparse and ragged tensors as multiple dense tensors and does
106  not use variants or special ops for encoding/decoding.)
107
108  Args:
109    element_spec: A nest of TypeSpec describing the elements of a dataset (or
110      map_fn).
111
112  Returns:
113    A non-nested list of TypeSpec used by the encoding of tensors by
114    datasets and map_fn for ELEMENT_SPEC. The items
115    in this list correspond to the items in `_flat_tensor_specs`.
116  """
117  if isinstance(element_spec, StructuredTensor.Spec):
118    specs = []
119    for _, field_spec in sorted(
120        element_spec._field_specs.items(), key=lambda t: t[0]):  # pylint: disable=protected-access
121      specs.extend(_specs_for_flat_tensors(field_spec))
122  elif isinstance(element_spec, type_spec.BatchableTypeSpec) and (
123      element_spec.__class__._flat_tensor_specs is  # pylint: disable=protected-access
124      type_spec.BatchableTypeSpec._flat_tensor_specs):  # pylint: disable=protected-access
125    # Classes which use the default `_flat_tensor_specs` from
126    # `BatchableTypeSpec` case (i.e. a derived class does not override
127    # `_flat_tensor_specs`.) are encoded using `component_specs`.
128    specs = nest.flatten(
129        element_spec._component_specs,  # pylint: disable=protected-access
130        expand_composites=False)
131  else:
132    # In addition flatting any nesting in Python,
133    # this default case covers things that are encoded by one tensor,
134    # such as dense tensors which are unchanged by encoding and
135    # ragged tensors and sparse tensors which are encoded by a variant tensor.
136    specs = nest.flatten(element_spec, expand_composites=False)
137  return specs
138# LINT.ThenChange()
139# Note that _specs_for_flat_tensors must correspond to _flat_tensor_specs
140
141
142def fulltypes_for_flat_tensors(element_spec):
143  """Convert the element_spec for a dataset to a list of FullType Def.
144
145  Note that "flat" in this function and in `_flat_tensor_specs` is a nickname
146  for the "batchable tensor list" encoding used by datasets and map_fn.
147  The FullTypeDef created corresponds to this encoding (e.g. that uses variants
148  and not the FullTypeDef corresponding to the default "component" encoding).
149
150  This is intended for temporary internal use and expected to be removed
151  when type inference support is sufficient. See limitations of
152  `_translate_to_fulltype_for_flat_tensors`.
153
154  Args:
155    element_spec: A nest of TypeSpec describing the elements of a dataset (or
156      map_fn).
157
158  Returns:
159    A list of FullTypeDef correspoinding to ELEMENT_SPEC. The items
160    in this list correspond to the items in `_flat_tensor_specs`.
161  """
162  specs = _specs_for_flat_tensors(element_spec)
163  full_types_lists = [_translate_to_fulltype_for_flat_tensors(s) for s in specs]
164  rval = nest.flatten(full_types_lists)  # flattens list-of-list to flat list.
165  return rval
166
167
168def fulltype_list_to_product(fulltype_list):
169  """Convert a list of FullType Def into a single FullType Def."""
170  return full_type_pb2.FullTypeDef(
171      type_id=full_type_pb2.TFT_PRODUCT, args=fulltype_list)
172
173
174def iterator_full_type_from_spec(element_spec):
175  """Returns a FullTypeDef for an iterator for the elements.
176
177  Args:
178     element_spec: A nested structure of `tf.TypeSpec` objects representing the
179       element type specification.
180
181  Returns:
182    A FullTypeDef for an iterator for the element tensor representation.
183  """
184  args = fulltypes_for_flat_tensors(element_spec)
185  return full_type_pb2.FullTypeDef(
186      type_id=full_type_pb2.TFT_PRODUCT,
187      args=[
188          full_type_pb2.FullTypeDef(
189              type_id=full_type_pb2.TFT_ITERATOR,
190              args=[
191                  full_type_pb2.FullTypeDef(
192                      type_id=full_type_pb2.TFT_PRODUCT, args=args)
193              ])
194      ])
195