1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like the identity matrix.""" 16 17import numpy as np 18 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.framework import tensor_shape 22from tensorflow.python.framework import tensor_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import check_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops.linalg import linalg_impl as linalg 28from tensorflow.python.ops.linalg import linear_operator 29from tensorflow.python.ops.linalg import linear_operator_util 30from tensorflow.python.util.tf_export import tf_export 31 32__all__ = [ 33 "LinearOperatorIdentity", 34 "LinearOperatorScaledIdentity", 35] 36 37 38class BaseLinearOperatorIdentity(linear_operator.LinearOperator): 39 """Base class for Identity operators.""" 40 41 def _check_num_rows_possibly_add_asserts(self): 42 """Static check of init arg `num_rows`, possibly add asserts.""" 43 # Possibly add asserts. 44 if self._assert_proper_shapes: 45 self._num_rows = control_flow_ops.with_dependencies([ 46 check_ops.assert_rank( 47 self._num_rows, 48 0, 49 message="Argument num_rows must be a 0-D Tensor."), 50 check_ops.assert_non_negative( 51 self._num_rows, 52 message="Argument num_rows must be non-negative."), 53 ], self._num_rows) 54 55 # Static checks. 56 if not self._num_rows.dtype.is_integer: 57 raise TypeError("Argument num_rows must be integer type. Found:" 58 " %s" % self._num_rows) 59 60 num_rows_static = self._num_rows_static 61 62 if num_rows_static is None: 63 return # Cannot do any other static checks. 64 65 if num_rows_static.ndim != 0: 66 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 67 " %s" % num_rows_static) 68 69 if num_rows_static < 0: 70 raise ValueError("Argument num_rows must be non-negative. Found:" 71 " %s" % num_rows_static) 72 73 def _min_matrix_dim(self): 74 """Minimum of domain/range dimension, if statically available, else None.""" 75 domain_dim = tensor_shape.dimension_value(self.domain_dimension) 76 range_dim = tensor_shape.dimension_value(self.range_dimension) 77 if domain_dim is None or range_dim is None: 78 return None 79 return min(domain_dim, range_dim) 80 81 def _min_matrix_dim_tensor(self): 82 """Minimum of domain/range dimension, as a tensor.""" 83 return math_ops.reduce_min(self.shape_tensor()[-2:]) 84 85 def _ones_diag(self): 86 """Returns the diagonal of this operator as all ones.""" 87 if self.shape.is_fully_defined(): 88 d_shape = self.batch_shape.concatenate([self._min_matrix_dim()]) 89 else: 90 d_shape = array_ops.concat( 91 [self.batch_shape_tensor(), 92 [self._min_matrix_dim_tensor()]], axis=0) 93 94 return array_ops.ones(shape=d_shape, dtype=self.dtype) 95 96 97@tf_export("linalg.LinearOperatorIdentity") 98@linear_operator.make_composite_tensor 99class LinearOperatorIdentity(BaseLinearOperatorIdentity): 100 """`LinearOperator` acting like a [batch] square identity matrix. 101 102 This operator acts like a [batch] identity matrix `A` with shape 103 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 104 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 105 an `N x N` matrix. This matrix `A` is not materialized, but for 106 purposes of broadcasting this shape will be relevant. 107 108 `LinearOperatorIdentity` is initialized with `num_rows`, and optionally 109 `batch_shape`, and `dtype` arguments. If `batch_shape` is `None`, this 110 operator efficiently passes through all arguments. If `batch_shape` is 111 provided, broadcasting may occur, which will require making copies. 112 113 ```python 114 # Create a 2 x 2 identity matrix. 115 operator = LinearOperatorIdentity(num_rows=2, dtype=tf.float32) 116 117 operator.to_dense() 118 ==> [[1., 0.] 119 [0., 1.]] 120 121 operator.shape 122 ==> [2, 2] 123 124 operator.log_abs_determinant() 125 ==> 0. 126 127 x = ... Shape [2, 4] Tensor 128 operator.matmul(x) 129 ==> Shape [2, 4] Tensor, same as x. 130 131 y = tf.random.normal(shape=[3, 2, 4]) 132 # Note that y.shape is compatible with operator.shape because operator.shape 133 # is broadcast to [3, 2, 2]. 134 # This broadcast does NOT require copying data, since we can infer that y 135 # will be passed through without changing shape. We are always able to infer 136 # this if the operator has no batch_shape. 137 x = operator.solve(y) 138 ==> Shape [3, 2, 4] Tensor, same as y. 139 140 # Create a 2-batch of 2x2 identity matrices 141 operator = LinearOperatorIdentity(num_rows=2, batch_shape=[2]) 142 operator.to_dense() 143 ==> [[[1., 0.] 144 [0., 1.]], 145 [[1., 0.] 146 [0., 1.]]] 147 148 # Here, even though the operator has a batch shape, the input is the same as 149 # the output, so x can be passed through without a copy. The operator is able 150 # to detect that no broadcast is necessary because both x and the operator 151 # have statically defined shape. 152 x = ... Shape [2, 2, 3] 153 operator.matmul(x) 154 ==> Shape [2, 2, 3] Tensor, same as x 155 156 # Here the operator and x have different batch_shape, and are broadcast. 157 # This requires a copy, since the output is different size than the input. 158 x = ... Shape [1, 2, 3] 159 operator.matmul(x) 160 ==> Shape [2, 2, 3] Tensor, equal to [x, x] 161 ``` 162 163 ### Shape compatibility 164 165 This operator acts on [batch] matrix with compatible shape. 166 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 167 168 ``` 169 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 170 x.shape = [C1,...,Cc] + [N, R], 171 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 172 ``` 173 174 ### Performance 175 176 If `batch_shape` initialization arg is `None`: 177 178 * `operator.matmul(x)` is `O(1)` 179 * `operator.solve(x)` is `O(1)` 180 * `operator.determinant()` is `O(1)` 181 182 If `batch_shape` initialization arg is provided, and static checks cannot 183 rule out the need to broadcast: 184 185 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 186 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 187 * `operator.determinant()` is `O(B1*...*Bb)` 188 189 #### Matrix property hints 190 191 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 192 for `X = non_singular, self_adjoint, positive_definite, square`. 193 These have the following meaning: 194 195 * If `is_X == True`, callers should expect the operator to have the 196 property `X`. This is a promise that should be fulfilled, but is *not* a 197 runtime assert. For example, finite floating point precision may result 198 in these promises being violated. 199 * If `is_X == False`, callers should expect the operator to not have `X`. 200 * If `is_X == None` (the default), callers should have no expectation either 201 way. 202 """ 203 204 def __init__(self, 205 num_rows, 206 batch_shape=None, 207 dtype=None, 208 is_non_singular=True, 209 is_self_adjoint=True, 210 is_positive_definite=True, 211 is_square=True, 212 assert_proper_shapes=False, 213 name="LinearOperatorIdentity"): 214 r"""Initialize a `LinearOperatorIdentity`. 215 216 The `LinearOperatorIdentity` is initialized with arguments defining `dtype` 217 and shape. 218 219 This operator is able to broadcast the leading (batch) dimensions, which 220 sometimes requires copying data. If `batch_shape` is `None`, the operator 221 can take arguments of any batch shape without copying. See examples. 222 223 Args: 224 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 225 corresponding identity matrix. 226 batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading 227 dimensions. If `None`, this operator has no leading dimensions. 228 dtype: Data type of the matrix that this operator represents. 229 is_non_singular: Expect that this operator is non-singular. 230 is_self_adjoint: Expect that this operator is equal to its hermitian 231 transpose. 232 is_positive_definite: Expect that this operator is positive definite, 233 meaning the quadratic form `x^H A x` has positive real part for all 234 nonzero `x`. Note that we do not require the operator to be 235 self-adjoint to be positive-definite. See: 236 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 237 is_square: Expect that this operator acts like square [batch] matrices. 238 assert_proper_shapes: Python `bool`. If `False`, only perform static 239 checks that initialization and method arguments have proper shape. 240 If `True`, and static checks are inconclusive, add asserts to the graph. 241 name: A name for this `LinearOperator` 242 243 Raises: 244 ValueError: If `num_rows` is determined statically to be non-scalar, or 245 negative. 246 ValueError: If `batch_shape` is determined statically to not be 1-D, or 247 negative. 248 ValueError: If any of the following is not `True`: 249 `{is_self_adjoint, is_non_singular, is_positive_definite}`. 250 TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable). 251 """ 252 parameters = dict( 253 num_rows=num_rows, 254 batch_shape=batch_shape, 255 dtype=dtype, 256 is_non_singular=is_non_singular, 257 is_self_adjoint=is_self_adjoint, 258 is_positive_definite=is_positive_definite, 259 is_square=is_square, 260 assert_proper_shapes=assert_proper_shapes, 261 name=name) 262 263 dtype = dtype or dtypes.float32 264 self._assert_proper_shapes = assert_proper_shapes 265 266 with ops.name_scope(name): 267 dtype = dtypes.as_dtype(dtype) 268 if not is_self_adjoint: 269 raise ValueError("An identity operator is always self adjoint.") 270 if not is_non_singular: 271 raise ValueError("An identity operator is always non-singular.") 272 if not is_positive_definite: 273 raise ValueError("An identity operator is always positive-definite.") 274 if not is_square: 275 raise ValueError("An identity operator is always square.") 276 277 super(LinearOperatorIdentity, self).__init__( 278 dtype=dtype, 279 is_non_singular=is_non_singular, 280 is_self_adjoint=is_self_adjoint, 281 is_positive_definite=is_positive_definite, 282 is_square=is_square, 283 parameters=parameters, 284 name=name) 285 286 linear_operator_util.assert_not_ref_type(num_rows, "num_rows") 287 linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape") 288 289 self._num_rows = linear_operator_util.shape_tensor( 290 num_rows, name="num_rows") 291 self._num_rows_static = tensor_util.constant_value(self._num_rows) 292 self._check_num_rows_possibly_add_asserts() 293 294 if batch_shape is None: 295 self._batch_shape_arg = None 296 else: 297 self._batch_shape_arg = linear_operator_util.shape_tensor( 298 batch_shape, name="batch_shape_arg") 299 self._batch_shape_static = tensor_util.constant_value( 300 self._batch_shape_arg) 301 self._check_batch_shape_possibly_add_asserts() 302 303 def _shape(self): 304 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 305 self._num_rows_static)) 306 if self._batch_shape_arg is None: 307 return matrix_shape 308 309 batch_shape = tensor_shape.TensorShape(self._batch_shape_static) 310 return batch_shape.concatenate(matrix_shape) 311 312 def _shape_tensor(self): 313 matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) 314 if self._batch_shape_arg is None: 315 return matrix_shape 316 317 return array_ops.concat((self._batch_shape_arg, matrix_shape), 0) 318 319 def _assert_non_singular(self): 320 return control_flow_ops.no_op("assert_non_singular") 321 322 def _assert_positive_definite(self): 323 return control_flow_ops.no_op("assert_positive_definite") 324 325 def _assert_self_adjoint(self): 326 return control_flow_ops.no_op("assert_self_adjoint") 327 328 def _possibly_broadcast_batch_shape(self, x): 329 """Return 'x', possibly after broadcasting the leading dimensions.""" 330 # If we have no batch shape, our batch shape broadcasts with everything! 331 if self._batch_shape_arg is None: 332 return x 333 334 # Static attempt: 335 # If we determine that no broadcast is necessary, pass x through 336 # If we need a broadcast, add to an array of zeros. 337 # 338 # special_shape is the shape that, when broadcast with x's shape, will give 339 # the correct broadcast_shape. Note that 340 # We have already verified the second to last dimension of self.shape 341 # matches x's shape in assert_compatible_matrix_dimensions. 342 # Also, the final dimension of 'x' can have any shape. 343 # Therefore, the final two dimensions of special_shape are 1's. 344 special_shape = self.batch_shape.concatenate([1, 1]) 345 bshape = array_ops.broadcast_static_shape(x.shape, special_shape) 346 if special_shape.is_fully_defined(): 347 # bshape.is_fully_defined iff special_shape.is_fully_defined. 348 if bshape == x.shape: 349 return x 350 # Use the built in broadcasting of addition. 351 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 352 return x + zeros 353 354 # Dynamic broadcast: 355 # Always add to an array of zeros, rather than using a "cond", since a 356 # cond would require copying data from GPU --> CPU. 357 special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0) 358 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 359 return x + zeros 360 361 def _matmul(self, x, adjoint=False, adjoint_arg=False): 362 # Note that adjoint has no effect since this matrix is self-adjoint. 363 x = linalg.adjoint(x) if adjoint_arg else x 364 if self._assert_proper_shapes: 365 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 366 x = control_flow_ops.with_dependencies([aps], x) 367 return self._possibly_broadcast_batch_shape(x) 368 369 def _determinant(self): 370 return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype) 371 372 def _log_abs_determinant(self): 373 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 374 375 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 376 return self._matmul(rhs, adjoint_arg=adjoint_arg) 377 378 def _trace(self): 379 # Get Tensor of all ones of same shape as self.batch_shape. 380 if self.batch_shape.is_fully_defined(): 381 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 382 else: 383 batch_of_ones = array_ops.ones( 384 shape=self.batch_shape_tensor(), dtype=self.dtype) 385 386 if self._min_matrix_dim() is not None: 387 return self._min_matrix_dim() * batch_of_ones 388 else: 389 return (math_ops.cast(self._min_matrix_dim_tensor(), self.dtype) * 390 batch_of_ones) 391 392 def _diag_part(self): 393 return self._ones_diag() 394 395 def add_to_tensor(self, mat, name="add_to_tensor"): 396 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 397 398 Args: 399 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 400 name: A name to give this `Op`. 401 402 Returns: 403 A `Tensor` with broadcast shape and same `dtype` as `self`. 404 """ 405 with self._name_scope(name): # pylint: disable=not-callable 406 mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat") 407 mat_diag = array_ops.matrix_diag_part(mat) 408 new_diag = 1 + mat_diag 409 return array_ops.matrix_set_diag(mat, new_diag) 410 411 def _eigvals(self): 412 return self._ones_diag() 413 414 def _cond(self): 415 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) 416 417 def _check_num_rows_possibly_add_asserts(self): 418 """Static check of init arg `num_rows`, possibly add asserts.""" 419 # Possibly add asserts. 420 if self._assert_proper_shapes: 421 self._num_rows = control_flow_ops.with_dependencies([ 422 check_ops.assert_rank( 423 self._num_rows, 424 0, 425 message="Argument num_rows must be a 0-D Tensor."), 426 check_ops.assert_non_negative( 427 self._num_rows, 428 message="Argument num_rows must be non-negative."), 429 ], self._num_rows) 430 431 # Static checks. 432 if not self._num_rows.dtype.is_integer: 433 raise TypeError("Argument num_rows must be integer type. Found:" 434 " %s" % self._num_rows) 435 436 num_rows_static = self._num_rows_static 437 438 if num_rows_static is None: 439 return # Cannot do any other static checks. 440 441 if num_rows_static.ndim != 0: 442 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 443 " %s" % num_rows_static) 444 445 if num_rows_static < 0: 446 raise ValueError("Argument num_rows must be non-negative. Found:" 447 " %s" % num_rows_static) 448 449 def _check_batch_shape_possibly_add_asserts(self): 450 """Static check of init arg `batch_shape`, possibly add asserts.""" 451 if self._batch_shape_arg is None: 452 return 453 454 # Possibly add asserts 455 if self._assert_proper_shapes: 456 self._batch_shape_arg = control_flow_ops.with_dependencies([ 457 check_ops.assert_rank( 458 self._batch_shape_arg, 459 1, 460 message="Argument batch_shape must be a 1-D Tensor."), 461 check_ops.assert_non_negative( 462 self._batch_shape_arg, 463 message="Argument batch_shape must be non-negative."), 464 ], self._batch_shape_arg) 465 466 # Static checks 467 if not self._batch_shape_arg.dtype.is_integer: 468 raise TypeError("Argument batch_shape must be integer type. Found:" 469 " %s" % self._batch_shape_arg) 470 471 if self._batch_shape_static is None: 472 return # Cannot do any other static checks. 473 474 if self._batch_shape_static.ndim != 1: 475 raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:" 476 " %s" % self._batch_shape_static) 477 478 if np.any(self._batch_shape_static < 0): 479 raise ValueError("Argument batch_shape must be non-negative. Found:" 480 "%s" % self._batch_shape_static) 481 482 @property 483 def _composite_tensor_prefer_static_fields(self): 484 return ("num_rows", "batch_shape") 485 486 @property 487 def _composite_tensor_fields(self): 488 return ("num_rows", "batch_shape", "dtype", "assert_proper_shapes") 489 490 def __getitem__(self, slices): 491 # Slice the batch shape and return a new LinearOperatorIdentity. 492 # Use a proxy shape and slice it. Use this as the new batch shape 493 new_batch_shape = array_ops.shape( 494 array_ops.ones(self._batch_shape_arg)[slices]) 495 parameters = dict(self.parameters, batch_shape=new_batch_shape) 496 return LinearOperatorIdentity(**parameters) 497 498 499@tf_export("linalg.LinearOperatorScaledIdentity") 500@linear_operator.make_composite_tensor 501class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): 502 """`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`. 503 504 This operator acts like a scaled [batch] identity matrix `A` with shape 505 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 506 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 507 a scaled version of the `N x N` identity matrix. 508 509 `LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier` 510 (a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the 511 `multiplier` determines the scale for each batch member. 512 513 ```python 514 # Create a 2 x 2 scaled identity matrix. 515 operator = LinearOperatorIdentity(num_rows=2, multiplier=3.) 516 517 operator.to_dense() 518 ==> [[3., 0.] 519 [0., 3.]] 520 521 operator.shape 522 ==> [2, 2] 523 524 operator.log_abs_determinant() 525 ==> 2 * Log[3] 526 527 x = ... Shape [2, 4] Tensor 528 operator.matmul(x) 529 ==> 3 * x 530 531 y = tf.random.normal(shape=[3, 2, 4]) 532 # Note that y.shape is compatible with operator.shape because operator.shape 533 # is broadcast to [3, 2, 2]. 534 x = operator.solve(y) 535 ==> 3 * x 536 537 # Create a 2-batch of 2x2 identity matrices 538 operator = LinearOperatorIdentity(num_rows=2, multiplier=5.) 539 operator.to_dense() 540 ==> [[[5., 0.] 541 [0., 5.]], 542 [[5., 0.] 543 [0., 5.]]] 544 545 x = ... Shape [2, 2, 3] 546 operator.matmul(x) 547 ==> 5 * x 548 549 # Here the operator and x have different batch_shape, and are broadcast. 550 x = ... Shape [1, 2, 3] 551 operator.matmul(x) 552 ==> 5 * x 553 ``` 554 555 ### Shape compatibility 556 557 This operator acts on [batch] matrix with compatible shape. 558 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 559 560 ``` 561 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 562 x.shape = [C1,...,Cc] + [N, R], 563 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 564 ``` 565 566 ### Performance 567 568 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 569 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 570 * `operator.determinant()` is `O(D1*...*Dd)` 571 572 #### Matrix property hints 573 574 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 575 for `X = non_singular, self_adjoint, positive_definite, square`. 576 These have the following meaning 577 * If `is_X == True`, callers should expect the operator to have the 578 property `X`. This is a promise that should be fulfilled, but is *not* a 579 runtime assert. For example, finite floating point precision may result 580 in these promises being violated. 581 * If `is_X == False`, callers should expect the operator to not have `X`. 582 * If `is_X == None` (the default), callers should have no expectation either 583 way. 584 """ 585 586 def __init__(self, 587 num_rows, 588 multiplier, 589 is_non_singular=None, 590 is_self_adjoint=None, 591 is_positive_definite=None, 592 is_square=True, 593 assert_proper_shapes=False, 594 name="LinearOperatorScaledIdentity"): 595 r"""Initialize a `LinearOperatorScaledIdentity`. 596 597 The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which 598 determines the size of each identity matrix, and a `multiplier`, 599 which defines `dtype`, batch shape, and scale of each matrix. 600 601 This operator is able to broadcast the leading (batch) dimensions. 602 603 Args: 604 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 605 corresponding identity matrix. 606 multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar). 607 is_non_singular: Expect that this operator is non-singular. 608 is_self_adjoint: Expect that this operator is equal to its hermitian 609 transpose. 610 is_positive_definite: Expect that this operator is positive definite, 611 meaning the quadratic form `x^H A x` has positive real part for all 612 nonzero `x`. Note that we do not require the operator to be 613 self-adjoint to be positive-definite. See: 614 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 615 is_square: Expect that this operator acts like square [batch] matrices. 616 assert_proper_shapes: Python `bool`. If `False`, only perform static 617 checks that initialization and method arguments have proper shape. 618 If `True`, and static checks are inconclusive, add asserts to the graph. 619 name: A name for this `LinearOperator` 620 621 Raises: 622 ValueError: If `num_rows` is determined statically to be non-scalar, or 623 negative. 624 """ 625 parameters = dict( 626 num_rows=num_rows, 627 multiplier=multiplier, 628 is_non_singular=is_non_singular, 629 is_self_adjoint=is_self_adjoint, 630 is_positive_definite=is_positive_definite, 631 is_square=is_square, 632 assert_proper_shapes=assert_proper_shapes, 633 name=name) 634 635 self._assert_proper_shapes = assert_proper_shapes 636 637 with ops.name_scope(name, values=[multiplier, num_rows]): 638 self._multiplier = linear_operator_util.convert_nonref_to_tensor( 639 multiplier, name="multiplier") 640 641 # Check and auto-set hints. 642 if not self._multiplier.dtype.is_complex: 643 if is_self_adjoint is False: # pylint: disable=g-bool-id-comparison 644 raise ValueError("A real diagonal operator is always self adjoint.") 645 else: 646 is_self_adjoint = True 647 648 if not is_square: 649 raise ValueError("A ScaledIdentity operator is always square.") 650 651 linear_operator_util.assert_not_ref_type(num_rows, "num_rows") 652 653 super(LinearOperatorScaledIdentity, self).__init__( 654 dtype=self._multiplier.dtype.base_dtype, 655 is_non_singular=is_non_singular, 656 is_self_adjoint=is_self_adjoint, 657 is_positive_definite=is_positive_definite, 658 is_square=is_square, 659 parameters=parameters, 660 name=name) 661 662 self._num_rows = linear_operator_util.shape_tensor( 663 num_rows, name="num_rows") 664 self._num_rows_static = tensor_util.constant_value(self._num_rows) 665 self._check_num_rows_possibly_add_asserts() 666 self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype) 667 self._num_rows_cast_to_real_dtype = math_ops.cast(self._num_rows, 668 self.dtype.real_dtype) 669 670 def _shape(self): 671 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 672 self._num_rows_static)) 673 674 batch_shape = self.multiplier.shape 675 return batch_shape.concatenate(matrix_shape) 676 677 def _shape_tensor(self): 678 matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) 679 680 batch_shape = array_ops.shape(self.multiplier) 681 return array_ops.concat((batch_shape, matrix_shape), 0) 682 683 def _assert_non_singular(self): 684 return check_ops.assert_positive( 685 math_ops.abs(self.multiplier), message="LinearOperator was singular") 686 687 def _assert_positive_definite(self): 688 return check_ops.assert_positive( 689 math_ops.real(self.multiplier), 690 message="LinearOperator was not positive definite.") 691 692 def _assert_self_adjoint(self): 693 imag_multiplier = math_ops.imag(self.multiplier) 694 return check_ops.assert_equal( 695 array_ops.zeros_like(imag_multiplier), 696 imag_multiplier, 697 message="LinearOperator was not self-adjoint") 698 699 def _make_multiplier_matrix(self, conjugate=False): 700 # Shape [B1,...Bb, 1, 1] 701 multiplier_matrix = array_ops.expand_dims( 702 array_ops.expand_dims(self.multiplier, -1), -1) 703 if conjugate: 704 multiplier_matrix = math_ops.conj(multiplier_matrix) 705 return multiplier_matrix 706 707 def _matmul(self, x, adjoint=False, adjoint_arg=False): 708 x = linalg.adjoint(x) if adjoint_arg else x 709 if self._assert_proper_shapes: 710 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 711 x = control_flow_ops.with_dependencies([aps], x) 712 return x * self._make_multiplier_matrix(conjugate=adjoint) 713 714 def _determinant(self): 715 return self.multiplier**self._num_rows_cast_to_dtype 716 717 def _log_abs_determinant(self): 718 return self._num_rows_cast_to_real_dtype * math_ops.log( 719 math_ops.abs(self.multiplier)) 720 721 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 722 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 723 if self._assert_proper_shapes: 724 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs) 725 rhs = control_flow_ops.with_dependencies([aps], rhs) 726 return rhs / self._make_multiplier_matrix(conjugate=adjoint) 727 728 def _trace(self): 729 # Get Tensor of all ones of same shape as self.batch_shape. 730 if self.batch_shape.is_fully_defined(): 731 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 732 else: 733 batch_of_ones = array_ops.ones( 734 shape=self.batch_shape_tensor(), dtype=self.dtype) 735 736 if self._min_matrix_dim() is not None: 737 return self.multiplier * self._min_matrix_dim() * batch_of_ones 738 else: 739 return (self.multiplier * math_ops.cast(self._min_matrix_dim_tensor(), 740 self.dtype) * batch_of_ones) 741 742 def _diag_part(self): 743 return self._ones_diag() * self.multiplier[..., array_ops.newaxis] 744 745 def add_to_tensor(self, mat, name="add_to_tensor"): 746 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 747 748 Args: 749 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 750 name: A name to give this `Op`. 751 752 Returns: 753 A `Tensor` with broadcast shape and same `dtype` as `self`. 754 """ 755 with self._name_scope(name): # pylint: disable=not-callable 756 # Shape [B1,...,Bb, 1] 757 multiplier_vector = array_ops.expand_dims(self.multiplier, -1) 758 759 # Shape [C1,...,Cc, M, M] 760 mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat") 761 762 # Shape [C1,...,Cc, M] 763 mat_diag = array_ops.matrix_diag_part(mat) 764 765 # multiplier_vector broadcasts here. 766 new_diag = multiplier_vector + mat_diag 767 768 return array_ops.matrix_set_diag(mat, new_diag) 769 770 def _eigvals(self): 771 return self._ones_diag() * self.multiplier[..., array_ops.newaxis] 772 773 def _cond(self): 774 # Condition number for a scalar time identity matrix is one, except when the 775 # scalar is zero. 776 return array_ops.where_v2( 777 math_ops.equal(self._multiplier, 0.), 778 math_ops.cast(np.nan, dtype=self.dtype), 779 math_ops.cast(1., dtype=self.dtype)) 780 781 @property 782 def multiplier(self): 783 """The [batch] scalar `Tensor`, `c` in `cI`.""" 784 return self._multiplier 785 786 @property 787 def _composite_tensor_prefer_static_fields(self): 788 return ("num_rows",) 789 790 @property 791 def _composite_tensor_fields(self): 792 return ("num_rows", "multiplier", "assert_proper_shapes") 793 794 @property 795 def _experimental_parameter_ndims_to_matrix_ndims(self): 796 return {"multiplier": 0} 797