1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for linear operators.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import contextlib 23 24import numpy as np 25import six 26 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import check_ops 32from tensorflow.python.ops import linalg_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops.linalg import linalg_impl as linalg 35from tensorflow.python.ops.linalg import linear_operator_algebra 36from tensorflow.python.ops.linalg import linear_operator_util 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util.tf_export import tf_export 39 40__all__ = ["LinearOperator"] 41 42 43# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices. 44@tf_export("linalg.LinearOperator") 45@six.add_metaclass(abc.ABCMeta) 46class LinearOperator(object): 47 """Base class defining a [batch of] linear operator[s]. 48 49 Subclasses of `LinearOperator` provide access to common methods on a 50 (batch) matrix, without the need to materialize the matrix. This allows: 51 52 * Matrix free computations 53 * Operators that take advantage of special structure, while providing a 54 consistent API to users. 55 56 #### Subclassing 57 58 To enable a public method, subclasses should implement the leading-underscore 59 version of the method. The argument signature should be identical except for 60 the omission of `name="..."`. For example, to enable 61 `matmul(x, adjoint=False, name="matmul")` a subclass should implement 62 `_matmul(x, adjoint=False)`. 63 64 #### Performance contract 65 66 Subclasses should only implement the assert methods 67 (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)` 68 time. 69 70 Class docstrings should contain an explanation of computational complexity. 71 Since this is a high-performance library, attention should be paid to detail, 72 and explanations can include constants as well as Big-O notation. 73 74 #### Shape compatibility 75 76 `LinearOperator` subclasses should operate on a [batch] matrix with 77 compatible shape. Class docstrings should define what is meant by compatible 78 shape. Some subclasses may not support batching. 79 80 Examples: 81 82 `x` is a batch matrix with compatible shape for `matmul` if 83 84 ``` 85 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 86 x.shape = [B1,...,Bb] + [N, R] 87 ``` 88 89 `rhs` is a batch matrix with compatible shape for `solve` if 90 91 ``` 92 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 93 rhs.shape = [B1,...,Bb] + [M, R] 94 ``` 95 96 #### Example docstring for subclasses. 97 98 This operator acts like a (batch) matrix `A` with shape 99 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a 100 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 101 an `m x n` matrix. Again, this matrix `A` may not be materialized, but for 102 purposes of identifying and working with compatible arguments the shape is 103 relevant. 104 105 Examples: 106 107 ```python 108 some_tensor = ... shape = ???? 109 operator = MyLinOp(some_tensor) 110 111 operator.shape() 112 ==> [2, 4, 4] 113 114 operator.log_abs_determinant() 115 ==> Shape [2] Tensor 116 117 x = ... Shape [2, 4, 5] Tensor 118 119 operator.matmul(x) 120 ==> Shape [2, 4, 5] Tensor 121 ``` 122 123 #### Shape compatibility 124 125 This operator acts on batch matrices with compatible shape. 126 FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE 127 128 #### Performance 129 130 FILL THIS IN 131 132 #### Matrix property hints 133 134 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 135 for `X = non_singular, self_adjoint, positive_definite, square`. 136 These have the following meaning: 137 138 * If `is_X == True`, callers should expect the operator to have the 139 property `X`. This is a promise that should be fulfilled, but is *not* a 140 runtime assert. For example, finite floating point precision may result 141 in these promises being violated. 142 * If `is_X == False`, callers should expect the operator to not have `X`. 143 * If `is_X == None` (the default), callers should have no expectation either 144 way. 145 """ 146 147 def __init__(self, 148 dtype, 149 graph_parents=None, 150 is_non_singular=None, 151 is_self_adjoint=None, 152 is_positive_definite=None, 153 is_square=None, 154 name=None): 155 r"""Initialize the `LinearOperator`. 156 157 **This is a private method for subclass use.** 158 **Subclasses should copy-paste this `__init__` documentation.** 159 160 Args: 161 dtype: The type of the this `LinearOperator`. Arguments to `matmul` and 162 `solve` will have to be this type. 163 graph_parents: Python list of graph prerequisites of this `LinearOperator` 164 Typically tensors that are passed during initialization. 165 is_non_singular: Expect that this operator is non-singular. 166 is_self_adjoint: Expect that this operator is equal to its hermitian 167 transpose. If `dtype` is real, this is equivalent to being symmetric. 168 is_positive_definite: Expect that this operator is positive definite, 169 meaning the quadratic form `x^H A x` has positive real part for all 170 nonzero `x`. Note that we do not require the operator to be 171 self-adjoint to be positive-definite. See: 172 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 173 is_square: Expect that this operator acts like square [batch] matrices. 174 name: A name for this `LinearOperator`. 175 176 Raises: 177 ValueError: If any member of graph_parents is `None` or not a `Tensor`. 178 ValueError: If hints are set incorrectly. 179 """ 180 # Check and auto-set flags. 181 if is_positive_definite: 182 if is_non_singular is False: 183 raise ValueError("A positive definite matrix is always non-singular.") 184 is_non_singular = True 185 186 if is_non_singular: 187 if is_square is False: 188 raise ValueError("A non-singular matrix is always square.") 189 is_square = True 190 191 if is_self_adjoint: 192 if is_square is False: 193 raise ValueError("A self-adjoint matrix is always square.") 194 is_square = True 195 196 self._is_square_set_or_implied_by_hints = is_square 197 198 graph_parents = [] if graph_parents is None else graph_parents 199 for i, t in enumerate(graph_parents): 200 if t is None or not tensor_util.is_tensor(t): 201 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 202 self._dtype = dtype 203 self._graph_parents = graph_parents 204 self._is_non_singular = is_non_singular 205 self._is_self_adjoint = is_self_adjoint 206 self._is_positive_definite = is_positive_definite 207 self._name = name or type(self).__name__ 208 209 @contextlib.contextmanager 210 def _name_scope(self, name=None, values=None): 211 """Helper function to standardize op scope.""" 212 with ops.name_scope(self.name): 213 with ops.name_scope( 214 name, values=((values or []) + self._graph_parents)) as scope: 215 yield scope 216 217 @property 218 def dtype(self): 219 """The `DType` of `Tensor`s handled by this `LinearOperator`.""" 220 return self._dtype 221 222 @property 223 def name(self): 224 """Name prepended to all ops created by this `LinearOperator`.""" 225 return self._name 226 227 @property 228 def graph_parents(self): 229 """List of graph dependencies of this `LinearOperator`.""" 230 return self._graph_parents 231 232 @property 233 def is_non_singular(self): 234 return self._is_non_singular 235 236 @property 237 def is_self_adjoint(self): 238 return self._is_self_adjoint 239 240 @property 241 def is_positive_definite(self): 242 return self._is_positive_definite 243 244 @property 245 def is_square(self): 246 """Return `True/False` depending on if this operator is square.""" 247 # Static checks done after __init__. Why? Because domain/range dimension 248 # sometimes requires lots of work done in the derived class after init. 249 auto_square_check = self.domain_dimension == self.range_dimension 250 if self._is_square_set_or_implied_by_hints is False and auto_square_check: 251 raise ValueError( 252 "User set is_square hint to False, but the operator was square.") 253 if self._is_square_set_or_implied_by_hints is None: 254 return auto_square_check 255 256 return self._is_square_set_or_implied_by_hints 257 258 @abc.abstractmethod 259 def _shape(self): 260 # Write this in derived class to enable all static shape methods. 261 raise NotImplementedError("_shape is not implemented.") 262 263 @property 264 def shape(self): 265 """`TensorShape` of this `LinearOperator`. 266 267 If this operator acts like the batch matrix `A` with 268 `A.shape = [B1,...,Bb, M, N]`, then this returns 269 `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.get_shape()`. 270 271 Returns: 272 `TensorShape`, statically determined, may be undefined. 273 """ 274 return self._shape() 275 276 @abc.abstractmethod 277 def _shape_tensor(self): 278 raise NotImplementedError("_shape_tensor is not implemented.") 279 280 def shape_tensor(self, name="shape_tensor"): 281 """Shape of this `LinearOperator`, determined at runtime. 282 283 If this operator acts like the batch matrix `A` with 284 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 285 `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`. 286 287 Args: 288 name: A name for this `Op`. 289 290 Returns: 291 `int32` `Tensor` 292 """ 293 with self._name_scope(name): 294 # Prefer to use statically defined shape if available. 295 if self.shape.is_fully_defined(): 296 return linear_operator_util.shape_tensor(self.shape.as_list()) 297 else: 298 return self._shape_tensor() 299 300 @property 301 def batch_shape(self): 302 """`TensorShape` of batch dimensions of this `LinearOperator`. 303 304 If this operator acts like the batch matrix `A` with 305 `A.shape = [B1,...,Bb, M, N]`, then this returns 306 `TensorShape([B1,...,Bb])`, equivalent to `A.get_shape()[:-2]` 307 308 Returns: 309 `TensorShape`, statically determined, may be undefined. 310 """ 311 # Derived classes get this "for free" once .shape is implemented. 312 return self.shape[:-2] 313 314 def batch_shape_tensor(self, name="batch_shape_tensor"): 315 """Shape of batch dimensions of this operator, determined at runtime. 316 317 If this operator acts like the batch matrix `A` with 318 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 319 `[B1,...,Bb]`. 320 321 Args: 322 name: A name for this `Op`. 323 324 Returns: 325 `int32` `Tensor` 326 """ 327 # Derived classes get this "for free" once .shape() is implemented. 328 with self._name_scope(name): 329 # Prefer to use statically defined shape if available. 330 if self.batch_shape.is_fully_defined(): 331 return linear_operator_util.shape_tensor( 332 self.batch_shape.as_list(), name="batch_shape") 333 else: 334 return self.shape_tensor()[:-2] 335 336 @property 337 def tensor_rank(self, name="tensor_rank"): 338 """Rank (in the sense of tensors) of matrix corresponding to this operator. 339 340 If this operator acts like the batch matrix `A` with 341 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 342 343 Args: 344 name: A name for this `Op`. 345 346 Returns: 347 Python integer, or None if the tensor rank is undefined. 348 """ 349 # Derived classes get this "for free" once .shape() is implemented. 350 with self._name_scope(name): 351 return self.shape.ndims 352 353 def tensor_rank_tensor(self, name="tensor_rank_tensor"): 354 """Rank (in the sense of tensors) of matrix corresponding to this operator. 355 356 If this operator acts like the batch matrix `A` with 357 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 358 359 Args: 360 name: A name for this `Op`. 361 362 Returns: 363 `int32` `Tensor`, determined at runtime. 364 """ 365 # Derived classes get this "for free" once .shape() is implemented. 366 with self._name_scope(name): 367 # Prefer to use statically defined shape if available. 368 if self.tensor_rank is not None: 369 return ops.convert_to_tensor(self.tensor_rank) 370 else: 371 return array_ops.size(self.shape_tensor()) 372 373 @property 374 def domain_dimension(self): 375 """Dimension (in the sense of vector spaces) of the domain of this operator. 376 377 If this operator acts like the batch matrix `A` with 378 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 379 380 Returns: 381 `Dimension` object. 382 """ 383 # Derived classes get this "for free" once .shape is implemented. 384 if self.shape.rank is None: 385 return tensor_shape.Dimension(None) 386 else: 387 return self.shape.dims[-1] 388 389 def domain_dimension_tensor(self, name="domain_dimension_tensor"): 390 """Dimension (in the sense of vector spaces) of the domain of this operator. 391 392 Determined at runtime. 393 394 If this operator acts like the batch matrix `A` with 395 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 396 397 Args: 398 name: A name for this `Op`. 399 400 Returns: 401 `int32` `Tensor` 402 """ 403 # Derived classes get this "for free" once .shape() is implemented. 404 with self._name_scope(name): 405 # Prefer to use statically defined shape if available. 406 dim_value = tensor_shape.dimension_value(self.domain_dimension) 407 if dim_value is not None: 408 return ops.convert_to_tensor(dim_value) 409 else: 410 return self.shape_tensor()[-1] 411 412 @property 413 def range_dimension(self): 414 """Dimension (in the sense of vector spaces) of the range of this operator. 415 416 If this operator acts like the batch matrix `A` with 417 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 418 419 Returns: 420 `Dimension` object. 421 """ 422 # Derived classes get this "for free" once .shape is implemented. 423 if self.shape.dims: 424 return self.shape.dims[-2] 425 else: 426 return tensor_shape.Dimension(None) 427 428 def range_dimension_tensor(self, name="range_dimension_tensor"): 429 """Dimension (in the sense of vector spaces) of the range of this operator. 430 431 Determined at runtime. 432 433 If this operator acts like the batch matrix `A` with 434 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 435 436 Args: 437 name: A name for this `Op`. 438 439 Returns: 440 `int32` `Tensor` 441 """ 442 # Derived classes get this "for free" once .shape() is implemented. 443 with self._name_scope(name): 444 # Prefer to use statically defined shape if available. 445 dim_value = tensor_shape.dimension_value(self.range_dimension) 446 if dim_value is not None: 447 return ops.convert_to_tensor(dim_value) 448 else: 449 return self.shape_tensor()[-2] 450 451 def _assert_non_singular(self): 452 """Private default implementation of _assert_non_singular.""" 453 logging.warn( 454 "Using (possibly slow) default implementation of assert_non_singular." 455 " Requires conversion to a dense matrix and O(N^3) operations.") 456 if self._can_use_cholesky(): 457 return self.assert_positive_definite() 458 else: 459 singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False) 460 # TODO(langmore) Add .eig and .cond as methods. 461 cond = (math_ops.reduce_max(singular_values, axis=-1) / 462 math_ops.reduce_min(singular_values, axis=-1)) 463 return check_ops.assert_less( 464 cond, 465 self._max_condition_number_to_be_non_singular(), 466 message="Singular matrix up to precision epsilon.") 467 468 def _max_condition_number_to_be_non_singular(self): 469 """Return the maximum condition number that we consider nonsingular.""" 470 with ops.name_scope("max_nonsingular_condition_number"): 471 dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps 472 eps = math_ops.cast( 473 math_ops.reduce_max([ 474 100., 475 math_ops.cast(self.range_dimension_tensor(), self.dtype), 476 math_ops.cast(self.domain_dimension_tensor(), self.dtype) 477 ]), self.dtype) * dtype_eps 478 return 1. / eps 479 480 def assert_non_singular(self, name="assert_non_singular"): 481 """Returns an `Op` that asserts this operator is non singular. 482 483 This operator is considered non-singular if 484 485 ``` 486 ConditionNumber < max{100, range_dimension, domain_dimension} * eps, 487 eps := np.finfo(self.dtype.as_numpy_dtype).eps 488 ``` 489 490 Args: 491 name: A string name to prepend to created ops. 492 493 Returns: 494 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 495 the operator is singular. 496 """ 497 with self._name_scope(name): 498 return self._assert_non_singular() 499 500 def _assert_positive_definite(self): 501 """Default implementation of _assert_positive_definite.""" 502 logging.warn( 503 "Using (possibly slow) default implementation of " 504 "assert_positive_definite." 505 " Requires conversion to a dense matrix and O(N^3) operations.") 506 # If the operator is self-adjoint, then checking that 507 # Cholesky decomposition succeeds + results in positive diag is necessary 508 # and sufficient. 509 if self.is_self_adjoint: 510 return check_ops.assert_positive( 511 array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())), 512 message="Matrix was not positive definite.") 513 # We have no generic check for positive definite. 514 raise NotImplementedError("assert_positive_definite is not implemented.") 515 516 def assert_positive_definite(self, name="assert_positive_definite"): 517 """Returns an `Op` that asserts this operator is positive definite. 518 519 Here, positive definite means that the quadratic form `x^H A x` has positive 520 real part for all nonzero `x`. Note that we do not require the operator to 521 be self-adjoint to be positive definite. 522 523 Args: 524 name: A name to give this `Op`. 525 526 Returns: 527 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 528 the operator is not positive definite. 529 """ 530 with self._name_scope(name): 531 return self._assert_positive_definite() 532 533 def _assert_self_adjoint(self): 534 dense = self.to_dense() 535 logging.warn( 536 "Using (possibly slow) default implementation of assert_self_adjoint." 537 " Requires conversion to a dense matrix.") 538 return check_ops.assert_equal( 539 dense, 540 linalg.adjoint(dense), 541 message="Matrix was not equal to its adjoint.") 542 543 def assert_self_adjoint(self, name="assert_self_adjoint"): 544 """Returns an `Op` that asserts this operator is self-adjoint. 545 546 Here we check that this operator is *exactly* equal to its hermitian 547 transpose. 548 549 Args: 550 name: A string name to prepend to created ops. 551 552 Returns: 553 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 554 the operator is not self-adjoint. 555 """ 556 with self._name_scope(name): 557 return self._assert_self_adjoint() 558 559 def _check_input_dtype(self, arg): 560 """Check that arg.dtype == self.dtype.""" 561 if arg.dtype != self.dtype: 562 raise TypeError( 563 "Expected argument to have dtype %s. Found: %s in tensor %s" % 564 (self.dtype, arg.dtype, arg)) 565 566 @abc.abstractmethod 567 def _matmul(self, x, adjoint=False, adjoint_arg=False): 568 raise NotImplementedError("_matmul is not implemented.") 569 570 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): 571 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. 572 573 ```python 574 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 575 operator = LinearOperator(...) 576 operator.shape = [..., M, N] 577 578 X = ... # shape [..., N, R], batch matrix, R > 0. 579 580 Y = operator.matmul(X) 581 Y.shape 582 ==> [..., M, R] 583 584 Y[..., :, r] = sum_j A[..., :, j] X[j, r] 585 ``` 586 587 Args: 588 x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as 589 `self`. See class docstring for definition of compatibility. 590 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 591 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is 592 the hermitian transpose (transposition and complex conjugation). 593 name: A name for this `Op`. 594 595 Returns: 596 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` 597 as `self`. 598 """ 599 if isinstance(x, LinearOperator): 600 if adjoint or adjoint_arg: 601 raise ValueError(".matmul not supported with adjoints.") 602 if (x.range_dimension is not None and 603 self.domain_dimension is not None and 604 x.range_dimension != self.domain_dimension): 605 raise ValueError( 606 "Operators are incompatible. Expected `x` to have dimension" 607 " {} but got {}.".format(self.domain_dimension, x.range_dimension)) 608 with self._name_scope(name): 609 return linear_operator_algebra.matmul(self, x) 610 611 with self._name_scope(name, values=[x]): 612 x = ops.convert_to_tensor(x, name="x") 613 self._check_input_dtype(x) 614 615 self_dim = -2 if adjoint else -1 616 arg_dim = -1 if adjoint_arg else -2 617 tensor_shape.dimension_at_index( 618 self.shape, self_dim).assert_is_compatible_with( 619 x.get_shape()[arg_dim]) 620 621 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 622 623 def _matvec(self, x, adjoint=False): 624 x_mat = array_ops.expand_dims(x, axis=-1) 625 y_mat = self.matmul(x_mat, adjoint=adjoint) 626 return array_ops.squeeze(y_mat, axis=-1) 627 628 def matvec(self, x, adjoint=False, name="matvec"): 629 """Transform [batch] vector `x` with left multiplication: `x --> Ax`. 630 631 ```python 632 # Make an operator acting like batch matric A. Assume A.shape = [..., M, N] 633 operator = LinearOperator(...) 634 635 X = ... # shape [..., N], batch vector 636 637 Y = operator.matvec(X) 638 Y.shape 639 ==> [..., M] 640 641 Y[..., :] = sum_j A[..., :, j] X[..., j] 642 ``` 643 644 Args: 645 x: `Tensor` with compatible shape and same `dtype` as `self`. 646 `x` is treated as a [batch] vector meaning for every set of leading 647 dimensions, the last dimension defines a vector. 648 See class docstring for definition of compatibility. 649 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 650 name: A name for this `Op`. 651 652 Returns: 653 A `Tensor` with shape `[..., M]` and same `dtype` as `self`. 654 """ 655 with self._name_scope(name, values=[x]): 656 x = ops.convert_to_tensor(x, name="x") 657 self._check_input_dtype(x) 658 self_dim = -2 if adjoint else -1 659 tensor_shape.dimension_at_index( 660 self.shape, self_dim).assert_is_compatible_with(x.get_shape()[-1]) 661 return self._matvec(x, adjoint=adjoint) 662 663 def _determinant(self): 664 logging.warn( 665 "Using (possibly slow) default implementation of determinant." 666 " Requires conversion to a dense matrix and O(N^3) operations.") 667 if self._can_use_cholesky(): 668 return math_ops.exp(self.log_abs_determinant()) 669 return linalg_ops.matrix_determinant(self.to_dense()) 670 671 def determinant(self, name="det"): 672 """Determinant for every batch member. 673 674 Args: 675 name: A name for this `Op`. 676 677 Returns: 678 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 679 680 Raises: 681 NotImplementedError: If `self.is_square` is `False`. 682 """ 683 if self.is_square is False: 684 raise NotImplementedError( 685 "Determinant not implemented for an operator that is expected to " 686 "not be square.") 687 with self._name_scope(name): 688 return self._determinant() 689 690 def _log_abs_determinant(self): 691 logging.warn( 692 "Using (possibly slow) default implementation of determinant." 693 " Requires conversion to a dense matrix and O(N^3) operations.") 694 if self._can_use_cholesky(): 695 diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())) 696 return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1]) 697 _, log_abs_det = linalg.slogdet(self.to_dense()) 698 return log_abs_det 699 700 def log_abs_determinant(self, name="log_abs_det"): 701 """Log absolute value of determinant for every batch member. 702 703 Args: 704 name: A name for this `Op`. 705 706 Returns: 707 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 708 709 Raises: 710 NotImplementedError: If `self.is_square` is `False`. 711 """ 712 if self.is_square is False: 713 raise NotImplementedError( 714 "Determinant not implemented for an operator that is expected to " 715 "not be square.") 716 with self._name_scope(name): 717 return self._log_abs_determinant() 718 719 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 720 """Default implementation of _solve.""" 721 if self.is_square is False: 722 raise NotImplementedError( 723 "Solve is not yet implemented for non-square operators.") 724 logging.warn( 725 "Using (possibly slow) default implementation of solve." 726 " Requires conversion to a dense matrix and O(N^3) operations.") 727 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 728 if self._can_use_cholesky(): 729 return linear_operator_util.cholesky_solve_with_broadcast( 730 linalg_ops.cholesky(self.to_dense()), rhs) 731 return linear_operator_util.matrix_solve_with_broadcast( 732 self.to_dense(), rhs, adjoint=adjoint) 733 734 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): 735 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. 736 737 The returned `Tensor` will be close to an exact solution if `A` is well 738 conditioned. Otherwise closeness will vary. See class docstring for details. 739 740 Examples: 741 742 ```python 743 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 744 operator = LinearOperator(...) 745 operator.shape = [..., M, N] 746 747 # Solve R > 0 linear systems for every member of the batch. 748 RHS = ... # shape [..., M, R] 749 750 X = operator.solve(RHS) 751 # X[..., :, r] is the solution to the r'th linear system 752 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] 753 754 operator.matmul(X) 755 ==> RHS 756 ``` 757 758 Args: 759 rhs: `Tensor` with same `dtype` as this operator and compatible shape. 760 `rhs` is treated like a [batch] matrix meaning for every set of leading 761 dimensions, the last two dimensions defines a matrix. 762 See class docstring for definition of compatibility. 763 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 764 of this `LinearOperator`: `A^H X = rhs`. 765 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` 766 is the hermitian transpose (transposition and complex conjugation). 767 name: A name scope to use for ops added by this method. 768 769 Returns: 770 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. 771 772 Raises: 773 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 774 """ 775 if self.is_non_singular is False: 776 raise NotImplementedError( 777 "Exact solve not implemented for an operator that is expected to " 778 "be singular.") 779 if self.is_square is False: 780 raise NotImplementedError( 781 "Exact solve not implemented for an operator that is expected to " 782 "not be square.") 783 with self._name_scope(name, values=[rhs]): 784 rhs = ops.convert_to_tensor(rhs, name="rhs") 785 self._check_input_dtype(rhs) 786 787 self_dim = -1 if adjoint else -2 788 arg_dim = -1 if adjoint_arg else -2 789 tensor_shape.dimension_at_index( 790 self.shape, self_dim).assert_is_compatible_with( 791 rhs.get_shape()[arg_dim]) 792 793 return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 794 795 def _solvevec(self, rhs, adjoint=False): 796 """Default implementation of _solvevec.""" 797 rhs_mat = array_ops.expand_dims(rhs, axis=-1) 798 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 799 return array_ops.squeeze(solution_mat, axis=-1) 800 801 def solvevec(self, rhs, adjoint=False, name="solve"): 802 """Solve single equation with best effort: `A X = rhs`. 803 804 The returned `Tensor` will be close to an exact solution if `A` is well 805 conditioned. Otherwise closeness will vary. See class docstring for details. 806 807 Examples: 808 809 ```python 810 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 811 operator = LinearOperator(...) 812 operator.shape = [..., M, N] 813 814 # Solve one linear system for every member of the batch. 815 RHS = ... # shape [..., M] 816 817 X = operator.solvevec(RHS) 818 # X is the solution to the linear system 819 # sum_j A[..., :, j] X[..., j] = RHS[..., :] 820 821 operator.matvec(X) 822 ==> RHS 823 ``` 824 825 Args: 826 rhs: `Tensor` with same `dtype` as this operator. 827 `rhs` is treated like a [batch] vector meaning for every set of leading 828 dimensions, the last dimension defines a vector. See class docstring 829 for definition of compatibility regarding batch dimensions. 830 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 831 of this `LinearOperator`: `A^H X = rhs`. 832 name: A name scope to use for ops added by this method. 833 834 Returns: 835 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. 836 837 Raises: 838 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 839 """ 840 with self._name_scope(name, values=[rhs]): 841 rhs = ops.convert_to_tensor(rhs, name="rhs") 842 self._check_input_dtype(rhs) 843 self_dim = -1 if adjoint else -2 844 tensor_shape.dimension_at_index( 845 self.shape, self_dim).assert_is_compatible_with( 846 rhs.get_shape()[-1]) 847 848 return self._solvevec(rhs, adjoint=adjoint) 849 850 def adjoint(self, name="adjoint"): 851 """Returns the adjoint of the current `LinearOperator`. 852 853 Given `A` representing this `LinearOperator`, return `A*`. 854 Note that calling `self.adjoint()` and `self.H` are equivalent. 855 856 Args: 857 name: A name for this `Op`. 858 859 Returns: 860 `LinearOperator` which represents the adjoint of this `LinearOperator`. 861 """ 862 if self.is_self_adjoint is True: # pylint: disable=g-bool-id-comparison 863 return self 864 with self._name_scope(name): 865 return linear_operator_algebra.adjoint(self) 866 867 # self.H is equivalent to self.adjoint(). 868 H = property(adjoint, None) 869 870 def inverse(self, name="inverse"): 871 """Returns the Inverse of this `LinearOperator`. 872 873 Given `A` representing this `LinearOperator`, return a `LinearOperator` 874 representing `A^-1`. 875 876 Args: 877 name: A name scope to use for ops added by this method. 878 879 Returns: 880 `LinearOperator` representing inverse of this matrix. 881 882 Raises: 883 ValueError: When the `LinearOperator` is not hinted to be `non_singular`. 884 """ 885 if self.is_square is False: # pylint: disable=g-bool-id-comparison 886 raise ValueError("Cannot take the Inverse: This operator represents " 887 "a non square matrix.") 888 if self.is_non_singular is False: # pylint: disable=g-bool-id-comparison 889 raise ValueError("Cannot take the Inverse: This operator represents " 890 "a singular matrix.") 891 892 with self._name_scope(name): 893 return linear_operator_algebra.inverse(self) 894 895 def cholesky(self, name="cholesky"): 896 """Returns a Cholesky factor as a `LinearOperator`. 897 898 Given `A` representing this `LinearOperator`, if `A` is positive definite 899 self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky 900 decomposition. 901 902 Args: 903 name: A name for this `Op`. 904 905 Returns: 906 `LinearOperator` which represents the lower triangular matrix 907 in the Cholesky decomposition. 908 909 Raises: 910 ValueError: When the `LinearOperator` is not hinted to be positive 911 definite and self adjoint. 912 """ 913 914 if not self._can_use_cholesky(): 915 raise ValueError("Cannot take the Cholesky decomposition: " 916 "Not a positive definite self adjoint matrix.") 917 with self._name_scope(name): 918 return linear_operator_algebra.cholesky(self) 919 920 def _to_dense(self): 921 """Generic and often inefficient implementation. Override often.""" 922 logging.warn("Using (possibly slow) default implementation of to_dense." 923 " Converts by self.matmul(identity).") 924 if self.batch_shape.is_fully_defined(): 925 batch_shape = self.batch_shape 926 else: 927 batch_shape = self.batch_shape_tensor() 928 929 dim_value = tensor_shape.dimension_value(self.domain_dimension) 930 if dim_value is not None: 931 n = dim_value 932 else: 933 n = self.domain_dimension_tensor() 934 935 eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype) 936 return self.matmul(eye) 937 938 def to_dense(self, name="to_dense"): 939 """Return a dense (batch) matrix representing this operator.""" 940 with self._name_scope(name): 941 return self._to_dense() 942 943 def _diag_part(self): 944 """Generic and often inefficient implementation. Override often.""" 945 return array_ops.matrix_diag_part(self.to_dense()) 946 947 def diag_part(self, name="diag_part"): 948 """Efficiently get the [batch] diagonal part of this operator. 949 950 If this operator has shape `[B1,...,Bb, M, N]`, this returns a 951 `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where 952 `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`. 953 954 ``` 955 my_operator = LinearOperatorDiag([1., 2.]) 956 957 # Efficiently get the diagonal 958 my_operator.diag_part() 959 ==> [1., 2.] 960 961 # Equivalent, but inefficient method 962 tf.matrix_diag_part(my_operator.to_dense()) 963 ==> [1., 2.] 964 ``` 965 966 Args: 967 name: A name for this `Op`. 968 969 Returns: 970 diag_part: A `Tensor` of same `dtype` as self. 971 """ 972 with self._name_scope(name): 973 return self._diag_part() 974 975 def _trace(self): 976 return math_ops.reduce_sum(self.diag_part(), axis=-1) 977 978 def trace(self, name="trace"): 979 """Trace of the linear operator, equal to sum of `self.diag_part()`. 980 981 If the operator is square, this is also the sum of the eigenvalues. 982 983 Args: 984 name: A name for this `Op`. 985 986 Returns: 987 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. 988 """ 989 with self._name_scope(name): 990 return self._trace() 991 992 def _add_to_tensor(self, x): 993 # Override if a more efficient implementation is available. 994 return self.to_dense() + x 995 996 def add_to_tensor(self, x, name="add_to_tensor"): 997 """Add matrix represented by this operator to `x`. Equivalent to `A + x`. 998 999 Args: 1000 x: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. 1001 name: A name to give this `Op`. 1002 1003 Returns: 1004 A `Tensor` with broadcast shape and same `dtype` as `self`. 1005 """ 1006 with self._name_scope(name, values=[x]): 1007 x = ops.convert_to_tensor(x, name="x") 1008 self._check_input_dtype(x) 1009 return self._add_to_tensor(x) 1010 1011 def _can_use_cholesky(self): 1012 return self.is_self_adjoint and self.is_positive_definite 1013