1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a lower triangular matrix.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops.linalg import linalg_impl as linalg 26from tensorflow.python.ops.linalg import linear_operator 27from tensorflow.python.ops.linalg import linear_operator_util 28from tensorflow.python.util.tf_export import tf_export 29 30__all__ = [ 31 "LinearOperatorLowerTriangular", 32] 33 34 35@tf_export("linalg.LinearOperatorLowerTriangular") 36class LinearOperatorLowerTriangular(linear_operator.LinearOperator): 37 """`LinearOperator` acting like a [batch] square lower triangular matrix. 38 39 This operator acts like a [batch] lower triangular matrix `A` with shape 40 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 41 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 42 an `N x N` matrix. 43 44 `LinearOperatorLowerTriangular` is initialized with a `Tensor` having 45 dimensions `[B1,...,Bb, N, N]`. The upper triangle of the last two 46 dimensions is ignored. 47 48 ```python 49 # Create a 2 x 2 lower-triangular linear operator. 50 tril = [[1., 2.], [3., 4.]] 51 operator = LinearOperatorLowerTriangular(tril) 52 53 # The upper triangle is ignored. 54 operator.to_dense() 55 ==> [[1., 0.] 56 [3., 4.]] 57 58 operator.shape 59 ==> [2, 2] 60 61 operator.log_abs_determinant() 62 ==> scalar Tensor 63 64 x = ... Shape [2, 4] Tensor 65 operator.matmul(x) 66 ==> Shape [2, 4] Tensor 67 68 # Create a [2, 3] batch of 4 x 4 linear operators. 69 tril = tf.random_normal(shape=[2, 3, 4, 4]) 70 operator = LinearOperatorLowerTriangular(tril) 71 ``` 72 73 #### Shape compatibility 74 75 This operator acts on [batch] matrix with compatible shape. 76 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 77 78 ``` 79 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 80 x.shape = [B1,...,Bb] + [N, R], with R >= 0. 81 ``` 82 83 #### Performance 84 85 Suppose `operator` is a `LinearOperatorLowerTriangular` of shape `[N, N]`, 86 and `x.shape = [N, R]`. Then 87 88 * `operator.matmul(x)` involves `N^2 * R` multiplications. 89 * `operator.solve(x)` involves `N * R` size `N` back-substitutions. 90 * `operator.determinant()` involves a size `N` `reduce_prod`. 91 92 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and 93 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 94 95 #### Matrix property hints 96 97 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 98 for `X = non_singular, self_adjoint, positive_definite, square`. 99 These have the following meaning: 100 101 * If `is_X == True`, callers should expect the operator to have the 102 property `X`. This is a promise that should be fulfilled, but is *not* a 103 runtime assert. For example, finite floating point precision may result 104 in these promises being violated. 105 * If `is_X == False`, callers should expect the operator to not have `X`. 106 * If `is_X == None` (the default), callers should have no expectation either 107 way. 108 """ 109 110 def __init__(self, 111 tril, 112 is_non_singular=None, 113 is_self_adjoint=None, 114 is_positive_definite=None, 115 is_square=None, 116 name="LinearOperatorLowerTriangular"): 117 r"""Initialize a `LinearOperatorLowerTriangular`. 118 119 Args: 120 tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. 121 The lower triangular part of `tril` defines this operator. The strictly 122 upper triangle is ignored. 123 is_non_singular: Expect that this operator is non-singular. 124 This operator is non-singular if and only if its diagonal elements are 125 all non-zero. 126 is_self_adjoint: Expect that this operator is equal to its hermitian 127 transpose. This operator is self-adjoint only if it is diagonal with 128 real-valued diagonal entries. In this case it is advised to use 129 `LinearOperatorDiag`. 130 is_positive_definite: Expect that this operator is positive definite, 131 meaning the quadratic form `x^H A x` has positive real part for all 132 nonzero `x`. Note that we do not require the operator to be 133 self-adjoint to be positive-definite. See: 134 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 135 is_square: Expect that this operator acts like square [batch] matrices. 136 name: A name for this `LinearOperator`. 137 138 Raises: 139 ValueError: If `is_square` is `False`. 140 """ 141 142 if is_square is False: 143 raise ValueError( 144 "Only square lower triangular operators supported at this time.") 145 is_square = True 146 147 with ops.name_scope(name, values=[tril]): 148 self._tril = ops.convert_to_tensor(tril, name="tril") 149 self._check_tril(self._tril) 150 self._tril = array_ops.matrix_band_part(tril, -1, 0) 151 self._diag = array_ops.matrix_diag_part(self._tril) 152 153 super(LinearOperatorLowerTriangular, self).__init__( 154 dtype=self._tril.dtype, 155 graph_parents=[self._tril], 156 is_non_singular=is_non_singular, 157 is_self_adjoint=is_self_adjoint, 158 is_positive_definite=is_positive_definite, 159 is_square=is_square, 160 name=name) 161 162 def _check_tril(self, tril): 163 """Static check of the `tril` argument.""" 164 allowed_dtypes = [ 165 dtypes.float16, 166 dtypes.float32, 167 dtypes.float64, 168 dtypes.complex64, 169 dtypes.complex128, 170 ] 171 dtype = tril.dtype 172 if dtype not in allowed_dtypes: 173 raise TypeError( 174 "Argument tril must have dtype in %s. Found: %s" 175 % (allowed_dtypes, dtype)) 176 177 if tril.get_shape().ndims is not None and tril.get_shape().ndims < 2: 178 raise ValueError( 179 "Argument tril must have at least 2 dimensions. Found: %s" 180 % tril) 181 182 def _shape(self): 183 return self._tril.get_shape() 184 185 def _shape_tensor(self): 186 return array_ops.shape(self._tril) 187 188 def _assert_non_singular(self): 189 return linear_operator_util.assert_no_entries_with_modulus_zero( 190 self._diag, 191 message="Singular operator: Diagonal contained zero values.") 192 193 def _matmul(self, x, adjoint=False, adjoint_arg=False): 194 return linear_operator_util.matmul_with_broadcast( 195 self._tril, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) 196 197 def _determinant(self): 198 return math_ops.reduce_prod(self._diag, axis=[-1]) 199 200 def _log_abs_determinant(self): 201 return math_ops.reduce_sum( 202 math_ops.log(math_ops.abs(self._diag)), axis=[-1]) 203 204 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 205 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 206 return linear_operator_util.matrix_triangular_solve_with_broadcast( 207 self._tril, rhs, lower=True, adjoint=adjoint) 208 209 def _to_dense(self): 210 return self._tril 211 212 def _add_to_tensor(self, x): 213 return self._tril + x 214