1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for linear algebra.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import check_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import gen_linalg_ops 31from tensorflow.python.ops import linalg_ops 32from tensorflow.python.ops import map_fn 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import special_math_ops 35from tensorflow.python.ops import stateless_random_ops 36from tensorflow.python.util import dispatch 37from tensorflow.python.util.tf_export import tf_export 38 39# Linear algebra ops. 40band_part = array_ops.matrix_band_part 41cholesky = linalg_ops.cholesky 42cholesky_solve = linalg_ops.cholesky_solve 43det = linalg_ops.matrix_determinant 44slogdet = gen_linalg_ops.log_matrix_determinant 45tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet)) 46diag = array_ops.matrix_diag 47diag_part = array_ops.matrix_diag_part 48eigh = linalg_ops.self_adjoint_eig 49eigvalsh = linalg_ops.self_adjoint_eigvals 50einsum = special_math_ops.einsum 51eye = linalg_ops.eye 52inv = linalg_ops.matrix_inverse 53logm = gen_linalg_ops.matrix_logarithm 54lu = gen_linalg_ops.lu 55tf_export('linalg.logm')(dispatch.add_dispatch_support(logm)) 56lstsq = linalg_ops.matrix_solve_ls 57norm = linalg_ops.norm 58qr = linalg_ops.qr 59set_diag = array_ops.matrix_set_diag 60solve = linalg_ops.matrix_solve 61sqrtm = linalg_ops.matrix_square_root 62svd = linalg_ops.svd 63tensordot = math_ops.tensordot 64trace = math_ops.trace 65transpose = array_ops.matrix_transpose 66triangular_solve = linalg_ops.matrix_triangular_solve 67 68 69@tf_export('linalg.logdet') 70@dispatch.add_dispatch_support 71def logdet(matrix, name=None): 72 """Computes log of the determinant of a hermitian positive definite matrix. 73 74 ```python 75 # Compute the determinant of a matrix while reducing the chance of over- or 76 underflow: 77 A = ... # shape 10 x 10 78 det = tf.exp(tf.linalg.logdet(A)) # scalar 79 ``` 80 81 Args: 82 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 83 or `complex128` with shape `[..., M, M]`. 84 name: A name to give this `Op`. Defaults to `logdet`. 85 86 Returns: 87 The natural log of the determinant of `matrix`. 88 89 @compatibility(numpy) 90 Equivalent to numpy.linalg.slogdet, although no sign is returned since only 91 hermitian positive definite matrices are supported. 92 @end_compatibility 93 """ 94 # This uses the property that the log det(A) = 2*sum(log(real(diag(C)))) 95 # where C is the cholesky decomposition of A. 96 with ops.name_scope(name, 'logdet', [matrix]): 97 chol = gen_linalg_ops.cholesky(matrix) 98 return 2.0 * math_ops.reduce_sum( 99 math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))), 100 axis=[-1]) 101 102 103@tf_export('linalg.adjoint') 104@dispatch.add_dispatch_support 105def adjoint(matrix, name=None): 106 """Transposes the last two dimensions of and conjugates tensor `matrix`. 107 108 For example: 109 110 ```python 111 x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j], 112 [4 + 4j, 5 + 5j, 6 + 6j]]) 113 tf.linalg.adjoint(x) # [[1 - 1j, 4 - 4j], 114 # [2 - 2j, 5 - 5j], 115 # [3 - 3j, 6 - 6j]] 116 ``` 117 118 Args: 119 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 120 or `complex128` with shape `[..., M, M]`. 121 name: A name to give this `Op` (optional). 122 123 Returns: 124 The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of 125 matrix. 126 """ 127 with ops.name_scope(name, 'adjoint', [matrix]): 128 matrix = ops.convert_to_tensor(matrix, name='matrix') 129 return array_ops.matrix_transpose(matrix, conjugate=True) 130 131 132# This section is ported nearly verbatim from Eigen's implementation: 133# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html 134def _matrix_exp_pade3(matrix): 135 """3rd-order Pade approximant for matrix exponential.""" 136 b = [120.0, 60.0, 12.0] 137 b = [constant_op.constant(x, matrix.dtype) for x in b] 138 ident = linalg_ops.eye( 139 array_ops.shape(matrix)[-2], 140 batch_shape=array_ops.shape(matrix)[:-2], 141 dtype=matrix.dtype) 142 matrix_2 = math_ops.matmul(matrix, matrix) 143 tmp = matrix_2 + b[1] * ident 144 matrix_u = math_ops.matmul(matrix, tmp) 145 matrix_v = b[2] * matrix_2 + b[0] * ident 146 return matrix_u, matrix_v 147 148 149def _matrix_exp_pade5(matrix): 150 """5th-order Pade approximant for matrix exponential.""" 151 b = [30240.0, 15120.0, 3360.0, 420.0, 30.0] 152 b = [constant_op.constant(x, matrix.dtype) for x in b] 153 ident = linalg_ops.eye( 154 array_ops.shape(matrix)[-2], 155 batch_shape=array_ops.shape(matrix)[:-2], 156 dtype=matrix.dtype) 157 matrix_2 = math_ops.matmul(matrix, matrix) 158 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 159 tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident 160 matrix_u = math_ops.matmul(matrix, tmp) 161 matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident 162 return matrix_u, matrix_v 163 164 165def _matrix_exp_pade7(matrix): 166 """7th-order Pade approximant for matrix exponential.""" 167 b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0] 168 b = [constant_op.constant(x, matrix.dtype) for x in b] 169 ident = linalg_ops.eye( 170 array_ops.shape(matrix)[-2], 171 batch_shape=array_ops.shape(matrix)[:-2], 172 dtype=matrix.dtype) 173 matrix_2 = math_ops.matmul(matrix, matrix) 174 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 175 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 176 tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident 177 matrix_u = math_ops.matmul(matrix, tmp) 178 matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident 179 return matrix_u, matrix_v 180 181 182def _matrix_exp_pade9(matrix): 183 """9th-order Pade approximant for matrix exponential.""" 184 b = [ 185 17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0, 186 2162160.0, 110880.0, 3960.0, 90.0 187 ] 188 b = [constant_op.constant(x, matrix.dtype) for x in b] 189 ident = linalg_ops.eye( 190 array_ops.shape(matrix)[-2], 191 batch_shape=array_ops.shape(matrix)[:-2], 192 dtype=matrix.dtype) 193 matrix_2 = math_ops.matmul(matrix, matrix) 194 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 195 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 196 matrix_8 = math_ops.matmul(matrix_6, matrix_2) 197 tmp = ( 198 matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + 199 b[1] * ident) 200 matrix_u = math_ops.matmul(matrix, tmp) 201 matrix_v = ( 202 b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + 203 b[0] * ident) 204 return matrix_u, matrix_v 205 206 207def _matrix_exp_pade13(matrix): 208 """13th-order Pade approximant for matrix exponential.""" 209 b = [ 210 64764752532480000.0, 32382376266240000.0, 7771770303897600.0, 211 1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0, 212 33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0 213 ] 214 b = [constant_op.constant(x, matrix.dtype) for x in b] 215 ident = linalg_ops.eye( 216 array_ops.shape(matrix)[-2], 217 batch_shape=array_ops.shape(matrix)[:-2], 218 dtype=matrix.dtype) 219 matrix_2 = math_ops.matmul(matrix, matrix) 220 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 221 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 222 tmp_u = ( 223 math_ops.matmul(matrix_6, matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) + 224 b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident) 225 matrix_u = math_ops.matmul(matrix, tmp_u) 226 tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2 227 matrix_v = ( 228 math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 + 229 b[2] * matrix_2 + b[0] * ident) 230 return matrix_u, matrix_v 231 232 233@tf_export('linalg.expm') 234@dispatch.add_dispatch_support 235def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin 236 r"""Computes the matrix exponential of one or more square matrices. 237 238 $$exp(A) = \sum_{n=0}^\infty A^n/n!$$ 239 240 The exponential is computed using a combination of the scaling and squaring 241 method and the Pade approximation. Details can be found in: 242 Nicholas J. Higham, "The scaling and squaring method for the matrix 243 exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005. 244 245 The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions 246 form square matrices. The output is a tensor of the same shape as the input 247 containing the exponential for all input submatrices `[..., :, :]`. 248 249 Args: 250 input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, or 251 `complex128` with shape `[..., M, M]`. 252 name: A name to give this `Op` (optional). 253 254 Returns: 255 the matrix exponential of the input. 256 257 Raises: 258 ValueError: An unsupported type is provided as input. 259 260 @compatibility(scipy) 261 Equivalent to scipy.linalg.expm 262 @end_compatibility 263 """ 264 with ops.name_scope(name, 'matrix_exponential', [input]): 265 matrix = ops.convert_to_tensor(input, name='input') 266 if matrix.shape[-2:] == [0, 0]: 267 return matrix 268 batch_shape = matrix.shape[:-2] 269 if not batch_shape.is_fully_defined(): 270 batch_shape = array_ops.shape(matrix)[:-2] 271 272 # reshaping the batch makes the where statements work better 273 matrix = array_ops.reshape( 274 matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0)) 275 l1_norm = math_ops.reduce_max( 276 math_ops.reduce_sum( 277 math_ops.abs(matrix), 278 axis=array_ops.size(array_ops.shape(matrix)) - 2), 279 axis=-1)[..., array_ops.newaxis, array_ops.newaxis] 280 281 const = lambda x: constant_op.constant(x, l1_norm.dtype) 282 283 def _nest_where(vals, cases): 284 assert len(vals) == len(cases) - 1 285 if len(vals) == 1: 286 return array_ops.where_v2( 287 math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1]) 288 else: 289 return array_ops.where_v2( 290 math_ops.less(l1_norm, const(vals[0])), cases[0], 291 _nest_where(vals[1:], cases[1:])) 292 293 if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]: 294 maxnorm = const(3.925724783138660) 295 squarings = math_ops.maximum( 296 math_ops.floor( 297 math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) 298 u3, v3 = _matrix_exp_pade3(matrix) 299 u5, v5 = _matrix_exp_pade5(matrix) 300 u7, v7 = _matrix_exp_pade7( 301 matrix / 302 math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype)) 303 conds = (4.258730016922831e-001, 1.880152677804762e+000) 304 u = _nest_where(conds, (u3, u5, u7)) 305 v = _nest_where(conds, (v3, v5, v7)) 306 elif matrix.dtype in [dtypes.float64, dtypes.complex128]: 307 maxnorm = const(5.371920351148152) 308 squarings = math_ops.maximum( 309 math_ops.floor( 310 math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) 311 u3, v3 = _matrix_exp_pade3(matrix) 312 u5, v5 = _matrix_exp_pade5(matrix) 313 u7, v7 = _matrix_exp_pade7(matrix) 314 u9, v9 = _matrix_exp_pade9(matrix) 315 u13, v13 = _matrix_exp_pade13( 316 matrix / 317 math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype)) 318 conds = (1.495585217958292e-002, 2.539398330063230e-001, 319 9.504178996162932e-001, 2.097847961257068e+000) 320 u = _nest_where(conds, (u3, u5, u7, u9, u13)) 321 v = _nest_where(conds, (v3, v5, v7, v9, v13)) 322 else: 323 raise ValueError('tf.linalg.expm does not support matrices of type %s' % 324 matrix.dtype) 325 326 is_finite = math_ops.is_finite(math_ops.reduce_max(l1_norm)) 327 nan = constant_op.constant(np.nan, matrix.dtype) 328 result = control_flow_ops.cond( 329 is_finite, lambda: linalg_ops.matrix_solve(-u + v, u + v), 330 lambda: array_ops.fill(array_ops.shape(matrix), nan)) 331 max_squarings = math_ops.reduce_max(squarings) 332 i = const(0.0) 333 334 def c(i, _): 335 return control_flow_ops.cond(is_finite, 336 lambda: math_ops.less(i, max_squarings), 337 lambda: constant_op.constant(False)) 338 339 def b(i, r): 340 return i + 1, array_ops.where_v2( 341 math_ops.less(i, squarings), math_ops.matmul(r, r), r) 342 343 _, result = control_flow_ops.while_loop(c, b, [i, result]) 344 if not matrix.shape.is_fully_defined(): 345 return array_ops.reshape( 346 result, 347 array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0)) 348 return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:])) 349 350 351@tf_export('linalg.banded_triangular_solve', v1=[]) 352def banded_triangular_solve( 353 bands, 354 rhs, 355 lower=True, 356 adjoint=False, # pylint: disable=redefined-outer-name 357 name=None): 358 r"""Solve triangular systems of equations with a banded solver. 359 360 `bands` is a tensor of shape `[..., K, M]`, where `K` represents the number 361 of bands stored. This corresponds to a batch of `M` by `M` matrices, whose 362 `K` subdiagonals (when `lower` is `True`) are stored. 363 364 This operator broadcasts the batch dimensions of `bands` and the batch 365 dimensions of `rhs`. 366 367 368 Examples: 369 370 Storing 2 bands of a 3x3 matrix. 371 Note that first element in the second row is ignored due to 372 the 'LEFT_RIGHT' padding. 373 374 >>> x = [[2., 3., 4.], [1., 2., 3.]] 375 >>> x2 = [[2., 3., 4.], [10000., 2., 3.]] 376 >>> y = tf.zeros([3, 3]) 377 >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0)) 378 >>> z 379 <tf.Tensor: shape=(3, 3), dtype=float32, numpy= 380 array([[2., 0., 0.], 381 [2., 3., 0.], 382 [0., 3., 4.]], dtype=float32)> 383 >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1])) 384 >>> soln 385 <tf.Tensor: shape=(3, 1), dtype=float32, numpy= 386 array([[0.5 ], 387 [0. ], 388 [0.25]], dtype=float32)> 389 >>> are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1])) 390 >>> tf.reduce_all(are_equal).numpy() 391 True 392 >>> are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1])) 393 >>> tf.reduce_all(are_equal).numpy() 394 True 395 396 Storing 2 superdiagonals of a 4x4 matrix. Because of the 'LEFT_RIGHT' padding 397 the last element of the first row is ignored. 398 399 >>> x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]] 400 >>> y = tf.zeros([4, 4]) 401 >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1)) 402 >>> z 403 <tf.Tensor: shape=(4, 4), dtype=float32, numpy= 404 array([[-1., 2., 0., 0.], 405 [ 0., -2., 3., 0.], 406 [ 0., 0., -3., 4.], 407 [ 0., 0., -0., -4.]], dtype=float32)> 408 >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False) 409 >>> soln 410 <tf.Tensor: shape=(4, 1), dtype=float32, numpy= 411 array([[-4. ], 412 [-1.5 ], 413 [-0.6666667], 414 [-0.25 ]], dtype=float32)> 415 >>> are_equal = (soln == tf.linalg.triangular_solve( 416 ... z, tf.ones([4, 1]), lower=False)) 417 >>> tf.reduce_all(are_equal).numpy() 418 True 419 420 421 Args: 422 bands: A `Tensor` describing the bands of the left hand side, with shape 423 `[..., K, M]`. The `K` rows correspond to the diagonal to the `K - 1`-th 424 diagonal (the diagonal is the top row) when `lower` is `True` and 425 otherwise the `K - 1`-th superdiagonal to the diagonal (the diagonal is 426 the bottom row) when `lower` is `False`. The bands are stored with 427 'LEFT_RIGHT' alignment, where the superdiagonals are padded on the right 428 and subdiagonals are padded on the left. This is the alignment cuSPARSE 429 uses. See `tf.linalg.set_diag` for more details. 430 rhs: A `Tensor` of shape [..., M] or [..., M, N] and with the same dtype as 431 `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known 432 statically, `rhs` will be treated as a matrix rather than a vector. 433 lower: An optional `bool`. Defaults to `True`. Boolean indicating whether 434 `bands` represents a lower or upper triangular matrix. 435 adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether 436 to solve with the matrix's block-wise adjoint. 437 name: A name to give this `Op` (optional). 438 439 Returns: 440 A `Tensor` of shape [..., M] or [..., M, N] containing the solutions. 441 """ 442 with ops.name_scope(name, 'banded_triangular_solve', [bands, rhs]): 443 return gen_linalg_ops.banded_triangular_solve( 444 bands, rhs, lower=lower, adjoint=adjoint) 445 446 447@tf_export('linalg.tridiagonal_solve') 448@dispatch.add_dispatch_support 449def tridiagonal_solve(diagonals, 450 rhs, 451 diagonals_format='compact', 452 transpose_rhs=False, 453 conjugate_rhs=False, 454 name=None, 455 partial_pivoting=True, 456 perturb_singular=False): 457 r"""Solves tridiagonal systems of equations. 458 459 The input can be supplied in various formats: `matrix`, `sequence` and 460 `compact`, specified by the `diagonals_format` arg. 461 462 In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with 463 two inner-most dimensions representing the square tridiagonal matrices. 464 Elements outside of the three diagonals will be ignored. 465 466 In `sequence` format, `diagonals` are supplied as a tuple or list of three 467 tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing 468 superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either 469 `M-1` or `M`; in the latter case, the last element of superdiagonal and the 470 first element of subdiagonal will be ignored. 471 472 In `compact` format the three diagonals are brought together into one tensor 473 of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, 474 diagonals, and subdiagonals, in order. Similarly to `sequence` format, 475 elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. 476 477 The `compact` format is recommended as the one with best performance. In case 478 you need to cast a tensor into a compact format manually, use `tf.gather_nd`. 479 An example for a tensor of shape [m, m]: 480 481 ```python 482 rhs = tf.constant([...]) 483 matrix = tf.constant([[...]]) 484 m = matrix.shape[0] 485 dummy_idx = [0, 0] # An arbitrary element to use as a dummy 486 indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx], # Superdiagonal 487 [[i, i] for i in range(m)], # Diagonal 488 [dummy_idx] + [[i + 1, i] for i in range(m - 1)]] # Subdiagonal 489 diagonals=tf.gather_nd(matrix, indices) 490 x = tf.linalg.tridiagonal_solve(diagonals, rhs) 491 ``` 492 493 Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or 494 `[..., M, K]`. The latter allows to simultaneously solve K systems with the 495 same left-hand sides and K different right-hand sides. If `transpose_rhs` 496 is set to `True` the expected shape is `[..., M]` or `[..., K, M]`. 497 498 The batch dimensions, denoted as `...`, must be the same in `diagonals` and 499 `rhs`. 500 501 The output is a tensor of the same shape as `rhs`: either `[..., M]` or 502 `[..., M, K]`. 503 504 The op isn't guaranteed to raise an error if the input matrix is not 505 invertible. `tf.debugging.check_numerics` can be applied to the output to 506 detect invertibility problems. 507 508 **Note**: with large batch sizes, the computation on the GPU may be slow, if 509 either `partial_pivoting=True` or there are multiple right-hand sides 510 (`K > 1`). If this issue arises, consider if it's possible to disable pivoting 511 and have `K = 1`, or, alternatively, consider using CPU. 512 513 On CPU, solution is computed via Gaussian elimination with or without partial 514 pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE 515 library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv 516 517 Args: 518 diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The 519 shape depends of `diagonals_format`, see description above. Must be 520 `float32`, `float64`, `complex64`, or `complex128`. 521 rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as 522 `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known 523 statically, `rhs` will be treated as a matrix rather than a vector. 524 diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is 525 `compact`. 526 transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect 527 if the shape of rhs is [..., M]). 528 conjugate_rhs: If `True`, `rhs` is conjugated before solving. 529 name: A name to give this `Op` (optional). 530 partial_pivoting: whether to perform partial pivoting. `True` by default. 531 Partial pivoting makes the procedure more stable, but slower. Partial 532 pivoting is unnecessary in some cases, including diagonally dominant and 533 symmetric positive definite matrices (see e.g. theorem 9.12 in [1]). 534 perturb_singular: whether to perturb singular matrices to return a finite 535 result. `False` by default. If true, solutions to systems involving 536 a singular matrix will be computed by perturbing near-zero pivots in 537 the partially pivoted LU decomposition. Specifically, tiny pivots are 538 perturbed by an amount of order `eps * max_{ij} |U(i,j)|` to avoid 539 overflow. Here `U` is the upper triangular part of the LU decomposition, 540 and `eps` is the machine precision. This is useful for solving 541 numerically singular systems when computing eigenvectors by inverse 542 iteration. 543 If `partial_pivoting` is `False`, `perturb_singular` must be `False` as 544 well. 545 546 Returns: 547 A `Tensor` of shape [..., M] or [..., M, K] containing the solutions. 548 If the input matrix is singular, the result is undefined. 549 550 Raises: 551 ValueError: Is raised if any of the following conditions hold: 552 1. An unsupported type is provided as input, 553 2. the input tensors have incorrect shapes, 554 3. `perturb_singular` is `True` but `partial_pivoting` is not. 555 UnimplementedError: Whenever `partial_pivoting` is true and the backend is 556 XLA, or whenever `perturb_singular` is true and the backend is 557 XLA or GPU. 558 559 [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms: 560 Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7. 561 562 """ 563 if perturb_singular and not partial_pivoting: 564 raise ValueError('partial_pivoting must be True if perturb_singular is.') 565 566 if diagonals_format == 'compact': 567 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 568 conjugate_rhs, partial_pivoting, 569 perturb_singular, name) 570 571 if diagonals_format == 'sequence': 572 if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3: 573 raise ValueError('Expected diagonals to be a sequence of length 3.') 574 575 superdiag, maindiag, subdiag = diagonals 576 if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1]) or 577 not superdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])): 578 raise ValueError( 579 'Tensors representing the three diagonals must have the same shape,' 580 'except for the last dimension, got {}, {}, {}'.format( 581 subdiag.shape, maindiag.shape, superdiag.shape)) 582 583 m = tensor_shape.dimension_value(maindiag.shape[-1]) 584 585 def pad_if_necessary(t, name, last_dim_padding): 586 n = tensor_shape.dimension_value(t.shape[-1]) 587 if not n or n == m: 588 return t 589 if n == m - 1: 590 paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] + 591 [last_dim_padding]) 592 return array_ops.pad(t, paddings) 593 raise ValueError('Expected {} to be have length {} or {}, got {}.'.format( 594 name, m, m - 1, n)) 595 596 subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0]) 597 superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1]) 598 599 diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2) 600 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 601 conjugate_rhs, partial_pivoting, 602 perturb_singular, name) 603 604 if diagonals_format == 'matrix': 605 m1 = tensor_shape.dimension_value(diagonals.shape[-1]) 606 m2 = tensor_shape.dimension_value(diagonals.shape[-2]) 607 if m1 and m2 and m1 != m2: 608 raise ValueError( 609 'Expected last two dimensions of diagonals to be same, got {} and {}' 610 .format(m1, m2)) 611 m = m1 or m2 612 diagonals = array_ops.matrix_diag_part( 613 diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT') 614 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 615 conjugate_rhs, partial_pivoting, 616 perturb_singular, name) 617 618 raise ValueError('Unrecognized diagonals_format: {}'.format(diagonals_format)) 619 620 621def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 622 conjugate_rhs, partial_pivoting, 623 perturb_singular, name): 624 """Helper function used after the input has been cast to compact form.""" 625 diags_rank, rhs_rank = diagonals.shape.rank, rhs.shape.rank 626 627 # If we know the rank of the diagonal tensor, do some static checking. 628 if diags_rank: 629 if diags_rank < 2: 630 raise ValueError( 631 'Expected diagonals to have rank at least 2, got {}'.format( 632 diags_rank)) 633 if rhs_rank and rhs_rank != diags_rank and rhs_rank != diags_rank - 1: 634 raise ValueError('Expected the rank of rhs to be {} or {}, got {}'.format( 635 diags_rank - 1, diags_rank, rhs_rank)) 636 if (rhs_rank and not diagonals.shape[:-2].is_compatible_with( 637 rhs.shape[:diags_rank - 2])): 638 raise ValueError('Batch shapes {} and {} are incompatible'.format( 639 diagonals.shape[:-2], rhs.shape[:diags_rank - 2])) 640 641 if diagonals.shape[-2] and diagonals.shape[-2] != 3: 642 raise ValueError('Expected 3 diagonals got {}'.format(diagonals.shape[-2])) 643 644 def check_num_lhs_matches_num_rhs(): 645 if (diagonals.shape[-1] and rhs.shape[-2] and 646 diagonals.shape[-1] != rhs.shape[-2]): 647 raise ValueError('Expected number of left-hand sided and right-hand ' 648 'sides to be equal, got {} and {}'.format( 649 diagonals.shape[-1], rhs.shape[-2])) 650 651 if rhs_rank and diags_rank and rhs_rank == diags_rank - 1: 652 # Rhs provided as a vector, ignoring transpose_rhs 653 if conjugate_rhs: 654 rhs = math_ops.conj(rhs) 655 rhs = array_ops.expand_dims(rhs, -1) 656 check_num_lhs_matches_num_rhs() 657 return array_ops.squeeze( 658 linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, 659 perturb_singular, name), -1) 660 661 if transpose_rhs: 662 rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs) 663 elif conjugate_rhs: 664 rhs = math_ops.conj(rhs) 665 666 check_num_lhs_matches_num_rhs() 667 return linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, 668 perturb_singular, name) 669 670 671@tf_export('linalg.tridiagonal_matmul') 672@dispatch.add_dispatch_support 673def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None): 674 r"""Multiplies tridiagonal matrix by matrix. 675 676 `diagonals` is representation of 3-diagonal NxN matrix, which depends on 677 `diagonals_format`. 678 679 In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with 680 two inner-most dimensions representing the square tridiagonal matrices. 681 Elements outside of the three diagonals will be ignored. 682 683 If `sequence` format, `diagonals` is list or tuple of three tensors: 684 `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element 685 of `superdiag` first element of `subdiag` are ignored. 686 687 In `compact` format the three diagonals are brought together into one tensor 688 of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, 689 diagonals, and subdiagonals, in order. Similarly to `sequence` format, 690 elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. 691 692 The `sequence` format is recommended as the one with the best performance. 693 694 `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`. 695 696 Example: 697 698 ```python 699 superdiag = tf.constant([-1, -1, 0], dtype=tf.float64) 700 maindiag = tf.constant([2, 2, 2], dtype=tf.float64) 701 subdiag = tf.constant([0, -1, -1], dtype=tf.float64) 702 diagonals = [superdiag, maindiag, subdiag] 703 rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64) 704 x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence') 705 ``` 706 707 Args: 708 diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The 709 shape depends of `diagonals_format`, see description above. Must be 710 `float32`, `float64`, `complex64`, or `complex128`. 711 rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`. 712 diagonals_format: one of `sequence`, or `compact`. Default is `compact`. 713 name: A name to give this `Op` (optional). 714 715 Returns: 716 A `Tensor` of shape [..., M, N] containing the result of multiplication. 717 718 Raises: 719 ValueError: An unsupported type is provided as input, or when the input 720 tensors have incorrect shapes. 721 """ 722 if diagonals_format == 'compact': 723 superdiag = diagonals[..., 0, :] 724 maindiag = diagonals[..., 1, :] 725 subdiag = diagonals[..., 2, :] 726 elif diagonals_format == 'sequence': 727 superdiag, maindiag, subdiag = diagonals 728 elif diagonals_format == 'matrix': 729 m1 = tensor_shape.dimension_value(diagonals.shape[-1]) 730 m2 = tensor_shape.dimension_value(diagonals.shape[-2]) 731 if m1 and m2 and m1 != m2: 732 raise ValueError( 733 'Expected last two dimensions of diagonals to be same, got {} and {}' 734 .format(m1, m2)) 735 diags = array_ops.matrix_diag_part( 736 diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT') 737 superdiag = diags[..., 0, :] 738 maindiag = diags[..., 1, :] 739 subdiag = diags[..., 2, :] 740 else: 741 raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format) 742 743 # C++ backend requires matrices. 744 # Converting 1-dimensional vectors to matrices with 1 row. 745 superdiag = array_ops.expand_dims(superdiag, -2) 746 maindiag = array_ops.expand_dims(maindiag, -2) 747 subdiag = array_ops.expand_dims(subdiag, -2) 748 749 return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name) 750 751 752def _maybe_validate_matrix(a, validate_args): 753 """Checks that input is a `float` matrix.""" 754 assertions = [] 755 if not a.dtype.is_floating: 756 raise TypeError('Input `a` must have `float`-like `dtype` ' 757 '(saw {}).'.format(a.dtype.name)) 758 if a.shape is not None and a.shape.rank is not None: 759 if a.shape.rank < 2: 760 raise ValueError('Input `a` must have at least 2 dimensions ' 761 '(saw: {}).'.format(a.shape.rank)) 762 elif validate_args: 763 assertions.append( 764 check_ops.assert_rank_at_least( 765 a, rank=2, message='Input `a` must have at least 2 dimensions.')) 766 return assertions 767 768 769@tf_export('linalg.matrix_rank') 770@dispatch.add_dispatch_support 771def matrix_rank(a, tol=None, validate_args=False, name=None): 772 """Compute the matrix rank of one or more matrices. 773 774 Args: 775 a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be 776 pseudo-inverted. 777 tol: Threshold below which the singular value is counted as 'zero'. 778 Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`). 779 validate_args: When `True`, additional assertions might be embedded in the 780 graph. 781 Default value: `False` (i.e., no graph assertions are added). 782 name: Python `str` prefixed to ops created by this function. 783 Default value: 'matrix_rank'. 784 785 Returns: 786 matrix_rank: (Batch of) `int32` scalars representing the number of non-zero 787 singular values. 788 """ 789 with ops.name_scope(name or 'matrix_rank'): 790 a = ops.convert_to_tensor(a, dtype_hint=dtypes.float32, name='a') 791 assertions = _maybe_validate_matrix(a, validate_args) 792 if assertions: 793 with ops.control_dependencies(assertions): 794 a = array_ops.identity(a) 795 s = svd(a, compute_uv=False) 796 if tol is None: 797 if (a.shape[-2:]).is_fully_defined(): 798 m = np.max(a.shape[-2:].as_list()) 799 else: 800 m = math_ops.reduce_max(array_ops.shape(a)[-2:]) 801 eps = np.finfo(a.dtype.as_numpy_dtype).eps 802 tol = ( 803 eps * math_ops.cast(m, a.dtype) * 804 math_ops.reduce_max(s, axis=-1, keepdims=True)) 805 return math_ops.reduce_sum(math_ops.cast(s > tol, dtypes.int32), axis=-1) 806 807 808@tf_export('linalg.pinv') 809@dispatch.add_dispatch_support 810def pinv(a, rcond=None, validate_args=False, name=None): 811 """Compute the Moore-Penrose pseudo-inverse of one or more matrices. 812 813 Calculate the [generalized inverse of a matrix]( 814 https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its 815 singular-value decomposition (SVD) and including all large singular values. 816 817 The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves' 818 [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then 819 `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if 820 `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then 821 `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1] 822 823 This function is analogous to [`numpy.linalg.pinv`]( 824 https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html). 825 It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the 826 default `rcond` is `1e-15`. Here the default is 827 `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`. 828 829 Args: 830 a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be 831 pseudo-inverted. 832 rcond: `Tensor` of small singular value cutoffs. Singular values smaller 833 (in modulus) than `rcond` * largest_singular_value (again, in modulus) are 834 set to zero. Must broadcast against `tf.shape(a)[:-2]`. 835 Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`. 836 validate_args: When `True`, additional assertions might be embedded in the 837 graph. 838 Default value: `False` (i.e., no graph assertions are added). 839 name: Python `str` prefixed to ops created by this function. 840 Default value: 'pinv'. 841 842 Returns: 843 a_pinv: (Batch of) pseudo-inverse of input `a`. Has same shape as `a` except 844 rightmost two dimensions are transposed. 845 846 Raises: 847 TypeError: if input `a` does not have `float`-like `dtype`. 848 ValueError: if input `a` has fewer than 2 dimensions. 849 850 #### Examples 851 852 ```python 853 import tensorflow as tf 854 import tensorflow_probability as tfp 855 856 a = tf.constant([[1., 0.4, 0.5], 857 [0.4, 0.2, 0.25], 858 [0.5, 0.25, 0.35]]) 859 tf.matmul(tf.linalg..pinv(a), a) 860 # ==> array([[1., 0., 0.], 861 [0., 1., 0.], 862 [0., 0., 1.]], dtype=float32) 863 864 a = tf.constant([[1., 0.4, 0.5, 1.], 865 [0.4, 0.2, 0.25, 2.], 866 [0.5, 0.25, 0.35, 3.]]) 867 tf.matmul(tf.linalg..pinv(a), a) 868 # ==> array([[ 0.76, 0.37, 0.21, -0.02], 869 [ 0.37, 0.43, -0.33, 0.02], 870 [ 0.21, -0.33, 0.81, 0.01], 871 [-0.02, 0.02, 0.01, 1. ]], dtype=float32) 872 ``` 873 874 #### References 875 876 [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press, 877 Inc., 1980, pp. 139-142. 878 """ 879 with ops.name_scope(name or 'pinv'): 880 a = ops.convert_to_tensor(a, name='a') 881 882 assertions = _maybe_validate_matrix(a, validate_args) 883 if assertions: 884 with ops.control_dependencies(assertions): 885 a = array_ops.identity(a) 886 887 dtype = a.dtype.as_numpy_dtype 888 889 if rcond is None: 890 891 def get_dim_size(dim): 892 dim_val = tensor_shape.dimension_value(a.shape[dim]) 893 if dim_val is not None: 894 return dim_val 895 return array_ops.shape(a)[dim] 896 897 num_rows = get_dim_size(-2) 898 num_cols = get_dim_size(-1) 899 if isinstance(num_rows, int) and isinstance(num_cols, int): 900 max_rows_cols = float(max(num_rows, num_cols)) 901 else: 902 max_rows_cols = math_ops.cast( 903 math_ops.maximum(num_rows, num_cols), dtype) 904 rcond = 10. * max_rows_cols * np.finfo(dtype).eps 905 906 rcond = ops.convert_to_tensor(rcond, dtype=dtype, name='rcond') 907 908 # Calculate pseudo inverse via SVD. 909 # Note: if a is Hermitian then u == v. (We might observe additional 910 # performance by explicitly setting `v = u` in such cases.) 911 [ 912 singular_values, # Sigma 913 left_singular_vectors, # U 914 right_singular_vectors, # V 915 ] = svd( 916 a, full_matrices=False, compute_uv=True) 917 918 # Saturate small singular values to inf. This has the effect of make 919 # `1. / s = 0.` while not resulting in `NaN` gradients. 920 cutoff = rcond * math_ops.reduce_max(singular_values, axis=-1) 921 singular_values = array_ops.where_v2( 922 singular_values > array_ops.expand_dims_v2(cutoff, -1), singular_values, 923 np.array(np.inf, dtype)) 924 925 # By the definition of the SVD, `a == u @ s @ v^H`, and the pseudo-inverse 926 # is defined as `pinv(a) == v @ inv(s) @ u^H`. 927 a_pinv = math_ops.matmul( 928 right_singular_vectors / array_ops.expand_dims_v2(singular_values, -2), 929 left_singular_vectors, 930 adjoint_b=True) 931 932 if a.shape is not None and a.shape.rank is not None: 933 a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]])) 934 935 return a_pinv 936 937 938@tf_export('linalg.lu_solve') 939@dispatch.add_dispatch_support 940def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): 941 """Solves systems of linear eqns `A X = RHS`, given LU factorizations. 942 943 Note: this function does not verify the implied matrix is actually invertible 944 nor is this condition checked even when `validate_args=True`. 945 946 Args: 947 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 948 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 949 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 950 X` then `perm = argmax(P)`. 951 rhs: Matrix-shaped float `Tensor` representing targets for which to solve; 952 `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., 953 tf.newaxis])[..., 0]`. 954 validate_args: Python `bool` indicating whether arguments should be checked 955 for correctness. Note: this function does not verify the implied matrix is 956 actually invertible, even when `validate_args=True`. 957 Default value: `False` (i.e., don't validate arguments). 958 name: Python `str` name given to ops managed by this object. 959 Default value: `None` (i.e., 'lu_solve'). 960 961 Returns: 962 x: The `X` in `A @ X = RHS`. 963 964 #### Examples 965 966 ```python 967 import numpy as np 968 import tensorflow as tf 969 import tensorflow_probability as tfp 970 971 x = [[[1., 2], 972 [3, 4]], 973 [[7, 8], 974 [3, 4]]] 975 inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) 976 tf.assert_near(tf.matrix_inverse(x), inv_x) 977 # ==> True 978 ``` 979 980 """ 981 982 with ops.name_scope(name or 'lu_solve'): 983 lower_upper = ops.convert_to_tensor( 984 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 985 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 986 rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') 987 988 assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) 989 if assertions: 990 with ops.control_dependencies(assertions): 991 lower_upper = array_ops.identity(lower_upper) 992 perm = array_ops.identity(perm) 993 rhs = array_ops.identity(rhs) 994 995 if (rhs.shape.rank == 2 and perm.shape.rank == 1): 996 # Both rhs and perm have scalar batch_shape. 997 permuted_rhs = array_ops.gather(rhs, perm, axis=-2) 998 else: 999 # Either rhs or perm have non-scalar batch_shape or we can't determine 1000 # this information statically. 1001 rhs_shape = array_ops.shape(rhs) 1002 broadcast_batch_shape = array_ops.broadcast_dynamic_shape( 1003 rhs_shape[:-2], 1004 array_ops.shape(perm)[:-1]) 1005 d, m = rhs_shape[-2], rhs_shape[-1] 1006 rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]], 1007 axis=0) 1008 1009 # Tile out rhs. 1010 broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape) 1011 broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m]) 1012 1013 # Tile out perm and add batch indices. 1014 broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1]) 1015 broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d]) 1016 broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape) 1017 broadcast_batch_indices = array_ops.broadcast_to( 1018 math_ops.range(broadcast_batch_size)[:, array_ops.newaxis], 1019 [broadcast_batch_size, d]) 1020 broadcast_perm = array_ops.stack( 1021 [broadcast_batch_indices, broadcast_perm], axis=-1) 1022 1023 permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm) 1024 permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape) 1025 1026 lower = set_diag( 1027 band_part(lower_upper, num_lower=-1, num_upper=0), 1028 array_ops.ones( 1029 array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) 1030 return triangular_solve( 1031 lower_upper, # Only upper is accessed. 1032 triangular_solve(lower, permuted_rhs), 1033 lower=False) 1034 1035 1036@tf_export('linalg.lu_matrix_inverse') 1037@dispatch.add_dispatch_support 1038def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): 1039 """Computes the inverse given the LU decomposition(s) of one or more matrices. 1040 1041 This op is conceptually identical to, 1042 1043 ```python 1044 inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X)) 1045 tf.assert_near(tf.matrix_inverse(X), inv_X) 1046 # ==> True 1047 ``` 1048 1049 Note: this function does not verify the implied matrix is actually invertible 1050 nor is this condition checked even when `validate_args=True`. 1051 1052 Args: 1053 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 1054 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 1055 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 1056 X` then `perm = argmax(P)`. 1057 validate_args: Python `bool` indicating whether arguments should be checked 1058 for correctness. Note: this function does not verify the implied matrix is 1059 actually invertible, even when `validate_args=True`. 1060 Default value: `False` (i.e., don't validate arguments). 1061 name: Python `str` name given to ops managed by this object. 1062 Default value: `None` (i.e., 'lu_matrix_inverse'). 1063 1064 Returns: 1065 inv_x: The matrix_inv, i.e., 1066 `tf.matrix_inverse(tf.linalg.lu_reconstruct(lu, perm))`. 1067 1068 #### Examples 1069 1070 ```python 1071 import numpy as np 1072 import tensorflow as tf 1073 import tensorflow_probability as tfp 1074 1075 x = [[[3., 4], [1, 2]], 1076 [[7., 8], [3, 4]]] 1077 inv_x = tf.linalg.lu_matrix_inverse(*tf.linalg.lu(x)) 1078 tf.assert_near(tf.matrix_inverse(x), inv_x) 1079 # ==> True 1080 ``` 1081 1082 """ 1083 1084 with ops.name_scope(name or 'lu_matrix_inverse'): 1085 lower_upper = ops.convert_to_tensor( 1086 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 1087 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 1088 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 1089 if assertions: 1090 with ops.control_dependencies(assertions): 1091 lower_upper = array_ops.identity(lower_upper) 1092 perm = array_ops.identity(perm) 1093 shape = array_ops.shape(lower_upper) 1094 return lu_solve( 1095 lower_upper, 1096 perm, 1097 rhs=eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype), 1098 validate_args=False) 1099 1100 1101@tf_export('linalg.lu_reconstruct') 1102@dispatch.add_dispatch_support 1103def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): 1104 """The reconstruct one or more matrices from their LU decomposition(s). 1105 1106 Args: 1107 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 1108 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 1109 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 1110 X` then `perm = argmax(P)`. 1111 validate_args: Python `bool` indicating whether arguments should be checked 1112 for correctness. 1113 Default value: `False` (i.e., don't validate arguments). 1114 name: Python `str` name given to ops managed by this object. 1115 Default value: `None` (i.e., 'lu_reconstruct'). 1116 1117 Returns: 1118 x: The original input to `tf.linalg.lu`, i.e., `x` as in, 1119 `lu_reconstruct(*tf.linalg.lu(x))`. 1120 1121 #### Examples 1122 1123 ```python 1124 import numpy as np 1125 import tensorflow as tf 1126 import tensorflow_probability as tfp 1127 1128 x = [[[3., 4], [1, 2]], 1129 [[7., 8], [3, 4]]] 1130 x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x)) 1131 tf.assert_near(x, x_reconstructed) 1132 # ==> True 1133 ``` 1134 1135 """ 1136 with ops.name_scope(name or 'lu_reconstruct'): 1137 lower_upper = ops.convert_to_tensor( 1138 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 1139 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 1140 1141 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 1142 if assertions: 1143 with ops.control_dependencies(assertions): 1144 lower_upper = array_ops.identity(lower_upper) 1145 perm = array_ops.identity(perm) 1146 1147 shape = array_ops.shape(lower_upper) 1148 1149 lower = set_diag( 1150 band_part(lower_upper, num_lower=-1, num_upper=0), 1151 array_ops.ones(shape[:-1], dtype=lower_upper.dtype)) 1152 upper = band_part(lower_upper, num_lower=0, num_upper=-1) 1153 x = math_ops.matmul(lower, upper) 1154 1155 if (lower_upper.shape is None or lower_upper.shape.rank is None or 1156 lower_upper.shape.rank != 2): 1157 # We either don't know the batch rank or there are >0 batch dims. 1158 batch_size = math_ops.reduce_prod(shape[:-2]) 1159 d = shape[-1] 1160 x = array_ops.reshape(x, [batch_size, d, d]) 1161 perm = array_ops.reshape(perm, [batch_size, d]) 1162 perm = map_fn.map_fn(array_ops.invert_permutation, perm) 1163 batch_indices = array_ops.broadcast_to( 1164 math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d]) 1165 x = array_ops.gather_nd(x, array_ops.stack([batch_indices, perm], 1166 axis=-1)) 1167 x = array_ops.reshape(x, shape) 1168 else: 1169 x = array_ops.gather(x, array_ops.invert_permutation(perm)) 1170 1171 x.set_shape(lower_upper.shape) 1172 return x 1173 1174 1175def lu_reconstruct_assertions(lower_upper, perm, validate_args): 1176 """Returns list of assertions related to `lu_reconstruct` assumptions.""" 1177 assertions = [] 1178 1179 message = 'Input `lower_upper` must have at least 2 dimensions.' 1180 if lower_upper.shape.rank is not None and lower_upper.shape.rank < 2: 1181 raise ValueError(message) 1182 elif validate_args: 1183 assertions.append( 1184 check_ops.assert_rank_at_least_v2(lower_upper, rank=2, message=message)) 1185 1186 message = '`rank(lower_upper)` must equal `rank(perm) + 1`' 1187 if lower_upper.shape.rank is not None and perm.shape.rank is not None: 1188 if lower_upper.shape.rank != perm.shape.rank + 1: 1189 raise ValueError(message) 1190 elif validate_args: 1191 assertions.append( 1192 check_ops.assert_rank( 1193 lower_upper, rank=array_ops.rank(perm) + 1, message=message)) 1194 1195 message = '`lower_upper` must be square.' 1196 if lower_upper.shape[:-2].is_fully_defined(): 1197 if lower_upper.shape[-2] != lower_upper.shape[-1]: 1198 raise ValueError(message) 1199 elif validate_args: 1200 m, n = array_ops.split( 1201 array_ops.shape(lower_upper)[-2:], num_or_size_splits=2) 1202 assertions.append(check_ops.assert_equal(m, n, message=message)) 1203 1204 return assertions 1205 1206 1207def _lu_solve_assertions(lower_upper, perm, rhs, validate_args): 1208 """Returns list of assertions related to `lu_solve` assumptions.""" 1209 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 1210 1211 message = 'Input `rhs` must have at least 2 dimensions.' 1212 if rhs.shape.ndims is not None: 1213 if rhs.shape.ndims < 2: 1214 raise ValueError(message) 1215 elif validate_args: 1216 assertions.append( 1217 check_ops.assert_rank_at_least(rhs, rank=2, message=message)) 1218 1219 message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.' 1220 if (lower_upper.shape[-1] is not None and rhs.shape[-2] is not None): 1221 if lower_upper.shape[-1] != rhs.shape[-2]: 1222 raise ValueError(message) 1223 elif validate_args: 1224 assertions.append( 1225 check_ops.assert_equal( 1226 array_ops.shape(lower_upper)[-1], 1227 array_ops.shape(rhs)[-2], 1228 message=message)) 1229 1230 return assertions 1231 1232 1233@tf_export('linalg.eigh_tridiagonal') 1234@dispatch.add_dispatch_support 1235def eigh_tridiagonal(alpha, 1236 beta, 1237 eigvals_only=True, 1238 select='a', 1239 select_range=None, 1240 tol=None, 1241 name=None): 1242 """Computes the eigenvalues of a Hermitian tridiagonal matrix. 1243 1244 Args: 1245 alpha: A real or complex tensor of shape (n), the diagonal elements of the 1246 matrix. NOTE: If alpha is complex, the imaginary part is ignored (assumed 1247 zero) to satisfy the requirement that the matrix be Hermitian. 1248 beta: A real or complex tensor of shape (n-1), containing the elements of 1249 the first super-diagonal of the matrix. If beta is complex, the first 1250 sub-diagonal of the matrix is assumed to be the conjugate of beta to 1251 satisfy the requirement that the matrix be Hermitian 1252 eigvals_only: If False, both eigenvalues and corresponding eigenvectors are 1253 computed. If True, only eigenvalues are computed. Default is True. 1254 select: Optional string with values in {‘a’, ‘v’, ‘i’} (default is 'a') that 1255 determines which eigenvalues to calculate: 1256 'a': all eigenvalues. 1257 ‘v’: eigenvalues in the interval (min, max] given by `select_range`. 1258 'i’: eigenvalues with indices min <= i <= max. 1259 select_range: Size 2 tuple or list or tensor specifying the range of 1260 eigenvalues to compute together with select. If select is 'a', 1261 select_range is ignored. 1262 tol: Optional scalar. The absolute tolerance to which each eigenvalue is 1263 required. An eigenvalue (or cluster) is considered to have converged if it 1264 lies in an interval of this width. If tol is None (default), the value 1265 eps*|T|_2 is used where eps is the machine precision, and |T|_2 is the 1266 2-norm of the matrix T. 1267 name: Optional name of the op. 1268 1269 Returns: 1270 eig_vals: The eigenvalues of the matrix in non-decreasing order. 1271 eig_vectors: If `eigvals_only` is False the eigenvectors are returned in 1272 the second output argument. 1273 1274 Raises: 1275 ValueError: If input values are invalid. 1276 NotImplemented: Computing eigenvectors for `eigvals_only` = False is 1277 not implemented yet. 1278 1279 This op implements a subset of the functionality of 1280 scipy.linalg.eigh_tridiagonal. 1281 1282 Note: The result is undefined if the input contains +/-inf or NaN, or if 1283 any value in beta has a magnitude greater than 1284 `numpy.sqrt(numpy.finfo(beta.dtype.as_numpy_dtype).max)`. 1285 1286 1287 TODO(b/187527398): 1288 Add support for outer batch dimensions. 1289 1290 #### Examples 1291 1292 ```python 1293 import numpy 1294 eigvals = tf.linalg.eigh_tridiagonal([0.0, 0.0, 0.0], [1.0, 1.0]) 1295 eigvals_expected = [-numpy.sqrt(2.0), 0.0, numpy.sqrt(2.0)] 1296 tf.assert_near(eigvals_expected, eigvals) 1297 # ==> True 1298 ``` 1299 1300 """ 1301 with ops.name_scope(name or 'eigh_tridiagonal'): 1302 1303 def _compute_eigenvalues(alpha, beta): 1304 """Computes all eigenvalues of a Hermitian tridiagonal matrix.""" 1305 1306 def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): 1307 """Implements the Sturm sequence recurrence.""" 1308 with ops.name_scope('sturm'): 1309 n = alpha.shape[0] 1310 zeros = array_ops.zeros(array_ops.shape(x), dtype=dtypes.int32) 1311 ones = array_ops.ones(array_ops.shape(x), dtype=dtypes.int32) 1312 1313 # The first step in the Sturm sequence recurrence 1314 # requires special care if x is equal to alpha[0]. 1315 def sturm_step0(): 1316 q = alpha[0] - x 1317 count = array_ops.where(q < 0, ones, zeros) 1318 q = array_ops.where( 1319 math_ops.equal(alpha[0], x), alpha0_perturbation, q) 1320 return q, count 1321 1322 # Subsequent steps all take this form: 1323 def sturm_step(i, q, count): 1324 q = alpha[i] - beta_sq[i - 1] / q - x 1325 count = array_ops.where(q <= pivmin, count + 1, count) 1326 q = array_ops.where(q <= pivmin, math_ops.minimum(q, -pivmin), q) 1327 return q, count 1328 1329 # The first step initializes q and count. 1330 q, count = sturm_step0() 1331 1332 # Peel off ((n-1) % blocksize) steps from the main loop, so we can run 1333 # the bulk of the iterations unrolled by a factor of blocksize. 1334 blocksize = 16 1335 i = 1 1336 peel = (n - 1) % blocksize 1337 unroll_cnt = peel 1338 1339 def unrolled_steps(start, q, count): 1340 for j in range(unroll_cnt): 1341 q, count = sturm_step(start + j, q, count) 1342 return start + unroll_cnt, q, count 1343 1344 i, q, count = unrolled_steps(i, q, count) 1345 1346 # Run the remaining steps of the Sturm sequence using a partially 1347 # unrolled while loop. 1348 unroll_cnt = blocksize 1349 cond = lambda i, q, count: math_ops.less(i, n) 1350 _, _, count = control_flow_ops.while_loop( 1351 cond, unrolled_steps, [i, q, count], back_prop=False) 1352 return count 1353 1354 with ops.name_scope('compute_eigenvalues'): 1355 if alpha.dtype.is_complex: 1356 alpha = math_ops.real(alpha) 1357 beta_sq = math_ops.real(math_ops.conj(beta) * beta) 1358 beta_abs = math_ops.sqrt(beta_sq) 1359 else: 1360 beta_sq = math_ops.square(beta) 1361 beta_abs = math_ops.abs(beta) 1362 1363 # Estimate the largest and smallest eigenvalues of T using the 1364 # Gershgorin circle theorem. 1365 finfo = np.finfo(alpha.dtype.as_numpy_dtype) 1366 off_diag_abs_row_sum = array_ops.concat( 1367 [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) 1368 lambda_est_max = math_ops.minimum( 1369 finfo.max, math_ops.reduce_max(alpha + off_diag_abs_row_sum)) 1370 lambda_est_min = math_ops.maximum( 1371 finfo.min, math_ops.reduce_min(alpha - off_diag_abs_row_sum)) 1372 # Upper bound on 2-norm of T. 1373 t_norm = math_ops.maximum( 1374 math_ops.abs(lambda_est_min), math_ops.abs(lambda_est_max)) 1375 1376 # Compute the smallest allowed pivot in the Sturm sequence to avoid 1377 # overflow. 1378 one = np.ones([], dtype=alpha.dtype.as_numpy_dtype) 1379 safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) 1380 pivmin = safemin * math_ops.maximum(one, math_ops.reduce_max(beta_sq)) 1381 alpha0_perturbation = math_ops.square(finfo.eps * beta_abs[0]) 1382 abs_tol = finfo.eps * t_norm 1383 if tol: 1384 abs_tol = math_ops.maximum(tol, abs_tol) 1385 # In the worst case, when the absolute tolerance is eps*lambda_est_max 1386 # and lambda_est_max = -lambda_est_min, we have to take as many 1387 # bisection steps as there are bits in the mantissa plus 1. 1388 max_it = finfo.nmant + 1 1389 1390 # Determine the indices of the desired eigenvalues, based on select 1391 # and select_range. 1392 asserts = None 1393 if select == 'a': 1394 target_counts = math_ops.range(n) 1395 elif select == 'i': 1396 asserts = check_ops.assert_less_equal( 1397 select_range[0], 1398 select_range[1], 1399 message='Got empty index range in select_range.') 1400 target_counts = math_ops.range(select_range[0], select_range[1] + 1) 1401 elif select == 'v': 1402 asserts = check_ops.assert_less( 1403 select_range[0], 1404 select_range[1], 1405 message='Got empty interval in select_range.') 1406 else: 1407 raise ValueError("'select must have a value in {'a', 'i', 'v'}.") 1408 1409 if asserts: 1410 with ops.control_dependencies([asserts]): 1411 alpha = array_ops.identity(alpha) 1412 1413 # Run binary search for all desired eigenvalues in parallel, starting 1414 # from an interval slightly wider than the estimated 1415 # [lambda_est_min, lambda_est_max]. 1416 fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. 1417 norm_slack = math_ops.cast(n, alpha.dtype) * fudge * finfo.eps * t_norm 1418 if select in {'a', 'i'}: 1419 lower = lambda_est_min - norm_slack - 2 * fudge * pivmin 1420 upper = lambda_est_max + norm_slack + fudge * pivmin 1421 else: 1422 # Count the number of eigenvalues in the given range. 1423 lower = select_range[0] - norm_slack - 2 * fudge * pivmin 1424 upper = select_range[1] + norm_slack + fudge * pivmin 1425 first = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, lower) 1426 last = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, upper) 1427 target_counts = math_ops.range(first, last) 1428 1429 # Pre-broadcast the scalars used in the Sturm sequence for improved 1430 # performance. 1431 upper = math_ops.minimum(upper, finfo.max) 1432 lower = math_ops.maximum(lower, finfo.min) 1433 target_shape = array_ops.shape(target_counts) 1434 lower = array_ops.broadcast_to(lower, shape=target_shape) 1435 upper = array_ops.broadcast_to(upper, shape=target_shape) 1436 pivmin = array_ops.broadcast_to(pivmin, target_shape) 1437 alpha0_perturbation = array_ops.broadcast_to(alpha0_perturbation, 1438 target_shape) 1439 1440 # We compute the midpoint as 0.5*lower + 0.5*upper to avoid overflow in 1441 # (lower + upper) or (upper - lower) when the matrix has eigenvalues 1442 # with magnitude greater than finfo.max / 2. 1443 def midpoint(lower, upper): 1444 return (0.5 * lower) + (0.5 * upper) 1445 1446 def continue_binary_search(i, lower, upper): 1447 return math_ops.logical_and( 1448 math_ops.less(i, max_it), 1449 math_ops.less(abs_tol, math_ops.reduce_max(upper - lower))) 1450 1451 def binary_search_step(i, lower, upper): 1452 mid = midpoint(lower, upper) 1453 counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) 1454 lower = array_ops.where(counts <= target_counts, mid, lower) 1455 upper = array_ops.where(counts > target_counts, mid, upper) 1456 return i + 1, lower, upper 1457 1458 # Start parallel binary searches. 1459 _, lower, upper = control_flow_ops.while_loop(continue_binary_search, 1460 binary_search_step, 1461 [0, lower, upper]) 1462 return midpoint(lower, upper) 1463 1464 def _compute_eigenvectors(alpha, beta, eigvals): 1465 """Implements inverse iteration to compute eigenvectors.""" 1466 with ops.name_scope('compute_eigenvectors'): 1467 k = array_ops.size(eigvals) 1468 n = array_ops.size(alpha) 1469 alpha = math_ops.cast(alpha, dtype=beta.dtype) 1470 1471 # Eigenvectors corresponding to cluster of close eigenvalues are 1472 # not unique and need to be explicitly orthogonalized. Here we 1473 # identify such clusters. Note: This function assumes that 1474 # eigenvalues are sorted in non-decreasing order. 1475 gap = eigvals[1:] - eigvals[:-1] 1476 eps = np.finfo(eigvals.dtype.as_numpy_dtype).eps 1477 t_norm = math_ops.maximum( 1478 math_ops.abs(eigvals[0]), math_ops.abs(eigvals[-1])) 1479 gaptol = np.sqrt(eps) * t_norm 1480 # Find the beginning and end of runs of eigenvectors corresponding 1481 # to eigenvalues closer than "gaptol", which will need to be 1482 # orthogonalized against each other. 1483 close = math_ops.less(gap, gaptol) 1484 left_neighbor_close = array_ops.concat([[False], close], axis=0) 1485 right_neighbor_close = array_ops.concat([close, [False]], axis=0) 1486 ortho_interval_start = math_ops.logical_and( 1487 math_ops.logical_not(left_neighbor_close), right_neighbor_close) 1488 ortho_interval_start = array_ops.squeeze( 1489 array_ops.where_v2(ortho_interval_start), axis=-1) 1490 ortho_interval_end = math_ops.logical_and( 1491 left_neighbor_close, math_ops.logical_not(right_neighbor_close)) 1492 ortho_interval_end = array_ops.squeeze( 1493 array_ops.where_v2(ortho_interval_end), axis=-1) + 1 1494 num_clusters = array_ops.size(ortho_interval_end) 1495 1496 # We perform inverse iteration for all eigenvectors in parallel, 1497 # starting from a random set of vectors, until all have converged. 1498 v0 = math_ops.cast( 1499 stateless_random_ops.stateless_random_normal( 1500 shape=(k, n), seed=[7, 42]), 1501 dtype=beta.dtype) 1502 nrm_v = norm(v0, axis=1) 1503 v0 = v0 / nrm_v[:, array_ops.newaxis] 1504 zero_nrm = constant_op.constant(0, shape=nrm_v.shape, dtype=nrm_v.dtype) 1505 1506 # Replicate alpha-eigvals(ik) and beta across the k eigenvectors so we 1507 # can solve the k systems 1508 # [T - eigvals(i)*eye(n)] x_i = r_i 1509 # simultaneously using the batching mechanism. 1510 eigvals_cast = math_ops.cast(eigvals, dtype=beta.dtype) 1511 alpha_shifted = ( 1512 alpha[array_ops.newaxis, :] - eigvals_cast[:, array_ops.newaxis]) 1513 beta = array_ops.tile(beta[array_ops.newaxis, :], [k, 1]) 1514 diags = [beta, alpha_shifted, math_ops.conj(beta)] 1515 1516 def orthogonalize_close_eigenvectors(eigenvectors): 1517 # Eigenvectors corresponding to a cluster of close eigenvalues are not 1518 # uniquely defined, but the subspace they span is. To avoid numerical 1519 # instability, we explicitly mutually orthogonalize such eigenvectors 1520 # after each step of inverse iteration. It is customary to use 1521 # modified Gram-Schmidt for this, but this is not very efficient 1522 # on some platforms, so here we defer to the QR decomposition in 1523 # TensorFlow. 1524 def orthogonalize_cluster(cluster_idx, eigenvectors): 1525 start = ortho_interval_start[cluster_idx] 1526 end = ortho_interval_end[cluster_idx] 1527 update_indices = array_ops.expand_dims( 1528 math_ops.range(start, end), -1) 1529 vectors_in_cluster = eigenvectors[start:end, :] 1530 # We use the builtin QR factorization to orthonormalize the 1531 # vectors in the cluster. 1532 q, _ = qr(transpose(vectors_in_cluster)) 1533 vectors_to_update = transpose(q) 1534 eigenvectors = array_ops.tensor_scatter_nd_update( 1535 eigenvectors, update_indices, vectors_to_update) 1536 return cluster_idx + 1, eigenvectors 1537 1538 _, eigenvectors = control_flow_ops.while_loop( 1539 lambda i, ev: math_ops.less(i, num_clusters), 1540 orthogonalize_cluster, [0, eigenvectors]) 1541 return eigenvectors 1542 1543 def continue_iteration(i, _, nrm_v, nrm_v_old): 1544 max_it = 5 # Taken from LAPACK xSTEIN. 1545 min_norm_growth = 0.1 1546 norm_growth_factor = constant_op.constant( 1547 1 + min_norm_growth, dtype=nrm_v.dtype) 1548 # We stop the inverse iteration when we reach the maximum number of 1549 # iterations or the norm growths is less than 10%. 1550 return math_ops.logical_and( 1551 math_ops.less(i, max_it), 1552 math_ops.reduce_any( 1553 math_ops.greater_equal( 1554 math_ops.real(nrm_v), 1555 math_ops.real(norm_growth_factor * nrm_v_old)))) 1556 1557 def inverse_iteration_step(i, v, nrm_v, nrm_v_old): 1558 v = tridiagonal_solve( 1559 diags, 1560 v, 1561 diagonals_format='sequence', 1562 partial_pivoting=True, 1563 perturb_singular=True) 1564 nrm_v_old = nrm_v 1565 nrm_v = norm(v, axis=1) 1566 v = v / nrm_v[:, array_ops.newaxis] 1567 v = orthogonalize_close_eigenvectors(v) 1568 return i + 1, v, nrm_v, nrm_v_old 1569 1570 _, v, nrm_v, _ = control_flow_ops.while_loop(continue_iteration, 1571 inverse_iteration_step, 1572 [0, v0, nrm_v, zero_nrm]) 1573 return transpose(v) 1574 1575 alpha = ops.convert_to_tensor(alpha, name='alpha') 1576 n = alpha.shape[0] 1577 if n <= 1: 1578 return math_ops.real(alpha) 1579 beta = ops.convert_to_tensor(beta, name='beta') 1580 1581 if alpha.dtype != beta.dtype: 1582 raise ValueError("'alpha' and 'beta' must have the same type.") 1583 1584 eigvals = _compute_eigenvalues(alpha, beta) 1585 if eigvals_only: 1586 return eigvals 1587 1588 eigvectors = _compute_eigenvectors(alpha, beta, eigvals) 1589 return eigvals, eigvectors 1590