• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Create a blockwise lower-triangular operator from `LinearOperators`."""
16
17from tensorflow.python.framework import common_shapes
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.linalg import linalg_impl as linalg
26from tensorflow.python.ops.linalg import linear_operator
27from tensorflow.python.ops.linalg import linear_operator_algebra
28from tensorflow.python.ops.linalg import linear_operator_util
29from tensorflow.python.util import nest
30from tensorflow.python.util.tf_export import tf_export
31
32__all__ = ["LinearOperatorBlockLowerTriangular"]
33
34
35@tf_export("linalg.LinearOperatorBlockLowerTriangular")
36@linear_operator.make_composite_tensor
37class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
38  """Combines `LinearOperators` into a blockwise lower-triangular matrix.
39
40  This operator is initialized with a nested list of linear operators, which
41  are combined into a new `LinearOperator` whose underlying matrix
42  representation is square and has each operator on or below the main diagonal,
43  and zero's elsewhere. Each element of the outer list is a list of
44  `LinearOperators` corresponding to a row-partition of the blockwise structure.
45  The number of `LinearOperator`s in row-partion `i` must be equal to `i`.
46
47  For example, a blockwise `3 x 3` `LinearOperatorBlockLowerTriangular` is
48  initialized with the list `[[op_00], [op_10, op_11], [op_20, op_21, op_22]]`,
49  where the `op_ij`, `i < 3, j <= i`, are `LinearOperator` instances. The
50  `LinearOperatorBlockLowerTriangular` behaves as the following blockwise
51  matrix, where `0` represents appropriately-sized [batch] matrices of zeros:
52
53  ```none
54  [[op_00,     0,     0],
55   [op_10, op_11,     0],
56   [op_20, op_21, op_22]]
57  ```
58
59  Each `op_jj` on the diagonal is required to represent a square matrix, and
60  hence will have shape `batch_shape_j + [M_j, M_j]`. `LinearOperator`s in row
61  `j` of the blockwise structure must have `range_dimension` equal to that of
62  `op_jj`, and `LinearOperators` in column `j` must have `domain_dimension`
63  equal to that of `op_jj`.
64
65  If each `op_jj` on the diagonal has shape `batch_shape_j + [M_j, M_j]`, then
66  the combined operator has shape `broadcast_batch_shape + [sum M_j, sum M_j]`,
67  where `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`,
68  `j = 0, 1, ..., J`, assuming the intermediate batch shapes broadcast.
69  Even if the combined shape is well defined, the combined operator's
70  methods may fail due to lack of broadcasting ability in the defining
71  operators' methods.
72
73  For example, to create a 4 x 4 linear operator combined of three 2 x 2
74  operators:
75  >>> operator_0 = tf.linalg.LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
76  >>> operator_1 = tf.linalg.LinearOperatorFullMatrix([[1., 0.], [0., 1.]])
77  >>> operator_2 = tf.linalg.LinearOperatorLowerTriangular([[5., 6.], [7., 8]])
78  >>> operator = LinearOperatorBlockLowerTriangular(
79  ...   [[operator_0], [operator_1, operator_2]])
80
81  >>> operator.to_dense()
82  <tf.Tensor: shape=(4, 4), dtype=float32, numpy=
83  array([[1., 2., 0., 0.],
84         [3., 4., 0., 0.],
85         [1., 0., 5., 0.],
86         [0., 1., 7., 8.]], dtype=float32)>
87
88  >>> operator.shape
89  TensorShape([4, 4])
90
91  >>> operator.log_abs_determinant()
92  <tf.Tensor: shape=(), dtype=float32, numpy=4.3820267>
93
94  >>> x0 = [[1., 6.], [-3., 4.]]
95  >>> x1 = [[0., 2.], [4., 0.]]
96  >>> x = tf.concat([x0, x1], 0)  # Shape [2, 4] Tensor
97  >>> operator.matmul(x)
98  <tf.Tensor: shape=(4, 2), dtype=float32, numpy=
99  array([[-5., 14.],
100         [-9., 34.],
101         [ 1., 16.],
102         [29., 18.]], dtype=float32)>
103
104  The above `matmul` is equivalent to:
105  >>> tf.concat([operator_0.matmul(x0),
106  ...   operator_1.matmul(x0) + operator_2.matmul(x1)], axis=0)
107  <tf.Tensor: shape=(4, 2), dtype=float32, numpy=
108  array([[-5., 14.],
109         [-9., 34.],
110         [ 1., 16.],
111         [29., 18.]], dtype=float32)>
112
113  #### Shape compatibility
114
115  This operator acts on [batch] matrix with compatible shape.
116  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
117
118  ```
119  operator.shape = [B1,...,Bb] + [M, N],  with b >= 0
120  x.shape =        [B1,...,Bb] + [N, R],  with R >= 0.
121  ```
122
123  For example:
124
125  Create a [2, 3] batch of 4 x 4 linear operators:
126  >>> matrix_44 = tf.random.normal(shape=[2, 3, 4, 4])
127  >>> operator_44 = tf.linalg.LinearOperatorFullMatrix(matrix_44)
128
129  Create a [1, 3] batch of 5 x 4 linear operators:
130  >>> matrix_54 = tf.random.normal(shape=[1, 3, 5, 4])
131  >>> operator_54 = tf.linalg.LinearOperatorFullMatrix(matrix_54)
132
133  Create a [1, 3] batch of 5 x 5 linear operators:
134  >>> matrix_55 = tf.random.normal(shape=[1, 3, 5, 5])
135  >>> operator_55 = tf.linalg.LinearOperatorFullMatrix(matrix_55)
136
137  Combine to create a [2, 3] batch of 9 x 9 operators:
138  >>> operator_99 = LinearOperatorBlockLowerTriangular(
139  ...   [[operator_44], [operator_54, operator_55]])
140  >>> operator_99.shape
141  TensorShape([2, 3, 9, 9])
142
143  Create a shape [2, 1, 9] batch of vectors and apply the operator to it.
144  >>> x = tf.random.normal(shape=[2, 1, 9])
145  >>> y = operator_99.matvec(x)
146  >>> y.shape
147  TensorShape([2, 3, 9])
148
149  Create a blockwise list of vectors and apply the operator to it. A blockwise
150  list is returned.
151  >>> x4 = tf.random.normal(shape=[2, 1, 4])
152  >>> x5 = tf.random.normal(shape=[2, 3, 5])
153  >>> y_blockwise = operator_99.matvec([x4, x5])
154  >>> y_blockwise[0].shape
155  TensorShape([2, 3, 4])
156  >>> y_blockwise[1].shape
157  TensorShape([2, 3, 5])
158
159  #### Performance
160
161  Suppose `operator` is a `LinearOperatorBlockLowerTriangular` consisting of `D`
162  row-partitions and `D` column-partitions, such that the total number of
163  operators is `N = D * (D + 1) // 2`.
164
165  * `operator.matmul` has complexity equal to the sum of the `matmul`
166    complexities of the individual operators.
167  * `operator.solve` has complexity equal to the sum of the `solve` complexities
168    of the operators on the diagonal and the `matmul` complexities of the
169    operators off the diagonal.
170  * `operator.determinant` has complexity equal to the sum of the `determinant`
171    complexities of the operators on the diagonal.
172
173  #### Matrix property hints
174
175  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
176  for `X = non_singular, self_adjoint, positive_definite, square`.
177  These have the following meaning:
178
179  * If `is_X == True`, callers should expect the operator to have the
180    property `X`.  This is a promise that should be fulfilled, but is *not* a
181    runtime assert.  For example, finite floating point precision may result
182    in these promises being violated.
183  * If `is_X == False`, callers should expect the operator to not have `X`.
184  * If `is_X == None` (the default), callers should have no expectation either
185    way.
186  """
187
188  def __init__(self,
189               operators,
190               is_non_singular=None,
191               is_self_adjoint=None,
192               is_positive_definite=None,
193               is_square=None,
194               name="LinearOperatorBlockLowerTriangular"):
195    r"""Initialize a `LinearOperatorBlockLowerTriangular`.
196
197    `LinearOperatorBlockLowerTriangular` is initialized with a list of lists of
198    operators `[[op_0], [op_1, op_2], [op_3, op_4, op_5],...]`.
199
200    Args:
201      operators:  Iterable of iterables of `LinearOperator` objects, each with
202        the same `dtype`. Each element of `operators` corresponds to a row-
203        partition, in top-to-bottom order. The operators in each row-partition
204        are filled in left-to-right. For example,
205        `operators = [[op_0], [op_1, op_2], [op_3, op_4, op_5]]` creates a
206        `LinearOperatorBlockLowerTriangular` with full block structure
207        `[[op_0, 0, 0], [op_1, op_2, 0], [op_3, op_4, op_5]]`. The number of
208        operators in the `i`th row must be equal to `i`, such that each operator
209        falls on or below the diagonal of the blockwise structure.
210        `LinearOperator`s that fall on the diagonal (the last elements of each
211        row) must be square. The other `LinearOperator`s must have domain
212        dimension equal to the domain dimension of the `LinearOperator`s in the
213        same column-partition, and range dimension equal to the range dimension
214        of the `LinearOperator`s in the same row-partition.
215      is_non_singular:  Expect that this operator is non-singular.
216      is_self_adjoint:  Expect that this operator is equal to its hermitian
217        transpose.
218      is_positive_definite:  Expect that this operator is positive definite,
219        meaning the quadratic form `x^H A x` has positive real part for all
220        nonzero `x`.  Note that we do not require the operator to be
221        self-adjoint to be positive-definite.  See:
222        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
223      is_square:  Expect that this operator acts like square [batch] matrices.
224        This will raise a `ValueError` if set to `False`.
225      name: A name for this `LinearOperator`.
226
227    Raises:
228      TypeError:  If all operators do not have the same `dtype`.
229      ValueError:  If `operators` is empty, contains an erroneous number of
230        elements, or contains operators with incompatible shapes.
231    """
232    parameters = dict(
233        operators=operators,
234        is_non_singular=is_non_singular,
235        is_self_adjoint=is_self_adjoint,
236        is_positive_definite=is_positive_definite,
237        is_square=is_square,
238        name=name
239    )
240
241    # Validate operators.
242    check_ops.assert_proper_iterable(operators)
243    for row in operators:
244      check_ops.assert_proper_iterable(row)
245    operators = [list(row) for row in operators]
246
247    if not operators:
248      raise ValueError(f"Argument `operators` must be a list of >=1 operators. "
249                       f"Received: {operators}.")
250    self._operators = operators
251    self._diagonal_operators = [row[-1] for row in operators]
252
253    dtype = operators[0][0].dtype
254    self._validate_dtype(dtype)
255    is_non_singular = self._validate_non_singular(is_non_singular)
256    self._validate_num_operators()
257    self._validate_operator_dimensions()
258    is_square = self._validate_square(is_square)
259    with ops.name_scope(name):
260      super(LinearOperatorBlockLowerTriangular, self).__init__(
261          dtype=dtype,
262          is_non_singular=is_non_singular,
263          is_self_adjoint=is_self_adjoint,
264          is_positive_definite=is_positive_definite,
265          is_square=is_square,
266          parameters=parameters,
267          name=name)
268
269  def _validate_num_operators(self):
270    for i, row in enumerate(self.operators):
271      if len(row) != i + 1:
272        raise ValueError(
273            f"Argument `operators[{i}]` must contain `{i + 1}` blocks. "
274            f"Received: {len(row)} blocks.")
275
276  def _validate_operator_dimensions(self):
277    """Check that `operators` have compatible dimensions."""
278    for i in range(1, len(self.operators)):
279      for j in range(i):
280        op = self.operators[i][j]
281
282        # `above_op` is the operator directly above `op` in the blockwise
283        # structure, in row partition `i-1`, column partition `j`. `op` should
284        # have the same `domain_dimension` as `above_op`.
285        above_op = self.operators[i - 1][j]
286
287        # `right_op` is the operator to the right of `op` in the blockwise
288        # structure, in row partition `i`, column partition `j+1`. `op` should
289        # have the same `range_dimension` as `right_op`.
290        right_op = self.operators[i][j + 1]
291
292        if (op.domain_dimension is not None and
293            above_op.domain_dimension is not None):
294          if op.domain_dimension != above_op.domain_dimension:
295            raise ValueError(f"Argument `operators[{i}][{j}].domain_dimension` "
296                             f"({op.domain_dimension}) must be the same as "
297                             f"`operators[{i-1}][{j}].domain_dimension` "
298                             f"({above_op.domain_dimension}).")
299        if (op.range_dimension is not None and
300            right_op.range_dimension is not None):
301          if op.range_dimension != right_op.range_dimension:
302            raise ValueError(f"Argument `operators[{i}][{j}].range_dimension` "
303                             f"({op.range_dimension}) must be the same as "
304                             f"`operators[{i}][{j + 1}].range_dimension` "
305                             f"({right_op.range_dimension}).")
306
307  # pylint: disable=g-bool-id-comparison
308  def _validate_non_singular(self, is_non_singular):
309    if all(op.is_non_singular for op in self._diagonal_operators):
310      if is_non_singular is False:
311        raise ValueError(
312            f"A blockwise lower-triangular operator with non-singular "
313            f"operators on the main diagonal is always non-singular. "
314            f"Expected argument `is_non_singular` to be True. "
315            f"Received: {is_non_singular}.")
316      return True
317    if any(op.is_non_singular is False for op in self._diagonal_operators):
318      if is_non_singular is True:
319        raise ValueError(
320            f"A blockwise lower-triangular operator with a singular operator "
321            f"on the main diagonal is always singular. Expected argument "
322            f"`is_non_singular` to be True. Received: {is_non_singular}.")
323      return False
324
325  def _validate_square(self, is_square):
326    if is_square is False:
327      raise ValueError(f"`LinearOperatorBlockLowerTriangular` must be square. "
328                       f"Expected argument `is_square` to be True. "
329                       f"Received: {is_square}.")
330    for i, op in enumerate(self._diagonal_operators):
331      if op.is_square is False:
332        raise ValueError(
333            f"Matrices on the diagonal (the final elements of each "
334            f"row-partition in the `operators` list) must be square. Expected "
335            f"argument `operators[{i}][-1].is_square` to be True. "
336            f"Received: {op.is_square}.")
337    return True
338  # pylint: enable=g-bool-id-comparison
339
340  def _validate_dtype(self, dtype):
341    for i, row in enumerate(self.operators):
342      for operator in row:
343        if operator.dtype != dtype:
344          name_type = (str((o.name, o.dtype)) for o in row)
345          raise TypeError(
346              "Expected all operators to have the same dtype.  Found {} in row "
347              "{} and {} in row 0.".format(name_type, i, str(dtype)))
348
349  @property
350  def operators(self):
351    return self._operators
352
353  def _block_range_dimensions(self):
354    return [op.range_dimension for op in self._diagonal_operators]
355
356  def _block_domain_dimensions(self):
357    return [op.domain_dimension for op in self._diagonal_operators]
358
359  def _block_range_dimension_tensors(self):
360    return [op.range_dimension_tensor() for op in self._diagonal_operators]
361
362  def _block_domain_dimension_tensors(self):
363    return [op.domain_dimension_tensor() for op in self._diagonal_operators]
364
365  def _shape(self):
366    # Get final matrix shape.
367    domain_dimension = sum(self._block_domain_dimensions())
368    range_dimension = sum(self._block_range_dimensions())
369    matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension])
370
371    # Get broadcast batch shape.
372    # broadcast_shape checks for compatibility.
373    batch_shape = self.operators[0][0].batch_shape
374    for row in self.operators[1:]:
375      for operator in row:
376        batch_shape = common_shapes.broadcast_shape(
377            batch_shape, operator.batch_shape)
378
379    return batch_shape.concatenate(matrix_shape)
380
381  def _shape_tensor(self):
382    # Avoid messy broadcasting if possible.
383    if self.shape.is_fully_defined():
384      return ops.convert_to_tensor_v2_with_dispatch(
385          self.shape.as_list(), dtype=dtypes.int32, name="shape")
386
387    domain_dimension = sum(self._block_domain_dimension_tensors())
388    range_dimension = sum(self._block_range_dimension_tensors())
389    matrix_shape = array_ops.stack([domain_dimension, range_dimension])
390
391    batch_shape = self.operators[0][0].batch_shape_tensor()
392    for row in self.operators[1:]:
393      for operator in row:
394        batch_shape = array_ops.broadcast_dynamic_shape(
395            batch_shape, operator.batch_shape_tensor())
396
397    return array_ops.concat((batch_shape, matrix_shape), 0)
398
399  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
400    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.
401
402    ```python
403    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
404    operator = LinearOperator(...)
405    operator.shape = [..., M, N]
406
407    X = ... # shape [..., N, R], batch matrix, R > 0.
408
409    Y = operator.matmul(X)
410    Y.shape
411    ==> [..., M, R]
412
413    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
414    ```
415
416    Args:
417      x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as
418        `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See
419        class docstring for definition of shape compatibility.
420      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
421      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
422        the hermitian transpose (transposition and complex conjugation).
423      name:  A name for this `Op`.
424
425    Returns:
426      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
427        as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
428        concatenate to `[..., M, R]`.
429    """
430    if isinstance(x, linear_operator.LinearOperator):
431      left_operator = self.adjoint() if adjoint else self
432      right_operator = x.adjoint() if adjoint_arg else x
433
434      if (right_operator.range_dimension is not None and
435          left_operator.domain_dimension is not None and
436          right_operator.range_dimension != left_operator.domain_dimension):
437        raise ValueError(
438            "Operators are incompatible. Expected `x` to have dimension"
439            " {} but got {}.".format(
440                left_operator.domain_dimension, right_operator.range_dimension))
441      with self._name_scope(name):  # pylint: disable=not-callable
442        return linear_operator_algebra.matmul(left_operator, right_operator)
443
444    with self._name_scope(name):  # pylint: disable=not-callable
445      arg_dim = -1 if adjoint_arg else -2
446      block_dimensions = (self._block_range_dimensions() if adjoint
447                          else self._block_domain_dimensions())
448      if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
449        for i, block in enumerate(x):
450          if not isinstance(block, linear_operator.LinearOperator):
451            block = ops.convert_to_tensor_v2_with_dispatch(block)
452            self._check_input_dtype(block)
453            block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
454            x[i] = block
455      else:
456        x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
457        self._check_input_dtype(x)
458        op_dimension = (self.range_dimension if adjoint
459                        else self.domain_dimension)
460        op_dimension.assert_is_compatible_with(x.shape[arg_dim])
461      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
462
463  def _matmul(self, x, adjoint=False, adjoint_arg=False):
464    arg_dim = -1 if adjoint_arg else -2
465    block_dimensions = (self._block_range_dimensions() if adjoint
466                        else self._block_domain_dimensions())
467    blockwise_arg = linear_operator_util.arg_is_blockwise(
468        block_dimensions, x, arg_dim)
469    if blockwise_arg:
470      split_x = x
471    else:
472      split_dim = -1 if adjoint_arg else -2
473      # Split input by columns if adjoint_arg is True, else rows
474      split_x = linear_operator_util.split_arg_into_blocks(
475          self._block_domain_dimensions(),
476          self._block_domain_dimension_tensors,
477          x, axis=split_dim)
478
479    result_list = []
480    # Iterate over row-partitions (i.e. column-partitions of the adjoint).
481    if adjoint:
482      for index in range(len(self.operators)):
483        # Begin with the operator on the diagonal and apply it to the
484        # respective `rhs` block.
485        result = self.operators[index][index].matmul(
486            split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)
487
488        # Iterate top to bottom over the operators in the remainder of the
489        # column-partition (i.e. left to right over the row-partition of the
490        # adjoint), apply the operator to the respective `rhs` block and
491        # accumulate the sum. For example, given the
492        # `LinearOperatorBlockLowerTriangular`:
493        #
494        # op = [[A, 0, 0],
495        #       [B, C, 0],
496        #       [D, E, F]]
497        #
498        # if `index = 1`, the following loop calculates:
499        # `y_1 = (C.matmul(x_1, adjoint=adjoint) +
500        #         E.matmul(x_2, adjoint=adjoint)`,
501        # where `x_1` and `x_2` are splits of `x`.
502        for j in range(index + 1, len(self.operators)):
503          result += self.operators[j][index].matmul(
504              split_x[j], adjoint=adjoint, adjoint_arg=adjoint_arg)
505        result_list.append(result)
506    else:
507      for row in self.operators:
508        # Begin with the left-most operator in the row-partition and apply it
509        # to the first `rhs` block.
510        result = row[0].matmul(
511            split_x[0], adjoint=adjoint, adjoint_arg=adjoint_arg)
512        # Iterate left to right over the operators in the remainder of the row
513        # partition, apply the operator to the respective `rhs` block, and
514        # accumulate the sum.
515        for j, operator in enumerate(row[1:]):
516          result += operator.matmul(
517              split_x[j + 1], adjoint=adjoint, adjoint_arg=adjoint_arg)
518        result_list.append(result)
519
520    if blockwise_arg:
521      return result_list
522
523    result_list = linear_operator_util.broadcast_matrix_batch_dims(
524        result_list)
525    return array_ops.concat(result_list, axis=-2)
526
527  def matvec(self, x, adjoint=False, name="matvec"):
528    """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.
529
530    ```python
531    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
532    operator = LinearOperator(...)
533
534    X = ... # shape [..., N], batch vector
535
536    Y = operator.matvec(X)
537    Y.shape
538    ==> [..., M]
539
540    Y[..., :] = sum_j A[..., :, j] X[..., j]
541    ```
542
543    Args:
544      x: `Tensor` with compatible shape and same `dtype` as `self`, or an
545        iterable of `Tensor`s. `Tensor`s are treated a [batch] vectors, meaning
546        for every set of leading dimensions, the last dimension defines a
547        vector.
548        See class docstring for definition of compatibility.
549      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
550      name:  A name for this `Op`.
551
552    Returns:
553      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
554    """
555    with self._name_scope(name):  # pylint: disable=not-callable
556      block_dimensions = (self._block_range_dimensions() if adjoint
557                          else self._block_domain_dimensions())
558      if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
559        for i, block in enumerate(x):
560          if not isinstance(block, linear_operator.LinearOperator):
561            block = ops.convert_to_tensor_v2_with_dispatch(block)
562            self._check_input_dtype(block)
563            block_dimensions[i].assert_is_compatible_with(block.shape[-1])
564            x[i] = block
565        x_mat = [block[..., array_ops.newaxis] for block in x]
566        y_mat = self.matmul(x_mat, adjoint=adjoint)
567        return [array_ops.squeeze(y, axis=-1) for y in y_mat]
568
569      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
570      self._check_input_dtype(x)
571      op_dimension = (self.range_dimension if adjoint
572                      else self.domain_dimension)
573      op_dimension.assert_is_compatible_with(x.shape[-1])
574      x_mat = x[..., array_ops.newaxis]
575      y_mat = self.matmul(x_mat, adjoint=adjoint)
576      return array_ops.squeeze(y_mat, axis=-1)
577
578  def _determinant(self):
579    if all(op.is_positive_definite for op in self._diagonal_operators):
580      return math_ops.exp(self._log_abs_determinant())
581    result = self._diagonal_operators[0].determinant()
582    for op in self._diagonal_operators[1:]:
583      result *= op.determinant()
584    return result
585
586  def _log_abs_determinant(self):
587    result = self._diagonal_operators[0].log_abs_determinant()
588    for op in self._diagonal_operators[1:]:
589      result += op.log_abs_determinant()
590    return result
591
592  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
593    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
594
595    The returned `Tensor` will be close to an exact solution if `A` is well
596    conditioned. Otherwise closeness will vary. See class docstring for details.
597
598    Given the blockwise `n + 1`-by-`n + 1` linear operator:
599
600    op = [[A_00     0  ...     0  ...    0],
601          [A_10  A_11  ...     0  ...    0],
602          ...
603          [A_k0  A_k1  ...  A_kk  ...    0],
604          ...
605          [A_n0  A_n1  ...  A_nk  ... A_nn]]
606
607    we find `x = op.solve(y)` by observing that
608
609    `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)`
610
611    and therefore
612
613    `x_k = A_kk.solve(y_k -
614                      A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))`
615
616    where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x`
617    and `y` along their appropriate axes.
618
619    We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve
620    for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`.
621
622    The adjoint case is solved similarly, beginning with
623    `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards.
624
625    Examples:
626
627    ```python
628    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
629    operator = LinearOperator(...)
630    operator.shape = [..., M, N]
631
632    # Solve R > 0 linear systems for every member of the batch.
633    RHS = ... # shape [..., M, R]
634
635    X = operator.solve(RHS)
636    # X[..., :, r] is the solution to the r'th linear system
637    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
638
639    operator.matmul(X)
640    ==> RHS
641    ```
642
643    Args:
644      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
645        or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices
646        meaning for every set of leading dimensions, the last two dimensions
647        defines a matrix.
648        See class docstring for definition of compatibility.
649      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
650        of this `LinearOperator`:  `A^H X = rhs`.
651      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
652        is the hermitian transpose (transposition and complex conjugation).
653      name:  A name scope to use for ops added by this method.
654
655    Returns:
656      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
657
658    Raises:
659      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
660    """
661    if self.is_non_singular is False:
662      raise NotImplementedError(
663          "Exact solve not implemented for an operator that is expected to "
664          "be singular.")
665    if self.is_square is False:
666      raise NotImplementedError(
667          "Exact solve not implemented for an operator that is expected to "
668          "not be square.")
669    if isinstance(rhs, linear_operator.LinearOperator):
670      left_operator = self.adjoint() if adjoint else self
671      right_operator = rhs.adjoint() if adjoint_arg else rhs
672
673      if (right_operator.range_dimension is not None and
674          left_operator.domain_dimension is not None and
675          right_operator.range_dimension != left_operator.domain_dimension):
676        raise ValueError(
677            "Operators are incompatible. Expected `rhs` to have dimension"
678            " {} but got {}.".format(
679                left_operator.domain_dimension, right_operator.range_dimension))
680      with self._name_scope(name):  # pylint: disable=not-callable
681        return linear_operator_algebra.solve(left_operator, right_operator)
682
683    with self._name_scope(name):  # pylint: disable=not-callable
684      block_dimensions = (self._block_domain_dimensions() if adjoint
685                          else self._block_range_dimensions())
686      arg_dim = -1 if adjoint_arg else -2
687      blockwise_arg = linear_operator_util.arg_is_blockwise(
688          block_dimensions, rhs, arg_dim)
689      if blockwise_arg:
690        for i, block in enumerate(rhs):
691          if not isinstance(block, linear_operator.LinearOperator):
692            block = ops.convert_to_tensor_v2_with_dispatch(block)
693            self._check_input_dtype(block)
694            block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
695            rhs[i] = block
696        if adjoint_arg:
697          split_rhs = [linalg.adjoint(y) for y in rhs]
698        else:
699          split_rhs = rhs
700
701      else:
702        rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
703        self._check_input_dtype(rhs)
704        op_dimension = (self.domain_dimension if adjoint
705                        else self.range_dimension)
706        op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])
707
708        rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
709        split_rhs = linear_operator_util.split_arg_into_blocks(
710            self._block_domain_dimensions(),
711            self._block_domain_dimension_tensors,
712            rhs, axis=-2)
713
714      solution_list = []
715      if adjoint:
716        # For an adjoint blockwise lower-triangular linear operator, the system
717        # must be solved bottom to top. Iterate backwards over rows of the
718        # adjoint (i.e. columns of the non-adjoint operator).
719        for index in reversed(range(len(self.operators))):
720          y = split_rhs[index]
721          # Iterate top to bottom over the operators in the off-diagonal portion
722          # of the column-partition (i.e. row-partition of the adjoint), apply
723          # the operator to the respective block of the solution found in
724          # previous iterations, and subtract the result from the `rhs` block.
725          # For example,let `A`, `B`, and `D` be the linear operators in the top
726          # row-partition of the adjoint of
727          # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`,
728          # and `x_1` and `x_2` be blocks of the solution found in previous
729          # iterations of the outer loop. The following loop (when `index == 0`)
730          # expresses
731          # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where
732          # `y_0* = y_0 - Bx_1 - Dx_2`.
733          for j in reversed(range(index + 1, len(self.operators))):
734            y = y - self.operators[j][index].matmul(
735                solution_list[len(self.operators) - 1 - j],
736                adjoint=adjoint)
737          # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`.
738          solution_list.append(
739              self._diagonal_operators[index].solve(y, adjoint=adjoint))
740        solution_list.reverse()
741      else:
742        # Iterate top to bottom over the row-partitions.
743        for row, y in zip(self.operators, split_rhs):
744          # Iterate left to right over the operators in the off-diagonal portion
745          # of the row-partition, apply the operator to the block of the
746          # solution found in previous iterations, and subtract the result from
747          # the `rhs` block. For example, let `D`, `E`, and `F` be the linear
748          # operators in the bottom row-partition of
749          # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and
750          # `x_0` and `x_1` be blocks of the solution found in previous
751          # iterations of the outer loop. The following loop
752          # (when `index == 2`), expresses
753          # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where
754          # `y_2* = y_2 - D_x0 - Ex_1`.
755          for i, operator in enumerate(row[:-1]):
756            y = y - operator.matmul(solution_list[i], adjoint=adjoint)
757          # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`.
758          solution_list.append(row[-1].solve(y, adjoint=adjoint))
759
760      if blockwise_arg:
761        return solution_list
762
763      solution_list = linear_operator_util.broadcast_matrix_batch_dims(
764          solution_list)
765      return array_ops.concat(solution_list, axis=-2)
766
767  def solvevec(self, rhs, adjoint=False, name="solve"):
768    """Solve single equation with best effort: `A X = rhs`.
769
770    The returned `Tensor` will be close to an exact solution if `A` is well
771    conditioned. Otherwise closeness will vary. See class docstring for details.
772
773    Examples:
774
775    ```python
776    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
777    operator = LinearOperator(...)
778    operator.shape = [..., M, N]
779
780    # Solve one linear system for every member of the batch.
781    RHS = ... # shape [..., M]
782
783    X = operator.solvevec(RHS)
784    # X is the solution to the linear system
785    # sum_j A[..., :, j] X[..., j] = RHS[..., :]
786
787    operator.matvec(X)
788    ==> RHS
789    ```
790
791    Args:
792      rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s
793        (for blockwise operators). `Tensor`s are treated as [batch] vectors,
794        meaning for every set of leading dimensions, the last dimension defines
795        a vector.  See class docstring for definition of compatibility regarding
796        batch dimensions.
797      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
798        of this `LinearOperator`:  `A^H X = rhs`.
799      name:  A name scope to use for ops added by this method.
800
801    Returns:
802      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
803
804    Raises:
805      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
806    """
807    with self._name_scope(name):  # pylint: disable=not-callable
808      block_dimensions = (self._block_domain_dimensions() if adjoint
809                          else self._block_range_dimensions())
810      if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
811        for i, block in enumerate(rhs):
812          if not isinstance(block, linear_operator.LinearOperator):
813            block = ops.convert_to_tensor_v2_with_dispatch(block)
814            self._check_input_dtype(block)
815            block_dimensions[i].assert_is_compatible_with(block.shape[-1])
816            rhs[i] = block
817        rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs]
818        solution_mat = self.solve(rhs_mat, adjoint=adjoint)
819        return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
820      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
821      self._check_input_dtype(rhs)
822      op_dimension = (self.domain_dimension if adjoint
823                      else self.range_dimension)
824      op_dimension.assert_is_compatible_with(rhs.shape[-1])
825      rhs_mat = array_ops.expand_dims(rhs, axis=-1)
826      solution_mat = self.solve(rhs_mat, adjoint=adjoint)
827      return array_ops.squeeze(solution_mat, axis=-1)
828
829  def _diag_part(self):
830    diag_list = []
831    for op in self._diagonal_operators:
832      # Extend the axis, since `broadcast_matrix_batch_dims` treats all but the
833      # final two dimensions as batch dimensions.
834      diag_list.append(op.diag_part()[..., array_ops.newaxis])
835    diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
836    diagonal = array_ops.concat(diag_list, axis=-2)
837    return array_ops.squeeze(diagonal, axis=-1)
838
839  def _trace(self):
840    result = self._diagonal_operators[0].trace()
841    for op in self._diagonal_operators[1:]:
842      result += op.trace()
843    return result
844
845  def _to_dense(self):
846    num_cols = 0
847    dense_rows = []
848    flat_broadcast_operators = linear_operator_util.broadcast_matrix_batch_dims(
849        [op.to_dense() for row in self.operators for op in row])  # pylint: disable=g-complex-comprehension
850    broadcast_operators = [
851        flat_broadcast_operators[i * (i + 1) // 2:(i + 1) * (i + 2) // 2]
852        for i in range(len(self.operators))]
853    for row_blocks in broadcast_operators:
854      batch_row_shape = array_ops.shape(row_blocks[0])[:-1]
855      num_cols += array_ops.shape(row_blocks[-1])[-1]
856      zeros_to_pad_after_shape = array_ops.concat(
857          [batch_row_shape,
858           [self.domain_dimension_tensor() - num_cols]], axis=-1)
859      zeros_to_pad_after = array_ops.zeros(
860          shape=zeros_to_pad_after_shape, dtype=self.dtype)
861
862      row_blocks.append(zeros_to_pad_after)
863      dense_rows.append(array_ops.concat(row_blocks, axis=-1))
864
865    mat = array_ops.concat(dense_rows, axis=-2)
866    mat.set_shape(self.shape)
867    return mat
868
869  def _assert_non_singular(self):
870    return control_flow_ops.group([
871        op.assert_non_singular() for op in self._diagonal_operators])
872
873  def _eigvals(self):
874    eig_list = []
875    for op in self._diagonal_operators:
876      # Extend the axis for broadcasting.
877      eig_list.append(op.eigvals()[..., array_ops.newaxis])
878    eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
879    eigs = array_ops.concat(eig_list, axis=-2)
880    return array_ops.squeeze(eigs, axis=-1)
881
882  @property
883  def _composite_tensor_fields(self):
884    return ("operators",)
885
886  @property
887  def _experimental_parameter_ndims_to_matrix_ndims(self):
888    # None of the operators contribute to the matrix shape.
889    return {"operators": nest.map_structure(lambda _: 0, self.operators)}
890