• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Create a Block Diagonal operator from one or more `LinearOperators`."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import common_shapes
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops.linalg import linear_operator
29from tensorflow.python.ops.linalg import linear_operator_algebra
30from tensorflow.python.ops.linalg import linear_operator_util
31from tensorflow.python.util.tf_export import tf_export
32
33__all__ = ["LinearOperatorBlockDiag"]
34
35
36@tf_export("linalg.LinearOperatorBlockDiag")
37@linear_operator.make_composite_tensor
38class LinearOperatorBlockDiag(linear_operator.LinearOperator):
39  """Combines one or more `LinearOperators` in to a Block Diagonal matrix.
40
41  This operator combines one or more linear operators `[op1,...,opJ]`,
42  building a new `LinearOperator`, whose underlying matrix representation
43  has each operator `opi` on the main diagonal, and zero's elsewhere.
44
45  #### Shape compatibility
46
47  If `opj` acts like a [batch] matrix `Aj`, then `op_combined` acts like
48  the [batch] matrix formed by having each matrix `Aj` on the main
49  diagonal.
50
51  Each `opj` is required to represent a matrix, and hence will have
52  shape `batch_shape_j + [M_j, N_j]`.
53
54  If `opj` has shape `batch_shape_j + [M_j, N_j]`, then the combined operator
55  has shape `broadcast_batch_shape + [sum M_j, sum N_j]`, where
56  `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`,
57  `j = 1,...,J`, assuming the intermediate batch shapes broadcast.
58
59  Arguments to `matmul`, `matvec`, `solve`, and `solvevec` may either be single
60  `Tensor`s or lists of `Tensor`s that are interpreted as blocks. The `j`th
61  element of a blockwise list of `Tensor`s must have dimensions that match
62  `opj` for the given method. If a list of blocks is input, then a list of
63  blocks is returned as well.
64
65  When the `opj` are not guaranteed to be square, this operator's methods might
66  fail due to the combined operator not being square and/or lack of efficient
67  methods.
68
69  ```python
70  # Create a 4 x 4 linear operator combined of two 2 x 2 operators.
71  operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
72  operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]])
73  operator = LinearOperatorBlockDiag([operator_1, operator_2])
74
75  operator.to_dense()
76  ==> [[1., 2., 0., 0.],
77       [3., 4., 0., 0.],
78       [0., 0., 1., 0.],
79       [0., 0., 0., 1.]]
80
81  operator.shape
82  ==> [4, 4]
83
84  operator.log_abs_determinant()
85  ==> scalar Tensor
86
87  x1 = ... # Shape [2, 2] Tensor
88  x2 = ... # Shape [2, 2] Tensor
89  x = tf.concat([x1, x2], 0)  # Shape [2, 4] Tensor
90  operator.matmul(x)
91  ==> tf.concat([operator_1.matmul(x1), operator_2.matmul(x2)])
92
93  # Create a 5 x 4 linear operator combining three blocks.
94  operator_1 = LinearOperatorFullMatrix([[1.], [3.]])
95  operator_2 = LinearOperatorFullMatrix([[1., 6.]])
96  operator_3 = LinearOperatorFullMatrix([[2.], [7.]])
97  operator = LinearOperatorBlockDiag([operator_1, operator_2, operator_3])
98
99  operator.to_dense()
100  ==> [[1., 0., 0., 0.],
101       [3., 0., 0., 0.],
102       [0., 1., 6., 0.],
103       [0., 0., 0., 2.]]
104       [0., 0., 0., 7.]]
105
106  operator.shape
107  ==> [5, 4]
108
109
110  # Create a [2, 3] batch of 4 x 4 linear operators.
111  matrix_44 = tf.random.normal(shape=[2, 3, 4, 4])
112  operator_44 = LinearOperatorFullMatrix(matrix)
113
114  # Create a [1, 3] batch of 5 x 5 linear operators.
115  matrix_55 = tf.random.normal(shape=[1, 3, 5, 5])
116  operator_55 = LinearOperatorFullMatrix(matrix_55)
117
118  # Combine to create a [2, 3] batch of 9 x 9 operators.
119  operator_99 = LinearOperatorBlockDiag([operator_44, operator_55])
120
121  # Create a shape [2, 3, 9] vector.
122  x = tf.random.normal(shape=[2, 3, 9])
123  operator_99.matmul(x)
124  ==> Shape [2, 3, 9] Tensor
125
126  # Create a blockwise list of vectors.
127  x = [tf.random.normal(shape=[2, 3, 4]), tf.random.normal(shape=[2, 3, 5])]
128  operator_99.matmul(x)
129  ==> [Shape [2, 3, 4] Tensor, Shape [2, 3, 5] Tensor]
130  ```
131
132  #### Performance
133
134  The performance of `LinearOperatorBlockDiag` on any operation is equal to
135  the sum of the individual operators' operations.
136
137
138  #### Matrix property hints
139
140  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
141  for `X = non_singular, self_adjoint, positive_definite, square`.
142  These have the following meaning:
143
144  * If `is_X == True`, callers should expect the operator to have the
145    property `X`.  This is a promise that should be fulfilled, but is *not* a
146    runtime assert.  For example, finite floating point precision may result
147    in these promises being violated.
148  * If `is_X == False`, callers should expect the operator to not have `X`.
149  * If `is_X == None` (the default), callers should have no expectation either
150    way.
151  """
152
153  def __init__(self,
154               operators,
155               is_non_singular=None,
156               is_self_adjoint=None,
157               is_positive_definite=None,
158               is_square=True,
159               name=None):
160    r"""Initialize a `LinearOperatorBlockDiag`.
161
162    `LinearOperatorBlockDiag` is initialized with a list of operators
163    `[op_1,...,op_J]`.
164
165    Args:
166      operators:  Iterable of `LinearOperator` objects, each with
167        the same `dtype` and composable shape.
168      is_non_singular:  Expect that this operator is non-singular.
169      is_self_adjoint:  Expect that this operator is equal to its hermitian
170        transpose.
171      is_positive_definite:  Expect that this operator is positive definite,
172        meaning the quadratic form `x^H A x` has positive real part for all
173        nonzero `x`.  Note that we do not require the operator to be
174        self-adjoint to be positive-definite.  See:
175        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
176      is_square:  Expect that this operator acts like square [batch] matrices.
177        This is true by default, and will raise a `ValueError` otherwise.
178      name: A name for this `LinearOperator`.  Default is the individual
179        operators names joined with `_o_`.
180
181    Raises:
182      TypeError:  If all operators do not have the same `dtype`.
183      ValueError:  If `operators` is empty or are non-square.
184    """
185    parameters = dict(
186        operators=operators,
187        is_non_singular=is_non_singular,
188        is_self_adjoint=is_self_adjoint,
189        is_positive_definite=is_positive_definite,
190        is_square=is_square,
191        name=name
192    )
193
194    # Validate operators.
195    check_ops.assert_proper_iterable(operators)
196    operators = list(operators)
197    if not operators:
198      raise ValueError(
199          "Expected a non-empty list of operators. Found: %s" % operators)
200    self._operators = operators
201
202    # Define diagonal operators, for functions that are shared across blockwise
203    # `LinearOperator` types.
204    self._diagonal_operators = operators
205
206    # Validate dtype.
207    dtype = operators[0].dtype
208    for operator in operators:
209      if operator.dtype != dtype:
210        name_type = (str((o.name, o.dtype)) for o in operators)
211        raise TypeError(
212            "Expected all operators to have the same dtype.  Found %s"
213            % "   ".join(name_type))
214
215    # Auto-set and check hints.
216    if all(operator.is_non_singular for operator in operators):
217      if is_non_singular is False:
218        raise ValueError(
219            "The direct sum of non-singular operators is always non-singular.")
220      is_non_singular = True
221
222    if all(operator.is_self_adjoint for operator in operators):
223      if is_self_adjoint is False:
224        raise ValueError(
225            "The direct sum of self-adjoint operators is always self-adjoint.")
226      is_self_adjoint = True
227
228    if all(operator.is_positive_definite for operator in operators):
229      if is_positive_definite is False:
230        raise ValueError(
231            "The direct sum of positive definite operators is always "
232            "positive definite.")
233      is_positive_definite = True
234
235    # Initialization.
236    graph_parents = []
237    for operator in operators:
238      graph_parents.extend(operator.graph_parents)
239
240    if name is None:
241      # Using ds to mean direct sum.
242      name = "_ds_".join(operator.name for operator in operators)
243    with ops.name_scope(name, values=graph_parents):
244      super(LinearOperatorBlockDiag, self).__init__(
245          dtype=dtype,
246          is_non_singular=is_non_singular,
247          is_self_adjoint=is_self_adjoint,
248          is_positive_definite=is_positive_definite,
249          is_square=is_square,
250          parameters=parameters,
251          name=name)
252
253    # TODO(b/143910018) Remove graph_parents in V3.
254    self._set_graph_parents(graph_parents)
255
256  @property
257  def operators(self):
258    return self._operators
259
260  def _block_range_dimensions(self):
261    return [op.range_dimension for op in self._diagonal_operators]
262
263  def _block_domain_dimensions(self):
264    return [op.domain_dimension for op in self._diagonal_operators]
265
266  def _block_range_dimension_tensors(self):
267    return [op.range_dimension_tensor() for op in self._diagonal_operators]
268
269  def _block_domain_dimension_tensors(self):
270    return [op.domain_dimension_tensor() for op in self._diagonal_operators]
271
272  def _shape(self):
273    # Get final matrix shape.
274    domain_dimension = sum(self._block_domain_dimensions())
275    range_dimension = sum(self._block_range_dimensions())
276    matrix_shape = tensor_shape.TensorShape([range_dimension, domain_dimension])
277
278    # Get broadcast batch shape.
279    # broadcast_shape checks for compatibility.
280    batch_shape = self.operators[0].batch_shape
281    for operator in self.operators[1:]:
282      batch_shape = common_shapes.broadcast_shape(
283          batch_shape, operator.batch_shape)
284
285    return batch_shape.concatenate(matrix_shape)
286
287  def _shape_tensor(self):
288    # Avoid messy broadcasting if possible.
289    if self.shape.is_fully_defined():
290      return ops.convert_to_tensor_v2_with_dispatch(
291          self.shape.as_list(), dtype=dtypes.int32, name="shape")
292
293    domain_dimension = sum(self._block_domain_dimension_tensors())
294    range_dimension = sum(self._block_range_dimension_tensors())
295    matrix_shape = array_ops.stack([range_dimension, domain_dimension])
296
297    # Dummy Tensor of zeros.  Will never be materialized.
298    zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
299    for operator in self.operators[1:]:
300      zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
301    batch_shape = array_ops.shape(zeros)
302
303    return array_ops.concat((batch_shape, matrix_shape), 0)
304
305  # TODO(b/188080761): Add a more efficient implementation of `cond` that
306  # constructs the condition number from the blockwise singular values.
307
308  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
309    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.
310
311    ```python
312    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
313    operator = LinearOperator(...)
314    operator.shape = [..., M, N]
315
316    X = ... # shape [..., N, R], batch matrix, R > 0.
317
318    Y = operator.matmul(X)
319    Y.shape
320    ==> [..., M, R]
321
322    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
323    ```
324
325    Args:
326      x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as
327        `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See
328        class docstring for definition of shape compatibility.
329      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
330      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
331        the hermitian transpose (transposition and complex conjugation).
332      name:  A name for this `Op`.
333
334    Returns:
335      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
336        as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
337        concatenate to `[..., M, R]`.
338    """
339    def _check_operators_agree(r, l, message):
340      if (r.range_dimension is not None and
341          l.domain_dimension is not None and
342          r.range_dimension != l.domain_dimension):
343        raise ValueError(message)
344
345    if isinstance(x, linear_operator.LinearOperator):
346      left_operator = self.adjoint() if adjoint else self
347      right_operator = x.adjoint() if adjoint_arg else x
348
349      _check_operators_agree(
350          right_operator, left_operator,
351          "Operators are incompatible. Expected `x` to have dimension"
352          " {} but got {}.".format(
353              left_operator.domain_dimension, right_operator.range_dimension))
354
355      # We can efficiently multiply BlockDiag LinearOperators if the number of
356      # blocks agree.
357      if isinstance(x, LinearOperatorBlockDiag):
358        if len(left_operator.operators) != len(right_operator.operators):
359          raise ValueError(
360              "Can not efficiently multiply two `LinearOperatorBlockDiag`s "
361              "together when number of blocks differ.")
362
363        for o1, o2 in zip(left_operator.operators, right_operator.operators):
364          _check_operators_agree(
365              o2, o1,
366              "Blocks are incompatible. Expected `x` to have dimension"
367              " {} but got {}.".format(
368                  o1.domain_dimension, o2.range_dimension))
369
370      with self._name_scope(name):  # pylint: disable=not-callable
371        return linear_operator_algebra.matmul(left_operator, right_operator)
372
373    with self._name_scope(name):  # pylint: disable=not-callable
374      arg_dim = -1 if adjoint_arg else -2
375      block_dimensions = (self._block_range_dimensions() if adjoint
376                          else self._block_domain_dimensions())
377      if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
378        for i, block in enumerate(x):
379          if not isinstance(block, linear_operator.LinearOperator):
380            block = ops.convert_to_tensor_v2_with_dispatch(block)
381            self._check_input_dtype(block)
382            block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
383            x[i] = block
384      else:
385        x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
386        self._check_input_dtype(x)
387        op_dimension = (self.range_dimension if adjoint
388                        else self.domain_dimension)
389        op_dimension.assert_is_compatible_with(x.shape[arg_dim])
390      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
391
392  def _matmul(self, x, adjoint=False, adjoint_arg=False):
393    arg_dim = -1 if adjoint_arg else -2
394    block_dimensions = (self._block_range_dimensions() if adjoint
395                        else self._block_domain_dimensions())
396    block_dimensions_fn = (
397        self._block_range_dimension_tensors if adjoint
398        else self._block_domain_dimension_tensors)
399    blockwise_arg = linear_operator_util.arg_is_blockwise(
400        block_dimensions, x, arg_dim)
401    if blockwise_arg:
402      split_x = x
403
404    else:
405      split_dim = -1 if adjoint_arg else -2
406      # Split input by rows normally, and otherwise columns.
407      split_x = linear_operator_util.split_arg_into_blocks(
408          block_dimensions, block_dimensions_fn, x, axis=split_dim)
409
410    result_list = []
411    for index, operator in enumerate(self.operators):
412      result_list += [operator.matmul(
413          split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
414
415    if blockwise_arg:
416      return result_list
417
418    result_list = linear_operator_util.broadcast_matrix_batch_dims(
419        result_list)
420    return array_ops.concat(result_list, axis=-2)
421
422  def matvec(self, x, adjoint=False, name="matvec"):
423    """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.
424
425    ```python
426    # Make an operator acting like batch matric A.  Assume A.shape = [..., M, N]
427    operator = LinearOperator(...)
428
429    X = ... # shape [..., N], batch vector
430
431    Y = operator.matvec(X)
432    Y.shape
433    ==> [..., M]
434
435    Y[..., :] = sum_j A[..., :, j] X[..., j]
436    ```
437
438    Args:
439      x: `Tensor` with compatible shape and same `dtype` as `self`, or an
440        iterable of `Tensor`s (for blockwise operators). `Tensor`s are treated
441        a [batch] vectors, meaning for every set of leading dimensions, the last
442        dimension defines a vector.
443        See class docstring for definition of compatibility.
444      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
445      name:  A name for this `Op`.
446
447    Returns:
448      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
449    """
450    with self._name_scope(name):  # pylint: disable=not-callable
451      block_dimensions = (self._block_range_dimensions() if adjoint
452                          else self._block_domain_dimensions())
453      if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
454        for i, block in enumerate(x):
455          if not isinstance(block, linear_operator.LinearOperator):
456            block = ops.convert_to_tensor_v2_with_dispatch(block)
457            self._check_input_dtype(block)
458            block_dimensions[i].assert_is_compatible_with(block.shape[-1])
459            x[i] = block
460        x_mat = [block[..., array_ops.newaxis] for block in x]
461        y_mat = self.matmul(x_mat, adjoint=adjoint)
462        return [array_ops.squeeze(y, axis=-1) for y in y_mat]
463
464      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
465      self._check_input_dtype(x)
466      op_dimension = (self.range_dimension if adjoint
467                      else self.domain_dimension)
468      op_dimension.assert_is_compatible_with(x.shape[-1])
469      x_mat = x[..., array_ops.newaxis]
470      y_mat = self.matmul(x_mat, adjoint=adjoint)
471      return array_ops.squeeze(y_mat, axis=-1)
472
473  def _determinant(self):
474    result = self.operators[0].determinant()
475    for operator in self.operators[1:]:
476      result *= operator.determinant()
477    return result
478
479  def _log_abs_determinant(self):
480    result = self.operators[0].log_abs_determinant()
481    for operator in self.operators[1:]:
482      result += operator.log_abs_determinant()
483    return result
484
485  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
486    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
487
488    The returned `Tensor` will be close to an exact solution if `A` is well
489    conditioned. Otherwise closeness will vary. See class docstring for details.
490
491    Examples:
492
493    ```python
494    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
495    operator = LinearOperator(...)
496    operator.shape = [..., M, N]
497
498    # Solve R > 0 linear systems for every member of the batch.
499    RHS = ... # shape [..., M, R]
500
501    X = operator.solve(RHS)
502    # X[..., :, r] is the solution to the r'th linear system
503    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
504
505    operator.matmul(X)
506    ==> RHS
507    ```
508
509    Args:
510      rhs: `Tensor` with same `dtype` as this operator and compatible shape,
511        or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated
512        like a [batch] matrices meaning for every set of leading dimensions, the
513        last two dimensions defines a matrix.
514        See class docstring for definition of compatibility.
515      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
516        of this `LinearOperator`:  `A^H X = rhs`.
517      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
518        is the hermitian transpose (transposition and complex conjugation).
519      name:  A name scope to use for ops added by this method.
520
521    Returns:
522      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
523
524    Raises:
525      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
526    """
527    if self.is_non_singular is False:
528      raise NotImplementedError(
529          "Exact solve not implemented for an operator that is expected to "
530          "be singular.")
531    if self.is_square is False:
532      raise NotImplementedError(
533          "Exact solve not implemented for an operator that is expected to "
534          "not be square.")
535
536    def _check_operators_agree(r, l, message):
537      if (r.range_dimension is not None and
538          l.domain_dimension is not None and
539          r.range_dimension != l.domain_dimension):
540        raise ValueError(message)
541
542    if isinstance(rhs, linear_operator.LinearOperator):
543      left_operator = self.adjoint() if adjoint else self
544      right_operator = rhs.adjoint() if adjoint_arg else rhs
545
546      _check_operators_agree(
547          right_operator, left_operator,
548          "Operators are incompatible. Expected `x` to have dimension"
549          " {} but got {}.".format(
550              left_operator.domain_dimension, right_operator.range_dimension))
551
552      # We can efficiently solve BlockDiag LinearOperators if the number of
553      # blocks agree.
554      if isinstance(right_operator, LinearOperatorBlockDiag):
555        if len(left_operator.operators) != len(right_operator.operators):
556          raise ValueError(
557              "Can not efficiently solve `LinearOperatorBlockDiag` when "
558              "number of blocks differ.")
559
560        for o1, o2 in zip(left_operator.operators, right_operator.operators):
561          _check_operators_agree(
562              o2, o1,
563              "Blocks are incompatible. Expected `x` to have dimension"
564              " {} but got {}.".format(
565                  o1.domain_dimension, o2.range_dimension))
566
567      with self._name_scope(name):  # pylint: disable=not-callable
568        return linear_operator_algebra.solve(left_operator, right_operator)
569
570    with self._name_scope(name):  # pylint: disable=not-callable
571      block_dimensions = (self._block_domain_dimensions() if adjoint
572                          else self._block_range_dimensions())
573      arg_dim = -1 if adjoint_arg else -2
574      blockwise_arg = linear_operator_util.arg_is_blockwise(
575          block_dimensions, rhs, arg_dim)
576
577      if blockwise_arg:
578        split_rhs = rhs
579        for i, block in enumerate(split_rhs):
580          if not isinstance(block, linear_operator.LinearOperator):
581            block = ops.convert_to_tensor_v2_with_dispatch(block)
582            self._check_input_dtype(block)
583            block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
584            split_rhs[i] = block
585      else:
586        rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
587        self._check_input_dtype(rhs)
588        op_dimension = (self.domain_dimension if adjoint
589                        else self.range_dimension)
590        op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])
591        split_dim = -1 if adjoint_arg else -2
592        # Split input by rows normally, and otherwise columns.
593        split_rhs = linear_operator_util.split_arg_into_blocks(
594            self._block_domain_dimensions(),
595            self._block_domain_dimension_tensors,
596            rhs, axis=split_dim)
597
598      solution_list = []
599      for index, operator in enumerate(self.operators):
600        solution_list += [operator.solve(
601            split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
602
603      if blockwise_arg:
604        return solution_list
605
606      solution_list = linear_operator_util.broadcast_matrix_batch_dims(
607          solution_list)
608      return array_ops.concat(solution_list, axis=-2)
609
610  def solvevec(self, rhs, adjoint=False, name="solve"):
611    """Solve single equation with best effort: `A X = rhs`.
612
613    The returned `Tensor` will be close to an exact solution if `A` is well
614    conditioned. Otherwise closeness will vary. See class docstring for details.
615
616    Examples:
617
618    ```python
619    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
620    operator = LinearOperator(...)
621    operator.shape = [..., M, N]
622
623    # Solve one linear system for every member of the batch.
624    RHS = ... # shape [..., M]
625
626    X = operator.solvevec(RHS)
627    # X is the solution to the linear system
628    # sum_j A[..., :, j] X[..., j] = RHS[..., :]
629
630    operator.matvec(X)
631    ==> RHS
632    ```
633
634    Args:
635      rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s
636        (for blockwise operators). `Tensor`s are treated as [batch] vectors,
637        meaning for every set of leading dimensions, the last dimension defines
638        a vector.  See class docstring for definition of compatibility regarding
639        batch dimensions.
640      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
641        of this `LinearOperator`:  `A^H X = rhs`.
642      name:  A name scope to use for ops added by this method.
643
644    Returns:
645      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
646
647    Raises:
648      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
649    """
650    with self._name_scope(name):  # pylint: disable=not-callable
651      block_dimensions = (self._block_domain_dimensions() if adjoint
652                          else self._block_range_dimensions())
653      if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
654        for i, block in enumerate(rhs):
655          if not isinstance(block, linear_operator.LinearOperator):
656            block = ops.convert_to_tensor_v2_with_dispatch(block)
657            self._check_input_dtype(block)
658            block_dimensions[i].assert_is_compatible_with(block.shape[-1])
659            rhs[i] = block
660        rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs]
661        solution_mat = self.solve(rhs_mat, adjoint=adjoint)
662        return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
663
664      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
665      self._check_input_dtype(rhs)
666      op_dimension = (self.domain_dimension if adjoint
667                      else self.range_dimension)
668      op_dimension.assert_is_compatible_with(rhs.shape[-1])
669      rhs_mat = array_ops.expand_dims(rhs, axis=-1)
670      solution_mat = self.solve(rhs_mat, adjoint=adjoint)
671      return array_ops.squeeze(solution_mat, axis=-1)
672
673  def _diag_part(self):
674    if not all(operator.is_square for operator in self.operators):
675      raise NotImplementedError(
676          "`diag_part` not implemented for an operator whose blocks are not "
677          "square.")
678    diag_list = []
679    for operator in self.operators:
680      # Extend the axis for broadcasting.
681      diag_list += [operator.diag_part()[..., array_ops.newaxis]]
682    diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
683    diagonal = array_ops.concat(diag_list, axis=-2)
684    return array_ops.squeeze(diagonal, axis=-1)
685
686  def _trace(self):
687    if not all(operator.is_square for operator in self.operators):
688      raise NotImplementedError(
689          "`trace` not implemented for an operator whose blocks are not "
690          "square.")
691    result = self.operators[0].trace()
692    for operator in self.operators[1:]:
693      result += operator.trace()
694    return result
695
696  def _to_dense(self):
697    num_cols = 0
698    rows = []
699    broadcasted_blocks = [operator.to_dense() for operator in self.operators]
700    broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims(
701        broadcasted_blocks)
702    for block in broadcasted_blocks:
703      batch_row_shape = array_ops.shape(block)[:-1]
704
705      zeros_to_pad_before_shape = array_ops.concat(
706          [batch_row_shape, [num_cols]], axis=-1)
707      zeros_to_pad_before = array_ops.zeros(
708          shape=zeros_to_pad_before_shape, dtype=block.dtype)
709      num_cols += array_ops.shape(block)[-1]
710      zeros_to_pad_after_shape = array_ops.concat(
711          [batch_row_shape,
712           [self.domain_dimension_tensor() - num_cols]], axis=-1)
713      zeros_to_pad_after = array_ops.zeros(
714          shape=zeros_to_pad_after_shape, dtype=block.dtype)
715
716      rows.append(array_ops.concat(
717          [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1))
718
719    mat = array_ops.concat(rows, axis=-2)
720    mat.set_shape(self.shape)
721    return mat
722
723  def _assert_non_singular(self):
724    return control_flow_ops.group([
725        operator.assert_non_singular() for operator in self.operators])
726
727  def _assert_self_adjoint(self):
728    return control_flow_ops.group([
729        operator.assert_self_adjoint() for operator in self.operators])
730
731  def _assert_positive_definite(self):
732    return control_flow_ops.group([
733        operator.assert_positive_definite() for operator in self.operators])
734
735  def _eigvals(self):
736    if not all(operator.is_square for operator in self.operators):
737      raise NotImplementedError(
738          "`eigvals` not implemented for an operator whose blocks are not "
739          "square.")
740    eig_list = []
741    for operator in self.operators:
742      # Extend the axis for broadcasting.
743      eig_list += [operator.eigvals()[..., array_ops.newaxis]]
744    eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
745    eigs = array_ops.concat(eig_list, axis=-2)
746    return array_ops.squeeze(eigs, axis=-1)
747
748  @property
749  def _composite_tensor_fields(self):
750    return ("operators",)
751