1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a diagonal matrix.""" 16 17from tensorflow.python.framework import ops 18from tensorflow.python.ops import array_ops 19from tensorflow.python.ops import check_ops 20from tensorflow.python.ops import math_ops 21from tensorflow.python.ops.linalg import linalg_impl as linalg 22from tensorflow.python.ops.linalg import linear_operator 23from tensorflow.python.ops.linalg import linear_operator_util 24from tensorflow.python.util.tf_export import tf_export 25 26__all__ = ["LinearOperatorDiag",] 27 28 29@tf_export("linalg.LinearOperatorDiag") 30@linear_operator.make_composite_tensor 31class LinearOperatorDiag(linear_operator.LinearOperator): 32 """`LinearOperator` acting like a [batch] square diagonal matrix. 33 34 This operator acts like a [batch] diagonal matrix `A` with shape 35 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 36 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 37 an `N x N` matrix. This matrix `A` is not materialized, but for 38 purposes of broadcasting this shape will be relevant. 39 40 `LinearOperatorDiag` is initialized with a (batch) vector. 41 42 ```python 43 # Create a 2 x 2 diagonal linear operator. 44 diag = [1., -1.] 45 operator = LinearOperatorDiag(diag) 46 47 operator.to_dense() 48 ==> [[1., 0.] 49 [0., -1.]] 50 51 operator.shape 52 ==> [2, 2] 53 54 operator.log_abs_determinant() 55 ==> scalar Tensor 56 57 x = ... Shape [2, 4] Tensor 58 operator.matmul(x) 59 ==> Shape [2, 4] Tensor 60 61 # Create a [2, 3] batch of 4 x 4 linear operators. 62 diag = tf.random.normal(shape=[2, 3, 4]) 63 operator = LinearOperatorDiag(diag) 64 65 # Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible 66 # since the batch dimensions, [2, 1], are broadcast to 67 # operator.batch_shape = [2, 3]. 68 y = tf.random.normal(shape=[2, 1, 4, 2]) 69 x = operator.solve(y) 70 ==> operator.matmul(x) = y 71 ``` 72 73 #### Shape compatibility 74 75 This operator acts on [batch] matrix with compatible shape. 76 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 77 78 ``` 79 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 80 x.shape = [C1,...,Cc] + [N, R], 81 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 82 ``` 83 84 #### Performance 85 86 Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`, 87 and `x.shape = [N, R]`. Then 88 89 * `operator.matmul(x)` involves `N * R` multiplications. 90 * `operator.solve(x)` involves `N` divisions and `N * R` multiplications. 91 * `operator.determinant()` involves a size `N` `reduce_prod`. 92 93 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and 94 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 95 96 #### Matrix property hints 97 98 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 99 for `X = non_singular, self_adjoint, positive_definite, square`. 100 These have the following meaning: 101 102 * If `is_X == True`, callers should expect the operator to have the 103 property `X`. This is a promise that should be fulfilled, but is *not* a 104 runtime assert. For example, finite floating point precision may result 105 in these promises being violated. 106 * If `is_X == False`, callers should expect the operator to not have `X`. 107 * If `is_X == None` (the default), callers should have no expectation either 108 way. 109 """ 110 111 def __init__(self, 112 diag, 113 is_non_singular=None, 114 is_self_adjoint=None, 115 is_positive_definite=None, 116 is_square=None, 117 name="LinearOperatorDiag"): 118 r"""Initialize a `LinearOperatorDiag`. 119 120 Args: 121 diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. 122 The diagonal of the operator. Allowed dtypes: `float16`, `float32`, 123 `float64`, `complex64`, `complex128`. 124 is_non_singular: Expect that this operator is non-singular. 125 is_self_adjoint: Expect that this operator is equal to its hermitian 126 transpose. If `diag.dtype` is real, this is auto-set to `True`. 127 is_positive_definite: Expect that this operator is positive definite, 128 meaning the quadratic form `x^H A x` has positive real part for all 129 nonzero `x`. Note that we do not require the operator to be 130 self-adjoint to be positive-definite. See: 131 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 132 is_square: Expect that this operator acts like square [batch] matrices. 133 name: A name for this `LinearOperator`. 134 135 Raises: 136 TypeError: If `diag.dtype` is not an allowed type. 137 ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`. 138 """ 139 parameters = dict( 140 diag=diag, 141 is_non_singular=is_non_singular, 142 is_self_adjoint=is_self_adjoint, 143 is_positive_definite=is_positive_definite, 144 is_square=is_square, 145 name=name 146 ) 147 148 with ops.name_scope(name, values=[diag]): 149 self._diag = linear_operator_util.convert_nonref_to_tensor( 150 diag, name="diag") 151 self._check_diag(self._diag) 152 153 # Check and auto-set hints. 154 if not self._diag.dtype.is_complex: 155 if is_self_adjoint is False: 156 raise ValueError("A real diagonal operator is always self adjoint.") 157 else: 158 is_self_adjoint = True 159 160 if is_square is False: 161 raise ValueError("Only square diagonal operators currently supported.") 162 is_square = True 163 164 super(LinearOperatorDiag, self).__init__( 165 dtype=self._diag.dtype, 166 is_non_singular=is_non_singular, 167 is_self_adjoint=is_self_adjoint, 168 is_positive_definite=is_positive_definite, 169 is_square=is_square, 170 parameters=parameters, 171 name=name) 172 173 def _check_diag(self, diag): 174 """Static check of diag.""" 175 if diag.shape.ndims is not None and diag.shape.ndims < 1: 176 raise ValueError("Argument diag must have at least 1 dimension. " 177 "Found: %s" % diag) 178 179 def _shape(self): 180 # If d_shape = [5, 3], we return [5, 3, 3]. 181 d_shape = self._diag.shape 182 return d_shape.concatenate(d_shape[-1:]) 183 184 def _shape_tensor(self): 185 d_shape = array_ops.shape(self._diag) 186 k = d_shape[-1] 187 return array_ops.concat((d_shape, [k]), 0) 188 189 @property 190 def diag(self): 191 return self._diag 192 193 def _assert_non_singular(self): 194 return linear_operator_util.assert_no_entries_with_modulus_zero( 195 self._diag, 196 message="Singular operator: Diagonal contained zero values.") 197 198 def _assert_positive_definite(self): 199 if self.dtype.is_complex: 200 message = ( 201 "Diagonal operator had diagonal entries with non-positive real part, " 202 "thus was not positive definite.") 203 else: 204 message = ( 205 "Real diagonal operator had non-positive diagonal entries, " 206 "thus was not positive definite.") 207 208 return check_ops.assert_positive( 209 math_ops.real(self._diag), 210 message=message) 211 212 def _assert_self_adjoint(self): 213 return linear_operator_util.assert_zero_imag_part( 214 self._diag, 215 message=( 216 "This diagonal operator contained non-zero imaginary values. " 217 " Thus it was not self-adjoint.")) 218 219 def _matmul(self, x, adjoint=False, adjoint_arg=False): 220 diag_term = math_ops.conj(self._diag) if adjoint else self._diag 221 x = linalg.adjoint(x) if adjoint_arg else x 222 diag_mat = array_ops.expand_dims(diag_term, -1) 223 return diag_mat * x 224 225 def _matvec(self, x, adjoint=False): 226 diag_term = math_ops.conj(self._diag) if adjoint else self._diag 227 return diag_term * x 228 229 def _determinant(self): 230 return math_ops.reduce_prod(self._diag, axis=[-1]) 231 232 def _log_abs_determinant(self): 233 log_det = math_ops.reduce_sum( 234 math_ops.log(math_ops.abs(self._diag)), axis=[-1]) 235 if self.dtype.is_complex: 236 log_det = math_ops.cast(log_det, dtype=self.dtype) 237 return log_det 238 239 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 240 diag_term = math_ops.conj(self._diag) if adjoint else self._diag 241 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 242 inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1) 243 return rhs * inv_diag_mat 244 245 def _to_dense(self): 246 return array_ops.matrix_diag(self._diag) 247 248 def _diag_part(self): 249 return self.diag 250 251 def _add_to_tensor(self, x): 252 x_diag = array_ops.matrix_diag_part(x) 253 new_diag = self._diag + x_diag 254 return array_ops.matrix_set_diag(x, new_diag) 255 256 def _eigvals(self): 257 return ops.convert_to_tensor_v2_with_dispatch(self.diag) 258 259 def _cond(self): 260 abs_diag = math_ops.abs(self.diag) 261 return (math_ops.reduce_max(abs_diag, axis=-1) / 262 math_ops.reduce_min(abs_diag, axis=-1)) 263 264 @property 265 def _composite_tensor_fields(self): 266 return ("diag",) 267 268 @property 269 def _experimental_parameter_ndims_to_matrix_ndims(self): 270 return {"diag": 1} 271