• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for linear operators."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import contextlib
23
24import numpy as np
25import six
26
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.framework import type_spec
34from tensorflow.python.module import module
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import check_ops
37from tensorflow.python.ops import linalg_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.ops import variables
41from tensorflow.python.ops.linalg import linalg_impl as linalg
42from tensorflow.python.ops.linalg import linear_operator_algebra
43from tensorflow.python.ops.linalg import linear_operator_util
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.training.tracking import data_structures
46from tensorflow.python.util import deprecation
47from tensorflow.python.util import dispatch
48from tensorflow.python.util import nest
49from tensorflow.python.util.tf_export import tf_export
50
51__all__ = ["LinearOperator"]
52
53
54# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices.
55@tf_export("linalg.LinearOperator")
56@six.add_metaclass(abc.ABCMeta)
57class LinearOperator(module.Module, composite_tensor.CompositeTensor):
58  """Base class defining a [batch of] linear operator[s].
59
60  Subclasses of `LinearOperator` provide access to common methods on a
61  (batch) matrix, without the need to materialize the matrix.  This allows:
62
63  * Matrix free computations
64  * Operators that take advantage of special structure, while providing a
65    consistent API to users.
66
67  #### Subclassing
68
69  To enable a public method, subclasses should implement the leading-underscore
70  version of the method.  The argument signature should be identical except for
71  the omission of `name="..."`.  For example, to enable
72  `matmul(x, adjoint=False, name="matmul")` a subclass should implement
73  `_matmul(x, adjoint=False)`.
74
75  #### Performance contract
76
77  Subclasses should only implement the assert methods
78  (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)`
79  time.
80
81  Class docstrings should contain an explanation of computational complexity.
82  Since this is a high-performance library, attention should be paid to detail,
83  and explanations can include constants as well as Big-O notation.
84
85  #### Shape compatibility
86
87  `LinearOperator` subclasses should operate on a [batch] matrix with
88  compatible shape.  Class docstrings should define what is meant by compatible
89  shape.  Some subclasses may not support batching.
90
91  Examples:
92
93  `x` is a batch matrix with compatible shape for `matmul` if
94
95  ```
96  operator.shape = [B1,...,Bb] + [M, N],  b >= 0,
97  x.shape =   [B1,...,Bb] + [N, R]
98  ```
99
100  `rhs` is a batch matrix with compatible shape for `solve` if
101
102  ```
103  operator.shape = [B1,...,Bb] + [M, N],  b >= 0,
104  rhs.shape =   [B1,...,Bb] + [M, R]
105  ```
106
107  #### Example docstring for subclasses.
108
109  This operator acts like a (batch) matrix `A` with shape
110  `[B1,...,Bb, M, N]` for some `b >= 0`.  The first `b` indices index a
111  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
112  an `m x n` matrix.  Again, this matrix `A` may not be materialized, but for
113  purposes of identifying and working with compatible arguments the shape is
114  relevant.
115
116  Examples:
117
118  ```python
119  some_tensor = ... shape = ????
120  operator = MyLinOp(some_tensor)
121
122  operator.shape()
123  ==> [2, 4, 4]
124
125  operator.log_abs_determinant()
126  ==> Shape [2] Tensor
127
128  x = ... Shape [2, 4, 5] Tensor
129
130  operator.matmul(x)
131  ==> Shape [2, 4, 5] Tensor
132  ```
133
134  #### Shape compatibility
135
136  This operator acts on batch matrices with compatible shape.
137  FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE
138
139  #### Performance
140
141  FILL THIS IN
142
143  #### Matrix property hints
144
145  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
146  for `X = non_singular, self_adjoint, positive_definite, square`.
147  These have the following meaning:
148
149  * If `is_X == True`, callers should expect the operator to have the
150    property `X`.  This is a promise that should be fulfilled, but is *not* a
151    runtime assert.  For example, finite floating point precision may result
152    in these promises being violated.
153  * If `is_X == False`, callers should expect the operator to not have `X`.
154  * If `is_X == None` (the default), callers should have no expectation either
155    way.
156
157  #### Initialization parameters
158
159  All subclasses of `LinearOperator` are expected to pass a `parameters`
160  argument to `super().__init__()`.  This should be a `dict` containing
161  the unadulterated arguments passed to the subclass `__init__`.  For example,
162  `MyLinearOperator` with an initializer should look like:
163
164  ```python
165  def __init__(self, operator, is_square=False, name=None):
166     parameters = dict(
167         operator=operator,
168         is_square=is_square,
169         name=name
170     )
171     ...
172     super().__init__(..., parameters=parameters)
173   ```
174
175   Users can then access `my_linear_operator.parameters` to see all arguments
176   passed to its initializer.
177  """
178
179  # TODO(b/143910018) Remove graph_parents in V3.
180  @deprecation.deprecated_args(None, "Do not pass `graph_parents`.  They will "
181                               " no longer be used.", "graph_parents")
182  def __init__(self,
183               dtype,
184               graph_parents=None,
185               is_non_singular=None,
186               is_self_adjoint=None,
187               is_positive_definite=None,
188               is_square=None,
189               name=None,
190               parameters=None):
191    r"""Initialize the `LinearOperator`.
192
193    **This is a private method for subclass use.**
194    **Subclasses should copy-paste this `__init__` documentation.**
195
196    Args:
197      dtype: The type of the this `LinearOperator`.  Arguments to `matmul` and
198        `solve` will have to be this type.
199      graph_parents: (Deprecated) Python list of graph prerequisites of this
200        `LinearOperator` Typically tensors that are passed during initialization
201      is_non_singular:  Expect that this operator is non-singular.
202      is_self_adjoint:  Expect that this operator is equal to its hermitian
203        transpose.  If `dtype` is real, this is equivalent to being symmetric.
204      is_positive_definite:  Expect that this operator is positive definite,
205        meaning the quadratic form `x^H A x` has positive real part for all
206        nonzero `x`.  Note that we do not require the operator to be
207        self-adjoint to be positive-definite.  See:
208        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
209      is_square:  Expect that this operator acts like square [batch] matrices.
210      name: A name for this `LinearOperator`.
211      parameters: Python `dict` of parameters used to instantiate this
212        `LinearOperator`.
213
214    Raises:
215      ValueError:  If any member of graph_parents is `None` or not a `Tensor`.
216      ValueError:  If hints are set incorrectly.
217    """
218    # Check and auto-set flags.
219    if is_positive_definite:
220      if is_non_singular is False:
221        raise ValueError("A positive definite matrix is always non-singular.")
222      is_non_singular = True
223
224    if is_non_singular:
225      if is_square is False:
226        raise ValueError("A non-singular matrix is always square.")
227      is_square = True
228
229    if is_self_adjoint:
230      if is_square is False:
231        raise ValueError("A self-adjoint matrix is always square.")
232      is_square = True
233
234    self._is_square_set_or_implied_by_hints = is_square
235
236    if graph_parents is not None:
237      self._set_graph_parents(graph_parents)
238    else:
239      self._graph_parents = []
240    self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype
241    self._is_non_singular = is_non_singular
242    self._is_self_adjoint = is_self_adjoint
243    self._is_positive_definite = is_positive_definite
244    self._parameters = self._no_dependency(parameters)
245    self._parameters_sanitized = False
246    self._name = name or type(self).__name__
247
248  @contextlib.contextmanager
249  def _name_scope(self, name=None):  # pylint: disable=method-hidden
250    """Helper function to standardize op scope."""
251    full_name = self.name
252    if name is not None:
253      full_name += "/" + name
254    with ops.name_scope(full_name) as scope:
255      yield scope
256
257  @property
258  def parameters(self):
259    """Dictionary of parameters used to instantiate this `LinearOperator`."""
260    return dict(self._parameters)
261
262  @property
263  def dtype(self):
264    """The `DType` of `Tensor`s handled by this `LinearOperator`."""
265    return self._dtype
266
267  @property
268  def name(self):
269    """Name prepended to all ops created by this `LinearOperator`."""
270    return self._name
271
272  @property
273  @deprecation.deprecated(None, "Do not call `graph_parents`.")
274  def graph_parents(self):
275    """List of graph dependencies of this `LinearOperator`."""
276    return self._graph_parents
277
278  @property
279  def is_non_singular(self):
280    return self._is_non_singular
281
282  @property
283  def is_self_adjoint(self):
284    return self._is_self_adjoint
285
286  @property
287  def is_positive_definite(self):
288    return self._is_positive_definite
289
290  @property
291  def is_square(self):
292    """Return `True/False` depending on if this operator is square."""
293    # Static checks done after __init__.  Why?  Because domain/range dimension
294    # sometimes requires lots of work done in the derived class after init.
295    auto_square_check = self.domain_dimension == self.range_dimension
296    if self._is_square_set_or_implied_by_hints is False and auto_square_check:
297      raise ValueError(
298          "User set is_square hint to False, but the operator was square.")
299    if self._is_square_set_or_implied_by_hints is None:
300      return auto_square_check
301
302    return self._is_square_set_or_implied_by_hints
303
304  @abc.abstractmethod
305  def _shape(self):
306    # Write this in derived class to enable all static shape methods.
307    raise NotImplementedError("_shape is not implemented.")
308
309  @property
310  def shape(self):
311    """`TensorShape` of this `LinearOperator`.
312
313    If this operator acts like the batch matrix `A` with
314    `A.shape = [B1,...,Bb, M, N]`, then this returns
315    `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.shape`.
316
317    Returns:
318      `TensorShape`, statically determined, may be undefined.
319    """
320    return self._shape()
321
322  def _shape_tensor(self):
323    # This is not an abstractmethod, since we want derived classes to be able to
324    # override this with optional kwargs, which can reduce the number of
325    # `convert_to_tensor` calls.  See derived classes for examples.
326    raise NotImplementedError("_shape_tensor is not implemented.")
327
328  def shape_tensor(self, name="shape_tensor"):
329    """Shape of this `LinearOperator`, determined at runtime.
330
331    If this operator acts like the batch matrix `A` with
332    `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
333    `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`.
334
335    Args:
336      name:  A name for this `Op`.
337
338    Returns:
339      `int32` `Tensor`
340    """
341    with self._name_scope(name):  # pylint: disable=not-callable
342      # Prefer to use statically defined shape if available.
343      if self.shape.is_fully_defined():
344        return linear_operator_util.shape_tensor(self.shape.as_list())
345      else:
346        return self._shape_tensor()
347
348  @property
349  def batch_shape(self):
350    """`TensorShape` of batch dimensions of this `LinearOperator`.
351
352    If this operator acts like the batch matrix `A` with
353    `A.shape = [B1,...,Bb, M, N]`, then this returns
354    `TensorShape([B1,...,Bb])`, equivalent to `A.shape[:-2]`
355
356    Returns:
357      `TensorShape`, statically determined, may be undefined.
358    """
359    # Derived classes get this "for free" once .shape is implemented.
360    return self.shape[:-2]
361
362  def batch_shape_tensor(self, name="batch_shape_tensor"):
363    """Shape of batch dimensions of this operator, determined at runtime.
364
365    If this operator acts like the batch matrix `A` with
366    `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
367    `[B1,...,Bb]`.
368
369    Args:
370      name:  A name for this `Op`.
371
372    Returns:
373      `int32` `Tensor`
374    """
375    # Derived classes get this "for free" once .shape() is implemented.
376    with self._name_scope(name):  # pylint: disable=not-callable
377      return self._batch_shape_tensor()
378
379  def _batch_shape_tensor(self, shape=None):
380    # `shape` may be passed in if this can be pre-computed in a
381    # more efficient manner, e.g. without excessive Tensor conversions.
382    if self.batch_shape.is_fully_defined():
383      return linear_operator_util.shape_tensor(
384          self.batch_shape.as_list(), name="batch_shape")
385    else:
386      shape = self.shape_tensor() if shape is None else shape
387      return shape[:-2]
388
389  @property
390  def tensor_rank(self, name="tensor_rank"):
391    """Rank (in the sense of tensors) of matrix corresponding to this operator.
392
393    If this operator acts like the batch matrix `A` with
394    `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
395
396    Args:
397      name:  A name for this `Op`.
398
399    Returns:
400      Python integer, or None if the tensor rank is undefined.
401    """
402    # Derived classes get this "for free" once .shape() is implemented.
403    with self._name_scope(name):  # pylint: disable=not-callable
404      return self.shape.ndims
405
406  def tensor_rank_tensor(self, name="tensor_rank_tensor"):
407    """Rank (in the sense of tensors) of matrix corresponding to this operator.
408
409    If this operator acts like the batch matrix `A` with
410    `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
411
412    Args:
413      name:  A name for this `Op`.
414
415    Returns:
416      `int32` `Tensor`, determined at runtime.
417    """
418    # Derived classes get this "for free" once .shape() is implemented.
419    with self._name_scope(name):  # pylint: disable=not-callable
420      return self._tensor_rank_tensor()
421
422  def _tensor_rank_tensor(self, shape=None):
423    # `shape` may be passed in if this can be pre-computed in a
424    # more efficient manner, e.g. without excessive Tensor conversions.
425    if self.tensor_rank is not None:
426      return ops.convert_to_tensor_v2_with_dispatch(self.tensor_rank)
427    else:
428      shape = self.shape_tensor() if shape is None else shape
429      return array_ops.size(shape)
430
431  @property
432  def domain_dimension(self):
433    """Dimension (in the sense of vector spaces) of the domain of this operator.
434
435    If this operator acts like the batch matrix `A` with
436    `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
437
438    Returns:
439      `Dimension` object.
440    """
441    # Derived classes get this "for free" once .shape is implemented.
442    if self.shape.rank is None:
443      return tensor_shape.Dimension(None)
444    else:
445      return self.shape.dims[-1]
446
447  def domain_dimension_tensor(self, name="domain_dimension_tensor"):
448    """Dimension (in the sense of vector spaces) of the domain of this operator.
449
450    Determined at runtime.
451
452    If this operator acts like the batch matrix `A` with
453    `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
454
455    Args:
456      name:  A name for this `Op`.
457
458    Returns:
459      `int32` `Tensor`
460    """
461    # Derived classes get this "for free" once .shape() is implemented.
462    with self._name_scope(name):  # pylint: disable=not-callable
463      return self._domain_dimension_tensor()
464
465  def _domain_dimension_tensor(self, shape=None):
466    # `shape` may be passed in if this can be pre-computed in a
467    # more efficient manner, e.g. without excessive Tensor conversions.
468    dim_value = tensor_shape.dimension_value(self.domain_dimension)
469    if dim_value is not None:
470      return ops.convert_to_tensor_v2_with_dispatch(dim_value)
471    else:
472      shape = self.shape_tensor() if shape is None else shape
473      return shape[-1]
474
475  @property
476  def range_dimension(self):
477    """Dimension (in the sense of vector spaces) of the range of this operator.
478
479    If this operator acts like the batch matrix `A` with
480    `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
481
482    Returns:
483      `Dimension` object.
484    """
485    # Derived classes get this "for free" once .shape is implemented.
486    if self.shape.dims:
487      return self.shape.dims[-2]
488    else:
489      return tensor_shape.Dimension(None)
490
491  def range_dimension_tensor(self, name="range_dimension_tensor"):
492    """Dimension (in the sense of vector spaces) of the range of this operator.
493
494    Determined at runtime.
495
496    If this operator acts like the batch matrix `A` with
497    `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
498
499    Args:
500      name:  A name for this `Op`.
501
502    Returns:
503      `int32` `Tensor`
504    """
505    # Derived classes get this "for free" once .shape() is implemented.
506    with self._name_scope(name):  # pylint: disable=not-callable
507      return self._range_dimension_tensor()
508
509  def _range_dimension_tensor(self, shape=None):
510    # `shape` may be passed in if this can be pre-computed in a
511    # more efficient manner, e.g. without excessive Tensor conversions.
512    dim_value = tensor_shape.dimension_value(self.range_dimension)
513    if dim_value is not None:
514      return ops.convert_to_tensor_v2_with_dispatch(dim_value)
515    else:
516      shape = self.shape_tensor() if shape is None else shape
517      return shape[-2]
518
519  def _assert_non_singular(self):
520    """Private default implementation of _assert_non_singular."""
521    logging.warn(
522        "Using (possibly slow) default implementation of assert_non_singular."
523        "  Requires conversion to a dense matrix and O(N^3) operations.")
524    if self._can_use_cholesky():
525      return self.assert_positive_definite()
526    else:
527      singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False)
528      # TODO(langmore) Add .eig and .cond as methods.
529      cond = (math_ops.reduce_max(singular_values, axis=-1) /
530              math_ops.reduce_min(singular_values, axis=-1))
531      return check_ops.assert_less(
532          cond,
533          self._max_condition_number_to_be_non_singular(),
534          message="Singular matrix up to precision epsilon.")
535
536  def _max_condition_number_to_be_non_singular(self):
537    """Return the maximum condition number that we consider nonsingular."""
538    with ops.name_scope("max_nonsingular_condition_number"):
539      dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps
540      eps = math_ops.cast(
541          math_ops.reduce_max([
542              100.,
543              math_ops.cast(self.range_dimension_tensor(), self.dtype),
544              math_ops.cast(self.domain_dimension_tensor(), self.dtype)
545          ]), self.dtype) * dtype_eps
546      return 1. / eps
547
548  def assert_non_singular(self, name="assert_non_singular"):
549    """Returns an `Op` that asserts this operator is non singular.
550
551    This operator is considered non-singular if
552
553    ```
554    ConditionNumber < max{100, range_dimension, domain_dimension} * eps,
555    eps := np.finfo(self.dtype.as_numpy_dtype).eps
556    ```
557
558    Args:
559      name:  A string name to prepend to created ops.
560
561    Returns:
562      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
563        the operator is singular.
564    """
565    with self._name_scope(name):  # pylint: disable=not-callable
566      return self._assert_non_singular()
567
568  def _assert_positive_definite(self):
569    """Default implementation of _assert_positive_definite."""
570    logging.warn(
571        "Using (possibly slow) default implementation of "
572        "assert_positive_definite."
573        "  Requires conversion to a dense matrix and O(N^3) operations.")
574    # If the operator is self-adjoint, then checking that
575    # Cholesky decomposition succeeds + results in positive diag is necessary
576    # and sufficient.
577    if self.is_self_adjoint:
578      return check_ops.assert_positive(
579          array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())),
580          message="Matrix was not positive definite.")
581    # We have no generic check for positive definite.
582    raise NotImplementedError("assert_positive_definite is not implemented.")
583
584  def assert_positive_definite(self, name="assert_positive_definite"):
585    """Returns an `Op` that asserts this operator is positive definite.
586
587    Here, positive definite means that the quadratic form `x^H A x` has positive
588    real part for all nonzero `x`.  Note that we do not require the operator to
589    be self-adjoint to be positive definite.
590
591    Args:
592      name:  A name to give this `Op`.
593
594    Returns:
595      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
596        the operator is not positive definite.
597    """
598    with self._name_scope(name):  # pylint: disable=not-callable
599      return self._assert_positive_definite()
600
601  def _assert_self_adjoint(self):
602    dense = self.to_dense()
603    logging.warn(
604        "Using (possibly slow) default implementation of assert_self_adjoint."
605        "  Requires conversion to a dense matrix.")
606    return check_ops.assert_equal(
607        dense,
608        linalg.adjoint(dense),
609        message="Matrix was not equal to its adjoint.")
610
611  def assert_self_adjoint(self, name="assert_self_adjoint"):
612    """Returns an `Op` that asserts this operator is self-adjoint.
613
614    Here we check that this operator is *exactly* equal to its hermitian
615    transpose.
616
617    Args:
618      name:  A string name to prepend to created ops.
619
620    Returns:
621      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
622        the operator is not self-adjoint.
623    """
624    with self._name_scope(name):  # pylint: disable=not-callable
625      return self._assert_self_adjoint()
626
627  def _check_input_dtype(self, arg):
628    """Check that arg.dtype == self.dtype."""
629    if arg.dtype.base_dtype != self.dtype:
630      raise TypeError(
631          "Expected argument to have dtype %s.  Found: %s in tensor %s" %
632          (self.dtype, arg.dtype, arg))
633
634  @abc.abstractmethod
635  def _matmul(self, x, adjoint=False, adjoint_arg=False):
636    raise NotImplementedError("_matmul is not implemented.")
637
638  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
639    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.
640
641    ```python
642    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
643    operator = LinearOperator(...)
644    operator.shape = [..., M, N]
645
646    X = ... # shape [..., N, R], batch matrix, R > 0.
647
648    Y = operator.matmul(X)
649    Y.shape
650    ==> [..., M, R]
651
652    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
653    ```
654
655    Args:
656      x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as
657        `self`. See class docstring for definition of compatibility.
658      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
659      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
660        the hermitian transpose (transposition and complex conjugation).
661      name:  A name for this `Op`.
662
663    Returns:
664      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
665        as `self`.
666    """
667    if isinstance(x, LinearOperator):
668      left_operator = self.adjoint() if adjoint else self
669      right_operator = x.adjoint() if adjoint_arg else x
670
671      if (right_operator.range_dimension is not None and
672          left_operator.domain_dimension is not None and
673          right_operator.range_dimension != left_operator.domain_dimension):
674        raise ValueError(
675            "Operators are incompatible. Expected `x` to have dimension"
676            " {} but got {}.".format(
677                left_operator.domain_dimension, right_operator.range_dimension))
678      with self._name_scope(name):  # pylint: disable=not-callable
679        return linear_operator_algebra.matmul(left_operator, right_operator)
680
681    with self._name_scope(name):  # pylint: disable=not-callable
682      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
683      self._check_input_dtype(x)
684
685      self_dim = -2 if adjoint else -1
686      arg_dim = -1 if adjoint_arg else -2
687      tensor_shape.dimension_at_index(
688          self.shape, self_dim).assert_is_compatible_with(
689              x.shape[arg_dim])
690
691      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
692
693  def __matmul__(self, other):
694    return self.matmul(other)
695
696  def _matvec(self, x, adjoint=False):
697    x_mat = array_ops.expand_dims(x, axis=-1)
698    y_mat = self.matmul(x_mat, adjoint=adjoint)
699    return array_ops.squeeze(y_mat, axis=-1)
700
701  def matvec(self, x, adjoint=False, name="matvec"):
702    """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.
703
704    ```python
705    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
706    operator = LinearOperator(...)
707
708    X = ... # shape [..., N], batch vector
709
710    Y = operator.matvec(X)
711    Y.shape
712    ==> [..., M]
713
714    Y[..., :] = sum_j A[..., :, j] X[..., j]
715    ```
716
717    Args:
718      x: `Tensor` with compatible shape and same `dtype` as `self`.
719        `x` is treated as a [batch] vector meaning for every set of leading
720        dimensions, the last dimension defines a vector.
721        See class docstring for definition of compatibility.
722      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
723      name:  A name for this `Op`.
724
725    Returns:
726      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
727    """
728    with self._name_scope(name):  # pylint: disable=not-callable
729      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
730      self._check_input_dtype(x)
731      self_dim = -2 if adjoint else -1
732      tensor_shape.dimension_at_index(
733          self.shape, self_dim).assert_is_compatible_with(x.shape[-1])
734      return self._matvec(x, adjoint=adjoint)
735
736  def _determinant(self):
737    logging.warn(
738        "Using (possibly slow) default implementation of determinant."
739        "  Requires conversion to a dense matrix and O(N^3) operations.")
740    if self._can_use_cholesky():
741      return math_ops.exp(self.log_abs_determinant())
742    return linalg_ops.matrix_determinant(self.to_dense())
743
744  def determinant(self, name="det"):
745    """Determinant for every batch member.
746
747    Args:
748      name:  A name for this `Op`.
749
750    Returns:
751      `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
752
753    Raises:
754      NotImplementedError:  If `self.is_square` is `False`.
755    """
756    if self.is_square is False:
757      raise NotImplementedError(
758          "Determinant not implemented for an operator that is expected to "
759          "not be square.")
760    with self._name_scope(name):  # pylint: disable=not-callable
761      return self._determinant()
762
763  def _log_abs_determinant(self):
764    logging.warn(
765        "Using (possibly slow) default implementation of determinant."
766        "  Requires conversion to a dense matrix and O(N^3) operations.")
767    if self._can_use_cholesky():
768      diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense()))
769      return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1])
770    _, log_abs_det = linalg.slogdet(self.to_dense())
771    return log_abs_det
772
773  def log_abs_determinant(self, name="log_abs_det"):
774    """Log absolute value of determinant for every batch member.
775
776    Args:
777      name:  A name for this `Op`.
778
779    Returns:
780      `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
781
782    Raises:
783      NotImplementedError:  If `self.is_square` is `False`.
784    """
785    if self.is_square is False:
786      raise NotImplementedError(
787          "Determinant not implemented for an operator that is expected to "
788          "not be square.")
789    with self._name_scope(name):  # pylint: disable=not-callable
790      return self._log_abs_determinant()
791
792  def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False):
793    """Solve by conversion to a dense matrix."""
794    if self.is_square is False:  # pylint: disable=g-bool-id-comparison
795      raise NotImplementedError(
796          "Solve is not yet implemented for non-square operators.")
797    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
798    if self._can_use_cholesky():
799      return linalg_ops.cholesky_solve(
800          linalg_ops.cholesky(self.to_dense()), rhs)
801    return linear_operator_util.matrix_solve_with_broadcast(
802        self.to_dense(), rhs, adjoint=adjoint)
803
804  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
805    """Default implementation of _solve."""
806    logging.warn(
807        "Using (possibly slow) default implementation of solve."
808        "  Requires conversion to a dense matrix and O(N^3) operations.")
809    return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
810
811  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
812    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
813
814    The returned `Tensor` will be close to an exact solution if `A` is well
815    conditioned. Otherwise closeness will vary. See class docstring for details.
816
817    Examples:
818
819    ```python
820    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
821    operator = LinearOperator(...)
822    operator.shape = [..., M, N]
823
824    # Solve R > 0 linear systems for every member of the batch.
825    RHS = ... # shape [..., M, R]
826
827    X = operator.solve(RHS)
828    # X[..., :, r] is the solution to the r'th linear system
829    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
830
831    operator.matmul(X)
832    ==> RHS
833    ```
834
835    Args:
836      rhs: `Tensor` with same `dtype` as this operator and compatible shape.
837        `rhs` is treated like a [batch] matrix meaning for every set of leading
838        dimensions, the last two dimensions defines a matrix.
839        See class docstring for definition of compatibility.
840      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
841        of this `LinearOperator`:  `A^H X = rhs`.
842      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
843        is the hermitian transpose (transposition and complex conjugation).
844      name:  A name scope to use for ops added by this method.
845
846    Returns:
847      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
848
849    Raises:
850      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
851    """
852    if self.is_non_singular is False:
853      raise NotImplementedError(
854          "Exact solve not implemented for an operator that is expected to "
855          "be singular.")
856    if self.is_square is False:
857      raise NotImplementedError(
858          "Exact solve not implemented for an operator that is expected to "
859          "not be square.")
860    if isinstance(rhs, LinearOperator):
861      left_operator = self.adjoint() if adjoint else self
862      right_operator = rhs.adjoint() if adjoint_arg else rhs
863
864      if (right_operator.range_dimension is not None and
865          left_operator.domain_dimension is not None and
866          right_operator.range_dimension != left_operator.domain_dimension):
867        raise ValueError(
868            "Operators are incompatible. Expected `rhs` to have dimension"
869            " {} but got {}.".format(
870                left_operator.domain_dimension, right_operator.range_dimension))
871      with self._name_scope(name):  # pylint: disable=not-callable
872        return linear_operator_algebra.solve(left_operator, right_operator)
873
874    with self._name_scope(name):  # pylint: disable=not-callable
875      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
876      self._check_input_dtype(rhs)
877
878      self_dim = -1 if adjoint else -2
879      arg_dim = -1 if adjoint_arg else -2
880      tensor_shape.dimension_at_index(
881          self.shape, self_dim).assert_is_compatible_with(
882              rhs.shape[arg_dim])
883
884      return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
885
886  def _solvevec(self, rhs, adjoint=False):
887    """Default implementation of _solvevec."""
888    rhs_mat = array_ops.expand_dims(rhs, axis=-1)
889    solution_mat = self.solve(rhs_mat, adjoint=adjoint)
890    return array_ops.squeeze(solution_mat, axis=-1)
891
892  def solvevec(self, rhs, adjoint=False, name="solve"):
893    """Solve single equation with best effort: `A X = rhs`.
894
895    The returned `Tensor` will be close to an exact solution if `A` is well
896    conditioned. Otherwise closeness will vary. See class docstring for details.
897
898    Examples:
899
900    ```python
901    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
902    operator = LinearOperator(...)
903    operator.shape = [..., M, N]
904
905    # Solve one linear system for every member of the batch.
906    RHS = ... # shape [..., M]
907
908    X = operator.solvevec(RHS)
909    # X is the solution to the linear system
910    # sum_j A[..., :, j] X[..., j] = RHS[..., :]
911
912    operator.matvec(X)
913    ==> RHS
914    ```
915
916    Args:
917      rhs: `Tensor` with same `dtype` as this operator.
918        `rhs` is treated like a [batch] vector meaning for every set of leading
919        dimensions, the last dimension defines a vector.  See class docstring
920        for definition of compatibility regarding batch dimensions.
921      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
922        of this `LinearOperator`:  `A^H X = rhs`.
923      name:  A name scope to use for ops added by this method.
924
925    Returns:
926      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
927
928    Raises:
929      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
930    """
931    with self._name_scope(name):  # pylint: disable=not-callable
932      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
933      self._check_input_dtype(rhs)
934      self_dim = -1 if adjoint else -2
935      tensor_shape.dimension_at_index(
936          self.shape, self_dim).assert_is_compatible_with(rhs.shape[-1])
937
938      return self._solvevec(rhs, adjoint=adjoint)
939
940  def adjoint(self, name="adjoint"):
941    """Returns the adjoint of the current `LinearOperator`.
942
943    Given `A` representing this `LinearOperator`, return `A*`.
944    Note that calling `self.adjoint()` and `self.H` are equivalent.
945
946    Args:
947      name:  A name for this `Op`.
948
949    Returns:
950      `LinearOperator` which represents the adjoint of this `LinearOperator`.
951    """
952    if self.is_self_adjoint is True:  # pylint: disable=g-bool-id-comparison
953      return self
954    with self._name_scope(name):  # pylint: disable=not-callable
955      return linear_operator_algebra.adjoint(self)
956
957  # self.H is equivalent to self.adjoint().
958  H = property(adjoint, None)
959
960  def inverse(self, name="inverse"):
961    """Returns the Inverse of this `LinearOperator`.
962
963    Given `A` representing this `LinearOperator`, return a `LinearOperator`
964    representing `A^-1`.
965
966    Args:
967      name: A name scope to use for ops added by this method.
968
969    Returns:
970      `LinearOperator` representing inverse of this matrix.
971
972    Raises:
973      ValueError: When the `LinearOperator` is not hinted to be `non_singular`.
974    """
975    if self.is_square is False:  # pylint: disable=g-bool-id-comparison
976      raise ValueError("Cannot take the Inverse: This operator represents "
977                       "a non square matrix.")
978    if self.is_non_singular is False:  # pylint: disable=g-bool-id-comparison
979      raise ValueError("Cannot take the Inverse: This operator represents "
980                       "a singular matrix.")
981
982    with self._name_scope(name):  # pylint: disable=not-callable
983      return linear_operator_algebra.inverse(self)
984
985  def cholesky(self, name="cholesky"):
986    """Returns a Cholesky factor as a `LinearOperator`.
987
988    Given `A` representing this `LinearOperator`, if `A` is positive definite
989    self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky
990    decomposition.
991
992    Args:
993      name:  A name for this `Op`.
994
995    Returns:
996      `LinearOperator` which represents the lower triangular matrix
997      in the Cholesky decomposition.
998
999    Raises:
1000      ValueError: When the `LinearOperator` is not hinted to be positive
1001        definite and self adjoint.
1002    """
1003
1004    if not self._can_use_cholesky():
1005      raise ValueError("Cannot take the Cholesky decomposition: "
1006                       "Not a positive definite self adjoint matrix.")
1007    with self._name_scope(name):  # pylint: disable=not-callable
1008      return linear_operator_algebra.cholesky(self)
1009
1010  def _to_dense(self):
1011    """Generic and often inefficient implementation.  Override often."""
1012    if self.batch_shape.is_fully_defined():
1013      batch_shape = self.batch_shape
1014    else:
1015      batch_shape = self.batch_shape_tensor()
1016
1017    dim_value = tensor_shape.dimension_value(self.domain_dimension)
1018    if dim_value is not None:
1019      n = dim_value
1020    else:
1021      n = self.domain_dimension_tensor()
1022
1023    eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
1024    return self.matmul(eye)
1025
1026  def to_dense(self, name="to_dense"):
1027    """Return a dense (batch) matrix representing this operator."""
1028    with self._name_scope(name):  # pylint: disable=not-callable
1029      return self._to_dense()
1030
1031  def _diag_part(self):
1032    """Generic and often inefficient implementation.  Override often."""
1033    return array_ops.matrix_diag_part(self.to_dense())
1034
1035  def diag_part(self, name="diag_part"):
1036    """Efficiently get the [batch] diagonal part of this operator.
1037
1038    If this operator has shape `[B1,...,Bb, M, N]`, this returns a
1039    `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where
1040    `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`.
1041
1042    ```
1043    my_operator = LinearOperatorDiag([1., 2.])
1044
1045    # Efficiently get the diagonal
1046    my_operator.diag_part()
1047    ==> [1., 2.]
1048
1049    # Equivalent, but inefficient method
1050    tf.linalg.diag_part(my_operator.to_dense())
1051    ==> [1., 2.]
1052    ```
1053
1054    Args:
1055      name:  A name for this `Op`.
1056
1057    Returns:
1058      diag_part:  A `Tensor` of same `dtype` as self.
1059    """
1060    with self._name_scope(name):  # pylint: disable=not-callable
1061      return self._diag_part()
1062
1063  def _trace(self):
1064    return math_ops.reduce_sum(self.diag_part(), axis=-1)
1065
1066  def trace(self, name="trace"):
1067    """Trace of the linear operator, equal to sum of `self.diag_part()`.
1068
1069    If the operator is square, this is also the sum of the eigenvalues.
1070
1071    Args:
1072      name:  A name for this `Op`.
1073
1074    Returns:
1075      Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`.
1076    """
1077    with self._name_scope(name):  # pylint: disable=not-callable
1078      return self._trace()
1079
1080  def _add_to_tensor(self, x):
1081    # Override if a more efficient implementation is available.
1082    return self.to_dense() + x
1083
1084  def add_to_tensor(self, x, name="add_to_tensor"):
1085    """Add matrix represented by this operator to `x`.  Equivalent to `A + x`.
1086
1087    Args:
1088      x:  `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
1089      name:  A name to give this `Op`.
1090
1091    Returns:
1092      A `Tensor` with broadcast shape and same `dtype` as `self`.
1093    """
1094    with self._name_scope(name):  # pylint: disable=not-callable
1095      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
1096      self._check_input_dtype(x)
1097      return self._add_to_tensor(x)
1098
1099  def _eigvals(self):
1100    return linalg_ops.self_adjoint_eigvals(self.to_dense())
1101
1102  def eigvals(self, name="eigvals"):
1103    """Returns the eigenvalues of this linear operator.
1104
1105    If the operator is marked as self-adjoint (via `is_self_adjoint`)
1106    this computation can be more efficient.
1107
1108    Note: This currently only supports self-adjoint operators.
1109
1110    Args:
1111      name:  A name for this `Op`.
1112
1113    Returns:
1114      Shape `[B1,...,Bb, N]` `Tensor` of same `dtype` as `self`.
1115    """
1116    if not self.is_self_adjoint:
1117      raise NotImplementedError("Only self-adjoint matrices are supported.")
1118    with self._name_scope(name):  # pylint: disable=not-callable
1119      return self._eigvals()
1120
1121  def _cond(self):
1122    if not self.is_self_adjoint:
1123      # In general the condition number is the ratio of the
1124      # absolute value of the largest and smallest singular values.
1125      vals = linalg_ops.svd(self.to_dense(), compute_uv=False)
1126    else:
1127      # For self-adjoint matrices, and in general normal matrices,
1128      # we can use eigenvalues.
1129      vals = math_ops.abs(self._eigvals())
1130
1131    return (math_ops.reduce_max(vals, axis=-1) /
1132            math_ops.reduce_min(vals, axis=-1))
1133
1134  def cond(self, name="cond"):
1135    """Returns the condition number of this linear operator.
1136
1137    Args:
1138      name:  A name for this `Op`.
1139
1140    Returns:
1141      Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`.
1142    """
1143    with self._name_scope(name):  # pylint: disable=not-callable
1144      return self._cond()
1145
1146  def _can_use_cholesky(self):
1147    return self.is_self_adjoint and self.is_positive_definite
1148
1149  def _set_graph_parents(self, graph_parents):
1150    """Set self._graph_parents.  Called during derived class init.
1151
1152    This method allows derived classes to set graph_parents, without triggering
1153    a deprecation warning (which is invoked if `graph_parents` is passed during
1154    `__init__`.
1155
1156    Args:
1157      graph_parents: Iterable over Tensors.
1158    """
1159    # TODO(b/143910018) Remove this function in V3.
1160    graph_parents = [] if graph_parents is None else graph_parents
1161    for i, t in enumerate(graph_parents):
1162      if t is None or not (linear_operator_util.is_ref(t) or
1163                           tensor_util.is_tf_type(t)):
1164        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
1165    self._graph_parents = graph_parents
1166
1167  @property
1168  def _composite_tensor_fields(self):
1169    """A tuple of parameter names to rebuild the `LinearOperator`.
1170
1171    The tuple contains the names of kwargs to the `LinearOperator`'s constructor
1172    that the `TypeSpec` needs to rebuild the `LinearOperator` instance.
1173
1174    "is_non_singular", "is_self_adjoint", "is_positive_definite", and
1175    "is_square" are common to all `LinearOperator` subclasses and may be
1176    omitted.
1177    """
1178    return ()
1179
1180  @property
1181  def _composite_tensor_prefer_static_fields(self):
1182    """A tuple of names referring to parameters that may be treated statically.
1183
1184    This is a subset of `_composite_tensor_fields`, and contains the names of
1185    of `Tensor`-like args to the `LinearOperator`s constructor that may be
1186    stored as static values, if they are statically known. These are typically
1187    shapes or axis values.
1188    """
1189    return ()
1190
1191  @property
1192  def _type_spec(self):
1193    # This property will be overwritten by the `@make_composite_tensor`
1194    # decorator. However, we need it so that a valid subclass of the `ABCMeta`
1195    # class `CompositeTensor` can be constructed and passed to the
1196    # `@make_composite_tensor` decorator.
1197    pass
1198
1199
1200class _LinearOperatorSpec(type_spec.TypeSpec):
1201  """A tf.TypeSpec for `LinearOperator` objects."""
1202
1203  __slots__ = ("_param_specs", "_non_tensor_params", "_prefer_static_fields")
1204
1205  def __init__(self, param_specs, non_tensor_params, prefer_static_fields):
1206    """Initializes a new `_LinearOperatorSpec`.
1207
1208    Args:
1209      param_specs: Python `dict` of `tf.TypeSpec` instances that describe
1210        kwargs to the `LinearOperator`'s constructor that are `Tensor`-like or
1211        `CompositeTensor` subclasses.
1212      non_tensor_params: Python `dict` containing non-`Tensor` and non-
1213        `CompositeTensor` kwargs to the `LinearOperator`'s constructor.
1214      prefer_static_fields: Python `tuple` of strings corresponding to the names
1215        of `Tensor`-like args to the `LinearOperator`s constructor that may be
1216        stored as static values, if known. These are typically shapes, indices,
1217        or axis values.
1218    """
1219    self._param_specs = param_specs
1220    self._non_tensor_params = non_tensor_params
1221    self._prefer_static_fields = prefer_static_fields
1222
1223  @classmethod
1224  def from_operator(cls, operator):
1225    """Builds a `_LinearOperatorSpec` from a `LinearOperator` instance.
1226
1227    Args:
1228      operator: An instance of `LinearOperator`.
1229
1230    Returns:
1231      linear_operator_spec: An instance of `_LinearOperatorSpec` to be used as
1232        the `TypeSpec` of `operator`.
1233    """
1234    validation_fields = ("is_non_singular", "is_self_adjoint",
1235                         "is_positive_definite", "is_square")
1236    kwargs = _extract_attrs(
1237        operator,
1238        keys=set(operator._composite_tensor_fields + validation_fields))  # pylint: disable=protected-access
1239
1240    non_tensor_params = {}
1241    param_specs = {}
1242    for k, v in list(kwargs.items()):
1243      type_spec_or_v = _extract_type_spec_recursively(v)
1244      is_tensor = [isinstance(x, type_spec.TypeSpec)
1245                   for x in nest.flatten(type_spec_or_v)]
1246      if all(is_tensor):
1247        param_specs[k] = type_spec_or_v
1248      elif not any(is_tensor):
1249        non_tensor_params[k] = v
1250      else:
1251        raise NotImplementedError(f"Field {k} contains a mix of `Tensor` and "
1252                                  f" non-`Tensor` values.")
1253
1254    return cls(
1255        param_specs=param_specs,
1256        non_tensor_params=non_tensor_params,
1257        prefer_static_fields=operator._composite_tensor_prefer_static_fields)  # pylint: disable=protected-access
1258
1259  def _to_components(self, obj):
1260    return _extract_attrs(obj, keys=list(self._param_specs))
1261
1262  def _from_components(self, components):
1263    kwargs = dict(self._non_tensor_params, **components)
1264    return self.value_type(**kwargs)
1265
1266  @property
1267  def _component_specs(self):
1268    return self._param_specs
1269
1270  def _serialize(self):
1271    return (self._param_specs,
1272            self._non_tensor_params,
1273            self._prefer_static_fields)
1274
1275
1276def make_composite_tensor(cls, module_name="tf.linalg"):
1277  """Class decorator to convert `LinearOperator`s to `CompositeTensor`."""
1278
1279  spec_name = "{}Spec".format(cls.__name__)
1280  spec_type = type(spec_name, (_LinearOperatorSpec,), {"value_type": cls})
1281  type_spec.register("{}.{}".format(module_name, spec_name))(spec_type)
1282  cls._type_spec = property(spec_type.from_operator)  # pylint: disable=protected-access
1283  return cls
1284
1285
1286def _extract_attrs(op, keys):
1287  """Extract constructor kwargs to reconstruct `op`.
1288
1289  Args:
1290    op: A `LinearOperator` instance.
1291    keys: A Python `tuple` of strings indicating the names of the constructor
1292      kwargs to extract from `op`.
1293
1294  Returns:
1295    kwargs: A Python `dict` of kwargs to `op`'s constructor, keyed by `keys`.
1296  """
1297
1298  kwargs = {}
1299  not_found = object()
1300  for k in keys:
1301    srcs = [
1302        getattr(op, k, not_found), getattr(op, "_" + k, not_found),
1303        getattr(op, "parameters", {}).get(k, not_found),
1304    ]
1305    if any(v is not not_found for v in srcs):
1306      kwargs[k] = [v for v in srcs if v is not not_found][0]
1307    else:
1308      raise ValueError(
1309          f"Could not determine an appropriate value for field `{k}` in object "
1310          f" `{op}`. Looked for \n"
1311          f" 1. an attr called `{k}`,\n"
1312          f" 2. an attr called `_{k}`,\n"
1313          f" 3. an entry in `op.parameters` with key '{k}'.")
1314    if k in op._composite_tensor_prefer_static_fields and kwargs[k] is not None:  # pylint: disable=protected-access
1315      if tensor_util.is_tensor(kwargs[k]):
1316        static_val = tensor_util.constant_value(kwargs[k])
1317        if static_val is not None:
1318          kwargs[k] = static_val
1319    if isinstance(kwargs[k], (np.ndarray, np.generic)):
1320      kwargs[k] = kwargs[k].tolist()
1321  return kwargs
1322
1323
1324def _extract_type_spec_recursively(value):
1325  """Return (collection of) `TypeSpec`(s) for `value` if it includes `Tensor`s.
1326
1327  If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If
1328  `value` is a collection containing `Tensor` values, recursively supplant them
1329  with their respective `TypeSpec`s in a collection of parallel stucture.
1330
1331  If `value` is none of the above, return it unchanged.
1332
1333  Args:
1334    value: a Python `object` to (possibly) turn into a (collection of)
1335    `tf.TypeSpec`(s).
1336
1337  Returns:
1338    spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value`
1339    or `value`, if no `Tensor`s are found.
1340  """
1341  if isinstance(value, composite_tensor.CompositeTensor):
1342    return value._type_spec  # pylint: disable=protected-access
1343  if isinstance(value, variables.Variable):
1344    return resource_variable_ops.VariableSpec(
1345        value.shape, dtype=value.dtype, trainable=value.trainable)
1346  if tensor_util.is_tensor(value):
1347    return tensor_spec.TensorSpec(value.shape, value.dtype)
1348  # Unwrap trackable data structures to comply with `Type_Spec._serialize`
1349  # requirements. `ListWrapper`s are converted to `list`s, and for other
1350  # trackable data structures, the `__wrapped__` attribute is used.
1351  if isinstance(value, list):
1352    return list(_extract_type_spec_recursively(v) for v in value)
1353  if isinstance(value, data_structures.TrackableDataStructure):
1354    return _extract_type_spec_recursively(value.__wrapped__)
1355  if isinstance(value, tuple):
1356    return type(value)(_extract_type_spec_recursively(x) for x in value)
1357  if isinstance(value, dict):
1358    return type(value)((k, _extract_type_spec_recursively(v))
1359                       for k, v in value.items())
1360  return value
1361
1362
1363# Overrides for tf.linalg functions. This allows a LinearOperator to be used in
1364# place of a Tensor.
1365# For instance tf.trace(linop) and linop.trace() both work.
1366
1367
1368@dispatch.dispatch_for_types(linalg.adjoint, LinearOperator)
1369def _adjoint(matrix, name=None):
1370  return matrix.adjoint(name)
1371
1372
1373@dispatch.dispatch_for_types(linalg.cholesky, LinearOperator)
1374def _cholesky(input, name=None):   # pylint:disable=redefined-builtin
1375  return input.cholesky(name)
1376
1377
1378# The signature has to match with the one in python/op/array_ops.py,
1379# so we have k, padding_value, and align even though we don't use them here.
1380# pylint:disable=unused-argument
1381@dispatch.dispatch_for_types(linalg.diag_part, LinearOperator)
1382def _diag_part(
1383    input,  # pylint:disable=redefined-builtin
1384    name="diag_part",
1385    k=0,
1386    padding_value=0,
1387    align="RIGHT_LEFT"):
1388  return input.diag_part(name)
1389# pylint:enable=unused-argument
1390
1391
1392@dispatch.dispatch_for_types(linalg.det, LinearOperator)
1393def _det(input, name=None):  # pylint:disable=redefined-builtin
1394  return input.determinant(name)
1395
1396
1397@dispatch.dispatch_for_types(linalg.inv, LinearOperator)
1398def _inverse(input, adjoint=False, name=None):   # pylint:disable=redefined-builtin
1399  inv = input.inverse(name)
1400  if adjoint:
1401    inv = inv.adjoint()
1402  return inv
1403
1404
1405@dispatch.dispatch_for_types(linalg.logdet, LinearOperator)
1406def _logdet(matrix, name=None):
1407  if matrix.is_positive_definite and matrix.is_self_adjoint:
1408    return matrix.log_abs_determinant(name)
1409  raise ValueError("Expected matrix to be self-adjoint positive definite.")
1410
1411
1412@dispatch.dispatch_for_types(math_ops.matmul, LinearOperator)
1413def _matmul(  # pylint:disable=missing-docstring
1414    a,
1415    b,
1416    transpose_a=False,
1417    transpose_b=False,
1418    adjoint_a=False,
1419    adjoint_b=False,
1420    a_is_sparse=False,
1421    b_is_sparse=False,
1422    output_type=None,  # pylint: disable=unused-argument
1423    name=None):
1424  if transpose_a or transpose_b:
1425    raise ValueError("Transposing not supported at this time.")
1426  if a_is_sparse or b_is_sparse:
1427    raise ValueError("Sparse methods not supported at this time.")
1428  if not isinstance(a, LinearOperator):
1429    # We use the identity (B^HA^H)^H =  AB
1430    adjoint_matmul = b.matmul(
1431        a,
1432        adjoint=(not adjoint_b),
1433        adjoint_arg=(not adjoint_a),
1434        name=name)
1435    return linalg.adjoint(adjoint_matmul)
1436  return a.matmul(
1437      b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name)
1438
1439
1440@dispatch.dispatch_for_types(linalg.solve, LinearOperator)
1441def _solve(
1442    matrix,
1443    rhs,
1444    adjoint=False,
1445    name=None):
1446  if not isinstance(matrix, LinearOperator):
1447    raise ValueError("Passing in `matrix` as a Tensor and `rhs` as a "
1448                     "LinearOperator is not supported.")
1449  return matrix.solve(rhs, adjoint=adjoint, name=name)
1450
1451
1452@dispatch.dispatch_for_types(linalg.trace, LinearOperator)
1453def _trace(x, name=None):
1454  return x.trace(name)
1455