1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like the identity matrix.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import check_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops.linalg import linalg_impl as linalg 32from tensorflow.python.ops.linalg import linear_operator 33from tensorflow.python.ops.linalg import linear_operator_util 34from tensorflow.python.util.tf_export import tf_export 35 36__all__ = [ 37 "LinearOperatorIdentity", 38 "LinearOperatorScaledIdentity", 39] 40 41 42class BaseLinearOperatorIdentity(linear_operator.LinearOperator): 43 """Base class for Identity operators.""" 44 45 def _check_num_rows_possibly_add_asserts(self): 46 """Static check of init arg `num_rows`, possibly add asserts.""" 47 # Possibly add asserts. 48 if self._assert_proper_shapes: 49 self._num_rows = control_flow_ops.with_dependencies([ 50 check_ops.assert_rank( 51 self._num_rows, 52 0, 53 message="Argument num_rows must be a 0-D Tensor."), 54 check_ops.assert_non_negative( 55 self._num_rows, 56 message="Argument num_rows must be non-negative."), 57 ], self._num_rows) 58 59 # Static checks. 60 if not self._num_rows.dtype.is_integer: 61 raise TypeError("Argument num_rows must be integer type. Found:" 62 " %s" % self._num_rows) 63 64 num_rows_static = self._num_rows_static 65 66 if num_rows_static is None: 67 return # Cannot do any other static checks. 68 69 if num_rows_static.ndim != 0: 70 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 71 " %s" % num_rows_static) 72 73 if num_rows_static < 0: 74 raise ValueError("Argument num_rows must be non-negative. Found:" 75 " %s" % num_rows_static) 76 77 def _min_matrix_dim(self): 78 """Minimum of domain/range dimension, if statically available, else None.""" 79 domain_dim = tensor_shape.dimension_value(self.domain_dimension) 80 range_dim = tensor_shape.dimension_value(self.range_dimension) 81 if domain_dim is None or range_dim is None: 82 return None 83 return min(domain_dim, range_dim) 84 85 def _min_matrix_dim_tensor(self): 86 """Minimum of domain/range dimension, as a tensor.""" 87 return math_ops.reduce_min(self.shape_tensor()[-2:]) 88 89 def _ones_diag(self): 90 """Returns the diagonal of this operator as all ones.""" 91 if self.shape.is_fully_defined(): 92 d_shape = self.batch_shape.concatenate([self._min_matrix_dim()]) 93 else: 94 d_shape = array_ops.concat( 95 [self.batch_shape_tensor(), 96 [self._min_matrix_dim_tensor()]], axis=0) 97 98 return array_ops.ones(shape=d_shape, dtype=self.dtype) 99 100 101@tf_export("linalg.LinearOperatorIdentity") 102class LinearOperatorIdentity(BaseLinearOperatorIdentity): 103 """`LinearOperator` acting like a [batch] square identity matrix. 104 105 This operator acts like a [batch] identity matrix `A` with shape 106 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 107 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 108 an `N x N` matrix. This matrix `A` is not materialized, but for 109 purposes of broadcasting this shape will be relevant. 110 111 `LinearOperatorIdentity` is initialized with `num_rows`, and optionally 112 `batch_shape`, and `dtype` arguments. If `batch_shape` is `None`, this 113 operator efficiently passes through all arguments. If `batch_shape` is 114 provided, broadcasting may occur, which will require making copies. 115 116 ```python 117 # Create a 2 x 2 identity matrix. 118 operator = LinearOperatorIdentity(num_rows=2, dtype=tf.float32) 119 120 operator.to_dense() 121 ==> [[1., 0.] 122 [0., 1.]] 123 124 operator.shape 125 ==> [2, 2] 126 127 operator.log_abs_determinant() 128 ==> 0. 129 130 x = ... Shape [2, 4] Tensor 131 operator.matmul(x) 132 ==> Shape [2, 4] Tensor, same as x. 133 134 y = tf.random_normal(shape=[3, 2, 4]) 135 # Note that y.shape is compatible with operator.shape because operator.shape 136 # is broadcast to [3, 2, 2]. 137 # This broadcast does NOT require copying data, since we can infer that y 138 # will be passed through without changing shape. We are always able to infer 139 # this if the operator has no batch_shape. 140 x = operator.solve(y) 141 ==> Shape [3, 2, 4] Tensor, same as y. 142 143 # Create a 2-batch of 2x2 identity matrices 144 operator = LinearOperatorIdentity(num_rows=2, batch_shape=[2]) 145 operator.to_dense() 146 ==> [[[1., 0.] 147 [0., 1.]], 148 [[1., 0.] 149 [0., 1.]]] 150 151 # Here, even though the operator has a batch shape, the input is the same as 152 # the output, so x can be passed through without a copy. The operator is able 153 # to detect that no broadcast is necessary because both x and the operator 154 # have statically defined shape. 155 x = ... Shape [2, 2, 3] 156 operator.matmul(x) 157 ==> Shape [2, 2, 3] Tensor, same as x 158 159 # Here the operator and x have different batch_shape, and are broadcast. 160 # This requires a copy, since the output is different size than the input. 161 x = ... Shape [1, 2, 3] 162 operator.matmul(x) 163 ==> Shape [2, 2, 3] Tensor, equal to [x, x] 164 ``` 165 166 ### Shape compatibility 167 168 This operator acts on [batch] matrix with compatible shape. 169 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 170 171 ``` 172 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 173 x.shape = [C1,...,Cc] + [N, R], 174 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 175 ``` 176 177 ### Performance 178 179 If `batch_shape` initialization arg is `None`: 180 181 * `operator.matmul(x)` is `O(1)` 182 * `operator.solve(x)` is `O(1)` 183 * `operator.determinant()` is `O(1)` 184 185 If `batch_shape` initialization arg is provided, and static checks cannot 186 rule out the need to broadcast: 187 188 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 189 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 190 * `operator.determinant()` is `O(B1*...*Bb)` 191 192 #### Matrix property hints 193 194 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 195 for `X = non_singular, self_adjoint, positive_definite, square`. 196 These have the following meaning: 197 198 * If `is_X == True`, callers should expect the operator to have the 199 property `X`. This is a promise that should be fulfilled, but is *not* a 200 runtime assert. For example, finite floating point precision may result 201 in these promises being violated. 202 * If `is_X == False`, callers should expect the operator to not have `X`. 203 * If `is_X == None` (the default), callers should have no expectation either 204 way. 205 """ 206 207 def __init__(self, 208 num_rows, 209 batch_shape=None, 210 dtype=None, 211 is_non_singular=True, 212 is_self_adjoint=True, 213 is_positive_definite=True, 214 is_square=True, 215 assert_proper_shapes=False, 216 name="LinearOperatorIdentity"): 217 r"""Initialize a `LinearOperatorIdentity`. 218 219 The `LinearOperatorIdentity` is initialized with arguments defining `dtype` 220 and shape. 221 222 This operator is able to broadcast the leading (batch) dimensions, which 223 sometimes requires copying data. If `batch_shape` is `None`, the operator 224 can take arguments of any batch shape without copying. See examples. 225 226 Args: 227 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 228 corresponding identity matrix. 229 batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading 230 dimensions. If `None`, this operator has no leading dimensions. 231 dtype: Data type of the matrix that this operator represents. 232 is_non_singular: Expect that this operator is non-singular. 233 is_self_adjoint: Expect that this operator is equal to its hermitian 234 transpose. 235 is_positive_definite: Expect that this operator is positive definite, 236 meaning the quadratic form `x^H A x` has positive real part for all 237 nonzero `x`. Note that we do not require the operator to be 238 self-adjoint to be positive-definite. See: 239 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 240 is_square: Expect that this operator acts like square [batch] matrices. 241 assert_proper_shapes: Python `bool`. If `False`, only perform static 242 checks that initialization and method arguments have proper shape. 243 If `True`, and static checks are inconclusive, add asserts to the graph. 244 name: A name for this `LinearOperator` 245 246 Raises: 247 ValueError: If `num_rows` is determined statically to be non-scalar, or 248 negative. 249 ValueError: If `batch_shape` is determined statically to not be 1-D, or 250 negative. 251 ValueError: If any of the following is not `True`: 252 `{is_self_adjoint, is_non_singular, is_positive_definite}`. 253 """ 254 dtype = dtype or dtypes.float32 255 self._assert_proper_shapes = assert_proper_shapes 256 257 with ops.name_scope(name): 258 dtype = dtypes.as_dtype(dtype) 259 if not is_self_adjoint: 260 raise ValueError("An identity operator is always self adjoint.") 261 if not is_non_singular: 262 raise ValueError("An identity operator is always non-singular.") 263 if not is_positive_definite: 264 raise ValueError("An identity operator is always positive-definite.") 265 if not is_square: 266 raise ValueError("An identity operator is always square.") 267 268 super(LinearOperatorIdentity, self).__init__( 269 dtype=dtype, 270 is_non_singular=is_non_singular, 271 is_self_adjoint=is_self_adjoint, 272 is_positive_definite=is_positive_definite, 273 is_square=is_square, 274 name=name) 275 276 self._num_rows = linear_operator_util.shape_tensor( 277 num_rows, name="num_rows") 278 self._num_rows_static = tensor_util.constant_value(self._num_rows) 279 self._check_num_rows_possibly_add_asserts() 280 281 if batch_shape is None: 282 self._batch_shape_arg = None 283 else: 284 self._batch_shape_arg = linear_operator_util.shape_tensor( 285 batch_shape, name="batch_shape_arg") 286 self._batch_shape_static = tensor_util.constant_value( 287 self._batch_shape_arg) 288 self._check_batch_shape_possibly_add_asserts() 289 290 def _shape(self): 291 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 292 self._num_rows_static)) 293 if self._batch_shape_arg is None: 294 return matrix_shape 295 296 batch_shape = tensor_shape.TensorShape(self._batch_shape_static) 297 return batch_shape.concatenate(matrix_shape) 298 299 def _shape_tensor(self): 300 matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) 301 if self._batch_shape_arg is None: 302 return matrix_shape 303 304 return array_ops.concat((self._batch_shape_arg, matrix_shape), 0) 305 306 def _assert_non_singular(self): 307 return control_flow_ops.no_op("assert_non_singular") 308 309 def _assert_positive_definite(self): 310 return control_flow_ops.no_op("assert_positive_definite") 311 312 def _assert_self_adjoint(self): 313 return control_flow_ops.no_op("assert_self_adjoint") 314 315 def _possibly_broadcast_batch_shape(self, x): 316 """Return 'x', possibly after broadcasting the leading dimensions.""" 317 # If we have no batch shape, our batch shape broadcasts with everything! 318 if self._batch_shape_arg is None: 319 return x 320 321 # Static attempt: 322 # If we determine that no broadcast is necessary, pass x through 323 # If we need a broadcast, add to an array of zeros. 324 # 325 # special_shape is the shape that, when broadcast with x's shape, will give 326 # the correct broadcast_shape. Note that 327 # We have already verified the second to last dimension of self.shape 328 # matches x's shape in assert_compatible_matrix_dimensions. 329 # Also, the final dimension of 'x' can have any shape. 330 # Therefore, the final two dimensions of special_shape are 1's. 331 special_shape = self.batch_shape.concatenate([1, 1]) 332 bshape = array_ops.broadcast_static_shape(x.get_shape(), special_shape) 333 if special_shape.is_fully_defined(): 334 # bshape.is_fully_defined iff special_shape.is_fully_defined. 335 if bshape == x.get_shape(): 336 return x 337 # Use the built in broadcasting of addition. 338 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 339 return x + zeros 340 341 # Dynamic broadcast: 342 # Always add to an array of zeros, rather than using a "cond", since a 343 # cond would require copying data from GPU --> CPU. 344 special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0) 345 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 346 return x + zeros 347 348 def _matmul(self, x, adjoint=False, adjoint_arg=False): 349 # Note that adjoint has no effect since this matrix is self-adjoint. 350 x = linalg.adjoint(x) if adjoint_arg else x 351 if self._assert_proper_shapes: 352 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 353 x = control_flow_ops.with_dependencies([aps], x) 354 return self._possibly_broadcast_batch_shape(x) 355 356 def _determinant(self): 357 return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype) 358 359 def _log_abs_determinant(self): 360 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 361 362 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 363 return self._matmul(rhs, adjoint_arg=adjoint_arg) 364 365 def _trace(self): 366 # Get Tensor of all ones of same shape as self.batch_shape. 367 if self.batch_shape.is_fully_defined(): 368 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 369 else: 370 batch_of_ones = array_ops.ones( 371 shape=self.batch_shape_tensor(), dtype=self.dtype) 372 373 if self._min_matrix_dim() is not None: 374 return self._min_matrix_dim() * batch_of_ones 375 else: 376 return (math_ops.cast(self._min_matrix_dim_tensor(), self.dtype) * 377 batch_of_ones) 378 379 def _diag_part(self): 380 return self._ones_diag() 381 382 def add_to_tensor(self, mat, name="add_to_tensor"): 383 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 384 385 Args: 386 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 387 name: A name to give this `Op`. 388 389 Returns: 390 A `Tensor` with broadcast shape and same `dtype` as `self`. 391 """ 392 with self._name_scope(name, values=[mat]): 393 mat = ops.convert_to_tensor(mat, name="mat") 394 mat_diag = array_ops.matrix_diag_part(mat) 395 new_diag = 1 + mat_diag 396 return array_ops.matrix_set_diag(mat, new_diag) 397 398 def _check_num_rows_possibly_add_asserts(self): 399 """Static check of init arg `num_rows`, possibly add asserts.""" 400 # Possibly add asserts. 401 if self._assert_proper_shapes: 402 self._num_rows = control_flow_ops.with_dependencies([ 403 check_ops.assert_rank( 404 self._num_rows, 405 0, 406 message="Argument num_rows must be a 0-D Tensor."), 407 check_ops.assert_non_negative( 408 self._num_rows, 409 message="Argument num_rows must be non-negative."), 410 ], self._num_rows) 411 412 # Static checks. 413 if not self._num_rows.dtype.is_integer: 414 raise TypeError("Argument num_rows must be integer type. Found:" 415 " %s" % self._num_rows) 416 417 num_rows_static = self._num_rows_static 418 419 if num_rows_static is None: 420 return # Cannot do any other static checks. 421 422 if num_rows_static.ndim != 0: 423 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 424 " %s" % num_rows_static) 425 426 if num_rows_static < 0: 427 raise ValueError("Argument num_rows must be non-negative. Found:" 428 " %s" % num_rows_static) 429 430 def _check_batch_shape_possibly_add_asserts(self): 431 """Static check of init arg `batch_shape`, possibly add asserts.""" 432 if self._batch_shape_arg is None: 433 return 434 435 # Possibly add asserts 436 if self._assert_proper_shapes: 437 self._batch_shape_arg = control_flow_ops.with_dependencies([ 438 check_ops.assert_rank( 439 self._batch_shape_arg, 440 1, 441 message="Argument batch_shape must be a 1-D Tensor."), 442 check_ops.assert_non_negative( 443 self._batch_shape_arg, 444 message="Argument batch_shape must be non-negative."), 445 ], self._batch_shape_arg) 446 447 # Static checks 448 if not self._batch_shape_arg.dtype.is_integer: 449 raise TypeError("Argument batch_shape must be integer type. Found:" 450 " %s" % self._batch_shape_arg) 451 452 if self._batch_shape_static is None: 453 return # Cannot do any other static checks. 454 455 if self._batch_shape_static.ndim != 1: 456 raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:" 457 " %s" % self._batch_shape_static) 458 459 if np.any(self._batch_shape_static < 0): 460 raise ValueError("Argument batch_shape must be non-negative. Found:" 461 "%s" % self._batch_shape_static) 462 463 464@tf_export("linalg.LinearOperatorScaledIdentity") 465class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): 466 """`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`. 467 468 This operator acts like a scaled [batch] identity matrix `A` with shape 469 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 470 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 471 a scaled version of the `N x N` identity matrix. 472 473 `LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier` 474 (a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the 475 `multiplier` determines the scale for each batch member. 476 477 ```python 478 # Create a 2 x 2 scaled identity matrix. 479 operator = LinearOperatorIdentity(num_rows=2, multiplier=3.) 480 481 operator.to_dense() 482 ==> [[3., 0.] 483 [0., 3.]] 484 485 operator.shape 486 ==> [2, 2] 487 488 operator.log_abs_determinant() 489 ==> 2 * Log[3] 490 491 x = ... Shape [2, 4] Tensor 492 operator.matmul(x) 493 ==> 3 * x 494 495 y = tf.random_normal(shape=[3, 2, 4]) 496 # Note that y.shape is compatible with operator.shape because operator.shape 497 # is broadcast to [3, 2, 2]. 498 x = operator.solve(y) 499 ==> 3 * x 500 501 # Create a 2-batch of 2x2 identity matrices 502 operator = LinearOperatorIdentity(num_rows=2, multiplier=5.) 503 operator.to_dense() 504 ==> [[[5., 0.] 505 [0., 5.]], 506 [[5., 0.] 507 [0., 5.]]] 508 509 x = ... Shape [2, 2, 3] 510 operator.matmul(x) 511 ==> 5 * x 512 513 # Here the operator and x have different batch_shape, and are broadcast. 514 x = ... Shape [1, 2, 3] 515 operator.matmul(x) 516 ==> 5 * x 517 ``` 518 519 ### Shape compatibility 520 521 This operator acts on [batch] matrix with compatible shape. 522 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 523 524 ``` 525 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 526 x.shape = [C1,...,Cc] + [N, R], 527 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 528 ``` 529 530 ### Performance 531 532 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 533 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 534 * `operator.determinant()` is `O(D1*...*Dd)` 535 536 #### Matrix property hints 537 538 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 539 for `X = non_singular, self_adjoint, positive_definite, square`. 540 These have the following meaning 541 * If `is_X == True`, callers should expect the operator to have the 542 property `X`. This is a promise that should be fulfilled, but is *not* a 543 runtime assert. For example, finite floating point precision may result 544 in these promises being violated. 545 * If `is_X == False`, callers should expect the operator to not have `X`. 546 * If `is_X == None` (the default), callers should have no expectation either 547 way. 548 """ 549 550 def __init__(self, 551 num_rows, 552 multiplier, 553 is_non_singular=None, 554 is_self_adjoint=None, 555 is_positive_definite=None, 556 is_square=True, 557 assert_proper_shapes=False, 558 name="LinearOperatorScaledIdentity"): 559 r"""Initialize a `LinearOperatorScaledIdentity`. 560 561 The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which 562 determines the size of each identity matrix, and a `multiplier`, 563 which defines `dtype`, batch shape, and scale of each matrix. 564 565 This operator is able to broadcast the leading (batch) dimensions. 566 567 Args: 568 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 569 corresponding identity matrix. 570 multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar). 571 is_non_singular: Expect that this operator is non-singular. 572 is_self_adjoint: Expect that this operator is equal to its hermitian 573 transpose. 574 is_positive_definite: Expect that this operator is positive definite, 575 meaning the quadratic form `x^H A x` has positive real part for all 576 nonzero `x`. Note that we do not require the operator to be 577 self-adjoint to be positive-definite. See: 578 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 579 is_square: Expect that this operator acts like square [batch] matrices. 580 assert_proper_shapes: Python `bool`. If `False`, only perform static 581 checks that initialization and method arguments have proper shape. 582 If `True`, and static checks are inconclusive, add asserts to the graph. 583 name: A name for this `LinearOperator` 584 585 Raises: 586 ValueError: If `num_rows` is determined statically to be non-scalar, or 587 negative. 588 """ 589 self._assert_proper_shapes = assert_proper_shapes 590 591 with ops.name_scope(name, values=[multiplier, num_rows]): 592 self._multiplier = ops.convert_to_tensor(multiplier, name="multiplier") 593 594 # Check and auto-set hints. 595 if not self._multiplier.dtype.is_complex: 596 if is_self_adjoint is False: # pylint: disable=g-bool-id-comparison 597 raise ValueError("A real diagonal operator is always self adjoint.") 598 else: 599 is_self_adjoint = True 600 601 if not is_square: 602 raise ValueError("A ScaledIdentity operator is always square.") 603 604 super(LinearOperatorScaledIdentity, self).__init__( 605 dtype=self._multiplier.dtype, 606 is_non_singular=is_non_singular, 607 is_self_adjoint=is_self_adjoint, 608 is_positive_definite=is_positive_definite, 609 is_square=is_square, 610 name=name) 611 612 # Shape [B1,...Bb, 1, 1] 613 self._multiplier_matrix = array_ops.expand_dims( 614 array_ops.expand_dims(self.multiplier, -1), -1) 615 self._multiplier_matrix_conj = math_ops.conj(self._multiplier_matrix) 616 self._abs_multiplier = math_ops.abs(self.multiplier) 617 618 self._num_rows = linear_operator_util.shape_tensor( 619 num_rows, name="num_rows") 620 self._num_rows_static = tensor_util.constant_value(self._num_rows) 621 self._check_num_rows_possibly_add_asserts() 622 self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype) 623 self._num_rows_cast_to_real_dtype = math_ops.cast(self._num_rows, 624 self.dtype.real_dtype) 625 626 def _shape(self): 627 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 628 self._num_rows_static)) 629 630 batch_shape = self.multiplier.get_shape() 631 return batch_shape.concatenate(matrix_shape) 632 633 def _shape_tensor(self): 634 matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) 635 636 batch_shape = array_ops.shape(self.multiplier) 637 return array_ops.concat((batch_shape, matrix_shape), 0) 638 639 def _assert_non_singular(self): 640 return check_ops.assert_positive( 641 math_ops.abs(self.multiplier), message="LinearOperator was singular") 642 643 def _assert_positive_definite(self): 644 return check_ops.assert_positive( 645 math_ops.real(self.multiplier), 646 message="LinearOperator was not positive definite.") 647 648 def _assert_self_adjoint(self): 649 imag_multiplier = math_ops.imag(self.multiplier) 650 return check_ops.assert_equal( 651 array_ops.zeros_like(imag_multiplier), 652 imag_multiplier, 653 message="LinearOperator was not self-adjoint") 654 655 def _matmul(self, x, adjoint=False, adjoint_arg=False): 656 x = linalg.adjoint(x) if adjoint_arg else x 657 if adjoint: 658 matrix = self._multiplier_matrix_conj 659 else: 660 matrix = self._multiplier_matrix 661 if self._assert_proper_shapes: 662 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 663 x = control_flow_ops.with_dependencies([aps], x) 664 return x * matrix 665 666 def _determinant(self): 667 return self.multiplier**self._num_rows_cast_to_dtype 668 669 def _log_abs_determinant(self): 670 return self._num_rows_cast_to_real_dtype * math_ops.log( 671 self._abs_multiplier) 672 673 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 674 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 675 if adjoint: 676 matrix = self._multiplier_matrix_conj 677 else: 678 matrix = self._multiplier_matrix 679 if self._assert_proper_shapes: 680 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs) 681 rhs = control_flow_ops.with_dependencies([aps], rhs) 682 return rhs / matrix 683 684 def _trace(self): 685 # Get Tensor of all ones of same shape as self.batch_shape. 686 if self.batch_shape.is_fully_defined(): 687 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 688 else: 689 batch_of_ones = array_ops.ones( 690 shape=self.batch_shape_tensor(), dtype=self.dtype) 691 692 if self._min_matrix_dim() is not None: 693 return self.multiplier * self._min_matrix_dim() * batch_of_ones 694 else: 695 return (self.multiplier * math_ops.cast(self._min_matrix_dim_tensor(), 696 self.dtype) * batch_of_ones) 697 698 def _diag_part(self): 699 return self._ones_diag() * self.multiplier[..., array_ops.newaxis] 700 701 def add_to_tensor(self, mat, name="add_to_tensor"): 702 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 703 704 Args: 705 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 706 name: A name to give this `Op`. 707 708 Returns: 709 A `Tensor` with broadcast shape and same `dtype` as `self`. 710 """ 711 with self._name_scope(name, values=[mat]): 712 # Shape [B1,...,Bb, 1] 713 multiplier_vector = array_ops.expand_dims(self.multiplier, -1) 714 715 # Shape [C1,...,Cc, M, M] 716 mat = ops.convert_to_tensor(mat, name="mat") 717 718 # Shape [C1,...,Cc, M] 719 mat_diag = array_ops.matrix_diag_part(mat) 720 721 # multiplier_vector broadcasts here. 722 new_diag = multiplier_vector + mat_diag 723 724 return array_ops.matrix_set_diag(mat, new_diag) 725 726 @property 727 def multiplier(self): 728 """The [batch] scalar `Tensor`, `c` in `cI`.""" 729 return self._multiplier 730