1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Meatadata about fields for user-defined ExtensionType classes.""" 16 17import collections 18import collections.abc 19import typing 20 21from tensorflow.python.framework import composite_tensor 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import immutable_dict 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.framework import type_spec 28 29# These names may not be used as the name for a ExtensionType field (to prevent 30# name clashes). All names beginning with `'_tf_extension_type'` are also 31# reserved. 32RESERVED_FIELD_NAMES = [ 33 'self', 34 # Name of the nested TypeSpec class. 35 'Spec', 36 # Names defined by the CompositeTensor base class. 37 '_type_spec', 38 '_shape_invariant_to_type_spec', 39 '_consumers', 40 # Names defined by the TypeSpec base class. 41 'value_type', 42 'is_compatible_with', 43 'most_specific_compatible_type', 44 '_with_tensor_ranks_only', 45 '_to_components', 46 '_from_components', 47 '_component_specs', 48 '_to_tensor_list', 49 '_from_tensor_list', 50 '_from_compatible_tensor_list', 51 '_flat_tensor_specs', 52 '_serialize', 53 '_deserialize', 54 '_to_legacy_output_types', 55 '_to_legacy_output_shapes', 56 '_to_legacy_output_classes', 57] 58 59 60class Sentinel(object): 61 """Sentinel value that's not equal (w/ `is`) to any user value.""" 62 63 def __init__(self, name): 64 self._name = name 65 66 def __repr__(self): 67 return self._name 68 69 70# ============================================================================== 71# ExtensionTypeField 72# ============================================================================== 73class ExtensionTypeField( 74 collections.namedtuple('ExtensionTypeField', 75 ['name', 'value_type', 'default'])): 76 """Metadata about a single field in a `tf.ExtensionType` object.""" 77 78 NO_DEFAULT = Sentinel('ExtensionTypeField.NO_DEFAULT') 79 80 def __new__(cls, name, value_type, default=NO_DEFAULT): 81 """Constructs a new ExtensionTypeField containing metadata for a single field. 82 83 Args: 84 name: The name of the new field (`str`). May not be a reserved name. 85 value_type: A python type expression constraining what values this field 86 can take. 87 default: The default value for the new field, or `NO_DEFAULT` if this 88 field has no default value. 89 90 Returns: 91 A new `ExtensionTypeField`. 92 93 Raises: 94 TypeError: If the type described by `value_type` is not currently 95 supported by `tf.ExtensionType`. 96 TypeError: If `default` is specified and its type does not match 97 `value_type`. 98 """ 99 try: 100 validate_field_value_type(value_type, allow_forward_references=True) 101 except TypeError as e: 102 raise TypeError(f'In field {name!r}: {e}') 103 104 if default is not cls.NO_DEFAULT: 105 default = _convert_value(default, value_type, 106 (f'default value for {name}',)) 107 return super(ExtensionTypeField, cls).__new__(cls, name, value_type, 108 default) 109 110 @staticmethod 111 def is_reserved_name(name): 112 """Returns true if `name` is a reserved name.""" 113 return name in RESERVED_FIELD_NAMES or name.lower().startswith( 114 '_tf_extension_type') 115 116 117def validate_field_value_type(value_type, 118 in_mapping_key=False, 119 allow_forward_references=False): 120 """Checks that `value_type` contains only supported type annotations. 121 122 Args: 123 value_type: The type annotation to check. 124 in_mapping_key: True if `value_type` is nested in the key of a mapping. 125 allow_forward_references: If false, then raise an exception if a 126 `value_type` contains a forward reference (i.e., a string literal). 127 128 Raises: 129 TypeError: If `value_type` contains an unsupported type annotation. 130 """ 131 if isinstance(value_type, str) or is_forward_ref(value_type): 132 if allow_forward_references: 133 return 134 else: 135 raise TypeError(f'Unresolved forward reference {value_type!r}') 136 137 if value_type in (int, float, str, bytes, bool, None, _NoneType, 138 dtypes.DType): 139 return 140 elif (value_type in (ops.Tensor, tensor_shape.TensorShape) or 141 isinstance(value_type, type_spec.TypeSpec) or 142 (isinstance(value_type, type) and 143 issubclass(value_type, composite_tensor.CompositeTensor))): 144 if in_mapping_key: 145 raise TypeError('Key must be hashable.') 146 elif is_generic_tuple(value_type) or is_generic_union(value_type): 147 type_args = get_generic_type_args(value_type) 148 if (len(type_args) == 2 and type_args[1] is Ellipsis and 149 is_generic_tuple(value_type)): # `Tuple[X, ...]` 150 validate_field_value_type(type_args[0], in_mapping_key, 151 allow_forward_references) 152 else: 153 for arg in get_generic_type_args(value_type): 154 validate_field_value_type(arg, in_mapping_key, allow_forward_references) 155 elif is_generic_mapping(value_type): 156 key_type, value_type = get_generic_type_args(value_type) 157 validate_field_value_type(key_type, True, allow_forward_references) 158 validate_field_value_type(value_type, in_mapping_key, 159 allow_forward_references) 160 elif isinstance(value_type, type): 161 raise TypeError(f'Unsupported type annotation `{value_type.__name__}`') 162 else: 163 raise TypeError(f'Unsupported type annotation {value_type!r}') 164 165 166# ============================================================================== 167# Type-checking & conversion for ExtensionTypeField values 168# ============================================================================== 169 170 171def convert_fields(fields, field_values): 172 """Type-checks and converts each field in `field_values` (in place). 173 174 Args: 175 fields: A list of `ExtensionTypeField` objects. 176 field_values: A `dict` mapping field names to values. Must contain an entry 177 for each field. I.e., `set(field_values.keys())` must be equal to 178 `set([f.name for f in fields])`. 179 180 Raises: 181 ValueError: If the keys of `field_values` do not match the names of 182 the fields in `fields`. 183 TypeError: If any value in `field_values` does not have the type indicated 184 by the corresponding `ExtensionTypeField` object. 185 """ 186 _convert_fields(fields, field_values, for_spec=False) 187 188 189def convert_fields_for_spec(fields, field_values): 190 """Type-checks and converts field values for a TypeSpec (in place). 191 192 This is similar to `convert_fields`, except that we expect a TypeSpec 193 for tensor-like types. In particular, if the `value_type` of a field 194 specifies a tensor-like type (tf.Tensor, CompositeTensor, or TypeSpec), 195 then the corresponding value in `fields` is expected to contain a TypeSpec 196 (rather than a value described by that TypeSpec). 197 198 Args: 199 fields: A list of `ExtensionTypeField` objects. 200 field_values: A `dict` mapping field names to values. Must contain an entry 201 for each field. I.e., `set(field_values.keys())` must be equal to 202 `set([f.name for f in fields])`. 203 204 Raises: 205 ValueError: If the keys of `field_values` do not match the names of 206 the fields in `fields`. 207 TypeError: If any value in `field_values` does not have the type indicated 208 by the corresponding `ExtensionTypeField` object. 209 """ 210 _convert_fields(fields, field_values, for_spec=True) 211 212 213def _convert_fields(fields, field_values, for_spec): 214 """Type-checks and converts each field in `field_values` (in place). 215 216 Args: 217 fields: A list of `ExtensionTypeField` objects. 218 field_values: A `dict` mapping field names to values. Must contain an entry 219 for each field. I.e., `set(field_values.keys())` must be equal to 220 `set([f.name for f in fields])`. 221 for_spec: If false, then expect a value for tensor-like types; if true, then 222 expect a TypeSpec for tensor-like types. 223 224 Raises: 225 ValueError: If the keys of `field_values` do not match the names of 226 the fields in `fields`. 227 TypeError: If any value in `field_values` does not have the type indicated 228 by the corresponding `ExtensionTypeField` object. 229 """ 230 converted = {} 231 if len(fields) != len(field_values): 232 _report_field_mismatches(fields, field_values) 233 for field in fields: 234 if field.name not in field_values: 235 _report_field_mismatches(fields, field_values) 236 field_value = field_values[field.name] 237 converted[field.name] = _convert_value(field_value, field.value_type, 238 (field.name,), for_spec) 239 field_values.update(converted) 240 241 242def _convert_value(value, expected_type, path, for_spec=False): 243 """Type-checks and converts a value. 244 245 Args: 246 value: The value to type-check. 247 expected_type: The expected type for the value. 248 path: Tuple of `str` naming the value (used for exception messages). 249 for_spec: If false, then expect a value for tensor-like types; if true, then 250 expect a TensorSpec for tensor-like types. 251 252 Returns: 253 A copy of `value`, converted to the expected type. 254 255 Raises: 256 TypeError: If `value` can not be converted to the expected type. 257 """ 258 assert isinstance(path, tuple) 259 260 if expected_type is None: 261 expected_type = _NoneType 262 263 if expected_type is ops.Tensor: 264 return _convert_tensor(value, path, for_spec) 265 elif isinstance(expected_type, tensor_spec.TensorSpec): 266 return _convert_tensor_spec(value, expected_type, path, for_spec) 267 elif isinstance(expected_type, type_spec.TypeSpec): 268 return _convert_type_spec(value, expected_type, path, for_spec) 269 elif (isinstance(expected_type, type) and 270 issubclass(expected_type, composite_tensor.CompositeTensor)): 271 return _convert_composite_tensor(value, expected_type, path, for_spec) 272 elif expected_type in (int, float, bool, str, bytes, _NoneType, dtypes.DType, 273 tensor_shape.TensorShape): 274 if not isinstance(value, expected_type): 275 raise TypeError(f'{"".join(path)}: expected ' 276 f'{expected_type.__name__}, got {value!r}') 277 return value 278 elif is_generic_tuple(expected_type): 279 return _convert_tuple(value, expected_type, path, for_spec) 280 elif is_generic_mapping(expected_type): 281 return _convert_mapping(value, expected_type, path, for_spec) 282 elif is_generic_union(expected_type): 283 return _convert_union(value, expected_type, path, for_spec) 284 else: 285 raise TypeError(f'{"".join(path)}: Unsupported type annotation ' 286 f'{expected_type!r}') 287 288 289def _convert_tensor(value, path, for_spec): 290 """Converts `value` to a `Tensor`.""" 291 if for_spec: 292 if not isinstance(value, tensor_spec.TensorSpec): 293 raise TypeError(f'{"".join(path)}: expected a TensorSpec, got {value!r}') 294 return value 295 296 if not isinstance(value, ops.Tensor): 297 try: 298 value = ops.convert_to_tensor(value) 299 except (ValueError, TypeError) as e: 300 raise TypeError(f'{"".join(path)}: expected a Tensor, ' 301 f'got {value!r}') from e 302 return value 303 304 305def _convert_tensor_spec(value, expected_type, path, for_spec): 306 """Converts `value` to a Tensor comptible with TensorSpec expected_type.""" 307 if for_spec: 308 if not (isinstance(value, tensor_spec.TensorSpec) and 309 expected_type.is_compatible_with(value)): 310 raise TypeError(f'{"".join(path)}: expected a TensorSpec compatible ' 311 f'with {expected_type}, got {value!r}') 312 return value 313 314 if not isinstance(value, ops.Tensor): 315 try: 316 value = ops.convert_to_tensor(value, expected_type.dtype) 317 except (ValueError, TypeError): 318 raise TypeError(f'{"".join(path)}: expected a {expected_type.dtype!r} ' 319 f'Tensor, got {value!r}') 320 if not expected_type.is_compatible_with(value): 321 raise TypeError(f'{"".join(path)}: expected a Tensor compatible with ' 322 f'{expected_type}, got {value!r}') 323 return value 324 325 326def _convert_type_spec(value, expected_type, path, for_spec): 327 """Converts `value` to a value comptible with TypeSpec `expected_type`.""" 328 if for_spec: 329 if not (isinstance(value, type_spec.TypeSpec) and 330 expected_type.is_compatible_with(value)): 331 raise TypeError(f'{"".join(path)}: expected a TypeSpec compatible ' 332 f'with {expected_type}, got {value!r}') 333 return value 334 335 if (isinstance(value, type_spec.TypeSpec) or 336 not expected_type.is_compatible_with(value)): 337 raise TypeError(f'{"".join(path)}: expected {expected_type!r}, ' 338 f'got {value!r}') 339 return value 340 341 342def _convert_composite_tensor(value, expected_type, path, for_spec): 343 """Converts `value` to a value of type `expected_type`.""" 344 if for_spec: 345 if not (isinstance(value, type_spec.TypeSpec) and 346 issubclass(value.value_type, expected_type)): 347 raise TypeError(f'{"".join(path)}: expected a TypeSpec for ' 348 f'{expected_type.__name__}, got {value!r}') 349 return value 350 351 if not isinstance(value, expected_type): 352 raise TypeError(f'{"".join(path)}: expected {expected_type.__name__}, ' 353 f'got {value!r}') 354 return value 355 356 357def _convert_tuple(value, expected_type, path, for_spec): 358 """Converts `value` to a tuple with type `expected_type`.""" 359 if not isinstance(value, typing.Sequence): 360 raise TypeError(f'{"".join(path)}: expected tuple, got {value!r}') 361 element_types = get_generic_type_args(expected_type) 362 if len(element_types) == 2 and element_types[1] is Ellipsis: 363 return tuple([ 364 _convert_value(v, element_types[0], path + (f'[{i}]',), for_spec) 365 for (i, v) in enumerate(value) 366 ]) 367 else: 368 if len(value) != len(element_types): 369 raise TypeError(f'{"".join(path)}: expected tuple with length ' 370 f'{len(element_types)}, got {value!r})') 371 return tuple([ 372 _convert_value(v, t, path + (f'[{i}]',), for_spec) 373 for (i, (v, t)) in enumerate(zip(value, element_types)) 374 ]) 375 376 377def _convert_mapping(value, expected_type, path, for_spec): 378 """Converts `value` to a mapping with type `expected_type`.""" 379 if not isinstance(value, typing.Mapping): 380 raise TypeError(f'{"".join(path)}: expected mapping, got {value!r}') 381 key_type, value_type = get_generic_type_args(expected_type) 382 return immutable_dict.ImmutableDict([ 383 (_convert_value(k, key_type, path + ('[<key>]',), for_spec), 384 _convert_value(v, value_type, path + (f'[{k!r}]',), for_spec)) 385 for (k, v) in value.items() 386 ]) 387 388 389def _convert_union(value, expected_type, path, for_spec): 390 """Converts `value` to a value with any of the types in `expected_type`.""" 391 for type_option in get_generic_type_args(expected_type): 392 try: 393 return _convert_value(value, type_option, path, for_spec) 394 except TypeError: 395 pass 396 raise TypeError(f'{"".join(path)}: expected {expected_type}, got {value!r}') 397 398 399def _report_field_mismatches(fields, field_values): 400 """Raises an exception with mismatches between fields and field_values.""" 401 expected = set(f.name for f in fields) 402 actual = set(field_values) 403 extra = actual - expected 404 if extra: 405 raise ValueError(f'Got unexpected fields: {extra}') 406 missing = expected - actual 407 if missing: 408 raise ValueError(f'Missing required fields: {missing}') 409 410 411# ============================================================================== 412# Utilities for accessing Python generic type annotations (typing.*) 413# ============================================================================== 414def is_generic_union(tp): 415 """Returns true if `tp` is a parameterized typing.Union value.""" 416 return (tp is not typing.Union and 417 getattr(tp, '__origin__', None) is typing.Union) 418 419 420def is_generic_tuple(tp): 421 """Returns true if `tp` is a parameterized typing.Tuple value.""" 422 return (tp not in (tuple, typing.Tuple) and 423 getattr(tp, '__origin__', None) in (tuple, typing.Tuple)) 424 425 426def is_generic_mapping(tp): 427 """Returns true if `tp` is a parameterized typing.Mapping value.""" 428 return (tp not in (collections.abc.Mapping, typing.Mapping) and getattr( 429 tp, '__origin__', None) in (collections.abc.Mapping, typing.Mapping)) 430 431 432def is_forward_ref(tp): 433 """Returns true if `tp` is a typing forward reference.""" 434 if hasattr(typing, 'ForwardRef'): 435 return isinstance(tp, typing.ForwardRef) 436 elif hasattr(typing, '_ForwardRef'): 437 return isinstance(tp, typing._ForwardRef) # pylint: disable=protected-access 438 else: 439 return False 440 441 442# Note: typing.get_args was added in Python 3.8. 443if hasattr(typing, 'get_args'): 444 get_generic_type_args = typing.get_args 445else: 446 get_generic_type_args = lambda tp: tp.__args__ 447 448_NoneType = type(None) 449