• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Composes one or more `LinearOperators`."""
16
17from tensorflow.python.framework import common_shapes
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops.linalg import linear_operator
25from tensorflow.python.util.tf_export import tf_export
26
27__all__ = ["LinearOperatorComposition"]
28
29
30@tf_export("linalg.LinearOperatorComposition")
31@linear_operator.make_composite_tensor
32class LinearOperatorComposition(linear_operator.LinearOperator):
33  """Composes one or more `LinearOperators`.
34
35  This operator composes one or more linear operators `[op1,...,opJ]`,
36  building a new `LinearOperator` with action defined by:
37
38  ```
39  op_composed(x) := op1(op2(...(opJ(x)...))
40  ```
41
42  If `opj` acts like [batch] matrix `Aj`, then `op_composed` acts like the
43  [batch] matrix formed with the multiplication `A1 A2...AJ`.
44
45  If `opj` has shape `batch_shape_j + [M_j, N_j]`, then we must have
46  `N_j = M_{j+1}`, in which case the composed operator has shape equal to
47  `broadcast_batch_shape + [M_1, N_J]`, where `broadcast_batch_shape` is the
48  mutual broadcast of `batch_shape_j`, `j = 1,...,J`, assuming the intermediate
49  batch shapes broadcast.  Even if the composed shape is well defined, the
50  composed operator's methods may fail due to lack of broadcasting ability in
51  the defining operators' methods.
52
53  ```python
54  # Create a 2 x 2 linear operator composed of two 2 x 2 operators.
55  operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
56  operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]])
57  operator = LinearOperatorComposition([operator_1, operator_2])
58
59  operator.to_dense()
60  ==> [[1., 2.]
61       [3., 4.]]
62
63  operator.shape
64  ==> [2, 2]
65
66  operator.log_abs_determinant()
67  ==> scalar Tensor
68
69  x = ... Shape [2, 4] Tensor
70  operator.matmul(x)
71  ==> Shape [2, 4] Tensor
72
73  # Create a [2, 3] batch of 4 x 5 linear operators.
74  matrix_45 = tf.random.normal(shape=[2, 3, 4, 5])
75  operator_45 = LinearOperatorFullMatrix(matrix)
76
77  # Create a [2, 3] batch of 5 x 6 linear operators.
78  matrix_56 = tf.random.normal(shape=[2, 3, 5, 6])
79  operator_56 = LinearOperatorFullMatrix(matrix_56)
80
81  # Compose to create a [2, 3] batch of 4 x 6 operators.
82  operator_46 = LinearOperatorComposition([operator_45, operator_56])
83
84  # Create a shape [2, 3, 6, 2] vector.
85  x = tf.random.normal(shape=[2, 3, 6, 2])
86  operator.matmul(x)
87  ==> Shape [2, 3, 4, 2] Tensor
88  ```
89
90  #### Performance
91
92  The performance of `LinearOperatorComposition` on any operation is equal to
93  the sum of the individual operators' operations.
94
95
96  #### Matrix property hints
97
98  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
99  for `X = non_singular, self_adjoint, positive_definite, square`.
100  These have the following meaning:
101
102  * If `is_X == True`, callers should expect the operator to have the
103    property `X`.  This is a promise that should be fulfilled, but is *not* a
104    runtime assert.  For example, finite floating point precision may result
105    in these promises being violated.
106  * If `is_X == False`, callers should expect the operator to not have `X`.
107  * If `is_X == None` (the default), callers should have no expectation either
108    way.
109  """
110
111  def __init__(self,
112               operators,
113               is_non_singular=None,
114               is_self_adjoint=None,
115               is_positive_definite=None,
116               is_square=None,
117               name=None):
118    r"""Initialize a `LinearOperatorComposition`.
119
120    `LinearOperatorComposition` is initialized with a list of operators
121    `[op_1,...,op_J]`.  For the `matmul` method to be well defined, the
122    composition `op_i.matmul(op_{i+1}(x))` must be defined.  Other methods have
123    similar constraints.
124
125    Args:
126      operators:  Iterable of `LinearOperator` objects, each with
127        the same `dtype` and composable shape.
128      is_non_singular:  Expect that this operator is non-singular.
129      is_self_adjoint:  Expect that this operator is equal to its hermitian
130        transpose.
131      is_positive_definite:  Expect that this operator is positive definite,
132        meaning the quadratic form `x^H A x` has positive real part for all
133        nonzero `x`.  Note that we do not require the operator to be
134        self-adjoint to be positive-definite.  See:
135        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
136      is_square:  Expect that this operator acts like square [batch] matrices.
137      name: A name for this `LinearOperator`.  Default is the individual
138        operators names joined with `_o_`.
139
140    Raises:
141      TypeError:  If all operators do not have the same `dtype`.
142      ValueError:  If `operators` is empty.
143    """
144    parameters = dict(
145        operators=operators,
146        is_non_singular=is_non_singular,
147        is_self_adjoint=is_self_adjoint,
148        is_positive_definite=is_positive_definite,
149        is_square=is_square,
150        name=name)
151
152    # Validate operators.
153    check_ops.assert_proper_iterable(operators)
154    operators = list(operators)
155    if not operators:
156      raise ValueError(
157          "Expected a non-empty list of operators. Found: %s" % operators)
158    self._operators = operators
159
160    # Validate dtype.
161    dtype = operators[0].dtype
162    for operator in operators:
163      if operator.dtype != dtype:
164        name_type = (str((o.name, o.dtype)) for o in operators)
165        raise TypeError(
166            "Expected all operators to have the same dtype.  Found %s"
167            % "   ".join(name_type))
168
169    # Auto-set and check hints.
170    if all(operator.is_non_singular for operator in operators):
171      if is_non_singular is False:  # pylint:disable=g-bool-id-comparison
172        raise ValueError(
173            "The composition of non-singular operators is always non-singular.")
174      is_non_singular = True
175
176    # Initialization.
177
178    if name is None:
179      name = "_o_".join(operator.name for operator in operators)
180    with ops.name_scope(name):
181      super(LinearOperatorComposition, self).__init__(
182          dtype=dtype,
183          is_non_singular=is_non_singular,
184          is_self_adjoint=is_self_adjoint,
185          is_positive_definite=is_positive_definite,
186          is_square=is_square,
187          parameters=parameters,
188          name=name)
189
190  @property
191  def operators(self):
192    return self._operators
193
194  def _shape(self):
195    # Get final matrix shape.
196    domain_dimension = self.operators[0].domain_dimension
197    for operator in self.operators[1:]:
198      domain_dimension.assert_is_compatible_with(operator.range_dimension)
199      domain_dimension = operator.domain_dimension
200
201    matrix_shape = tensor_shape.TensorShape(
202        [self.operators[0].range_dimension,
203         self.operators[-1].domain_dimension])
204
205    # Get broadcast batch shape.
206    # broadcast_shape checks for compatibility.
207    batch_shape = self.operators[0].batch_shape
208    for operator in self.operators[1:]:
209      batch_shape = common_shapes.broadcast_shape(
210          batch_shape, operator.batch_shape)
211
212    return batch_shape.concatenate(matrix_shape)
213
214  def _shape_tensor(self):
215    # Avoid messy broadcasting if possible.
216    if self.shape.is_fully_defined():
217      return ops.convert_to_tensor(
218          self.shape.as_list(), dtype=dtypes.int32, name="shape")
219
220    # Don't check the matrix dimensions.  That would add unnecessary Asserts to
221    # the graph.  Things will fail at runtime naturally if shapes are
222    # incompatible.
223    matrix_shape = array_ops.stack([
224        self.operators[0].range_dimension_tensor(),
225        self.operators[-1].domain_dimension_tensor()
226    ])
227
228    # Dummy Tensor of zeros.  Will never be materialized.
229    zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
230    for operator in self.operators[1:]:
231      zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
232    batch_shape = array_ops.shape(zeros)
233
234    return array_ops.concat((batch_shape, matrix_shape), 0)
235
236  def _matmul(self, x, adjoint=False, adjoint_arg=False):
237    # If self.operators = [A, B], and not adjoint, then
238    # matmul_order_list = [B, A].
239    # As a result, we return A.matmul(B.matmul(x))
240    if adjoint:
241      matmul_order_list = self.operators
242    else:
243      matmul_order_list = list(reversed(self.operators))
244
245    result = matmul_order_list[0].matmul(
246        x, adjoint=adjoint, adjoint_arg=adjoint_arg)
247    for operator in matmul_order_list[1:]:
248      result = operator.matmul(result, adjoint=adjoint)
249    return result
250
251  def _determinant(self):
252    result = self.operators[0].determinant()
253    for operator in self.operators[1:]:
254      result *= operator.determinant()
255    return result
256
257  def _log_abs_determinant(self):
258    result = self.operators[0].log_abs_determinant()
259    for operator in self.operators[1:]:
260      result += operator.log_abs_determinant()
261    return result
262
263  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
264    # TODO(langmore) Implement solve using solve_ls if some intermediate
265    # operator maps to a high dimensional space.
266    # In that case, an exact solve may still be possible.
267
268    # If self.operators = [A, B], and not adjoint, then
269    # solve_order_list = [A, B].
270    # As a result, we return B.solve(A.solve(x))
271    if adjoint:
272      solve_order_list = list(reversed(self.operators))
273    else:
274      solve_order_list = self.operators
275
276    solution = solve_order_list[0].solve(
277        rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
278    for operator in solve_order_list[1:]:
279      solution = operator.solve(solution, adjoint=adjoint)
280    return solution
281
282  def _assert_non_singular(self):
283    if all(operator.is_square for operator in self.operators):
284      asserts = [operator.assert_non_singular() for operator in self.operators]
285      return control_flow_ops.group(asserts)
286    return super(LinearOperatorComposition, self)._assert_non_singular()
287
288  @property
289  def _composite_tensor_fields(self):
290    return ("operators",)
291
292  @property
293  def _experimental_parameter_ndims_to_matrix_ndims(self):
294    return {"operators": [0] * len(self.operators)}
295