• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type specifications for TensorFlow APIs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23import re
24
25import numpy as np
26import six
27
28from tensorflow.python.framework import composite_tensor
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.util import _pywrap_utils
33from tensorflow.python.util import compat
34from tensorflow.python.util import nest
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util.lazy_loader import LazyLoader
37from tensorflow.python.util.tf_export import tf_export
38
39# Use LazyLoader to avoid circular dependencies.
40tensor_spec = LazyLoader(
41    "tensor_spec", globals(),
42    "tensorflow.python.framework.tensor_spec")
43ops = LazyLoader(
44    "ops", globals(),
45    "tensorflow.python.framework.ops")
46
47
48@tf_export("TypeSpec", v1=["TypeSpec", "data.experimental.Structure"])
49@six.add_metaclass(abc.ABCMeta)
50class TypeSpec(object):
51  """Specifies a TensorFlow value type.
52
53  A `tf.TypeSpec` provides metadata describing an object accepted or returned
54  by TensorFlow APIs.  Concrete subclasses, such as `tf.TensorSpec` and
55  `tf.RaggedTensorSpec`, are used to describe different value types.
56
57  For example, `tf.function`'s `input_signature` argument accepts a list
58  (or nested structure) of `TypeSpec`s.
59
60  Creating new subclasses of `TypeSpec` (outside of TensorFlow core) is not
61  currently supported.  In particular, we may make breaking changes to the
62  private methods and properties defined by this base class.
63
64  Example:
65
66  >>> spec = tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32)
67  >>> @tf.function(input_signature=[spec])
68  ... def double(x):
69  ...   return x * 2
70  >>> print(double(tf.ragged.constant([[1, 2], [3]])))
71  <tf.RaggedTensor [[2, 4], [6]]>
72  """
73  # === Subclassing ===
74  #
75  # Each `TypeSpec` subclass must define:
76  #
77  #   * A "component encoding" for values.
78  #   * A "serialization" for types.
79  #
80  # The component encoding for a value is a nested structure of `tf.Tensor`
81  # or `CompositeTensor` that can be used by the `TypeSpec` to reconstruct
82  # the value.  Each individual `TypeSpec` must use the same nested structure
83  # for all values -- this structure is defined by the `component_specs`
84  # attribute.  Decomposing values into components, and reconstructing them
85  # from those components, should be inexpensive.  In particular, it should
86  # *not* require any TensorFlow ops.
87  #
88  # The serialization for a `TypeSpec` is a nested tuple of values that can
89  # be used to reconstruct the `TypeSpec`.  See the documentation for
90  # `_serialize()` for more information.
91
92  __slots__ = []
93
94  @abc.abstractproperty
95  def value_type(self):
96    """The Python type for values that are compatible with this TypeSpec.
97
98    In particular, all values that are compatible with this TypeSpec must be an
99    instance of this type.
100    """
101    raise NotImplementedError("%s.value_type" % type(self).__name__)
102
103  def is_compatible_with(self, spec_or_value):
104    """Returns true if `spec_or_value` is compatible with this TypeSpec."""
105    # === Subclassing ===
106    # If not overridden by subclasses, the default behavior is to convert
107    # `spec_or_value` to a `TypeSpec` (if it isn't already); and then to
108    # consider two `TypeSpec`s compatible if they have the same type, and
109    # the values returned by `_serialize` are compatible (where
110    # `tf.TensorShape`, `tf.TensorSpec`, and `tf.DType` are checked for
111    # compatibility using their `is_compatible_with` method; and all other
112    # types are considered compatible if they are equal).
113    if not isinstance(spec_or_value, TypeSpec):
114      spec_or_value = type_spec_from_value(spec_or_value)
115    if type(self) is not type(spec_or_value):
116      return False
117    return self.__is_compatible(self._serialize(),
118                                spec_or_value._serialize())  # pylint: disable=protected-access
119
120  def most_specific_compatible_type(self, other):
121    """Returns the most specific TypeSpec compatible with `self` and `other`.
122
123    Args:
124      other: A `TypeSpec`.
125
126    Raises:
127      ValueError: If there is no TypeSpec that is compatible with both `self`
128        and `other`.
129    """
130    # === Subclassing ===
131    # If not overridden by a subclass, the default behavior is to raise a
132    # `ValueError` if `self` and `other` have different types, or if their type
133    # serializations differ by anything other than `TensorShape`s.  Otherwise,
134    # the two type serializations are combined (using
135    # `most_specific_compatible_shape` to combine `TensorShape`s), and the
136    # result is used to construct and return a new `TypeSpec`.
137    if type(self) is not type(other):
138      raise ValueError("No TypeSpec is compatible with both %s and %s" %
139                       (self, other))
140    merged = self.__most_specific_compatible_type_serialization(
141        self._serialize(), other._serialize())  # pylint: disable=protected-access
142    return self._deserialize(merged)
143
144  def _with_tensor_ranks_only(self):
145    """Returns a TypeSpec compatible with `self`, with tensor shapes relaxed.
146
147    Returns:
148      A `TypeSpec` that is compatible with `self`, where any `TensorShape`
149      information has been relaxed to include only tensor rank (and not
150      the dimension sizes for individual axes).
151    """
152
153    # === Subclassing ===
154    # If not overridden by a subclass, the default behavior is to serialize
155    # this TypeSpec, relax any TensorSpec or TensorShape values, and
156    # deserialize the result.
157
158    def relax(value):
159      if isinstance(value, TypeSpec):
160        return value._with_tensor_ranks_only()  # pylint: disable=protected-access
161      elif (isinstance(value, tensor_shape.TensorShape) and
162            value.rank is not None):
163        return tensor_shape.TensorShape([None] * value.rank)
164      else:
165        return value
166
167    return self._deserialize(nest.map_structure(relax, self._serialize()))
168
169  # === Component encoding for values ===
170
171  @abc.abstractmethod
172  def _to_components(self, value):
173    """Encodes `value` as a nested structure of `Tensor` or `CompositeTensor`.
174
175    Args:
176      value: A value compatible with this `TypeSpec`.  (Caller is responsible
177        for ensuring compatibility.)
178
179    Returns:
180      A nested structure of `tf.Tensor` or `tf.CompositeTensor` compatible with
181      `self._component_specs`, which can be used to reconstruct `value`.
182    """
183    # === Subclassing ===
184    # This method must be inexpensive (do not call TF ops).
185    raise NotImplementedError("%s._to_components()" % type(self).__name__)
186
187  @abc.abstractmethod
188  def _from_components(self, components):
189    """Reconstructs a value from a nested structure of Tensor/CompositeTensor.
190
191    Args:
192      components: A nested structure of `tf.Tensor` or `tf.CompositeTensor`,
193        compatible with `self._component_specs`.  (Caller is responsible for
194        ensuring compatibility.)
195
196    Returns:
197      A value that is compatible with this `TypeSpec`.
198    """
199    # === Subclassing ===
200    # This method must be inexpensive (do not call TF ops).
201    raise NotImplementedError("%s._from_components()" % type(self).__name__)
202
203  @abc.abstractproperty
204  def _component_specs(self):
205    """A nested structure of TypeSpecs for this type's components.
206
207    Returns:
208      A nested structure describing the component encodings that are returned
209      by this TypeSpec's `_to_components` method.  In particular, for a
210      TypeSpec `spec` and a compatible value `value`:
211
212      ```
213      nest.map_structure(lambda t, c: assert t.is_compatible_with(c),
214                         spec._component_specs, spec._to_components(value))
215      ```
216    """
217    raise NotImplementedError("%s._component_specs()" % type(self).__name__)
218
219  # === Tensor list encoding for values ===
220
221  def _to_tensor_list(self, value):
222    """Encodes `value` as a flat list of `tf.Tensor`.
223
224    By default, this just flattens `self._to_components(value)` using
225    `nest.flatten`.  However, subclasses may override this to return a
226    different tensor encoding for values.  In particular, some subclasses
227    of `BatchableTypeSpec` override this method to return a "boxed" encoding
228    for values, which then can be batched or unbatched.  See
229    `BatchableTypeSpec` for more details.
230
231    Args:
232      value: A value with compatible this `TypeSpec`.  (Caller is responsible
233        for ensuring compatibility.)
234
235    Returns:
236      A list of `tf.Tensor`, compatible with `self._flat_tensor_specs`, which
237      can be used to reconstruct `value`.
238    """
239    return nest.flatten(self._to_components(value), expand_composites=True)
240
241  def _from_tensor_list(self, tensor_list):
242    """Reconstructs a value from a flat list of `tf.Tensor`.
243
244    Args:
245      tensor_list: A flat list of `tf.Tensor`, compatible with
246        `self._flat_tensor_specs`.
247
248    Returns:
249      A value that is compatible with this `TypeSpec`.
250
251    Raises:
252      ValueError: If `tensor_list` is not compatible with
253      `self._flat_tensor_specs`.
254    """
255    self.__check_tensor_list(tensor_list)
256    return self._from_compatible_tensor_list(tensor_list)
257
258  def _from_compatible_tensor_list(self, tensor_list):
259    """Reconstructs a value from a compatible flat list of `tf.Tensor`.
260
261    Args:
262      tensor_list: A flat list of `tf.Tensor`, compatible with
263        `self._flat_tensor_specs`.  (Caller is responsible for ensuring
264        compatibility.)
265
266    Returns:
267      A value that is compatible with this `TypeSpec`.
268    """
269    return self._from_components(nest.pack_sequence_as(
270        self._component_specs, tensor_list, expand_composites=True))
271
272  @property
273  def _flat_tensor_specs(self):
274    """A list of TensorSpecs compatible with self._to_tensor_list(v)."""
275    return nest.flatten(self._component_specs, expand_composites=True)
276
277  # === Serialization for types ===
278
279  @abc.abstractmethod
280  def _serialize(self):
281    """Returns a nested tuple containing the state of this TypeSpec.
282
283    The serialization may contain the following value types: boolean,
284    integer, string, float, None, `TensorSpec`, `tf.TensorShape`, `tf.DType`,
285    `np.ndarray`, `TypeSpec`, and nested tuples, namedtuples, dicts, and
286    OrderedDicts of any of the above.
287
288    This method is used to provide default definitions for: equality
289    testing (__eq__, __ne__), hashing (__hash__), pickling (__reduce__),
290    string representation (__repr__), `self.is_compatible_with()`,
291    `self.most_specific_compatible_type()`, and protobuf serialization
292    (e.g. TensorInfo and StructuredValue).
293    """
294    raise NotImplementedError("%s._serialize()" % type(self).__name__)
295
296  @classmethod
297  def _deserialize(cls, serialization):
298    """Reconstructs a TypeSpec from a value returned by `serialize`.
299
300    Args:
301      serialization: A value returned by _serialize.  In some contexts,
302        `namedtuple`s in `serialization` may not have the identical type
303        that was returned by `_serialize` (but its type will still be a
304        `namedtuple` type with the same type name and field names).  For
305        example, the code that loads a SavedModel does not have access to
306        the original `namedtuple` type, so it dynamically creates a new
307        `namedtuple` type with the same type name and field names as the
308        original one.  If necessary, you can check `serialization` for these
309        duck-typed `nametuple` types, and restore them to the original type.
310        (E.g., this would be necessary if you rely on type checks such as
311        `isinstance` for this `TypeSpec`'s member variables).
312
313    Returns:
314      A `TypeSpec` of type `cls`.
315    """
316    return cls(*serialization)
317
318  # === Operators ===
319
320  def __eq__(self, other):
321    # pylint: disable=protected-access
322    return (type(other) is type(self) and
323            self.__get_cmp_key() == other.__get_cmp_key())
324
325  def __ne__(self, other):
326    return not self == other
327
328  def __hash__(self):
329    return hash(self.__get_cmp_key())
330
331  def __reduce__(self):
332    return type(self), self._serialize()
333
334  def __repr__(self):
335    return "%s%r" % (type(self).__name__, self._serialize())
336
337  # === Legacy Output ===
338  # TODO(b/133606651) Document and/or deprecate the legacy_output methods.
339  # (These are used by tf.data.)
340
341  def _to_legacy_output_types(self):
342    raise NotImplementedError("%s._to_legacy_output_types()" %
343                              type(self).__name__)
344
345  def _to_legacy_output_shapes(self):
346    raise NotImplementedError("%s._to_legacy_output_shapes()" %
347                              type(self).__name__)
348
349  def _to_legacy_output_classes(self):
350    return self.value_type
351
352  # === Private Helper Methods ===
353
354  def __check_tensor_list(self, tensor_list):
355    expected = self._flat_tensor_specs
356    specs = [type_spec_from_value(t) for t in tensor_list]
357    if len(specs) != len(expected):
358      raise ValueError("Incompatible input: wrong number of tensors")
359    for i, (s1, s2) in enumerate(zip(specs, expected)):
360      if not s1.is_compatible_with(s2):
361        raise ValueError("Incompatible input: tensor %d (%s) is incompatible "
362                         "with %s" % (i, tensor_list[i], s2))
363
364  def __get_cmp_key(self):
365    """Returns a hashable eq-comparable key for `self`."""
366    # TODO(b/133606651): Decide whether to cache this value.
367    return (type(self), self.__make_cmp_key(self._serialize()))
368
369  def __make_cmp_key(self, value):
370    """Converts `value` to a hashable key."""
371    if isinstance(value,
372                  (int, float, bool, np.generic, dtypes.DType, TypeSpec)):
373      return value
374    if isinstance(value, compat.bytes_or_text_types):
375      return value
376    if value is None:
377      return value
378    if isinstance(value, dict):
379      return tuple([
380          tuple([self.__make_cmp_key(key),
381                 self.__make_cmp_key(value[key])])
382          for key in sorted(value.keys())
383      ])
384    if isinstance(value, tuple):
385      return tuple([self.__make_cmp_key(v) for v in value])
386    if isinstance(value, list):
387      return (list, tuple([self.__make_cmp_key(v) for v in value]))
388    if isinstance(value, tensor_shape.TensorShape):
389      if value.ndims is None:
390        # Note: we include a type object in the tuple, to ensure we can't get
391        # false-positive matches (since users can't include type objects).
392        return (tensor_shape.TensorShape, None)
393      return (tensor_shape.TensorShape, tuple(value.as_list()))
394    if isinstance(value, np.ndarray):
395      return (np.ndarray, value.shape,
396              TypeSpec.__nested_list_to_tuple(value.tolist()))
397    raise ValueError("Unsupported value type %s returned by "
398                     "%s._serialize" %
399                     (type(value).__name__, type(self).__name__))
400
401  @staticmethod
402  def __nested_list_to_tuple(value):
403    """Converts a nested list to a corresponding nested tuple."""
404    if isinstance(value, list):
405      return tuple(TypeSpec.__nested_list_to_tuple(v) for v in value)
406    return value
407
408  @staticmethod
409  def __same_types(a, b):
410    """Returns whether a and b have the same type, up to namedtuple equivalence.
411
412    Consistent with tf.nest.assert_same_structure(), two namedtuple types
413    are considered the same iff they agree in their class name (without
414    qualification by module name) and in their sequence of field names.
415    This makes namedtuples recreated by StructureCoder compatible with their
416    original Python definition.
417
418    Args:
419      a: a Python object.
420      b: a Python object.
421
422    Returns:
423      A boolean that is true iff type(a) and type(b) are the same object
424      or equivalent namedtuple types.
425    """
426    if nest.is_namedtuple(a) and nest.is_namedtuple(b):
427      return nest.same_namedtuples(a, b)
428    else:
429      return type(a) is type(b)
430
431  @staticmethod
432  def __is_compatible(a, b):
433    """Returns true if the given type serializations compatible."""
434    if isinstance(a, TypeSpec):
435      return a.is_compatible_with(b)
436    if not TypeSpec.__same_types(a, b):
437      return False
438    if isinstance(a, (list, tuple)):
439      return (len(a) == len(b) and
440              all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b)))
441    if isinstance(a, dict):
442      return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
443          TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys()))
444    if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)):
445      return a.is_compatible_with(b)
446    return a == b
447
448  @staticmethod
449  def __most_specific_compatible_type_serialization(a, b):
450    """Helper for most_specific_compatible_type.
451
452    Combines two type serializations as follows:
453
454    * If they are both tuples of the same length, then recursively combine
455      the respective tuple elements.
456    * If they are both dicts with the same keys, then recursively combine
457      the respective dict elements.
458    * If they are both TypeSpecs, then combine using
459      TypeSpec.most_specific_compatible_type.
460    * If they are both TensorShapes, then combine using
461      TensorShape.most_specific_compatible_shape.
462    * If they are both TensorSpecs with the same dtype, then combine using
463      TensorShape.most_specific_compatible_shape to combine shapes.
464    * If they are equal, then return a.
465    * If none of the above, then raise a ValueError.
466
467    Args:
468      a: A serialized TypeSpec or nested component from a serialized TypeSpec.
469      b: A serialized TypeSpec or nested component from a serialized TypeSpec.
470
471    Returns:
472      A value with the same type and structure as `a` and `b`.
473
474    Raises:
475      ValueError: If `a` and `b` are incompatible.
476    """
477    if not TypeSpec.__same_types(a, b):
478      raise ValueError("Types are not compatible: %r vs %r" % (a, b))
479    if nest.is_namedtuple(a):
480      assert a._fields == b._fields  # Implied by __same_types(a, b).
481      return type(a)(*[
482          TypeSpec.__most_specific_compatible_type_serialization(x, y)
483          for (x, y) in zip(a, b)])
484    if isinstance(a, (list, tuple)):
485      if len(a) != len(b):
486        raise ValueError("Types are not compatible: %r vs %r" % (a, b))
487      return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
488                   for (x, y) in zip(a, b))
489    if isinstance(a, collections.OrderedDict):
490      a_keys, b_keys = a.keys(), b.keys()
491      if len(a) != len(b) or a_keys != b_keys:
492        raise ValueError("Types are not compatible: %r vs %r" % (a, b))
493      return collections.OrderedDict([
494          (k,
495           TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k]))
496          for k in a_keys
497      ])
498    if isinstance(a, dict):
499      a_keys, b_keys = sorted(a.keys()), sorted(b.keys())
500      if len(a) != len(b) or a_keys != b_keys:
501        raise ValueError("Types are not compatible: %r vs %r" % (a, b))
502      return {
503          k: TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k])
504          for k in a_keys
505      }
506    if isinstance(a, tensor_shape.TensorShape):
507      return a.most_specific_compatible_shape(b)
508    if isinstance(a, list):
509      raise AssertionError("_serialize() should not return list values.")
510    if isinstance(a, TypeSpec):
511      return a.most_specific_compatible_type(b)
512    if a != b:
513      raise ValueError("Types are not compatible: %r vs %r" % (a, b))
514    return a
515
516
517class BatchableTypeSpec(TypeSpec):
518  """TypeSpec with a batchable tensor encoding.
519
520  The batchable tensor encoding is a list of `tf.Tensor`s that supports
521  batching and unbatching.  In particular, stacking (or unstacking)
522  values with the same `TypeSpec` must be equivalent to stacking (or
523  unstacking) each of their tensor lists.  Unlike the component encoding
524  (returned by `self._to_components)`, the batchable tensor encoding
525  may require using encoding/decoding ops.
526
527  If a subclass's batchable tensor encoding is not simply a flattened version
528  of the component encoding, then the subclass must override `_to_tensor_list`,
529  `_from_tensor_list`, and _flat_tensor_specs`.
530  """
531
532  __slots__ = []
533
534  @abc.abstractmethod
535  def _batch(self, batch_size):
536    """Returns a TypeSpec representing a batch of objects with this TypeSpec.
537
538    Args:
539      batch_size: An `int` representing the number of elements in a batch,
540        or `None` if the batch size may vary.
541
542    Returns:
543      A `TypeSpec` representing a batch of objects with this TypeSpec.
544    """
545    raise NotImplementedError("%s._batch" % type(self).__name__)
546
547  @abc.abstractmethod
548  def _unbatch(self):
549    """Returns a TypeSpec representing a single element this TypeSpec.
550
551    Returns:
552      A `TypeSpec` representing a single element of objects with this TypeSpec.
553    """
554    raise NotImplementedError("%s._unbatch" % type(self).__name__)
555
556  def _to_batched_tensor_list(self, value):
557    """Returns a tensor list encoding for value with rank>0."""
558    tensor_list = self._to_tensor_list(value)
559    if any(t.shape.ndims == 0 for t in tensor_list):
560      raise ValueError("Value %s has insufficient rank for batching." % value)
561    return tensor_list
562
563
564@tf_export("type_spec_from_value")
565def type_spec_from_value(value):
566  """Returns a `tf.TypeSpec` that represents the given `value`.
567
568  Examples:
569
570    >>> tf.type_spec_from_value(tf.constant([1, 2, 3]))
571    TensorSpec(shape=(3,), dtype=tf.int32, name=None)
572    >>> tf.type_spec_from_value(np.array([4.0, 5.0], np.float64))
573    TensorSpec(shape=(2,), dtype=tf.float64, name=None)
574    >>> tf.type_spec_from_value(tf.ragged.constant([[1, 2], [3, 4, 5]]))
575    RaggedTensorSpec(TensorShape([2, None]), tf.int32, 1, tf.int64)
576
577    >>> example_input = tf.ragged.constant([[1, 2], [3]])
578    >>> @tf.function(input_signature=[tf.type_spec_from_value(example_input)])
579    ... def f(x):
580    ...   return tf.reduce_sum(x, axis=1)
581
582  Args:
583    value: A value that can be accepted or returned by TensorFlow APIs.
584      Accepted types for `value` include `tf.Tensor`, any value that can be
585      converted to `tf.Tensor` using `tf.convert_to_tensor`, and any subclass
586      of `CompositeTensor` (such as `tf.RaggedTensor`).
587
588  Returns:
589    A `TypeSpec` that is compatible with `value`.
590
591  Raises:
592    TypeError: If a TypeSpec cannot be built for `value`, because its type
593      is not supported.
594  """
595  spec = _type_spec_from_value(value)
596  if spec is not None:
597    return spec
598
599  # Fallback: try converting value to a tensor.
600  try:
601    tensor = ops.convert_to_tensor(value)
602    spec = _type_spec_from_value(tensor)
603    if spec is not None:
604      return spec
605  except (ValueError, TypeError) as e:
606    logging.vlog(
607        3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e))
608
609  raise TypeError("Could not build a TypeSpec for %r with type %s" %
610                  (value, type(value).__name__))
611
612
613def _type_spec_from_value(value):
614  """Returns a `TypeSpec` that represents the given `value`."""
615  if isinstance(value, ops.Tensor):
616    # Note: we do not include Tensor names when constructing TypeSpecs.
617    return tensor_spec.TensorSpec(value.shape, value.dtype)
618
619  if isinstance(value, composite_tensor.CompositeTensor):
620    return value._type_spec  # pylint: disable=protected-access
621
622  # If `value` is a list and all of its elements can be represented by the same
623  # batchable type spec, then we can represent the entire list using a single
624  # type spec that captures the type accurately (unlike the `convert_to_tensor`
625  # fallback).
626  if isinstance(value, list) and value:
627    subspecs = [_type_spec_from_value(v) for v in value]
628    if isinstance(subspecs[0], BatchableTypeSpec):
629      merged_subspec = subspecs[0]
630      try:
631        for subspec in subspecs[1:]:
632          merged_subspec = merged_subspec.most_specific_compatible_type(subspec)
633        return merged_subspec._batch(len(subspecs))  # pylint: disable=protected-access
634      except (ValueError, TypeError):
635        pass  # incompatible subspecs
636
637  for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY):
638    type_object, converter_fn, allow_subclass = entry
639    if ((type(value) is type_object) or  # pylint: disable=unidiomatic-typecheck
640        (allow_subclass and isinstance(value, type_object))):
641      return converter_fn(value)
642
643  return None
644
645_TYPE_CONVERSION_FUNCTION_REGISTRY = []
646
647
648def register_type_spec_from_value_converter(type_object, converter_fn,
649                                            allow_subclass=False):
650  """Registers a function for converting values with a given type to TypeSpecs.
651
652  If multiple registered `type_object`s match a value, then the most recent
653  registration takes precedence.  Custom converters should not be defined for
654  `CompositeTensor`s; use `CompositeTensor._type_spec` instead.
655
656  Args:
657    type_object: A Python `type` object representing the type of values
658      accepted by `converter_fn`.
659    converter_fn: A function that takes one argument (an instance of the
660      type represented by `type_object`) and returns a `TypeSpec`.
661    allow_subclass: If true, then use `isinstance(value, type_object)` to
662      check for matches.  If false, then use `type(value) is type_object`.
663  """
664  _, type_object = tf_decorator.unwrap(type_object)
665  _TYPE_CONVERSION_FUNCTION_REGISTRY.append(
666      (type_object, converter_fn, allow_subclass))
667
668
669_pywrap_utils.RegisterType("TypeSpec", TypeSpec)
670
671
672_TYPE_SPEC_TO_NAME = {}
673_NAME_TO_TYPE_SPEC = {}
674
675
676# Regular expression for valid TypeSpec names.
677_REGISTERED_NAME_RE = re.compile(r"^(\w+\.)+\w+$")
678
679
680# TODO(b/173744905) tf_export this as "tf.register_type_spec".  (And add a
681# usage example to the docstring, once the API is public.)
682#
683# TODO(b/173744905) Update this decorator to apply to ExtensionType rather than
684# TypeSpec (once we do refactoring to move to_components/from_components from
685# TypeSpec to ExtensionType).
686def register(name):
687  """Decorator used to register a globally unique name for a TypeSpec subclass.
688
689  Args:
690    name: The name of the type spec.  Must be globally unique.  Must have
691      the form `"{project_name}.{type_name}"`.  E.g. `"my_project.MyTypeSpec"`.
692
693  Returns:
694    A class decorator that registers the decorated class with the given name.
695  """
696  if not isinstance(name, str):
697    raise TypeError("Expected `name` to be a string; got %r" % (name,))
698  if not _REGISTERED_NAME_RE.match(name):
699    raise ValueError(
700        "Registered name must have the form '{project_name}.{type_name}' "
701        "(e.g. 'my_project.MyTypeSpec'); got %r." % name)
702
703  def decorator_fn(cls):
704    if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
705      raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
706    if cls in _TYPE_SPEC_TO_NAME:
707      raise ValueError("Class %s.%s has already been registered with name %s."
708                       % (cls.__module__, cls.__name__,
709                          _TYPE_SPEC_TO_NAME[cls]))
710    if name in _NAME_TO_TYPE_SPEC:
711      raise ValueError("Name %s has already been registered for class %s.%s."
712                       % (name, _NAME_TO_TYPE_SPEC[name].__module__,
713                          _NAME_TO_TYPE_SPEC[name].__name__))
714    _TYPE_SPEC_TO_NAME[cls] = name
715    _NAME_TO_TYPE_SPEC[name] = cls
716    return cls
717
718  return decorator_fn
719
720
721# TODO(edloper) tf_export this as "tf.get_type_spec_name" (or some similar name)
722def get_name(cls):
723  """Returns the registered name for TypeSpec `cls`."""
724  if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
725    raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
726  if cls not in _TYPE_SPEC_TO_NAME:
727    raise ValueError("TypeSpec %s.%s has not been registered." %
728                     (cls.__module__, cls.__name__))
729  return _TYPE_SPEC_TO_NAME[cls]
730
731
732# TODO(edloper) tf_export this as "tf.lookup_type_spec" (or some similar name)
733def lookup(name):
734  """Returns the TypeSpec that has been registered with name `name`."""
735  if not isinstance(name, str):
736    raise TypeError("Expected `name` to be a string; got %r" % (name,))
737  if name not in _NAME_TO_TYPE_SPEC:
738    raise ValueError("No TypeSpec has been registered with name %r" % (name,))
739  return _NAME_TO_TYPE_SPEC[name]
740