1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Create a blockwise lower-triangular operator from `LinearOperators`.""" 16 17from tensorflow.python.framework import common_shapes 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import check_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops.linalg import linalg_impl as linalg 26from tensorflow.python.ops.linalg import linear_operator 27from tensorflow.python.ops.linalg import linear_operator_algebra 28from tensorflow.python.ops.linalg import linear_operator_util 29from tensorflow.python.util import nest 30from tensorflow.python.util.tf_export import tf_export 31 32__all__ = ["LinearOperatorBlockLowerTriangular"] 33 34 35@tf_export("linalg.LinearOperatorBlockLowerTriangular") 36@linear_operator.make_composite_tensor 37class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator): 38 """Combines `LinearOperators` into a blockwise lower-triangular matrix. 39 40 This operator is initialized with a nested list of linear operators, which 41 are combined into a new `LinearOperator` whose underlying matrix 42 representation is square and has each operator on or below the main diagonal, 43 and zero's elsewhere. Each element of the outer list is a list of 44 `LinearOperators` corresponding to a row-partition of the blockwise structure. 45 The number of `LinearOperator`s in row-partion `i` must be equal to `i`. 46 47 For example, a blockwise `3 x 3` `LinearOperatorBlockLowerTriangular` is 48 initialized with the list `[[op_00], [op_10, op_11], [op_20, op_21, op_22]]`, 49 where the `op_ij`, `i < 3, j <= i`, are `LinearOperator` instances. The 50 `LinearOperatorBlockLowerTriangular` behaves as the following blockwise 51 matrix, where `0` represents appropriately-sized [batch] matrices of zeros: 52 53 ```none 54 [[op_00, 0, 0], 55 [op_10, op_11, 0], 56 [op_20, op_21, op_22]] 57 ``` 58 59 Each `op_jj` on the diagonal is required to represent a square matrix, and 60 hence will have shape `batch_shape_j + [M_j, M_j]`. `LinearOperator`s in row 61 `j` of the blockwise structure must have `range_dimension` equal to that of 62 `op_jj`, and `LinearOperators` in column `j` must have `domain_dimension` 63 equal to that of `op_jj`. 64 65 If each `op_jj` on the diagonal has shape `batch_shape_j + [M_j, M_j]`, then 66 the combined operator has shape `broadcast_batch_shape + [sum M_j, sum M_j]`, 67 where `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`, 68 `j = 0, 1, ..., J`, assuming the intermediate batch shapes broadcast. 69 Even if the combined shape is well defined, the combined operator's 70 methods may fail due to lack of broadcasting ability in the defining 71 operators' methods. 72 73 For example, to create a 4 x 4 linear operator combined of three 2 x 2 74 operators: 75 >>> operator_0 = tf.linalg.LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) 76 >>> operator_1 = tf.linalg.LinearOperatorFullMatrix([[1., 0.], [0., 1.]]) 77 >>> operator_2 = tf.linalg.LinearOperatorLowerTriangular([[5., 6.], [7., 8]]) 78 >>> operator = LinearOperatorBlockLowerTriangular( 79 ... [[operator_0], [operator_1, operator_2]]) 80 81 >>> operator.to_dense() 82 <tf.Tensor: shape=(4, 4), dtype=float32, numpy= 83 array([[1., 2., 0., 0.], 84 [3., 4., 0., 0.], 85 [1., 0., 5., 0.], 86 [0., 1., 7., 8.]], dtype=float32)> 87 88 >>> operator.shape 89 TensorShape([4, 4]) 90 91 >>> operator.log_abs_determinant() 92 <tf.Tensor: shape=(), dtype=float32, numpy=4.3820267> 93 94 >>> x0 = [[1., 6.], [-3., 4.]] 95 >>> x1 = [[0., 2.], [4., 0.]] 96 >>> x = tf.concat([x0, x1], 0) # Shape [2, 4] Tensor 97 >>> operator.matmul(x) 98 <tf.Tensor: shape=(4, 2), dtype=float32, numpy= 99 array([[-5., 14.], 100 [-9., 34.], 101 [ 1., 16.], 102 [29., 18.]], dtype=float32)> 103 104 The above `matmul` is equivalent to: 105 >>> tf.concat([operator_0.matmul(x0), 106 ... operator_1.matmul(x0) + operator_2.matmul(x1)], axis=0) 107 <tf.Tensor: shape=(4, 2), dtype=float32, numpy= 108 array([[-5., 14.], 109 [-9., 34.], 110 [ 1., 16.], 111 [29., 18.]], dtype=float32)> 112 113 #### Shape compatibility 114 115 This operator acts on [batch] matrix with compatible shape. 116 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 117 118 ``` 119 operator.shape = [B1,...,Bb] + [M, N], with b >= 0 120 x.shape = [B1,...,Bb] + [N, R], with R >= 0. 121 ``` 122 123 For example: 124 125 Create a [2, 3] batch of 4 x 4 linear operators: 126 >>> matrix_44 = tf.random.normal(shape=[2, 3, 4, 4]) 127 >>> operator_44 = tf.linalg.LinearOperatorFullMatrix(matrix_44) 128 129 Create a [1, 3] batch of 5 x 4 linear operators: 130 >>> matrix_54 = tf.random.normal(shape=[1, 3, 5, 4]) 131 >>> operator_54 = tf.linalg.LinearOperatorFullMatrix(matrix_54) 132 133 Create a [1, 3] batch of 5 x 5 linear operators: 134 >>> matrix_55 = tf.random.normal(shape=[1, 3, 5, 5]) 135 >>> operator_55 = tf.linalg.LinearOperatorFullMatrix(matrix_55) 136 137 Combine to create a [2, 3] batch of 9 x 9 operators: 138 >>> operator_99 = LinearOperatorBlockLowerTriangular( 139 ... [[operator_44], [operator_54, operator_55]]) 140 >>> operator_99.shape 141 TensorShape([2, 3, 9, 9]) 142 143 Create a shape [2, 1, 9] batch of vectors and apply the operator to it. 144 >>> x = tf.random.normal(shape=[2, 1, 9]) 145 >>> y = operator_99.matvec(x) 146 >>> y.shape 147 TensorShape([2, 3, 9]) 148 149 Create a blockwise list of vectors and apply the operator to it. A blockwise 150 list is returned. 151 >>> x4 = tf.random.normal(shape=[2, 1, 4]) 152 >>> x5 = tf.random.normal(shape=[2, 3, 5]) 153 >>> y_blockwise = operator_99.matvec([x4, x5]) 154 >>> y_blockwise[0].shape 155 TensorShape([2, 3, 4]) 156 >>> y_blockwise[1].shape 157 TensorShape([2, 3, 5]) 158 159 #### Performance 160 161 Suppose `operator` is a `LinearOperatorBlockLowerTriangular` consisting of `D` 162 row-partitions and `D` column-partitions, such that the total number of 163 operators is `N = D * (D + 1) // 2`. 164 165 * `operator.matmul` has complexity equal to the sum of the `matmul` 166 complexities of the individual operators. 167 * `operator.solve` has complexity equal to the sum of the `solve` complexities 168 of the operators on the diagonal and the `matmul` complexities of the 169 operators off the diagonal. 170 * `operator.determinant` has complexity equal to the sum of the `determinant` 171 complexities of the operators on the diagonal. 172 173 #### Matrix property hints 174 175 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 176 for `X = non_singular, self_adjoint, positive_definite, square`. 177 These have the following meaning: 178 179 * If `is_X == True`, callers should expect the operator to have the 180 property `X`. This is a promise that should be fulfilled, but is *not* a 181 runtime assert. For example, finite floating point precision may result 182 in these promises being violated. 183 * If `is_X == False`, callers should expect the operator to not have `X`. 184 * If `is_X == None` (the default), callers should have no expectation either 185 way. 186 """ 187 188 def __init__(self, 189 operators, 190 is_non_singular=None, 191 is_self_adjoint=None, 192 is_positive_definite=None, 193 is_square=None, 194 name="LinearOperatorBlockLowerTriangular"): 195 r"""Initialize a `LinearOperatorBlockLowerTriangular`. 196 197 `LinearOperatorBlockLowerTriangular` is initialized with a list of lists of 198 operators `[[op_0], [op_1, op_2], [op_3, op_4, op_5],...]`. 199 200 Args: 201 operators: Iterable of iterables of `LinearOperator` objects, each with 202 the same `dtype`. Each element of `operators` corresponds to a row- 203 partition, in top-to-bottom order. The operators in each row-partition 204 are filled in left-to-right. For example, 205 `operators = [[op_0], [op_1, op_2], [op_3, op_4, op_5]]` creates a 206 `LinearOperatorBlockLowerTriangular` with full block structure 207 `[[op_0, 0, 0], [op_1, op_2, 0], [op_3, op_4, op_5]]`. The number of 208 operators in the `i`th row must be equal to `i`, such that each operator 209 falls on or below the diagonal of the blockwise structure. 210 `LinearOperator`s that fall on the diagonal (the last elements of each 211 row) must be square. The other `LinearOperator`s must have domain 212 dimension equal to the domain dimension of the `LinearOperator`s in the 213 same column-partition, and range dimension equal to the range dimension 214 of the `LinearOperator`s in the same row-partition. 215 is_non_singular: Expect that this operator is non-singular. 216 is_self_adjoint: Expect that this operator is equal to its hermitian 217 transpose. 218 is_positive_definite: Expect that this operator is positive definite, 219 meaning the quadratic form `x^H A x` has positive real part for all 220 nonzero `x`. Note that we do not require the operator to be 221 self-adjoint to be positive-definite. See: 222 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 223 is_square: Expect that this operator acts like square [batch] matrices. 224 This will raise a `ValueError` if set to `False`. 225 name: A name for this `LinearOperator`. 226 227 Raises: 228 TypeError: If all operators do not have the same `dtype`. 229 ValueError: If `operators` is empty, contains an erroneous number of 230 elements, or contains operators with incompatible shapes. 231 """ 232 parameters = dict( 233 operators=operators, 234 is_non_singular=is_non_singular, 235 is_self_adjoint=is_self_adjoint, 236 is_positive_definite=is_positive_definite, 237 is_square=is_square, 238 name=name 239 ) 240 241 # Validate operators. 242 check_ops.assert_proper_iterable(operators) 243 for row in operators: 244 check_ops.assert_proper_iterable(row) 245 operators = [list(row) for row in operators] 246 247 if not operators: 248 raise ValueError(f"Argument `operators` must be a list of >=1 operators. " 249 f"Received: {operators}.") 250 self._operators = operators 251 self._diagonal_operators = [row[-1] for row in operators] 252 253 dtype = operators[0][0].dtype 254 self._validate_dtype(dtype) 255 is_non_singular = self._validate_non_singular(is_non_singular) 256 self._validate_num_operators() 257 self._validate_operator_dimensions() 258 is_square = self._validate_square(is_square) 259 with ops.name_scope(name): 260 super(LinearOperatorBlockLowerTriangular, self).__init__( 261 dtype=dtype, 262 is_non_singular=is_non_singular, 263 is_self_adjoint=is_self_adjoint, 264 is_positive_definite=is_positive_definite, 265 is_square=is_square, 266 parameters=parameters, 267 name=name) 268 269 def _validate_num_operators(self): 270 for i, row in enumerate(self.operators): 271 if len(row) != i + 1: 272 raise ValueError( 273 f"Argument `operators[{i}]` must contain `{i + 1}` blocks. " 274 f"Received: {len(row)} blocks.") 275 276 def _validate_operator_dimensions(self): 277 """Check that `operators` have compatible dimensions.""" 278 for i in range(1, len(self.operators)): 279 for j in range(i): 280 op = self.operators[i][j] 281 282 # `above_op` is the operator directly above `op` in the blockwise 283 # structure, in row partition `i-1`, column partition `j`. `op` should 284 # have the same `domain_dimension` as `above_op`. 285 above_op = self.operators[i - 1][j] 286 287 # `right_op` is the operator to the right of `op` in the blockwise 288 # structure, in row partition `i`, column partition `j+1`. `op` should 289 # have the same `range_dimension` as `right_op`. 290 right_op = self.operators[i][j + 1] 291 292 if (op.domain_dimension is not None and 293 above_op.domain_dimension is not None): 294 if op.domain_dimension != above_op.domain_dimension: 295 raise ValueError(f"Argument `operators[{i}][{j}].domain_dimension` " 296 f"({op.domain_dimension}) must be the same as " 297 f"`operators[{i-1}][{j}].domain_dimension` " 298 f"({above_op.domain_dimension}).") 299 if (op.range_dimension is not None and 300 right_op.range_dimension is not None): 301 if op.range_dimension != right_op.range_dimension: 302 raise ValueError(f"Argument `operators[{i}][{j}].range_dimension` " 303 f"({op.range_dimension}) must be the same as " 304 f"`operators[{i}][{j + 1}].range_dimension` " 305 f"({right_op.range_dimension}).") 306 307 # pylint: disable=g-bool-id-comparison 308 def _validate_non_singular(self, is_non_singular): 309 if all(op.is_non_singular for op in self._diagonal_operators): 310 if is_non_singular is False: 311 raise ValueError( 312 f"A blockwise lower-triangular operator with non-singular " 313 f"operators on the main diagonal is always non-singular. " 314 f"Expected argument `is_non_singular` to be True. " 315 f"Received: {is_non_singular}.") 316 return True 317 if any(op.is_non_singular is False for op in self._diagonal_operators): 318 if is_non_singular is True: 319 raise ValueError( 320 f"A blockwise lower-triangular operator with a singular operator " 321 f"on the main diagonal is always singular. Expected argument " 322 f"`is_non_singular` to be True. Received: {is_non_singular}.") 323 return False 324 325 def _validate_square(self, is_square): 326 if is_square is False: 327 raise ValueError(f"`LinearOperatorBlockLowerTriangular` must be square. " 328 f"Expected argument `is_square` to be True. " 329 f"Received: {is_square}.") 330 for i, op in enumerate(self._diagonal_operators): 331 if op.is_square is False: 332 raise ValueError( 333 f"Matrices on the diagonal (the final elements of each " 334 f"row-partition in the `operators` list) must be square. Expected " 335 f"argument `operators[{i}][-1].is_square` to be True. " 336 f"Received: {op.is_square}.") 337 return True 338 # pylint: enable=g-bool-id-comparison 339 340 def _validate_dtype(self, dtype): 341 for i, row in enumerate(self.operators): 342 for operator in row: 343 if operator.dtype != dtype: 344 name_type = (str((o.name, o.dtype)) for o in row) 345 raise TypeError( 346 "Expected all operators to have the same dtype. Found {} in row " 347 "{} and {} in row 0.".format(name_type, i, str(dtype))) 348 349 @property 350 def operators(self): 351 return self._operators 352 353 def _block_range_dimensions(self): 354 return [op.range_dimension for op in self._diagonal_operators] 355 356 def _block_domain_dimensions(self): 357 return [op.domain_dimension for op in self._diagonal_operators] 358 359 def _block_range_dimension_tensors(self): 360 return [op.range_dimension_tensor() for op in self._diagonal_operators] 361 362 def _block_domain_dimension_tensors(self): 363 return [op.domain_dimension_tensor() for op in self._diagonal_operators] 364 365 def _shape(self): 366 # Get final matrix shape. 367 domain_dimension = sum(self._block_domain_dimensions()) 368 range_dimension = sum(self._block_range_dimensions()) 369 matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension]) 370 371 # Get broadcast batch shape. 372 # broadcast_shape checks for compatibility. 373 batch_shape = self.operators[0][0].batch_shape 374 for row in self.operators[1:]: 375 for operator in row: 376 batch_shape = common_shapes.broadcast_shape( 377 batch_shape, operator.batch_shape) 378 379 return batch_shape.concatenate(matrix_shape) 380 381 def _shape_tensor(self): 382 # Avoid messy broadcasting if possible. 383 if self.shape.is_fully_defined(): 384 return ops.convert_to_tensor_v2_with_dispatch( 385 self.shape.as_list(), dtype=dtypes.int32, name="shape") 386 387 domain_dimension = sum(self._block_domain_dimension_tensors()) 388 range_dimension = sum(self._block_range_dimension_tensors()) 389 matrix_shape = array_ops.stack([domain_dimension, range_dimension]) 390 391 batch_shape = self.operators[0][0].batch_shape_tensor() 392 for row in self.operators[1:]: 393 for operator in row: 394 batch_shape = array_ops.broadcast_dynamic_shape( 395 batch_shape, operator.batch_shape_tensor()) 396 397 return array_ops.concat((batch_shape, matrix_shape), 0) 398 399 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): 400 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. 401 402 ```python 403 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 404 operator = LinearOperator(...) 405 operator.shape = [..., M, N] 406 407 X = ... # shape [..., N, R], batch matrix, R > 0. 408 409 Y = operator.matmul(X) 410 Y.shape 411 ==> [..., M, R] 412 413 Y[..., :, r] = sum_j A[..., :, j] X[j, r] 414 ``` 415 416 Args: 417 x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as 418 `self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See 419 class docstring for definition of shape compatibility. 420 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 421 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is 422 the hermitian transpose (transposition and complex conjugation). 423 name: A name for this `Op`. 424 425 Returns: 426 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` 427 as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that 428 concatenate to `[..., M, R]`. 429 """ 430 if isinstance(x, linear_operator.LinearOperator): 431 left_operator = self.adjoint() if adjoint else self 432 right_operator = x.adjoint() if adjoint_arg else x 433 434 if (right_operator.range_dimension is not None and 435 left_operator.domain_dimension is not None and 436 right_operator.range_dimension != left_operator.domain_dimension): 437 raise ValueError( 438 "Operators are incompatible. Expected `x` to have dimension" 439 " {} but got {}.".format( 440 left_operator.domain_dimension, right_operator.range_dimension)) 441 with self._name_scope(name): # pylint: disable=not-callable 442 return linear_operator_algebra.matmul(left_operator, right_operator) 443 444 with self._name_scope(name): # pylint: disable=not-callable 445 arg_dim = -1 if adjoint_arg else -2 446 block_dimensions = (self._block_range_dimensions() if adjoint 447 else self._block_domain_dimensions()) 448 if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim): 449 for i, block in enumerate(x): 450 if not isinstance(block, linear_operator.LinearOperator): 451 block = ops.convert_to_tensor_v2_with_dispatch(block) 452 self._check_input_dtype(block) 453 block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim]) 454 x[i] = block 455 else: 456 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 457 self._check_input_dtype(x) 458 op_dimension = (self.range_dimension if adjoint 459 else self.domain_dimension) 460 op_dimension.assert_is_compatible_with(x.shape[arg_dim]) 461 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 462 463 def _matmul(self, x, adjoint=False, adjoint_arg=False): 464 arg_dim = -1 if adjoint_arg else -2 465 block_dimensions = (self._block_range_dimensions() if adjoint 466 else self._block_domain_dimensions()) 467 blockwise_arg = linear_operator_util.arg_is_blockwise( 468 block_dimensions, x, arg_dim) 469 if blockwise_arg: 470 split_x = x 471 else: 472 split_dim = -1 if adjoint_arg else -2 473 # Split input by columns if adjoint_arg is True, else rows 474 split_x = linear_operator_util.split_arg_into_blocks( 475 self._block_domain_dimensions(), 476 self._block_domain_dimension_tensors, 477 x, axis=split_dim) 478 479 result_list = [] 480 # Iterate over row-partitions (i.e. column-partitions of the adjoint). 481 if adjoint: 482 for index in range(len(self.operators)): 483 # Begin with the operator on the diagonal and apply it to the 484 # respective `rhs` block. 485 result = self.operators[index][index].matmul( 486 split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg) 487 488 # Iterate top to bottom over the operators in the remainder of the 489 # column-partition (i.e. left to right over the row-partition of the 490 # adjoint), apply the operator to the respective `rhs` block and 491 # accumulate the sum. For example, given the 492 # `LinearOperatorBlockLowerTriangular`: 493 # 494 # op = [[A, 0, 0], 495 # [B, C, 0], 496 # [D, E, F]] 497 # 498 # if `index = 1`, the following loop calculates: 499 # `y_1 = (C.matmul(x_1, adjoint=adjoint) + 500 # E.matmul(x_2, adjoint=adjoint)`, 501 # where `x_1` and `x_2` are splits of `x`. 502 for j in range(index + 1, len(self.operators)): 503 result += self.operators[j][index].matmul( 504 split_x[j], adjoint=adjoint, adjoint_arg=adjoint_arg) 505 result_list.append(result) 506 else: 507 for row in self.operators: 508 # Begin with the left-most operator in the row-partition and apply it 509 # to the first `rhs` block. 510 result = row[0].matmul( 511 split_x[0], adjoint=adjoint, adjoint_arg=adjoint_arg) 512 # Iterate left to right over the operators in the remainder of the row 513 # partition, apply the operator to the respective `rhs` block, and 514 # accumulate the sum. 515 for j, operator in enumerate(row[1:]): 516 result += operator.matmul( 517 split_x[j + 1], adjoint=adjoint, adjoint_arg=adjoint_arg) 518 result_list.append(result) 519 520 if blockwise_arg: 521 return result_list 522 523 result_list = linear_operator_util.broadcast_matrix_batch_dims( 524 result_list) 525 return array_ops.concat(result_list, axis=-2) 526 527 def matvec(self, x, adjoint=False, name="matvec"): 528 """Transform [batch] vector `x` with left multiplication: `x --> Ax`. 529 530 ```python 531 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 532 operator = LinearOperator(...) 533 534 X = ... # shape [..., N], batch vector 535 536 Y = operator.matvec(X) 537 Y.shape 538 ==> [..., M] 539 540 Y[..., :] = sum_j A[..., :, j] X[..., j] 541 ``` 542 543 Args: 544 x: `Tensor` with compatible shape and same `dtype` as `self`, or an 545 iterable of `Tensor`s. `Tensor`s are treated a [batch] vectors, meaning 546 for every set of leading dimensions, the last dimension defines a 547 vector. 548 See class docstring for definition of compatibility. 549 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 550 name: A name for this `Op`. 551 552 Returns: 553 A `Tensor` with shape `[..., M]` and same `dtype` as `self`. 554 """ 555 with self._name_scope(name): # pylint: disable=not-callable 556 block_dimensions = (self._block_range_dimensions() if adjoint 557 else self._block_domain_dimensions()) 558 if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1): 559 for i, block in enumerate(x): 560 if not isinstance(block, linear_operator.LinearOperator): 561 block = ops.convert_to_tensor_v2_with_dispatch(block) 562 self._check_input_dtype(block) 563 block_dimensions[i].assert_is_compatible_with(block.shape[-1]) 564 x[i] = block 565 x_mat = [block[..., array_ops.newaxis] for block in x] 566 y_mat = self.matmul(x_mat, adjoint=adjoint) 567 return [array_ops.squeeze(y, axis=-1) for y in y_mat] 568 569 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 570 self._check_input_dtype(x) 571 op_dimension = (self.range_dimension if adjoint 572 else self.domain_dimension) 573 op_dimension.assert_is_compatible_with(x.shape[-1]) 574 x_mat = x[..., array_ops.newaxis] 575 y_mat = self.matmul(x_mat, adjoint=adjoint) 576 return array_ops.squeeze(y_mat, axis=-1) 577 578 def _determinant(self): 579 if all(op.is_positive_definite for op in self._diagonal_operators): 580 return math_ops.exp(self._log_abs_determinant()) 581 result = self._diagonal_operators[0].determinant() 582 for op in self._diagonal_operators[1:]: 583 result *= op.determinant() 584 return result 585 586 def _log_abs_determinant(self): 587 result = self._diagonal_operators[0].log_abs_determinant() 588 for op in self._diagonal_operators[1:]: 589 result += op.log_abs_determinant() 590 return result 591 592 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): 593 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. 594 595 The returned `Tensor` will be close to an exact solution if `A` is well 596 conditioned. Otherwise closeness will vary. See class docstring for details. 597 598 Given the blockwise `n + 1`-by-`n + 1` linear operator: 599 600 op = [[A_00 0 ... 0 ... 0], 601 [A_10 A_11 ... 0 ... 0], 602 ... 603 [A_k0 A_k1 ... A_kk ... 0], 604 ... 605 [A_n0 A_n1 ... A_nk ... A_nn]] 606 607 we find `x = op.solve(y)` by observing that 608 609 `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)` 610 611 and therefore 612 613 `x_k = A_kk.solve(y_k - 614 A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))` 615 616 where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x` 617 and `y` along their appropriate axes. 618 619 We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve 620 for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`. 621 622 The adjoint case is solved similarly, beginning with 623 `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards. 624 625 Examples: 626 627 ```python 628 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 629 operator = LinearOperator(...) 630 operator.shape = [..., M, N] 631 632 # Solve R > 0 linear systems for every member of the batch. 633 RHS = ... # shape [..., M, R] 634 635 X = operator.solve(RHS) 636 # X[..., :, r] is the solution to the r'th linear system 637 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] 638 639 operator.matmul(X) 640 ==> RHS 641 ``` 642 643 Args: 644 rhs: `Tensor` with same `dtype` as this operator and compatible shape, 645 or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices 646 meaning for every set of leading dimensions, the last two dimensions 647 defines a matrix. 648 See class docstring for definition of compatibility. 649 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 650 of this `LinearOperator`: `A^H X = rhs`. 651 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` 652 is the hermitian transpose (transposition and complex conjugation). 653 name: A name scope to use for ops added by this method. 654 655 Returns: 656 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. 657 658 Raises: 659 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 660 """ 661 if self.is_non_singular is False: 662 raise NotImplementedError( 663 "Exact solve not implemented for an operator that is expected to " 664 "be singular.") 665 if self.is_square is False: 666 raise NotImplementedError( 667 "Exact solve not implemented for an operator that is expected to " 668 "not be square.") 669 if isinstance(rhs, linear_operator.LinearOperator): 670 left_operator = self.adjoint() if adjoint else self 671 right_operator = rhs.adjoint() if adjoint_arg else rhs 672 673 if (right_operator.range_dimension is not None and 674 left_operator.domain_dimension is not None and 675 right_operator.range_dimension != left_operator.domain_dimension): 676 raise ValueError( 677 "Operators are incompatible. Expected `rhs` to have dimension" 678 " {} but got {}.".format( 679 left_operator.domain_dimension, right_operator.range_dimension)) 680 with self._name_scope(name): # pylint: disable=not-callable 681 return linear_operator_algebra.solve(left_operator, right_operator) 682 683 with self._name_scope(name): # pylint: disable=not-callable 684 block_dimensions = (self._block_domain_dimensions() if adjoint 685 else self._block_range_dimensions()) 686 arg_dim = -1 if adjoint_arg else -2 687 blockwise_arg = linear_operator_util.arg_is_blockwise( 688 block_dimensions, rhs, arg_dim) 689 if blockwise_arg: 690 for i, block in enumerate(rhs): 691 if not isinstance(block, linear_operator.LinearOperator): 692 block = ops.convert_to_tensor_v2_with_dispatch(block) 693 self._check_input_dtype(block) 694 block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim]) 695 rhs[i] = block 696 if adjoint_arg: 697 split_rhs = [linalg.adjoint(y) for y in rhs] 698 else: 699 split_rhs = rhs 700 701 else: 702 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 703 self._check_input_dtype(rhs) 704 op_dimension = (self.domain_dimension if adjoint 705 else self.range_dimension) 706 op_dimension.assert_is_compatible_with(rhs.shape[arg_dim]) 707 708 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 709 split_rhs = linear_operator_util.split_arg_into_blocks( 710 self._block_domain_dimensions(), 711 self._block_domain_dimension_tensors, 712 rhs, axis=-2) 713 714 solution_list = [] 715 if adjoint: 716 # For an adjoint blockwise lower-triangular linear operator, the system 717 # must be solved bottom to top. Iterate backwards over rows of the 718 # adjoint (i.e. columns of the non-adjoint operator). 719 for index in reversed(range(len(self.operators))): 720 y = split_rhs[index] 721 # Iterate top to bottom over the operators in the off-diagonal portion 722 # of the column-partition (i.e. row-partition of the adjoint), apply 723 # the operator to the respective block of the solution found in 724 # previous iterations, and subtract the result from the `rhs` block. 725 # For example,let `A`, `B`, and `D` be the linear operators in the top 726 # row-partition of the adjoint of 727 # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`, 728 # and `x_1` and `x_2` be blocks of the solution found in previous 729 # iterations of the outer loop. The following loop (when `index == 0`) 730 # expresses 731 # `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where 732 # `y_0* = y_0 - Bx_1 - Dx_2`. 733 for j in reversed(range(index + 1, len(self.operators))): 734 y = y - self.operators[j][index].matmul( 735 solution_list[len(self.operators) - 1 - j], 736 adjoint=adjoint) 737 # Continuing the example above, solve `Ax_0 = y_0*` for `x_0`. 738 solution_list.append( 739 self._diagonal_operators[index].solve(y, adjoint=adjoint)) 740 solution_list.reverse() 741 else: 742 # Iterate top to bottom over the row-partitions. 743 for row, y in zip(self.operators, split_rhs): 744 # Iterate left to right over the operators in the off-diagonal portion 745 # of the row-partition, apply the operator to the block of the 746 # solution found in previous iterations, and subtract the result from 747 # the `rhs` block. For example, let `D`, `E`, and `F` be the linear 748 # operators in the bottom row-partition of 749 # `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and 750 # `x_0` and `x_1` be blocks of the solution found in previous 751 # iterations of the outer loop. The following loop 752 # (when `index == 2`), expresses 753 # `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where 754 # `y_2* = y_2 - D_x0 - Ex_1`. 755 for i, operator in enumerate(row[:-1]): 756 y = y - operator.matmul(solution_list[i], adjoint=adjoint) 757 # Continuing the example above, solve `Fx_2 = y_2*` for `x_2`. 758 solution_list.append(row[-1].solve(y, adjoint=adjoint)) 759 760 if blockwise_arg: 761 return solution_list 762 763 solution_list = linear_operator_util.broadcast_matrix_batch_dims( 764 solution_list) 765 return array_ops.concat(solution_list, axis=-2) 766 767 def solvevec(self, rhs, adjoint=False, name="solve"): 768 """Solve single equation with best effort: `A X = rhs`. 769 770 The returned `Tensor` will be close to an exact solution if `A` is well 771 conditioned. Otherwise closeness will vary. See class docstring for details. 772 773 Examples: 774 775 ```python 776 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 777 operator = LinearOperator(...) 778 operator.shape = [..., M, N] 779 780 # Solve one linear system for every member of the batch. 781 RHS = ... # shape [..., M] 782 783 X = operator.solvevec(RHS) 784 # X is the solution to the linear system 785 # sum_j A[..., :, j] X[..., j] = RHS[..., :] 786 787 operator.matvec(X) 788 ==> RHS 789 ``` 790 791 Args: 792 rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s 793 (for blockwise operators). `Tensor`s are treated as [batch] vectors, 794 meaning for every set of leading dimensions, the last dimension defines 795 a vector. See class docstring for definition of compatibility regarding 796 batch dimensions. 797 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 798 of this `LinearOperator`: `A^H X = rhs`. 799 name: A name scope to use for ops added by this method. 800 801 Returns: 802 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. 803 804 Raises: 805 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 806 """ 807 with self._name_scope(name): # pylint: disable=not-callable 808 block_dimensions = (self._block_domain_dimensions() if adjoint 809 else self._block_range_dimensions()) 810 if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1): 811 for i, block in enumerate(rhs): 812 if not isinstance(block, linear_operator.LinearOperator): 813 block = ops.convert_to_tensor_v2_with_dispatch(block) 814 self._check_input_dtype(block) 815 block_dimensions[i].assert_is_compatible_with(block.shape[-1]) 816 rhs[i] = block 817 rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs] 818 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 819 return [array_ops.squeeze(x, axis=-1) for x in solution_mat] 820 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 821 self._check_input_dtype(rhs) 822 op_dimension = (self.domain_dimension if adjoint 823 else self.range_dimension) 824 op_dimension.assert_is_compatible_with(rhs.shape[-1]) 825 rhs_mat = array_ops.expand_dims(rhs, axis=-1) 826 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 827 return array_ops.squeeze(solution_mat, axis=-1) 828 829 def _diag_part(self): 830 diag_list = [] 831 for op in self._diagonal_operators: 832 # Extend the axis, since `broadcast_matrix_batch_dims` treats all but the 833 # final two dimensions as batch dimensions. 834 diag_list.append(op.diag_part()[..., array_ops.newaxis]) 835 diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) 836 diagonal = array_ops.concat(diag_list, axis=-2) 837 return array_ops.squeeze(diagonal, axis=-1) 838 839 def _trace(self): 840 result = self._diagonal_operators[0].trace() 841 for op in self._diagonal_operators[1:]: 842 result += op.trace() 843 return result 844 845 def _to_dense(self): 846 num_cols = 0 847 dense_rows = [] 848 flat_broadcast_operators = linear_operator_util.broadcast_matrix_batch_dims( 849 [op.to_dense() for row in self.operators for op in row]) # pylint: disable=g-complex-comprehension 850 broadcast_operators = [ 851 flat_broadcast_operators[i * (i + 1) // 2:(i + 1) * (i + 2) // 2] 852 for i in range(len(self.operators))] 853 for row_blocks in broadcast_operators: 854 batch_row_shape = array_ops.shape(row_blocks[0])[:-1] 855 num_cols += array_ops.shape(row_blocks[-1])[-1] 856 zeros_to_pad_after_shape = array_ops.concat( 857 [batch_row_shape, 858 [self.domain_dimension_tensor() - num_cols]], axis=-1) 859 zeros_to_pad_after = array_ops.zeros( 860 shape=zeros_to_pad_after_shape, dtype=self.dtype) 861 862 row_blocks.append(zeros_to_pad_after) 863 dense_rows.append(array_ops.concat(row_blocks, axis=-1)) 864 865 mat = array_ops.concat(dense_rows, axis=-2) 866 mat.set_shape(self.shape) 867 return mat 868 869 def _assert_non_singular(self): 870 return control_flow_ops.group([ 871 op.assert_non_singular() for op in self._diagonal_operators]) 872 873 def _eigvals(self): 874 eig_list = [] 875 for op in self._diagonal_operators: 876 # Extend the axis for broadcasting. 877 eig_list.append(op.eigvals()[..., array_ops.newaxis]) 878 eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list) 879 eigs = array_ops.concat(eig_list, axis=-2) 880 return array_ops.squeeze(eigs, axis=-1) 881 882 @property 883 def _composite_tensor_fields(self): 884 return ("operators",) 885 886 @property 887 def _experimental_parameter_ndims_to_matrix_ndims(self): 888 # None of the operators contribute to the matrix shape. 889 return {"operators": nest.map_structure(lambda _: 0, self.operators)} 890