• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Create a Block Diagonal operator from one or more `LinearOperators`."""
16
17from tensorflow.python.framework import common_shapes
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops.linalg import linear_operator
25from tensorflow.python.ops.linalg import linear_operator_algebra
26from tensorflow.python.ops.linalg import linear_operator_util
27from tensorflow.python.util.tf_export import tf_export
28
29__all__ = ["LinearOperatorBlockDiag"]
30
31
32@tf_export("linalg.LinearOperatorBlockDiag")
33@linear_operator.make_composite_tensor
34class LinearOperatorBlockDiag(linear_operator.LinearOperator):
35  """Combines one or more `LinearOperators` in to a Block Diagonal matrix.
36
37  This operator combines one or more linear operators `[op1,...,opJ]`,
38  building a new `LinearOperator`, whose underlying matrix representation
39  has each operator `opi` on the main diagonal, and zero's elsewhere.
40
41  #### Shape compatibility
42
43  If `opj` acts like a [batch] matrix `Aj`, then `op_combined` acts like
44  the [batch] matrix formed by having each matrix `Aj` on the main
45  diagonal.
46
47  Each `opj` is required to represent a matrix, and hence will have
48  shape `batch_shape_j + [M_j, N_j]`.
49
50  If `opj` has shape `batch_shape_j + [M_j, N_j]`, then the combined operator
51  has shape `broadcast_batch_shape + [sum M_j, sum N_j]`, where
52  `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`,
53  `j = 1,...,J`, assuming the intermediate batch shapes broadcast.
54
55  Arguments to `matmul`, `matvec`, `solve`, and `solvevec` may either be single
56  `Tensor`s or lists of `Tensor`s that are interpreted as blocks. The `j`th
57  element of a blockwise list of `Tensor`s must have dimensions that match
58  `opj` for the given method. If a list of blocks is input, then a list of
59  blocks is returned as well.
60
61  When the `opj` are not guaranteed to be square, this operator's methods might
62  fail due to the combined operator not being square and/or lack of efficient
63  methods.
64
65  ```python
66  # Create a 4 x 4 linear operator combined of two 2 x 2 operators.
67  operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
68  operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]])
69  operator = LinearOperatorBlockDiag([operator_1, operator_2])
70
71  operator.to_dense()
72  ==> [[1., 2., 0., 0.],
73       [3., 4., 0., 0.],
74       [0., 0., 1., 0.],
75       [0., 0., 0., 1.]]
76
77  operator.shape
78  ==> [4, 4]
79
80  operator.log_abs_determinant()
81  ==> scalar Tensor
82
83  x1 = ... # Shape [2, 2] Tensor
84  x2 = ... # Shape [2, 2] Tensor
85  x = tf.concat([x1, x2], 0)  # Shape [2, 4] Tensor
86  operator.matmul(x)
87  ==> tf.concat([operator_1.matmul(x1), operator_2.matmul(x2)])
88
89  # Create a 5 x 4 linear operator combining three blocks.
90  operator_1 = LinearOperatorFullMatrix([[1.], [3.]])
91  operator_2 = LinearOperatorFullMatrix([[1., 6.]])
92  operator_3 = LinearOperatorFullMatrix([[2.], [7.]])
93  operator = LinearOperatorBlockDiag([operator_1, operator_2, operator_3])
94
95  operator.to_dense()
96  ==> [[1., 0., 0., 0.],
97       [3., 0., 0., 0.],
98       [0., 1., 6., 0.],
99       [0., 0., 0., 2.]]
100       [0., 0., 0., 7.]]
101
102  operator.shape
103  ==> [5, 4]
104
105
106  # Create a [2, 3] batch of 4 x 4 linear operators.
107  matrix_44 = tf.random.normal(shape=[2, 3, 4, 4])
108  operator_44 = LinearOperatorFullMatrix(matrix)
109
110  # Create a [1, 3] batch of 5 x 5 linear operators.
111  matrix_55 = tf.random.normal(shape=[1, 3, 5, 5])
112  operator_55 = LinearOperatorFullMatrix(matrix_55)
113
114  # Combine to create a [2, 3] batch of 9 x 9 operators.
115  operator_99 = LinearOperatorBlockDiag([operator_44, operator_55])
116
117  # Create a shape [2, 3, 9] vector.
118  x = tf.random.normal(shape=[2, 3, 9])
119  operator_99.matmul(x)
120  ==> Shape [2, 3, 9] Tensor
121
122  # Create a blockwise list of vectors.
123  x = [tf.random.normal(shape=[2, 3, 4]), tf.random.normal(shape=[2, 3, 5])]
124  operator_99.matmul(x)
125  ==> [Shape [2, 3, 4] Tensor, Shape [2, 3, 5] Tensor]
126  ```
127
128  #### Performance
129
130  The performance of `LinearOperatorBlockDiag` on any operation is equal to
131  the sum of the individual operators' operations.
132
133
134  #### Matrix property hints
135
136  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
137  for `X = non_singular, self_adjoint, positive_definite, square`.
138  These have the following meaning:
139
140  * If `is_X == True`, callers should expect the operator to have the
141    property `X`.  This is a promise that should be fulfilled, but is *not* a
142    runtime assert.  For example, finite floating point precision may result
143    in these promises being violated.
144  * If `is_X == False`, callers should expect the operator to not have `X`.
145  * If `is_X == None` (the default), callers should have no expectation either
146    way.
147  """
148
149  def __init__(self,
150               operators,
151               is_non_singular=None,
152               is_self_adjoint=None,
153               is_positive_definite=None,
154               is_square=True,
155               name=None):
156    r"""Initialize a `LinearOperatorBlockDiag`.
157
158    `LinearOperatorBlockDiag` is initialized with a list of operators
159    `[op_1,...,op_J]`.
160
161    Args:
162      operators:  Iterable of `LinearOperator` objects, each with
163        the same `dtype` and composable shape.
164      is_non_singular:  Expect that this operator is non-singular.
165      is_self_adjoint:  Expect that this operator is equal to its hermitian
166        transpose.
167      is_positive_definite:  Expect that this operator is positive definite,
168        meaning the quadratic form `x^H A x` has positive real part for all
169        nonzero `x`.  Note that we do not require the operator to be
170        self-adjoint to be positive-definite.  See:
171        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
172      is_square:  Expect that this operator acts like square [batch] matrices.
173        This is true by default, and will raise a `ValueError` otherwise.
174      name: A name for this `LinearOperator`.  Default is the individual
175        operators names joined with `_o_`.
176
177    Raises:
178      TypeError:  If all operators do not have the same `dtype`.
179      ValueError:  If `operators` is empty or are non-square.
180    """
181    parameters = dict(
182        operators=operators,
183        is_non_singular=is_non_singular,
184        is_self_adjoint=is_self_adjoint,
185        is_positive_definite=is_positive_definite,
186        is_square=is_square,
187        name=name
188    )
189
190    # Validate operators.
191    check_ops.assert_proper_iterable(operators)
192    operators = list(operators)
193    if not operators:
194      raise ValueError(
195          "Expected a non-empty list of operators. Found: %s" % operators)
196    self._operators = operators
197
198    # Define diagonal operators, for functions that are shared across blockwise
199    # `LinearOperator` types.
200    self._diagonal_operators = operators
201
202    # Validate dtype.
203    dtype = operators[0].dtype
204    for operator in operators:
205      if operator.dtype != dtype:
206        name_type = (str((o.name, o.dtype)) for o in operators)
207        raise TypeError(
208            "Expected all operators to have the same dtype.  Found %s"
209            % "   ".join(name_type))
210
211    # Auto-set and check hints.
212    if all(operator.is_non_singular for operator in operators):
213      if is_non_singular is False:
214        raise ValueError(
215            "The direct sum of non-singular operators is always non-singular.")
216      is_non_singular = True
217
218    if all(operator.is_self_adjoint for operator in operators):
219      if is_self_adjoint is False:
220        raise ValueError(
221            "The direct sum of self-adjoint operators is always self-adjoint.")
222      is_self_adjoint = True
223
224    if all(operator.is_positive_definite for operator in operators):
225      if is_positive_definite is False:
226        raise ValueError(
227            "The direct sum of positive definite operators is always "
228            "positive definite.")
229      is_positive_definite = True
230
231    if name is None:
232      # Using ds to mean direct sum.
233      name = "_ds_".join(operator.name for operator in operators)
234    with ops.name_scope(name):
235      super(LinearOperatorBlockDiag, self).__init__(
236          dtype=dtype,
237          is_non_singular=is_non_singular,
238          is_self_adjoint=is_self_adjoint,
239          is_positive_definite=is_positive_definite,
240          is_square=is_square,
241          parameters=parameters,
242          name=name)
243
244  @property
245  def operators(self):
246    return self._operators
247
248  def _block_range_dimensions(self):
249    return [op.range_dimension for op in self._diagonal_operators]
250
251  def _block_domain_dimensions(self):
252    return [op.domain_dimension for op in self._diagonal_operators]
253
254  def _block_range_dimension_tensors(self):
255    return [op.range_dimension_tensor() for op in self._diagonal_operators]
256
257  def _block_domain_dimension_tensors(self):
258    return [op.domain_dimension_tensor() for op in self._diagonal_operators]
259
260  def _shape(self):
261    # Get final matrix shape.
262    domain_dimension = sum(self._block_domain_dimensions())
263    range_dimension = sum(self._block_range_dimensions())
264    matrix_shape = tensor_shape.TensorShape([range_dimension, domain_dimension])
265
266    # Get broadcast batch shape.
267    # broadcast_shape checks for compatibility.
268    batch_shape = self.operators[0].batch_shape
269    for operator in self.operators[1:]:
270      batch_shape = common_shapes.broadcast_shape(
271          batch_shape, operator.batch_shape)
272
273    return batch_shape.concatenate(matrix_shape)
274
275  def _shape_tensor(self):
276    # Avoid messy broadcasting if possible.
277    if self.shape.is_fully_defined():
278      return ops.convert_to_tensor_v2_with_dispatch(
279          self.shape.as_list(), dtype=dtypes.int32, name="shape")
280
281    domain_dimension = sum(self._block_domain_dimension_tensors())
282    range_dimension = sum(self._block_range_dimension_tensors())
283    matrix_shape = array_ops.stack([range_dimension, domain_dimension])
284
285    # Dummy Tensor of zeros.  Will never be materialized.
286    zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
287    for operator in self.operators[1:]:
288      zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
289    batch_shape = array_ops.shape(zeros)
290
291    return array_ops.concat((batch_shape, matrix_shape), 0)
292
293  # TODO(b/188080761): Add a more efficient implementation of `cond` that
294  # constructs the condition number from the blockwise singular values.
295
296  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
297    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.
298
299    ```python
300    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
301    operator = LinearOperator(...)
302    operator.shape = [..., M, N]
303
304    X = ... # shape [..., N, R], batch matrix, R > 0.
305
306    Y = operator.matmul(X)
307    Y.shape
308    ==> [..., M, R]
309
310    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
311    ```
312
313    Args:
314      x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as
315        `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See
316        class docstring for definition of shape compatibility.
317      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
318      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
319        the hermitian transpose (transposition and complex conjugation).
320      name:  A name for this `Op`.
321
322    Returns:
323      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
324        as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
325        concatenate to `[..., M, R]`.
326    """
327    def _check_operators_agree(r, l, message):
328      if (r.range_dimension is not None and
329          l.domain_dimension is not None and
330          r.range_dimension != l.domain_dimension):
331        raise ValueError(message)
332
333    if isinstance(x, linear_operator.LinearOperator):
334      left_operator = self.adjoint() if adjoint else self
335      right_operator = x.adjoint() if adjoint_arg else x
336
337      _check_operators_agree(
338          right_operator, left_operator,
339          "Operators are incompatible. Expected `x` to have dimension"
340          " {} but got {}.".format(
341              left_operator.domain_dimension, right_operator.range_dimension))
342
343      # We can efficiently multiply BlockDiag LinearOperators if the number of
344      # blocks agree.
345      if isinstance(x, LinearOperatorBlockDiag):
346        if len(left_operator.operators) != len(right_operator.operators):
347          raise ValueError(
348              "Can not efficiently multiply two `LinearOperatorBlockDiag`s "
349              "together when number of blocks differ.")
350
351        for o1, o2 in zip(left_operator.operators, right_operator.operators):
352          _check_operators_agree(
353              o2, o1,
354              "Blocks are incompatible. Expected `x` to have dimension"
355              " {} but got {}.".format(
356                  o1.domain_dimension, o2.range_dimension))
357
358      with self._name_scope(name):  # pylint: disable=not-callable
359        return linear_operator_algebra.matmul(left_operator, right_operator)
360
361    with self._name_scope(name):  # pylint: disable=not-callable
362      arg_dim = -1 if adjoint_arg else -2
363      block_dimensions = (self._block_range_dimensions() if adjoint
364                          else self._block_domain_dimensions())
365      if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
366        for i, block in enumerate(x):
367          if not isinstance(block, linear_operator.LinearOperator):
368            block = ops.convert_to_tensor_v2_with_dispatch(block)
369            self._check_input_dtype(block)
370            block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
371            x[i] = block
372      else:
373        x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
374        self._check_input_dtype(x)
375        op_dimension = (self.range_dimension if adjoint
376                        else self.domain_dimension)
377        op_dimension.assert_is_compatible_with(x.shape[arg_dim])
378      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
379
380  def _matmul(self, x, adjoint=False, adjoint_arg=False):
381    arg_dim = -1 if adjoint_arg else -2
382    block_dimensions = (self._block_range_dimensions() if adjoint
383                        else self._block_domain_dimensions())
384    block_dimensions_fn = (
385        self._block_range_dimension_tensors if adjoint
386        else self._block_domain_dimension_tensors)
387    blockwise_arg = linear_operator_util.arg_is_blockwise(
388        block_dimensions, x, arg_dim)
389    if blockwise_arg:
390      split_x = x
391
392    else:
393      split_dim = -1 if adjoint_arg else -2
394      # Split input by rows normally, and otherwise columns.
395      split_x = linear_operator_util.split_arg_into_blocks(
396          block_dimensions, block_dimensions_fn, x, axis=split_dim)
397
398    result_list = []
399    for index, operator in enumerate(self.operators):
400      result_list += [operator.matmul(
401          split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
402
403    if blockwise_arg:
404      return result_list
405
406    result_list = linear_operator_util.broadcast_matrix_batch_dims(
407        result_list)
408    return array_ops.concat(result_list, axis=-2)
409
410  def matvec(self, x, adjoint=False, name="matvec"):
411    """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.
412
413    ```python
414    # Make an operator acting like batch matric A.  Assume A.shape = [..., M, N]
415    operator = LinearOperator(...)
416
417    X = ... # shape [..., N], batch vector
418
419    Y = operator.matvec(X)
420    Y.shape
421    ==> [..., M]
422
423    Y[..., :] = sum_j A[..., :, j] X[..., j]
424    ```
425
426    Args:
427      x: `Tensor` with compatible shape and same `dtype` as `self`, or an
428        iterable of `Tensor`s (for blockwise operators). `Tensor`s are treated
429        a [batch] vectors, meaning for every set of leading dimensions, the last
430        dimension defines a vector.
431        See class docstring for definition of compatibility.
432      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
433      name:  A name for this `Op`.
434
435    Returns:
436      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
437    """
438    with self._name_scope(name):  # pylint: disable=not-callable
439      block_dimensions = (self._block_range_dimensions() if adjoint
440                          else self._block_domain_dimensions())
441      if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
442        for i, block in enumerate(x):
443          if not isinstance(block, linear_operator.LinearOperator):
444            block = ops.convert_to_tensor_v2_with_dispatch(block)
445            self._check_input_dtype(block)
446            block_dimensions[i].assert_is_compatible_with(block.shape[-1])
447            x[i] = block
448        x_mat = [block[..., array_ops.newaxis] for block in x]
449        y_mat = self.matmul(x_mat, adjoint=adjoint)
450        return [array_ops.squeeze(y, axis=-1) for y in y_mat]
451
452      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
453      self._check_input_dtype(x)
454      op_dimension = (self.range_dimension if adjoint
455                      else self.domain_dimension)
456      op_dimension.assert_is_compatible_with(x.shape[-1])
457      x_mat = x[..., array_ops.newaxis]
458      y_mat = self.matmul(x_mat, adjoint=adjoint)
459      return array_ops.squeeze(y_mat, axis=-1)
460
461  def _determinant(self):
462    result = self.operators[0].determinant()
463    for operator in self.operators[1:]:
464      result *= operator.determinant()
465    return result
466
467  def _log_abs_determinant(self):
468    result = self.operators[0].log_abs_determinant()
469    for operator in self.operators[1:]:
470      result += operator.log_abs_determinant()
471    return result
472
473  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
474    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
475
476    The returned `Tensor` will be close to an exact solution if `A` is well
477    conditioned. Otherwise closeness will vary. See class docstring for details.
478
479    Examples:
480
481    ```python
482    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
483    operator = LinearOperator(...)
484    operator.shape = [..., M, N]
485
486    # Solve R > 0 linear systems for every member of the batch.
487    RHS = ... # shape [..., M, R]
488
489    X = operator.solve(RHS)
490    # X[..., :, r] is the solution to the r'th linear system
491    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
492
493    operator.matmul(X)
494    ==> RHS
495    ```
496
497    Args:
498      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
499        or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated
500        like a [batch] matrices meaning for every set of leading dimensions, the
501        last two dimensions defines a matrix.
502        See class docstring for definition of compatibility.
503      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
504        of this `LinearOperator`:  `A^H X = rhs`.
505      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
506        is the hermitian transpose (transposition and complex conjugation).
507      name:  A name scope to use for ops added by this method.
508
509    Returns:
510      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
511
512    Raises:
513      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
514    """
515    if self.is_non_singular is False:
516      raise NotImplementedError(
517          "Exact solve not implemented for an operator that is expected to "
518          "be singular.")
519    if self.is_square is False:
520      raise NotImplementedError(
521          "Exact solve not implemented for an operator that is expected to "
522          "not be square.")
523
524    def _check_operators_agree(r, l, message):
525      if (r.range_dimension is not None and
526          l.domain_dimension is not None and
527          r.range_dimension != l.domain_dimension):
528        raise ValueError(message)
529
530    if isinstance(rhs, linear_operator.LinearOperator):
531      left_operator = self.adjoint() if adjoint else self
532      right_operator = rhs.adjoint() if adjoint_arg else rhs
533
534      _check_operators_agree(
535          right_operator, left_operator,
536          "Operators are incompatible. Expected `x` to have dimension"
537          " {} but got {}.".format(
538              left_operator.domain_dimension, right_operator.range_dimension))
539
540      # We can efficiently solve BlockDiag LinearOperators if the number of
541      # blocks agree.
542      if isinstance(right_operator, LinearOperatorBlockDiag):
543        if len(left_operator.operators) != len(right_operator.operators):
544          raise ValueError(
545              "Can not efficiently solve `LinearOperatorBlockDiag` when "
546              "number of blocks differ.")
547
548        for o1, o2 in zip(left_operator.operators, right_operator.operators):
549          _check_operators_agree(
550              o2, o1,
551              "Blocks are incompatible. Expected `x` to have dimension"
552              " {} but got {}.".format(
553                  o1.domain_dimension, o2.range_dimension))
554
555      with self._name_scope(name):  # pylint: disable=not-callable
556        return linear_operator_algebra.solve(left_operator, right_operator)
557
558    with self._name_scope(name):  # pylint: disable=not-callable
559      block_dimensions = (self._block_domain_dimensions() if adjoint
560                          else self._block_range_dimensions())
561      arg_dim = -1 if adjoint_arg else -2
562      blockwise_arg = linear_operator_util.arg_is_blockwise(
563          block_dimensions, rhs, arg_dim)
564
565      if blockwise_arg:
566        split_rhs = rhs
567        for i, block in enumerate(split_rhs):
568          if not isinstance(block, linear_operator.LinearOperator):
569            block = ops.convert_to_tensor_v2_with_dispatch(block)
570            self._check_input_dtype(block)
571            block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
572            split_rhs[i] = block
573      else:
574        rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
575        self._check_input_dtype(rhs)
576        op_dimension = (self.domain_dimension if adjoint
577                        else self.range_dimension)
578        op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])
579        split_dim = -1 if adjoint_arg else -2
580        # Split input by rows normally, and otherwise columns.
581        split_rhs = linear_operator_util.split_arg_into_blocks(
582            self._block_domain_dimensions(),
583            self._block_domain_dimension_tensors,
584            rhs, axis=split_dim)
585
586      solution_list = []
587      for index, operator in enumerate(self.operators):
588        solution_list += [operator.solve(
589            split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
590
591      if blockwise_arg:
592        return solution_list
593
594      solution_list = linear_operator_util.broadcast_matrix_batch_dims(
595          solution_list)
596      return array_ops.concat(solution_list, axis=-2)
597
598  def solvevec(self, rhs, adjoint=False, name="solve"):
599    """Solve single equation with best effort: `A X = rhs`.
600
601    The returned `Tensor` will be close to an exact solution if `A` is well
602    conditioned. Otherwise closeness will vary. See class docstring for details.
603
604    Examples:
605
606    ```python
607    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
608    operator = LinearOperator(...)
609    operator.shape = [..., M, N]
610
611    # Solve one linear system for every member of the batch.
612    RHS = ... # shape [..., M]
613
614    X = operator.solvevec(RHS)
615    # X is the solution to the linear system
616    # sum_j A[..., :, j] X[..., j] = RHS[..., :]
617
618    operator.matvec(X)
619    ==> RHS
620    ```
621
622    Args:
623      rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s
624        (for blockwise operators). `Tensor`s are treated as [batch] vectors,
625        meaning for every set of leading dimensions, the last dimension defines
626        a vector.  See class docstring for definition of compatibility regarding
627        batch dimensions.
628      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
629        of this `LinearOperator`:  `A^H X = rhs`.
630      name:  A name scope to use for ops added by this method.
631
632    Returns:
633      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
634
635    Raises:
636      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
637    """
638    with self._name_scope(name):  # pylint: disable=not-callable
639      block_dimensions = (self._block_domain_dimensions() if adjoint
640                          else self._block_range_dimensions())
641      if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
642        for i, block in enumerate(rhs):
643          if not isinstance(block, linear_operator.LinearOperator):
644            block = ops.convert_to_tensor_v2_with_dispatch(block)
645            self._check_input_dtype(block)
646            block_dimensions[i].assert_is_compatible_with(block.shape[-1])
647            rhs[i] = block
648        rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs]
649        solution_mat = self.solve(rhs_mat, adjoint=adjoint)
650        return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
651
652      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
653      self._check_input_dtype(rhs)
654      op_dimension = (self.domain_dimension if adjoint
655                      else self.range_dimension)
656      op_dimension.assert_is_compatible_with(rhs.shape[-1])
657      rhs_mat = array_ops.expand_dims(rhs, axis=-1)
658      solution_mat = self.solve(rhs_mat, adjoint=adjoint)
659      return array_ops.squeeze(solution_mat, axis=-1)
660
661  def _diag_part(self):
662    if not all(operator.is_square for operator in self.operators):
663      raise NotImplementedError(
664          "`diag_part` not implemented for an operator whose blocks are not "
665          "square.")
666    diag_list = []
667    for operator in self.operators:
668      # Extend the axis for broadcasting.
669      diag_list += [operator.diag_part()[..., array_ops.newaxis]]
670    diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
671    diagonal = array_ops.concat(diag_list, axis=-2)
672    return array_ops.squeeze(diagonal, axis=-1)
673
674  def _trace(self):
675    if not all(operator.is_square for operator in self.operators):
676      raise NotImplementedError(
677          "`trace` not implemented for an operator whose blocks are not "
678          "square.")
679    result = self.operators[0].trace()
680    for operator in self.operators[1:]:
681      result += operator.trace()
682    return result
683
684  def _to_dense(self):
685    num_cols = 0
686    rows = []
687    broadcasted_blocks = [operator.to_dense() for operator in self.operators]
688    broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims(
689        broadcasted_blocks)
690    for block in broadcasted_blocks:
691      batch_row_shape = array_ops.shape(block)[:-1]
692
693      zeros_to_pad_before_shape = array_ops.concat(
694          [batch_row_shape, [num_cols]], axis=-1)
695      zeros_to_pad_before = array_ops.zeros(
696          shape=zeros_to_pad_before_shape, dtype=block.dtype)
697      num_cols += array_ops.shape(block)[-1]
698      zeros_to_pad_after_shape = array_ops.concat(
699          [batch_row_shape,
700           [self.domain_dimension_tensor() - num_cols]], axis=-1)
701      zeros_to_pad_after = array_ops.zeros(
702          shape=zeros_to_pad_after_shape, dtype=block.dtype)
703
704      rows.append(array_ops.concat(
705          [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1))
706
707    mat = array_ops.concat(rows, axis=-2)
708    mat.set_shape(self.shape)
709    return mat
710
711  def _assert_non_singular(self):
712    return control_flow_ops.group([
713        operator.assert_non_singular() for operator in self.operators])
714
715  def _assert_self_adjoint(self):
716    return control_flow_ops.group([
717        operator.assert_self_adjoint() for operator in self.operators])
718
719  def _assert_positive_definite(self):
720    return control_flow_ops.group([
721        operator.assert_positive_definite() for operator in self.operators])
722
723  def _eigvals(self):
724    if not all(operator.is_square for operator in self.operators):
725      raise NotImplementedError(
726          "`eigvals` not implemented for an operator whose blocks are not "
727          "square.")
728    eig_list = []
729    for operator in self.operators:
730      # Extend the axis for broadcasting.
731      eig_list += [operator.eigvals()[..., array_ops.newaxis]]
732    eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
733    eigs = array_ops.concat(eig_list, axis=-2)
734    return array_ops.squeeze(eigs, axis=-1)
735
736  @property
737  def _composite_tensor_fields(self):
738    return ("operators",)
739
740  @property
741  def _experimental_parameter_ndims_to_matrix_ndims(self):
742    return {"operators": [0] * len(self.operators)}
743