• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""User-defined ExtensionType classes."""
16
17import abc
18import typing
19
20from tensorflow.python.framework import composite_tensor
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import extension_type_field
23from tensorflow.python.framework import immutable_dict
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.framework import type_spec
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import composite_tensor_ops
30from tensorflow.python.ops import gen_math_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.saved_model import nested_structure_coder
33from tensorflow.python.util import nest
34from tensorflow.python.util import tf_decorator
35from tensorflow.python.util import tf_inspect
36
37# Attribute used to keep track of when we're inside a user-defined constructor
38# (in which case the fields of `self` may be modified).
39_IN_CONSTRUCTOR = '_tf_extension_type_in_constructor'
40
41
42# ==============================================================================
43# Utility functions
44# ==============================================================================
45def _create_object_from_type_and_dict(cls, obj_dict):
46  """Creates an object, bypassing the constructor.
47
48  Creates an object of type `cls`, whose `__dict__` is updated to contain
49  `obj_dict`.
50
51  Args:
52    cls: The type of the new object.
53    obj_dict: A `Mapping` that should be used to initialize the new object's
54      `__dict__`.
55
56  Returns:
57    An object of type `cls`.
58  """
59  value = object.__new__(cls)
60  value.__dict__.update(obj_dict)
61  return value
62
63
64# ==============================================================================
65# Metaclass for tf.ExtensionType
66# ==============================================================================
67class ExtensionTypeMetaclass(abc.ABCMeta):
68  """Metaclass for tf.ExtensionType types."""
69
70  def __init__(cls, name, bases, namespace):
71    # Don't transform base classes that are part of the framework -- only
72    # transform user classes.  We identify classes that are part of the
73    # framework by setting '_tf_extension_type_do_not_transform_this_class=True'
74    # in the class definition.  (Note: we check for this in the class namespace,
75    # so it is *not* ineherited.)
76    if not namespace.get('_tf_extension_type_do_not_transform_this_class',
77                         False):
78      _check_field_annotations(cls)
79      _add_extension_type_constructor(cls)
80      _add_type_spec(cls)
81    super(ExtensionTypeMetaclass, cls).__init__(name, bases, namespace)
82
83
84# ==============================================================================
85# Base class for user-defined types
86# ==============================================================================
87class ExtensionType(
88    composite_tensor.CompositeTensor, metaclass=ExtensionTypeMetaclass):
89  """Base class for TensorFlow `ExtensionType` classes.
90
91  Tensorflow `ExtensionType` classes are specialized Python classes that can be
92  used transparently with TensorFlow -- e.g., they can be used with ops
93  such as `tf.cond` or `tf.while_loop` and used as inputs or outputs for
94  `tf.function` and Keras layers.
95
96  New `ExtensionType` classes are defined by creating a subclass of
97  `tf.ExtensionType` that
98  contains type annotations for all instance variables.  The following type
99  annotations are supported:
100
101  Type                 | Example
102  -------------------- | --------------------------------------------
103  Python integers      | `i: int`
104  Python floats        | `f: float`
105  Python strings       | `s: str`
106  Python booleans      | `b: bool`
107  Python None          | `n: None`
108  Tensors              | `t: tf.Tensor`
109  Composite Tensors    | `rt: tf.RaggdTensor`
110  Extension Types      | `m: MyMaskedTensor`
111  Tensor shapes        | `shape: tf.TensorShape`
112  Tensor dtypes        | `dtype: tf.DType`
113  Type unions          | `length: typing.Union[int, float]`
114  Tuples               | `params: typing.Tuple[int, float, int, int]`
115  Tuples w/ Ellipsis   | `lengths: typing.Tuple[int, ...]`
116  Mappings             | `tags: typing.Mapping[str, str]`
117  TensorSpec instances | `t2: tf.TensorSpec(shape=[8, None], dtype=tf.int32)`
118  TypeSpec instances   | `rt2: tf.RaggedTensorSpec(ragged_rank=2)`
119
120  Fields annotated with `typing.Mapping` will be stored using an immutable
121  mapping type.
122
123  Due to technical limitations of Python's `typing` module, `TensorSpec`
124  and `TypeSpec` instances may not currently be nested inside generic types
125  (such as `typing.Union` or `typing.Tuple`).  TODO(b/184564088) Define
126  tf generic types to avoid this limitation.
127
128  ExtensionType values are immutable -- i.e., once constructed, you can not
129  modify or delete any of their instance members.
130
131  ### Examples
132
133  >>> class MaskedTensor(ExtensionType):
134  ...   values: tf.Tensor
135  ...   mask: tf.TensorSpec(shape=None, dtype=tf.bool)
136
137  >>> class Toy(ExtensionType):
138  ...   name: str
139  ...   price: ops.Tensor
140  ...   features: typing.Mapping[str, ops.Tensor]
141
142  >>> class ToyStore(ExtensionType):
143  ...   name: str
144  ...   toys: typing.Tuple[Toy, ...]
145  """
146
147  # Let the metaclass know that it should *not* transform this class (since
148  # this class is part of the ExtensionType framework, and not a user class).
149  _tf_extension_type_do_not_transform_this_class = True
150
151  def __init__(self, *args, **kwargs):
152    if type(self) is ExtensionType:  # pylint: disable=unidiomatic-typecheck
153      raise AssertionError('ExtensionType is an abstract base class.')
154
155  # This class variable is used to cache the return value for
156  # _tf_extension_type_fields.
157  _tf_extension_type_cached_fields = None
158
159  @classmethod
160  def _tf_extension_type_fields(cls):  # pylint: disable=no-self-argument
161    """An ordered list describing the fields of this ExtensionType.
162
163    Returns:
164      A list of `ExtensionTypeField` objects.  Forward references are resolved
165      if possible, or left unresolved otherwise.
166    """
167    if cls._tf_extension_type_cached_fields is not None:
168      return cls._tf_extension_type_cached_fields
169
170    try:
171      type_hints = typing.get_type_hints(cls)
172      ok_to_cache = True  # all forward references have been resolved.
173    except (NameError, AttributeError):
174      # Unresolved forward reference -- gather type hints manually.
175      # * NameError comes from an annotation like `Foo` where class
176      #   `Foo` hasn't been defined yet.
177      # * AttributeError comes from an annotation like `foo.Bar`, where
178      #   the module `foo` exists but `Bar` hasn't been defined yet.
179      # Note: If a user attempts to instantiate a `ExtensionType` type that
180      # still has unresolved forward references (e.g., because of a typo or a
181      # missing import), then the constructor will raise an exception.
182      type_hints = {}
183      for base in reversed(cls.__mro__):
184        type_hints.update(base.__dict__.get('__annotations__', {}))
185      ok_to_cache = False
186
187    fields = []
188    for (name, value_type) in type_hints.items():
189      default = getattr(cls, name,
190                        extension_type_field.ExtensionTypeField.NO_DEFAULT)
191      fields.append(
192          extension_type_field.ExtensionTypeField(name, value_type, default))
193    fields = tuple(fields)
194
195    if ok_to_cache:
196      cls._tf_extension_type_cached_fields = fields
197
198    return fields
199
200  @classmethod
201  def _tf_extension_type_has_field(cls, name):
202    return any(name == field.name for field in cls._tf_extension_type_fields())
203
204  def _tf_extension_type_convert_fields(self):
205    extension_type_field.convert_fields(self._tf_extension_type_fields(),
206                                        self.__dict__)
207
208  def __repr__(self):
209    fields = ', '.join([
210        f'{field.name}={getattr(self, field.name)!r}'
211        for field in self._tf_extension_type_fields()
212    ])
213    return f'{type(self).__name__}({fields})'
214
215  def __setattr__(self, name, value):
216    if hasattr(self,
217               _IN_CONSTRUCTOR) and self._tf_extension_type_has_field(name):
218      self.__dict__[name] = value
219    else:
220      raise AttributeError('cannot assign to field %r' % name)
221
222  def __delattr__(self, name):
223    if hasattr(self,
224               _IN_CONSTRUCTOR) and self._tf_extension_type_has_field(name):
225      del self.__dict__[name]
226    else:
227      raise AttributeError('cannot delete field %r' % name)
228
229  def __getattr__(self, name):
230    if '_tf_extension_type_packed_variant' in self.__dict__:
231      # Note: it's *not* ok to cache the results of unpack() here.  In
232      # particular, it would be nice if we could do something like
233      # `self.__dict__.update(unpack(self).__dict__)`, but that (potentially)
234      # violates an invariant required by the `cond` operation.  E.g., if we had
235      # `tf.cond(lambda: x.foo, lambda: x.bar)`, then tensor `x.bar` used in the
236      # "else" branch would be created by an op in the "then" branch (when
237      # looking up `x.foo`); and that's not allowed.
238      return getattr(unpack(self), name)
239
240    raise AttributeError(
241        f'{type(self).__name__!r} object has no attribute {name!r}')
242
243  def __eq__(self, other):
244    if type(self) is not type(other):
245      return False
246
247    if self._type_spec != other._type_spec:
248      return False
249
250    self_tensors = nest.flatten(self, expand_composites=True)
251    other_tensors = nest.flatten(other, expand_composites=True)
252    if len(self_tensors) != len(other_tensors):
253      return False
254    conditions = []
255    for t1, t2 in zip(self_tensors, other_tensors):
256      conditions.append(
257          math_ops.reduce_all(
258              gen_math_ops.equal(
259                  array_ops.shape(t1),
260                  array_ops.shape(t2),
261                  incompatible_shape_error=False)))
262      # Explicitly check shape (values that have different shapes but broadcast
263      # to the same value are considered non-equal).
264      conditions.append(
265          math_ops.reduce_all(
266              gen_math_ops.equal(t1, t2, incompatible_shape_error=False)))
267    return math_ops.reduce_all(array_ops.stack(conditions))
268
269  def __ne__(self, other):
270    eq = self.__eq__(other)
271    if isinstance(eq, ops.Tensor):
272      return math_ops.logical_not(eq)
273    else:
274      return not eq
275
276  def __validate__(self):
277    """Perform post-construction validation."""
278
279  # This instance variable is used to cache the value for the _type_spec
280  # property.
281  _tf_extension_type_cached_type_spec = None
282
283  @property
284  def _type_spec(self):  # CompositeTensor API.
285    # Note: the TypeSpec contains all static (non-tensor) data from `self`.
286    if self._tf_extension_type_cached_type_spec is None:
287      assert not is_packed(self)  # Packed version always caches TypeSpec.
288      self.__dict__[
289          '_tf_extension_type_cached_type_spec'] = self.Spec.from_value(self)
290    return self._tf_extension_type_cached_type_spec
291
292
293def pack(value):
294  """Returns a copy of `value` with fields packed in a single Variant.
295
296  Args:
297    value: An `ExtensionType` object.
298
299  Returns:
300    An `ExtensionType` object.
301  """
302  if is_packed(value):
303    return value
304
305  spec = value._type_spec._tf_extension_type_with_packed(True)  # pylint: disable=protected-access
306  try:
307    variant = composite_tensor_ops.composite_tensor_to_variants(value)
308  except nested_structure_coder.NotEncodableError as e:
309    # Note: the only time `_TypeSpecCodec.can_encode` returns False is if the
310    # named type is not registered.  The default error message would simply
311    # tell the user that there is no encoder for the object, so we provide
312    # a more useful message letting them know how to register the type.
313    raise ValueError('ExtensionTypes must have a __name__ field in order '
314                     'to be packed.') from e
315
316  return _create_object_from_type_and_dict(
317      type(value), {
318          '_tf_extension_type_cached_type_spec': spec,
319          '_tf_extension_type_packed_variant': variant,
320      })
321
322
323def unpack(value):
324  """Returns a copy of `value` with individual fields stored in __dict__.
325
326  Args:
327    value: An `ExtensionType` object.
328
329  Returns:
330    An `ExtensionType` object.
331  """
332  if not is_packed(value):
333    return value
334
335  # pylint: disable=protected-access
336  variant = value._tf_extension_type_packed_variant
337  spec = value._tf_extension_type_cached_type_spec
338  spec = spec._tf_extension_type_with_packed(False)
339  return composite_tensor_ops.composite_tensor_from_variant(variant, spec)
340
341
342def is_packed(value):
343  """Returns true if `value`'s fields are packed in a single Variant."""
344  if not isinstance(value, ExtensionType):
345    raise ValueError(f'Expected ExtensionType, got {value}')
346  return '_tf_extension_type_packed_variant' in value.__dict__
347
348
349# ==============================================================================
350# Base class for the tf.ExtensionType TypeSpecs
351# ==============================================================================
352# TODO(b/184565242) Support custom TypeSpec constructors.
353# TODO(b/184565242) Support custom TypeSpec methods & properties.
354# TODO(b/184565242) Support custom TypeSpec validation.
355# TODO(b/184565242) Support custom TypeSpec repr.
356# TODO(b/184565242) Support customizing type relaxation for tracing.
357# TODO(b/184565242) Support conversion to/from FullType
358
359
360class ExtensionTypeSpec(type_spec.TypeSpec):
361  """Base class for tf.ExtensionType TypeSpec."""
362
363  def _serialize(self):  # TypeSpec API.
364    # Use a tuple of (name, value) pairs, to ensure we preserve field ordering.
365    fields = [f.name for f in self._tf_extension_type_fields()]
366    if self._tf_extension_type_is_packed:
367      fields.append('_tf_extension_type_is_packed')
368    return tuple(
369        (f, _change_nested_mappings_to(self.__dict__[f], dict)) for f in fields)
370
371  @classmethod
372  def _deserialize(cls, state):  # TypeSpec API.
373    state = _change_nested_mappings_to(state, immutable_dict.ImmutableDict)
374    return _create_object_from_type_and_dict(cls, state)
375
376  def _to_components(self, value):  # TypeSpec API.
377    if self._tf_extension_type_is_packed:
378      return value._tf_extension_type_packed_variant  # pylint: disable=protected-access
379
380    tensor_or_composite = (ops.Tensor, composite_tensor.CompositeTensor)
381    # Retireve fields by the order of spec dict to preserve field ordering. This
382    # is needed as nest.flatten would sort dictionary entries by key.
383    value_tuple = tuple(value.__dict__[key] for key in self.__dict__)
384    return tuple(
385        x for x in nest.flatten(value_tuple)
386        if isinstance(x, tensor_or_composite))
387
388  def _from_components(self, components):  # TypeSpec API.
389    if self._tf_extension_type_is_packed:
390      return _create_object_from_type_and_dict(
391          self.value_type, {
392              '_tf_extension_type_cached_type_spec': self,
393              '_tf_extension_type_packed_variant': components
394          })
395
396    spec_tuple = tuple(self.__dict__.values())
397    components_iter = iter(components)
398    flat = [
399        next(components_iter) if isinstance(x, type_spec.TypeSpec) else x
400        for x in nest.flatten(spec_tuple)
401    ]
402    if list(components_iter):
403      raise ValueError('Components do not match spec.')
404    value_tuple = nest.pack_sequence_as(spec_tuple, flat)
405    fields = dict(zip(self.__dict__.keys(), value_tuple))
406
407    # Build the new value.  Bypass the constructor (__init__), in case the user
408    # who defined the ExtensionType used a custom constructor.
409    return _create_object_from_type_and_dict(self.value_type, fields)
410
411  @property
412  def _component_specs(self):  # TypeSpec API.
413    if self._tf_extension_type_is_packed:
414      return tensor_spec.TensorSpec((), dtypes.variant)
415
416    components = []
417
418    def push_if_type_spec(x):
419      if isinstance(x, type_spec.TypeSpec):
420        components.append(x)
421
422    nest.map_structure(push_if_type_spec, tuple(self.__dict__.values()))
423    return tuple(components)
424
425  @classmethod
426  def from_value(cls, value):
427    cached_spec = getattr(value, '_tf_extension_type_cached_type_spec', None)
428    if cached_spec is not None:
429      return cached_spec
430
431    value_fields = value.__dict__
432    spec_fields = nest.map_structure(_replace_tensor_with_spec, value_fields)
433    spec_fields.pop('_tf_extension_type_cached_fields', None)
434    return _create_object_from_type_and_dict(cls, spec_fields)
435
436  def __setattr__(self, name, value):
437    if (hasattr(self, _IN_CONSTRUCTOR) and
438        self._tf_extension_type_has_field(name)):
439      self.__dict__[name] = value
440    else:
441      raise AttributeError('cannot assign to field %r' % name)
442
443  def __delattr__(self, name):
444    if (hasattr(self, _IN_CONSTRUCTOR) and
445        self._tf_extension_type_has_field(name)):
446      del self.__dict__[name]
447    else:
448      raise AttributeError('cannot delete field %r' % name)
449
450  def __validate__(self):
451    """Perform post-construction validation."""
452
453  @classmethod
454  def _tf_extension_type_fields(cls):
455    return cls.value_type._tf_extension_type_fields()  # pylint: disable=protected-access
456
457  @classmethod
458  def _tf_extension_type_has_field(cls, name):
459    return any(name == field.name for field in cls._tf_extension_type_fields())
460
461  def _tf_extension_type_convert_fields(self):
462    extension_type_field.convert_fields_for_spec(
463        self._tf_extension_type_fields(), self.__dict__)
464
465  def __repr__(self):
466    fields = ', '.join([f'{k}={v!r}' for (k, v) in self._serialize()])
467    return f'{type(self).__name__}({fields})'
468
469  _tf_extension_type_is_packed = False
470
471  def _tf_extension_type_with_packed(self, value):
472    """Returns a copy of this `TypeSpec` with `packed=value`.
473
474    Args:
475      value: A boolean value.
476
477    Returns:
478      A copy of `self` with `_tf_extension_type_is_packed=value`.
479    """
480    copy = _create_object_from_type_and_dict(type(self), self.__dict__)
481    copy.__dict__['_tf_extension_type_is_packed'] = value
482    return copy
483
484
485def _replace_tensor_with_spec(value):
486  if isinstance(value, ops.Tensor):
487    # Note: we intentionally exclude `value.name` from the `TensorSpec`.
488    return tensor_spec.TensorSpec(value.shape, value.dtype)
489  if hasattr(value, '_type_spec'):
490    return value._type_spec  # pylint: disable=protected-access
491  return value
492
493
494def _change_nested_mappings_to(value, new_type):
495  """Recursively replace mappings with `new_type`."""
496  if isinstance(value, (dict, immutable_dict.ImmutableDict)):
497    return new_type([(k, _change_nested_mappings_to(v, new_type))
498                     for (k, v) in value.items()])
499  elif isinstance(value, tuple):
500    return tuple(_change_nested_mappings_to(elt, new_type) for elt in value)
501  else:
502    return value
503
504
505# ==============================================================================
506# Helper methods for tf.ExtensionTypeMetaclass
507# ==============================================================================
508
509
510def _check_field_annotations(cls):
511  """Validates the field annotations for tf.ExtensionType subclass `cls`."""
512  # Check that no fields use reserved names.
513  for name in cls.__dict__:
514    if extension_type_field.ExtensionTypeField.is_reserved_name(name):
515      raise ValueError(f"The field name '{name}' is reserved.")
516
517  # Check that all fields have type annotaitons.
518  annotations = getattr(cls, '__annotations__', {})
519  for (key, value) in cls.__dict__.items():
520    if not (key in annotations or callable(value) or key.startswith('_abc_') or
521            key == '_tf_extension_type_fields' or
522            key.startswith('__') and key.endswith('__') or
523            isinstance(value, (property, classmethod, staticmethod))):
524      raise ValueError('Field %s must have a type annotation' % key)
525
526
527def _add_extension_type_constructor(cls):
528  """Creates a constructor for a ExtensionType or ExtensionTypeSpec subclass."""
529  if '__init__' in cls.__dict__:
530    _wrap_user_constructor(cls)
531  else:
532    _build_extension_type_constructor(cls)
533
534
535def _wrap_user_constructor(cls):
536  """Wraps a user-defined constructor for tf.ExtensionType subclass `cls`."""
537  user_constructor = cls.__init__
538
539  def wrapped_init(self, *args, **kwargs):
540    self.__dict__[_IN_CONSTRUCTOR] = True
541    user_constructor(self, *args, **kwargs)
542    del self.__dict__[_IN_CONSTRUCTOR]
543
544    self._tf_extension_type_convert_fields()  # pylint: disable=protected-access
545    self.__validate__()
546
547  cls.__init__ = tf_decorator.make_decorator(user_constructor, wrapped_init)
548
549
550# TODO(b/184565242) Consider using the templating system from autograph here.
551def _build_extension_type_constructor(cls):
552  """Builds a constructor for tf.ExtensionType subclass `cls`."""
553  fields = cls._tf_extension_type_fields()  # pylint: disable=protected-access
554
555  # Check that no-default fields don't follow default fields.  (Otherwise, we
556  # can't build a well-formed constructor.)
557  default_fields = []
558  for field in fields:
559    if field.default is not extension_type_field.ExtensionTypeField.NO_DEFAULT:
560      default_fields.append(field.name)
561    elif default_fields:
562      raise ValueError(
563          f'In definition for {cls.__name__}: Field without default '
564          f'{field.name!r} follows field with default {default_fields[-1]!r}.  '
565          f'Either add a default value for {field.name!r}, or move it before '
566          f'{default_fields[0]!r} in the field annotations.')
567
568  params = []
569  kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
570  for field in fields:
571    if field.default is extension_type_field.ExtensionTypeField.NO_DEFAULT:
572      default = tf_inspect.Parameter.empty
573    else:
574      default = field.default
575    params.append(
576        tf_inspect.Parameter(
577            field.name, kind, default=default, annotation=field.value_type))
578
579  signature = tf_inspect.Signature(params, return_annotation=cls.__name__)
580
581  def __init__(self, *args, **kwargs):  # pylint: disable=invalid-name
582    bound_args = signature.bind(*args, **kwargs)
583    bound_args.apply_defaults()
584    self.__dict__.update(bound_args.arguments)
585    self._tf_extension_type_convert_fields()  # pylint: disable=protected-access
586    self.__validate__()
587
588  # __signature__ is supported by some inspection/documentation tools
589  # (but note: typing.get_type_hints does not respect __signature__).
590  __init__.__signature__ = tf_inspect.Signature(
591      [
592          tf_inspect.Parameter('self',
593                               tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
594      ] + params,
595      return_annotation=cls)
596
597  cls.__init__ = __init__
598
599
600def _build_spec_constructor(cls):
601  """Builds a constructor for ExtensionTypeSpec subclass `cls`."""
602  params = []
603  kind = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
604  for field in cls._tf_extension_type_fields():  # pylint: disable=protected-access
605    params.append(tf_inspect.Parameter(field.name, kind))
606
607  signature = tf_inspect.Signature(params, return_annotation=cls.__name__)
608
609  def __init__(self, *args, **kwargs):  # pylint: disable=invalid-name
610    bound_args = signature.bind(*args, **kwargs)
611    bound_args.apply_defaults()
612    self.__dict__.update(bound_args.arguments)
613    self._tf_extension_type_convert_fields()  # pylint: disable=protected-access
614    self.__validate__()
615
616  # __signature__ is supported by some inspection/documentation tools.
617  __init__.__signature__ = tf_inspect.Signature(
618      [
619          tf_inspect.Parameter('self',
620                               tf_inspect.Parameter.POSITIONAL_OR_KEYWORD)
621      ] + params,
622      return_annotation=cls)
623
624  cls.__init__ = __init__
625
626
627def _add_type_spec(cls):
628  """Creates a nested TypeSpec class for tf.ExtensionType subclass `cls`."""
629  # Build the TypeSpec class for this ExtensionType, and add it as a
630  # nested class.
631  spec_name = cls.__name__ + '.Spec'
632  spec_dict = {'value_type': cls}
633  spec = type(spec_name, (ExtensionTypeSpec,), spec_dict)
634  setattr(cls, 'Spec', spec)
635
636  # Build a constructor for the TypeSpec class.
637  _build_spec_constructor(spec)
638
639  cls.__abstractmethods__ -= {'_type_spec'}
640
641  # If the user included an explicit `__name__` attribute, then use that to
642  # register the TypeSpec (so it can be used in SavedModel signatures).
643  if '__name__' in cls.__dict__:
644    type_spec.register(cls.__dict__['__name__'] + '.Spec')(spec)
645
646
647# ==============================================================================
648# Anonymous ExtensionType
649# ==============================================================================
650class AnonymousExtensionType(ExtensionType):
651  """Fallback used to decode `tf.ExtensionType` when the original type is unavailable.
652
653  When a SavedModel is serialized, the signatures of any functions in the
654  SavedModel can include `tf.ExtensionType` subclasses.  These subclasses are
655  usually
656  registered, so they can be restored when the SavedModel is loaded.  However,
657  if a SavedModel is loaded without first registering the ExtensionType types in
658  its
659  signature, then the SavedModel will fall back to using the
660  `AnonymousExtensionType`
661  type instead.
662
663  If necessary, `AnonymousExtensionType` objects can be converted to a concrete
664  `tf.ExtensionType` subclass (and vice versa) using `reinterpret`.
665  """
666
667  # Let the metaclass know that it should *not* transform this class (since
668  # this class is part of the ExtensionType framework, and not a user class).
669  _tf_extension_type_do_not_transform_this_class = True
670
671  def __init__(self, **fields):
672    for name in fields:
673      if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or
674          (name.startswith('__') and name.endswith('__'))):
675        raise ValueError(f'The field name {name!r} is reserved.')
676    fields = [(k, _convert_anonymous_fields(v)) for (k, v) in fields.items()]
677    self.__dict__.update(fields)
678    self._tf_extension_type_convert_fields()
679    super().__init__()
680
681  @classmethod
682  def _tf_extension_type_fields(cls):
683    return [
684        extension_type_field.ExtensionTypeField(name, None)
685        for name in cls.__dict__
686        if not extension_type_field.ExtensionTypeField.is_reserved_name(name)
687    ]
688
689  def __setattr__(self, name, value):
690    raise AttributeError('cannot assign to field %r' % name)
691
692  def __delattr__(self, name):
693    raise AttributeError('cannot delete field %r' % name)
694
695  def _tf_extension_type_convert_fields(self):
696    fields = [(k, _convert_anonymous_fields(v))
697              for (k, v) in self.__dict__.items()
698              if not extension_type_field.ExtensionTypeField.is_reserved_name(k)
699             ]
700    self.__dict__.update(fields)
701
702  def __repr__(self):
703    fields = [
704        f'{k}={v!r}' for (k, v) in self.__dict__.items()
705        if not extension_type_field.ExtensionTypeField.is_reserved_name(k)
706    ]
707    return f'AnonymousExtensionType({", ".join(fields)})'
708
709  _tf_extension_type_cached_type_spec = None
710
711  @property
712  def _type_spec(self):  # CompositeTensor API.
713    # Note: the TypeSpec contains all static (non-tensor) data from `self`.
714    if self._tf_extension_type_cached_type_spec is None:
715      spec = AnonymousExtensionTypeSpec.from_value(self)
716      self.__dict__['_tf_extension_type_cached_type_spec'] = spec
717    return self._tf_extension_type_cached_type_spec
718
719
720@type_spec.register('tf.AnonymousExtensionType.Spec')
721class AnonymousExtensionTypeSpec(ExtensionTypeSpec):
722  """TypeSpec for AnonymousExtensionType."""
723
724  def __init__(self, **fields):
725    for name in fields:
726      if (extension_type_field.ExtensionTypeField.is_reserved_name(name) or
727          (name.startswith('__') and name.endswith('__'))):
728        raise ValueError(f'The field name {name!r} is reserved.')
729    fields = [(k, _convert_anonymous_fields(v, for_spec=True))
730              for (k, v) in fields.items()]
731    self.__dict__.update(fields)
732    super().__init__()
733
734  value_type = AnonymousExtensionType  # TypeSpec API.
735
736  def _serialize(self):  # TypeSpec API.
737    return tuple(
738        (name, _change_nested_mappings_to(value, dict))
739        for (name, value) in self.__dict__.items()
740        if not extension_type_field.ExtensionTypeField.is_reserved_name(name))
741
742  def __setattr__(self, name, value):
743    raise AttributeError('cannot assign to field %r' % name)
744
745  def __delattr__(self, name):
746    raise AttributeError('cannot delete field %r' % name)
747
748
749def _convert_anonymous_fields(value, for_spec=False):
750  """Type-checks and converts `value` for inclusion in an AnonymousExtensionType."""
751  if isinstance(value, (int, float, bool, str, bytes, type(None), dtypes.DType,
752                        tensor_shape.TensorShape)):
753    return value
754
755  if isinstance(value, tuple):
756    return tuple(_convert_anonymous_fields(v, for_spec) for v in value)
757
758  if isinstance(value, typing.Mapping):
759    return immutable_dict.ImmutableDict([
760        (_convert_anonymous_fields(k, for_spec),
761         _convert_anonymous_fields(v, for_spec)) for (k, v) in value.items()
762    ])
763
764  if (isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)) and
765      not for_spec):
766    return value
767
768  if isinstance(value, type_spec.TypeSpec) and for_spec:
769    return value
770
771  raise ValueError(f'Unsupported field value: {value!r}')
772
773
774# ==============================================================================
775# reinterpret
776# ==============================================================================
777def reinterpret(value, new_type):
778  """Converts a given `ExtensionType` to a new type with compatible fields.
779
780  In particular, this can be used to convert a concrete subclass of
781  `ExtensionType` to an `AnonymousExtensionType`, or vice versa.  When
782  converting to a non-anonymous ExtensionType, field values are type-checked to
783  ensure they are consistent with `new_type`'s type annotations, and validated
784  with `new_type.__validate__`.
785
786  Args:
787    value: An instance of a subclass of `tf.ExtensionType`
788    new_type: A subclass of `tf.ExtensionType`
789
790  Returns:
791    An instance of `new_type`, whose fields are copied from `value`.
792  """
793  if not isinstance(value, ExtensionType):
794    raise ValueError(
795        f'Expected `value` to be a tf.ExtensionType; got {value!r}')
796  if not (isinstance(new_type, type) and issubclass(new_type, ExtensionType)):
797    raise ValueError('Expected `new_type` to be a subclass of tf.ExtensionType;'
798                     f' got {new_type!r}')
799
800  fields = [
801      item for item in value.__dict__.items()
802      if not extension_type_field.ExtensionTypeField.is_reserved_name(item[0])
803  ]
804  new_value = _create_object_from_type_and_dict(new_type, fields)
805  new_value._tf_extension_type_convert_fields()  # pylint: disable=protected-access
806  new_value.__validate__()
807  return new_value
808