• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Composes one or more `LinearOperators`."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import common_shapes
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops.linalg import linear_operator
28from tensorflow.python.util.tf_export import tf_export
29
30__all__ = ["LinearOperatorComposition"]
31
32
33@tf_export("linalg.LinearOperatorComposition")
34@linear_operator.make_composite_tensor
35class LinearOperatorComposition(linear_operator.LinearOperator):
36  """Composes one or more `LinearOperators`.
37
38  This operator composes one or more linear operators `[op1,...,opJ]`,
39  building a new `LinearOperator` with action defined by:
40
41  ```
42  op_composed(x) := op1(op2(...(opJ(x)...))
43  ```
44
45  If `opj` acts like [batch] matrix `Aj`, then `op_composed` acts like the
46  [batch] matrix formed with the multiplication `A1 A2...AJ`.
47
48  If `opj` has shape `batch_shape_j + [M_j, N_j]`, then we must have
49  `N_j = M_{j+1}`, in which case the composed operator has shape equal to
50  `broadcast_batch_shape + [M_1, N_J]`, where `broadcast_batch_shape` is the
51  mutual broadcast of `batch_shape_j`, `j = 1,...,J`, assuming the intermediate
52  batch shapes broadcast.  Even if the composed shape is well defined, the
53  composed operator's methods may fail due to lack of broadcasting ability in
54  the defining operators' methods.
55
56  ```python
57  # Create a 2 x 2 linear operator composed of two 2 x 2 operators.
58  operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
59  operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]])
60  operator = LinearOperatorComposition([operator_1, operator_2])
61
62  operator.to_dense()
63  ==> [[1., 2.]
64       [3., 4.]]
65
66  operator.shape
67  ==> [2, 2]
68
69  operator.log_abs_determinant()
70  ==> scalar Tensor
71
72  x = ... Shape [2, 4] Tensor
73  operator.matmul(x)
74  ==> Shape [2, 4] Tensor
75
76  # Create a [2, 3] batch of 4 x 5 linear operators.
77  matrix_45 = tf.random.normal(shape=[2, 3, 4, 5])
78  operator_45 = LinearOperatorFullMatrix(matrix)
79
80  # Create a [2, 3] batch of 5 x 6 linear operators.
81  matrix_56 = tf.random.normal(shape=[2, 3, 5, 6])
82  operator_56 = LinearOperatorFullMatrix(matrix_56)
83
84  # Compose to create a [2, 3] batch of 4 x 6 operators.
85  operator_46 = LinearOperatorComposition([operator_45, operator_56])
86
87  # Create a shape [2, 3, 6, 2] vector.
88  x = tf.random.normal(shape=[2, 3, 6, 2])
89  operator.matmul(x)
90  ==> Shape [2, 3, 4, 2] Tensor
91  ```
92
93  #### Performance
94
95  The performance of `LinearOperatorComposition` on any operation is equal to
96  the sum of the individual operators' operations.
97
98
99  #### Matrix property hints
100
101  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
102  for `X = non_singular, self_adjoint, positive_definite, square`.
103  These have the following meaning:
104
105  * If `is_X == True`, callers should expect the operator to have the
106    property `X`.  This is a promise that should be fulfilled, but is *not* a
107    runtime assert.  For example, finite floating point precision may result
108    in these promises being violated.
109  * If `is_X == False`, callers should expect the operator to not have `X`.
110  * If `is_X == None` (the default), callers should have no expectation either
111    way.
112  """
113
114  def __init__(self,
115               operators,
116               is_non_singular=None,
117               is_self_adjoint=None,
118               is_positive_definite=None,
119               is_square=None,
120               name=None):
121    r"""Initialize a `LinearOperatorComposition`.
122
123    `LinearOperatorComposition` is initialized with a list of operators
124    `[op_1,...,op_J]`.  For the `matmul` method to be well defined, the
125    composition `op_i.matmul(op_{i+1}(x))` must be defined.  Other methods have
126    similar constraints.
127
128    Args:
129      operators:  Iterable of `LinearOperator` objects, each with
130        the same `dtype` and composable shape.
131      is_non_singular:  Expect that this operator is non-singular.
132      is_self_adjoint:  Expect that this operator is equal to its hermitian
133        transpose.
134      is_positive_definite:  Expect that this operator is positive definite,
135        meaning the quadratic form `x^H A x` has positive real part for all
136        nonzero `x`.  Note that we do not require the operator to be
137        self-adjoint to be positive-definite.  See:
138        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
139      is_square:  Expect that this operator acts like square [batch] matrices.
140      name: A name for this `LinearOperator`.  Default is the individual
141        operators names joined with `_o_`.
142
143    Raises:
144      TypeError:  If all operators do not have the same `dtype`.
145      ValueError:  If `operators` is empty.
146    """
147    parameters = dict(
148        operators=operators,
149        is_non_singular=is_non_singular,
150        is_self_adjoint=is_self_adjoint,
151        is_positive_definite=is_positive_definite,
152        is_square=is_square,
153        name=name)
154
155    # Validate operators.
156    check_ops.assert_proper_iterable(operators)
157    operators = list(operators)
158    if not operators:
159      raise ValueError(
160          "Expected a non-empty list of operators. Found: %s" % operators)
161    self._operators = operators
162
163    # Validate dtype.
164    dtype = operators[0].dtype
165    for operator in operators:
166      if operator.dtype != dtype:
167        name_type = (str((o.name, o.dtype)) for o in operators)
168        raise TypeError(
169            "Expected all operators to have the same dtype.  Found %s"
170            % "   ".join(name_type))
171
172    # Auto-set and check hints.
173    if all(operator.is_non_singular for operator in operators):
174      if is_non_singular is False:
175        raise ValueError(
176            "The composition of non-singular operators is always non-singular.")
177      is_non_singular = True
178
179    # Initialization.
180    graph_parents = []
181    for operator in operators:
182      graph_parents.extend(operator.graph_parents)
183
184    if name is None:
185      name = "_o_".join(operator.name for operator in operators)
186    with ops.name_scope(name, values=graph_parents):
187      super(LinearOperatorComposition, self).__init__(
188          dtype=dtype,
189          is_non_singular=is_non_singular,
190          is_self_adjoint=is_self_adjoint,
191          is_positive_definite=is_positive_definite,
192          is_square=is_square,
193          parameters=parameters,
194          name=name)
195    # TODO(b/143910018) Remove graph_parents in V3.
196    self._set_graph_parents(graph_parents)
197
198  @property
199  def operators(self):
200    return self._operators
201
202  def _shape(self):
203    # Get final matrix shape.
204    domain_dimension = self.operators[0].domain_dimension
205    for operator in self.operators[1:]:
206      domain_dimension.assert_is_compatible_with(operator.range_dimension)
207      domain_dimension = operator.domain_dimension
208
209    matrix_shape = tensor_shape.TensorShape(
210        [self.operators[0].range_dimension,
211         self.operators[-1].domain_dimension])
212
213    # Get broadcast batch shape.
214    # broadcast_shape checks for compatibility.
215    batch_shape = self.operators[0].batch_shape
216    for operator in self.operators[1:]:
217      batch_shape = common_shapes.broadcast_shape(
218          batch_shape, operator.batch_shape)
219
220    return batch_shape.concatenate(matrix_shape)
221
222  def _shape_tensor(self):
223    # Avoid messy broadcasting if possible.
224    if self.shape.is_fully_defined():
225      return ops.convert_to_tensor(
226          self.shape.as_list(), dtype=dtypes.int32, name="shape")
227
228    # Don't check the matrix dimensions.  That would add unnecessary Asserts to
229    # the graph.  Things will fail at runtime naturally if shapes are
230    # incompatible.
231    matrix_shape = array_ops.stack([
232        self.operators[0].range_dimension_tensor(),
233        self.operators[-1].domain_dimension_tensor()
234    ])
235
236    # Dummy Tensor of zeros.  Will never be materialized.
237    zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
238    for operator in self.operators[1:]:
239      zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
240    batch_shape = array_ops.shape(zeros)
241
242    return array_ops.concat((batch_shape, matrix_shape), 0)
243
244  def _matmul(self, x, adjoint=False, adjoint_arg=False):
245    # If self.operators = [A, B], and not adjoint, then
246    # matmul_order_list = [B, A].
247    # As a result, we return A.matmul(B.matmul(x))
248    if adjoint:
249      matmul_order_list = self.operators
250    else:
251      matmul_order_list = list(reversed(self.operators))
252
253    result = matmul_order_list[0].matmul(
254        x, adjoint=adjoint, adjoint_arg=adjoint_arg)
255    for operator in matmul_order_list[1:]:
256      result = operator.matmul(result, adjoint=adjoint)
257    return result
258
259  def _determinant(self):
260    result = self.operators[0].determinant()
261    for operator in self.operators[1:]:
262      result *= operator.determinant()
263    return result
264
265  def _log_abs_determinant(self):
266    result = self.operators[0].log_abs_determinant()
267    for operator in self.operators[1:]:
268      result += operator.log_abs_determinant()
269    return result
270
271  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
272    # TODO(langmore) Implement solve using solve_ls if some intermediate
273    # operator maps to a high dimensional space.
274    # In that case, an exact solve may still be possible.
275
276    # If self.operators = [A, B], and not adjoint, then
277    # solve_order_list = [A, B].
278    # As a result, we return B.solve(A.solve(x))
279    if adjoint:
280      solve_order_list = list(reversed(self.operators))
281    else:
282      solve_order_list = self.operators
283
284    solution = solve_order_list[0].solve(
285        rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
286    for operator in solve_order_list[1:]:
287      solution = operator.solve(solution, adjoint=adjoint)
288    return solution
289
290  @property
291  def _composite_tensor_fields(self):
292    return ("operators",)
293