1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Create a Block Diagonal operator from one or more `LinearOperators`.""" 16 17from tensorflow.python.framework import common_shapes 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import check_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops.linalg import linear_operator 25from tensorflow.python.ops.linalg import linear_operator_algebra 26from tensorflow.python.ops.linalg import linear_operator_util 27from tensorflow.python.util.tf_export import tf_export 28 29__all__ = ["LinearOperatorBlockDiag"] 30 31 32@tf_export("linalg.LinearOperatorBlockDiag") 33@linear_operator.make_composite_tensor 34class LinearOperatorBlockDiag(linear_operator.LinearOperator): 35 """Combines one or more `LinearOperators` in to a Block Diagonal matrix. 36 37 This operator combines one or more linear operators `[op1,...,opJ]`, 38 building a new `LinearOperator`, whose underlying matrix representation 39 has each operator `opi` on the main diagonal, and zero's elsewhere. 40 41 #### Shape compatibility 42 43 If `opj` acts like a [batch] matrix `Aj`, then `op_combined` acts like 44 the [batch] matrix formed by having each matrix `Aj` on the main 45 diagonal. 46 47 Each `opj` is required to represent a matrix, and hence will have 48 shape `batch_shape_j + [M_j, N_j]`. 49 50 If `opj` has shape `batch_shape_j + [M_j, N_j]`, then the combined operator 51 has shape `broadcast_batch_shape + [sum M_j, sum N_j]`, where 52 `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`, 53 `j = 1,...,J`, assuming the intermediate batch shapes broadcast. 54 55 Arguments to `matmul`, `matvec`, `solve`, and `solvevec` may either be single 56 `Tensor`s or lists of `Tensor`s that are interpreted as blocks. The `j`th 57 element of a blockwise list of `Tensor`s must have dimensions that match 58 `opj` for the given method. If a list of blocks is input, then a list of 59 blocks is returned as well. 60 61 When the `opj` are not guaranteed to be square, this operator's methods might 62 fail due to the combined operator not being square and/or lack of efficient 63 methods. 64 65 ```python 66 # Create a 4 x 4 linear operator combined of two 2 x 2 operators. 67 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) 68 operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]]) 69 operator = LinearOperatorBlockDiag([operator_1, operator_2]) 70 71 operator.to_dense() 72 ==> [[1., 2., 0., 0.], 73 [3., 4., 0., 0.], 74 [0., 0., 1., 0.], 75 [0., 0., 0., 1.]] 76 77 operator.shape 78 ==> [4, 4] 79 80 operator.log_abs_determinant() 81 ==> scalar Tensor 82 83 x1 = ... # Shape [2, 2] Tensor 84 x2 = ... # Shape [2, 2] Tensor 85 x = tf.concat([x1, x2], 0) # Shape [2, 4] Tensor 86 operator.matmul(x) 87 ==> tf.concat([operator_1.matmul(x1), operator_2.matmul(x2)]) 88 89 # Create a 5 x 4 linear operator combining three blocks. 90 operator_1 = LinearOperatorFullMatrix([[1.], [3.]]) 91 operator_2 = LinearOperatorFullMatrix([[1., 6.]]) 92 operator_3 = LinearOperatorFullMatrix([[2.], [7.]]) 93 operator = LinearOperatorBlockDiag([operator_1, operator_2, operator_3]) 94 95 operator.to_dense() 96 ==> [[1., 0., 0., 0.], 97 [3., 0., 0., 0.], 98 [0., 1., 6., 0.], 99 [0., 0., 0., 2.]] 100 [0., 0., 0., 7.]] 101 102 operator.shape 103 ==> [5, 4] 104 105 106 # Create a [2, 3] batch of 4 x 4 linear operators. 107 matrix_44 = tf.random.normal(shape=[2, 3, 4, 4]) 108 operator_44 = LinearOperatorFullMatrix(matrix) 109 110 # Create a [1, 3] batch of 5 x 5 linear operators. 111 matrix_55 = tf.random.normal(shape=[1, 3, 5, 5]) 112 operator_55 = LinearOperatorFullMatrix(matrix_55) 113 114 # Combine to create a [2, 3] batch of 9 x 9 operators. 115 operator_99 = LinearOperatorBlockDiag([operator_44, operator_55]) 116 117 # Create a shape [2, 3, 9] vector. 118 x = tf.random.normal(shape=[2, 3, 9]) 119 operator_99.matmul(x) 120 ==> Shape [2, 3, 9] Tensor 121 122 # Create a blockwise list of vectors. 123 x = [tf.random.normal(shape=[2, 3, 4]), tf.random.normal(shape=[2, 3, 5])] 124 operator_99.matmul(x) 125 ==> [Shape [2, 3, 4] Tensor, Shape [2, 3, 5] Tensor] 126 ``` 127 128 #### Performance 129 130 The performance of `LinearOperatorBlockDiag` on any operation is equal to 131 the sum of the individual operators' operations. 132 133 134 #### Matrix property hints 135 136 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 137 for `X = non_singular, self_adjoint, positive_definite, square`. 138 These have the following meaning: 139 140 * If `is_X == True`, callers should expect the operator to have the 141 property `X`. This is a promise that should be fulfilled, but is *not* a 142 runtime assert. For example, finite floating point precision may result 143 in these promises being violated. 144 * If `is_X == False`, callers should expect the operator to not have `X`. 145 * If `is_X == None` (the default), callers should have no expectation either 146 way. 147 """ 148 149 def __init__(self, 150 operators, 151 is_non_singular=None, 152 is_self_adjoint=None, 153 is_positive_definite=None, 154 is_square=True, 155 name=None): 156 r"""Initialize a `LinearOperatorBlockDiag`. 157 158 `LinearOperatorBlockDiag` is initialized with a list of operators 159 `[op_1,...,op_J]`. 160 161 Args: 162 operators: Iterable of `LinearOperator` objects, each with 163 the same `dtype` and composable shape. 164 is_non_singular: Expect that this operator is non-singular. 165 is_self_adjoint: Expect that this operator is equal to its hermitian 166 transpose. 167 is_positive_definite: Expect that this operator is positive definite, 168 meaning the quadratic form `x^H A x` has positive real part for all 169 nonzero `x`. Note that we do not require the operator to be 170 self-adjoint to be positive-definite. See: 171 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 172 is_square: Expect that this operator acts like square [batch] matrices. 173 This is true by default, and will raise a `ValueError` otherwise. 174 name: A name for this `LinearOperator`. Default is the individual 175 operators names joined with `_o_`. 176 177 Raises: 178 TypeError: If all operators do not have the same `dtype`. 179 ValueError: If `operators` is empty or are non-square. 180 """ 181 parameters = dict( 182 operators=operators, 183 is_non_singular=is_non_singular, 184 is_self_adjoint=is_self_adjoint, 185 is_positive_definite=is_positive_definite, 186 is_square=is_square, 187 name=name 188 ) 189 190 # Validate operators. 191 check_ops.assert_proper_iterable(operators) 192 operators = list(operators) 193 if not operators: 194 raise ValueError( 195 "Expected a non-empty list of operators. Found: %s" % operators) 196 self._operators = operators 197 198 # Define diagonal operators, for functions that are shared across blockwise 199 # `LinearOperator` types. 200 self._diagonal_operators = operators 201 202 # Validate dtype. 203 dtype = operators[0].dtype 204 for operator in operators: 205 if operator.dtype != dtype: 206 name_type = (str((o.name, o.dtype)) for o in operators) 207 raise TypeError( 208 "Expected all operators to have the same dtype. Found %s" 209 % " ".join(name_type)) 210 211 # Auto-set and check hints. 212 if all(operator.is_non_singular for operator in operators): 213 if is_non_singular is False: 214 raise ValueError( 215 "The direct sum of non-singular operators is always non-singular.") 216 is_non_singular = True 217 218 if all(operator.is_self_adjoint for operator in operators): 219 if is_self_adjoint is False: 220 raise ValueError( 221 "The direct sum of self-adjoint operators is always self-adjoint.") 222 is_self_adjoint = True 223 224 if all(operator.is_positive_definite for operator in operators): 225 if is_positive_definite is False: 226 raise ValueError( 227 "The direct sum of positive definite operators is always " 228 "positive definite.") 229 is_positive_definite = True 230 231 if name is None: 232 # Using ds to mean direct sum. 233 name = "_ds_".join(operator.name for operator in operators) 234 with ops.name_scope(name): 235 super(LinearOperatorBlockDiag, self).__init__( 236 dtype=dtype, 237 is_non_singular=is_non_singular, 238 is_self_adjoint=is_self_adjoint, 239 is_positive_definite=is_positive_definite, 240 is_square=is_square, 241 parameters=parameters, 242 name=name) 243 244 @property 245 def operators(self): 246 return self._operators 247 248 def _block_range_dimensions(self): 249 return [op.range_dimension for op in self._diagonal_operators] 250 251 def _block_domain_dimensions(self): 252 return [op.domain_dimension for op in self._diagonal_operators] 253 254 def _block_range_dimension_tensors(self): 255 return [op.range_dimension_tensor() for op in self._diagonal_operators] 256 257 def _block_domain_dimension_tensors(self): 258 return [op.domain_dimension_tensor() for op in self._diagonal_operators] 259 260 def _shape(self): 261 # Get final matrix shape. 262 domain_dimension = sum(self._block_domain_dimensions()) 263 range_dimension = sum(self._block_range_dimensions()) 264 matrix_shape = tensor_shape.TensorShape([range_dimension, domain_dimension]) 265 266 # Get broadcast batch shape. 267 # broadcast_shape checks for compatibility. 268 batch_shape = self.operators[0].batch_shape 269 for operator in self.operators[1:]: 270 batch_shape = common_shapes.broadcast_shape( 271 batch_shape, operator.batch_shape) 272 273 return batch_shape.concatenate(matrix_shape) 274 275 def _shape_tensor(self): 276 # Avoid messy broadcasting if possible. 277 if self.shape.is_fully_defined(): 278 return ops.convert_to_tensor_v2_with_dispatch( 279 self.shape.as_list(), dtype=dtypes.int32, name="shape") 280 281 domain_dimension = sum(self._block_domain_dimension_tensors()) 282 range_dimension = sum(self._block_range_dimension_tensors()) 283 matrix_shape = array_ops.stack([range_dimension, domain_dimension]) 284 285 # Dummy Tensor of zeros. Will never be materialized. 286 zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) 287 for operator in self.operators[1:]: 288 zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) 289 batch_shape = array_ops.shape(zeros) 290 291 return array_ops.concat((batch_shape, matrix_shape), 0) 292 293 # TODO(b/188080761): Add a more efficient implementation of `cond` that 294 # constructs the condition number from the blockwise singular values. 295 296 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): 297 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. 298 299 ```python 300 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 301 operator = LinearOperator(...) 302 operator.shape = [..., M, N] 303 304 X = ... # shape [..., N, R], batch matrix, R > 0. 305 306 Y = operator.matmul(X) 307 Y.shape 308 ==> [..., M, R] 309 310 Y[..., :, r] = sum_j A[..., :, j] X[j, r] 311 ``` 312 313 Args: 314 x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as 315 `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See 316 class docstring for definition of shape compatibility. 317 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 318 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is 319 the hermitian transpose (transposition and complex conjugation). 320 name: A name for this `Op`. 321 322 Returns: 323 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` 324 as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that 325 concatenate to `[..., M, R]`. 326 """ 327 def _check_operators_agree(r, l, message): 328 if (r.range_dimension is not None and 329 l.domain_dimension is not None and 330 r.range_dimension != l.domain_dimension): 331 raise ValueError(message) 332 333 if isinstance(x, linear_operator.LinearOperator): 334 left_operator = self.adjoint() if adjoint else self 335 right_operator = x.adjoint() if adjoint_arg else x 336 337 _check_operators_agree( 338 right_operator, left_operator, 339 "Operators are incompatible. Expected `x` to have dimension" 340 " {} but got {}.".format( 341 left_operator.domain_dimension, right_operator.range_dimension)) 342 343 # We can efficiently multiply BlockDiag LinearOperators if the number of 344 # blocks agree. 345 if isinstance(x, LinearOperatorBlockDiag): 346 if len(left_operator.operators) != len(right_operator.operators): 347 raise ValueError( 348 "Can not efficiently multiply two `LinearOperatorBlockDiag`s " 349 "together when number of blocks differ.") 350 351 for o1, o2 in zip(left_operator.operators, right_operator.operators): 352 _check_operators_agree( 353 o2, o1, 354 "Blocks are incompatible. Expected `x` to have dimension" 355 " {} but got {}.".format( 356 o1.domain_dimension, o2.range_dimension)) 357 358 with self._name_scope(name): # pylint: disable=not-callable 359 return linear_operator_algebra.matmul(left_operator, right_operator) 360 361 with self._name_scope(name): # pylint: disable=not-callable 362 arg_dim = -1 if adjoint_arg else -2 363 block_dimensions = (self._block_range_dimensions() if adjoint 364 else self._block_domain_dimensions()) 365 if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim): 366 for i, block in enumerate(x): 367 if not isinstance(block, linear_operator.LinearOperator): 368 block = ops.convert_to_tensor_v2_with_dispatch(block) 369 self._check_input_dtype(block) 370 block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim]) 371 x[i] = block 372 else: 373 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 374 self._check_input_dtype(x) 375 op_dimension = (self.range_dimension if adjoint 376 else self.domain_dimension) 377 op_dimension.assert_is_compatible_with(x.shape[arg_dim]) 378 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 379 380 def _matmul(self, x, adjoint=False, adjoint_arg=False): 381 arg_dim = -1 if adjoint_arg else -2 382 block_dimensions = (self._block_range_dimensions() if adjoint 383 else self._block_domain_dimensions()) 384 block_dimensions_fn = ( 385 self._block_range_dimension_tensors if adjoint 386 else self._block_domain_dimension_tensors) 387 blockwise_arg = linear_operator_util.arg_is_blockwise( 388 block_dimensions, x, arg_dim) 389 if blockwise_arg: 390 split_x = x 391 392 else: 393 split_dim = -1 if adjoint_arg else -2 394 # Split input by rows normally, and otherwise columns. 395 split_x = linear_operator_util.split_arg_into_blocks( 396 block_dimensions, block_dimensions_fn, x, axis=split_dim) 397 398 result_list = [] 399 for index, operator in enumerate(self.operators): 400 result_list += [operator.matmul( 401 split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] 402 403 if blockwise_arg: 404 return result_list 405 406 result_list = linear_operator_util.broadcast_matrix_batch_dims( 407 result_list) 408 return array_ops.concat(result_list, axis=-2) 409 410 def matvec(self, x, adjoint=False, name="matvec"): 411 """Transform [batch] vector `x` with left multiplication: `x --> Ax`. 412 413 ```python 414 # Make an operator acting like batch matric A. Assume A.shape = [..., M, N] 415 operator = LinearOperator(...) 416 417 X = ... # shape [..., N], batch vector 418 419 Y = operator.matvec(X) 420 Y.shape 421 ==> [..., M] 422 423 Y[..., :] = sum_j A[..., :, j] X[..., j] 424 ``` 425 426 Args: 427 x: `Tensor` with compatible shape and same `dtype` as `self`, or an 428 iterable of `Tensor`s (for blockwise operators). `Tensor`s are treated 429 a [batch] vectors, meaning for every set of leading dimensions, the last 430 dimension defines a vector. 431 See class docstring for definition of compatibility. 432 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 433 name: A name for this `Op`. 434 435 Returns: 436 A `Tensor` with shape `[..., M]` and same `dtype` as `self`. 437 """ 438 with self._name_scope(name): # pylint: disable=not-callable 439 block_dimensions = (self._block_range_dimensions() if adjoint 440 else self._block_domain_dimensions()) 441 if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1): 442 for i, block in enumerate(x): 443 if not isinstance(block, linear_operator.LinearOperator): 444 block = ops.convert_to_tensor_v2_with_dispatch(block) 445 self._check_input_dtype(block) 446 block_dimensions[i].assert_is_compatible_with(block.shape[-1]) 447 x[i] = block 448 x_mat = [block[..., array_ops.newaxis] for block in x] 449 y_mat = self.matmul(x_mat, adjoint=adjoint) 450 return [array_ops.squeeze(y, axis=-1) for y in y_mat] 451 452 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 453 self._check_input_dtype(x) 454 op_dimension = (self.range_dimension if adjoint 455 else self.domain_dimension) 456 op_dimension.assert_is_compatible_with(x.shape[-1]) 457 x_mat = x[..., array_ops.newaxis] 458 y_mat = self.matmul(x_mat, adjoint=adjoint) 459 return array_ops.squeeze(y_mat, axis=-1) 460 461 def _determinant(self): 462 result = self.operators[0].determinant() 463 for operator in self.operators[1:]: 464 result *= operator.determinant() 465 return result 466 467 def _log_abs_determinant(self): 468 result = self.operators[0].log_abs_determinant() 469 for operator in self.operators[1:]: 470 result += operator.log_abs_determinant() 471 return result 472 473 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): 474 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. 475 476 The returned `Tensor` will be close to an exact solution if `A` is well 477 conditioned. Otherwise closeness will vary. See class docstring for details. 478 479 Examples: 480 481 ```python 482 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 483 operator = LinearOperator(...) 484 operator.shape = [..., M, N] 485 486 # Solve R > 0 linear systems for every member of the batch. 487 RHS = ... # shape [..., M, R] 488 489 X = operator.solve(RHS) 490 # X[..., :, r] is the solution to the r'th linear system 491 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] 492 493 operator.matmul(X) 494 ==> RHS 495 ``` 496 497 Args: 498 rhs: `Tensor` with same `dtype` as this operator and compatible shape, 499 or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated 500 like a [batch] matrices meaning for every set of leading dimensions, the 501 last two dimensions defines a matrix. 502 See class docstring for definition of compatibility. 503 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 504 of this `LinearOperator`: `A^H X = rhs`. 505 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` 506 is the hermitian transpose (transposition and complex conjugation). 507 name: A name scope to use for ops added by this method. 508 509 Returns: 510 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. 511 512 Raises: 513 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 514 """ 515 if self.is_non_singular is False: 516 raise NotImplementedError( 517 "Exact solve not implemented for an operator that is expected to " 518 "be singular.") 519 if self.is_square is False: 520 raise NotImplementedError( 521 "Exact solve not implemented for an operator that is expected to " 522 "not be square.") 523 524 def _check_operators_agree(r, l, message): 525 if (r.range_dimension is not None and 526 l.domain_dimension is not None and 527 r.range_dimension != l.domain_dimension): 528 raise ValueError(message) 529 530 if isinstance(rhs, linear_operator.LinearOperator): 531 left_operator = self.adjoint() if adjoint else self 532 right_operator = rhs.adjoint() if adjoint_arg else rhs 533 534 _check_operators_agree( 535 right_operator, left_operator, 536 "Operators are incompatible. Expected `x` to have dimension" 537 " {} but got {}.".format( 538 left_operator.domain_dimension, right_operator.range_dimension)) 539 540 # We can efficiently solve BlockDiag LinearOperators if the number of 541 # blocks agree. 542 if isinstance(right_operator, LinearOperatorBlockDiag): 543 if len(left_operator.operators) != len(right_operator.operators): 544 raise ValueError( 545 "Can not efficiently solve `LinearOperatorBlockDiag` when " 546 "number of blocks differ.") 547 548 for o1, o2 in zip(left_operator.operators, right_operator.operators): 549 _check_operators_agree( 550 o2, o1, 551 "Blocks are incompatible. Expected `x` to have dimension" 552 " {} but got {}.".format( 553 o1.domain_dimension, o2.range_dimension)) 554 555 with self._name_scope(name): # pylint: disable=not-callable 556 return linear_operator_algebra.solve(left_operator, right_operator) 557 558 with self._name_scope(name): # pylint: disable=not-callable 559 block_dimensions = (self._block_domain_dimensions() if adjoint 560 else self._block_range_dimensions()) 561 arg_dim = -1 if adjoint_arg else -2 562 blockwise_arg = linear_operator_util.arg_is_blockwise( 563 block_dimensions, rhs, arg_dim) 564 565 if blockwise_arg: 566 split_rhs = rhs 567 for i, block in enumerate(split_rhs): 568 if not isinstance(block, linear_operator.LinearOperator): 569 block = ops.convert_to_tensor_v2_with_dispatch(block) 570 self._check_input_dtype(block) 571 block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim]) 572 split_rhs[i] = block 573 else: 574 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 575 self._check_input_dtype(rhs) 576 op_dimension = (self.domain_dimension if adjoint 577 else self.range_dimension) 578 op_dimension.assert_is_compatible_with(rhs.shape[arg_dim]) 579 split_dim = -1 if adjoint_arg else -2 580 # Split input by rows normally, and otherwise columns. 581 split_rhs = linear_operator_util.split_arg_into_blocks( 582 self._block_domain_dimensions(), 583 self._block_domain_dimension_tensors, 584 rhs, axis=split_dim) 585 586 solution_list = [] 587 for index, operator in enumerate(self.operators): 588 solution_list += [operator.solve( 589 split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] 590 591 if blockwise_arg: 592 return solution_list 593 594 solution_list = linear_operator_util.broadcast_matrix_batch_dims( 595 solution_list) 596 return array_ops.concat(solution_list, axis=-2) 597 598 def solvevec(self, rhs, adjoint=False, name="solve"): 599 """Solve single equation with best effort: `A X = rhs`. 600 601 The returned `Tensor` will be close to an exact solution if `A` is well 602 conditioned. Otherwise closeness will vary. See class docstring for details. 603 604 Examples: 605 606 ```python 607 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 608 operator = LinearOperator(...) 609 operator.shape = [..., M, N] 610 611 # Solve one linear system for every member of the batch. 612 RHS = ... # shape [..., M] 613 614 X = operator.solvevec(RHS) 615 # X is the solution to the linear system 616 # sum_j A[..., :, j] X[..., j] = RHS[..., :] 617 618 operator.matvec(X) 619 ==> RHS 620 ``` 621 622 Args: 623 rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s 624 (for blockwise operators). `Tensor`s are treated as [batch] vectors, 625 meaning for every set of leading dimensions, the last dimension defines 626 a vector. See class docstring for definition of compatibility regarding 627 batch dimensions. 628 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 629 of this `LinearOperator`: `A^H X = rhs`. 630 name: A name scope to use for ops added by this method. 631 632 Returns: 633 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. 634 635 Raises: 636 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 637 """ 638 with self._name_scope(name): # pylint: disable=not-callable 639 block_dimensions = (self._block_domain_dimensions() if adjoint 640 else self._block_range_dimensions()) 641 if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1): 642 for i, block in enumerate(rhs): 643 if not isinstance(block, linear_operator.LinearOperator): 644 block = ops.convert_to_tensor_v2_with_dispatch(block) 645 self._check_input_dtype(block) 646 block_dimensions[i].assert_is_compatible_with(block.shape[-1]) 647 rhs[i] = block 648 rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs] 649 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 650 return [array_ops.squeeze(x, axis=-1) for x in solution_mat] 651 652 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 653 self._check_input_dtype(rhs) 654 op_dimension = (self.domain_dimension if adjoint 655 else self.range_dimension) 656 op_dimension.assert_is_compatible_with(rhs.shape[-1]) 657 rhs_mat = array_ops.expand_dims(rhs, axis=-1) 658 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 659 return array_ops.squeeze(solution_mat, axis=-1) 660 661 def _diag_part(self): 662 if not all(operator.is_square for operator in self.operators): 663 raise NotImplementedError( 664 "`diag_part` not implemented for an operator whose blocks are not " 665 "square.") 666 diag_list = [] 667 for operator in self.operators: 668 # Extend the axis for broadcasting. 669 diag_list += [operator.diag_part()[..., array_ops.newaxis]] 670 diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) 671 diagonal = array_ops.concat(diag_list, axis=-2) 672 return array_ops.squeeze(diagonal, axis=-1) 673 674 def _trace(self): 675 if not all(operator.is_square for operator in self.operators): 676 raise NotImplementedError( 677 "`trace` not implemented for an operator whose blocks are not " 678 "square.") 679 result = self.operators[0].trace() 680 for operator in self.operators[1:]: 681 result += operator.trace() 682 return result 683 684 def _to_dense(self): 685 num_cols = 0 686 rows = [] 687 broadcasted_blocks = [operator.to_dense() for operator in self.operators] 688 broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims( 689 broadcasted_blocks) 690 for block in broadcasted_blocks: 691 batch_row_shape = array_ops.shape(block)[:-1] 692 693 zeros_to_pad_before_shape = array_ops.concat( 694 [batch_row_shape, [num_cols]], axis=-1) 695 zeros_to_pad_before = array_ops.zeros( 696 shape=zeros_to_pad_before_shape, dtype=block.dtype) 697 num_cols += array_ops.shape(block)[-1] 698 zeros_to_pad_after_shape = array_ops.concat( 699 [batch_row_shape, 700 [self.domain_dimension_tensor() - num_cols]], axis=-1) 701 zeros_to_pad_after = array_ops.zeros( 702 shape=zeros_to_pad_after_shape, dtype=block.dtype) 703 704 rows.append(array_ops.concat( 705 [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) 706 707 mat = array_ops.concat(rows, axis=-2) 708 mat.set_shape(self.shape) 709 return mat 710 711 def _assert_non_singular(self): 712 return control_flow_ops.group([ 713 operator.assert_non_singular() for operator in self.operators]) 714 715 def _assert_self_adjoint(self): 716 return control_flow_ops.group([ 717 operator.assert_self_adjoint() for operator in self.operators]) 718 719 def _assert_positive_definite(self): 720 return control_flow_ops.group([ 721 operator.assert_positive_definite() for operator in self.operators]) 722 723 def _eigvals(self): 724 if not all(operator.is_square for operator in self.operators): 725 raise NotImplementedError( 726 "`eigvals` not implemented for an operator whose blocks are not " 727 "square.") 728 eig_list = [] 729 for operator in self.operators: 730 # Extend the axis for broadcasting. 731 eig_list += [operator.eigvals()[..., array_ops.newaxis]] 732 eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list) 733 eigs = array_ops.concat(eig_list, axis=-2) 734 return array_ops.squeeze(eigs, axis=-1) 735 736 @property 737 def _composite_tensor_fields(self): 738 return ("operators",) 739 740 @property 741 def _experimental_parameter_ndims_to_matrix_ndims(self): 742 return {"operators": [0] * len(self.operators)} 743