1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for types information, incuding full type information.""" 16 17from typing import List 18 19from tensorflow.core.framework import full_type_pb2 20from tensorflow.core.framework import types_pb2 21from tensorflow.python.framework import type_spec 22from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec 23from tensorflow.python.ops.structured.structured_tensor import StructuredTensor 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.util import nest 26 27# TODO(b/226455884) A python binding for DT_TO_FT or map_dtype_to_tensor() from 28# tensorflow/core/framework/types.cc to avoid duplication here 29_DT_TO_FT = { 30 types_pb2.DT_FLOAT: full_type_pb2.TFT_FLOAT, 31 types_pb2.DT_DOUBLE: full_type_pb2.TFT_DOUBLE, 32 types_pb2.DT_INT32: full_type_pb2.TFT_INT32, 33 types_pb2.DT_UINT8: full_type_pb2.TFT_UINT8, 34 types_pb2.DT_INT16: full_type_pb2.TFT_INT16, 35 types_pb2.DT_INT8: full_type_pb2.TFT_INT8, 36 types_pb2.DT_STRING: full_type_pb2.TFT_STRING, 37 types_pb2.DT_COMPLEX64: full_type_pb2.TFT_COMPLEX64, 38 types_pb2.DT_INT64: full_type_pb2.TFT_INT64, 39 types_pb2.DT_BOOL: full_type_pb2.TFT_BOOL, 40 types_pb2.DT_UINT16: full_type_pb2.TFT_UINT16, 41 types_pb2.DT_COMPLEX128: full_type_pb2.TFT_COMPLEX128, 42 types_pb2.DT_HALF: full_type_pb2.TFT_HALF, 43 types_pb2.DT_UINT32: full_type_pb2.TFT_UINT32, 44 types_pb2.DT_UINT64: full_type_pb2.TFT_UINT64, 45 types_pb2.DT_VARIANT: full_type_pb2.TFT_LEGACY_VARIANT, 46} 47 48 49def _translate_to_fulltype_for_flat_tensors( 50 spec: type_spec.TypeSpec) -> List[full_type_pb2.FullTypeDef]: 51 """Convert a TypeSec to a list of FullTypeDef. 52 53 The FullTypeDef created corresponds to the encoding used with datasets 54 (and map_fn) that uses variants (and not FullTypeDef corresponding to the 55 default "component" encoding). 56 57 Currently, the only use of this is for information about the contents of 58 ragged tensors, so only ragged tensors return useful full type information 59 and other types return TFT_UNSET. While this could be improved in the future, 60 this function is intended for temporary use and expected to be removed 61 when type inference support is sufficient. 62 63 Args: 64 spec: A TypeSpec for one element of a dataset or map_fn. 65 66 Returns: 67 A list of FullTypeDef corresponding to SPEC. The length of this list 68 is always the same as the length of spec._flat_tensor_specs. 69 """ 70 if isinstance(spec, RaggedTensorSpec): 71 dt = spec.dtype 72 elem_t = _DT_TO_FT.get(dt) 73 if elem_t is None: 74 logging.vlog(1, "dtype %s that has no conversion to fulltype.", dt) 75 elif elem_t == full_type_pb2.TFT_LEGACY_VARIANT: 76 logging.vlog(1, "Ragged tensors containing variants are not supported.", 77 dt) 78 else: 79 assert len(spec._flat_tensor_specs) == 1 # pylint: disable=protected-access 80 return [ 81 full_type_pb2.FullTypeDef( 82 type_id=full_type_pb2.TFT_RAGGED, 83 args=[full_type_pb2.FullTypeDef(type_id=elem_t)]) 84 ] 85 return [ 86 full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET) 87 for t in spec._flat_tensor_specs # pylint: disable=protected-access 88 ] 89 90 91# LINT.IfChange(_specs_for_flat_tensors) 92def _specs_for_flat_tensors(element_spec): 93 """Return a flat list of type specs for element_spec. 94 95 Note that "flat" in this function and in `_flat_tensor_specs` is a nickname 96 for the "batchable tensor list" encoding used by datasets and map_fn 97 internally (in C++/graphs). The ability to batch, unbatch and change 98 batch size is one important characteristic of this encoding. A second 99 important characteristic is that it represets a ragged tensor or sparse 100 tensor as a single tensor of type variant (and this encoding uses special 101 ops to encode/decode to/from variants). 102 103 (In constrast, the more typical encoding, e.g. the C++/graph 104 representation when calling a tf.function, is "component encoding" which 105 represents sparse and ragged tensors as multiple dense tensors and does 106 not use variants or special ops for encoding/decoding.) 107 108 Args: 109 element_spec: A nest of TypeSpec describing the elements of a dataset (or 110 map_fn). 111 112 Returns: 113 A non-nested list of TypeSpec used by the encoding of tensors by 114 datasets and map_fn for ELEMENT_SPEC. The items 115 in this list correspond to the items in `_flat_tensor_specs`. 116 """ 117 if isinstance(element_spec, StructuredTensor.Spec): 118 specs = [] 119 for _, field_spec in sorted( 120 element_spec._field_specs.items(), key=lambda t: t[0]): # pylint: disable=protected-access 121 specs.extend(_specs_for_flat_tensors(field_spec)) 122 elif isinstance(element_spec, type_spec.BatchableTypeSpec) and ( 123 element_spec.__class__._flat_tensor_specs is # pylint: disable=protected-access 124 type_spec.BatchableTypeSpec._flat_tensor_specs): # pylint: disable=protected-access 125 # Classes which use the default `_flat_tensor_specs` from 126 # `BatchableTypeSpec` case (i.e. a derived class does not override 127 # `_flat_tensor_specs`.) are encoded using `component_specs`. 128 specs = nest.flatten( 129 element_spec._component_specs, # pylint: disable=protected-access 130 expand_composites=False) 131 else: 132 # In addition flatting any nesting in Python, 133 # this default case covers things that are encoded by one tensor, 134 # such as dense tensors which are unchanged by encoding and 135 # ragged tensors and sparse tensors which are encoded by a variant tensor. 136 specs = nest.flatten(element_spec, expand_composites=False) 137 return specs 138# LINT.ThenChange() 139# Note that _specs_for_flat_tensors must correspond to _flat_tensor_specs 140 141 142def fulltypes_for_flat_tensors(element_spec): 143 """Convert the element_spec for a dataset to a list of FullType Def. 144 145 Note that "flat" in this function and in `_flat_tensor_specs` is a nickname 146 for the "batchable tensor list" encoding used by datasets and map_fn. 147 The FullTypeDef created corresponds to this encoding (e.g. that uses variants 148 and not the FullTypeDef corresponding to the default "component" encoding). 149 150 This is intended for temporary internal use and expected to be removed 151 when type inference support is sufficient. See limitations of 152 `_translate_to_fulltype_for_flat_tensors`. 153 154 Args: 155 element_spec: A nest of TypeSpec describing the elements of a dataset (or 156 map_fn). 157 158 Returns: 159 A list of FullTypeDef correspoinding to ELEMENT_SPEC. The items 160 in this list correspond to the items in `_flat_tensor_specs`. 161 """ 162 specs = _specs_for_flat_tensors(element_spec) 163 full_types_lists = [_translate_to_fulltype_for_flat_tensors(s) for s in specs] 164 rval = nest.flatten(full_types_lists) # flattens list-of-list to flat list. 165 return rval 166 167 168def fulltype_list_to_product(fulltype_list): 169 """Convert a list of FullType Def into a single FullType Def.""" 170 return full_type_pb2.FullTypeDef( 171 type_id=full_type_pb2.TFT_PRODUCT, args=fulltype_list) 172 173 174def iterator_full_type_from_spec(element_spec): 175 """Returns a FullTypeDef for an iterator for the elements. 176 177 Args: 178 element_spec: A nested structure of `tf.TypeSpec` objects representing the 179 element type specification. 180 181 Returns: 182 A FullTypeDef for an iterator for the element tensor representation. 183 """ 184 args = fulltypes_for_flat_tensors(element_spec) 185 return full_type_pb2.FullTypeDef( 186 type_id=full_type_pb2.TFT_PRODUCT, 187 args=[ 188 full_type_pb2.FullTypeDef( 189 type_id=full_type_pb2.TFT_ITERATOR, 190 args=[ 191 full_type_pb2.FullTypeDef( 192 type_id=full_type_pb2.TFT_PRODUCT, args=args) 193 ]) 194 ]) 195