• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` that wraps a [batch] matrix."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.linalg import linear_operator
26from tensorflow.python.ops.linalg import linear_operator_util
27from tensorflow.python.util.tf_export import tf_export
28
29__all__ = ["LinearOperatorFullMatrix"]
30
31
32@tf_export("linalg.LinearOperatorFullMatrix")
33@linear_operator.make_composite_tensor
34class LinearOperatorFullMatrix(linear_operator.LinearOperator):
35  """`LinearOperator` that wraps a [batch] matrix.
36
37  This operator wraps a [batch] matrix `A` (which is a `Tensor`) with shape
38  `[B1,...,Bb, M, N]` for some `b >= 0`.  The first `b` indices index a
39  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
40  an `M x N` matrix.
41
42  ```python
43  # Create a 2 x 2 linear operator.
44  matrix = [[1., 2.], [3., 4.]]
45  operator = LinearOperatorFullMatrix(matrix)
46
47  operator.to_dense()
48  ==> [[1., 2.]
49       [3., 4.]]
50
51  operator.shape
52  ==> [2, 2]
53
54  operator.log_abs_determinant()
55  ==> scalar Tensor
56
57  x = ... Shape [2, 4] Tensor
58  operator.matmul(x)
59  ==> Shape [2, 4] Tensor
60
61  # Create a [2, 3] batch of 4 x 4 linear operators.
62  matrix = tf.random.normal(shape=[2, 3, 4, 4])
63  operator = LinearOperatorFullMatrix(matrix)
64  ```
65
66  #### Shape compatibility
67
68  This operator acts on [batch] matrix with compatible shape.
69  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
70
71  ```
72  operator.shape = [B1,...,Bb] + [M, N],  with b >= 0
73  x.shape =        [B1,...,Bb] + [N, R],  with R >= 0.
74  ```
75
76  #### Performance
77
78  `LinearOperatorFullMatrix` has exactly the same performance as would be
79  achieved by using standard `TensorFlow` matrix ops.  Intelligent choices are
80  made based on the following initialization hints.
81
82  * If `dtype` is real, and `is_self_adjoint` and `is_positive_definite`, a
83    Cholesky factorization is used for the determinant and solve.
84
85  In all cases, suppose `operator` is a `LinearOperatorFullMatrix` of shape
86  `[M, N]`, and `x.shape = [N, R]`.  Then
87
88  * `operator.matmul(x)` is `O(M * N * R)`.
89  * If `M=N`, `operator.solve(x)` is `O(N^3 * R)`.
90  * If `M=N`, `operator.determinant()` is `O(N^3)`.
91
92  If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and
93  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
94
95  #### Matrix property hints
96
97  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
98  for `X = non_singular, self_adjoint, positive_definite, square`.
99  These have the following meaning:
100
101  * If `is_X == True`, callers should expect the operator to have the
102    property `X`.  This is a promise that should be fulfilled, but is *not* a
103    runtime assert.  For example, finite floating point precision may result
104    in these promises being violated.
105  * If `is_X == False`, callers should expect the operator to not have `X`.
106  * If `is_X == None` (the default), callers should have no expectation either
107    way.
108  """
109
110  def __init__(self,
111               matrix,
112               is_non_singular=None,
113               is_self_adjoint=None,
114               is_positive_definite=None,
115               is_square=None,
116               name="LinearOperatorFullMatrix"):
117    r"""Initialize a `LinearOperatorFullMatrix`.
118
119    Args:
120      matrix:  Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`.
121        Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
122        `complex128`.
123      is_non_singular:  Expect that this operator is non-singular.
124      is_self_adjoint:  Expect that this operator is equal to its hermitian
125        transpose.
126      is_positive_definite:  Expect that this operator is positive definite,
127        meaning the quadratic form `x^H A x` has positive real part for all
128        nonzero `x`.  Note that we do not require the operator to be
129        self-adjoint to be positive-definite.  See:
130        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
131      is_square:  Expect that this operator acts like square [batch] matrices.
132      name: A name for this `LinearOperator`.
133
134    Raises:
135      TypeError:  If `diag.dtype` is not an allowed type.
136    """
137    parameters = dict(
138        matrix=matrix,
139        is_non_singular=is_non_singular,
140        is_self_adjoint=is_self_adjoint,
141        is_positive_definite=is_positive_definite,
142        is_square=is_square,
143        name=name
144    )
145
146    with ops.name_scope(name, values=[matrix]):
147      self._matrix = linear_operator_util.convert_nonref_to_tensor(
148          matrix, name="matrix")
149      self._check_matrix(self._matrix)
150
151      super(LinearOperatorFullMatrix, self).__init__(
152          dtype=self._matrix.dtype,
153          is_non_singular=is_non_singular,
154          is_self_adjoint=is_self_adjoint,
155          is_positive_definite=is_positive_definite,
156          is_square=is_square,
157          parameters=parameters,
158          name=name)
159      # TODO(b/143910018) Remove graph_parents in V3.
160      self._set_graph_parents([self._matrix])
161
162  def _check_matrix(self, matrix):
163    """Static check of the `matrix` argument."""
164    allowed_dtypes = [
165        dtypes.float16,
166        dtypes.float32,
167        dtypes.float64,
168        dtypes.complex64,
169        dtypes.complex128,
170    ]
171
172    matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
173
174    dtype = matrix.dtype
175    if dtype not in allowed_dtypes:
176      raise TypeError(
177          "Argument matrix must have dtype in %s.  Found: %s"
178          % (allowed_dtypes, dtype))
179
180    if matrix.shape.ndims is not None and matrix.shape.ndims < 2:
181      raise ValueError(
182          "Argument matrix must have at least 2 dimensions.  Found: %s"
183          % matrix)
184
185  def _shape(self):
186    return self._matrix.shape
187
188  def _shape_tensor(self):
189    return array_ops.shape(self._matrix)
190
191  def _matmul(self, x, adjoint=False, adjoint_arg=False):
192    return math_ops.matmul(
193        self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
194
195  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
196    return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
197
198  def _to_dense(self):
199    return self._matrix
200
201  @property
202  def _composite_tensor_fields(self):
203    return ("matrix",)
204