1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for linear operators.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import contextlib 23 24import numpy as np 25import six 26 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.framework import type_spec 34from tensorflow.python.module import module 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import check_ops 37from tensorflow.python.ops import linalg_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import resource_variable_ops 40from tensorflow.python.ops import variables 41from tensorflow.python.ops.linalg import linalg_impl as linalg 42from tensorflow.python.ops.linalg import linear_operator_algebra 43from tensorflow.python.ops.linalg import linear_operator_util 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.training.tracking import data_structures 46from tensorflow.python.util import deprecation 47from tensorflow.python.util import dispatch 48from tensorflow.python.util import nest 49from tensorflow.python.util.tf_export import tf_export 50 51__all__ = ["LinearOperator"] 52 53 54# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices. 55@tf_export("linalg.LinearOperator") 56@six.add_metaclass(abc.ABCMeta) 57class LinearOperator(module.Module, composite_tensor.CompositeTensor): 58 """Base class defining a [batch of] linear operator[s]. 59 60 Subclasses of `LinearOperator` provide access to common methods on a 61 (batch) matrix, without the need to materialize the matrix. This allows: 62 63 * Matrix free computations 64 * Operators that take advantage of special structure, while providing a 65 consistent API to users. 66 67 #### Subclassing 68 69 To enable a public method, subclasses should implement the leading-underscore 70 version of the method. The argument signature should be identical except for 71 the omission of `name="..."`. For example, to enable 72 `matmul(x, adjoint=False, name="matmul")` a subclass should implement 73 `_matmul(x, adjoint=False)`. 74 75 #### Performance contract 76 77 Subclasses should only implement the assert methods 78 (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)` 79 time. 80 81 Class docstrings should contain an explanation of computational complexity. 82 Since this is a high-performance library, attention should be paid to detail, 83 and explanations can include constants as well as Big-O notation. 84 85 #### Shape compatibility 86 87 `LinearOperator` subclasses should operate on a [batch] matrix with 88 compatible shape. Class docstrings should define what is meant by compatible 89 shape. Some subclasses may not support batching. 90 91 Examples: 92 93 `x` is a batch matrix with compatible shape for `matmul` if 94 95 ``` 96 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 97 x.shape = [B1,...,Bb] + [N, R] 98 ``` 99 100 `rhs` is a batch matrix with compatible shape for `solve` if 101 102 ``` 103 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 104 rhs.shape = [B1,...,Bb] + [M, R] 105 ``` 106 107 #### Example docstring for subclasses. 108 109 This operator acts like a (batch) matrix `A` with shape 110 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a 111 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 112 an `m x n` matrix. Again, this matrix `A` may not be materialized, but for 113 purposes of identifying and working with compatible arguments the shape is 114 relevant. 115 116 Examples: 117 118 ```python 119 some_tensor = ... shape = ???? 120 operator = MyLinOp(some_tensor) 121 122 operator.shape() 123 ==> [2, 4, 4] 124 125 operator.log_abs_determinant() 126 ==> Shape [2] Tensor 127 128 x = ... Shape [2, 4, 5] Tensor 129 130 operator.matmul(x) 131 ==> Shape [2, 4, 5] Tensor 132 ``` 133 134 #### Shape compatibility 135 136 This operator acts on batch matrices with compatible shape. 137 FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE 138 139 #### Performance 140 141 FILL THIS IN 142 143 #### Matrix property hints 144 145 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 146 for `X = non_singular, self_adjoint, positive_definite, square`. 147 These have the following meaning: 148 149 * If `is_X == True`, callers should expect the operator to have the 150 property `X`. This is a promise that should be fulfilled, but is *not* a 151 runtime assert. For example, finite floating point precision may result 152 in these promises being violated. 153 * If `is_X == False`, callers should expect the operator to not have `X`. 154 * If `is_X == None` (the default), callers should have no expectation either 155 way. 156 157 #### Initialization parameters 158 159 All subclasses of `LinearOperator` are expected to pass a `parameters` 160 argument to `super().__init__()`. This should be a `dict` containing 161 the unadulterated arguments passed to the subclass `__init__`. For example, 162 `MyLinearOperator` with an initializer should look like: 163 164 ```python 165 def __init__(self, operator, is_square=False, name=None): 166 parameters = dict( 167 operator=operator, 168 is_square=is_square, 169 name=name 170 ) 171 ... 172 super().__init__(..., parameters=parameters) 173 ``` 174 175 Users can then access `my_linear_operator.parameters` to see all arguments 176 passed to its initializer. 177 """ 178 179 # TODO(b/143910018) Remove graph_parents in V3. 180 @deprecation.deprecated_args(None, "Do not pass `graph_parents`. They will " 181 " no longer be used.", "graph_parents") 182 def __init__(self, 183 dtype, 184 graph_parents=None, 185 is_non_singular=None, 186 is_self_adjoint=None, 187 is_positive_definite=None, 188 is_square=None, 189 name=None, 190 parameters=None): 191 r"""Initialize the `LinearOperator`. 192 193 **This is a private method for subclass use.** 194 **Subclasses should copy-paste this `__init__` documentation.** 195 196 Args: 197 dtype: The type of the this `LinearOperator`. Arguments to `matmul` and 198 `solve` will have to be this type. 199 graph_parents: (Deprecated) Python list of graph prerequisites of this 200 `LinearOperator` Typically tensors that are passed during initialization 201 is_non_singular: Expect that this operator is non-singular. 202 is_self_adjoint: Expect that this operator is equal to its hermitian 203 transpose. If `dtype` is real, this is equivalent to being symmetric. 204 is_positive_definite: Expect that this operator is positive definite, 205 meaning the quadratic form `x^H A x` has positive real part for all 206 nonzero `x`. Note that we do not require the operator to be 207 self-adjoint to be positive-definite. See: 208 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 209 is_square: Expect that this operator acts like square [batch] matrices. 210 name: A name for this `LinearOperator`. 211 parameters: Python `dict` of parameters used to instantiate this 212 `LinearOperator`. 213 214 Raises: 215 ValueError: If any member of graph_parents is `None` or not a `Tensor`. 216 ValueError: If hints are set incorrectly. 217 """ 218 # Check and auto-set flags. 219 if is_positive_definite: 220 if is_non_singular is False: 221 raise ValueError("A positive definite matrix is always non-singular.") 222 is_non_singular = True 223 224 if is_non_singular: 225 if is_square is False: 226 raise ValueError("A non-singular matrix is always square.") 227 is_square = True 228 229 if is_self_adjoint: 230 if is_square is False: 231 raise ValueError("A self-adjoint matrix is always square.") 232 is_square = True 233 234 self._is_square_set_or_implied_by_hints = is_square 235 236 if graph_parents is not None: 237 self._set_graph_parents(graph_parents) 238 else: 239 self._graph_parents = [] 240 self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype 241 self._is_non_singular = is_non_singular 242 self._is_self_adjoint = is_self_adjoint 243 self._is_positive_definite = is_positive_definite 244 self._parameters = self._no_dependency(parameters) 245 self._parameters_sanitized = False 246 self._name = name or type(self).__name__ 247 248 @contextlib.contextmanager 249 def _name_scope(self, name=None): # pylint: disable=method-hidden 250 """Helper function to standardize op scope.""" 251 full_name = self.name 252 if name is not None: 253 full_name += "/" + name 254 with ops.name_scope(full_name) as scope: 255 yield scope 256 257 @property 258 def parameters(self): 259 """Dictionary of parameters used to instantiate this `LinearOperator`.""" 260 return dict(self._parameters) 261 262 @property 263 def dtype(self): 264 """The `DType` of `Tensor`s handled by this `LinearOperator`.""" 265 return self._dtype 266 267 @property 268 def name(self): 269 """Name prepended to all ops created by this `LinearOperator`.""" 270 return self._name 271 272 @property 273 @deprecation.deprecated(None, "Do not call `graph_parents`.") 274 def graph_parents(self): 275 """List of graph dependencies of this `LinearOperator`.""" 276 return self._graph_parents 277 278 @property 279 def is_non_singular(self): 280 return self._is_non_singular 281 282 @property 283 def is_self_adjoint(self): 284 return self._is_self_adjoint 285 286 @property 287 def is_positive_definite(self): 288 return self._is_positive_definite 289 290 @property 291 def is_square(self): 292 """Return `True/False` depending on if this operator is square.""" 293 # Static checks done after __init__. Why? Because domain/range dimension 294 # sometimes requires lots of work done in the derived class after init. 295 auto_square_check = self.domain_dimension == self.range_dimension 296 if self._is_square_set_or_implied_by_hints is False and auto_square_check: 297 raise ValueError( 298 "User set is_square hint to False, but the operator was square.") 299 if self._is_square_set_or_implied_by_hints is None: 300 return auto_square_check 301 302 return self._is_square_set_or_implied_by_hints 303 304 @abc.abstractmethod 305 def _shape(self): 306 # Write this in derived class to enable all static shape methods. 307 raise NotImplementedError("_shape is not implemented.") 308 309 @property 310 def shape(self): 311 """`TensorShape` of this `LinearOperator`. 312 313 If this operator acts like the batch matrix `A` with 314 `A.shape = [B1,...,Bb, M, N]`, then this returns 315 `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.shape`. 316 317 Returns: 318 `TensorShape`, statically determined, may be undefined. 319 """ 320 return self._shape() 321 322 def _shape_tensor(self): 323 # This is not an abstractmethod, since we want derived classes to be able to 324 # override this with optional kwargs, which can reduce the number of 325 # `convert_to_tensor` calls. See derived classes for examples. 326 raise NotImplementedError("_shape_tensor is not implemented.") 327 328 def shape_tensor(self, name="shape_tensor"): 329 """Shape of this `LinearOperator`, determined at runtime. 330 331 If this operator acts like the batch matrix `A` with 332 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 333 `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`. 334 335 Args: 336 name: A name for this `Op`. 337 338 Returns: 339 `int32` `Tensor` 340 """ 341 with self._name_scope(name): # pylint: disable=not-callable 342 # Prefer to use statically defined shape if available. 343 if self.shape.is_fully_defined(): 344 return linear_operator_util.shape_tensor(self.shape.as_list()) 345 else: 346 return self._shape_tensor() 347 348 @property 349 def batch_shape(self): 350 """`TensorShape` of batch dimensions of this `LinearOperator`. 351 352 If this operator acts like the batch matrix `A` with 353 `A.shape = [B1,...,Bb, M, N]`, then this returns 354 `TensorShape([B1,...,Bb])`, equivalent to `A.shape[:-2]` 355 356 Returns: 357 `TensorShape`, statically determined, may be undefined. 358 """ 359 # Derived classes get this "for free" once .shape is implemented. 360 return self.shape[:-2] 361 362 def batch_shape_tensor(self, name="batch_shape_tensor"): 363 """Shape of batch dimensions of this operator, determined at runtime. 364 365 If this operator acts like the batch matrix `A` with 366 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 367 `[B1,...,Bb]`. 368 369 Args: 370 name: A name for this `Op`. 371 372 Returns: 373 `int32` `Tensor` 374 """ 375 # Derived classes get this "for free" once .shape() is implemented. 376 with self._name_scope(name): # pylint: disable=not-callable 377 return self._batch_shape_tensor() 378 379 def _batch_shape_tensor(self, shape=None): 380 # `shape` may be passed in if this can be pre-computed in a 381 # more efficient manner, e.g. without excessive Tensor conversions. 382 if self.batch_shape.is_fully_defined(): 383 return linear_operator_util.shape_tensor( 384 self.batch_shape.as_list(), name="batch_shape") 385 else: 386 shape = self.shape_tensor() if shape is None else shape 387 return shape[:-2] 388 389 @property 390 def tensor_rank(self, name="tensor_rank"): 391 """Rank (in the sense of tensors) of matrix corresponding to this operator. 392 393 If this operator acts like the batch matrix `A` with 394 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 395 396 Args: 397 name: A name for this `Op`. 398 399 Returns: 400 Python integer, or None if the tensor rank is undefined. 401 """ 402 # Derived classes get this "for free" once .shape() is implemented. 403 with self._name_scope(name): # pylint: disable=not-callable 404 return self.shape.ndims 405 406 def tensor_rank_tensor(self, name="tensor_rank_tensor"): 407 """Rank (in the sense of tensors) of matrix corresponding to this operator. 408 409 If this operator acts like the batch matrix `A` with 410 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 411 412 Args: 413 name: A name for this `Op`. 414 415 Returns: 416 `int32` `Tensor`, determined at runtime. 417 """ 418 # Derived classes get this "for free" once .shape() is implemented. 419 with self._name_scope(name): # pylint: disable=not-callable 420 return self._tensor_rank_tensor() 421 422 def _tensor_rank_tensor(self, shape=None): 423 # `shape` may be passed in if this can be pre-computed in a 424 # more efficient manner, e.g. without excessive Tensor conversions. 425 if self.tensor_rank is not None: 426 return ops.convert_to_tensor_v2_with_dispatch(self.tensor_rank) 427 else: 428 shape = self.shape_tensor() if shape is None else shape 429 return array_ops.size(shape) 430 431 @property 432 def domain_dimension(self): 433 """Dimension (in the sense of vector spaces) of the domain of this operator. 434 435 If this operator acts like the batch matrix `A` with 436 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 437 438 Returns: 439 `Dimension` object. 440 """ 441 # Derived classes get this "for free" once .shape is implemented. 442 if self.shape.rank is None: 443 return tensor_shape.Dimension(None) 444 else: 445 return self.shape.dims[-1] 446 447 def domain_dimension_tensor(self, name="domain_dimension_tensor"): 448 """Dimension (in the sense of vector spaces) of the domain of this operator. 449 450 Determined at runtime. 451 452 If this operator acts like the batch matrix `A` with 453 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 454 455 Args: 456 name: A name for this `Op`. 457 458 Returns: 459 `int32` `Tensor` 460 """ 461 # Derived classes get this "for free" once .shape() is implemented. 462 with self._name_scope(name): # pylint: disable=not-callable 463 return self._domain_dimension_tensor() 464 465 def _domain_dimension_tensor(self, shape=None): 466 # `shape` may be passed in if this can be pre-computed in a 467 # more efficient manner, e.g. without excessive Tensor conversions. 468 dim_value = tensor_shape.dimension_value(self.domain_dimension) 469 if dim_value is not None: 470 return ops.convert_to_tensor_v2_with_dispatch(dim_value) 471 else: 472 shape = self.shape_tensor() if shape is None else shape 473 return shape[-1] 474 475 @property 476 def range_dimension(self): 477 """Dimension (in the sense of vector spaces) of the range of this operator. 478 479 If this operator acts like the batch matrix `A` with 480 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 481 482 Returns: 483 `Dimension` object. 484 """ 485 # Derived classes get this "for free" once .shape is implemented. 486 if self.shape.dims: 487 return self.shape.dims[-2] 488 else: 489 return tensor_shape.Dimension(None) 490 491 def range_dimension_tensor(self, name="range_dimension_tensor"): 492 """Dimension (in the sense of vector spaces) of the range of this operator. 493 494 Determined at runtime. 495 496 If this operator acts like the batch matrix `A` with 497 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 498 499 Args: 500 name: A name for this `Op`. 501 502 Returns: 503 `int32` `Tensor` 504 """ 505 # Derived classes get this "for free" once .shape() is implemented. 506 with self._name_scope(name): # pylint: disable=not-callable 507 return self._range_dimension_tensor() 508 509 def _range_dimension_tensor(self, shape=None): 510 # `shape` may be passed in if this can be pre-computed in a 511 # more efficient manner, e.g. without excessive Tensor conversions. 512 dim_value = tensor_shape.dimension_value(self.range_dimension) 513 if dim_value is not None: 514 return ops.convert_to_tensor_v2_with_dispatch(dim_value) 515 else: 516 shape = self.shape_tensor() if shape is None else shape 517 return shape[-2] 518 519 def _assert_non_singular(self): 520 """Private default implementation of _assert_non_singular.""" 521 logging.warn( 522 "Using (possibly slow) default implementation of assert_non_singular." 523 " Requires conversion to a dense matrix and O(N^3) operations.") 524 if self._can_use_cholesky(): 525 return self.assert_positive_definite() 526 else: 527 singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False) 528 # TODO(langmore) Add .eig and .cond as methods. 529 cond = (math_ops.reduce_max(singular_values, axis=-1) / 530 math_ops.reduce_min(singular_values, axis=-1)) 531 return check_ops.assert_less( 532 cond, 533 self._max_condition_number_to_be_non_singular(), 534 message="Singular matrix up to precision epsilon.") 535 536 def _max_condition_number_to_be_non_singular(self): 537 """Return the maximum condition number that we consider nonsingular.""" 538 with ops.name_scope("max_nonsingular_condition_number"): 539 dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps 540 eps = math_ops.cast( 541 math_ops.reduce_max([ 542 100., 543 math_ops.cast(self.range_dimension_tensor(), self.dtype), 544 math_ops.cast(self.domain_dimension_tensor(), self.dtype) 545 ]), self.dtype) * dtype_eps 546 return 1. / eps 547 548 def assert_non_singular(self, name="assert_non_singular"): 549 """Returns an `Op` that asserts this operator is non singular. 550 551 This operator is considered non-singular if 552 553 ``` 554 ConditionNumber < max{100, range_dimension, domain_dimension} * eps, 555 eps := np.finfo(self.dtype.as_numpy_dtype).eps 556 ``` 557 558 Args: 559 name: A string name to prepend to created ops. 560 561 Returns: 562 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 563 the operator is singular. 564 """ 565 with self._name_scope(name): # pylint: disable=not-callable 566 return self._assert_non_singular() 567 568 def _assert_positive_definite(self): 569 """Default implementation of _assert_positive_definite.""" 570 logging.warn( 571 "Using (possibly slow) default implementation of " 572 "assert_positive_definite." 573 " Requires conversion to a dense matrix and O(N^3) operations.") 574 # If the operator is self-adjoint, then checking that 575 # Cholesky decomposition succeeds + results in positive diag is necessary 576 # and sufficient. 577 if self.is_self_adjoint: 578 return check_ops.assert_positive( 579 array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())), 580 message="Matrix was not positive definite.") 581 # We have no generic check for positive definite. 582 raise NotImplementedError("assert_positive_definite is not implemented.") 583 584 def assert_positive_definite(self, name="assert_positive_definite"): 585 """Returns an `Op` that asserts this operator is positive definite. 586 587 Here, positive definite means that the quadratic form `x^H A x` has positive 588 real part for all nonzero `x`. Note that we do not require the operator to 589 be self-adjoint to be positive definite. 590 591 Args: 592 name: A name to give this `Op`. 593 594 Returns: 595 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 596 the operator is not positive definite. 597 """ 598 with self._name_scope(name): # pylint: disable=not-callable 599 return self._assert_positive_definite() 600 601 def _assert_self_adjoint(self): 602 dense = self.to_dense() 603 logging.warn( 604 "Using (possibly slow) default implementation of assert_self_adjoint." 605 " Requires conversion to a dense matrix.") 606 return check_ops.assert_equal( 607 dense, 608 linalg.adjoint(dense), 609 message="Matrix was not equal to its adjoint.") 610 611 def assert_self_adjoint(self, name="assert_self_adjoint"): 612 """Returns an `Op` that asserts this operator is self-adjoint. 613 614 Here we check that this operator is *exactly* equal to its hermitian 615 transpose. 616 617 Args: 618 name: A string name to prepend to created ops. 619 620 Returns: 621 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 622 the operator is not self-adjoint. 623 """ 624 with self._name_scope(name): # pylint: disable=not-callable 625 return self._assert_self_adjoint() 626 627 def _check_input_dtype(self, arg): 628 """Check that arg.dtype == self.dtype.""" 629 if arg.dtype.base_dtype != self.dtype: 630 raise TypeError( 631 "Expected argument to have dtype %s. Found: %s in tensor %s" % 632 (self.dtype, arg.dtype, arg)) 633 634 @abc.abstractmethod 635 def _matmul(self, x, adjoint=False, adjoint_arg=False): 636 raise NotImplementedError("_matmul is not implemented.") 637 638 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): 639 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. 640 641 ```python 642 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 643 operator = LinearOperator(...) 644 operator.shape = [..., M, N] 645 646 X = ... # shape [..., N, R], batch matrix, R > 0. 647 648 Y = operator.matmul(X) 649 Y.shape 650 ==> [..., M, R] 651 652 Y[..., :, r] = sum_j A[..., :, j] X[j, r] 653 ``` 654 655 Args: 656 x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as 657 `self`. See class docstring for definition of compatibility. 658 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 659 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is 660 the hermitian transpose (transposition and complex conjugation). 661 name: A name for this `Op`. 662 663 Returns: 664 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` 665 as `self`. 666 """ 667 if isinstance(x, LinearOperator): 668 left_operator = self.adjoint() if adjoint else self 669 right_operator = x.adjoint() if adjoint_arg else x 670 671 if (right_operator.range_dimension is not None and 672 left_operator.domain_dimension is not None and 673 right_operator.range_dimension != left_operator.domain_dimension): 674 raise ValueError( 675 "Operators are incompatible. Expected `x` to have dimension" 676 " {} but got {}.".format( 677 left_operator.domain_dimension, right_operator.range_dimension)) 678 with self._name_scope(name): # pylint: disable=not-callable 679 return linear_operator_algebra.matmul(left_operator, right_operator) 680 681 with self._name_scope(name): # pylint: disable=not-callable 682 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 683 self._check_input_dtype(x) 684 685 self_dim = -2 if adjoint else -1 686 arg_dim = -1 if adjoint_arg else -2 687 tensor_shape.dimension_at_index( 688 self.shape, self_dim).assert_is_compatible_with( 689 x.shape[arg_dim]) 690 691 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 692 693 def __matmul__(self, other): 694 return self.matmul(other) 695 696 def _matvec(self, x, adjoint=False): 697 x_mat = array_ops.expand_dims(x, axis=-1) 698 y_mat = self.matmul(x_mat, adjoint=adjoint) 699 return array_ops.squeeze(y_mat, axis=-1) 700 701 def matvec(self, x, adjoint=False, name="matvec"): 702 """Transform [batch] vector `x` with left multiplication: `x --> Ax`. 703 704 ```python 705 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 706 operator = LinearOperator(...) 707 708 X = ... # shape [..., N], batch vector 709 710 Y = operator.matvec(X) 711 Y.shape 712 ==> [..., M] 713 714 Y[..., :] = sum_j A[..., :, j] X[..., j] 715 ``` 716 717 Args: 718 x: `Tensor` with compatible shape and same `dtype` as `self`. 719 `x` is treated as a [batch] vector meaning for every set of leading 720 dimensions, the last dimension defines a vector. 721 See class docstring for definition of compatibility. 722 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 723 name: A name for this `Op`. 724 725 Returns: 726 A `Tensor` with shape `[..., M]` and same `dtype` as `self`. 727 """ 728 with self._name_scope(name): # pylint: disable=not-callable 729 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 730 self._check_input_dtype(x) 731 self_dim = -2 if adjoint else -1 732 tensor_shape.dimension_at_index( 733 self.shape, self_dim).assert_is_compatible_with(x.shape[-1]) 734 return self._matvec(x, adjoint=adjoint) 735 736 def _determinant(self): 737 logging.warn( 738 "Using (possibly slow) default implementation of determinant." 739 " Requires conversion to a dense matrix and O(N^3) operations.") 740 if self._can_use_cholesky(): 741 return math_ops.exp(self.log_abs_determinant()) 742 return linalg_ops.matrix_determinant(self.to_dense()) 743 744 def determinant(self, name="det"): 745 """Determinant for every batch member. 746 747 Args: 748 name: A name for this `Op`. 749 750 Returns: 751 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 752 753 Raises: 754 NotImplementedError: If `self.is_square` is `False`. 755 """ 756 if self.is_square is False: 757 raise NotImplementedError( 758 "Determinant not implemented for an operator that is expected to " 759 "not be square.") 760 with self._name_scope(name): # pylint: disable=not-callable 761 return self._determinant() 762 763 def _log_abs_determinant(self): 764 logging.warn( 765 "Using (possibly slow) default implementation of determinant." 766 " Requires conversion to a dense matrix and O(N^3) operations.") 767 if self._can_use_cholesky(): 768 diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())) 769 return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1]) 770 _, log_abs_det = linalg.slogdet(self.to_dense()) 771 return log_abs_det 772 773 def log_abs_determinant(self, name="log_abs_det"): 774 """Log absolute value of determinant for every batch member. 775 776 Args: 777 name: A name for this `Op`. 778 779 Returns: 780 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 781 782 Raises: 783 NotImplementedError: If `self.is_square` is `False`. 784 """ 785 if self.is_square is False: 786 raise NotImplementedError( 787 "Determinant not implemented for an operator that is expected to " 788 "not be square.") 789 with self._name_scope(name): # pylint: disable=not-callable 790 return self._log_abs_determinant() 791 792 def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False): 793 """Solve by conversion to a dense matrix.""" 794 if self.is_square is False: # pylint: disable=g-bool-id-comparison 795 raise NotImplementedError( 796 "Solve is not yet implemented for non-square operators.") 797 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 798 if self._can_use_cholesky(): 799 return linalg_ops.cholesky_solve( 800 linalg_ops.cholesky(self.to_dense()), rhs) 801 return linear_operator_util.matrix_solve_with_broadcast( 802 self.to_dense(), rhs, adjoint=adjoint) 803 804 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 805 """Default implementation of _solve.""" 806 logging.warn( 807 "Using (possibly slow) default implementation of solve." 808 " Requires conversion to a dense matrix and O(N^3) operations.") 809 return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 810 811 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): 812 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. 813 814 The returned `Tensor` will be close to an exact solution if `A` is well 815 conditioned. Otherwise closeness will vary. See class docstring for details. 816 817 Examples: 818 819 ```python 820 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 821 operator = LinearOperator(...) 822 operator.shape = [..., M, N] 823 824 # Solve R > 0 linear systems for every member of the batch. 825 RHS = ... # shape [..., M, R] 826 827 X = operator.solve(RHS) 828 # X[..., :, r] is the solution to the r'th linear system 829 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] 830 831 operator.matmul(X) 832 ==> RHS 833 ``` 834 835 Args: 836 rhs: `Tensor` with same `dtype` as this operator and compatible shape. 837 `rhs` is treated like a [batch] matrix meaning for every set of leading 838 dimensions, the last two dimensions defines a matrix. 839 See class docstring for definition of compatibility. 840 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 841 of this `LinearOperator`: `A^H X = rhs`. 842 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` 843 is the hermitian transpose (transposition and complex conjugation). 844 name: A name scope to use for ops added by this method. 845 846 Returns: 847 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. 848 849 Raises: 850 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 851 """ 852 if self.is_non_singular is False: 853 raise NotImplementedError( 854 "Exact solve not implemented for an operator that is expected to " 855 "be singular.") 856 if self.is_square is False: 857 raise NotImplementedError( 858 "Exact solve not implemented for an operator that is expected to " 859 "not be square.") 860 if isinstance(rhs, LinearOperator): 861 left_operator = self.adjoint() if adjoint else self 862 right_operator = rhs.adjoint() if adjoint_arg else rhs 863 864 if (right_operator.range_dimension is not None and 865 left_operator.domain_dimension is not None and 866 right_operator.range_dimension != left_operator.domain_dimension): 867 raise ValueError( 868 "Operators are incompatible. Expected `rhs` to have dimension" 869 " {} but got {}.".format( 870 left_operator.domain_dimension, right_operator.range_dimension)) 871 with self._name_scope(name): # pylint: disable=not-callable 872 return linear_operator_algebra.solve(left_operator, right_operator) 873 874 with self._name_scope(name): # pylint: disable=not-callable 875 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 876 self._check_input_dtype(rhs) 877 878 self_dim = -1 if adjoint else -2 879 arg_dim = -1 if adjoint_arg else -2 880 tensor_shape.dimension_at_index( 881 self.shape, self_dim).assert_is_compatible_with( 882 rhs.shape[arg_dim]) 883 884 return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 885 886 def _solvevec(self, rhs, adjoint=False): 887 """Default implementation of _solvevec.""" 888 rhs_mat = array_ops.expand_dims(rhs, axis=-1) 889 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 890 return array_ops.squeeze(solution_mat, axis=-1) 891 892 def solvevec(self, rhs, adjoint=False, name="solve"): 893 """Solve single equation with best effort: `A X = rhs`. 894 895 The returned `Tensor` will be close to an exact solution if `A` is well 896 conditioned. Otherwise closeness will vary. See class docstring for details. 897 898 Examples: 899 900 ```python 901 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 902 operator = LinearOperator(...) 903 operator.shape = [..., M, N] 904 905 # Solve one linear system for every member of the batch. 906 RHS = ... # shape [..., M] 907 908 X = operator.solvevec(RHS) 909 # X is the solution to the linear system 910 # sum_j A[..., :, j] X[..., j] = RHS[..., :] 911 912 operator.matvec(X) 913 ==> RHS 914 ``` 915 916 Args: 917 rhs: `Tensor` with same `dtype` as this operator. 918 `rhs` is treated like a [batch] vector meaning for every set of leading 919 dimensions, the last dimension defines a vector. See class docstring 920 for definition of compatibility regarding batch dimensions. 921 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 922 of this `LinearOperator`: `A^H X = rhs`. 923 name: A name scope to use for ops added by this method. 924 925 Returns: 926 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. 927 928 Raises: 929 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 930 """ 931 with self._name_scope(name): # pylint: disable=not-callable 932 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 933 self._check_input_dtype(rhs) 934 self_dim = -1 if adjoint else -2 935 tensor_shape.dimension_at_index( 936 self.shape, self_dim).assert_is_compatible_with(rhs.shape[-1]) 937 938 return self._solvevec(rhs, adjoint=adjoint) 939 940 def adjoint(self, name="adjoint"): 941 """Returns the adjoint of the current `LinearOperator`. 942 943 Given `A` representing this `LinearOperator`, return `A*`. 944 Note that calling `self.adjoint()` and `self.H` are equivalent. 945 946 Args: 947 name: A name for this `Op`. 948 949 Returns: 950 `LinearOperator` which represents the adjoint of this `LinearOperator`. 951 """ 952 if self.is_self_adjoint is True: # pylint: disable=g-bool-id-comparison 953 return self 954 with self._name_scope(name): # pylint: disable=not-callable 955 return linear_operator_algebra.adjoint(self) 956 957 # self.H is equivalent to self.adjoint(). 958 H = property(adjoint, None) 959 960 def inverse(self, name="inverse"): 961 """Returns the Inverse of this `LinearOperator`. 962 963 Given `A` representing this `LinearOperator`, return a `LinearOperator` 964 representing `A^-1`. 965 966 Args: 967 name: A name scope to use for ops added by this method. 968 969 Returns: 970 `LinearOperator` representing inverse of this matrix. 971 972 Raises: 973 ValueError: When the `LinearOperator` is not hinted to be `non_singular`. 974 """ 975 if self.is_square is False: # pylint: disable=g-bool-id-comparison 976 raise ValueError("Cannot take the Inverse: This operator represents " 977 "a non square matrix.") 978 if self.is_non_singular is False: # pylint: disable=g-bool-id-comparison 979 raise ValueError("Cannot take the Inverse: This operator represents " 980 "a singular matrix.") 981 982 with self._name_scope(name): # pylint: disable=not-callable 983 return linear_operator_algebra.inverse(self) 984 985 def cholesky(self, name="cholesky"): 986 """Returns a Cholesky factor as a `LinearOperator`. 987 988 Given `A` representing this `LinearOperator`, if `A` is positive definite 989 self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky 990 decomposition. 991 992 Args: 993 name: A name for this `Op`. 994 995 Returns: 996 `LinearOperator` which represents the lower triangular matrix 997 in the Cholesky decomposition. 998 999 Raises: 1000 ValueError: When the `LinearOperator` is not hinted to be positive 1001 definite and self adjoint. 1002 """ 1003 1004 if not self._can_use_cholesky(): 1005 raise ValueError("Cannot take the Cholesky decomposition: " 1006 "Not a positive definite self adjoint matrix.") 1007 with self._name_scope(name): # pylint: disable=not-callable 1008 return linear_operator_algebra.cholesky(self) 1009 1010 def _to_dense(self): 1011 """Generic and often inefficient implementation. Override often.""" 1012 if self.batch_shape.is_fully_defined(): 1013 batch_shape = self.batch_shape 1014 else: 1015 batch_shape = self.batch_shape_tensor() 1016 1017 dim_value = tensor_shape.dimension_value(self.domain_dimension) 1018 if dim_value is not None: 1019 n = dim_value 1020 else: 1021 n = self.domain_dimension_tensor() 1022 1023 eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype) 1024 return self.matmul(eye) 1025 1026 def to_dense(self, name="to_dense"): 1027 """Return a dense (batch) matrix representing this operator.""" 1028 with self._name_scope(name): # pylint: disable=not-callable 1029 return self._to_dense() 1030 1031 def _diag_part(self): 1032 """Generic and often inefficient implementation. Override often.""" 1033 return array_ops.matrix_diag_part(self.to_dense()) 1034 1035 def diag_part(self, name="diag_part"): 1036 """Efficiently get the [batch] diagonal part of this operator. 1037 1038 If this operator has shape `[B1,...,Bb, M, N]`, this returns a 1039 `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where 1040 `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`. 1041 1042 ``` 1043 my_operator = LinearOperatorDiag([1., 2.]) 1044 1045 # Efficiently get the diagonal 1046 my_operator.diag_part() 1047 ==> [1., 2.] 1048 1049 # Equivalent, but inefficient method 1050 tf.linalg.diag_part(my_operator.to_dense()) 1051 ==> [1., 2.] 1052 ``` 1053 1054 Args: 1055 name: A name for this `Op`. 1056 1057 Returns: 1058 diag_part: A `Tensor` of same `dtype` as self. 1059 """ 1060 with self._name_scope(name): # pylint: disable=not-callable 1061 return self._diag_part() 1062 1063 def _trace(self): 1064 return math_ops.reduce_sum(self.diag_part(), axis=-1) 1065 1066 def trace(self, name="trace"): 1067 """Trace of the linear operator, equal to sum of `self.diag_part()`. 1068 1069 If the operator is square, this is also the sum of the eigenvalues. 1070 1071 Args: 1072 name: A name for this `Op`. 1073 1074 Returns: 1075 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. 1076 """ 1077 with self._name_scope(name): # pylint: disable=not-callable 1078 return self._trace() 1079 1080 def _add_to_tensor(self, x): 1081 # Override if a more efficient implementation is available. 1082 return self.to_dense() + x 1083 1084 def add_to_tensor(self, x, name="add_to_tensor"): 1085 """Add matrix represented by this operator to `x`. Equivalent to `A + x`. 1086 1087 Args: 1088 x: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. 1089 name: A name to give this `Op`. 1090 1091 Returns: 1092 A `Tensor` with broadcast shape and same `dtype` as `self`. 1093 """ 1094 with self._name_scope(name): # pylint: disable=not-callable 1095 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 1096 self._check_input_dtype(x) 1097 return self._add_to_tensor(x) 1098 1099 def _eigvals(self): 1100 return linalg_ops.self_adjoint_eigvals(self.to_dense()) 1101 1102 def eigvals(self, name="eigvals"): 1103 """Returns the eigenvalues of this linear operator. 1104 1105 If the operator is marked as self-adjoint (via `is_self_adjoint`) 1106 this computation can be more efficient. 1107 1108 Note: This currently only supports self-adjoint operators. 1109 1110 Args: 1111 name: A name for this `Op`. 1112 1113 Returns: 1114 Shape `[B1,...,Bb, N]` `Tensor` of same `dtype` as `self`. 1115 """ 1116 if not self.is_self_adjoint: 1117 raise NotImplementedError("Only self-adjoint matrices are supported.") 1118 with self._name_scope(name): # pylint: disable=not-callable 1119 return self._eigvals() 1120 1121 def _cond(self): 1122 if not self.is_self_adjoint: 1123 # In general the condition number is the ratio of the 1124 # absolute value of the largest and smallest singular values. 1125 vals = linalg_ops.svd(self.to_dense(), compute_uv=False) 1126 else: 1127 # For self-adjoint matrices, and in general normal matrices, 1128 # we can use eigenvalues. 1129 vals = math_ops.abs(self._eigvals()) 1130 1131 return (math_ops.reduce_max(vals, axis=-1) / 1132 math_ops.reduce_min(vals, axis=-1)) 1133 1134 def cond(self, name="cond"): 1135 """Returns the condition number of this linear operator. 1136 1137 Args: 1138 name: A name for this `Op`. 1139 1140 Returns: 1141 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. 1142 """ 1143 with self._name_scope(name): # pylint: disable=not-callable 1144 return self._cond() 1145 1146 def _can_use_cholesky(self): 1147 return self.is_self_adjoint and self.is_positive_definite 1148 1149 def _set_graph_parents(self, graph_parents): 1150 """Set self._graph_parents. Called during derived class init. 1151 1152 This method allows derived classes to set graph_parents, without triggering 1153 a deprecation warning (which is invoked if `graph_parents` is passed during 1154 `__init__`. 1155 1156 Args: 1157 graph_parents: Iterable over Tensors. 1158 """ 1159 # TODO(b/143910018) Remove this function in V3. 1160 graph_parents = [] if graph_parents is None else graph_parents 1161 for i, t in enumerate(graph_parents): 1162 if t is None or not (linear_operator_util.is_ref(t) or 1163 tensor_util.is_tf_type(t)): 1164 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 1165 self._graph_parents = graph_parents 1166 1167 @property 1168 def _composite_tensor_fields(self): 1169 """A tuple of parameter names to rebuild the `LinearOperator`. 1170 1171 The tuple contains the names of kwargs to the `LinearOperator`'s constructor 1172 that the `TypeSpec` needs to rebuild the `LinearOperator` instance. 1173 1174 "is_non_singular", "is_self_adjoint", "is_positive_definite", and 1175 "is_square" are common to all `LinearOperator` subclasses and may be 1176 omitted. 1177 """ 1178 return () 1179 1180 @property 1181 def _composite_tensor_prefer_static_fields(self): 1182 """A tuple of names referring to parameters that may be treated statically. 1183 1184 This is a subset of `_composite_tensor_fields`, and contains the names of 1185 of `Tensor`-like args to the `LinearOperator`s constructor that may be 1186 stored as static values, if they are statically known. These are typically 1187 shapes or axis values. 1188 """ 1189 return () 1190 1191 @property 1192 def _type_spec(self): 1193 # This property will be overwritten by the `@make_composite_tensor` 1194 # decorator. However, we need it so that a valid subclass of the `ABCMeta` 1195 # class `CompositeTensor` can be constructed and passed to the 1196 # `@make_composite_tensor` decorator. 1197 pass 1198 1199 1200class _LinearOperatorSpec(type_spec.TypeSpec): 1201 """A tf.TypeSpec for `LinearOperator` objects.""" 1202 1203 __slots__ = ("_param_specs", "_non_tensor_params", "_prefer_static_fields") 1204 1205 def __init__(self, param_specs, non_tensor_params, prefer_static_fields): 1206 """Initializes a new `_LinearOperatorSpec`. 1207 1208 Args: 1209 param_specs: Python `dict` of `tf.TypeSpec` instances that describe 1210 kwargs to the `LinearOperator`'s constructor that are `Tensor`-like or 1211 `CompositeTensor` subclasses. 1212 non_tensor_params: Python `dict` containing non-`Tensor` and non- 1213 `CompositeTensor` kwargs to the `LinearOperator`'s constructor. 1214 prefer_static_fields: Python `tuple` of strings corresponding to the names 1215 of `Tensor`-like args to the `LinearOperator`s constructor that may be 1216 stored as static values, if known. These are typically shapes, indices, 1217 or axis values. 1218 """ 1219 self._param_specs = param_specs 1220 self._non_tensor_params = non_tensor_params 1221 self._prefer_static_fields = prefer_static_fields 1222 1223 @classmethod 1224 def from_operator(cls, operator): 1225 """Builds a `_LinearOperatorSpec` from a `LinearOperator` instance. 1226 1227 Args: 1228 operator: An instance of `LinearOperator`. 1229 1230 Returns: 1231 linear_operator_spec: An instance of `_LinearOperatorSpec` to be used as 1232 the `TypeSpec` of `operator`. 1233 """ 1234 validation_fields = ("is_non_singular", "is_self_adjoint", 1235 "is_positive_definite", "is_square") 1236 kwargs = _extract_attrs( 1237 operator, 1238 keys=set(operator._composite_tensor_fields + validation_fields)) # pylint: disable=protected-access 1239 1240 non_tensor_params = {} 1241 param_specs = {} 1242 for k, v in list(kwargs.items()): 1243 type_spec_or_v = _extract_type_spec_recursively(v) 1244 is_tensor = [isinstance(x, type_spec.TypeSpec) 1245 for x in nest.flatten(type_spec_or_v)] 1246 if all(is_tensor): 1247 param_specs[k] = type_spec_or_v 1248 elif not any(is_tensor): 1249 non_tensor_params[k] = v 1250 else: 1251 raise NotImplementedError(f"Field {k} contains a mix of `Tensor` and " 1252 f" non-`Tensor` values.") 1253 1254 return cls( 1255 param_specs=param_specs, 1256 non_tensor_params=non_tensor_params, 1257 prefer_static_fields=operator._composite_tensor_prefer_static_fields) # pylint: disable=protected-access 1258 1259 def _to_components(self, obj): 1260 return _extract_attrs(obj, keys=list(self._param_specs)) 1261 1262 def _from_components(self, components): 1263 kwargs = dict(self._non_tensor_params, **components) 1264 return self.value_type(**kwargs) 1265 1266 @property 1267 def _component_specs(self): 1268 return self._param_specs 1269 1270 def _serialize(self): 1271 return (self._param_specs, 1272 self._non_tensor_params, 1273 self._prefer_static_fields) 1274 1275 1276def make_composite_tensor(cls, module_name="tf.linalg"): 1277 """Class decorator to convert `LinearOperator`s to `CompositeTensor`.""" 1278 1279 spec_name = "{}Spec".format(cls.__name__) 1280 spec_type = type(spec_name, (_LinearOperatorSpec,), {"value_type": cls}) 1281 type_spec.register("{}.{}".format(module_name, spec_name))(spec_type) 1282 cls._type_spec = property(spec_type.from_operator) # pylint: disable=protected-access 1283 return cls 1284 1285 1286def _extract_attrs(op, keys): 1287 """Extract constructor kwargs to reconstruct `op`. 1288 1289 Args: 1290 op: A `LinearOperator` instance. 1291 keys: A Python `tuple` of strings indicating the names of the constructor 1292 kwargs to extract from `op`. 1293 1294 Returns: 1295 kwargs: A Python `dict` of kwargs to `op`'s constructor, keyed by `keys`. 1296 """ 1297 1298 kwargs = {} 1299 not_found = object() 1300 for k in keys: 1301 srcs = [ 1302 getattr(op, k, not_found), getattr(op, "_" + k, not_found), 1303 getattr(op, "parameters", {}).get(k, not_found), 1304 ] 1305 if any(v is not not_found for v in srcs): 1306 kwargs[k] = [v for v in srcs if v is not not_found][0] 1307 else: 1308 raise ValueError( 1309 f"Could not determine an appropriate value for field `{k}` in object " 1310 f" `{op}`. Looked for \n" 1311 f" 1. an attr called `{k}`,\n" 1312 f" 2. an attr called `_{k}`,\n" 1313 f" 3. an entry in `op.parameters` with key '{k}'.") 1314 if k in op._composite_tensor_prefer_static_fields and kwargs[k] is not None: # pylint: disable=protected-access 1315 if tensor_util.is_tensor(kwargs[k]): 1316 static_val = tensor_util.constant_value(kwargs[k]) 1317 if static_val is not None: 1318 kwargs[k] = static_val 1319 if isinstance(kwargs[k], (np.ndarray, np.generic)): 1320 kwargs[k] = kwargs[k].tolist() 1321 return kwargs 1322 1323 1324def _extract_type_spec_recursively(value): 1325 """Return (collection of) `TypeSpec`(s) for `value` if it includes `Tensor`s. 1326 1327 If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If 1328 `value` is a collection containing `Tensor` values, recursively supplant them 1329 with their respective `TypeSpec`s in a collection of parallel stucture. 1330 1331 If `value` is none of the above, return it unchanged. 1332 1333 Args: 1334 value: a Python `object` to (possibly) turn into a (collection of) 1335 `tf.TypeSpec`(s). 1336 1337 Returns: 1338 spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value` 1339 or `value`, if no `Tensor`s are found. 1340 """ 1341 if isinstance(value, composite_tensor.CompositeTensor): 1342 return value._type_spec # pylint: disable=protected-access 1343 if isinstance(value, variables.Variable): 1344 return resource_variable_ops.VariableSpec( 1345 value.shape, dtype=value.dtype, trainable=value.trainable) 1346 if tensor_util.is_tensor(value): 1347 return tensor_spec.TensorSpec(value.shape, value.dtype) 1348 # Unwrap trackable data structures to comply with `Type_Spec._serialize` 1349 # requirements. `ListWrapper`s are converted to `list`s, and for other 1350 # trackable data structures, the `__wrapped__` attribute is used. 1351 if isinstance(value, list): 1352 return list(_extract_type_spec_recursively(v) for v in value) 1353 if isinstance(value, data_structures.TrackableDataStructure): 1354 return _extract_type_spec_recursively(value.__wrapped__) 1355 if isinstance(value, tuple): 1356 return type(value)(_extract_type_spec_recursively(x) for x in value) 1357 if isinstance(value, dict): 1358 return type(value)((k, _extract_type_spec_recursively(v)) 1359 for k, v in value.items()) 1360 return value 1361 1362 1363# Overrides for tf.linalg functions. This allows a LinearOperator to be used in 1364# place of a Tensor. 1365# For instance tf.trace(linop) and linop.trace() both work. 1366 1367 1368@dispatch.dispatch_for_types(linalg.adjoint, LinearOperator) 1369def _adjoint(matrix, name=None): 1370 return matrix.adjoint(name) 1371 1372 1373@dispatch.dispatch_for_types(linalg.cholesky, LinearOperator) 1374def _cholesky(input, name=None): # pylint:disable=redefined-builtin 1375 return input.cholesky(name) 1376 1377 1378# The signature has to match with the one in python/op/array_ops.py, 1379# so we have k, padding_value, and align even though we don't use them here. 1380# pylint:disable=unused-argument 1381@dispatch.dispatch_for_types(linalg.diag_part, LinearOperator) 1382def _diag_part( 1383 input, # pylint:disable=redefined-builtin 1384 name="diag_part", 1385 k=0, 1386 padding_value=0, 1387 align="RIGHT_LEFT"): 1388 return input.diag_part(name) 1389# pylint:enable=unused-argument 1390 1391 1392@dispatch.dispatch_for_types(linalg.det, LinearOperator) 1393def _det(input, name=None): # pylint:disable=redefined-builtin 1394 return input.determinant(name) 1395 1396 1397@dispatch.dispatch_for_types(linalg.inv, LinearOperator) 1398def _inverse(input, adjoint=False, name=None): # pylint:disable=redefined-builtin 1399 inv = input.inverse(name) 1400 if adjoint: 1401 inv = inv.adjoint() 1402 return inv 1403 1404 1405@dispatch.dispatch_for_types(linalg.logdet, LinearOperator) 1406def _logdet(matrix, name=None): 1407 if matrix.is_positive_definite and matrix.is_self_adjoint: 1408 return matrix.log_abs_determinant(name) 1409 raise ValueError("Expected matrix to be self-adjoint positive definite.") 1410 1411 1412@dispatch.dispatch_for_types(math_ops.matmul, LinearOperator) 1413def _matmul( # pylint:disable=missing-docstring 1414 a, 1415 b, 1416 transpose_a=False, 1417 transpose_b=False, 1418 adjoint_a=False, 1419 adjoint_b=False, 1420 a_is_sparse=False, 1421 b_is_sparse=False, 1422 output_type=None, # pylint: disable=unused-argument 1423 name=None): 1424 if transpose_a or transpose_b: 1425 raise ValueError("Transposing not supported at this time.") 1426 if a_is_sparse or b_is_sparse: 1427 raise ValueError("Sparse methods not supported at this time.") 1428 if not isinstance(a, LinearOperator): 1429 # We use the identity (B^HA^H)^H = AB 1430 adjoint_matmul = b.matmul( 1431 a, 1432 adjoint=(not adjoint_b), 1433 adjoint_arg=(not adjoint_a), 1434 name=name) 1435 return linalg.adjoint(adjoint_matmul) 1436 return a.matmul( 1437 b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name) 1438 1439 1440@dispatch.dispatch_for_types(linalg.solve, LinearOperator) 1441def _solve( 1442 matrix, 1443 rhs, 1444 adjoint=False, 1445 name=None): 1446 if not isinstance(matrix, LinearOperator): 1447 raise ValueError("Passing in `matrix` as a Tensor and `rhs` as a " 1448 "LinearOperator is not supported.") 1449 return matrix.solve(rhs, adjoint=adjoint, name=name) 1450 1451 1452@dispatch.dispatch_for_types(linalg.trace, LinearOperator) 1453def _trace(x, name=None): 1454 return x.trace(name) 1455