1# mypy: allow-untyped-defs 2from functools import partial 3from typing import Optional, Tuple, Union 4 5import torch 6import torch._prims as prims 7import torch._prims_common as utils 8import torch._refs as refs 9import torch._refs.linalg as linalg 10from torch import Tensor 11from torch._prims_common import ( 12 check_fp_or_complex, 13 check_is_matrix, 14 Dim, 15 DimsType, 16 ELEMENTWISE_TYPE_PROMOTION_KIND, 17 IntLike, 18 TensorLikeType, 19) 20from torch._prims_common.wrappers import ( 21 _maybe_convert_to_dtype, 22 elementwise_type_promotion_wrapper, 23 out_wrapper, 24) 25 26 27__all__ = [ 28 "diagonal", 29 "matrix_norm", 30 "norm", 31 "svd", 32 "svdvals", 33 "vector_norm", 34 "vecdot", 35 "cross", 36] 37 38 39def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): 40 """ 41 Checks related to the dtype kwarg in `linalg.*norm` functions 42 """ 43 if dtype is not None: 44 torch._check( 45 utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), 46 lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", 47 ) 48 torch._check( 49 utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), 50 lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( 51 fn_name=fn_name, 52 d="complex" if utils.is_complex_dtype(x_dtype) else "real", 53 dtype=dtype, 54 ), 55 ) 56 torch._check( 57 utils.get_higher_dtype(dtype, x_dtype) == dtype, 58 lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " 59 "without narrowing to the specified dtype ({dtype})", 60 ) 61 62 63import operator 64 65# Utilities should come BEFORE this import 66from torch._decomp import register_decomposition 67from torch._decomp.decompositions import pw_cast_for_opmath 68 69 70@register_decomposition(torch._ops.ops.aten.linalg_cross) 71@out_wrapper() 72@pw_cast_for_opmath 73def cross(a: Tensor, b: Tensor, dim: int = -1): 74 torch._check( 75 a.ndim == b.ndim, 76 lambda: "linalg.cross: inputs must have the same number of dimensions.", 77 ) 78 torch._check( 79 a.size(dim) == 3 and b.size(dim) == 3, 80 lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}", 81 ) 82 a, b = torch.broadcast_tensors(a, b) 83 dim = utils.canonicalize_dim(a.ndim, dim) 84 idx = torch.arange(3, device=a.device) 85 return a.index_select(dim, (idx + 1) % 3) * b.index_select( 86 dim, (idx + 2) % 3 87 ) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) 88 89 90def diagonal( 91 input: TensorLikeType, 92 *, 93 offset: int = 0, 94 dim1: int = -2, 95 dim2: int = -1, 96) -> TensorLikeType: 97 return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2) 98 99 100@register_decomposition(torch._ops.ops.aten.linalg_vector_norm) 101@out_wrapper(exact_dtype=True) 102def vector_norm( 103 x: TensorLikeType, 104 ord: Union[float, int] = 2, 105 dim: Optional[DimsType] = None, 106 keepdim: bool = False, 107 *, 108 dtype: Optional[torch.dtype] = None, 109) -> Tensor: 110 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 111 112 # Checks 113 check_fp_or_complex(x.dtype, "linalg.vector_norm") 114 115 if isinstance(dim, Dim): 116 dim = [dim] # type: ignore[assignment] 117 118 if guard_size_oblivious(x.numel() == 0) and (ord < 0.0 or ord == float("inf")): 119 torch._check( 120 dim is not None and len(dim) != 0, 121 lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " 122 "because the operation does not have an identity", 123 ) 124 shape = x.shape 125 assert dim is not None # mypy does not seem to be able to see through check? 126 for d in dim: 127 torch._check( 128 shape[d] != 0, 129 lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " 130 f"dimension {d} because this dimension is empty and the " 131 "operation does not have an identity", 132 ) 133 _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") 134 135 computation_dtype, result_dtype = utils.reduction_dtypes( 136 x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype 137 ) 138 139 to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype) 140 141 # Implementation 142 if ord == 0.0: 143 return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) 144 elif ord == float("inf"): 145 return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] 146 elif ord == float("-inf"): 147 return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] 148 else: 149 # From here on the computation dtype is important as the reduction is non-trivial 150 x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment] 151 reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim) 152 153 is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0 154 if not (is_ord_even and utils.is_float_dtype(x.dtype)): 155 x = torch.abs(x) 156 return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] 157 158 159def _backshift_permutation(dim0, dim1, ndim): 160 # Auxiliary function for matrix_norm 161 # Computes the permutation that moves the two given dimensions to the back 162 ret = [i for i in range(ndim) if i != dim0 and i != dim1] 163 ret.extend((dim0, dim1)) 164 return ret 165 166 167def _inverse_permutation(perm): 168 # Given a permutation, returns its inverse. It's equivalent to argsort on an array 169 return [i for i, j in sorted(enumerate(perm), key=operator.itemgetter(1))] 170 171 172# CompositeImplicitAutograd 173@out_wrapper(exact_dtype=True) 174def matrix_norm( 175 A: TensorLikeType, 176 ord: Union[float, str] = "fro", 177 dim: DimsType = (-2, -1), 178 keepdim: bool = False, 179 *, 180 dtype: Optional[torch.dtype] = None, 181) -> TensorLikeType: 182 # shape 183 check_is_matrix(A, "linalg.matrix_norm") 184 # dim 185 dim = utils.canonicalize_dims(A.ndim, dim) 186 if isinstance(dim, Dim): 187 dim = (dim,) # type: ignore[assignment] 188 torch._check( 189 len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}" 190 ) 191 torch._check( 192 dim[0] != dim[1], 193 lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", 194 ) 195 # dtype arg 196 _check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") 197 198 if isinstance(ord, str): 199 # ord 200 torch._check( 201 ord in ("fro", "nuc"), 202 lambda: "linalg.matrix_norm: Order {ord} not supported.", 203 ) 204 # dtype 205 check_fp_or_complex( 206 A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc" 207 ) 208 209 if ord == "fro": 210 return vector_norm(A, 2, dim, keepdim, dtype=dtype) 211 else: # ord == "nuc" 212 if dtype is not None: 213 A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] 214 perm = _backshift_permutation(dim[0], dim[1], A.ndim) 215 result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) 216 if keepdim: 217 inv_perm = _inverse_permutation(perm) 218 result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) 219 return result 220 else: 221 # ord 222 abs_ord = abs(ord) 223 torch._check( 224 abs_ord in (2, 1, float("inf")), 225 lambda: "linalg.matrix_norm: Order {ord} not supported.", 226 ) 227 # dtype 228 check_fp_or_complex( 229 A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2 230 ) 231 232 max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) 233 234 if abs_ord == 2.0: 235 if dtype is not None: 236 A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] 237 perm = _backshift_permutation(dim[0], dim[1], A.ndim) 238 result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) 239 if keepdim: 240 inv_perm = _inverse_permutation(perm) 241 result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) 242 return result 243 else: # 1, -1, inf, -inf 244 dim0, dim1 = dim 245 if abs_ord == float("inf"): 246 dim0, dim1 = dim1, dim0 247 if not keepdim and (dim0 < dim1): 248 dim1 -= 1 249 return max_min( 250 vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 251 ) 252 253 254# CompositeImplicitAutograd 255@out_wrapper(exact_dtype=True) 256def norm( 257 A: TensorLikeType, 258 ord: Optional[Union[float, str]] = None, 259 dim: Optional[DimsType] = None, 260 keepdim: bool = False, 261 *, 262 dtype: Optional[torch.dtype] = None, 263) -> TensorLikeType: 264 if dim is not None: 265 if isinstance(dim, Dim): 266 dim = (dim,) # type: ignore[assignment] 267 torch._check( 268 len(dim) in (1, 2), 269 lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", 270 ) 271 elif ord is not None: 272 torch._check( 273 A.ndim in (1, 2), 274 lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", 275 ) 276 277 if ord is not None and ( 278 (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2) 279 ): 280 if dim is None: 281 dim = (0, 1) 282 return matrix_norm(A, ord, dim, keepdim, dtype=dtype) 283 else: 284 if ord is None: 285 ord = 2.0 286 return vector_norm(A, ord, dim, keepdim, dtype=dtype) # type: ignore[arg-type] 287 288 289# CompositeImplicitAutograd 290@out_wrapper("U", "S", "Vh", exact_dtype=True) 291def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]: 292 return prims.svd(A, full_matrices=full_matrices) 293 294 295# CompositeImplicitAutograd 296@out_wrapper(exact_dtype=True) 297def svdvals(A: TensorLikeType) -> Tensor: 298 return svd(A, full_matrices=False)[1] 299 300 301# CompositeImplicitAutograd 302@out_wrapper() 303@elementwise_type_promotion_wrapper( 304 type_promoting_args=("x", "y"), 305 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 306) 307def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: 308 check_fp_or_complex(x.dtype, "linalg.vecdot") 309 return (x.conj() * y).sum(dim=dim) 310