1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""User-defined ExtensionType classes.""" 16 17import abc 18import typing 19 20from tensorflow.python.framework import composite_tensor 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import extension_type_field 23from tensorflow.python.framework import immutable_dict 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.framework import type_spec 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import composite_tensor_ops 30from tensorflow.python.ops import gen_math_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.saved_model import nested_structure_coder 33from tensorflow.python.util import nest 34from tensorflow.python.util import tf_decorator 35from tensorflow.python.util import tf_inspect 36 37# Attribute used to keep track of when we're inside a user-defined constructor 38# (in which case the fields of `self` may be modified). 39_IN_CONSTRUCTOR = '_tf_extension_type_in_constructor' 40 41 42# ============================================================================== 43# Utility functions 44# ============================================================================== 45def _create_object_from_type_and_dict(cls, obj_dict): 46 """Creates an object, bypassing the constructor. 47 48 Creates an object of type `cls`, whose `__dict__` is updated to contain 49 `obj_dict`. 50 51 Args: 52 cls: The type of the new object. 53 obj_dict: A `Mapping` that should be used to initialize the new object's 54 `__dict__`. 55 56 Returns: 57 An object of type `cls`. 58 """ 59 value = object.__new__(cls) 60 value.__dict__.update(obj_dict) 61 return value 62 63 64# ============================================================================== 65# Metaclass for tf.ExtensionType 66# ============================================================================== 67class ExtensionTypeMetaclass(abc.ABCMeta): 68 """Metaclass for tf.ExtensionType types.""" 69 70 def __init__(cls, name, bases, namespace): 71 # Don't transform base classes that are part of the framework -- only 72 # transform user classes. We identify classes that are part of the 73 # framework by setting '_tf_extension_type_do_not_transform_this_class=True' 74 # in the class definition. (Note: we check for this in the class namespace, 75 # so it is *not* ineherited.) 76 if not namespace.get('_tf_extension_type_do_not_transform_this_class', 77 False): 78 _check_field_annotations(cls) 79 _add_extension_type_constructor(cls) 80 _add_type_spec(cls) 81 super(ExtensionTypeMetaclass, cls).__init__(name, bases, namespace) 82 83 84# ============================================================================== 85# Base class for user-defined types 86# ============================================================================== 87class ExtensionType( 88 composite_tensor.CompositeTensor, metaclass=ExtensionTypeMetaclass): 89 """Base class for TensorFlow `ExtensionType` classes. 90 91 Tensorflow `ExtensionType` classes are specialized Python classes that can be 92 used transparently with TensorFlow -- e.g., they can be used with ops 93 such as `tf.cond` or `tf.while_loop` and used as inputs or outputs for 94 `tf.function` and Keras layers. 95 96 New `ExtensionType` classes are defined by creating a subclass of 97 `tf.ExtensionType` that 98 contains type annotations for all instance variables. The following type 99 annotations are supported: 100 101 Type | Example 102 -------------------- | -------------------------------------------- 103 Python integers | `i: int` 104 Python floats | `f: float` 105 Python strings | `s: str` 106 Python booleans | `b: bool` 107 Python None | `n: None` 108 Tensors | `t: tf.Tensor` 109 Composite Tensors | `rt: tf.RaggdTensor` 110 Extension Types | `m: MyMaskedTensor` 111 Tensor shapes | `shape: tf.TensorShape` 112 Tensor dtypes | `dtype: tf.DType` 113 Type unions | `length: typing.Union[int, float]` 114 Tuples | `params: typing.Tuple[int, float, int, int]` 115 Tuples w/ Ellipsis | `lengths: typing.Tuple[int, ...]` 116 Mappings | `tags: typing.Mapping[str, str]` 117 TensorSpec instances | `t2: tf.TensorSpec(shape=[8, None], dtype=tf.int32)` 118 TypeSpec instances | `rt2: tf.RaggedTensorSpec(ragged_rank=2)` 119 120 Fields annotated with `typing.Mapping` will be stored using an immutable 121 mapping type. 122 123 Due to technical limitations of Python's `typing` module, `TensorSpec` 124 and `TypeSpec` instances may not currently be nested inside generic types 125 (such as `typing.Union` or `typing.Tuple`). TODO(b/184564088) Define 126 tf generic types to avoid this limitation. 127 128 ExtensionType values are immutable -- i.e., once constructed, you can not 129 modify or delete any of their instance members. 130 131 ### Examples 132 133 >>> class MaskedTensor(ExtensionType): 134 ... values: tf.Tensor 135 ... mask: tf.TensorSpec(shape=None, dtype=tf.bool) 136 137 >>> class Toy(ExtensionType): 138 ... name: str 139 ... price: ops.Tensor 140 ... features: typing.Mapping[str, ops.Tensor] 141 142 >>> class ToyStore(ExtensionType): 143 ... name: str 144 ... toys: typing.Tuple[Toy, ...] 145 """ 146 147 # Let the metaclass know that it should *not* transform this class (since 148 # this class is part of the ExtensionType framework, and not a user class). 149 _tf_extension_type_do_not_transform_this_class = True 150 151 def __init__(self, *args, **kwargs): 152 if type(self) is ExtensionType: # pylint: disable=unidiomatic-typecheck 153 raise AssertionError('ExtensionType is an abstract base class.') 154 155 # This class variable is used to cache the return value for 156 # _tf_extension_type_fields. 157 _tf_extension_type_cached_fields = None 158 159 @classmethod 160 def _tf_extension_type_fields(cls): # pylint: disable=no-self-argument 161 """An ordered list describing the fields of this ExtensionType. 162 163 Returns: 164 A list of `ExtensionTypeField` objects. Forward references are resolved 165 if possible, or left unresolved otherwise. 166 """ 167 if cls._tf_extension_type_cached_fields is not None: 168 return cls._tf_extension_type_cached_fields 169 170 try: 171 type_hints = typing.get_type_hints(cls) 172 ok_to_cache = True # all forward references have been resolved. 173 except (NameError, AttributeError): 174 # Unresolved forward reference -- gather type hints manually. 175 # * NameError comes from an annotation like `Foo` where class 176 # `Foo` hasn't been defined yet. 177 # * AttributeError comes from an annotation like `foo.Bar`, where 178 # the module `foo` exists but `Bar` hasn't been defined yet. 179 # Note: If a user attempts to instantiate a `ExtensionType` type that 180 # still has unresolved forward references (e.g., because of a typo or a 181 # missing import), then the constructor will raise an exception. 182 type_hints = {} 183 for base in reversed(cls.__mro__): 184 type_hints.update(base.__dict__.get('__annotations__', {})) 185 ok_to_cache = False 186 187 fields = [] 188 for (name, value_type) in type_hints.items(): 189 default = getattr(cls, name, 190 extension_type_field.ExtensionTypeField.NO_DEFAULT) 191 fields.append( 192 extension_type_field.ExtensionTypeField(name, value_type, default)) 193 fields = tuple(fields) 194 195 if ok_to_cache: 196 cls._tf_extension_type_cached_fields = fields 197 198 return fields 199 200 @classmethod 201 def _tf_extension_type_has_field(cls, name): 202 return any(name == field.name for field in cls._tf_extension_type_fields()) 203 204 def _tf_extension_type_convert_fields(self): 205 extension_type_field.convert_fields(self._tf_extension_type_fields(), 206 self.__dict__) 207 208 def __repr__(self): 209 fields = ', '.join([ 210 f'{field.name}={getattr(self, field.name)!r}' 211 for field in self._tf_extension_type_fields() 212 ]) 213 return f'{type(self).__name__}({fields})' 214 215 def __setattr__(self, name, value): 216 if hasattr(self, 217 _IN_CONSTRUCTOR) and self._tf_extension_type_has_field(name): 218 self.__dict__[name] = value 219 else: 220 raise AttributeError('cannot assign to field %r' % name) 221 222 def __delattr__(self, name): 223 if hasattr(self, 224 _IN_CONSTRUCTOR) and self._tf_extension_type_has_field(name): 225 del self.__dict__[name] 226 else: 227 raise AttributeError('cannot delete field %r' % name) 228 229 def __getattr__(self, name): 230 if '_tf_extension_type_packed_variant' in self.__dict__: 231 # Note: it's *not* ok to cache the results of unpack() here. In 232 # particular, it would be nice if we could do something like 233 # `self.__dict__.update(unpack(self).__dict__)`, but that (potentially) 234 # violates an invariant required by the `cond` operation. E.g., if we had 235 # `tf.cond(lambda: x.foo, lambda: x.bar)`, then tensor `x.bar` used in the 236 # "else" branch would be created by an op in the "then" branch (when 237 # looking up `x.foo`); and that's not allowed. 238 return getattr(unpack(self), name) 239 240 raise AttributeError( 241 f'{type(self).__name__!r} object has no attribute {name!r}') 242 243 def __eq__(self, other): 244 if type(self) is not type(other): 245 return False 246 247 if self._type_spec != other._type_spec: 248 return False 249 250 self_tensors = nest.flatten(self, expand_composites=True) 251 other_tensors = nest.flatten(other, expand_composites=True) 252 if len(self_tensors) != len(other_tensors): 253 return False 254 conditions = [] 255 for t1, t2 in zip(self_tensors, other_tensors): 256 conditions.append( 257 math_ops.reduce_all( 258 gen_math_ops.equal( 259 array_ops.shape(t1), 260 array_ops.shape(t2), 261 incompatible_shape_error=False))) 262 # Explicitly check shape (values that have different shapes but broadcast 263 # to the same value are considered non-equal). 264 conditions.append( 265 math_ops.reduce_all( 266 gen_math_ops.equal(t1, t2, incompatible_shape_error=False))) 267 return math_ops.reduce_all(array_ops.stack(conditions)) 268 269 def __ne__(self, other): 270 eq = self.__eq__(other) 271 if isinstance(eq, ops.Tensor): 272 return math_ops.logical_not(eq) 273 else: 274 return not eq 275 276 def __validate__(self): 277 """Perform post-construction validation.""" 278 279 # This instance variable is used to cache the value for the _type_spec 280 # property. 281 _tf_extension_type_cached_type_spec = None 282 283 @property 284 def _type_spec(self): # CompositeTensor API. 285 # Note: the TypeSpec contains all static (non-tensor) data from `self`. 286 if self._tf_extension_type_cached_type_spec is None: 287 assert not is_packed(self) # Packed version always caches TypeSpec. 288 self.__dict__[ 289 '_tf_extension_type_cached_type_spec'] = self.Spec.from_value(self) 290 return self._tf_extension_type_cached_type_spec 291 292 293def pack(value): 294 """Returns a copy of `value` with fields packed in a single Variant. 295 296 Args: 297 value: An `ExtensionType` object. 298 299 Returns: 300 An `ExtensionType` object. 301 """ 302 if is_packed(value): 303 return value 304 305 spec = value._type_spec._tf_extension_type_with_packed(True) # pylint: disable=protected-access 306 try: 307 variant = composite_tensor_ops.composite_tensor_to_variants(value) 308 except nested_structure_coder.NotEncodableError as e: 309 # Note: the only time `_TypeSpecCodec.can_encode` returns False is if the 310 # named type is not registered. The default error message would simply 311 # tell the user that there is no encoder for the object, so we provide 312 # a more useful message letting them know how to register the type. 313 raise ValueError('ExtensionTypes must have a __name__ field in order ' 314 'to be packed.') from e 315 316 return _create_object_from_type_and_dict( 317 type(value), { 318 '_tf_extension_type_cached_type_spec': spec, 319 '_tf_extension_type_packed_variant': variant, 320 }) 321 322 323def unpack(value): 324 """Returns a copy of `value` with individual fields stored in __dict__. 325 326 Args: 327 value: An `ExtensionType` object. 328 329 Returns: 330 An `ExtensionType` object. 331 """ 332 if not is_packed(value): 333 return value 334 335 # pylint: disable=protected-access 336 variant = value._tf_extension_type_packed_variant 337 spec = value._tf_extension_type_cached_type_spec 338 spec = spec._tf_extension_type_with_packed(False) 339 return composite_tensor_ops.composite_tensor_from_variant(variant, spec) 340 341 342def is_packed(value): 343 """Returns true if `value`'s fields are packed in a single Variant.""" 344 if not isinstance(value, ExtensionType): 345 raise ValueError(f'Expected ExtensionType, got {value}') 346 return '_tf_extension_type_packed_variant' in value.__dict__ 347 348 349# ============================================================================== 350# Base class for the tf.ExtensionType TypeSpecs 351# ============================================================================== 352# TODO(b/184565242) Support custom TypeSpec constructors. 353# TODO(b/184565242) Support custom TypeSpec methods & properties. 354# TODO(b/184565242) Support custom TypeSpec validation. 355# TODO(b/184565242) Support custom TypeSpec repr. 356# TODO(b/184565242) Support customizing type relaxation for tracing. 357# TODO(b/184565242) Support conversion to/from FullType 358 359 360class ExtensionTypeSpec(type_spec.TypeSpec): 361 """Base class for tf.ExtensionType TypeSpec.""" 362 363 def _serialize(self): # TypeSpec API. 364 # Use a tuple of (name, value) pairs, to ensure we preserve field ordering. 365 fields = [f.name for f in self._tf_extension_type_fields()] 366 if self._tf_extension_type_is_packed: 367 fields.append('_tf_extension_type_is_packed') 368 return tuple( 369 (f, _change_nested_mappings_to(self.__dict__[f], dict)) for f in fields) 370 371 @classmethod 372 def _deserialize(cls, state): # TypeSpec API. 373 state = _change_nested_mappings_to(state, immutable_dict.ImmutableDict) 374 return _create_object_from_type_and_dict(cls, state) 375 376 def _to_components(self, value): # TypeSpec API. 377 if self._tf_extension_type_is_packed: 378 return value._tf_extension_type_packed_variant # pylint: disable=protected-access 379 380 tensor_or_composite = (ops.Tensor, composite_tensor.CompositeTensor) 381 # Retireve fields by the order of spec dict to preserve field ordering. This 382 # is needed as nest.flatten would sort dictionary entries by key. 383 value_tuple = tuple(value.__dict__[key] for key in self.__dict__) 384 return tuple( 385 x for x in nest.flatten(value_tuple) 386 if isinstance(x, tensor_or_composite)) 387 388 def _from_components(self, components): # TypeSpec API. 389 if self._tf_extension_type_is_packed: 390 return _create_object_from_type_and_dict( 391 self.value_type, { 392 '_tf_extension_type_cached_type_spec': self, 393 '_tf_extension_type_packed_variant': components 394 }) 395 396 spec_tuple = tuple(self.__dict__.values()) 397 components_iter = iter(components) 398 flat = [ 399 next(components_iter) if isinstance(x, type_spec.TypeSpec) else x 400 for x in nest.flatten(spec_tuple) 401 ] 402 if list(components_iter): 403 raise ValueError('Components do not match spec.') 404 value_tuple = nest.pack_sequence_as(spec_tuple, flat) 405 fields = dict(zip(self.__dict__.keys(), value_tuple)) 406 407 # Build the new value. Bypass the constructor (__init__), in case the user 408 # who defined the ExtensionType used a custom constructor. 409 return _create_object_from_type_and_dict(self.value_type, fields) 410 411 @property 412 def _component_specs(self): # TypeSpec API. 413 if self._tf_extension_type_is_packed: 414 return tensor_spec.TensorSpec((), dtypes.variant) 415 416 components = [] 417 418 def push_if_type_spec(x): 419 if isinstance(x, type_spec.TypeSpec): 420 components.append(x) 421 422 nest.map_structure(push_if_type_spec, tuple(self.__dict__.values())) 423 return tuple(components) 424 425 @classmethod 426 def from_value(cls, value): 427 cached_spec = getattr(value, '_tf_extension_type_cached_type_spec', None) 428 if cached_spec is not None: 429 return cached_spec 430 431 value_fields = value.__dict__ 432 spec_fields = nest.map_structure(_replace_tensor_with_spec, value_fields) 433 spec_fields.pop('_tf_extension_type_cached_fields', None) 434 return _create_object_from_type_and_dict(cls, spec_fields) 435 436 def __setattr__(self, name, value): 437 if (hasattr(self, _IN_CONSTRUCTOR) and 438 self._tf_extension_type_has_field(name)): 439 self.__dict__[name] = value 440 else: 441 raise AttributeError('cannot assign to field %r' % name) 442 443 def __delattr__(self, name): 444 if (hasattr(self, _IN_CONSTRUCTOR) and 445 self._tf_extension_type_has_field(name)): 446 del self.__dict__[name] 447 else: 448 raise AttributeError('cannot delete field %r' % name) 449 450 def __validate__(self): 451 """Perform post-construction validation.""" 452 453 @classmethod 454 def _tf_extension_type_fields(cls): 455 return cls.value_type._tf_extension_type_fields() # pylint: disable=protected-access 456 457 @classmethod 458 def _tf_extension_type_has_field(cls, name): 459 return any(name == field.name for field in cls._tf_extension_type_fields()) 460 461 def _tf_extension_type_convert_fields(self): 462 extension_type_field.convert_fields_for_spec( 463 self._tf_extension_type_fields(), self.__dict__) 464 465 def __repr__(self): 466 fields = ', '.join([f'{k}={v!r}' for (k, v) in self._serialize()]) 467 return f'{type(self).__name__}({fields})' 468 469 _tf_extension_type_is_packed = False 470 471 def _tf_extension_type_with_packed(self, value): 472 """Returns a copy of this `TypeSpec` with `packed=value`. 473 474 Args: 475 value: A boolean value. 476 477 Returns: 478 A copy of `self` with `_tf_extension_type_is_packed=value`. 479 """ 480 copy = _create_object_from_type_and_dict(type(self), self.__dict__) 481 copy.__dict__['_tf_extension_type_is_packed'] = value 482 return copy 483 484 485def _replace_tensor_with_spec(value): 486 if isinstance(value, ops.Tensor): 487 # Note: we intentionally exclude `value.name` from the `TensorSpec`. 488 return tensor_spec.TensorSpec(value.shape, value.dtype) 489 if hasattr(value, '_type_spec'): 490 return value._type_spec # pylint: disable=protected-access 491 return value 492 493 494def _change_nested_mappings_to(value, new_type): 495 """Recursively replace mappings with `new_type`.""" 496 if isinstance(value, (dict, immutable_dict.ImmutableDict)): 497 return new_type([(k, _change_nested_mappings_to(v, new_type)) 498 for (k, v) in value.items()]) 499 elif isinstance(value, tuple): 500 return tuple(_change_nested_mappings_to(elt, new_type) for elt in value) 501 else: 502 return value 503 504 505# ============================================================================== 506# Helper methods for tf.ExtensionTypeMetaclass 507# ============================================================================== 508 509 510def _check_field_annotations(cls): 511 """Validates the field annotations for tf.ExtensionType subclass `cls`.""" 512 # Check that no fields use reserved names. 513 for name in cls.__dict__: 514 if extension_type_field.ExtensionTypeField.is_reserved_name(name): 515 raise ValueError(f"The field name '{name}' is reserved.") 516 517 # Check that all fields have type annotaitons. 518 annotations = getattr(cls, '__annotations__', {}) 519 for (key, value) in cls.__dict__.items(): 520 if not (key in annotations or callable(value) or key.startswith('_abc_') or 521 key == '_tf_extension_type_fields' or 522 key.startswith('__') and key.endswith('__') or 523 isinstance(value, (property, classmethod, staticmethod))): 524 raise ValueError('Field %s must have a type annotation' % key) 525 526 527def _add_extension_type_constructor(cls): 528 """Creates a constructor for a ExtensionType or ExtensionTypeSpec subclass.""" 529 if '__init__' in cls.__dict__: 530 _wrap_user_constructor(cls) 531 else: 532 _build_extension_type_constructor(cls) 533 534 535def _wrap_user_constructor(cls): 536 """Wraps a user-defined constructor for tf.ExtensionType subclass `cls`.""" 537 user_constructor = cls.__init__ 538 539 def wrapped_init(self, *args, **kwargs): 540 self.__dict__[_IN_CONSTRUCTOR] = True 541 user_constructor(self, *args, **kwargs) 542 del self.__dict__[_IN_CONSTRUCTOR] 543 544 self._tf_extension_type_convert_fields() # pylint: disable=protected-access 545 self.__validate__() 546 547 cls.__init__ = tf_decorator.make_decorator(user_constructor, wrapped_init) 548 549 550# TODO(b/184565242) Consider using the templating system from autograph here. 551def _build_extension_type_constructor(cls): 552 """Builds a constructor for tf.ExtensionType subclass `cls`.""" 553 fields = cls._tf_extension_type_fields() # pylint: disable=protected-access 554 555 # Check that no-default fields don't follow default fields. (Otherwise, we 556 # can't build a well-formed constructor.) 557 default_fields = [] 558 for field in fields: 559 if field.default is not extension_type_field.ExtensionTypeField.NO_DEFAULT: 560 default_fields.append(field.name) 561 elif default_fields: 562 raise ValueError( 563 f'In definition for {cls.__name__}: Field without default ' 564 f'{field.name!r} follows field with default {default_fields[-1]!r}. ' 565 f'Either add a default value for {field.name!r}, or move it before ' 566 f'{default_fields[0]!r} in the field annotations.') 567 568 params = [] 569 kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD 570 for field in fields: 571 if field.default is extension_type_field.ExtensionTypeField.NO_DEFAULT: 572 default = tf_inspect.Parameter.empty 573 else: 574 default = field.default 575 params.append( 576 tf_inspect.Parameter( 577 field.name, kind, default=default, annotation=field.value_type)) 578 579 signature = tf_inspect.Signature(params, return_annotation=cls.__name__) 580 581 def __init__(self, *args, **kwargs): # pylint: disable=invalid-name 582 bound_args = signature.bind(*args, **kwargs) 583 bound_args.apply_defaults() 584 self.__dict__.update(bound_args.arguments) 585 self._tf_extension_type_convert_fields() # pylint: disable=protected-access 586 self.__validate__() 587 588 # __signature__ is supported by some inspection/documentation tools 589 # (but note: typing.get_type_hints does not respect __signature__). 590 __init__.__signature__ = tf_inspect.Signature( 591 [ 592 tf_inspect.Parameter('self', 593 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD) 594 ] + params, 595 return_annotation=cls) 596 597 cls.__init__ = __init__ 598 599 600def _build_spec_constructor(cls): 601 """Builds a constructor for ExtensionTypeSpec subclass `cls`.""" 602 params = [] 603 kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD 604 for field in cls._tf_extension_type_fields(): # pylint: disable=protected-access 605 params.append(tf_inspect.Parameter(field.name, kind)) 606 607 signature = tf_inspect.Signature(params, return_annotation=cls.__name__) 608 609 def __init__(self, *args, **kwargs): # pylint: disable=invalid-name 610 bound_args = signature.bind(*args, **kwargs) 611 bound_args.apply_defaults() 612 self.__dict__.update(bound_args.arguments) 613 self._tf_extension_type_convert_fields() # pylint: disable=protected-access 614 self.__validate__() 615 616 # __signature__ is supported by some inspection/documentation tools. 617 __init__.__signature__ = tf_inspect.Signature( 618 [ 619 tf_inspect.Parameter('self', 620 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD) 621 ] + params, 622 return_annotation=cls) 623 624 cls.__init__ = __init__ 625 626 627def _add_type_spec(cls): 628 """Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`.""" 629 # Build the TypeSpec class for this ExtensionType, and add it as a 630 # nested class. 631 spec_name = cls.__name__ + '.Spec' 632 spec_dict = {'value_type': cls} 633 spec = type(spec_name, (ExtensionTypeSpec,), spec_dict) 634 setattr(cls, 'Spec', spec) 635 636 # Build a constructor for the TypeSpec class. 637 _build_spec_constructor(spec) 638 639 cls.__abstractmethods__ -= {'_type_spec'} 640 641 # If the user included an explicit `__name__` attribute, then use that to 642 # register the TypeSpec (so it can be used in SavedModel signatures). 643 if '__name__' in cls.__dict__: 644 type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec) 645 646 647# ============================================================================== 648# Anonymous ExtensionType 649# ============================================================================== 650class AnonymousExtensionType(ExtensionType): 651 """Fallback used to decode `tf.ExtensionType` when the original type is unavailable. 652 653 When a SavedModel is serialized, the signatures of any functions in the 654 SavedModel can include `tf.ExtensionType` subclasses. These subclasses are 655 usually 656 registered, so they can be restored when the SavedModel is loaded. However, 657 if a SavedModel is loaded without first registering the ExtensionType types in 658 its 659 signature, then the SavedModel will fall back to using the 660 `AnonymousExtensionType` 661 type instead. 662 663 If necessary, `AnonymousExtensionType` objects can be converted to a concrete 664 `tf.ExtensionType` subclass (and vice versa) using `reinterpret`. 665 """ 666 667 # Let the metaclass know that it should *not* transform this class (since 668 # this class is part of the ExtensionType framework, and not a user class). 669 _tf_extension_type_do_not_transform_this_class = True 670 671 def __init__(self, **fields): 672 for name in fields: 673 if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or 674 (name.startswith('__') and name.endswith('__'))): 675 raise ValueError(f'The field name {name!r} is reserved.') 676 fields = [(k, _convert_anonymous_fields(v)) for (k, v) in fields.items()] 677 self.__dict__.update(fields) 678 self._tf_extension_type_convert_fields() 679 super().__init__() 680 681 @classmethod 682 def _tf_extension_type_fields(cls): 683 return [ 684 extension_type_field.ExtensionTypeField(name, None) 685 for name in cls.__dict__ 686 if not extension_type_field.ExtensionTypeField.is_reserved_name(name) 687 ] 688 689 def __setattr__(self, name, value): 690 raise AttributeError('cannot assign to field %r' % name) 691 692 def __delattr__(self, name): 693 raise AttributeError('cannot delete field %r' % name) 694 695 def _tf_extension_type_convert_fields(self): 696 fields = [(k, _convert_anonymous_fields(v)) 697 for (k, v) in self.__dict__.items() 698 if not extension_type_field.ExtensionTypeField.is_reserved_name(k) 699 ] 700 self.__dict__.update(fields) 701 702 def __repr__(self): 703 fields = [ 704 f'{k}={v!r}' for (k, v) in self.__dict__.items() 705 if not extension_type_field.ExtensionTypeField.is_reserved_name(k) 706 ] 707 return f'AnonymousExtensionType({", ".join(fields)})' 708 709 _tf_extension_type_cached_type_spec = None 710 711 @property 712 def _type_spec(self): # CompositeTensor API. 713 # Note: the TypeSpec contains all static (non-tensor) data from `self`. 714 if self._tf_extension_type_cached_type_spec is None: 715 spec = AnonymousExtensionTypeSpec.from_value(self) 716 self.__dict__['_tf_extension_type_cached_type_spec'] = spec 717 return self._tf_extension_type_cached_type_spec 718 719 720@type_spec.register('tf.AnonymousExtensionType.Spec') 721class AnonymousExtensionTypeSpec(ExtensionTypeSpec): 722 """TypeSpec for AnonymousExtensionType.""" 723 724 def __init__(self, **fields): 725 for name in fields: 726 if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or 727 (name.startswith('__') and name.endswith('__'))): 728 raise ValueError(f'The field name {name!r} is reserved.') 729 fields = [(k, _convert_anonymous_fields(v, for_spec=True)) 730 for (k, v) in fields.items()] 731 self.__dict__.update(fields) 732 super().__init__() 733 734 value_type = AnonymousExtensionType # TypeSpec API. 735 736 def _serialize(self): # TypeSpec API. 737 return tuple( 738 (name, _change_nested_mappings_to(value, dict)) 739 for (name, value) in self.__dict__.items() 740 if not extension_type_field.ExtensionTypeField.is_reserved_name(name)) 741 742 def __setattr__(self, name, value): 743 raise AttributeError('cannot assign to field %r' % name) 744 745 def __delattr__(self, name): 746 raise AttributeError('cannot delete field %r' % name) 747 748 749def _convert_anonymous_fields(value, for_spec=False): 750 """Type-checks and converts `value` for inclusion in an AnonymousExtensionType.""" 751 if isinstance(value, (int, float, bool, str, bytes, type(None), dtypes.DType, 752 tensor_shape.TensorShape)): 753 return value 754 755 if isinstance(value, tuple): 756 return tuple(_convert_anonymous_fields(v, for_spec) for v in value) 757 758 if isinstance(value, typing.Mapping): 759 return immutable_dict.ImmutableDict([ 760 (_convert_anonymous_fields(k, for_spec), 761 _convert_anonymous_fields(v, for_spec)) for (k, v) in value.items() 762 ]) 763 764 if (isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)) and 765 not for_spec): 766 return value 767 768 if isinstance(value, type_spec.TypeSpec) and for_spec: 769 return value 770 771 raise ValueError(f'Unsupported field value: {value!r}') 772 773 774# ============================================================================== 775# reinterpret 776# ============================================================================== 777def reinterpret(value, new_type): 778 """Converts a given `ExtensionType` to a new type with compatible fields. 779 780 In particular, this can be used to convert a concrete subclass of 781 `ExtensionType` to an `AnonymousExtensionType`, or vice versa. When 782 converting to a non-anonymous ExtensionType, field values are type-checked to 783 ensure they are consistent with `new_type`'s type annotations, and validated 784 with `new_type.__validate__`. 785 786 Args: 787 value: An instance of a subclass of `tf.ExtensionType` 788 new_type: A subclass of `tf.ExtensionType` 789 790 Returns: 791 An instance of `new_type`, whose fields are copied from `value`. 792 """ 793 if not isinstance(value, ExtensionType): 794 raise ValueError( 795 f'Expected `value` to be a tf.ExtensionType; got {value!r}') 796 if not (isinstance(new_type, type) and issubclass(new_type, ExtensionType)): 797 raise ValueError('Expected `new_type` to be a subclass of tf.ExtensionType;' 798 f' got {new_type!r}') 799 800 fields = [ 801 item for item in value.__dict__.items() 802 if not extension_type_field.ExtensionTypeField.is_reserved_name(item[0]) 803 ] 804 new_value = _create_object_from_type_and_dict(new_type, fields) 805 new_value._tf_extension_type_convert_fields() # pylint: disable=protected-access 806 new_value.__validate__() 807 return new_value 808