1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Create a Block Diagonal operator from one or more `LinearOperators`.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import common_shapes 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import check_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops.linalg import linear_operator 29from tensorflow.python.ops.linalg import linear_operator_algebra 30from tensorflow.python.ops.linalg import linear_operator_util 31from tensorflow.python.util.tf_export import tf_export 32 33__all__ = ["LinearOperatorBlockDiag"] 34 35 36@tf_export("linalg.LinearOperatorBlockDiag") 37@linear_operator.make_composite_tensor 38class LinearOperatorBlockDiag(linear_operator.LinearOperator): 39 """Combines one or more `LinearOperators` in to a Block Diagonal matrix. 40 41 This operator combines one or more linear operators `[op1,...,opJ]`, 42 building a new `LinearOperator`, whose underlying matrix representation 43 has each operator `opi` on the main diagonal, and zero's elsewhere. 44 45 #### Shape compatibility 46 47 If `opj` acts like a [batch] matrix `Aj`, then `op_combined` acts like 48 the [batch] matrix formed by having each matrix `Aj` on the main 49 diagonal. 50 51 Each `opj` is required to represent a matrix, and hence will have 52 shape `batch_shape_j + [M_j, N_j]`. 53 54 If `opj` has shape `batch_shape_j + [M_j, N_j]`, then the combined operator 55 has shape `broadcast_batch_shape + [sum M_j, sum N_j]`, where 56 `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`, 57 `j = 1,...,J`, assuming the intermediate batch shapes broadcast. 58 59 Arguments to `matmul`, `matvec`, `solve`, and `solvevec` may either be single 60 `Tensor`s or lists of `Tensor`s that are interpreted as blocks. The `j`th 61 element of a blockwise list of `Tensor`s must have dimensions that match 62 `opj` for the given method. If a list of blocks is input, then a list of 63 blocks is returned as well. 64 65 When the `opj` are not guaranteed to be square, this operator's methods might 66 fail due to the combined operator not being square and/or lack of efficient 67 methods. 68 69 ```python 70 # Create a 4 x 4 linear operator combined of two 2 x 2 operators. 71 operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) 72 operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]]) 73 operator = LinearOperatorBlockDiag([operator_1, operator_2]) 74 75 operator.to_dense() 76 ==> [[1., 2., 0., 0.], 77 [3., 4., 0., 0.], 78 [0., 0., 1., 0.], 79 [0., 0., 0., 1.]] 80 81 operator.shape 82 ==> [4, 4] 83 84 operator.log_abs_determinant() 85 ==> scalar Tensor 86 87 x1 = ... # Shape [2, 2] Tensor 88 x2 = ... # Shape [2, 2] Tensor 89 x = tf.concat([x1, x2], 0) # Shape [2, 4] Tensor 90 operator.matmul(x) 91 ==> tf.concat([operator_1.matmul(x1), operator_2.matmul(x2)]) 92 93 # Create a 5 x 4 linear operator combining three blocks. 94 operator_1 = LinearOperatorFullMatrix([[1.], [3.]]) 95 operator_2 = LinearOperatorFullMatrix([[1., 6.]]) 96 operator_3 = LinearOperatorFullMatrix([[2.], [7.]]) 97 operator = LinearOperatorBlockDiag([operator_1, operator_2, operator_3]) 98 99 operator.to_dense() 100 ==> [[1., 0., 0., 0.], 101 [3., 0., 0., 0.], 102 [0., 1., 6., 0.], 103 [0., 0., 0., 2.]] 104 [0., 0., 0., 7.]] 105 106 operator.shape 107 ==> [5, 4] 108 109 110 # Create a [2, 3] batch of 4 x 4 linear operators. 111 matrix_44 = tf.random.normal(shape=[2, 3, 4, 4]) 112 operator_44 = LinearOperatorFullMatrix(matrix) 113 114 # Create a [1, 3] batch of 5 x 5 linear operators. 115 matrix_55 = tf.random.normal(shape=[1, 3, 5, 5]) 116 operator_55 = LinearOperatorFullMatrix(matrix_55) 117 118 # Combine to create a [2, 3] batch of 9 x 9 operators. 119 operator_99 = LinearOperatorBlockDiag([operator_44, operator_55]) 120 121 # Create a shape [2, 3, 9] vector. 122 x = tf.random.normal(shape=[2, 3, 9]) 123 operator_99.matmul(x) 124 ==> Shape [2, 3, 9] Tensor 125 126 # Create a blockwise list of vectors. 127 x = [tf.random.normal(shape=[2, 3, 4]), tf.random.normal(shape=[2, 3, 5])] 128 operator_99.matmul(x) 129 ==> [Shape [2, 3, 4] Tensor, Shape [2, 3, 5] Tensor] 130 ``` 131 132 #### Performance 133 134 The performance of `LinearOperatorBlockDiag` on any operation is equal to 135 the sum of the individual operators' operations. 136 137 138 #### Matrix property hints 139 140 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 141 for `X = non_singular, self_adjoint, positive_definite, square`. 142 These have the following meaning: 143 144 * If `is_X == True`, callers should expect the operator to have the 145 property `X`. This is a promise that should be fulfilled, but is *not* a 146 runtime assert. For example, finite floating point precision may result 147 in these promises being violated. 148 * If `is_X == False`, callers should expect the operator to not have `X`. 149 * If `is_X == None` (the default), callers should have no expectation either 150 way. 151 """ 152 153 def __init__(self, 154 operators, 155 is_non_singular=None, 156 is_self_adjoint=None, 157 is_positive_definite=None, 158 is_square=True, 159 name=None): 160 r"""Initialize a `LinearOperatorBlockDiag`. 161 162 `LinearOperatorBlockDiag` is initialized with a list of operators 163 `[op_1,...,op_J]`. 164 165 Args: 166 operators: Iterable of `LinearOperator` objects, each with 167 the same `dtype` and composable shape. 168 is_non_singular: Expect that this operator is non-singular. 169 is_self_adjoint: Expect that this operator is equal to its hermitian 170 transpose. 171 is_positive_definite: Expect that this operator is positive definite, 172 meaning the quadratic form `x^H A x` has positive real part for all 173 nonzero `x`. Note that we do not require the operator to be 174 self-adjoint to be positive-definite. See: 175 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 176 is_square: Expect that this operator acts like square [batch] matrices. 177 This is true by default, and will raise a `ValueError` otherwise. 178 name: A name for this `LinearOperator`. Default is the individual 179 operators names joined with `_o_`. 180 181 Raises: 182 TypeError: If all operators do not have the same `dtype`. 183 ValueError: If `operators` is empty or are non-square. 184 """ 185 parameters = dict( 186 operators=operators, 187 is_non_singular=is_non_singular, 188 is_self_adjoint=is_self_adjoint, 189 is_positive_definite=is_positive_definite, 190 is_square=is_square, 191 name=name 192 ) 193 194 # Validate operators. 195 check_ops.assert_proper_iterable(operators) 196 operators = list(operators) 197 if not operators: 198 raise ValueError( 199 "Expected a non-empty list of operators. Found: %s" % operators) 200 self._operators = operators 201 202 # Define diagonal operators, for functions that are shared across blockwise 203 # `LinearOperator` types. 204 self._diagonal_operators = operators 205 206 # Validate dtype. 207 dtype = operators[0].dtype 208 for operator in operators: 209 if operator.dtype != dtype: 210 name_type = (str((o.name, o.dtype)) for o in operators) 211 raise TypeError( 212 "Expected all operators to have the same dtype. Found %s" 213 % " ".join(name_type)) 214 215 # Auto-set and check hints. 216 if all(operator.is_non_singular for operator in operators): 217 if is_non_singular is False: 218 raise ValueError( 219 "The direct sum of non-singular operators is always non-singular.") 220 is_non_singular = True 221 222 if all(operator.is_self_adjoint for operator in operators): 223 if is_self_adjoint is False: 224 raise ValueError( 225 "The direct sum of self-adjoint operators is always self-adjoint.") 226 is_self_adjoint = True 227 228 if all(operator.is_positive_definite for operator in operators): 229 if is_positive_definite is False: 230 raise ValueError( 231 "The direct sum of positive definite operators is always " 232 "positive definite.") 233 is_positive_definite = True 234 235 # Initialization. 236 graph_parents = [] 237 for operator in operators: 238 graph_parents.extend(operator.graph_parents) 239 240 if name is None: 241 # Using ds to mean direct sum. 242 name = "_ds_".join(operator.name for operator in operators) 243 with ops.name_scope(name, values=graph_parents): 244 super(LinearOperatorBlockDiag, self).__init__( 245 dtype=dtype, 246 is_non_singular=is_non_singular, 247 is_self_adjoint=is_self_adjoint, 248 is_positive_definite=is_positive_definite, 249 is_square=is_square, 250 parameters=parameters, 251 name=name) 252 253 # TODO(b/143910018) Remove graph_parents in V3. 254 self._set_graph_parents(graph_parents) 255 256 @property 257 def operators(self): 258 return self._operators 259 260 def _block_range_dimensions(self): 261 return [op.range_dimension for op in self._diagonal_operators] 262 263 def _block_domain_dimensions(self): 264 return [op.domain_dimension for op in self._diagonal_operators] 265 266 def _block_range_dimension_tensors(self): 267 return [op.range_dimension_tensor() for op in self._diagonal_operators] 268 269 def _block_domain_dimension_tensors(self): 270 return [op.domain_dimension_tensor() for op in self._diagonal_operators] 271 272 def _shape(self): 273 # Get final matrix shape. 274 domain_dimension = sum(self._block_domain_dimensions()) 275 range_dimension = sum(self._block_range_dimensions()) 276 matrix_shape = tensor_shape.TensorShape([range_dimension, domain_dimension]) 277 278 # Get broadcast batch shape. 279 # broadcast_shape checks for compatibility. 280 batch_shape = self.operators[0].batch_shape 281 for operator in self.operators[1:]: 282 batch_shape = common_shapes.broadcast_shape( 283 batch_shape, operator.batch_shape) 284 285 return batch_shape.concatenate(matrix_shape) 286 287 def _shape_tensor(self): 288 # Avoid messy broadcasting if possible. 289 if self.shape.is_fully_defined(): 290 return ops.convert_to_tensor_v2_with_dispatch( 291 self.shape.as_list(), dtype=dtypes.int32, name="shape") 292 293 domain_dimension = sum(self._block_domain_dimension_tensors()) 294 range_dimension = sum(self._block_range_dimension_tensors()) 295 matrix_shape = array_ops.stack([range_dimension, domain_dimension]) 296 297 # Dummy Tensor of zeros. Will never be materialized. 298 zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) 299 for operator in self.operators[1:]: 300 zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) 301 batch_shape = array_ops.shape(zeros) 302 303 return array_ops.concat((batch_shape, matrix_shape), 0) 304 305 # TODO(b/188080761): Add a more efficient implementation of `cond` that 306 # constructs the condition number from the blockwise singular values. 307 308 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): 309 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. 310 311 ```python 312 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 313 operator = LinearOperator(...) 314 operator.shape = [..., M, N] 315 316 X = ... # shape [..., N, R], batch matrix, R > 0. 317 318 Y = operator.matmul(X) 319 Y.shape 320 ==> [..., M, R] 321 322 Y[..., :, r] = sum_j A[..., :, j] X[j, r] 323 ``` 324 325 Args: 326 x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as 327 `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See 328 class docstring for definition of shape compatibility. 329 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 330 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is 331 the hermitian transpose (transposition and complex conjugation). 332 name: A name for this `Op`. 333 334 Returns: 335 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` 336 as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that 337 concatenate to `[..., M, R]`. 338 """ 339 def _check_operators_agree(r, l, message): 340 if (r.range_dimension is not None and 341 l.domain_dimension is not None and 342 r.range_dimension != l.domain_dimension): 343 raise ValueError(message) 344 345 if isinstance(x, linear_operator.LinearOperator): 346 left_operator = self.adjoint() if adjoint else self 347 right_operator = x.adjoint() if adjoint_arg else x 348 349 _check_operators_agree( 350 right_operator, left_operator, 351 "Operators are incompatible. Expected `x` to have dimension" 352 " {} but got {}.".format( 353 left_operator.domain_dimension, right_operator.range_dimension)) 354 355 # We can efficiently multiply BlockDiag LinearOperators if the number of 356 # blocks agree. 357 if isinstance(x, LinearOperatorBlockDiag): 358 if len(left_operator.operators) != len(right_operator.operators): 359 raise ValueError( 360 "Can not efficiently multiply two `LinearOperatorBlockDiag`s " 361 "together when number of blocks differ.") 362 363 for o1, o2 in zip(left_operator.operators, right_operator.operators): 364 _check_operators_agree( 365 o2, o1, 366 "Blocks are incompatible. Expected `x` to have dimension" 367 " {} but got {}.".format( 368 o1.domain_dimension, o2.range_dimension)) 369 370 with self._name_scope(name): # pylint: disable=not-callable 371 return linear_operator_algebra.matmul(left_operator, right_operator) 372 373 with self._name_scope(name): # pylint: disable=not-callable 374 arg_dim = -1 if adjoint_arg else -2 375 block_dimensions = (self._block_range_dimensions() if adjoint 376 else self._block_domain_dimensions()) 377 if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim): 378 for i, block in enumerate(x): 379 if not isinstance(block, linear_operator.LinearOperator): 380 block = ops.convert_to_tensor_v2_with_dispatch(block) 381 self._check_input_dtype(block) 382 block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim]) 383 x[i] = block 384 else: 385 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 386 self._check_input_dtype(x) 387 op_dimension = (self.range_dimension if adjoint 388 else self.domain_dimension) 389 op_dimension.assert_is_compatible_with(x.shape[arg_dim]) 390 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 391 392 def _matmul(self, x, adjoint=False, adjoint_arg=False): 393 arg_dim = -1 if adjoint_arg else -2 394 block_dimensions = (self._block_range_dimensions() if adjoint 395 else self._block_domain_dimensions()) 396 block_dimensions_fn = ( 397 self._block_range_dimension_tensors if adjoint 398 else self._block_domain_dimension_tensors) 399 blockwise_arg = linear_operator_util.arg_is_blockwise( 400 block_dimensions, x, arg_dim) 401 if blockwise_arg: 402 split_x = x 403 404 else: 405 split_dim = -1 if adjoint_arg else -2 406 # Split input by rows normally, and otherwise columns. 407 split_x = linear_operator_util.split_arg_into_blocks( 408 block_dimensions, block_dimensions_fn, x, axis=split_dim) 409 410 result_list = [] 411 for index, operator in enumerate(self.operators): 412 result_list += [operator.matmul( 413 split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] 414 415 if blockwise_arg: 416 return result_list 417 418 result_list = linear_operator_util.broadcast_matrix_batch_dims( 419 result_list) 420 return array_ops.concat(result_list, axis=-2) 421 422 def matvec(self, x, adjoint=False, name="matvec"): 423 """Transform [batch] vector `x` with left multiplication: `x --> Ax`. 424 425 ```python 426 # Make an operator acting like batch matric A. Assume A.shape = [..., M, N] 427 operator = LinearOperator(...) 428 429 X = ... # shape [..., N], batch vector 430 431 Y = operator.matvec(X) 432 Y.shape 433 ==> [..., M] 434 435 Y[..., :] = sum_j A[..., :, j] X[..., j] 436 ``` 437 438 Args: 439 x: `Tensor` with compatible shape and same `dtype` as `self`, or an 440 iterable of `Tensor`s (for blockwise operators). `Tensor`s are treated 441 a [batch] vectors, meaning for every set of leading dimensions, the last 442 dimension defines a vector. 443 See class docstring for definition of compatibility. 444 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 445 name: A name for this `Op`. 446 447 Returns: 448 A `Tensor` with shape `[..., M]` and same `dtype` as `self`. 449 """ 450 with self._name_scope(name): # pylint: disable=not-callable 451 block_dimensions = (self._block_range_dimensions() if adjoint 452 else self._block_domain_dimensions()) 453 if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1): 454 for i, block in enumerate(x): 455 if not isinstance(block, linear_operator.LinearOperator): 456 block = ops.convert_to_tensor_v2_with_dispatch(block) 457 self._check_input_dtype(block) 458 block_dimensions[i].assert_is_compatible_with(block.shape[-1]) 459 x[i] = block 460 x_mat = [block[..., array_ops.newaxis] for block in x] 461 y_mat = self.matmul(x_mat, adjoint=adjoint) 462 return [array_ops.squeeze(y, axis=-1) for y in y_mat] 463 464 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 465 self._check_input_dtype(x) 466 op_dimension = (self.range_dimension if adjoint 467 else self.domain_dimension) 468 op_dimension.assert_is_compatible_with(x.shape[-1]) 469 x_mat = x[..., array_ops.newaxis] 470 y_mat = self.matmul(x_mat, adjoint=adjoint) 471 return array_ops.squeeze(y_mat, axis=-1) 472 473 def _determinant(self): 474 result = self.operators[0].determinant() 475 for operator in self.operators[1:]: 476 result *= operator.determinant() 477 return result 478 479 def _log_abs_determinant(self): 480 result = self.operators[0].log_abs_determinant() 481 for operator in self.operators[1:]: 482 result += operator.log_abs_determinant() 483 return result 484 485 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): 486 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. 487 488 The returned `Tensor` will be close to an exact solution if `A` is well 489 conditioned. Otherwise closeness will vary. See class docstring for details. 490 491 Examples: 492 493 ```python 494 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 495 operator = LinearOperator(...) 496 operator.shape = [..., M, N] 497 498 # Solve R > 0 linear systems for every member of the batch. 499 RHS = ... # shape [..., M, R] 500 501 X = operator.solve(RHS) 502 # X[..., :, r] is the solution to the r'th linear system 503 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] 504 505 operator.matmul(X) 506 ==> RHS 507 ``` 508 509 Args: 510 rhs: `Tensor` with same `dtype` as this operator and compatible shape, 511 or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated 512 like a [batch] matrices meaning for every set of leading dimensions, the 513 last two dimensions defines a matrix. 514 See class docstring for definition of compatibility. 515 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 516 of this `LinearOperator`: `A^H X = rhs`. 517 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` 518 is the hermitian transpose (transposition and complex conjugation). 519 name: A name scope to use for ops added by this method. 520 521 Returns: 522 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. 523 524 Raises: 525 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 526 """ 527 if self.is_non_singular is False: 528 raise NotImplementedError( 529 "Exact solve not implemented for an operator that is expected to " 530 "be singular.") 531 if self.is_square is False: 532 raise NotImplementedError( 533 "Exact solve not implemented for an operator that is expected to " 534 "not be square.") 535 536 def _check_operators_agree(r, l, message): 537 if (r.range_dimension is not None and 538 l.domain_dimension is not None and 539 r.range_dimension != l.domain_dimension): 540 raise ValueError(message) 541 542 if isinstance(rhs, linear_operator.LinearOperator): 543 left_operator = self.adjoint() if adjoint else self 544 right_operator = rhs.adjoint() if adjoint_arg else rhs 545 546 _check_operators_agree( 547 right_operator, left_operator, 548 "Operators are incompatible. Expected `x` to have dimension" 549 " {} but got {}.".format( 550 left_operator.domain_dimension, right_operator.range_dimension)) 551 552 # We can efficiently solve BlockDiag LinearOperators if the number of 553 # blocks agree. 554 if isinstance(right_operator, LinearOperatorBlockDiag): 555 if len(left_operator.operators) != len(right_operator.operators): 556 raise ValueError( 557 "Can not efficiently solve `LinearOperatorBlockDiag` when " 558 "number of blocks differ.") 559 560 for o1, o2 in zip(left_operator.operators, right_operator.operators): 561 _check_operators_agree( 562 o2, o1, 563 "Blocks are incompatible. Expected `x` to have dimension" 564 " {} but got {}.".format( 565 o1.domain_dimension, o2.range_dimension)) 566 567 with self._name_scope(name): # pylint: disable=not-callable 568 return linear_operator_algebra.solve(left_operator, right_operator) 569 570 with self._name_scope(name): # pylint: disable=not-callable 571 block_dimensions = (self._block_domain_dimensions() if adjoint 572 else self._block_range_dimensions()) 573 arg_dim = -1 if adjoint_arg else -2 574 blockwise_arg = linear_operator_util.arg_is_blockwise( 575 block_dimensions, rhs, arg_dim) 576 577 if blockwise_arg: 578 split_rhs = rhs 579 for i, block in enumerate(split_rhs): 580 if not isinstance(block, linear_operator.LinearOperator): 581 block = ops.convert_to_tensor_v2_with_dispatch(block) 582 self._check_input_dtype(block) 583 block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim]) 584 split_rhs[i] = block 585 else: 586 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 587 self._check_input_dtype(rhs) 588 op_dimension = (self.domain_dimension if adjoint 589 else self.range_dimension) 590 op_dimension.assert_is_compatible_with(rhs.shape[arg_dim]) 591 split_dim = -1 if adjoint_arg else -2 592 # Split input by rows normally, and otherwise columns. 593 split_rhs = linear_operator_util.split_arg_into_blocks( 594 self._block_domain_dimensions(), 595 self._block_domain_dimension_tensors, 596 rhs, axis=split_dim) 597 598 solution_list = [] 599 for index, operator in enumerate(self.operators): 600 solution_list += [operator.solve( 601 split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] 602 603 if blockwise_arg: 604 return solution_list 605 606 solution_list = linear_operator_util.broadcast_matrix_batch_dims( 607 solution_list) 608 return array_ops.concat(solution_list, axis=-2) 609 610 def solvevec(self, rhs, adjoint=False, name="solve"): 611 """Solve single equation with best effort: `A X = rhs`. 612 613 The returned `Tensor` will be close to an exact solution if `A` is well 614 conditioned. Otherwise closeness will vary. See class docstring for details. 615 616 Examples: 617 618 ```python 619 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 620 operator = LinearOperator(...) 621 operator.shape = [..., M, N] 622 623 # Solve one linear system for every member of the batch. 624 RHS = ... # shape [..., M] 625 626 X = operator.solvevec(RHS) 627 # X is the solution to the linear system 628 # sum_j A[..., :, j] X[..., j] = RHS[..., :] 629 630 operator.matvec(X) 631 ==> RHS 632 ``` 633 634 Args: 635 rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s 636 (for blockwise operators). `Tensor`s are treated as [batch] vectors, 637 meaning for every set of leading dimensions, the last dimension defines 638 a vector. See class docstring for definition of compatibility regarding 639 batch dimensions. 640 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 641 of this `LinearOperator`: `A^H X = rhs`. 642 name: A name scope to use for ops added by this method. 643 644 Returns: 645 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. 646 647 Raises: 648 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 649 """ 650 with self._name_scope(name): # pylint: disable=not-callable 651 block_dimensions = (self._block_domain_dimensions() if adjoint 652 else self._block_range_dimensions()) 653 if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1): 654 for i, block in enumerate(rhs): 655 if not isinstance(block, linear_operator.LinearOperator): 656 block = ops.convert_to_tensor_v2_with_dispatch(block) 657 self._check_input_dtype(block) 658 block_dimensions[i].assert_is_compatible_with(block.shape[-1]) 659 rhs[i] = block 660 rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs] 661 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 662 return [array_ops.squeeze(x, axis=-1) for x in solution_mat] 663 664 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 665 self._check_input_dtype(rhs) 666 op_dimension = (self.domain_dimension if adjoint 667 else self.range_dimension) 668 op_dimension.assert_is_compatible_with(rhs.shape[-1]) 669 rhs_mat = array_ops.expand_dims(rhs, axis=-1) 670 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 671 return array_ops.squeeze(solution_mat, axis=-1) 672 673 def _diag_part(self): 674 if not all(operator.is_square for operator in self.operators): 675 raise NotImplementedError( 676 "`diag_part` not implemented for an operator whose blocks are not " 677 "square.") 678 diag_list = [] 679 for operator in self.operators: 680 # Extend the axis for broadcasting. 681 diag_list += [operator.diag_part()[..., array_ops.newaxis]] 682 diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) 683 diagonal = array_ops.concat(diag_list, axis=-2) 684 return array_ops.squeeze(diagonal, axis=-1) 685 686 def _trace(self): 687 if not all(operator.is_square for operator in self.operators): 688 raise NotImplementedError( 689 "`trace` not implemented for an operator whose blocks are not " 690 "square.") 691 result = self.operators[0].trace() 692 for operator in self.operators[1:]: 693 result += operator.trace() 694 return result 695 696 def _to_dense(self): 697 num_cols = 0 698 rows = [] 699 broadcasted_blocks = [operator.to_dense() for operator in self.operators] 700 broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims( 701 broadcasted_blocks) 702 for block in broadcasted_blocks: 703 batch_row_shape = array_ops.shape(block)[:-1] 704 705 zeros_to_pad_before_shape = array_ops.concat( 706 [batch_row_shape, [num_cols]], axis=-1) 707 zeros_to_pad_before = array_ops.zeros( 708 shape=zeros_to_pad_before_shape, dtype=block.dtype) 709 num_cols += array_ops.shape(block)[-1] 710 zeros_to_pad_after_shape = array_ops.concat( 711 [batch_row_shape, 712 [self.domain_dimension_tensor() - num_cols]], axis=-1) 713 zeros_to_pad_after = array_ops.zeros( 714 shape=zeros_to_pad_after_shape, dtype=block.dtype) 715 716 rows.append(array_ops.concat( 717 [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) 718 719 mat = array_ops.concat(rows, axis=-2) 720 mat.set_shape(self.shape) 721 return mat 722 723 def _assert_non_singular(self): 724 return control_flow_ops.group([ 725 operator.assert_non_singular() for operator in self.operators]) 726 727 def _assert_self_adjoint(self): 728 return control_flow_ops.group([ 729 operator.assert_self_adjoint() for operator in self.operators]) 730 731 def _assert_positive_definite(self): 732 return control_flow_ops.group([ 733 operator.assert_positive_definite() for operator in self.operators]) 734 735 def _eigvals(self): 736 if not all(operator.is_square for operator in self.operators): 737 raise NotImplementedError( 738 "`eigvals` not implemented for an operator whose blocks are not " 739 "square.") 740 eig_list = [] 741 for operator in self.operators: 742 # Extend the axis for broadcasting. 743 eig_list += [operator.eigvals()[..., array_ops.newaxis]] 744 eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list) 745 eigs = array_ops.concat(eig_list, axis=-2) 746 return array_ops.squeeze(eigs, axis=-1) 747 748 @property 749 def _composite_tensor_fields(self): 750 return ("operators",) 751