1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` that wraps a [batch] matrix.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops.linalg import linear_operator 26from tensorflow.python.ops.linalg import linear_operator_util 27from tensorflow.python.util.tf_export import tf_export 28 29__all__ = ["LinearOperatorFullMatrix"] 30 31 32@tf_export("linalg.LinearOperatorFullMatrix") 33@linear_operator.make_composite_tensor 34class LinearOperatorFullMatrix(linear_operator.LinearOperator): 35 """`LinearOperator` that wraps a [batch] matrix. 36 37 This operator wraps a [batch] matrix `A` (which is a `Tensor`) with shape 38 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a 39 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 40 an `M x N` matrix. 41 42 ```python 43 # Create a 2 x 2 linear operator. 44 matrix = [[1., 2.], [3., 4.]] 45 operator = LinearOperatorFullMatrix(matrix) 46 47 operator.to_dense() 48 ==> [[1., 2.] 49 [3., 4.]] 50 51 operator.shape 52 ==> [2, 2] 53 54 operator.log_abs_determinant() 55 ==> scalar Tensor 56 57 x = ... Shape [2, 4] Tensor 58 operator.matmul(x) 59 ==> Shape [2, 4] Tensor 60 61 # Create a [2, 3] batch of 4 x 4 linear operators. 62 matrix = tf.random.normal(shape=[2, 3, 4, 4]) 63 operator = LinearOperatorFullMatrix(matrix) 64 ``` 65 66 #### Shape compatibility 67 68 This operator acts on [batch] matrix with compatible shape. 69 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 70 71 ``` 72 operator.shape = [B1,...,Bb] + [M, N], with b >= 0 73 x.shape = [B1,...,Bb] + [N, R], with R >= 0. 74 ``` 75 76 #### Performance 77 78 `LinearOperatorFullMatrix` has exactly the same performance as would be 79 achieved by using standard `TensorFlow` matrix ops. Intelligent choices are 80 made based on the following initialization hints. 81 82 * If `dtype` is real, and `is_self_adjoint` and `is_positive_definite`, a 83 Cholesky factorization is used for the determinant and solve. 84 85 In all cases, suppose `operator` is a `LinearOperatorFullMatrix` of shape 86 `[M, N]`, and `x.shape = [N, R]`. Then 87 88 * `operator.matmul(x)` is `O(M * N * R)`. 89 * If `M=N`, `operator.solve(x)` is `O(N^3 * R)`. 90 * If `M=N`, `operator.determinant()` is `O(N^3)`. 91 92 If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and 93 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 94 95 #### Matrix property hints 96 97 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 98 for `X = non_singular, self_adjoint, positive_definite, square`. 99 These have the following meaning: 100 101 * If `is_X == True`, callers should expect the operator to have the 102 property `X`. This is a promise that should be fulfilled, but is *not* a 103 runtime assert. For example, finite floating point precision may result 104 in these promises being violated. 105 * If `is_X == False`, callers should expect the operator to not have `X`. 106 * If `is_X == None` (the default), callers should have no expectation either 107 way. 108 """ 109 110 def __init__(self, 111 matrix, 112 is_non_singular=None, 113 is_self_adjoint=None, 114 is_positive_definite=None, 115 is_square=None, 116 name="LinearOperatorFullMatrix"): 117 r"""Initialize a `LinearOperatorFullMatrix`. 118 119 Args: 120 matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`. 121 Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, 122 `complex128`. 123 is_non_singular: Expect that this operator is non-singular. 124 is_self_adjoint: Expect that this operator is equal to its hermitian 125 transpose. 126 is_positive_definite: Expect that this operator is positive definite, 127 meaning the quadratic form `x^H A x` has positive real part for all 128 nonzero `x`. Note that we do not require the operator to be 129 self-adjoint to be positive-definite. See: 130 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 131 is_square: Expect that this operator acts like square [batch] matrices. 132 name: A name for this `LinearOperator`. 133 134 Raises: 135 TypeError: If `diag.dtype` is not an allowed type. 136 """ 137 parameters = dict( 138 matrix=matrix, 139 is_non_singular=is_non_singular, 140 is_self_adjoint=is_self_adjoint, 141 is_positive_definite=is_positive_definite, 142 is_square=is_square, 143 name=name 144 ) 145 146 with ops.name_scope(name, values=[matrix]): 147 self._matrix = linear_operator_util.convert_nonref_to_tensor( 148 matrix, name="matrix") 149 self._check_matrix(self._matrix) 150 151 super(LinearOperatorFullMatrix, self).__init__( 152 dtype=self._matrix.dtype, 153 is_non_singular=is_non_singular, 154 is_self_adjoint=is_self_adjoint, 155 is_positive_definite=is_positive_definite, 156 is_square=is_square, 157 parameters=parameters, 158 name=name) 159 # TODO(b/143910018) Remove graph_parents in V3. 160 self._set_graph_parents([self._matrix]) 161 162 def _check_matrix(self, matrix): 163 """Static check of the `matrix` argument.""" 164 allowed_dtypes = [ 165 dtypes.float16, 166 dtypes.float32, 167 dtypes.float64, 168 dtypes.complex64, 169 dtypes.complex128, 170 ] 171 172 matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix") 173 174 dtype = matrix.dtype 175 if dtype not in allowed_dtypes: 176 raise TypeError( 177 "Argument matrix must have dtype in %s. Found: %s" 178 % (allowed_dtypes, dtype)) 179 180 if matrix.shape.ndims is not None and matrix.shape.ndims < 2: 181 raise ValueError( 182 "Argument matrix must have at least 2 dimensions. Found: %s" 183 % matrix) 184 185 def _shape(self): 186 return self._matrix.shape 187 188 def _shape_tensor(self): 189 return array_ops.shape(self._matrix) 190 191 def _matmul(self, x, adjoint=False, adjoint_arg=False): 192 return math_ops.matmul( 193 self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) 194 195 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 196 return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 197 198 def _to_dense(self): 199 return self._matrix 200 201 @property 202 def _composite_tensor_fields(self): 203 return ("matrix",) 204