• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type specifications for TensorFlow APIs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23import re
24
25import numpy as np
26import six
27
28from tensorflow.python.framework import composite_tensor
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.util import _pywrap_utils
33from tensorflow.python.util import compat
34from tensorflow.python.util import nest
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util.compat import collections_abc
37from tensorflow.python.util.lazy_loader import LazyLoader
38from tensorflow.python.util.tf_export import tf_export
39
40# Use LazyLoader to avoid circular dependencies.
41tensor_spec = LazyLoader(
42    "tensor_spec", globals(),
43    "tensorflow.python.framework.tensor_spec")
44ops = LazyLoader(
45    "ops", globals(),
46    "tensorflow.python.framework.ops")
47
48
49@tf_export("TypeSpec", v1=["TypeSpec", "data.experimental.Structure"])
50@six.add_metaclass(abc.ABCMeta)
51class TypeSpec(object):
52  """Specifies a TensorFlow value type.
53
54  A `tf.TypeSpec` provides metadata describing an object accepted or returned
55  by TensorFlow APIs.  Concrete subclasses, such as `tf.TensorSpec` and
56  `tf.RaggedTensorSpec`, are used to describe different value types.
57
58  For example, `tf.function`'s `input_signature` argument accepts a list
59  (or nested structure) of `TypeSpec`s.
60
61  Creating new subclasses of `TypeSpec` (outside of TensorFlow core) is not
62  currently supported.  In particular, we may make breaking changes to the
63  private methods and properties defined by this base class.
64
65  Example:
66
67  >>> spec = tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32)
68  >>> @tf.function(input_signature=[spec])
69  ... def double(x):
70  ...   return x * 2
71  >>> print(double(tf.ragged.constant([[1, 2], [3]])))
72  <tf.RaggedTensor [[2, 4], [6]]>
73  """
74  # === Subclassing ===
75  #
76  # Each `TypeSpec` subclass must define:
77  #
78  #   * A "component encoding" for values.
79  #   * A "serialization" for types.
80  #
81  # The component encoding for a value is a nested structure of `tf.Tensor`
82  # or `CompositeTensor` that can be used by the `TypeSpec` to reconstruct
83  # the value.  Each individual `TypeSpec` must use the same nested structure
84  # for all values -- this structure is defined by the `component_specs`
85  # attribute.  Decomposing values into components, and reconstructing them
86  # from those components, should be inexpensive.  In particular, it should
87  # *not* require any TensorFlow ops.
88  #
89  # The serialization for a `TypeSpec` is a nested tuple of values that can
90  # be used to reconstruct the `TypeSpec`.  See the documentation for
91  # `_serialize()` for more information.
92
93  __slots__ = []
94
95  @abc.abstractproperty
96  def value_type(self):
97    """The Python type for values that are compatible with this TypeSpec.
98
99    In particular, all values that are compatible with this TypeSpec must be an
100    instance of this type.
101    """
102    raise NotImplementedError("%s.value_type" % type(self).__name__)
103
104  def is_compatible_with(self, spec_or_value):
105    """Returns true if `spec_or_value` is compatible with this TypeSpec."""
106    # === Subclassing ===
107    # If not overridden by subclasses, the default behavior is to convert
108    # `spec_or_value` to a `TypeSpec` (if it isn't already); and then to
109    # consider two `TypeSpec`s compatible if they have the same type, and
110    # the values returned by `_serialize` are compatible (where
111    # `tf.TensorShape`, `tf.TensorSpec`, and `tf.DType` are checked for
112    # compatibility using their `is_compatible_with` method; and all other
113    # types are considered compatible if they are equal).
114    if not isinstance(spec_or_value, TypeSpec):
115      spec_or_value = type_spec_from_value(spec_or_value)
116    if type(self) is not type(spec_or_value):
117      return False
118    return self.__is_compatible(self._serialize(),
119                                spec_or_value._serialize())  # pylint: disable=protected-access
120
121  def most_specific_compatible_type(self, other):
122    """Returns the most specific TypeSpec compatible with `self` and `other`.
123
124    Args:
125      other: A `TypeSpec`.
126
127    Raises:
128      ValueError: If there is no TypeSpec that is compatible with both `self`
129        and `other`.
130    """
131    # === Subclassing ===
132    # If not overridden by a subclass, the default behavior is to raise a
133    # `ValueError` if `self` and `other` have different types, or if their type
134    # serializations differ by anything other than `TensorShape`s.  Otherwise,
135    # the two type serializations are combined (using
136    # `most_specific_compatible_shape` to combine `TensorShape`s), and the
137    # result is used to construct and return a new `TypeSpec`.
138    if type(self) is not type(other):
139      raise ValueError("No TypeSpec is compatible with both %s and %s" %
140                       (self, other))
141    merged = self.__most_specific_compatible_type_serialization(
142        self._serialize(), other._serialize())  # pylint: disable=protected-access
143    return self._deserialize(merged)
144
145  def _with_tensor_ranks_only(self):
146    """Returns a TypeSpec compatible with `self`, with tensor shapes relaxed.
147
148    Returns:
149      A `TypeSpec` that is compatible with `self`, where any `TensorShape`
150      information has been relaxed to include only tensor rank (and not
151      the dimension sizes for individual axes).
152    """
153
154    # === Subclassing ===
155    # If not overridden by a subclass, the default behavior is to serialize
156    # this TypeSpec, relax any TensorSpec or TensorShape values, and
157    # deserialize the result.
158
159    def relax(value):
160      if isinstance(value, TypeSpec):
161        return value._with_tensor_ranks_only()  # pylint: disable=protected-access
162      elif (isinstance(value, tensor_shape.TensorShape) and
163            value.rank is not None):
164        return tensor_shape.TensorShape([None] * value.rank)
165      else:
166        return value
167
168    return self._deserialize(nest.map_structure(relax, self._serialize()))
169
170  # === Component encoding for values ===
171
172  @abc.abstractmethod
173  def _to_components(self, value):
174    """Encodes `value` as a nested structure of `Tensor` or `CompositeTensor`.
175
176    Args:
177      value: A value compatible with this `TypeSpec`.  (Caller is responsible
178        for ensuring compatibility.)
179
180    Returns:
181      A nested structure of `tf.Tensor` or `tf.CompositeTensor` compatible with
182      `self._component_specs`, which can be used to reconstruct `value`.
183    """
184    # === Subclassing ===
185    # This method must be inexpensive (do not call TF ops).
186    raise NotImplementedError("%s._to_components()" % type(self).__name__)
187
188  @abc.abstractmethod
189  def _from_components(self, components):
190    """Reconstructs a value from a nested structure of Tensor/CompositeTensor.
191
192    Args:
193      components: A nested structure of `tf.Tensor` or `tf.CompositeTensor`,
194        compatible with `self._component_specs`.  (Caller is responsible for
195        ensuring compatibility.)
196
197    Returns:
198      A value that is compatible with this `TypeSpec`.
199    """
200    # === Subclassing ===
201    # This method must be inexpensive (do not call TF ops).
202    raise NotImplementedError("%s._from_components()" % type(self).__name__)
203
204  @abc.abstractproperty
205  def _component_specs(self):
206    """A nested structure of TypeSpecs for this type's components.
207
208    Returns:
209      A nested structure describing the component encodings that are returned
210      by this TypeSpec's `_to_components` method.  In particular, for a
211      TypeSpec `spec` and a compatible value `value`:
212
213      ```
214      nest.map_structure(lambda t, c: assert t.is_compatible_with(c),
215                         spec._component_specs, spec._to_components(value))
216      ```
217    """
218    raise NotImplementedError("%s._component_specs()" % type(self).__name__)
219
220  # === Tensor list encoding for values ===
221
222  def _to_tensor_list(self, value):
223    """Encodes `value` as a flat list of `tf.Tensor`.
224
225    By default, this just flattens `self._to_components(value)` using
226    `nest.flatten`.  However, subclasses may override this to return a
227    different tensor encoding for values.  In particular, some subclasses
228    of `BatchableTypeSpec` override this method to return a "boxed" encoding
229    for values, which then can be batched or unbatched.  See
230    `BatchableTypeSpec` for more details.
231
232    Args:
233      value: A value with compatible this `TypeSpec`.  (Caller is responsible
234        for ensuring compatibility.)
235
236    Returns:
237      A list of `tf.Tensor`, compatible with `self._flat_tensor_specs`, which
238      can be used to reconstruct `value`.
239    """
240    return nest.flatten(self._to_components(value), expand_composites=True)
241
242  def _from_tensor_list(self, tensor_list):
243    """Reconstructs a value from a flat list of `tf.Tensor`.
244
245    Args:
246      tensor_list: A flat list of `tf.Tensor`, compatible with
247        `self._flat_tensor_specs`.
248
249    Returns:
250      A value that is compatible with this `TypeSpec`.
251
252    Raises:
253      ValueError: If `tensor_list` is not compatible with
254      `self._flat_tensor_specs`.
255    """
256    self.__check_tensor_list(tensor_list)
257    return self._from_compatible_tensor_list(tensor_list)
258
259  def _from_compatible_tensor_list(self, tensor_list):
260    """Reconstructs a value from a compatible flat list of `tf.Tensor`.
261
262    Args:
263      tensor_list: A flat list of `tf.Tensor`, compatible with
264        `self._flat_tensor_specs`.  (Caller is responsible for ensuring
265        compatibility.)
266
267    Returns:
268      A value that is compatible with this `TypeSpec`.
269    """
270    return self._from_components(nest.pack_sequence_as(
271        self._component_specs, tensor_list, expand_composites=True))
272
273  @property
274  def _flat_tensor_specs(self):
275    """A list of TensorSpecs compatible with self._to_tensor_list(v)."""
276    return nest.flatten(self._component_specs, expand_composites=True)
277
278  # === Serialization for types ===
279
280  @abc.abstractmethod
281  def _serialize(self):
282    """Returns a nested tuple containing the state of this TypeSpec.
283
284    The serialization may contain the following value types: boolean,
285    integer, string, float, None, `TensorSpec`, `tf.TensorShape`, `tf.DType`,
286    `np.ndarray`, `TypeSpec`, and nested tuples, namedtuples, dicts, and
287    OrderedDicts of any of the above.
288
289    This method is used to provide default definitions for: equality
290    testing (__eq__, __ne__), hashing (__hash__), pickling (__reduce__),
291    string representation (__repr__), `self.is_compatible_with()`,
292    `self.most_specific_compatible_type()`, and protobuf serialization
293    (e.g. TensorInfo and StructuredValue).
294    """
295    raise NotImplementedError("%s._serialize()" % type(self).__name__)
296
297  @classmethod
298  def _deserialize(cls, serialization):
299    """Reconstructs a TypeSpec from a value returned by `serialize`."""
300    return cls(*serialization)
301
302  # === Operators ===
303
304  def __eq__(self, other):
305    # pylint: disable=protected-access
306    return (type(other) is type(self) and
307            self.__get_cmp_key() == other.__get_cmp_key())
308
309  def __ne__(self, other):
310    return not self == other
311
312  def __hash__(self):
313    return hash(self.__get_cmp_key())
314
315  def __reduce__(self):
316    return type(self), self._serialize()
317
318  def __repr__(self):
319    return "%s%r" % (type(self).__name__, self._serialize())
320
321  # === Legacy Output ===
322  # TODO(b/133606651) Document and/or deprecate the legacy_output methods.
323  # (These are used by tf.data.)
324
325  def _to_legacy_output_types(self):
326    raise NotImplementedError("%s._to_legacy_output_types()" %
327                              type(self).__name__)
328
329  def _to_legacy_output_shapes(self):
330    raise NotImplementedError("%s._to_legacy_output_shapes()" %
331                              type(self).__name__)
332
333  def _to_legacy_output_classes(self):
334    return self.value_type
335
336  # === Private Helper Methods ===
337
338  def __check_tensor_list(self, tensor_list):
339    expected = self._flat_tensor_specs
340    specs = [type_spec_from_value(t) for t in tensor_list]
341    if len(specs) != len(expected):
342      raise ValueError("Incompatible input: wrong number of tensors")
343    for i, (s1, s2) in enumerate(zip(specs, expected)):
344      if not s1.is_compatible_with(s2):
345        raise ValueError("Incompatible input: tensor %d (%s) is incompatible "
346                         "with %s" % (i, tensor_list[i], s2))
347
348  def __get_cmp_key(self):
349    """Returns a hashable eq-comparable key for `self`."""
350    # TODO(b/133606651): Decide whether to cache this value.
351    return (type(self), self.__make_cmp_key(self._serialize()))
352
353  def __make_cmp_key(self, value):
354    """Converts `value` to a hashable key."""
355    if isinstance(value,
356                  (int, float, bool, np.generic, dtypes.DType, TypeSpec)):
357      return value
358    if isinstance(value, compat.bytes_or_text_types):
359      return value
360    if value is None:
361      return value
362    if isinstance(value, dict):
363      return tuple([
364          tuple([self.__make_cmp_key(key),
365                 self.__make_cmp_key(value[key])])
366          for key in sorted(value.keys())
367      ])
368    if isinstance(value, tuple):
369      return tuple([self.__make_cmp_key(v) for v in value])
370    if isinstance(value, list):
371      return (list, tuple([self.__make_cmp_key(v) for v in value]))
372    if isinstance(value, tensor_shape.TensorShape):
373      if value.ndims is None:
374        # Note: we include a type object in the tuple, to ensure we can't get
375        # false-positive matches (since users can't include type objects).
376        return (tensor_shape.TensorShape, None)
377      return (tensor_shape.TensorShape, tuple(value.as_list()))
378    if isinstance(value, np.ndarray):
379      return (np.ndarray, value.shape,
380              TypeSpec.__nested_list_to_tuple(value.tolist()))
381    raise ValueError("Unsupported value type %s returned by "
382                     "%s._serialize" %
383                     (type(value).__name__, type(self).__name__))
384
385  @staticmethod
386  def __nested_list_to_tuple(value):
387    """Converts a nested list to a corresponding nested tuple."""
388    if isinstance(value, list):
389      return tuple(TypeSpec.__nested_list_to_tuple(v) for v in value)
390    return value
391
392  @staticmethod
393  def __is_compatible(a, b):
394    """Returns true if the given type serializations compatible."""
395    if isinstance(a, TypeSpec):
396      return a.is_compatible_with(b)
397    if type(a) is not type(b):
398      return False
399    if isinstance(a, (list, tuple)):
400      return (len(a) == len(b) and
401              all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b)))
402    if isinstance(a, dict):
403      return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
404          TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys()))
405    if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)):
406      return a.is_compatible_with(b)
407    return a == b
408
409  @staticmethod
410  def __is_named_tuple(t):
411    """Returns true if the given tuple t is a namedtuple."""
412    return (hasattr(t, "_fields") and
413            isinstance(t._fields, collections_abc.Sequence) and
414            all(isinstance(f, six.string_types) for f in t._fields))
415
416  @staticmethod
417  def __most_specific_compatible_type_serialization(a, b):
418    """Helper for most_specific_compatible_type.
419
420    Combines two type serializations as follows:
421
422    * If they are both tuples of the same length, then recursively combine
423      the respective tuple elements.
424    * If they are both dicts with the same keys, then recursively combine
425      the respective dict elements.
426    * If they are both TypeSpecs, then combine using
427      TypeSpec.most_specific_compatible_type.
428    * If they are both TensorShapes, then combine using
429      TensorShape.most_specific_compatible_shape.
430    * If they are both TensorSpecs with the same dtype, then combine using
431      TensorShape.most_specific_compatible_shape to combine shapes.
432    * If they are equal, then return a.
433    * If none of the above, then raise a ValueError.
434
435    Args:
436      a: A serialized TypeSpec or nested component from a serialized TypeSpec.
437      b: A serialized TypeSpec or nested component from a serialized TypeSpec.
438
439    Returns:
440      A value with the same type and structure as `a` and `b`.
441
442    Raises:
443      ValueError: If `a` and `b` are incompatible.
444    """
445    if type(a) is not type(b):
446      raise ValueError("Types are not compatible: %r vs %r" % (a, b))
447    if isinstance(a, (list, tuple)):
448      if len(a) != len(b):
449        raise ValueError("Types are not compatible: %r vs %r" % (a, b))
450      if TypeSpec.__is_named_tuple(a):
451        if not hasattr(b, "_fields") or not isinstance(
452            b._fields, collections_abc.Sequence) or a._fields != b._fields:
453          raise ValueError("Types are not compatible: %r vs %r" % (a, b))
454        return type(a)(*[
455            TypeSpec.__most_specific_compatible_type_serialization(x, y)
456            for (x, y) in zip(a, b)])
457      return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
458                   for (x, y) in zip(a, b))
459    if isinstance(a, collections.OrderedDict):
460      a_keys, b_keys = a.keys(), b.keys()
461      if len(a) != len(b) or a_keys != b_keys:
462        raise ValueError("Types are not compatible: %r vs %r" % (a, b))
463      return collections.OrderedDict([
464          (k,
465           TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k]))
466          for k in a_keys
467      ])
468    if isinstance(a, dict):
469      a_keys, b_keys = sorted(a.keys()), sorted(b.keys())
470      if len(a) != len(b) or a_keys != b_keys:
471        raise ValueError("Types are not compatible: %r vs %r" % (a, b))
472      return {
473          k: TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k])
474          for k in a_keys
475      }
476    if isinstance(a, tensor_shape.TensorShape):
477      return a.most_specific_compatible_shape(b)
478    if isinstance(a, list):
479      raise AssertionError("_serialize() should not return list values.")
480    if isinstance(a, TypeSpec):
481      return a.most_specific_compatible_type(b)
482    if a != b:
483      raise ValueError("Types are not compatible: %r vs %r" % (a, b))
484    return a
485
486
487class BatchableTypeSpec(TypeSpec):
488  """TypeSpec with a batchable tensor encoding.
489
490  The batchable tensor encoding is a list of `tf.Tensor`s that supports
491  batching and unbatching.  In particular, stacking (or unstacking)
492  values with the same `TypeSpec` must be equivalent to stacking (or
493  unstacking) each of their tensor lists.  Unlike the component encoding
494  (returned by `self._to_components)`, the batchable tensor encoding
495  may require using encoding/decoding ops.
496
497  If a subclass's batchable tensor encoding is not simply a flattened version
498  of the component encoding, then the subclass must override `_to_tensor_list`,
499  `_from_tensor_list`, and _flat_tensor_specs`.
500  """
501
502  __slots__ = []
503
504  @abc.abstractmethod
505  def _batch(self, batch_size):
506    """Returns a TypeSpec representing a batch of objects with this TypeSpec.
507
508    Args:
509      batch_size: An `int` representing the number of elements in a batch,
510        or `None` if the batch size may vary.
511
512    Returns:
513      A `TypeSpec` representing a batch of objects with this TypeSpec.
514    """
515    raise NotImplementedError("%s._batch" % type(self).__name__)
516
517  @abc.abstractmethod
518  def _unbatch(self):
519    """Returns a TypeSpec representing a single element this TypeSpec.
520
521    Returns:
522      A `TypeSpec` representing a single element of objects with this TypeSpec.
523    """
524    raise NotImplementedError("%s._unbatch" % type(self).__name__)
525
526  def _to_batched_tensor_list(self, value):
527    """Returns a tensor list encoding for value with rank>0."""
528    tensor_list = self._to_tensor_list(value)
529    if any(t.shape.ndims == 0 for t in tensor_list):
530      raise ValueError("Value %s has insufficient rank for batching." % value)
531    return tensor_list
532
533
534@tf_export("type_spec_from_value")
535def type_spec_from_value(value):
536  """Returns a `tf.TypeSpec` that represents the given `value`.
537
538  Examples:
539
540    >>> tf.type_spec_from_value(tf.constant([1, 2, 3]))
541    TensorSpec(shape=(3,), dtype=tf.int32, name=None)
542    >>> tf.type_spec_from_value(np.array([4.0, 5.0], np.float64))
543    TensorSpec(shape=(2,), dtype=tf.float64, name=None)
544    >>> tf.type_spec_from_value(tf.ragged.constant([[1, 2], [3, 4, 5]]))
545    RaggedTensorSpec(TensorShape([2, None]), tf.int32, 1, tf.int64)
546
547    >>> example_input = tf.ragged.constant([[1, 2], [3]])
548    >>> @tf.function(input_signature=[tf.type_spec_from_value(example_input)])
549    ... def f(x):
550    ...   return tf.reduce_sum(x, axis=1)
551
552  Args:
553    value: A value that can be accepted or returned by TensorFlow APIs.
554      Accepted types for `value` include `tf.Tensor`, any value that can be
555      converted to `tf.Tensor` using `tf.convert_to_tensor`, and any subclass
556      of `CompositeTensor` (such as `tf.RaggedTensor`).
557
558  Returns:
559    A `TypeSpec` that is compatible with `value`.
560
561  Raises:
562    TypeError: If a TypeSpec cannot be built for `value`, because its type
563      is not supported.
564  """
565  spec = _type_spec_from_value(value)
566  if spec is not None:
567    return spec
568
569  # Fallback: try converting value to a tensor.
570  try:
571    tensor = ops.convert_to_tensor(value)
572    spec = _type_spec_from_value(tensor)
573    if spec is not None:
574      return spec
575  except (ValueError, TypeError) as e:
576    logging.vlog(
577        3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e))
578
579  raise TypeError("Could not build a TypeSpec for %r with type %s" %
580                  (value, type(value).__name__))
581
582
583def _type_spec_from_value(value):
584  """Returns a `TypeSpec` that represents the given `value`."""
585  if isinstance(value, ops.Tensor):
586    # Note: we do not include Tensor names when constructing TypeSpecs.
587    return tensor_spec.TensorSpec(value.shape, value.dtype)
588
589  if isinstance(value, composite_tensor.CompositeTensor):
590    return value._type_spec  # pylint: disable=protected-access
591
592  # If `value` is a list and all of its elements can be represented by the same
593  # batchable type spec, then we can represent the entire list using a single
594  # type spec that captures the type accurately (unlike the `convert_to_tensor`
595  # fallback).
596  if isinstance(value, list) and value:
597    subspecs = [_type_spec_from_value(v) for v in value]
598    if isinstance(subspecs[0], BatchableTypeSpec):
599      merged_subspec = subspecs[0]
600      try:
601        for subspec in subspecs[1:]:
602          merged_subspec = merged_subspec.most_specific_compatible_type(subspec)
603        return merged_subspec._batch(len(subspecs))  # pylint: disable=protected-access
604      except (ValueError, TypeError):
605        pass  # incompatible subspecs
606
607  for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY):
608    type_object, converter_fn, allow_subclass = entry
609    if ((type(value) is type_object) or  # pylint: disable=unidiomatic-typecheck
610        (allow_subclass and isinstance(value, type_object))):
611      return converter_fn(value)
612
613  return None
614
615_TYPE_CONVERSION_FUNCTION_REGISTRY = []
616
617
618def register_type_spec_from_value_converter(type_object, converter_fn,
619                                            allow_subclass=False):
620  """Registers a function for converting values with a given type to TypeSpecs.
621
622  If multiple registered `type_object`s match a value, then the most recent
623  registration takes precedence.  Custom converters should not be defined for
624  `CompositeTensor`s; use `CompositeTensor._type_spec` instead.
625
626  Args:
627    type_object: A Python `type` object representing the type of values
628      accepted by `converter_fn`.
629    converter_fn: A function that takes one argument (an instance of the
630      type represented by `type_object`) and returns a `TypeSpec`.
631    allow_subclass: If true, then use `isinstance(value, type_object)` to
632      check for matches.  If false, then use `type(value) is type_object`.
633  """
634  _, type_object = tf_decorator.unwrap(type_object)
635  _TYPE_CONVERSION_FUNCTION_REGISTRY.append(
636      (type_object, converter_fn, allow_subclass))
637
638
639_pywrap_utils.RegisterType("TypeSpec", TypeSpec)
640
641
642_TYPE_SPEC_TO_NAME = {}
643_NAME_TO_TYPE_SPEC = {}
644
645
646# Regular expression for valid TypeSpec names.
647_REGISTERED_NAME_RE = re.compile(r"^(\w+\.)+\w+$")
648
649
650# TODO(b/173744905) tf_export this as "tf.register_type_spec".  (And add a
651# usage example to the docstring, once the API is public.)
652#
653# TODO(b/173744905) Update this decorator to apply to ExtensionType rather than
654# TypeSpec (once we do refactoring to move to_components/from_components from
655# TypeSpec to ExtensionType).
656def register(name):
657  """Decorator used to register a globally unique name for a TypeSpec subclass.
658
659  Args:
660    name: The name of the type spec.  Must be globally unique.  Must have
661      the form `"{project_name}.{type_name}"`.  E.g. `"my_project.MyTypeSpec"`.
662
663  Returns:
664    A class decorator that registers the decorated class with the given name.
665  """
666  if not isinstance(name, str):
667    raise TypeError("Expected `name` to be a string; got %r" % (name,))
668  if not _REGISTERED_NAME_RE.match(name):
669    raise ValueError(
670        "Registered name must have the form '{project_name}.{type_name}' "
671        "(e.g. 'my_project.MyTypeSpec'); got %r." % name)
672
673  def decorator_fn(cls):
674    if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
675      raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
676    if cls in _TYPE_SPEC_TO_NAME:
677      raise ValueError("Class %s.%s has already been registered with name %s."
678                       % (cls.__module__, cls.__name__,
679                          _TYPE_SPEC_TO_NAME[cls]))
680    if name in _NAME_TO_TYPE_SPEC:
681      raise ValueError("Name %s has already been registered for class %s.%s."
682                       % (name, _NAME_TO_TYPE_SPEC[name].__module__,
683                          _NAME_TO_TYPE_SPEC[name].__name__))
684    _TYPE_SPEC_TO_NAME[cls] = name
685    _NAME_TO_TYPE_SPEC[name] = cls
686    return cls
687
688  return decorator_fn
689
690
691# TODO(edloper) tf_export this as "tf.get_type_spec_name" (or some similar name)
692def get_name(cls):
693  """Returns the registered name for TypeSpec `cls`."""
694  if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
695    raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
696  if cls not in _TYPE_SPEC_TO_NAME:
697    raise ValueError("TypeSpec %s.%s has not been registered." %
698                     (cls.__module__, cls.__name__))
699  return _TYPE_SPEC_TO_NAME[cls]
700
701
702# TODO(edloper) tf_export this as "tf.lookup_type_spec" (or some similar name)
703def lookup(name):
704  """Returns the TypeSpec that has been registered with name `name`."""
705  if not isinstance(name, str):
706    raise TypeError("Expected `name` to be a string; got %r" % (name,))
707  if name not in _NAME_TO_TYPE_SPEC:
708    raise ValueError("No TypeSpec has been registered with name %r" % (name,))
709  return _NAME_TO_TYPE_SPEC[name]
710