1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like the identity matrix.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import check_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops.linalg import linalg_impl as linalg 32from tensorflow.python.ops.linalg import linear_operator 33from tensorflow.python.ops.linalg import linear_operator_util 34from tensorflow.python.util.tf_export import tf_export 35 36__all__ = [ 37 "LinearOperatorIdentity", 38 "LinearOperatorScaledIdentity", 39] 40 41 42class BaseLinearOperatorIdentity(linear_operator.LinearOperator): 43 """Base class for Identity operators.""" 44 45 def _check_num_rows_possibly_add_asserts(self): 46 """Static check of init arg `num_rows`, possibly add asserts.""" 47 # Possibly add asserts. 48 if self._assert_proper_shapes: 49 self._num_rows = control_flow_ops.with_dependencies([ 50 check_ops.assert_rank( 51 self._num_rows, 52 0, 53 message="Argument num_rows must be a 0-D Tensor."), 54 check_ops.assert_non_negative( 55 self._num_rows, 56 message="Argument num_rows must be non-negative."), 57 ], self._num_rows) 58 59 # Static checks. 60 if not self._num_rows.dtype.is_integer: 61 raise TypeError("Argument num_rows must be integer type. Found:" 62 " %s" % self._num_rows) 63 64 num_rows_static = self._num_rows_static 65 66 if num_rows_static is None: 67 return # Cannot do any other static checks. 68 69 if num_rows_static.ndim != 0: 70 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 71 " %s" % num_rows_static) 72 73 if num_rows_static < 0: 74 raise ValueError("Argument num_rows must be non-negative. Found:" 75 " %s" % num_rows_static) 76 77 def _min_matrix_dim(self): 78 """Minimum of domain/range dimension, if statically available, else None.""" 79 domain_dim = tensor_shape.dimension_value(self.domain_dimension) 80 range_dim = tensor_shape.dimension_value(self.range_dimension) 81 if domain_dim is None or range_dim is None: 82 return None 83 return min(domain_dim, range_dim) 84 85 def _min_matrix_dim_tensor(self): 86 """Minimum of domain/range dimension, as a tensor.""" 87 return math_ops.reduce_min(self.shape_tensor()[-2:]) 88 89 def _ones_diag(self): 90 """Returns the diagonal of this operator as all ones.""" 91 if self.shape.is_fully_defined(): 92 d_shape = self.batch_shape.concatenate([self._min_matrix_dim()]) 93 else: 94 d_shape = array_ops.concat( 95 [self.batch_shape_tensor(), 96 [self._min_matrix_dim_tensor()]], axis=0) 97 98 return array_ops.ones(shape=d_shape, dtype=self.dtype) 99 100 101@tf_export("linalg.LinearOperatorIdentity") 102@linear_operator.make_composite_tensor 103class LinearOperatorIdentity(BaseLinearOperatorIdentity): 104 """`LinearOperator` acting like a [batch] square identity matrix. 105 106 This operator acts like a [batch] identity matrix `A` with shape 107 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 108 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 109 an `N x N` matrix. This matrix `A` is not materialized, but for 110 purposes of broadcasting this shape will be relevant. 111 112 `LinearOperatorIdentity` is initialized with `num_rows`, and optionally 113 `batch_shape`, and `dtype` arguments. If `batch_shape` is `None`, this 114 operator efficiently passes through all arguments. If `batch_shape` is 115 provided, broadcasting may occur, which will require making copies. 116 117 ```python 118 # Create a 2 x 2 identity matrix. 119 operator = LinearOperatorIdentity(num_rows=2, dtype=tf.float32) 120 121 operator.to_dense() 122 ==> [[1., 0.] 123 [0., 1.]] 124 125 operator.shape 126 ==> [2, 2] 127 128 operator.log_abs_determinant() 129 ==> 0. 130 131 x = ... Shape [2, 4] Tensor 132 operator.matmul(x) 133 ==> Shape [2, 4] Tensor, same as x. 134 135 y = tf.random.normal(shape=[3, 2, 4]) 136 # Note that y.shape is compatible with operator.shape because operator.shape 137 # is broadcast to [3, 2, 2]. 138 # This broadcast does NOT require copying data, since we can infer that y 139 # will be passed through without changing shape. We are always able to infer 140 # this if the operator has no batch_shape. 141 x = operator.solve(y) 142 ==> Shape [3, 2, 4] Tensor, same as y. 143 144 # Create a 2-batch of 2x2 identity matrices 145 operator = LinearOperatorIdentity(num_rows=2, batch_shape=[2]) 146 operator.to_dense() 147 ==> [[[1., 0.] 148 [0., 1.]], 149 [[1., 0.] 150 [0., 1.]]] 151 152 # Here, even though the operator has a batch shape, the input is the same as 153 # the output, so x can be passed through without a copy. The operator is able 154 # to detect that no broadcast is necessary because both x and the operator 155 # have statically defined shape. 156 x = ... Shape [2, 2, 3] 157 operator.matmul(x) 158 ==> Shape [2, 2, 3] Tensor, same as x 159 160 # Here the operator and x have different batch_shape, and are broadcast. 161 # This requires a copy, since the output is different size than the input. 162 x = ... Shape [1, 2, 3] 163 operator.matmul(x) 164 ==> Shape [2, 2, 3] Tensor, equal to [x, x] 165 ``` 166 167 ### Shape compatibility 168 169 This operator acts on [batch] matrix with compatible shape. 170 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 171 172 ``` 173 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 174 x.shape = [C1,...,Cc] + [N, R], 175 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 176 ``` 177 178 ### Performance 179 180 If `batch_shape` initialization arg is `None`: 181 182 * `operator.matmul(x)` is `O(1)` 183 * `operator.solve(x)` is `O(1)` 184 * `operator.determinant()` is `O(1)` 185 186 If `batch_shape` initialization arg is provided, and static checks cannot 187 rule out the need to broadcast: 188 189 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 190 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 191 * `operator.determinant()` is `O(B1*...*Bb)` 192 193 #### Matrix property hints 194 195 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 196 for `X = non_singular, self_adjoint, positive_definite, square`. 197 These have the following meaning: 198 199 * If `is_X == True`, callers should expect the operator to have the 200 property `X`. This is a promise that should be fulfilled, but is *not* a 201 runtime assert. For example, finite floating point precision may result 202 in these promises being violated. 203 * If `is_X == False`, callers should expect the operator to not have `X`. 204 * If `is_X == None` (the default), callers should have no expectation either 205 way. 206 """ 207 208 def __init__(self, 209 num_rows, 210 batch_shape=None, 211 dtype=None, 212 is_non_singular=True, 213 is_self_adjoint=True, 214 is_positive_definite=True, 215 is_square=True, 216 assert_proper_shapes=False, 217 name="LinearOperatorIdentity"): 218 r"""Initialize a `LinearOperatorIdentity`. 219 220 The `LinearOperatorIdentity` is initialized with arguments defining `dtype` 221 and shape. 222 223 This operator is able to broadcast the leading (batch) dimensions, which 224 sometimes requires copying data. If `batch_shape` is `None`, the operator 225 can take arguments of any batch shape without copying. See examples. 226 227 Args: 228 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 229 corresponding identity matrix. 230 batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading 231 dimensions. If `None`, this operator has no leading dimensions. 232 dtype: Data type of the matrix that this operator represents. 233 is_non_singular: Expect that this operator is non-singular. 234 is_self_adjoint: Expect that this operator is equal to its hermitian 235 transpose. 236 is_positive_definite: Expect that this operator is positive definite, 237 meaning the quadratic form `x^H A x` has positive real part for all 238 nonzero `x`. Note that we do not require the operator to be 239 self-adjoint to be positive-definite. See: 240 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 241 is_square: Expect that this operator acts like square [batch] matrices. 242 assert_proper_shapes: Python `bool`. If `False`, only perform static 243 checks that initialization and method arguments have proper shape. 244 If `True`, and static checks are inconclusive, add asserts to the graph. 245 name: A name for this `LinearOperator` 246 247 Raises: 248 ValueError: If `num_rows` is determined statically to be non-scalar, or 249 negative. 250 ValueError: If `batch_shape` is determined statically to not be 1-D, or 251 negative. 252 ValueError: If any of the following is not `True`: 253 `{is_self_adjoint, is_non_singular, is_positive_definite}`. 254 TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable). 255 """ 256 parameters = dict( 257 num_rows=num_rows, 258 batch_shape=batch_shape, 259 dtype=dtype, 260 is_non_singular=is_non_singular, 261 is_self_adjoint=is_self_adjoint, 262 is_positive_definite=is_positive_definite, 263 is_square=is_square, 264 assert_proper_shapes=assert_proper_shapes, 265 name=name) 266 267 dtype = dtype or dtypes.float32 268 self._assert_proper_shapes = assert_proper_shapes 269 270 with ops.name_scope(name): 271 dtype = dtypes.as_dtype(dtype) 272 if not is_self_adjoint: 273 raise ValueError("An identity operator is always self adjoint.") 274 if not is_non_singular: 275 raise ValueError("An identity operator is always non-singular.") 276 if not is_positive_definite: 277 raise ValueError("An identity operator is always positive-definite.") 278 if not is_square: 279 raise ValueError("An identity operator is always square.") 280 281 super(LinearOperatorIdentity, self).__init__( 282 dtype=dtype, 283 is_non_singular=is_non_singular, 284 is_self_adjoint=is_self_adjoint, 285 is_positive_definite=is_positive_definite, 286 is_square=is_square, 287 parameters=parameters, 288 name=name) 289 290 linear_operator_util.assert_not_ref_type(num_rows, "num_rows") 291 linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape") 292 293 self._num_rows = linear_operator_util.shape_tensor( 294 num_rows, name="num_rows") 295 self._num_rows_static = tensor_util.constant_value(self._num_rows) 296 self._check_num_rows_possibly_add_asserts() 297 298 if batch_shape is None: 299 self._batch_shape_arg = None 300 else: 301 self._batch_shape_arg = linear_operator_util.shape_tensor( 302 batch_shape, name="batch_shape_arg") 303 self._batch_shape_static = tensor_util.constant_value( 304 self._batch_shape_arg) 305 self._check_batch_shape_possibly_add_asserts() 306 307 def _shape(self): 308 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 309 self._num_rows_static)) 310 if self._batch_shape_arg is None: 311 return matrix_shape 312 313 batch_shape = tensor_shape.TensorShape(self._batch_shape_static) 314 return batch_shape.concatenate(matrix_shape) 315 316 def _shape_tensor(self): 317 matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) 318 if self._batch_shape_arg is None: 319 return matrix_shape 320 321 return array_ops.concat((self._batch_shape_arg, matrix_shape), 0) 322 323 def _assert_non_singular(self): 324 return control_flow_ops.no_op("assert_non_singular") 325 326 def _assert_positive_definite(self): 327 return control_flow_ops.no_op("assert_positive_definite") 328 329 def _assert_self_adjoint(self): 330 return control_flow_ops.no_op("assert_self_adjoint") 331 332 def _possibly_broadcast_batch_shape(self, x): 333 """Return 'x', possibly after broadcasting the leading dimensions.""" 334 # If we have no batch shape, our batch shape broadcasts with everything! 335 if self._batch_shape_arg is None: 336 return x 337 338 # Static attempt: 339 # If we determine that no broadcast is necessary, pass x through 340 # If we need a broadcast, add to an array of zeros. 341 # 342 # special_shape is the shape that, when broadcast with x's shape, will give 343 # the correct broadcast_shape. Note that 344 # We have already verified the second to last dimension of self.shape 345 # matches x's shape in assert_compatible_matrix_dimensions. 346 # Also, the final dimension of 'x' can have any shape. 347 # Therefore, the final two dimensions of special_shape are 1's. 348 special_shape = self.batch_shape.concatenate([1, 1]) 349 bshape = array_ops.broadcast_static_shape(x.shape, special_shape) 350 if special_shape.is_fully_defined(): 351 # bshape.is_fully_defined iff special_shape.is_fully_defined. 352 if bshape == x.shape: 353 return x 354 # Use the built in broadcasting of addition. 355 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 356 return x + zeros 357 358 # Dynamic broadcast: 359 # Always add to an array of zeros, rather than using a "cond", since a 360 # cond would require copying data from GPU --> CPU. 361 special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0) 362 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 363 return x + zeros 364 365 def _matmul(self, x, adjoint=False, adjoint_arg=False): 366 # Note that adjoint has no effect since this matrix is self-adjoint. 367 x = linalg.adjoint(x) if adjoint_arg else x 368 if self._assert_proper_shapes: 369 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 370 x = control_flow_ops.with_dependencies([aps], x) 371 return self._possibly_broadcast_batch_shape(x) 372 373 def _determinant(self): 374 return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype) 375 376 def _log_abs_determinant(self): 377 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 378 379 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 380 return self._matmul(rhs, adjoint_arg=adjoint_arg) 381 382 def _trace(self): 383 # Get Tensor of all ones of same shape as self.batch_shape. 384 if self.batch_shape.is_fully_defined(): 385 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 386 else: 387 batch_of_ones = array_ops.ones( 388 shape=self.batch_shape_tensor(), dtype=self.dtype) 389 390 if self._min_matrix_dim() is not None: 391 return self._min_matrix_dim() * batch_of_ones 392 else: 393 return (math_ops.cast(self._min_matrix_dim_tensor(), self.dtype) * 394 batch_of_ones) 395 396 def _diag_part(self): 397 return self._ones_diag() 398 399 def add_to_tensor(self, mat, name="add_to_tensor"): 400 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 401 402 Args: 403 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 404 name: A name to give this `Op`. 405 406 Returns: 407 A `Tensor` with broadcast shape and same `dtype` as `self`. 408 """ 409 with self._name_scope(name): # pylint: disable=not-callable 410 mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat") 411 mat_diag = array_ops.matrix_diag_part(mat) 412 new_diag = 1 + mat_diag 413 return array_ops.matrix_set_diag(mat, new_diag) 414 415 def _eigvals(self): 416 return self._ones_diag() 417 418 def _cond(self): 419 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) 420 421 def _check_num_rows_possibly_add_asserts(self): 422 """Static check of init arg `num_rows`, possibly add asserts.""" 423 # Possibly add asserts. 424 if self._assert_proper_shapes: 425 self._num_rows = control_flow_ops.with_dependencies([ 426 check_ops.assert_rank( 427 self._num_rows, 428 0, 429 message="Argument num_rows must be a 0-D Tensor."), 430 check_ops.assert_non_negative( 431 self._num_rows, 432 message="Argument num_rows must be non-negative."), 433 ], self._num_rows) 434 435 # Static checks. 436 if not self._num_rows.dtype.is_integer: 437 raise TypeError("Argument num_rows must be integer type. Found:" 438 " %s" % self._num_rows) 439 440 num_rows_static = self._num_rows_static 441 442 if num_rows_static is None: 443 return # Cannot do any other static checks. 444 445 if num_rows_static.ndim != 0: 446 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 447 " %s" % num_rows_static) 448 449 if num_rows_static < 0: 450 raise ValueError("Argument num_rows must be non-negative. Found:" 451 " %s" % num_rows_static) 452 453 def _check_batch_shape_possibly_add_asserts(self): 454 """Static check of init arg `batch_shape`, possibly add asserts.""" 455 if self._batch_shape_arg is None: 456 return 457 458 # Possibly add asserts 459 if self._assert_proper_shapes: 460 self._batch_shape_arg = control_flow_ops.with_dependencies([ 461 check_ops.assert_rank( 462 self._batch_shape_arg, 463 1, 464 message="Argument batch_shape must be a 1-D Tensor."), 465 check_ops.assert_non_negative( 466 self._batch_shape_arg, 467 message="Argument batch_shape must be non-negative."), 468 ], self._batch_shape_arg) 469 470 # Static checks 471 if not self._batch_shape_arg.dtype.is_integer: 472 raise TypeError("Argument batch_shape must be integer type. Found:" 473 " %s" % self._batch_shape_arg) 474 475 if self._batch_shape_static is None: 476 return # Cannot do any other static checks. 477 478 if self._batch_shape_static.ndim != 1: 479 raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:" 480 " %s" % self._batch_shape_static) 481 482 if np.any(self._batch_shape_static < 0): 483 raise ValueError("Argument batch_shape must be non-negative. Found:" 484 "%s" % self._batch_shape_static) 485 486 @property 487 def _composite_tensor_prefer_static_fields(self): 488 return ("num_rows", "batch_shape") 489 490 @property 491 def _composite_tensor_fields(self): 492 return ("num_rows", "batch_shape", "dtype", "assert_proper_shapes") 493 494 495@tf_export("linalg.LinearOperatorScaledIdentity") 496@linear_operator.make_composite_tensor 497class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): 498 """`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`. 499 500 This operator acts like a scaled [batch] identity matrix `A` with shape 501 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 502 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 503 a scaled version of the `N x N` identity matrix. 504 505 `LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier` 506 (a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the 507 `multiplier` determines the scale for each batch member. 508 509 ```python 510 # Create a 2 x 2 scaled identity matrix. 511 operator = LinearOperatorIdentity(num_rows=2, multiplier=3.) 512 513 operator.to_dense() 514 ==> [[3., 0.] 515 [0., 3.]] 516 517 operator.shape 518 ==> [2, 2] 519 520 operator.log_abs_determinant() 521 ==> 2 * Log[3] 522 523 x = ... Shape [2, 4] Tensor 524 operator.matmul(x) 525 ==> 3 * x 526 527 y = tf.random.normal(shape=[3, 2, 4]) 528 # Note that y.shape is compatible with operator.shape because operator.shape 529 # is broadcast to [3, 2, 2]. 530 x = operator.solve(y) 531 ==> 3 * x 532 533 # Create a 2-batch of 2x2 identity matrices 534 operator = LinearOperatorIdentity(num_rows=2, multiplier=5.) 535 operator.to_dense() 536 ==> [[[5., 0.] 537 [0., 5.]], 538 [[5., 0.] 539 [0., 5.]]] 540 541 x = ... Shape [2, 2, 3] 542 operator.matmul(x) 543 ==> 5 * x 544 545 # Here the operator and x have different batch_shape, and are broadcast. 546 x = ... Shape [1, 2, 3] 547 operator.matmul(x) 548 ==> 5 * x 549 ``` 550 551 ### Shape compatibility 552 553 This operator acts on [batch] matrix with compatible shape. 554 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 555 556 ``` 557 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 558 x.shape = [C1,...,Cc] + [N, R], 559 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 560 ``` 561 562 ### Performance 563 564 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 565 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 566 * `operator.determinant()` is `O(D1*...*Dd)` 567 568 #### Matrix property hints 569 570 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 571 for `X = non_singular, self_adjoint, positive_definite, square`. 572 These have the following meaning 573 * If `is_X == True`, callers should expect the operator to have the 574 property `X`. This is a promise that should be fulfilled, but is *not* a 575 runtime assert. For example, finite floating point precision may result 576 in these promises being violated. 577 * If `is_X == False`, callers should expect the operator to not have `X`. 578 * If `is_X == None` (the default), callers should have no expectation either 579 way. 580 """ 581 582 def __init__(self, 583 num_rows, 584 multiplier, 585 is_non_singular=None, 586 is_self_adjoint=None, 587 is_positive_definite=None, 588 is_square=True, 589 assert_proper_shapes=False, 590 name="LinearOperatorScaledIdentity"): 591 r"""Initialize a `LinearOperatorScaledIdentity`. 592 593 The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which 594 determines the size of each identity matrix, and a `multiplier`, 595 which defines `dtype`, batch shape, and scale of each matrix. 596 597 This operator is able to broadcast the leading (batch) dimensions. 598 599 Args: 600 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 601 corresponding identity matrix. 602 multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar). 603 is_non_singular: Expect that this operator is non-singular. 604 is_self_adjoint: Expect that this operator is equal to its hermitian 605 transpose. 606 is_positive_definite: Expect that this operator is positive definite, 607 meaning the quadratic form `x^H A x` has positive real part for all 608 nonzero `x`. Note that we do not require the operator to be 609 self-adjoint to be positive-definite. See: 610 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 611 is_square: Expect that this operator acts like square [batch] matrices. 612 assert_proper_shapes: Python `bool`. If `False`, only perform static 613 checks that initialization and method arguments have proper shape. 614 If `True`, and static checks are inconclusive, add asserts to the graph. 615 name: A name for this `LinearOperator` 616 617 Raises: 618 ValueError: If `num_rows` is determined statically to be non-scalar, or 619 negative. 620 """ 621 parameters = dict( 622 num_rows=num_rows, 623 multiplier=multiplier, 624 is_non_singular=is_non_singular, 625 is_self_adjoint=is_self_adjoint, 626 is_positive_definite=is_positive_definite, 627 is_square=is_square, 628 assert_proper_shapes=assert_proper_shapes, 629 name=name) 630 631 self._assert_proper_shapes = assert_proper_shapes 632 633 with ops.name_scope(name, values=[multiplier, num_rows]): 634 self._multiplier = linear_operator_util.convert_nonref_to_tensor( 635 multiplier, name="multiplier") 636 637 # Check and auto-set hints. 638 if not self._multiplier.dtype.is_complex: 639 if is_self_adjoint is False: # pylint: disable=g-bool-id-comparison 640 raise ValueError("A real diagonal operator is always self adjoint.") 641 else: 642 is_self_adjoint = True 643 644 if not is_square: 645 raise ValueError("A ScaledIdentity operator is always square.") 646 647 linear_operator_util.assert_not_ref_type(num_rows, "num_rows") 648 649 super(LinearOperatorScaledIdentity, self).__init__( 650 dtype=self._multiplier.dtype.base_dtype, 651 is_non_singular=is_non_singular, 652 is_self_adjoint=is_self_adjoint, 653 is_positive_definite=is_positive_definite, 654 is_square=is_square, 655 parameters=parameters, 656 name=name) 657 658 self._num_rows = linear_operator_util.shape_tensor( 659 num_rows, name="num_rows") 660 self._num_rows_static = tensor_util.constant_value(self._num_rows) 661 self._check_num_rows_possibly_add_asserts() 662 self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype) 663 self._num_rows_cast_to_real_dtype = math_ops.cast(self._num_rows, 664 self.dtype.real_dtype) 665 666 def _shape(self): 667 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 668 self._num_rows_static)) 669 670 batch_shape = self.multiplier.shape 671 return batch_shape.concatenate(matrix_shape) 672 673 def _shape_tensor(self): 674 matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) 675 676 batch_shape = array_ops.shape(self.multiplier) 677 return array_ops.concat((batch_shape, matrix_shape), 0) 678 679 def _assert_non_singular(self): 680 return check_ops.assert_positive( 681 math_ops.abs(self.multiplier), message="LinearOperator was singular") 682 683 def _assert_positive_definite(self): 684 return check_ops.assert_positive( 685 math_ops.real(self.multiplier), 686 message="LinearOperator was not positive definite.") 687 688 def _assert_self_adjoint(self): 689 imag_multiplier = math_ops.imag(self.multiplier) 690 return check_ops.assert_equal( 691 array_ops.zeros_like(imag_multiplier), 692 imag_multiplier, 693 message="LinearOperator was not self-adjoint") 694 695 def _make_multiplier_matrix(self, conjugate=False): 696 # Shape [B1,...Bb, 1, 1] 697 multiplier_matrix = array_ops.expand_dims( 698 array_ops.expand_dims(self.multiplier, -1), -1) 699 if conjugate: 700 multiplier_matrix = math_ops.conj(multiplier_matrix) 701 return multiplier_matrix 702 703 def _matmul(self, x, adjoint=False, adjoint_arg=False): 704 x = linalg.adjoint(x) if adjoint_arg else x 705 if self._assert_proper_shapes: 706 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 707 x = control_flow_ops.with_dependencies([aps], x) 708 return x * self._make_multiplier_matrix(conjugate=adjoint) 709 710 def _determinant(self): 711 return self.multiplier**self._num_rows_cast_to_dtype 712 713 def _log_abs_determinant(self): 714 return self._num_rows_cast_to_real_dtype * math_ops.log( 715 math_ops.abs(self.multiplier)) 716 717 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 718 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 719 if self._assert_proper_shapes: 720 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs) 721 rhs = control_flow_ops.with_dependencies([aps], rhs) 722 return rhs / self._make_multiplier_matrix(conjugate=adjoint) 723 724 def _trace(self): 725 # Get Tensor of all ones of same shape as self.batch_shape. 726 if self.batch_shape.is_fully_defined(): 727 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 728 else: 729 batch_of_ones = array_ops.ones( 730 shape=self.batch_shape_tensor(), dtype=self.dtype) 731 732 if self._min_matrix_dim() is not None: 733 return self.multiplier * self._min_matrix_dim() * batch_of_ones 734 else: 735 return (self.multiplier * math_ops.cast(self._min_matrix_dim_tensor(), 736 self.dtype) * batch_of_ones) 737 738 def _diag_part(self): 739 return self._ones_diag() * self.multiplier[..., array_ops.newaxis] 740 741 def add_to_tensor(self, mat, name="add_to_tensor"): 742 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 743 744 Args: 745 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 746 name: A name to give this `Op`. 747 748 Returns: 749 A `Tensor` with broadcast shape and same `dtype` as `self`. 750 """ 751 with self._name_scope(name): # pylint: disable=not-callable 752 # Shape [B1,...,Bb, 1] 753 multiplier_vector = array_ops.expand_dims(self.multiplier, -1) 754 755 # Shape [C1,...,Cc, M, M] 756 mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat") 757 758 # Shape [C1,...,Cc, M] 759 mat_diag = array_ops.matrix_diag_part(mat) 760 761 # multiplier_vector broadcasts here. 762 new_diag = multiplier_vector + mat_diag 763 764 return array_ops.matrix_set_diag(mat, new_diag) 765 766 def _eigvals(self): 767 return self._ones_diag() * self.multiplier[..., array_ops.newaxis] 768 769 def _cond(self): 770 # Condition number for a scalar time identity matrix is one, except when the 771 # scalar is zero. 772 return array_ops.where_v2( 773 math_ops.equal(self._multiplier, 0.), 774 math_ops.cast(np.nan, dtype=self.dtype), 775 math_ops.cast(1., dtype=self.dtype)) 776 777 @property 778 def multiplier(self): 779 """The [batch] scalar `Tensor`, `c` in `cI`.""" 780 return self._multiplier 781 782 @property 783 def _composite_tensor_prefer_static_fields(self): 784 return ("num_rows",) 785 786 @property 787 def _composite_tensor_fields(self): 788 return ("num_rows", "multiplier", "assert_proper_shapes") 789