1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a lower triangular matrix.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops.linalg import linalg_impl as linalg 25from tensorflow.python.ops.linalg import linear_operator 26from tensorflow.python.ops.linalg import linear_operator_util 27from tensorflow.python.util.tf_export import tf_export 28 29__all__ = [ 30 "LinearOperatorLowerTriangular", 31] 32 33 34@tf_export("linalg.LinearOperatorLowerTriangular") 35@linear_operator.make_composite_tensor 36class LinearOperatorLowerTriangular(linear_operator.LinearOperator): 37 """`LinearOperator` acting like a [batch] square lower triangular matrix. 38 39 This operator acts like a [batch] lower triangular matrix `A` with shape 40 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 41 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 42 an `N x N` matrix. 43 44 `LinearOperatorLowerTriangular` is initialized with a `Tensor` having 45 dimensions `[B1,...,Bb, N, N]`. The upper triangle of the last two 46 dimensions is ignored. 47 48 ```python 49 # Create a 2 x 2 lower-triangular linear operator. 50 tril = [[1., 2.], [3., 4.]] 51 operator = LinearOperatorLowerTriangular(tril) 52 53 # The upper triangle is ignored. 54 operator.to_dense() 55 ==> [[1., 0.] 56 [3., 4.]] 57 58 operator.shape 59 ==> [2, 2] 60 61 operator.log_abs_determinant() 62 ==> scalar Tensor 63 64 x = ... Shape [2, 4] Tensor 65 operator.matmul(x) 66 ==> Shape [2, 4] Tensor 67 68 # Create a [2, 3] batch of 4 x 4 linear operators. 69 tril = tf.random.normal(shape=[2, 3, 4, 4]) 70 operator = LinearOperatorLowerTriangular(tril) 71 ``` 72 73 #### Shape compatibility 74 75 This operator acts on [batch] matrix with compatible shape. 76 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 77 78 ``` 79 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 80 x.shape = [B1,...,Bb] + [N, R], with R >= 0. 81 ``` 82 83 #### Performance 84 85 Suppose `operator` is a `LinearOperatorLowerTriangular` of shape `[N, N]`, 86 and `x.shape = [N, R]`. Then 87 88 * `operator.matmul(x)` involves `N^2 * R` multiplications. 89 * `operator.solve(x)` involves `N * R` size `N` back-substitutions. 90 * `operator.determinant()` involves a size `N` `reduce_prod`. 91 92 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and 93 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 94 95 #### Matrix property hints 96 97 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 98 for `X = non_singular, self_adjoint, positive_definite, square`. 99 These have the following meaning: 100 101 * If `is_X == True`, callers should expect the operator to have the 102 property `X`. This is a promise that should be fulfilled, but is *not* a 103 runtime assert. For example, finite floating point precision may result 104 in these promises being violated. 105 * If `is_X == False`, callers should expect the operator to not have `X`. 106 * If `is_X == None` (the default), callers should have no expectation either 107 way. 108 """ 109 110 def __init__(self, 111 tril, 112 is_non_singular=None, 113 is_self_adjoint=None, 114 is_positive_definite=None, 115 is_square=None, 116 name="LinearOperatorLowerTriangular"): 117 r"""Initialize a `LinearOperatorLowerTriangular`. 118 119 Args: 120 tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. 121 The lower triangular part of `tril` defines this operator. The strictly 122 upper triangle is ignored. 123 is_non_singular: Expect that this operator is non-singular. 124 This operator is non-singular if and only if its diagonal elements are 125 all non-zero. 126 is_self_adjoint: Expect that this operator is equal to its hermitian 127 transpose. This operator is self-adjoint only if it is diagonal with 128 real-valued diagonal entries. In this case it is advised to use 129 `LinearOperatorDiag`. 130 is_positive_definite: Expect that this operator is positive definite, 131 meaning the quadratic form `x^H A x` has positive real part for all 132 nonzero `x`. Note that we do not require the operator to be 133 self-adjoint to be positive-definite. See: 134 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 135 is_square: Expect that this operator acts like square [batch] matrices. 136 name: A name for this `LinearOperator`. 137 138 Raises: 139 ValueError: If `is_square` is `False`. 140 """ 141 parameters = dict( 142 tril=tril, 143 is_non_singular=is_non_singular, 144 is_self_adjoint=is_self_adjoint, 145 is_positive_definite=is_positive_definite, 146 is_square=is_square, 147 name=name 148 ) 149 150 if is_square is False: 151 raise ValueError( 152 "Only square lower triangular operators supported at this time.") 153 is_square = True 154 155 with ops.name_scope(name, values=[tril]): 156 self._tril = linear_operator_util.convert_nonref_to_tensor(tril, 157 name="tril") 158 self._check_tril(self._tril) 159 160 super(LinearOperatorLowerTriangular, self).__init__( 161 dtype=self._tril.dtype, 162 is_non_singular=is_non_singular, 163 is_self_adjoint=is_self_adjoint, 164 is_positive_definite=is_positive_definite, 165 is_square=is_square, 166 parameters=parameters, 167 name=name) 168 self._set_graph_parents([self._tril]) 169 170 def _check_tril(self, tril): 171 """Static check of the `tril` argument.""" 172 173 if tril.shape.ndims is not None and tril.shape.ndims < 2: 174 raise ValueError( 175 "Argument tril must have at least 2 dimensions. Found: %s" 176 % tril) 177 178 def _get_tril(self): 179 """Gets the `tril` kwarg, with upper part zero-d out.""" 180 return array_ops.matrix_band_part(self._tril, -1, 0) 181 182 def _get_diag(self): 183 """Gets the diagonal part of `tril` kwarg.""" 184 return array_ops.matrix_diag_part(self._tril) 185 186 def _shape(self): 187 return self._tril.shape 188 189 def _shape_tensor(self): 190 return array_ops.shape(self._tril) 191 192 def _assert_non_singular(self): 193 return linear_operator_util.assert_no_entries_with_modulus_zero( 194 self._get_diag(), 195 message="Singular operator: Diagonal contained zero values.") 196 197 def _matmul(self, x, adjoint=False, adjoint_arg=False): 198 return math_ops.matmul( 199 self._get_tril(), x, adjoint_a=adjoint, adjoint_b=adjoint_arg) 200 201 def _determinant(self): 202 return math_ops.reduce_prod(self._get_diag(), axis=[-1]) 203 204 def _log_abs_determinant(self): 205 return math_ops.reduce_sum( 206 math_ops.log(math_ops.abs(self._get_diag())), axis=[-1]) 207 208 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 209 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 210 return linalg.triangular_solve( 211 self._get_tril(), rhs, lower=True, adjoint=adjoint) 212 213 def _to_dense(self): 214 return self._get_tril() 215 216 def _eigvals(self): 217 return self._get_diag() 218 219 @property 220 def _composite_tensor_fields(self): 221 return ("tril",) 222