• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for linear algebra."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import check_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import gen_linalg_ops
31from tensorflow.python.ops import linalg_ops
32from tensorflow.python.ops import map_fn
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import special_math_ops
35from tensorflow.python.ops import stateless_random_ops
36from tensorflow.python.util import dispatch
37from tensorflow.python.util.tf_export import tf_export
38
39# Linear algebra ops.
40band_part = array_ops.matrix_band_part
41cholesky = linalg_ops.cholesky
42cholesky_solve = linalg_ops.cholesky_solve
43det = linalg_ops.matrix_determinant
44slogdet = gen_linalg_ops.log_matrix_determinant
45tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet))
46diag = array_ops.matrix_diag
47diag_part = array_ops.matrix_diag_part
48eigh = linalg_ops.self_adjoint_eig
49eigvalsh = linalg_ops.self_adjoint_eigvals
50einsum = special_math_ops.einsum
51eye = linalg_ops.eye
52inv = linalg_ops.matrix_inverse
53logm = gen_linalg_ops.matrix_logarithm
54lu = gen_linalg_ops.lu
55tf_export('linalg.logm')(dispatch.add_dispatch_support(logm))
56lstsq = linalg_ops.matrix_solve_ls
57norm = linalg_ops.norm
58qr = linalg_ops.qr
59set_diag = array_ops.matrix_set_diag
60solve = linalg_ops.matrix_solve
61sqrtm = linalg_ops.matrix_square_root
62svd = linalg_ops.svd
63tensordot = math_ops.tensordot
64trace = math_ops.trace
65transpose = array_ops.matrix_transpose
66triangular_solve = linalg_ops.matrix_triangular_solve
67
68
69@tf_export('linalg.logdet')
70@dispatch.add_dispatch_support
71def logdet(matrix, name=None):
72  """Computes log of the determinant of a hermitian positive definite matrix.
73
74  ```python
75  # Compute the determinant of a matrix while reducing the chance of over- or
76  underflow:
77  A = ... # shape 10 x 10
78  det = tf.exp(tf.linalg.logdet(A))  # scalar
79  ```
80
81  Args:
82    matrix:  A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
83      or `complex128` with shape `[..., M, M]`.
84    name:  A name to give this `Op`.  Defaults to `logdet`.
85
86  Returns:
87    The natural log of the determinant of `matrix`.
88
89  @compatibility(numpy)
90  Equivalent to numpy.linalg.slogdet, although no sign is returned since only
91  hermitian positive definite matrices are supported.
92  @end_compatibility
93  """
94  # This uses the property that the log det(A) = 2*sum(log(real(diag(C))))
95  # where C is the cholesky decomposition of A.
96  with ops.name_scope(name, 'logdet', [matrix]):
97    chol = gen_linalg_ops.cholesky(matrix)
98    return 2.0 * math_ops.reduce_sum(
99        math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))),
100        axis=[-1])
101
102
103@tf_export('linalg.adjoint')
104@dispatch.add_dispatch_support
105def adjoint(matrix, name=None):
106  """Transposes the last two dimensions of and conjugates tensor `matrix`.
107
108  For example:
109
110  ```python
111  x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
112                   [4 + 4j, 5 + 5j, 6 + 6j]])
113  tf.linalg.adjoint(x)  # [[1 - 1j, 4 - 4j],
114                        #  [2 - 2j, 5 - 5j],
115                        #  [3 - 3j, 6 - 6j]]
116  ```
117
118  Args:
119    matrix:  A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
120      or `complex128` with shape `[..., M, M]`.
121    name:  A name to give this `Op` (optional).
122
123  Returns:
124    The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of
125    matrix.
126  """
127  with ops.name_scope(name, 'adjoint', [matrix]):
128    matrix = ops.convert_to_tensor(matrix, name='matrix')
129    return array_ops.matrix_transpose(matrix, conjugate=True)
130
131
132# This section is ported nearly verbatim from Eigen's implementation:
133# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html
134def _matrix_exp_pade3(matrix):
135  """3rd-order Pade approximant for matrix exponential."""
136  b = [120.0, 60.0, 12.0]
137  b = [constant_op.constant(x, matrix.dtype) for x in b]
138  ident = linalg_ops.eye(
139      array_ops.shape(matrix)[-2],
140      batch_shape=array_ops.shape(matrix)[:-2],
141      dtype=matrix.dtype)
142  matrix_2 = math_ops.matmul(matrix, matrix)
143  tmp = matrix_2 + b[1] * ident
144  matrix_u = math_ops.matmul(matrix, tmp)
145  matrix_v = b[2] * matrix_2 + b[0] * ident
146  return matrix_u, matrix_v
147
148
149def _matrix_exp_pade5(matrix):
150  """5th-order Pade approximant for matrix exponential."""
151  b = [30240.0, 15120.0, 3360.0, 420.0, 30.0]
152  b = [constant_op.constant(x, matrix.dtype) for x in b]
153  ident = linalg_ops.eye(
154      array_ops.shape(matrix)[-2],
155      batch_shape=array_ops.shape(matrix)[:-2],
156      dtype=matrix.dtype)
157  matrix_2 = math_ops.matmul(matrix, matrix)
158  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
159  tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident
160  matrix_u = math_ops.matmul(matrix, tmp)
161  matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
162  return matrix_u, matrix_v
163
164
165def _matrix_exp_pade7(matrix):
166  """7th-order Pade approximant for matrix exponential."""
167  b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0]
168  b = [constant_op.constant(x, matrix.dtype) for x in b]
169  ident = linalg_ops.eye(
170      array_ops.shape(matrix)[-2],
171      batch_shape=array_ops.shape(matrix)[:-2],
172      dtype=matrix.dtype)
173  matrix_2 = math_ops.matmul(matrix, matrix)
174  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
175  matrix_6 = math_ops.matmul(matrix_4, matrix_2)
176  tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident
177  matrix_u = math_ops.matmul(matrix, tmp)
178  matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
179  return matrix_u, matrix_v
180
181
182def _matrix_exp_pade9(matrix):
183  """9th-order Pade approximant for matrix exponential."""
184  b = [
185      17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0,
186      2162160.0, 110880.0, 3960.0, 90.0
187  ]
188  b = [constant_op.constant(x, matrix.dtype) for x in b]
189  ident = linalg_ops.eye(
190      array_ops.shape(matrix)[-2],
191      batch_shape=array_ops.shape(matrix)[:-2],
192      dtype=matrix.dtype)
193  matrix_2 = math_ops.matmul(matrix, matrix)
194  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
195  matrix_6 = math_ops.matmul(matrix_4, matrix_2)
196  matrix_8 = math_ops.matmul(matrix_6, matrix_2)
197  tmp = (
198      matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 +
199      b[1] * ident)
200  matrix_u = math_ops.matmul(matrix, tmp)
201  matrix_v = (
202      b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 +
203      b[0] * ident)
204  return matrix_u, matrix_v
205
206
207def _matrix_exp_pade13(matrix):
208  """13th-order Pade approximant for matrix exponential."""
209  b = [
210      64764752532480000.0, 32382376266240000.0, 7771770303897600.0,
211      1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0,
212      33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0
213  ]
214  b = [constant_op.constant(x, matrix.dtype) for x in b]
215  ident = linalg_ops.eye(
216      array_ops.shape(matrix)[-2],
217      batch_shape=array_ops.shape(matrix)[:-2],
218      dtype=matrix.dtype)
219  matrix_2 = math_ops.matmul(matrix, matrix)
220  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
221  matrix_6 = math_ops.matmul(matrix_4, matrix_2)
222  tmp_u = (
223      math_ops.matmul(matrix_6, matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) +
224      b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident)
225  matrix_u = math_ops.matmul(matrix, tmp_u)
226  tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2
227  matrix_v = (
228      math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 +
229      b[2] * matrix_2 + b[0] * ident)
230  return matrix_u, matrix_v
231
232
233@tf_export('linalg.expm')
234@dispatch.add_dispatch_support
235def matrix_exponential(input, name=None):  # pylint: disable=redefined-builtin
236  r"""Computes the matrix exponential of one or more square matrices.
237
238  $$exp(A) = \sum_{n=0}^\infty A^n/n!$$
239
240  The exponential is computed using a combination of the scaling and squaring
241  method and the Pade approximation. Details can be found in:
242  Nicholas J. Higham, "The scaling and squaring method for the matrix
243  exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
244
245  The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
246  form square matrices. The output is a tensor of the same shape as the input
247  containing the exponential for all input submatrices `[..., :, :]`.
248
249  Args:
250    input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, or
251      `complex128` with shape `[..., M, M]`.
252    name:  A name to give this `Op` (optional).
253
254  Returns:
255    the matrix exponential of the input.
256
257  Raises:
258    ValueError: An unsupported type is provided as input.
259
260  @compatibility(scipy)
261  Equivalent to scipy.linalg.expm
262  @end_compatibility
263  """
264  with ops.name_scope(name, 'matrix_exponential', [input]):
265    matrix = ops.convert_to_tensor(input, name='input')
266    if matrix.shape[-2:] == [0, 0]:
267      return matrix
268    batch_shape = matrix.shape[:-2]
269    if not batch_shape.is_fully_defined():
270      batch_shape = array_ops.shape(matrix)[:-2]
271
272    # reshaping the batch makes the where statements work better
273    matrix = array_ops.reshape(
274        matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0))
275    l1_norm = math_ops.reduce_max(
276        math_ops.reduce_sum(
277            math_ops.abs(matrix),
278            axis=array_ops.size(array_ops.shape(matrix)) - 2),
279        axis=-1)[..., array_ops.newaxis, array_ops.newaxis]
280
281    const = lambda x: constant_op.constant(x, l1_norm.dtype)
282
283    def _nest_where(vals, cases):
284      assert len(vals) == len(cases) - 1
285      if len(vals) == 1:
286        return array_ops.where_v2(
287            math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
288      else:
289        return array_ops.where_v2(
290            math_ops.less(l1_norm, const(vals[0])), cases[0],
291            _nest_where(vals[1:], cases[1:]))
292
293    if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]:
294      maxnorm = const(3.925724783138660)
295      squarings = math_ops.maximum(
296          math_ops.floor(
297              math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
298      u3, v3 = _matrix_exp_pade3(matrix)
299      u5, v5 = _matrix_exp_pade5(matrix)
300      u7, v7 = _matrix_exp_pade7(
301          matrix /
302          math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
303      conds = (4.258730016922831e-001, 1.880152677804762e+000)
304      u = _nest_where(conds, (u3, u5, u7))
305      v = _nest_where(conds, (v3, v5, v7))
306    elif matrix.dtype in [dtypes.float64, dtypes.complex128]:
307      maxnorm = const(5.371920351148152)
308      squarings = math_ops.maximum(
309          math_ops.floor(
310              math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
311      u3, v3 = _matrix_exp_pade3(matrix)
312      u5, v5 = _matrix_exp_pade5(matrix)
313      u7, v7 = _matrix_exp_pade7(matrix)
314      u9, v9 = _matrix_exp_pade9(matrix)
315      u13, v13 = _matrix_exp_pade13(
316          matrix /
317          math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
318      conds = (1.495585217958292e-002, 2.539398330063230e-001,
319               9.504178996162932e-001, 2.097847961257068e+000)
320      u = _nest_where(conds, (u3, u5, u7, u9, u13))
321      v = _nest_where(conds, (v3, v5, v7, v9, v13))
322    else:
323      raise ValueError('tf.linalg.expm does not support matrices of type %s' %
324                       matrix.dtype)
325
326    is_finite = math_ops.is_finite(math_ops.reduce_max(l1_norm))
327    nan = constant_op.constant(np.nan, matrix.dtype)
328    result = control_flow_ops.cond(
329        is_finite, lambda: linalg_ops.matrix_solve(-u + v, u + v),
330        lambda: array_ops.fill(array_ops.shape(matrix), nan))
331    max_squarings = math_ops.reduce_max(squarings)
332    i = const(0.0)
333
334    def c(i, _):
335      return control_flow_ops.cond(is_finite,
336                                   lambda: math_ops.less(i, max_squarings),
337                                   lambda: constant_op.constant(False))
338
339    def b(i, r):
340      return i + 1, array_ops.where_v2(
341          math_ops.less(i, squarings), math_ops.matmul(r, r), r)
342
343    _, result = control_flow_ops.while_loop(c, b, [i, result])
344    if not matrix.shape.is_fully_defined():
345      return array_ops.reshape(
346          result,
347          array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0))
348    return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:]))
349
350
351@tf_export('linalg.banded_triangular_solve', v1=[])
352def banded_triangular_solve(
353    bands,
354    rhs,
355    lower=True,
356    adjoint=False,  # pylint: disable=redefined-outer-name
357    name=None):
358  r"""Solve triangular systems of equations with a banded solver.
359
360  `bands` is a tensor of shape `[..., K, M]`, where `K` represents the number
361  of bands stored. This corresponds to a batch of `M` by `M` matrices, whose
362  `K` subdiagonals (when `lower` is `True`) are stored.
363
364  This operator broadcasts the batch dimensions of `bands` and the batch
365  dimensions of `rhs`.
366
367
368  Examples:
369
370  Storing 2 bands of a 3x3 matrix.
371  Note that first element in the second row is ignored due to
372  the 'LEFT_RIGHT' padding.
373
374  >>> x = [[2., 3., 4.], [1., 2., 3.]]
375  >>> x2 = [[2., 3., 4.], [10000., 2., 3.]]
376  >>> y = tf.zeros([3, 3])
377  >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0))
378  >>> z
379  <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
380  array([[2., 0., 0.],
381         [2., 3., 0.],
382         [0., 3., 4.]], dtype=float32)>
383  >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1]))
384  >>> soln
385  <tf.Tensor: shape=(3, 1), dtype=float32, numpy=
386  array([[0.5 ],
387         [0.  ],
388         [0.25]], dtype=float32)>
389  >>> are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1]))
390  >>> tf.reduce_all(are_equal).numpy()
391  True
392  >>> are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1]))
393  >>> tf.reduce_all(are_equal).numpy()
394  True
395
396  Storing 2 superdiagonals of a 4x4 matrix. Because of the 'LEFT_RIGHT' padding
397  the last element of the first row is ignored.
398
399  >>> x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]]
400  >>> y = tf.zeros([4, 4])
401  >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1))
402  >>> z
403  <tf.Tensor: shape=(4, 4), dtype=float32, numpy=
404  array([[-1.,  2.,  0.,  0.],
405         [ 0., -2.,  3.,  0.],
406         [ 0.,  0., -3.,  4.],
407         [ 0.,  0., -0., -4.]], dtype=float32)>
408  >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False)
409  >>> soln
410  <tf.Tensor: shape=(4, 1), dtype=float32, numpy=
411  array([[-4.       ],
412         [-1.5      ],
413         [-0.6666667],
414         [-0.25     ]], dtype=float32)>
415  >>> are_equal = (soln == tf.linalg.triangular_solve(
416  ...   z, tf.ones([4, 1]), lower=False))
417  >>> tf.reduce_all(are_equal).numpy()
418  True
419
420
421  Args:
422    bands: A `Tensor` describing the bands of the left hand side, with shape
423      `[..., K, M]`. The `K` rows correspond to the diagonal to the `K - 1`-th
424      diagonal (the diagonal is the top row) when `lower` is `True` and
425      otherwise the `K - 1`-th superdiagonal to the diagonal (the diagonal is
426      the bottom row) when `lower` is `False`. The bands are stored with
427      'LEFT_RIGHT' alignment, where the superdiagonals are padded on the right
428      and subdiagonals are padded on the left. This is the alignment cuSPARSE
429      uses.  See  `tf.linalg.set_diag` for more details.
430    rhs: A `Tensor` of shape [..., M] or [..., M, N] and with the same dtype as
431      `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
432      statically, `rhs` will be treated as a matrix rather than a vector.
433    lower: An optional `bool`. Defaults to `True`. Boolean indicating whether
434      `bands` represents a lower or upper triangular matrix.
435    adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether
436      to solve with the matrix's block-wise adjoint.
437    name:  A name to give this `Op` (optional).
438
439  Returns:
440    A `Tensor` of shape [..., M] or [..., M, N] containing the solutions.
441  """
442  with ops.name_scope(name, 'banded_triangular_solve', [bands, rhs]):
443    return gen_linalg_ops.banded_triangular_solve(
444        bands, rhs, lower=lower, adjoint=adjoint)
445
446
447@tf_export('linalg.tridiagonal_solve')
448@dispatch.add_dispatch_support
449def tridiagonal_solve(diagonals,
450                      rhs,
451                      diagonals_format='compact',
452                      transpose_rhs=False,
453                      conjugate_rhs=False,
454                      name=None,
455                      partial_pivoting=True,
456                      perturb_singular=False):
457  r"""Solves tridiagonal systems of equations.
458
459  The input can be supplied in various formats: `matrix`, `sequence` and
460  `compact`, specified by the `diagonals_format` arg.
461
462  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
463  two inner-most dimensions representing the square tridiagonal matrices.
464  Elements outside of the three diagonals will be ignored.
465
466  In `sequence` format, `diagonals` are supplied as a tuple or list of three
467  tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing
468  superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either
469  `M-1` or `M`; in the latter case, the last element of superdiagonal and the
470  first element of subdiagonal will be ignored.
471
472  In `compact` format the three diagonals are brought together into one tensor
473  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
474  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
475  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.
476
477  The `compact` format is recommended as the one with best performance. In case
478  you need to cast a tensor into a compact format manually, use `tf.gather_nd`.
479  An example for a tensor of shape [m, m]:
480
481  ```python
482  rhs = tf.constant([...])
483  matrix = tf.constant([[...]])
484  m = matrix.shape[0]
485  dummy_idx = [0, 0]  # An arbitrary element to use as a dummy
486  indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx],  # Superdiagonal
487           [[i, i] for i in range(m)],                          # Diagonal
488           [dummy_idx] + [[i + 1, i] for i in range(m - 1)]]    # Subdiagonal
489  diagonals=tf.gather_nd(matrix, indices)
490  x = tf.linalg.tridiagonal_solve(diagonals, rhs)
491  ```
492
493  Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or
494  `[..., M, K]`. The latter allows to simultaneously solve K systems with the
495  same left-hand sides and K different right-hand sides. If `transpose_rhs`
496  is set to `True` the expected shape is `[..., M]` or `[..., K, M]`.
497
498  The batch dimensions, denoted as `...`, must be the same in `diagonals` and
499  `rhs`.
500
501  The output is a tensor of the same shape as `rhs`: either `[..., M]` or
502  `[..., M, K]`.
503
504  The op isn't guaranteed to raise an error if the input matrix is not
505  invertible. `tf.debugging.check_numerics` can be applied to the output to
506  detect invertibility problems.
507
508  **Note**: with large batch sizes, the computation on the GPU may be slow, if
509  either `partial_pivoting=True` or there are multiple right-hand sides
510  (`K > 1`). If this issue arises, consider if it's possible to disable pivoting
511  and have `K = 1`, or, alternatively, consider using CPU.
512
513  On CPU, solution is computed via Gaussian elimination with or without partial
514  pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE
515  library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
516
517  Args:
518    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
519      shape depends of `diagonals_format`, see description above. Must be
520      `float32`, `float64`, `complex64`, or `complex128`.
521    rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as
522      `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
523      statically, `rhs` will be treated as a matrix rather than a vector.
524    diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
525      `compact`.
526    transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect
527      if the shape of rhs is [..., M]).
528    conjugate_rhs: If `True`, `rhs` is conjugated before solving.
529    name:  A name to give this `Op` (optional).
530    partial_pivoting: whether to perform partial pivoting. `True` by default.
531      Partial pivoting makes the procedure more stable, but slower. Partial
532      pivoting is unnecessary in some cases, including diagonally dominant and
533      symmetric positive definite matrices (see e.g. theorem 9.12 in [1]).
534    perturb_singular: whether to perturb singular matrices to return a finite
535      result. `False` by default. If true, solutions to systems involving
536      a singular matrix will be computed by perturbing near-zero pivots in
537      the partially pivoted LU decomposition. Specifically, tiny pivots are
538      perturbed by an amount of order `eps * max_{ij} |U(i,j)|` to avoid
539      overflow. Here `U` is the upper triangular part of the LU decomposition,
540      and `eps` is the machine precision. This is useful for solving
541      numerically singular systems when computing eigenvectors by inverse
542      iteration.
543      If `partial_pivoting` is `False`, `perturb_singular` must be `False` as
544      well.
545
546  Returns:
547    A `Tensor` of shape [..., M] or [..., M, K] containing the solutions.
548    If the input matrix is singular, the result is undefined.
549
550  Raises:
551    ValueError: Is raised if any of the following conditions hold:
552      1. An unsupported type is provided as input,
553      2. the input tensors have incorrect shapes,
554      3. `perturb_singular` is `True` but `partial_pivoting` is not.
555    UnimplementedError: Whenever `partial_pivoting` is true and the backend is
556      XLA, or whenever `perturb_singular` is true and the backend is
557      XLA or GPU.
558
559  [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
560  Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.
561
562  """
563  if perturb_singular and not partial_pivoting:
564    raise ValueError('partial_pivoting must be True if perturb_singular is.')
565
566  if diagonals_format == 'compact':
567    return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
568                                             conjugate_rhs, partial_pivoting,
569                                             perturb_singular, name)
570
571  if diagonals_format == 'sequence':
572    if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3:
573      raise ValueError('Expected diagonals to be a sequence of length 3.')
574
575    superdiag, maindiag, subdiag = diagonals
576    if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1]) or
577        not superdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])):
578      raise ValueError(
579          'Tensors representing the three diagonals must have the same shape,'
580          'except for the last dimension, got {}, {}, {}'.format(
581              subdiag.shape, maindiag.shape, superdiag.shape))
582
583    m = tensor_shape.dimension_value(maindiag.shape[-1])
584
585    def pad_if_necessary(t, name, last_dim_padding):
586      n = tensor_shape.dimension_value(t.shape[-1])
587      if not n or n == m:
588        return t
589      if n == m - 1:
590        paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] +
591                    [last_dim_padding])
592        return array_ops.pad(t, paddings)
593      raise ValueError('Expected {} to be have length {} or {}, got {}.'.format(
594          name, m, m - 1, n))
595
596    subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0])
597    superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1])
598
599    diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2)
600    return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
601                                             conjugate_rhs, partial_pivoting,
602                                             perturb_singular, name)
603
604  if diagonals_format == 'matrix':
605    m1 = tensor_shape.dimension_value(diagonals.shape[-1])
606    m2 = tensor_shape.dimension_value(diagonals.shape[-2])
607    if m1 and m2 and m1 != m2:
608      raise ValueError(
609          'Expected last two dimensions of diagonals to be same, got {} and {}'
610          .format(m1, m2))
611    m = m1 or m2
612    diagonals = array_ops.matrix_diag_part(
613        diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
614    return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
615                                             conjugate_rhs, partial_pivoting,
616                                             perturb_singular, name)
617
618  raise ValueError('Unrecognized diagonals_format: {}'.format(diagonals_format))
619
620
621def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
622                                      conjugate_rhs, partial_pivoting,
623                                      perturb_singular, name):
624  """Helper function used after the input has been cast to compact form."""
625  diags_rank, rhs_rank = diagonals.shape.rank, rhs.shape.rank
626
627  # If we know the rank of the diagonal tensor, do some static checking.
628  if diags_rank:
629    if diags_rank < 2:
630      raise ValueError(
631          'Expected diagonals to have rank at least 2, got {}'.format(
632              diags_rank))
633    if rhs_rank and rhs_rank != diags_rank and rhs_rank != diags_rank - 1:
634      raise ValueError('Expected the rank of rhs to be {} or {}, got {}'.format(
635          diags_rank - 1, diags_rank, rhs_rank))
636    if (rhs_rank and not diagonals.shape[:-2].is_compatible_with(
637        rhs.shape[:diags_rank - 2])):
638      raise ValueError('Batch shapes {} and {} are incompatible'.format(
639          diagonals.shape[:-2], rhs.shape[:diags_rank - 2]))
640
641  if diagonals.shape[-2] and diagonals.shape[-2] != 3:
642    raise ValueError('Expected 3 diagonals got {}'.format(diagonals.shape[-2]))
643
644  def check_num_lhs_matches_num_rhs():
645    if (diagonals.shape[-1] and rhs.shape[-2] and
646        diagonals.shape[-1] != rhs.shape[-2]):
647      raise ValueError('Expected number of left-hand sided and right-hand '
648                       'sides to be equal, got {} and {}'.format(
649                           diagonals.shape[-1], rhs.shape[-2]))
650
651  if rhs_rank and diags_rank and rhs_rank == diags_rank - 1:
652    # Rhs provided as a vector, ignoring transpose_rhs
653    if conjugate_rhs:
654      rhs = math_ops.conj(rhs)
655    rhs = array_ops.expand_dims(rhs, -1)
656    check_num_lhs_matches_num_rhs()
657    return array_ops.squeeze(
658        linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting,
659                                     perturb_singular, name), -1)
660
661  if transpose_rhs:
662    rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs)
663  elif conjugate_rhs:
664    rhs = math_ops.conj(rhs)
665
666  check_num_lhs_matches_num_rhs()
667  return linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting,
668                                      perturb_singular, name)
669
670
671@tf_export('linalg.tridiagonal_matmul')
672@dispatch.add_dispatch_support
673def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
674  r"""Multiplies tridiagonal matrix by matrix.
675
676  `diagonals` is representation of 3-diagonal NxN matrix, which depends on
677  `diagonals_format`.
678
679  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
680  two inner-most dimensions representing the square tridiagonal matrices.
681  Elements outside of the three diagonals will be ignored.
682
683  If `sequence` format, `diagonals` is list or tuple of three tensors:
684  `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element
685  of `superdiag` first element of `subdiag` are ignored.
686
687  In `compact` format the three diagonals are brought together into one tensor
688  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
689  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
690  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.
691
692  The `sequence` format is recommended as the one with the best performance.
693
694  `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`.
695
696  Example:
697
698  ```python
699  superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
700  maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
701  subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
702  diagonals = [superdiag, maindiag, subdiag]
703  rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64)
704  x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
705  ```
706
707  Args:
708    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
709      shape depends of `diagonals_format`, see description above. Must be
710      `float32`, `float64`, `complex64`, or `complex128`.
711    rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`.
712    diagonals_format: one of `sequence`, or `compact`. Default is `compact`.
713    name:  A name to give this `Op` (optional).
714
715  Returns:
716    A `Tensor` of shape [..., M, N] containing the result of multiplication.
717
718  Raises:
719    ValueError: An unsupported type is provided as input, or when the input
720    tensors have incorrect shapes.
721  """
722  if diagonals_format == 'compact':
723    superdiag = diagonals[..., 0, :]
724    maindiag = diagonals[..., 1, :]
725    subdiag = diagonals[..., 2, :]
726  elif diagonals_format == 'sequence':
727    superdiag, maindiag, subdiag = diagonals
728  elif diagonals_format == 'matrix':
729    m1 = tensor_shape.dimension_value(diagonals.shape[-1])
730    m2 = tensor_shape.dimension_value(diagonals.shape[-2])
731    if m1 and m2 and m1 != m2:
732      raise ValueError(
733          'Expected last two dimensions of diagonals to be same, got {} and {}'
734          .format(m1, m2))
735    diags = array_ops.matrix_diag_part(
736        diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
737    superdiag = diags[..., 0, :]
738    maindiag = diags[..., 1, :]
739    subdiag = diags[..., 2, :]
740  else:
741    raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format)
742
743  # C++ backend requires matrices.
744  # Converting 1-dimensional vectors to matrices with 1 row.
745  superdiag = array_ops.expand_dims(superdiag, -2)
746  maindiag = array_ops.expand_dims(maindiag, -2)
747  subdiag = array_ops.expand_dims(subdiag, -2)
748
749  return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name)
750
751
752def _maybe_validate_matrix(a, validate_args):
753  """Checks that input is a `float` matrix."""
754  assertions = []
755  if not a.dtype.is_floating:
756    raise TypeError('Input `a` must have `float`-like `dtype` '
757                    '(saw {}).'.format(a.dtype.name))
758  if a.shape is not None and a.shape.rank is not None:
759    if a.shape.rank < 2:
760      raise ValueError('Input `a` must have at least 2 dimensions '
761                       '(saw: {}).'.format(a.shape.rank))
762  elif validate_args:
763    assertions.append(
764        check_ops.assert_rank_at_least(
765            a, rank=2, message='Input `a` must have at least 2 dimensions.'))
766  return assertions
767
768
769@tf_export('linalg.matrix_rank')
770@dispatch.add_dispatch_support
771def matrix_rank(a, tol=None, validate_args=False, name=None):
772  """Compute the matrix rank of one or more matrices.
773
774  Args:
775    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
776      pseudo-inverted.
777    tol: Threshold below which the singular value is counted as 'zero'.
778      Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`).
779    validate_args: When `True`, additional assertions might be embedded in the
780      graph.
781      Default value: `False` (i.e., no graph assertions are added).
782    name: Python `str` prefixed to ops created by this function.
783      Default value: 'matrix_rank'.
784
785  Returns:
786    matrix_rank: (Batch of) `int32` scalars representing the number of non-zero
787      singular values.
788  """
789  with ops.name_scope(name or 'matrix_rank'):
790    a = ops.convert_to_tensor(a, dtype_hint=dtypes.float32, name='a')
791    assertions = _maybe_validate_matrix(a, validate_args)
792    if assertions:
793      with ops.control_dependencies(assertions):
794        a = array_ops.identity(a)
795    s = svd(a, compute_uv=False)
796    if tol is None:
797      if (a.shape[-2:]).is_fully_defined():
798        m = np.max(a.shape[-2:].as_list())
799      else:
800        m = math_ops.reduce_max(array_ops.shape(a)[-2:])
801      eps = np.finfo(a.dtype.as_numpy_dtype).eps
802      tol = (
803          eps * math_ops.cast(m, a.dtype) *
804          math_ops.reduce_max(s, axis=-1, keepdims=True))
805    return math_ops.reduce_sum(math_ops.cast(s > tol, dtypes.int32), axis=-1)
806
807
808@tf_export('linalg.pinv')
809@dispatch.add_dispatch_support
810def pinv(a, rcond=None, validate_args=False, name=None):
811  """Compute the Moore-Penrose pseudo-inverse of one or more matrices.
812
813  Calculate the [generalized inverse of a matrix](
814  https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its
815  singular-value decomposition (SVD) and including all large singular values.
816
817  The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves'
818  [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then
819  `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if
820  `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then
821  `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1]
822
823  This function is analogous to [`numpy.linalg.pinv`](
824  https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html).
825  It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
826  default `rcond` is `1e-15`. Here the default is
827  `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.
828
829  Args:
830    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
831      pseudo-inverted.
832    rcond: `Tensor` of small singular value cutoffs.  Singular values smaller
833      (in modulus) than `rcond` * largest_singular_value (again, in modulus) are
834      set to zero. Must broadcast against `tf.shape(a)[:-2]`.
835      Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`.
836    validate_args: When `True`, additional assertions might be embedded in the
837      graph.
838      Default value: `False` (i.e., no graph assertions are added).
839    name: Python `str` prefixed to ops created by this function.
840      Default value: 'pinv'.
841
842  Returns:
843    a_pinv: (Batch of) pseudo-inverse of input `a`. Has same shape as `a` except
844      rightmost two dimensions are transposed.
845
846  Raises:
847    TypeError: if input `a` does not have `float`-like `dtype`.
848    ValueError: if input `a` has fewer than 2 dimensions.
849
850  #### Examples
851
852  ```python
853  import tensorflow as tf
854  import tensorflow_probability as tfp
855
856  a = tf.constant([[1.,  0.4,  0.5],
857                   [0.4, 0.2,  0.25],
858                   [0.5, 0.25, 0.35]])
859  tf.matmul(tf.linalg..pinv(a), a)
860  # ==> array([[1., 0., 0.],
861               [0., 1., 0.],
862               [0., 0., 1.]], dtype=float32)
863
864  a = tf.constant([[1.,  0.4,  0.5,  1.],
865                   [0.4, 0.2,  0.25, 2.],
866                   [0.5, 0.25, 0.35, 3.]])
867  tf.matmul(tf.linalg..pinv(a), a)
868  # ==> array([[ 0.76,  0.37,  0.21, -0.02],
869               [ 0.37,  0.43, -0.33,  0.02],
870               [ 0.21, -0.33,  0.81,  0.01],
871               [-0.02,  0.02,  0.01,  1.  ]], dtype=float32)
872  ```
873
874  #### References
875
876  [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press,
877       Inc., 1980, pp. 139-142.
878  """
879  with ops.name_scope(name or 'pinv'):
880    a = ops.convert_to_tensor(a, name='a')
881
882    assertions = _maybe_validate_matrix(a, validate_args)
883    if assertions:
884      with ops.control_dependencies(assertions):
885        a = array_ops.identity(a)
886
887    dtype = a.dtype.as_numpy_dtype
888
889    if rcond is None:
890
891      def get_dim_size(dim):
892        dim_val = tensor_shape.dimension_value(a.shape[dim])
893        if dim_val is not None:
894          return dim_val
895        return array_ops.shape(a)[dim]
896
897      num_rows = get_dim_size(-2)
898      num_cols = get_dim_size(-1)
899      if isinstance(num_rows, int) and isinstance(num_cols, int):
900        max_rows_cols = float(max(num_rows, num_cols))
901      else:
902        max_rows_cols = math_ops.cast(
903            math_ops.maximum(num_rows, num_cols), dtype)
904      rcond = 10. * max_rows_cols * np.finfo(dtype).eps
905
906    rcond = ops.convert_to_tensor(rcond, dtype=dtype, name='rcond')
907
908    # Calculate pseudo inverse via SVD.
909    # Note: if a is Hermitian then u == v. (We might observe additional
910    # performance by explicitly setting `v = u` in such cases.)
911    [
912        singular_values,  # Sigma
913        left_singular_vectors,  # U
914        right_singular_vectors,  # V
915    ] = svd(
916        a, full_matrices=False, compute_uv=True)
917
918    # Saturate small singular values to inf. This has the effect of make
919    # `1. / s = 0.` while not resulting in `NaN` gradients.
920    cutoff = rcond * math_ops.reduce_max(singular_values, axis=-1)
921    singular_values = array_ops.where_v2(
922        singular_values > array_ops.expand_dims_v2(cutoff, -1), singular_values,
923        np.array(np.inf, dtype))
924
925    # By the definition of the SVD, `a == u @ s @ v^H`, and the pseudo-inverse
926    # is defined as `pinv(a) == v @ inv(s) @ u^H`.
927    a_pinv = math_ops.matmul(
928        right_singular_vectors / array_ops.expand_dims_v2(singular_values, -2),
929        left_singular_vectors,
930        adjoint_b=True)
931
932    if a.shape is not None and a.shape.rank is not None:
933      a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]]))
934
935    return a_pinv
936
937
938@tf_export('linalg.lu_solve')
939@dispatch.add_dispatch_support
940def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
941  """Solves systems of linear eqns `A X = RHS`, given LU factorizations.
942
943  Note: this function does not verify the implied matrix is actually invertible
944  nor is this condition checked even when `validate_args=True`.
945
946  Args:
947    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
948      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
949    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
950      X` then `perm = argmax(P)`.
951    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
952      `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[...,
953        tf.newaxis])[..., 0]`.
954    validate_args: Python `bool` indicating whether arguments should be checked
955      for correctness. Note: this function does not verify the implied matrix is
956        actually invertible, even when `validate_args=True`.
957      Default value: `False` (i.e., don't validate arguments).
958    name: Python `str` name given to ops managed by this object.
959      Default value: `None` (i.e., 'lu_solve').
960
961  Returns:
962    x: The `X` in `A @ X = RHS`.
963
964  #### Examples
965
966  ```python
967  import numpy as np
968  import tensorflow as tf
969  import tensorflow_probability as tfp
970
971  x = [[[1., 2],
972        [3, 4]],
973       [[7, 8],
974        [3, 4]]]
975  inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
976  tf.assert_near(tf.matrix_inverse(x), inv_x)
977  # ==> True
978  ```
979
980  """
981
982  with ops.name_scope(name or 'lu_solve'):
983    lower_upper = ops.convert_to_tensor(
984        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
985    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
986    rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs')
987
988    assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args)
989    if assertions:
990      with ops.control_dependencies(assertions):
991        lower_upper = array_ops.identity(lower_upper)
992        perm = array_ops.identity(perm)
993        rhs = array_ops.identity(rhs)
994
995    if (rhs.shape.rank == 2 and perm.shape.rank == 1):
996      # Both rhs and perm have scalar batch_shape.
997      permuted_rhs = array_ops.gather(rhs, perm, axis=-2)
998    else:
999      # Either rhs or perm have non-scalar batch_shape or we can't determine
1000      # this information statically.
1001      rhs_shape = array_ops.shape(rhs)
1002      broadcast_batch_shape = array_ops.broadcast_dynamic_shape(
1003          rhs_shape[:-2],
1004          array_ops.shape(perm)[:-1])
1005      d, m = rhs_shape[-2], rhs_shape[-1]
1006      rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]],
1007                                             axis=0)
1008
1009      # Tile out rhs.
1010      broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape)
1011      broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m])
1012
1013      # Tile out perm and add batch indices.
1014      broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1])
1015      broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d])
1016      broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape)
1017      broadcast_batch_indices = array_ops.broadcast_to(
1018          math_ops.range(broadcast_batch_size)[:, array_ops.newaxis],
1019          [broadcast_batch_size, d])
1020      broadcast_perm = array_ops.stack(
1021          [broadcast_batch_indices, broadcast_perm], axis=-1)
1022
1023      permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm)
1024      permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape)
1025
1026    lower = set_diag(
1027        band_part(lower_upper, num_lower=-1, num_upper=0),
1028        array_ops.ones(
1029            array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
1030    return triangular_solve(
1031        lower_upper,  # Only upper is accessed.
1032        triangular_solve(lower, permuted_rhs),
1033        lower=False)
1034
1035
1036@tf_export('linalg.lu_matrix_inverse')
1037@dispatch.add_dispatch_support
1038def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
1039  """Computes the inverse given the LU decomposition(s) of one or more matrices.
1040
1041  This op is conceptually identical to,
1042
1043  ```python
1044  inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X))
1045  tf.assert_near(tf.matrix_inverse(X), inv_X)
1046  # ==> True
1047  ```
1048
1049  Note: this function does not verify the implied matrix is actually invertible
1050  nor is this condition checked even when `validate_args=True`.
1051
1052  Args:
1053    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
1054      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
1055    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
1056      X` then `perm = argmax(P)`.
1057    validate_args: Python `bool` indicating whether arguments should be checked
1058      for correctness. Note: this function does not verify the implied matrix is
1059        actually invertible, even when `validate_args=True`.
1060      Default value: `False` (i.e., don't validate arguments).
1061    name: Python `str` name given to ops managed by this object.
1062      Default value: `None` (i.e., 'lu_matrix_inverse').
1063
1064  Returns:
1065    inv_x: The matrix_inv, i.e.,
1066      `tf.matrix_inverse(tf.linalg.lu_reconstruct(lu, perm))`.
1067
1068  #### Examples
1069
1070  ```python
1071  import numpy as np
1072  import tensorflow as tf
1073  import tensorflow_probability as tfp
1074
1075  x = [[[3., 4], [1, 2]],
1076       [[7., 8], [3, 4]]]
1077  inv_x = tf.linalg.lu_matrix_inverse(*tf.linalg.lu(x))
1078  tf.assert_near(tf.matrix_inverse(x), inv_x)
1079  # ==> True
1080  ```
1081
1082  """
1083
1084  with ops.name_scope(name or 'lu_matrix_inverse'):
1085    lower_upper = ops.convert_to_tensor(
1086        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
1087    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
1088    assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
1089    if assertions:
1090      with ops.control_dependencies(assertions):
1091        lower_upper = array_ops.identity(lower_upper)
1092        perm = array_ops.identity(perm)
1093    shape = array_ops.shape(lower_upper)
1094    return lu_solve(
1095        lower_upper,
1096        perm,
1097        rhs=eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype),
1098        validate_args=False)
1099
1100
1101@tf_export('linalg.lu_reconstruct')
1102@dispatch.add_dispatch_support
1103def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
1104  """The reconstruct one or more matrices from their LU decomposition(s).
1105
1106  Args:
1107    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
1108      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
1109    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
1110      X` then `perm = argmax(P)`.
1111    validate_args: Python `bool` indicating whether arguments should be checked
1112      for correctness.
1113      Default value: `False` (i.e., don't validate arguments).
1114    name: Python `str` name given to ops managed by this object.
1115      Default value: `None` (i.e., 'lu_reconstruct').
1116
1117  Returns:
1118    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
1119      `lu_reconstruct(*tf.linalg.lu(x))`.
1120
1121  #### Examples
1122
1123  ```python
1124  import numpy as np
1125  import tensorflow as tf
1126  import tensorflow_probability as tfp
1127
1128  x = [[[3., 4], [1, 2]],
1129       [[7., 8], [3, 4]]]
1130  x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x))
1131  tf.assert_near(x, x_reconstructed)
1132  # ==> True
1133  ```
1134
1135  """
1136  with ops.name_scope(name or 'lu_reconstruct'):
1137    lower_upper = ops.convert_to_tensor(
1138        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
1139    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
1140
1141    assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
1142    if assertions:
1143      with ops.control_dependencies(assertions):
1144        lower_upper = array_ops.identity(lower_upper)
1145        perm = array_ops.identity(perm)
1146
1147    shape = array_ops.shape(lower_upper)
1148
1149    lower = set_diag(
1150        band_part(lower_upper, num_lower=-1, num_upper=0),
1151        array_ops.ones(shape[:-1], dtype=lower_upper.dtype))
1152    upper = band_part(lower_upper, num_lower=0, num_upper=-1)
1153    x = math_ops.matmul(lower, upper)
1154
1155    if (lower_upper.shape is None or lower_upper.shape.rank is None or
1156        lower_upper.shape.rank != 2):
1157      # We either don't know the batch rank or there are >0 batch dims.
1158      batch_size = math_ops.reduce_prod(shape[:-2])
1159      d = shape[-1]
1160      x = array_ops.reshape(x, [batch_size, d, d])
1161      perm = array_ops.reshape(perm, [batch_size, d])
1162      perm = map_fn.map_fn(array_ops.invert_permutation, perm)
1163      batch_indices = array_ops.broadcast_to(
1164          math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d])
1165      x = array_ops.gather_nd(x, array_ops.stack([batch_indices, perm],
1166                                                 axis=-1))
1167      x = array_ops.reshape(x, shape)
1168    else:
1169      x = array_ops.gather(x, array_ops.invert_permutation(perm))
1170
1171    x.set_shape(lower_upper.shape)
1172    return x
1173
1174
1175def lu_reconstruct_assertions(lower_upper, perm, validate_args):
1176  """Returns list of assertions related to `lu_reconstruct` assumptions."""
1177  assertions = []
1178
1179  message = 'Input `lower_upper` must have at least 2 dimensions.'
1180  if lower_upper.shape.rank is not None and lower_upper.shape.rank < 2:
1181    raise ValueError(message)
1182  elif validate_args:
1183    assertions.append(
1184        check_ops.assert_rank_at_least_v2(lower_upper, rank=2, message=message))
1185
1186  message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
1187  if lower_upper.shape.rank is not None and perm.shape.rank is not None:
1188    if lower_upper.shape.rank != perm.shape.rank + 1:
1189      raise ValueError(message)
1190  elif validate_args:
1191    assertions.append(
1192        check_ops.assert_rank(
1193            lower_upper, rank=array_ops.rank(perm) + 1, message=message))
1194
1195  message = '`lower_upper` must be square.'
1196  if lower_upper.shape[:-2].is_fully_defined():
1197    if lower_upper.shape[-2] != lower_upper.shape[-1]:
1198      raise ValueError(message)
1199  elif validate_args:
1200    m, n = array_ops.split(
1201        array_ops.shape(lower_upper)[-2:], num_or_size_splits=2)
1202    assertions.append(check_ops.assert_equal(m, n, message=message))
1203
1204  return assertions
1205
1206
1207def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
1208  """Returns list of assertions related to `lu_solve` assumptions."""
1209  assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
1210
1211  message = 'Input `rhs` must have at least 2 dimensions.'
1212  if rhs.shape.ndims is not None:
1213    if rhs.shape.ndims < 2:
1214      raise ValueError(message)
1215  elif validate_args:
1216    assertions.append(
1217        check_ops.assert_rank_at_least(rhs, rank=2, message=message))
1218
1219  message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.'
1220  if (lower_upper.shape[-1] is not None and rhs.shape[-2] is not None):
1221    if lower_upper.shape[-1] != rhs.shape[-2]:
1222      raise ValueError(message)
1223  elif validate_args:
1224    assertions.append(
1225        check_ops.assert_equal(
1226            array_ops.shape(lower_upper)[-1],
1227            array_ops.shape(rhs)[-2],
1228            message=message))
1229
1230  return assertions
1231
1232
1233@tf_export('linalg.eigh_tridiagonal')
1234@dispatch.add_dispatch_support
1235def eigh_tridiagonal(alpha,
1236                     beta,
1237                     eigvals_only=True,
1238                     select='a',
1239                     select_range=None,
1240                     tol=None,
1241                     name=None):
1242  """Computes the eigenvalues of a Hermitian tridiagonal matrix.
1243
1244  Args:
1245    alpha: A real or complex tensor of shape (n), the diagonal elements of the
1246      matrix. NOTE: If alpha is complex, the imaginary part is ignored (assumed
1247        zero) to satisfy the requirement that the matrix be Hermitian.
1248    beta: A real or complex tensor of shape (n-1), containing the elements of
1249      the first super-diagonal of the matrix. If beta is complex, the first
1250      sub-diagonal of the matrix is assumed to be the conjugate of beta to
1251      satisfy the requirement that the matrix be Hermitian
1252    eigvals_only: If False, both eigenvalues and corresponding eigenvectors are
1253      computed. If True, only eigenvalues are computed. Default is True.
1254    select: Optional string with values in {‘a’, ‘v’, ‘i’} (default is 'a') that
1255      determines which eigenvalues to calculate:
1256        'a': all eigenvalues.
1257        ‘v’: eigenvalues in the interval (min, max] given by `select_range`.
1258        'i’: eigenvalues with indices min <= i <= max.
1259    select_range: Size 2 tuple or list or tensor specifying the range of
1260      eigenvalues to compute together with select. If select is 'a',
1261      select_range is ignored.
1262    tol: Optional scalar. The absolute tolerance to which each eigenvalue is
1263      required. An eigenvalue (or cluster) is considered to have converged if it
1264      lies in an interval of this width. If tol is None (default), the value
1265      eps*|T|_2 is used where eps is the machine precision, and |T|_2 is the
1266      2-norm of the matrix T.
1267    name: Optional name of the op.
1268
1269  Returns:
1270    eig_vals: The eigenvalues of the matrix in non-decreasing order.
1271    eig_vectors: If `eigvals_only` is False the eigenvectors are returned in
1272      the second output argument.
1273
1274  Raises:
1275     ValueError: If input values are invalid.
1276     NotImplemented: Computing eigenvectors for `eigvals_only` = False is
1277       not implemented yet.
1278
1279  This op implements a subset of the functionality of
1280  scipy.linalg.eigh_tridiagonal.
1281
1282  Note: The result is undefined if the input contains +/-inf or NaN, or if
1283  any value in beta has a magnitude greater than
1284  `numpy.sqrt(numpy.finfo(beta.dtype.as_numpy_dtype).max)`.
1285
1286
1287  TODO(b/187527398):
1288    Add support for outer batch dimensions.
1289
1290  #### Examples
1291
1292  ```python
1293  import numpy
1294  eigvals = tf.linalg.eigh_tridiagonal([0.0, 0.0, 0.0], [1.0, 1.0])
1295  eigvals_expected = [-numpy.sqrt(2.0), 0.0, numpy.sqrt(2.0)]
1296  tf.assert_near(eigvals_expected, eigvals)
1297  # ==> True
1298  ```
1299
1300  """
1301  with ops.name_scope(name or 'eigh_tridiagonal'):
1302
1303    def _compute_eigenvalues(alpha, beta):
1304      """Computes all eigenvalues of a Hermitian tridiagonal matrix."""
1305
1306      def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
1307        """Implements the Sturm sequence recurrence."""
1308        with ops.name_scope('sturm'):
1309          n = alpha.shape[0]
1310          zeros = array_ops.zeros(array_ops.shape(x), dtype=dtypes.int32)
1311          ones = array_ops.ones(array_ops.shape(x), dtype=dtypes.int32)
1312
1313          # The first step in the Sturm sequence recurrence
1314          # requires special care if x is equal to alpha[0].
1315          def sturm_step0():
1316            q = alpha[0] - x
1317            count = array_ops.where(q < 0, ones, zeros)
1318            q = array_ops.where(
1319                math_ops.equal(alpha[0], x), alpha0_perturbation, q)
1320            return q, count
1321
1322          # Subsequent steps all take this form:
1323          def sturm_step(i, q, count):
1324            q = alpha[i] - beta_sq[i - 1] / q - x
1325            count = array_ops.where(q <= pivmin, count + 1, count)
1326            q = array_ops.where(q <= pivmin, math_ops.minimum(q, -pivmin), q)
1327            return q, count
1328
1329          # The first step initializes q and count.
1330          q, count = sturm_step0()
1331
1332          # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
1333          # the bulk of the iterations unrolled by a factor of blocksize.
1334          blocksize = 16
1335          i = 1
1336          peel = (n - 1) % blocksize
1337          unroll_cnt = peel
1338
1339          def unrolled_steps(start, q, count):
1340            for j in range(unroll_cnt):
1341              q, count = sturm_step(start + j, q, count)
1342            return start + unroll_cnt, q, count
1343
1344          i, q, count = unrolled_steps(i, q, count)
1345
1346          # Run the remaining steps of the Sturm sequence using a partially
1347          # unrolled while loop.
1348          unroll_cnt = blocksize
1349          cond = lambda i, q, count: math_ops.less(i, n)
1350          _, _, count = control_flow_ops.while_loop(
1351              cond, unrolled_steps, [i, q, count], back_prop=False)
1352          return count
1353
1354      with ops.name_scope('compute_eigenvalues'):
1355        if alpha.dtype.is_complex:
1356          alpha = math_ops.real(alpha)
1357          beta_sq = math_ops.real(math_ops.conj(beta) * beta)
1358          beta_abs = math_ops.sqrt(beta_sq)
1359        else:
1360          beta_sq = math_ops.square(beta)
1361          beta_abs = math_ops.abs(beta)
1362
1363        # Estimate the largest and smallest eigenvalues of T using the
1364        # Gershgorin circle theorem.
1365        finfo = np.finfo(alpha.dtype.as_numpy_dtype)
1366        off_diag_abs_row_sum = array_ops.concat(
1367            [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
1368        lambda_est_max = math_ops.minimum(
1369            finfo.max, math_ops.reduce_max(alpha + off_diag_abs_row_sum))
1370        lambda_est_min = math_ops.maximum(
1371            finfo.min, math_ops.reduce_min(alpha - off_diag_abs_row_sum))
1372        # Upper bound on 2-norm of T.
1373        t_norm = math_ops.maximum(
1374            math_ops.abs(lambda_est_min), math_ops.abs(lambda_est_max))
1375
1376        # Compute the smallest allowed pivot in the Sturm sequence to avoid
1377        # overflow.
1378        one = np.ones([], dtype=alpha.dtype.as_numpy_dtype)
1379        safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
1380        pivmin = safemin * math_ops.maximum(one, math_ops.reduce_max(beta_sq))
1381        alpha0_perturbation = math_ops.square(finfo.eps * beta_abs[0])
1382        abs_tol = finfo.eps * t_norm
1383        if tol:
1384          abs_tol = math_ops.maximum(tol, abs_tol)
1385        # In the worst case, when the absolute tolerance is eps*lambda_est_max
1386        # and lambda_est_max = -lambda_est_min, we have to take as many
1387        # bisection steps as there are bits in the mantissa plus 1.
1388        max_it = finfo.nmant + 1
1389
1390        # Determine the indices of the desired eigenvalues, based on select
1391        # and select_range.
1392        asserts = None
1393        if select == 'a':
1394          target_counts = math_ops.range(n)
1395        elif select == 'i':
1396          asserts = check_ops.assert_less_equal(
1397              select_range[0],
1398              select_range[1],
1399              message='Got empty index range in select_range.')
1400          target_counts = math_ops.range(select_range[0], select_range[1] + 1)
1401        elif select == 'v':
1402          asserts = check_ops.assert_less(
1403              select_range[0],
1404              select_range[1],
1405              message='Got empty interval in select_range.')
1406        else:
1407          raise ValueError("'select must have a value in {'a', 'i', 'v'}.")
1408
1409        if asserts:
1410          with ops.control_dependencies([asserts]):
1411            alpha = array_ops.identity(alpha)
1412
1413        # Run binary search for all desired eigenvalues in parallel, starting
1414        # from  an interval slightly wider than the estimated
1415        # [lambda_est_min, lambda_est_max].
1416        fudge = 2.1  # We widen starting interval the Gershgorin interval a bit.
1417        norm_slack = math_ops.cast(n, alpha.dtype) * fudge * finfo.eps * t_norm
1418        if select in {'a', 'i'}:
1419          lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
1420          upper = lambda_est_max + norm_slack + fudge * pivmin
1421        else:
1422          # Count the number of eigenvalues in the given range.
1423          lower = select_range[0] - norm_slack - 2 * fudge * pivmin
1424          upper = select_range[1] + norm_slack + fudge * pivmin
1425          first = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, lower)
1426          last = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, upper)
1427          target_counts = math_ops.range(first, last)
1428
1429        # Pre-broadcast the scalars used in the Sturm sequence for improved
1430        # performance.
1431        upper = math_ops.minimum(upper, finfo.max)
1432        lower = math_ops.maximum(lower, finfo.min)
1433        target_shape = array_ops.shape(target_counts)
1434        lower = array_ops.broadcast_to(lower, shape=target_shape)
1435        upper = array_ops.broadcast_to(upper, shape=target_shape)
1436        pivmin = array_ops.broadcast_to(pivmin, target_shape)
1437        alpha0_perturbation = array_ops.broadcast_to(alpha0_perturbation,
1438                                                     target_shape)
1439
1440        # We compute the midpoint as 0.5*lower + 0.5*upper to avoid overflow in
1441        # (lower + upper) or (upper - lower) when the matrix has eigenvalues
1442        # with magnitude greater than finfo.max / 2.
1443        def midpoint(lower, upper):
1444          return (0.5 * lower) + (0.5 * upper)
1445
1446        def continue_binary_search(i, lower, upper):
1447          return math_ops.logical_and(
1448              math_ops.less(i, max_it),
1449              math_ops.less(abs_tol, math_ops.reduce_max(upper - lower)))
1450
1451        def binary_search_step(i, lower, upper):
1452          mid = midpoint(lower, upper)
1453          counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
1454          lower = array_ops.where(counts <= target_counts, mid, lower)
1455          upper = array_ops.where(counts > target_counts, mid, upper)
1456          return i + 1, lower, upper
1457
1458        # Start parallel binary searches.
1459        _, lower, upper = control_flow_ops.while_loop(continue_binary_search,
1460                                                      binary_search_step,
1461                                                      [0, lower, upper])
1462        return midpoint(lower, upper)
1463
1464    def _compute_eigenvectors(alpha, beta, eigvals):
1465      """Implements inverse iteration to compute eigenvectors."""
1466      with ops.name_scope('compute_eigenvectors'):
1467        k = array_ops.size(eigvals)
1468        n = array_ops.size(alpha)
1469        alpha = math_ops.cast(alpha, dtype=beta.dtype)
1470
1471        # Eigenvectors corresponding to cluster of close eigenvalues are
1472        # not unique and need to be explicitly orthogonalized. Here we
1473        # identify such clusters. Note: This function assumes that
1474        # eigenvalues are sorted in non-decreasing order.
1475        gap = eigvals[1:] - eigvals[:-1]
1476        eps = np.finfo(eigvals.dtype.as_numpy_dtype).eps
1477        t_norm = math_ops.maximum(
1478            math_ops.abs(eigvals[0]), math_ops.abs(eigvals[-1]))
1479        gaptol = np.sqrt(eps) * t_norm
1480        # Find the beginning and end of runs of eigenvectors corresponding
1481        # to eigenvalues closer than "gaptol", which will need to be
1482        # orthogonalized against each other.
1483        close = math_ops.less(gap, gaptol)
1484        left_neighbor_close = array_ops.concat([[False], close], axis=0)
1485        right_neighbor_close = array_ops.concat([close, [False]], axis=0)
1486        ortho_interval_start = math_ops.logical_and(
1487            math_ops.logical_not(left_neighbor_close), right_neighbor_close)
1488        ortho_interval_start = array_ops.squeeze(
1489            array_ops.where_v2(ortho_interval_start), axis=-1)
1490        ortho_interval_end = math_ops.logical_and(
1491            left_neighbor_close, math_ops.logical_not(right_neighbor_close))
1492        ortho_interval_end = array_ops.squeeze(
1493            array_ops.where_v2(ortho_interval_end), axis=-1) + 1
1494        num_clusters = array_ops.size(ortho_interval_end)
1495
1496        # We perform inverse iteration for all eigenvectors in parallel,
1497        # starting from a random set of vectors, until all have converged.
1498        v0 = math_ops.cast(
1499            stateless_random_ops.stateless_random_normal(
1500                shape=(k, n), seed=[7, 42]),
1501            dtype=beta.dtype)
1502        nrm_v = norm(v0, axis=1)
1503        v0 = v0 / nrm_v[:, array_ops.newaxis]
1504        zero_nrm = constant_op.constant(0, shape=nrm_v.shape, dtype=nrm_v.dtype)
1505
1506        # Replicate alpha-eigvals(ik) and beta across the k eigenvectors so we
1507        # can solve the k systems
1508        #    [T - eigvals(i)*eye(n)] x_i = r_i
1509        # simultaneously using the batching mechanism.
1510        eigvals_cast = math_ops.cast(eigvals, dtype=beta.dtype)
1511        alpha_shifted = (
1512            alpha[array_ops.newaxis, :] - eigvals_cast[:, array_ops.newaxis])
1513        beta = array_ops.tile(beta[array_ops.newaxis, :], [k, 1])
1514        diags = [beta, alpha_shifted, math_ops.conj(beta)]
1515
1516        def orthogonalize_close_eigenvectors(eigenvectors):
1517          # Eigenvectors corresponding to a cluster of close eigenvalues are not
1518          # uniquely defined, but the subspace they span is. To avoid numerical
1519          # instability, we explicitly mutually orthogonalize such eigenvectors
1520          # after each step of inverse iteration. It is customary to use
1521          # modified Gram-Schmidt for this, but this is not very efficient
1522          # on some platforms, so here we defer to the QR decomposition in
1523          # TensorFlow.
1524          def orthogonalize_cluster(cluster_idx, eigenvectors):
1525            start = ortho_interval_start[cluster_idx]
1526            end = ortho_interval_end[cluster_idx]
1527            update_indices = array_ops.expand_dims(
1528                math_ops.range(start, end), -1)
1529            vectors_in_cluster = eigenvectors[start:end, :]
1530            # We use the builtin QR factorization to orthonormalize the
1531            # vectors in the cluster.
1532            q, _ = qr(transpose(vectors_in_cluster))
1533            vectors_to_update = transpose(q)
1534            eigenvectors = array_ops.tensor_scatter_nd_update(
1535                eigenvectors, update_indices, vectors_to_update)
1536            return cluster_idx + 1, eigenvectors
1537
1538          _, eigenvectors = control_flow_ops.while_loop(
1539              lambda i, ev: math_ops.less(i, num_clusters),
1540              orthogonalize_cluster, [0, eigenvectors])
1541          return eigenvectors
1542
1543        def continue_iteration(i, _, nrm_v, nrm_v_old):
1544          max_it = 5  # Taken from LAPACK xSTEIN.
1545          min_norm_growth = 0.1
1546          norm_growth_factor = constant_op.constant(
1547              1 + min_norm_growth, dtype=nrm_v.dtype)
1548          # We stop the inverse iteration when we reach the maximum number of
1549          # iterations or the norm growths is less than 10%.
1550          return math_ops.logical_and(
1551              math_ops.less(i, max_it),
1552              math_ops.reduce_any(
1553                  math_ops.greater_equal(
1554                      math_ops.real(nrm_v),
1555                      math_ops.real(norm_growth_factor * nrm_v_old))))
1556
1557        def inverse_iteration_step(i, v, nrm_v, nrm_v_old):
1558          v = tridiagonal_solve(
1559              diags,
1560              v,
1561              diagonals_format='sequence',
1562              partial_pivoting=True,
1563              perturb_singular=True)
1564          nrm_v_old = nrm_v
1565          nrm_v = norm(v, axis=1)
1566          v = v / nrm_v[:, array_ops.newaxis]
1567          v = orthogonalize_close_eigenvectors(v)
1568          return i + 1, v, nrm_v, nrm_v_old
1569
1570        _, v, nrm_v, _ = control_flow_ops.while_loop(continue_iteration,
1571                                                     inverse_iteration_step,
1572                                                     [0, v0, nrm_v, zero_nrm])
1573        return transpose(v)
1574
1575    alpha = ops.convert_to_tensor(alpha, name='alpha')
1576    n = alpha.shape[0]
1577    if n <= 1:
1578      return math_ops.real(alpha)
1579    beta = ops.convert_to_tensor(beta, name='beta')
1580
1581    if alpha.dtype != beta.dtype:
1582      raise ValueError("'alpha' and 'beta' must have the same type.")
1583
1584    eigvals = _compute_eigenvalues(alpha, beta)
1585    if eigvals_only:
1586      return eigvals
1587
1588    eigvectors = _compute_eigenvectors(alpha, beta, eigvals)
1589    return eigvals, eigvectors
1590