• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tensor-like objects that are composed from tf.Tensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import six
24
25from tensorflow.python import pywrap_tensorflow
26from tensorflow.python.util import nest
27
28
29@six.add_metaclass(abc.ABCMeta)
30class CompositeTensor(object):
31  """Abstract base class for Tensor-like objects that are composed from Tensors.
32
33  Each `CompositeTensor` can be decomposed into a structured collection of
34  component `tf.Tensor`s, and reconstructed from those components.
35
36  The `tensorflow.python.util.nest` module has support for treating composite
37  tensors as structure, which makes it easy to flatten and reconstruct
38  composite tensors (or larger structures that contain composite tensors).
39  E.g.:
40
41  ```python
42  ct = ...  # Create a composite tensor.
43  flat_list_of_tensors = nest.flatten(ct, expand_composites=True)
44  transformed_list_of_tensors = ...  # do something with the flat tensors.
45  result = nest.pack_sequence_as(ct, transformed_list_of_tensors)
46  ```
47  """
48
49  @abc.abstractmethod
50  def _to_components(self):
51    """Decomposes this composite tensor into its components.
52
53    Returns:
54      The components that comprise this composite tensor: a nested structure
55      (as defined by `tf.python.util.nest`) whose values are `tf.Tensor`s or
56      `CompositeTensor`s.
57    """
58    raise NotImplementedError("CompositeTensor._to_components")
59
60  @abc.abstractmethod
61  def _from_components(cls, components):  # pylint: disable=no-self-argument
62    """Creates a composite tensor of type `cls` from components.
63
64    Args:
65      components: The components that should be used to form the
66        composite tensor: a nested structure (as defined by
67        `tf.python.util.nest`) whose values are tf.Tensors or composite
68        tensors.
69
70    Returns:
71      A `CompositeTensor` of type `cls`.
72    """
73    raise NotImplementedError("CompositeTensor._from_components")
74
75  @abc.abstractmethod
76  def _shape_invariant_to_components(self, shape=None):
77    """Converts a shape invariant into invariants for individual components.
78
79    Args:
80      shape: A `tf.TensorShape` object.  The shape invariant for this
81        `CompositeTensor`, or `None` if a default shape invariant should be
82        used (based on the value of this `CompositeTensor`).
83
84    Returns:
85      A nested structure whose values are `tf.TensorShape` objects, specifying
86      the shape invariants for the tensors that comprise this `CompositeTensor`.
87    """
88    raise NotImplementedError("CompositeTensor._shape_invariant_to_components")
89
90  @abc.abstractproperty
91  def _is_graph_tensor(self):
92    """Returns True if this tensor's components belong to a TF graph."""
93    raise NotImplementedError("CompositeTensor._is_symbolic_tensor")
94
95  def consumers(self):
96    """Returns a list of `Operation`s that consume this `CompositeTensor`.
97
98    Returns:
99      A list of `Operation`s.
100
101    Raises:
102      RuntimeError: If this method is called while executing eagerly.
103    """
104    consumers = nest.flatten([
105        component.consumers()
106        for component in self._to_components()
107        if getattr(component, "graph", None) is not None
108    ])
109    return list(set(consumers))
110
111
112pywrap_tensorflow.RegisterType("CompositeTensor", CompositeTensor)
113
114
115# @TODO(edloper): Can we replace convert_to_tensor_or_xyz with just
116# convert_to_tensor_or_composite?  Alternatively, should composite tensors
117# register a dispatch override for tf.convert_to_tensor?
118