1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like a tridiagonal matrix.""" 16 17from tensorflow.python.framework import ops 18from tensorflow.python.ops import array_ops 19from tensorflow.python.ops import check_ops 20from tensorflow.python.ops import control_flow_ops 21from tensorflow.python.ops import gen_array_ops 22from tensorflow.python.ops import manip_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops.linalg import linalg_impl as linalg 25from tensorflow.python.ops.linalg import linear_operator 26from tensorflow.python.ops.linalg import linear_operator_util 27from tensorflow.python.util.tf_export import tf_export 28 29__all__ = ['LinearOperatorTridiag',] 30 31_COMPACT = 'compact' 32_MATRIX = 'matrix' 33_SEQUENCE = 'sequence' 34_DIAGONAL_FORMATS = frozenset({_COMPACT, _MATRIX, _SEQUENCE}) 35 36 37@tf_export('linalg.LinearOperatorTridiag') 38@linear_operator.make_composite_tensor 39class LinearOperatorTridiag(linear_operator.LinearOperator): 40 """`LinearOperator` acting like a [batch] square tridiagonal matrix. 41 42 This operator acts like a [batch] square tridiagonal matrix `A` with shape 43 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 44 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 45 an `N x M` matrix. This matrix `A` is not materialized, but for 46 purposes of broadcasting this shape will be relevant. 47 48 Example usage: 49 50 Create a 3 x 3 tridiagonal linear operator. 51 52 >>> superdiag = [3., 4., 5.] 53 >>> diag = [1., -1., 2.] 54 >>> subdiag = [6., 7., 8] 55 >>> operator = tf.linalg.LinearOperatorTridiag( 56 ... [superdiag, diag, subdiag], 57 ... diagonals_format='sequence') 58 >>> operator.to_dense() 59 <tf.Tensor: shape=(3, 3), dtype=float32, numpy= 60 array([[ 1., 3., 0.], 61 [ 7., -1., 4.], 62 [ 0., 8., 2.]], dtype=float32)> 63 >>> operator.shape 64 TensorShape([3, 3]) 65 66 Scalar Tensor output. 67 68 >>> operator.log_abs_determinant() 69 <tf.Tensor: shape=(), dtype=float32, numpy=4.3307333> 70 71 Create a [2, 3] batch of 4 x 4 linear operators. 72 73 >>> diagonals = tf.random.normal(shape=[2, 3, 3, 4]) 74 >>> operator = tf.linalg.LinearOperatorTridiag( 75 ... diagonals, 76 ... diagonals_format='compact') 77 78 Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible 79 since the batch dimensions, [2, 1], are broadcast to 80 operator.batch_shape = [2, 3]. 81 82 >>> y = tf.random.normal(shape=[2, 1, 4, 2]) 83 >>> x = operator.solve(y) 84 >>> x 85 <tf.Tensor: shape=(2, 3, 4, 2), dtype=float32, numpy=..., 86 dtype=float32)> 87 88 #### Shape compatibility 89 90 This operator acts on [batch] matrix with compatible shape. 91 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 92 93 ``` 94 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 95 x.shape = [C1,...,Cc] + [N, R], 96 and [C1,...,Cc] broadcasts with [B1,...,Bb]. 97 ``` 98 99 #### Performance 100 101 Suppose `operator` is a `LinearOperatorTridiag` of shape `[N, N]`, 102 and `x.shape = [N, R]`. Then 103 104 * `operator.matmul(x)` will take O(N * R) time. 105 * `operator.solve(x)` will take O(N * R) time. 106 107 If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and 108 `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. 109 110 #### Matrix property hints 111 112 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 113 for `X = non_singular, self_adjoint, positive_definite, square`. 114 These have the following meaning: 115 116 * If `is_X == True`, callers should expect the operator to have the 117 property `X`. This is a promise that should be fulfilled, but is *not* a 118 runtime assert. For example, finite floating point precision may result 119 in these promises being violated. 120 * If `is_X == False`, callers should expect the operator to not have `X`. 121 * If `is_X == None` (the default), callers should have no expectation either 122 way. 123 """ 124 125 def __init__(self, 126 diagonals, 127 diagonals_format=_COMPACT, 128 is_non_singular=None, 129 is_self_adjoint=None, 130 is_positive_definite=None, 131 is_square=None, 132 name='LinearOperatorTridiag'): 133 r"""Initialize a `LinearOperatorTridiag`. 134 135 Args: 136 diagonals: `Tensor` or list of `Tensor`s depending on `diagonals_format`. 137 138 If `diagonals_format=sequence`, this is a list of three `Tensor`'s each 139 with shape `[B1, ..., Bb, N]`, `b >= 0, N >= 0`, representing the 140 superdiagonal, diagonal and subdiagonal in that order. Note the 141 superdiagonal is padded with an element in the last position, and the 142 subdiagonal is padded with an element in the front. 143 144 If `diagonals_format=matrix` this is a `[B1, ... Bb, N, N]` shaped 145 `Tensor` representing the full tridiagonal matrix. 146 147 If `diagonals_format=compact` this is a `[B1, ... Bb, 3, N]` shaped 148 `Tensor` with the second to last dimension indexing the 149 superdiagonal, diagonal and subdiagonal in that order. Note the 150 superdiagonal is padded with an element in the last position, and the 151 subdiagonal is padded with an element in the front. 152 153 In every case, these `Tensor`s are all floating dtype. 154 diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is 155 `compact`. 156 is_non_singular: Expect that this operator is non-singular. 157 is_self_adjoint: Expect that this operator is equal to its hermitian 158 transpose. If `diag.dtype` is real, this is auto-set to `True`. 159 is_positive_definite: Expect that this operator is positive definite, 160 meaning the quadratic form `x^H A x` has positive real part for all 161 nonzero `x`. Note that we do not require the operator to be 162 self-adjoint to be positive-definite. See: 163 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 164 is_square: Expect that this operator acts like square [batch] matrices. 165 name: A name for this `LinearOperator`. 166 167 Raises: 168 TypeError: If `diag.dtype` is not an allowed type. 169 ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`. 170 """ 171 parameters = dict( 172 diagonals=diagonals, 173 diagonals_format=diagonals_format, 174 is_non_singular=is_non_singular, 175 is_self_adjoint=is_self_adjoint, 176 is_positive_definite=is_positive_definite, 177 is_square=is_square, 178 name=name 179 ) 180 181 with ops.name_scope(name, values=[diagonals]): 182 if diagonals_format not in _DIAGONAL_FORMATS: 183 raise ValueError( 184 f'Argument `diagonals_format` must be one of compact, matrix, or ' 185 f'sequence. Received : {diagonals_format}.') 186 if diagonals_format == _SEQUENCE: 187 self._diagonals = [linear_operator_util.convert_nonref_to_tensor( 188 d, name='diag_{}'.format(i)) for i, d in enumerate(diagonals)] 189 dtype = self._diagonals[0].dtype 190 else: 191 self._diagonals = linear_operator_util.convert_nonref_to_tensor( 192 diagonals, name='diagonals') 193 dtype = self._diagonals.dtype 194 self._diagonals_format = diagonals_format 195 196 super(LinearOperatorTridiag, self).__init__( 197 dtype=dtype, 198 is_non_singular=is_non_singular, 199 is_self_adjoint=is_self_adjoint, 200 is_positive_definite=is_positive_definite, 201 is_square=is_square, 202 parameters=parameters, 203 name=name) 204 205 def _shape(self): 206 if self.diagonals_format == _MATRIX: 207 return self.diagonals.shape 208 if self.diagonals_format == _COMPACT: 209 # Remove the second to last dimension that contains the value 3. 210 d_shape = self.diagonals.shape[:-2].concatenate( 211 self.diagonals.shape[-1]) 212 else: 213 broadcast_shape = array_ops.broadcast_static_shape( 214 self.diagonals[0].shape[:-1], 215 self.diagonals[1].shape[:-1]) 216 broadcast_shape = array_ops.broadcast_static_shape( 217 broadcast_shape, 218 self.diagonals[2].shape[:-1]) 219 d_shape = broadcast_shape.concatenate(self.diagonals[1].shape[-1]) 220 return d_shape.concatenate(d_shape[-1]) 221 222 def _shape_tensor(self, diagonals=None): 223 diagonals = diagonals if diagonals is not None else self.diagonals 224 if self.diagonals_format == _MATRIX: 225 return array_ops.shape(diagonals) 226 if self.diagonals_format == _COMPACT: 227 d_shape = array_ops.shape(diagonals[..., 0, :]) 228 else: 229 broadcast_shape = array_ops.broadcast_dynamic_shape( 230 array_ops.shape(self.diagonals[0])[:-1], 231 array_ops.shape(self.diagonals[1])[:-1]) 232 broadcast_shape = array_ops.broadcast_dynamic_shape( 233 broadcast_shape, 234 array_ops.shape(self.diagonals[2])[:-1]) 235 d_shape = array_ops.concat( 236 [broadcast_shape, [array_ops.shape(self.diagonals[1])[-1]]], axis=0) 237 return array_ops.concat([d_shape, [d_shape[-1]]], axis=-1) 238 239 def _assert_self_adjoint(self): 240 # Check the diagonal has non-zero imaginary, and the super and subdiagonals 241 # are conjugate. 242 243 asserts = [] 244 diag_message = ( 245 'This tridiagonal operator contained non-zero ' 246 'imaginary values on the diagonal.') 247 off_diag_message = ( 248 'This tridiagonal operator has non-conjugate ' 249 'subdiagonal and superdiagonal.') 250 251 if self.diagonals_format == _MATRIX: 252 asserts += [check_ops.assert_equal( 253 self.diagonals, linalg.adjoint(self.diagonals), 254 message='Matrix was not equal to its adjoint.')] 255 elif self.diagonals_format == _COMPACT: 256 diagonals = ops.convert_to_tensor_v2_with_dispatch(self.diagonals) 257 asserts += [linear_operator_util.assert_zero_imag_part( 258 diagonals[..., 1, :], message=diag_message)] 259 # Roll the subdiagonal so the shifted argument is at the end. 260 subdiag = manip_ops.roll(diagonals[..., 2, :], shift=-1, axis=-1) 261 asserts += [check_ops.assert_equal( 262 math_ops.conj(subdiag[..., :-1]), 263 diagonals[..., 0, :-1], 264 message=off_diag_message)] 265 else: 266 asserts += [linear_operator_util.assert_zero_imag_part( 267 self.diagonals[1], message=diag_message)] 268 subdiag = manip_ops.roll(self.diagonals[2], shift=-1, axis=-1) 269 asserts += [check_ops.assert_equal( 270 math_ops.conj(subdiag[..., :-1]), 271 self.diagonals[0][..., :-1], 272 message=off_diag_message)] 273 return control_flow_ops.group(asserts) 274 275 def _construct_adjoint_diagonals(self, diagonals): 276 # Constructs adjoint tridiagonal matrix from diagonals. 277 if self.diagonals_format == _SEQUENCE: 278 diagonals = [math_ops.conj(d) for d in reversed(diagonals)] 279 # The subdiag and the superdiag swap places, so we need to shift the 280 # padding argument. 281 diagonals[0] = manip_ops.roll(diagonals[0], shift=-1, axis=-1) 282 diagonals[2] = manip_ops.roll(diagonals[2], shift=1, axis=-1) 283 return diagonals 284 elif self.diagonals_format == _MATRIX: 285 return linalg.adjoint(diagonals) 286 else: 287 diagonals = math_ops.conj(diagonals) 288 superdiag, diag, subdiag = array_ops.unstack( 289 diagonals, num=3, axis=-2) 290 # The subdiag and the superdiag swap places, so we need 291 # to shift all arguments. 292 new_superdiag = manip_ops.roll(subdiag, shift=-1, axis=-1) 293 new_subdiag = manip_ops.roll(superdiag, shift=1, axis=-1) 294 return array_ops.stack([new_superdiag, diag, new_subdiag], axis=-2) 295 296 def _matmul(self, x, adjoint=False, adjoint_arg=False): 297 diagonals = self.diagonals 298 if adjoint: 299 diagonals = self._construct_adjoint_diagonals(diagonals) 300 x = linalg.adjoint(x) if adjoint_arg else x 301 return linalg.tridiagonal_matmul( 302 diagonals, x, 303 diagonals_format=self.diagonals_format) 304 305 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 306 diagonals = self.diagonals 307 if adjoint: 308 diagonals = self._construct_adjoint_diagonals(diagonals) 309 310 # TODO(b/144860784): Remove the broadcasting code below once 311 # tridiagonal_solve broadcasts. 312 313 rhs_shape = array_ops.shape(rhs) 314 k = self._shape_tensor(diagonals)[-1] 315 broadcast_shape = array_ops.broadcast_dynamic_shape( 316 self._shape_tensor(diagonals)[:-2], rhs_shape[:-2]) 317 rhs = array_ops.broadcast_to( 318 rhs, array_ops.concat( 319 [broadcast_shape, rhs_shape[-2:]], axis=-1)) 320 if self.diagonals_format == _MATRIX: 321 diagonals = array_ops.broadcast_to( 322 diagonals, array_ops.concat( 323 [broadcast_shape, [k, k]], axis=-1)) 324 elif self.diagonals_format == _COMPACT: 325 diagonals = array_ops.broadcast_to( 326 diagonals, array_ops.concat( 327 [broadcast_shape, [3, k]], axis=-1)) 328 else: 329 diagonals = [ 330 array_ops.broadcast_to(d, array_ops.concat( 331 [broadcast_shape, [k]], axis=-1)) for d in diagonals] 332 333 y = linalg.tridiagonal_solve( 334 diagonals, rhs, 335 diagonals_format=self.diagonals_format, 336 transpose_rhs=adjoint_arg, 337 conjugate_rhs=adjoint_arg) 338 return y 339 340 def _diag_part(self): 341 if self.diagonals_format == _MATRIX: 342 return array_ops.matrix_diag_part(self.diagonals) 343 elif self.diagonals_format == _SEQUENCE: 344 diagonal = self.diagonals[1] 345 return array_ops.broadcast_to( 346 diagonal, self.shape_tensor()[:-1]) 347 else: 348 return self.diagonals[..., 1, :] 349 350 def _to_dense(self): 351 if self.diagonals_format == _MATRIX: 352 return self.diagonals 353 354 if self.diagonals_format == _COMPACT: 355 return gen_array_ops.matrix_diag_v3( 356 self.diagonals, 357 k=(-1, 1), 358 num_rows=-1, 359 num_cols=-1, 360 align='LEFT_RIGHT', 361 padding_value=0.) 362 363 diagonals = [ 364 ops.convert_to_tensor_v2_with_dispatch(d) for d in self.diagonals 365 ] 366 diagonals = array_ops.stack(diagonals, axis=-2) 367 368 return gen_array_ops.matrix_diag_v3( 369 diagonals, 370 k=(-1, 1), 371 num_rows=-1, 372 num_cols=-1, 373 align='LEFT_RIGHT', 374 padding_value=0.) 375 376 @property 377 def diagonals(self): 378 return self._diagonals 379 380 @property 381 def diagonals_format(self): 382 return self._diagonals_format 383 384 @property 385 def _composite_tensor_fields(self): 386 return ('diagonals', 'diagonals_format') 387 388 @property 389 def _experimental_parameter_ndims_to_matrix_ndims(self): 390 diagonal_event_ndims = 2 391 if self.diagonals_format == _SEQUENCE: 392 # For the diagonal and the super/sub diagonals. 393 diagonal_event_ndims = [1, 1, 1] 394 return { 395 'diagonals': diagonal_event_ndims, 396 } 397