1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` that wraps a [batch] matrix.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.ops import array_ops 20from tensorflow.python.ops import math_ops 21from tensorflow.python.ops.linalg import linear_operator 22from tensorflow.python.ops.linalg import linear_operator_util 23from tensorflow.python.util.tf_export import tf_export 24 25__all__ = ["LinearOperatorFullMatrix"] 26 27 28@tf_export("linalg.LinearOperatorFullMatrix") 29@linear_operator.make_composite_tensor 30class LinearOperatorFullMatrix(linear_operator.LinearOperator): 31 """`LinearOperator` that wraps a [batch] matrix. 32 33 This operator wraps a [batch] matrix `A` (which is a `Tensor`) with shape 34 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a 35 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 36 an `M x N` matrix. 37 38 ```python 39 # Create a 2 x 2 linear operator. 40 matrix = [[1., 2.], [3., 4.]] 41 operator = LinearOperatorFullMatrix(matrix) 42 43 operator.to_dense() 44 ==> [[1., 2.] 45 [3., 4.]] 46 47 operator.shape 48 ==> [2, 2] 49 50 operator.log_abs_determinant() 51 ==> scalar Tensor 52 53 x = ... Shape [2, 4] Tensor 54 operator.matmul(x) 55 ==> Shape [2, 4] Tensor 56 57 # Create a [2, 3] batch of 4 x 4 linear operators. 58 matrix = tf.random.normal(shape=[2, 3, 4, 4]) 59 operator = LinearOperatorFullMatrix(matrix) 60 ``` 61 62 #### Shape compatibility 63 64 This operator acts on [batch] matrix with compatible shape. 65 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 66 67 ``` 68 operator.shape = [B1,...,Bb] + [M, N], with b >= 0 69 x.shape = [B1,...,Bb] + [N, R], with R >= 0. 70 ``` 71 72 #### Performance 73 74 `LinearOperatorFullMatrix` has exactly the same performance as would be 75 achieved by using standard `TensorFlow` matrix ops. Intelligent choices are 76 made based on the following initialization hints. 77 78 * If `dtype` is real, and `is_self_adjoint` and `is_positive_definite`, a 79 Cholesky factorization is used for the determinant and solve. 80 81 In all cases, suppose `operator` is a `LinearOperatorFullMatrix` of shape 82 `[M, N]`, and `x.shape = [N, R]`. Then 83 84 * `operator.matmul(x)` is `O(M * N * R)`. 85 * If `M=N`, `operator.solve(x)` is `O(N^3 * R)`. 86 * If `M=N`, `operator.determinant()` is `O(N^3)`. 87 88 If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and 89 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 90 91 #### Matrix property hints 92 93 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 94 for `X = non_singular, self_adjoint, positive_definite, square`. 95 These have the following meaning: 96 97 * If `is_X == True`, callers should expect the operator to have the 98 property `X`. This is a promise that should be fulfilled, but is *not* a 99 runtime assert. For example, finite floating point precision may result 100 in these promises being violated. 101 * If `is_X == False`, callers should expect the operator to not have `X`. 102 * If `is_X == None` (the default), callers should have no expectation either 103 way. 104 """ 105 106 def __init__(self, 107 matrix, 108 is_non_singular=None, 109 is_self_adjoint=None, 110 is_positive_definite=None, 111 is_square=None, 112 name="LinearOperatorFullMatrix"): 113 r"""Initialize a `LinearOperatorFullMatrix`. 114 115 Args: 116 matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`. 117 Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, 118 `complex128`. 119 is_non_singular: Expect that this operator is non-singular. 120 is_self_adjoint: Expect that this operator is equal to its hermitian 121 transpose. 122 is_positive_definite: Expect that this operator is positive definite, 123 meaning the quadratic form `x^H A x` has positive real part for all 124 nonzero `x`. Note that we do not require the operator to be 125 self-adjoint to be positive-definite. See: 126 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 127 is_square: Expect that this operator acts like square [batch] matrices. 128 name: A name for this `LinearOperator`. 129 130 Raises: 131 TypeError: If `diag.dtype` is not an allowed type. 132 """ 133 parameters = dict( 134 matrix=matrix, 135 is_non_singular=is_non_singular, 136 is_self_adjoint=is_self_adjoint, 137 is_positive_definite=is_positive_definite, 138 is_square=is_square, 139 name=name 140 ) 141 142 with ops.name_scope(name, values=[matrix]): 143 self._matrix = linear_operator_util.convert_nonref_to_tensor( 144 matrix, name="matrix") 145 self._check_matrix(self._matrix) 146 147 super(LinearOperatorFullMatrix, self).__init__( 148 dtype=self._matrix.dtype, 149 is_non_singular=is_non_singular, 150 is_self_adjoint=is_self_adjoint, 151 is_positive_definite=is_positive_definite, 152 is_square=is_square, 153 parameters=parameters, 154 name=name) 155 156 def _check_matrix(self, matrix): 157 """Static check of the `matrix` argument.""" 158 allowed_dtypes = [ 159 dtypes.float16, 160 dtypes.float32, 161 dtypes.float64, 162 dtypes.complex64, 163 dtypes.complex128, 164 ] 165 166 matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix") 167 168 dtype = matrix.dtype 169 if dtype not in allowed_dtypes: 170 raise TypeError(f"Argument `matrix` must have dtype in {allowed_dtypes}. " 171 f"Received: {dtype}.") 172 173 if matrix.shape.ndims is not None and matrix.shape.ndims < 2: 174 raise ValueError(f"Argument `matrix` must have at least 2 dimensions. " 175 f"Received: {matrix}.") 176 177 @property 178 def matrix(self): 179 """The matrix defining this operator.""" 180 return self._matrix 181 182 def _shape(self): 183 return self._matrix.shape 184 185 def _shape_tensor(self): 186 return array_ops.shape(self._matrix) 187 188 def _matmul(self, x, adjoint=False, adjoint_arg=False): 189 return math_ops.matmul( 190 self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) 191 192 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 193 return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 194 195 def _to_dense(self): 196 return self._matrix 197 198 @property 199 def _composite_tensor_fields(self): 200 return ("matrix",) 201 202 @property 203 def _experimental_parameter_ndims_to_matrix_ndims(self): 204 return {"matrix": 2} 205