1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a lower triangular matrix.""" 16 17from tensorflow.python.framework import ops 18from tensorflow.python.ops import array_ops 19from tensorflow.python.ops import math_ops 20from tensorflow.python.ops.linalg import linalg_impl as linalg 21from tensorflow.python.ops.linalg import linear_operator 22from tensorflow.python.ops.linalg import linear_operator_util 23from tensorflow.python.util.tf_export import tf_export 24 25__all__ = [ 26 "LinearOperatorLowerTriangular", 27] 28 29 30@tf_export("linalg.LinearOperatorLowerTriangular") 31@linear_operator.make_composite_tensor 32class LinearOperatorLowerTriangular(linear_operator.LinearOperator): 33 """`LinearOperator` acting like a [batch] square lower triangular matrix. 34 35 This operator acts like a [batch] lower triangular matrix `A` with shape 36 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 37 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 38 an `N x N` matrix. 39 40 `LinearOperatorLowerTriangular` is initialized with a `Tensor` having 41 dimensions `[B1,...,Bb, N, N]`. The upper triangle of the last two 42 dimensions is ignored. 43 44 ```python 45 # Create a 2 x 2 lower-triangular linear operator. 46 tril = [[1., 2.], [3., 4.]] 47 operator = LinearOperatorLowerTriangular(tril) 48 49 # The upper triangle is ignored. 50 operator.to_dense() 51 ==> [[1., 0.] 52 [3., 4.]] 53 54 operator.shape 55 ==> [2, 2] 56 57 operator.log_abs_determinant() 58 ==> scalar Tensor 59 60 x = ... Shape [2, 4] Tensor 61 operator.matmul(x) 62 ==> Shape [2, 4] Tensor 63 64 # Create a [2, 3] batch of 4 x 4 linear operators. 65 tril = tf.random.normal(shape=[2, 3, 4, 4]) 66 operator = LinearOperatorLowerTriangular(tril) 67 ``` 68 69 #### Shape compatibility 70 71 This operator acts on [batch] matrix with compatible shape. 72 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 73 74 ``` 75 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 76 x.shape = [B1,...,Bb] + [N, R], with R >= 0. 77 ``` 78 79 #### Performance 80 81 Suppose `operator` is a `LinearOperatorLowerTriangular` of shape `[N, N]`, 82 and `x.shape = [N, R]`. Then 83 84 * `operator.matmul(x)` involves `N^2 * R` multiplications. 85 * `operator.solve(x)` involves `N * R` size `N` back-substitutions. 86 * `operator.determinant()` involves a size `N` `reduce_prod`. 87 88 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and 89 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 90 91 #### Matrix property hints 92 93 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 94 for `X = non_singular, self_adjoint, positive_definite, square`. 95 These have the following meaning: 96 97 * If `is_X == True`, callers should expect the operator to have the 98 property `X`. This is a promise that should be fulfilled, but is *not* a 99 runtime assert. For example, finite floating point precision may result 100 in these promises being violated. 101 * If `is_X == False`, callers should expect the operator to not have `X`. 102 * If `is_X == None` (the default), callers should have no expectation either 103 way. 104 """ 105 106 def __init__(self, 107 tril, 108 is_non_singular=None, 109 is_self_adjoint=None, 110 is_positive_definite=None, 111 is_square=None, 112 name="LinearOperatorLowerTriangular"): 113 r"""Initialize a `LinearOperatorLowerTriangular`. 114 115 Args: 116 tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. 117 The lower triangular part of `tril` defines this operator. The strictly 118 upper triangle is ignored. 119 is_non_singular: Expect that this operator is non-singular. 120 This operator is non-singular if and only if its diagonal elements are 121 all non-zero. 122 is_self_adjoint: Expect that this operator is equal to its hermitian 123 transpose. This operator is self-adjoint only if it is diagonal with 124 real-valued diagonal entries. In this case it is advised to use 125 `LinearOperatorDiag`. 126 is_positive_definite: Expect that this operator is positive definite, 127 meaning the quadratic form `x^H A x` has positive real part for all 128 nonzero `x`. Note that we do not require the operator to be 129 self-adjoint to be positive-definite. See: 130 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 131 is_square: Expect that this operator acts like square [batch] matrices. 132 name: A name for this `LinearOperator`. 133 134 Raises: 135 ValueError: If `is_square` is `False`. 136 """ 137 parameters = dict( 138 tril=tril, 139 is_non_singular=is_non_singular, 140 is_self_adjoint=is_self_adjoint, 141 is_positive_definite=is_positive_definite, 142 is_square=is_square, 143 name=name 144 ) 145 146 if is_square is False: 147 raise ValueError( 148 "Only square lower triangular operators supported at this time.") 149 is_square = True 150 151 with ops.name_scope(name, values=[tril]): 152 self._tril = linear_operator_util.convert_nonref_to_tensor(tril, 153 name="tril") 154 self._check_tril(self._tril) 155 156 super(LinearOperatorLowerTriangular, self).__init__( 157 dtype=self._tril.dtype, 158 is_non_singular=is_non_singular, 159 is_self_adjoint=is_self_adjoint, 160 is_positive_definite=is_positive_definite, 161 is_square=is_square, 162 parameters=parameters, 163 name=name) 164 165 @property 166 def tril(self): 167 """The lower triangular matrix defining this operator.""" 168 return self._tril 169 170 def _check_tril(self, tril): 171 """Static check of the `tril` argument.""" 172 173 if tril.shape.ndims is not None and tril.shape.ndims < 2: 174 raise ValueError( 175 "Argument tril must have at least 2 dimensions. Found: %s" 176 % tril) 177 178 def _get_tril(self): 179 """Gets the `tril` kwarg, with upper part zero-d out.""" 180 return array_ops.matrix_band_part(self._tril, -1, 0) 181 182 def _get_diag(self): 183 """Gets the diagonal part of `tril` kwarg.""" 184 return array_ops.matrix_diag_part(self._tril) 185 186 def _shape(self): 187 return self._tril.shape 188 189 def _shape_tensor(self): 190 return array_ops.shape(self._tril) 191 192 def _assert_non_singular(self): 193 return linear_operator_util.assert_no_entries_with_modulus_zero( 194 self._get_diag(), 195 message="Singular operator: Diagonal contained zero values.") 196 197 def _matmul(self, x, adjoint=False, adjoint_arg=False): 198 return math_ops.matmul( 199 self._get_tril(), x, adjoint_a=adjoint, adjoint_b=adjoint_arg) 200 201 def _determinant(self): 202 return math_ops.reduce_prod(self._get_diag(), axis=[-1]) 203 204 def _log_abs_determinant(self): 205 return math_ops.reduce_sum( 206 math_ops.log(math_ops.abs(self._get_diag())), axis=[-1]) 207 208 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 209 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 210 return linalg.triangular_solve( 211 self._get_tril(), rhs, lower=True, adjoint=adjoint) 212 213 def _to_dense(self): 214 return self._get_tril() 215 216 def _eigvals(self): 217 return self._get_diag() 218 219 @property 220 def _composite_tensor_fields(self): 221 return ("tril",) 222 223 @property 224 def _experimental_parameter_ndims_to_matrix_ndims(self): 225 return {"tril": 2} 226