1# mypy: allow-untyped-defs 2"""Various linear algebra utility methods for internal use.""" 3 4from typing import Optional, Tuple 5 6import torch 7from torch import Tensor 8 9 10def is_sparse(A): 11 """Check if tensor A is a sparse tensor""" 12 if isinstance(A, torch.Tensor): 13 return A.layout == torch.sparse_coo 14 15 error_str = "expected Tensor" 16 if not torch.jit.is_scripting(): 17 error_str += f" but got {type(A)}" 18 raise TypeError(error_str) 19 20 21def get_floating_dtype(A): 22 """Return the floating point dtype of tensor A. 23 24 Integer types map to float32. 25 """ 26 dtype = A.dtype 27 if dtype in (torch.float16, torch.float32, torch.float64): 28 return dtype 29 return torch.float32 30 31 32def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: 33 """Multiply two matrices. 34 35 If A is None, return B. A can be sparse or dense. B is always 36 dense. 37 """ 38 if A is None: 39 return B 40 if is_sparse(A): 41 return torch.sparse.mm(A, B) 42 return torch.matmul(A, B) 43 44 45def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: 46 """Return bilinear form of matrices: :math:`X^T A Y`.""" 47 return matmul(X.mT, matmul(A, Y)) 48 49 50def qform(A: Optional[Tensor], S: Tensor): 51 """Return quadratic form :math:`S^T A S`.""" 52 return bform(S, A, S) 53 54 55def basis(A): 56 """Return orthogonal basis of A columns.""" 57 return torch.linalg.qr(A).Q 58 59 60def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]: 61 """Return eigenpairs of A with specified ordering.""" 62 if largest is None: 63 largest = False 64 E, Z = torch.linalg.eigh(A, UPLO="U") 65 # assuming that E is ordered 66 if largest: 67 E = torch.flip(E, dims=(-1,)) 68 Z = torch.flip(Z, dims=(-1,)) 69 return E, Z 70 71 72# These functions were deprecated and removed 73# This nice error message can be removed in version 1.13+ 74def matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor: 75 raise RuntimeError( 76 "This function was deprecated since version 1.9 and is now removed.\n" 77 "Please use the `torch.linalg.matrix_rank` function instead. " 78 "The parameter 'symmetric' was renamed in `torch.linalg.matrix_rank()` to 'hermitian'." 79 ) 80 81 82def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]: 83 raise RuntimeError( 84 "This function was deprecated since version 1.9 and is now removed. " 85 "`torch.solve` is deprecated in favor of `torch.linalg.solve`. " 86 "`torch.linalg.solve` has its arguments reversed and does not return the LU factorization.\n\n" 87 "To get the LU factorization see `torch.lu`, which can be used with `torch.lu_solve` or `torch.lu_unpack`.\n" 88 "X = torch.solve(B, A).solution " 89 "should be replaced with:\n" 90 "X = torch.linalg.solve(A, B)" 91 ) 92 93 94def lstsq(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]: 95 raise RuntimeError( 96 "This function was deprecated since version 1.9 and is now removed. " 97 "`torch.lstsq` is deprecated in favor of `torch.linalg.lstsq`.\n" 98 "`torch.linalg.lstsq` has reversed arguments and does not return the QR decomposition in " 99 "the returned tuple (although it returns other information about the problem).\n\n" 100 "To get the QR decomposition consider using `torch.linalg.qr`.\n\n" 101 "The returned solution in `torch.lstsq` stored the residuals of the solution in the " 102 "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, " 103 "the residuals are in the field 'residuals' of the returned named tuple.\n\n" 104 "The unpacking of the solution, as in\n" 105 "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n" 106 "should be replaced with:\n" 107 "X = torch.linalg.lstsq(A, B).solution" 108 ) 109 110 111def _symeig( 112 input, 113 eigenvectors=False, 114 upper=True, 115 *, 116 out=None, 117) -> Tuple[Tensor, Tensor]: 118 raise RuntimeError( 119 "This function was deprecated since version 1.9 and is now removed. " 120 "The default behavior has changed from using the upper triangular portion of the matrix by default " 121 "to using the lower triangular portion.\n\n" 122 "L, _ = torch.symeig(A, upper=upper) " 123 "should be replaced with:\n" 124 "L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n\n" 125 "and\n\n" 126 "L, V = torch.symeig(A, eigenvectors=True) " 127 "should be replaced with:\n" 128 "L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')" 129 ) 130 131 132def eig( 133 self: Tensor, 134 eigenvectors: bool = False, 135 *, 136 e=None, 137 v=None, 138) -> Tuple[Tensor, Tensor]: 139 raise RuntimeError( 140 "This function was deprecated since version 1.9 and is now removed. " 141 "`torch.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble` rather than real tensors " 142 "mimicking complex tensors.\n\n" 143 "L, _ = torch.eig(A) " 144 "should be replaced with:\n" 145 "L_complex = torch.linalg.eigvals(A)\n\n" 146 "and\n\n" 147 "L, V = torch.eig(A, eigenvectors=True) " 148 "should be replaced with:\n" 149 "L_complex, V_complex = torch.linalg.eig(A)" 150 ) 151