• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` that wraps a [batch] matrix."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.linalg import linear_operator
26from tensorflow.python.ops.linalg import linear_operator_util
27from tensorflow.python.util.tf_export import tf_export
28
29__all__ = ["LinearOperatorFullMatrix"]
30
31
32@tf_export("linalg.LinearOperatorFullMatrix")
33class LinearOperatorFullMatrix(linear_operator.LinearOperator):
34  """`LinearOperator` that wraps a [batch] matrix.
35
36  This operator wraps a [batch] matrix `A` (which is a `Tensor`) with shape
37  `[B1,...,Bb, M, N]` for some `b >= 0`.  The first `b` indices index a
38  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
39  an `M x N` matrix.
40
41  ```python
42  # Create a 2 x 2 linear operator.
43  matrix = [[1., 2.], [3., 4.]]
44  operator = LinearOperatorFullMatrix(matrix)
45
46  operator.to_dense()
47  ==> [[1., 2.]
48       [3., 4.]]
49
50  operator.shape
51  ==> [2, 2]
52
53  operator.log_abs_determinant()
54  ==> scalar Tensor
55
56  x = ... Shape [2, 4] Tensor
57  operator.matmul(x)
58  ==> Shape [2, 4] Tensor
59
60  # Create a [2, 3] batch of 4 x 4 linear operators.
61  matrix = tf.random.normal(shape=[2, 3, 4, 4])
62  operator = LinearOperatorFullMatrix(matrix)
63  ```
64
65  #### Shape compatibility
66
67  This operator acts on [batch] matrix with compatible shape.
68  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
69
70  ```
71  operator.shape = [B1,...,Bb] + [M, N],  with b >= 0
72  x.shape =        [B1,...,Bb] + [N, R],  with R >= 0.
73  ```
74
75  #### Performance
76
77  `LinearOperatorFullMatrix` has exactly the same performance as would be
78  achieved by using standard `TensorFlow` matrix ops.  Intelligent choices are
79  made based on the following initialization hints.
80
81  * If `dtype` is real, and `is_self_adjoint` and `is_positive_definite`, a
82    Cholesky factorization is used for the determinant and solve.
83
84  In all cases, suppose `operator` is a `LinearOperatorFullMatrix` of shape
85  `[M, N]`, and `x.shape = [N, R]`.  Then
86
87  * `operator.matmul(x)` is `O(M * N * R)`.
88  * If `M=N`, `operator.solve(x)` is `O(N^3 * R)`.
89  * If `M=N`, `operator.determinant()` is `O(N^3)`.
90
91  If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and
92  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
93
94  #### Matrix property hints
95
96  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
97  for `X = non_singular, self_adjoint, positive_definite, square`.
98  These have the following meaning:
99
100  * If `is_X == True`, callers should expect the operator to have the
101    property `X`.  This is a promise that should be fulfilled, but is *not* a
102    runtime assert.  For example, finite floating point precision may result
103    in these promises being violated.
104  * If `is_X == False`, callers should expect the operator to not have `X`.
105  * If `is_X == None` (the default), callers should have no expectation either
106    way.
107  """
108
109  def __init__(self,
110               matrix,
111               is_non_singular=None,
112               is_self_adjoint=None,
113               is_positive_definite=None,
114               is_square=None,
115               name="LinearOperatorFullMatrix"):
116    r"""Initialize a `LinearOperatorFullMatrix`.
117
118    Args:
119      matrix:  Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`.
120        Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
121        `complex128`.
122      is_non_singular:  Expect that this operator is non-singular.
123      is_self_adjoint:  Expect that this operator is equal to its hermitian
124        transpose.
125      is_positive_definite:  Expect that this operator is positive definite,
126        meaning the quadratic form `x^H A x` has positive real part for all
127        nonzero `x`.  Note that we do not require the operator to be
128        self-adjoint to be positive-definite.  See:
129        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
130      is_square:  Expect that this operator acts like square [batch] matrices.
131      name: A name for this `LinearOperator`.
132
133    Raises:
134      TypeError:  If `diag.dtype` is not an allowed type.
135    """
136    parameters = dict(
137        matrix=matrix,
138        is_non_singular=is_non_singular,
139        is_self_adjoint=is_self_adjoint,
140        is_positive_definite=is_positive_definite,
141        is_square=is_square,
142        name=name
143    )
144
145    with ops.name_scope(name, values=[matrix]):
146      self._matrix = linear_operator_util.convert_nonref_to_tensor(
147          matrix, name="matrix")
148      self._check_matrix(self._matrix)
149
150      super(LinearOperatorFullMatrix, self).__init__(
151          dtype=self._matrix.dtype,
152          is_non_singular=is_non_singular,
153          is_self_adjoint=is_self_adjoint,
154          is_positive_definite=is_positive_definite,
155          is_square=is_square,
156          parameters=parameters,
157          name=name)
158      # TODO(b/143910018) Remove graph_parents in V3.
159      self._set_graph_parents([self._matrix])
160
161  def _check_matrix(self, matrix):
162    """Static check of the `matrix` argument."""
163    allowed_dtypes = [
164        dtypes.float16,
165        dtypes.float32,
166        dtypes.float64,
167        dtypes.complex64,
168        dtypes.complex128,
169    ]
170
171    matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
172
173    dtype = matrix.dtype
174    if dtype not in allowed_dtypes:
175      raise TypeError(
176          "Argument matrix must have dtype in %s.  Found: %s"
177          % (allowed_dtypes, dtype))
178
179    if matrix.shape.ndims is not None and matrix.shape.ndims < 2:
180      raise ValueError(
181          "Argument matrix must have at least 2 dimensions.  Found: %s"
182          % matrix)
183
184  def _shape(self):
185    return self._matrix.shape
186
187  def _shape_tensor(self):
188    return array_ops.shape(self._matrix)
189
190  def _matmul(self, x, adjoint=False, adjoint_arg=False):
191    return math_ops.matmul(
192        self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
193
194  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
195    return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
196
197  def _to_dense(self):
198    return self._matrix
199