1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tensor-like objects that are composed from tf.Tensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import six 24 25from tensorflow.python import pywrap_tensorflow 26from tensorflow.python.util import nest 27 28 29@six.add_metaclass(abc.ABCMeta) 30class CompositeTensor(object): 31 """Abstract base class for Tensor-like objects that are composed from Tensors. 32 33 Each `CompositeTensor` can be decomposed into a structured collection of 34 component `tf.Tensor`s, and reconstructed from those components. 35 36 The `tensorflow.python.util.nest` module has support for treating composite 37 tensors as structure, which makes it easy to flatten and reconstruct 38 composite tensors (or larger structures that contain composite tensors). 39 E.g.: 40 41 ```python 42 ct = ... # Create a composite tensor. 43 flat_list_of_tensors = nest.flatten(ct, expand_composites=True) 44 transformed_list_of_tensors = ... # do something with the flat tensors. 45 result = nest.pack_sequence_as(ct, transformed_list_of_tensors) 46 ``` 47 """ 48 49 @abc.abstractmethod 50 def _to_components(self): 51 """Decomposes this composite tensor into its components. 52 53 Returns: 54 The components that comprise this composite tensor: a nested structure 55 (as defined by `tf.python.util.nest`) whose values are `tf.Tensor`s or 56 `CompositeTensor`s. 57 """ 58 raise NotImplementedError("CompositeTensor._to_components") 59 60 @abc.abstractmethod 61 def _from_components(cls, components): # pylint: disable=no-self-argument 62 """Creates a composite tensor of type `cls` from components. 63 64 Args: 65 components: The components that should be used to form the 66 composite tensor: a nested structure (as defined by 67 `tf.python.util.nest`) whose values are tf.Tensors or composite 68 tensors. 69 70 Returns: 71 A `CompositeTensor` of type `cls`. 72 """ 73 raise NotImplementedError("CompositeTensor._from_components") 74 75 @abc.abstractmethod 76 def _shape_invariant_to_components(self, shape=None): 77 """Converts a shape invariant into invariants for individual components. 78 79 Args: 80 shape: A `tf.TensorShape` object. The shape invariant for this 81 `CompositeTensor`, or `None` if a default shape invariant should be 82 used (based on the value of this `CompositeTensor`). 83 84 Returns: 85 A nested structure whose values are `tf.TensorShape` objects, specifying 86 the shape invariants for the tensors that comprise this `CompositeTensor`. 87 """ 88 raise NotImplementedError("CompositeTensor._shape_invariant_to_components") 89 90 @abc.abstractproperty 91 def _is_graph_tensor(self): 92 """Returns True if this tensor's components belong to a TF graph.""" 93 raise NotImplementedError("CompositeTensor._is_symbolic_tensor") 94 95 def consumers(self): 96 """Returns a list of `Operation`s that consume this `CompositeTensor`. 97 98 Returns: 99 A list of `Operation`s. 100 101 Raises: 102 RuntimeError: If this method is called while executing eagerly. 103 """ 104 consumers = nest.flatten([ 105 component.consumers() 106 for component in self._to_components() 107 if getattr(component, "graph", None) is not None 108 ]) 109 return list(set(consumers)) 110 111 112pywrap_tensorflow.RegisterType("CompositeTensor", CompositeTensor) 113 114 115# @TODO(edloper): Can we replace convert_to_tensor_or_xyz with just 116# convert_to_tensor_or_composite? Alternatively, should composite tensors 117# register a dispatch override for tf.convert_to_tensor? 118