• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type specifications for TensorFlow APIs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import numpy as np
24import six
25
26from tensorflow.python import _pywrap_utils
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import compat
32from tensorflow.python.util import nest
33from tensorflow.python.util import tf_decorator
34from tensorflow.python.util.lazy_loader import LazyLoader
35from tensorflow.python.util.tf_export import tf_export
36
37# Use LazyLoader to avoid circular dependencies.
38tensor_spec = LazyLoader(
39    "tensor_spec", globals(),
40    "tensorflow.python.framework.tensor_spec")
41ops = LazyLoader(
42    "ops", globals(),
43    "tensorflow.python.framework.ops")
44
45
46@tf_export("TypeSpec", v1=["TypeSpec", "data.experimental.Structure"])
47@six.add_metaclass(abc.ABCMeta)
48class TypeSpec(object):
49  """Specifies a TensorFlow value type.
50
51  A `tf.TypeSpec` provides metadata describing an object accepted or returned
52  by TensorFlow APIs.  Concrete subclasses, such as `tf.TensorSpec` and
53  `tf.RaggedTensorSpec`, are used to describe different value types.
54
55  For example, `tf.function`'s `input_signature` argument accepts a list
56  (or nested structure) of `TypeSpec`s.
57
58  Creating new subclasses of TypeSpec (outside of TensorFlow core) is not
59  currently supported.  In particular, we may make breaking changes to the
60  private methods and properties defined by this base class.
61  """
62  # === Subclassing ===
63  #
64  # Each `TypeSpec` subclass must define:
65  #
66  #   * A "component encoding" for values.
67  #   * A "serialization" for types.
68  #
69  # The component encoding for a value is a nested structure of `tf.Tensor`
70  # or `CompositeTensor` that can be used by the `TypeSpec` to reconstruct
71  # the value.  Each individual `TypeSpec` must use the same nested structure
72  # for all values -- this structure is defined by the `component_specs`
73  # attribute.  Decomposing values into components, and reconstructing them
74  # from those components, should be inexpensive.  In particular, it should
75  # *not* require any TensorFlow ops.
76  #
77  # The serialization for a `TypeSpec` is a nested tuple of values that can
78  # be used to reconstruct the `TypeSpec`.  See the documentation for
79  # `_serialize()` for more information.
80
81  __slots__ = []
82
83  @abc.abstractproperty
84  def value_type(self):
85    """The Python type for values that are compatible with this TypeSpec."""
86    raise NotImplementedError("%s.value_type" % type(self).__name__)
87
88  def is_compatible_with(self, spec_or_value):
89    """Returns true if `spec_or_value` is compatible with this TypeSpec."""
90    # === Subclassing ===
91    # If not overridden by subclasses, the default behavior is to convert
92    # `spec_or_value` to a `TypeSpec` (if it isn't already); and then to
93    # consider two `TypeSpec`s compatible if they have the same type, and
94    # the values returned by `_serialize` are compatible (where
95    # `tf.TensorShape`, `tf.TensorSpec`, and `tf.DType` are checked for
96    # compatibility using their `is_compatible_with` method; and all other
97    # types are considered compatible if they are equal).
98    if not isinstance(spec_or_value, TypeSpec):
99      spec_or_value = type_spec_from_value(spec_or_value)
100    if type(self) is not type(spec_or_value):
101      return False
102    return self.__is_compatible(self._serialize(),
103                                spec_or_value._serialize())  # pylint: disable=protected-access
104
105  def most_specific_compatible_type(self, other):
106    """Returns the most specific TypeSpec compatible with `self` and `other`.
107
108    Args:
109      other: A `TypeSpec`.
110
111    Raises:
112      ValueError: If there is no TypeSpec that is compatible with both `self`
113        and `other`.
114    """
115    # === Subclassing ===
116    # If not overridden by a subclass, the default behavior is to raise a
117    # `ValueError` if `self` and `other` have different types, or if their type
118    # serializations differ by anything other than `TensorShape`s.  Otherwise,
119    # the two type serializations are combined (using
120    # `most_specific_compatible_shape` to combine `TensorShape`s), and the
121    # result is used to construct and return a new `TypeSpec`.
122    if type(self) is not type(other):
123      raise ValueError("No TypeSpec is compatible with both %s and %s" %
124                       (self, other))
125    merged = self.__most_specific_compatible_type_serialization(
126        self._serialize(), other._serialize())  # pylint: disable=protected-access
127    return self._deserialize(merged)
128
129  # === Component encoding for values ===
130
131  @abc.abstractmethod
132  def _to_components(self, value):
133    """Encodes `value` as a nested structure of `Tensor` or `CompositeTensor`.
134
135    Args:
136      value: A value compatible with this `TypeSpec`.  (Caller is responsible
137        for ensuring compatibility.)
138
139    Returns:
140      A nested structure of `tf.Tensor` or `tf.CompositeTensor` compatible with
141      `self._component_specs`, which can be used to reconstruct `value`.
142    """
143    # === Subclassing ===
144    # This method must be inexpensive (do not call TF ops).
145    raise NotImplementedError("%s._to_components()" % type(self).__name__)
146
147  @abc.abstractmethod
148  def _from_components(self, components):
149    """Reconstructs a value from a nested structure of Tensor/CompositeTensor.
150
151    Args:
152      components: A nested structure of `tf.Tensor` or `tf.CompositeTensor`,
153        compatible with `self._component_specs`.  (Caller is repsonsible for
154        ensuring compatibility.)
155
156    Returns:
157      A value that is compatible with this `TypeSpec`.
158    """
159    # === Subclassing ===
160    # This method must be inexpensive (do not call TF ops).
161    raise NotImplementedError("%s._from_components()" % type(self).__name__)
162
163  @abc.abstractproperty
164  def _component_specs(self):
165    """A nested structure of TypeSpecs for this type's components.
166
167    Returns:
168      A nested structure describing the component encodings that are returned
169      by this TypeSpec's `_to_components` method.  In particular, for a
170      TypeSpec `spec` and a compatible value `value`:
171
172      ```
173      nest.map_structure(lambda t, c: assert t.is_compatible_with(c),
174                         spec._component_specs, spec._to_components(value))
175      ```
176    """
177    raise NotImplementedError("%s._component_specs()" % type(self).__name__)
178
179  # === Tensor list encoding for values ===
180
181  def _to_tensor_list(self, value):
182    """Encodes `value` as a flat list of `tf.Tensor`.
183
184    By default, this just flattens `self._to_components(value)` using
185    `nest.flatten`.  However, subclasses may override this to return a
186    different tensor encoding for values.  In particular, some subclasses
187    of `BatchableTypeSpec` override this method to return a "boxed" encoding
188    for values, which then can be batched or unbatched.  See
189    `BatchableTypeSpec` for more details.
190
191    Args:
192      value: A value with compatible this `TypeSpec`.  (Caller is responsible
193        for ensuring compatibility.)
194
195    Returns:
196      A list of `tf.Tensor`, compatible with `self._flat_tensor_specs`, which
197      can be used to reconstruct `value`.
198    """
199    return nest.flatten(self._to_components(value), expand_composites=True)
200
201  def _from_tensor_list(self, tensor_list):
202    """Reconstructs a value from a flat list of `tf.Tensor`.
203
204    Args:
205      tensor_list: A flat list of `tf.Tensor`, compatible with
206        `self._flat_tensor_specs`.
207
208    Returns:
209      A value that is compatible with this `TypeSpec`.
210
211    Raises:
212      ValueError: If `tensor_list` is not compatible with
213      `self._flat_tensor_specs`.
214    """
215    self.__check_tensor_list(tensor_list)
216    return self._from_compatible_tensor_list(tensor_list)
217
218  def _from_compatible_tensor_list(self, tensor_list):
219    """Reconstructs a value from a compatible flat list of `tf.Tensor`.
220
221    Args:
222      tensor_list: A flat list of `tf.Tensor`, compatible with
223        `self._flat_tensor_specs`.  (Caller is responsible for ensuring
224        compatibility.)
225
226    Returns:
227      A value that is compatible with this `TypeSpec`.
228    """
229    return self._from_components(nest.pack_sequence_as(
230        self._component_specs, tensor_list, expand_composites=True))
231
232  @property
233  def _flat_tensor_specs(self):
234    """A list of TensorSpecs compatible with self._to_tensor_list(v)."""
235    return nest.flatten(self._component_specs, expand_composites=True)
236
237  # === Serialization for types ===
238
239  @abc.abstractmethod
240  def _serialize(self):
241    """Returns a nested tuple containing the state of this TypeSpec.
242
243    The serialization may contain the following value types: boolean,
244    integer, string, float, None, `TensorSpec`, `tf.TensorShape`, `tf.DType`,
245    `np.ndarray`, `TypeSpec`, and nested tuples, namedtuples, dicts, and
246    OrderedDicts of any of the above.
247
248    This method is used to provide default definitions for: equality
249    testing (__eq__, __ne__), hashing (__hash__), pickling (__reduce__),
250    string representation (__repr__), `self.is_compatible_with()`,
251    `self.most_specific_compatible_type()`, and protobuf serialization
252    (e.g. TensorInfo and StructuredValue).
253    """
254    raise NotImplementedError("%s._serialize()" % type(self).__name__)
255
256  @classmethod
257  def _deserialize(cls, serialization):
258    """Reconstructs a TypeSpec from a value returned by `serialize`."""
259    return cls(*serialization)
260
261  # === Operators ===
262
263  def __eq__(self, other):
264    # pylint: disable=protected-access
265    return (type(other) is type(self) and
266            self.__get_cmp_key() == other.__get_cmp_key())
267
268  def __ne__(self, other):
269    return not self == other
270
271  def __hash__(self):
272    return hash(self.__get_cmp_key())
273
274  def __reduce__(self):
275    return type(self), self._serialize()
276
277  def __repr__(self):
278    return "%s%r" % (type(self).__name__, self._serialize())
279
280  # === Legacy Output ===
281  # TODO(b/133606651) Document and/or deprecate the legacy_output methods.
282  # (These are used by tf.data.)
283
284  def _to_legacy_output_types(self):
285    raise NotImplementedError("%s._to_legacy_output_types()" %
286                              type(self).__name__)
287
288  def _to_legacy_output_shapes(self):
289    raise NotImplementedError("%s._to_legacy_output_shapes()" %
290                              type(self).__name__)
291
292  def _to_legacy_output_classes(self):
293    return self.value_type
294
295  # === Private Helper Methods ===
296
297  def __check_tensor_list(self, tensor_list):
298    expected = self._flat_tensor_specs
299    specs = [type_spec_from_value(t) for t in tensor_list]
300    if len(specs) != len(expected):
301      raise ValueError("Incompatible input: wrong number of tensors")
302    for i, (s1, s2) in enumerate(zip(specs, expected)):
303      if not s1.is_compatible_with(s2):
304        raise ValueError("Incompatible input: tensor %d (%s) is incompatible "
305                         "with %s" % (i, tensor_list[i], s2))
306
307  def __get_cmp_key(self):
308    """Returns a hashable eq-comparable key for `self`."""
309    # TODO(b/133606651): Decide whether to cache this value.
310    return (type(self), self.__make_cmp_key(self._serialize()))
311
312  def __make_cmp_key(self, value):
313    """Converts `value` to a hashable key."""
314    if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec)):
315      return value
316    if isinstance(value, compat.bytes_or_text_types):
317      return value
318    if value is None:
319      return value
320    if isinstance(value, dict):
321      return tuple([
322          tuple([self.__make_cmp_key(key),
323                 self.__make_cmp_key(value[key])])
324          for key in sorted(value.keys())
325      ])
326    if isinstance(value, tuple):
327      return tuple([self.__make_cmp_key(v) for v in value])
328    if isinstance(value, list):
329      return (list, tuple([self.__make_cmp_key(v) for v in value]))
330    if isinstance(value, tensor_shape.TensorShape):
331      if value.ndims is None:
332        # Note: we include a type object in the tuple, to ensure we can't get
333        # false-positive matches (since users can't include type objects).
334        return (tensor_shape.TensorShape, None)
335      return (tensor_shape.TensorShape, tuple(value.as_list()))
336    if isinstance(value, np.ndarray):
337      return (np.ndarray, value.shape,
338              TypeSpec.__nested_list_to_tuple(value.tolist()))
339    raise ValueError("Unsupported value type %s returned by "
340                     "%s._serialize" %
341                     (type(value).__name__, type(self).__name__))
342
343  @staticmethod
344  def __nested_list_to_tuple(value):
345    """Converts a nested list to a corresponding nested tuple."""
346    if isinstance(value, list):
347      return tuple(TypeSpec.__nested_list_to_tuple(v) for v in value)
348    return value
349
350  @staticmethod
351  def __is_compatible(a, b):
352    """Returns true if the given type serializations compatible."""
353    if type(a) is not type(b):
354      return False
355    if isinstance(a, (list, tuple)):
356      return (len(a) == len(b) and
357              all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b)))
358    if isinstance(a, dict):
359      return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
360          TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys()))
361    if isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType)):
362      return a.is_compatible_with(b)
363    return a == b
364
365  @staticmethod
366  def __most_specific_compatible_type_serialization(a, b):
367    """Helper for most_specific_compatible_type.
368
369    Combines two type serializations as follows:
370
371    * If they are both tuples of the same length, then recursively combine
372      the respective tuple elements.
373    * If they are both dicts with the same keys, then recursively combine
374      the respective dict elements.
375    * If they are both TypeSpecs, then combine using
376      TypeSpec.most_specific_comptible_type.
377    * If they are both TensorShapes, then combine using
378      TensorShape.most_specific_compatible_shape.
379    * If they are both TensorSpecs with the same dtype, then combine using
380      TensorShape.most_specific_compatible_shape to combine shapes.
381    * If they are equal, then return a.
382    * If none of the above, then raise a ValueError.
383
384    Args:
385      a: A serialized TypeSpec or nested component from a serialized TypeSpec.
386      b: A serialized TypeSpec or nested component from a serialized TypeSpec.
387
388    Returns:
389      A value with the same type and structure as `a` and `b`.
390
391    Raises:
392      ValueError: If `a` and `b` are incompatible.
393    """
394    if type(a) is not type(b):
395      raise ValueError("Types are not compatible: %r vs %r" % (a, b))
396    if isinstance(a, (list, tuple)):
397      if len(a) != len(b):
398        raise ValueError("Types are not compatible: %r vs %r" % (a, b))
399      return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
400                   for (x, y) in zip(a, b))
401    if isinstance(a, dict):
402      a_keys, b_keys = sorted(a.keys()), sorted(b.keys())
403      if len(a) != len(b) or a_keys != b_keys:
404        raise ValueError("Types are not compatible: %r vs %r" % (a, b))
405      return {
406          k: TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k])
407          for k in a_keys
408      }
409    if isinstance(a, tensor_shape.TensorShape):
410      return a.most_specific_compatible_shape(b)
411    if isinstance(a, list):
412      raise AssertionError("_serialize() should not return list values.")
413    if isinstance(a, TypeSpec):
414      return a.most_specific_compatible_type(b)
415    if a != b:
416      raise ValueError("Types are not compatible: %r vs %r" % (a, b))
417    return a
418
419
420class BatchableTypeSpec(TypeSpec):
421  """TypeSpec with a batchable tensor encoding.
422
423  The batchable tensor encoding is a list of `tf.Tensor`s that supports
424  batching and unbatching.  In particular, stacking (or unstacking)
425  values with the same `TypeSpec` must be equivalent to stacking (or
426  unstacking) each of their tensor lists.  Unlike the component encoding
427  (returned by `self._to_components)`, the batchable tensor encoding
428  may require using encoding/decoding ops.
429
430  If a subclass's batchable tensor encoding is not simply a flattened version
431  of the component encoding, then the subclass must override `_to_tensor_list`,
432  `_from_tensor_list`, and _flat_tensor_specs`.
433  """
434
435  __slots__ = []
436
437  @abc.abstractmethod
438  def _batch(self, batch_size):
439    """Returns a TypeSpec representing a batch of objects with this TypeSpec.
440
441    Args:
442      batch_size: An `int` representing the number of elements in a batch,
443        or `None` if the batch size may vary.
444
445    Returns:
446      A `TypeSpec` representing a batch of objects with this TypeSpec.
447    """
448    raise NotImplementedError("%s._batch" % type(self).__name__)
449
450  @abc.abstractmethod
451  def _unbatch(self):
452    """Returns a TypeSpec representing a single element this TypeSpec.
453
454    Returns:
455      A `TypeSpec` representing a single element of objects with this TypeSpec.
456    """
457    raise NotImplementedError("%s._unbatch" % type(self).__name__)
458
459  def _to_batched_tensor_list(self, value):
460    """Returns a tensor list encoding for value with rank>0."""
461    tensor_list = self._to_tensor_list(value)
462    if any(t.shape.ndims == 0 for t in tensor_list):
463      raise ValueError("Value %s has insufficient rank for batching." % value)
464    return tensor_list
465
466
467def type_spec_from_value(value):
468  """Returns a `TypeSpec` that represents the given `value`.
469
470  Args:
471    value: A value that can be accepted or returned by TensorFlow APIs.
472
473  Returns:
474    A `TypeSpec` that is compatible with `value`.
475
476  Raises:
477    TypeError: If a TypeSpec cannot be built for `value`, because its type
478      is not supported.
479  """
480  spec = _type_spec_from_value(value)
481  if spec is not None:
482    return spec
483
484  # Fallback: try converting value to a tensor.
485  try:
486    tensor = ops.convert_to_tensor(value)
487    spec = _type_spec_from_value(tensor)
488    if spec is not None:
489      return spec
490  except (ValueError, TypeError) as e:
491    logging.vlog(
492        3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e))
493
494  raise TypeError("Could not build a TypeSpec for %r with type %s" %
495                  (value, type(value).__name__))
496
497
498def _type_spec_from_value(value):
499  """Returns a `TypeSpec` that represents the given `value`."""
500  if isinstance(value, ops.Tensor):
501    # Note: we do not include Tensor names when constructing TypeSpecs.
502    return tensor_spec.TensorSpec(value.shape, value.dtype)
503
504  if isinstance(value, composite_tensor.CompositeTensor):
505    return value._type_spec  # pylint: disable=protected-access
506
507  # If `value` is a list and all of its elements can be represented by the same
508  # batchable type spec, then we can represent the entire list using a single
509  # type spec that captures the type accurately (unlike the `convert_to_tensor`
510  # fallback).
511  if isinstance(value, list) and value:
512    subspecs = [_type_spec_from_value(v) for v in value]
513    if isinstance(subspecs[0], BatchableTypeSpec):
514      merged_subspec = subspecs[0]
515      try:
516        for subspec in subspecs[1:]:
517          merged_subspec = merged_subspec.most_specific_compatible_type(subspec)
518        return merged_subspec._batch(len(subspecs))  # pylint: disable=protected-access
519      except (ValueError, TypeError):
520        pass  # incompatible subspecs
521
522  for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY):
523    type_object, converter_fn, allow_subclass = entry
524    if ((type(value) is type_object) or  # pylint: disable=unidiomatic-typecheck
525        (allow_subclass and isinstance(value, type_object))):
526      return converter_fn(value)
527
528  return None
529
530_TYPE_CONVERSION_FUNCTION_REGISTRY = []
531
532
533def register_type_spec_from_value_converter(type_object, converter_fn,
534                                            allow_subclass=False):
535  """Registers a function for converting values with a given type to TypeSpecs.
536
537  If multiple registered `type_object`s match a value, then the most recent
538  registration takes precedence.  Custom converters should not be defined for
539  `CompositeTensor`s; use `CompositeTensor._type_spec` instead.
540
541  Args:
542    type_object: A Python `type` object representing the type of values
543      accepted by `converter_fn`.
544    converter_fn: A function that takes one argument (an instance of the
545      type represented by `type_object`) and returns a `TypeSpec`.
546    allow_subclass: If true, then use `isinstance(value, type_object)` to
547      check for matches.  If false, then use `type(value) is type_object`.
548  """
549  _, type_object = tf_decorator.unwrap(type_object)
550  _TYPE_CONVERSION_FUNCTION_REGISTRY.append(
551      (type_object, converter_fn, allow_subclass))
552
553
554_pywrap_utils.RegisterType("TypeSpec", TypeSpec)
555