• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from functools import partial
3from typing import Optional, Tuple, Union
4
5import torch
6import torch._prims as prims
7import torch._prims_common as utils
8import torch._refs as refs
9import torch._refs.linalg as linalg
10from torch import Tensor
11from torch._prims_common import (
12    check_fp_or_complex,
13    check_is_matrix,
14    Dim,
15    DimsType,
16    ELEMENTWISE_TYPE_PROMOTION_KIND,
17    IntLike,
18    TensorLikeType,
19)
20from torch._prims_common.wrappers import (
21    _maybe_convert_to_dtype,
22    elementwise_type_promotion_wrapper,
23    out_wrapper,
24)
25
26
27__all__ = [
28    "diagonal",
29    "matrix_norm",
30    "norm",
31    "svd",
32    "svdvals",
33    "vector_norm",
34    "vecdot",
35    "cross",
36]
37
38
39def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str):
40    """
41    Checks related to the dtype kwarg in `linalg.*norm` functions
42    """
43    if dtype is not None:
44        torch._check(
45            utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
46            lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
47        )
48        torch._check(
49            utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
50            lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
51                fn_name=fn_name,
52                d="complex" if utils.is_complex_dtype(x_dtype) else "real",
53                dtype=dtype,
54            ),
55        )
56        torch._check(
57            utils.get_higher_dtype(dtype, x_dtype) == dtype,
58            lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
59            "without narrowing to the specified dtype ({dtype})",
60        )
61
62
63import operator
64
65# Utilities should come BEFORE this import
66from torch._decomp import register_decomposition
67from torch._decomp.decompositions import pw_cast_for_opmath
68
69
70@register_decomposition(torch._ops.ops.aten.linalg_cross)
71@out_wrapper()
72@pw_cast_for_opmath
73def cross(a: Tensor, b: Tensor, dim: int = -1):
74    torch._check(
75        a.ndim == b.ndim,
76        lambda: "linalg.cross: inputs must have the same number of dimensions.",
77    )
78    torch._check(
79        a.size(dim) == 3 and b.size(dim) == 3,
80        lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}",
81    )
82    a, b = torch.broadcast_tensors(a, b)
83    dim = utils.canonicalize_dim(a.ndim, dim)
84    idx = torch.arange(3, device=a.device)
85    return a.index_select(dim, (idx + 1) % 3) * b.index_select(
86        dim, (idx + 2) % 3
87    ) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3)
88
89
90def diagonal(
91    input: TensorLikeType,
92    *,
93    offset: int = 0,
94    dim1: int = -2,
95    dim2: int = -1,
96) -> TensorLikeType:
97    return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2)
98
99
100@register_decomposition(torch._ops.ops.aten.linalg_vector_norm)
101@out_wrapper(exact_dtype=True)
102def vector_norm(
103    x: TensorLikeType,
104    ord: Union[float, int] = 2,
105    dim: Optional[DimsType] = None,
106    keepdim: bool = False,
107    *,
108    dtype: Optional[torch.dtype] = None,
109) -> Tensor:
110    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
111
112    # Checks
113    check_fp_or_complex(x.dtype, "linalg.vector_norm")
114
115    if isinstance(dim, Dim):
116        dim = [dim]  # type: ignore[assignment]
117
118    if guard_size_oblivious(x.numel() == 0) and (ord < 0.0 or ord == float("inf")):
119        torch._check(
120            dim is not None and len(dim) != 0,
121            lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
122            "because the operation does not have an identity",
123        )
124        shape = x.shape
125        assert dim is not None  # mypy does not seem to be able to see through check?
126        for d in dim:
127            torch._check(
128                shape[d] != 0,
129                lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
130                f"dimension {d} because this dimension is empty and the "
131                "operation does not have an identity",
132            )
133    _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")
134
135    computation_dtype, result_dtype = utils.reduction_dtypes(
136        x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype
137    )
138
139    to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype)
140
141    # Implementation
142    if ord == 0.0:
143        return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype)
144    elif ord == float("inf"):
145        return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim))  # type: ignore[return-value,arg-type]
146    elif ord == float("-inf"):
147        return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim))  # type: ignore[return-value,arg-type]
148    else:
149        # From here on the computation dtype is important as the reduction is non-trivial
150        x = _maybe_convert_to_dtype(x, computation_dtype)  # type: ignore[assignment]
151        reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim)
152
153        is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0
154        if not (is_ord_even and utils.is_float_dtype(x.dtype)):
155            x = torch.abs(x)
156        return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))  # type: ignore[return-value]
157
158
159def _backshift_permutation(dim0, dim1, ndim):
160    # Auxiliary function for matrix_norm
161    # Computes the permutation that moves the two given dimensions to the back
162    ret = [i for i in range(ndim) if i != dim0 and i != dim1]
163    ret.extend((dim0, dim1))
164    return ret
165
166
167def _inverse_permutation(perm):
168    # Given a permutation, returns its inverse. It's equivalent to argsort on an array
169    return [i for i, j in sorted(enumerate(perm), key=operator.itemgetter(1))]
170
171
172# CompositeImplicitAutograd
173@out_wrapper(exact_dtype=True)
174def matrix_norm(
175    A: TensorLikeType,
176    ord: Union[float, str] = "fro",
177    dim: DimsType = (-2, -1),
178    keepdim: bool = False,
179    *,
180    dtype: Optional[torch.dtype] = None,
181) -> TensorLikeType:
182    # shape
183    check_is_matrix(A, "linalg.matrix_norm")
184    # dim
185    dim = utils.canonicalize_dims(A.ndim, dim)
186    if isinstance(dim, Dim):
187        dim = (dim,)  # type: ignore[assignment]
188    torch._check(
189        len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
190    )
191    torch._check(
192        dim[0] != dim[1],
193        lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
194    )
195    # dtype arg
196    _check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm")
197
198    if isinstance(ord, str):
199        # ord
200        torch._check(
201            ord in ("fro", "nuc"),
202            lambda: "linalg.matrix_norm: Order {ord} not supported.",
203        )
204        # dtype
205        check_fp_or_complex(
206            A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc"
207        )
208
209        if ord == "fro":
210            return vector_norm(A, 2, dim, keepdim, dtype=dtype)
211        else:  # ord == "nuc"
212            if dtype is not None:
213                A = _maybe_convert_to_dtype(A, dtype)  # type: ignore[assignment]
214            perm = _backshift_permutation(dim[0], dim[1], A.ndim)
215            result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
216            if keepdim:
217                inv_perm = _inverse_permutation(perm)
218                result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
219            return result
220    else:
221        # ord
222        abs_ord = abs(ord)
223        torch._check(
224            abs_ord in (2, 1, float("inf")),
225            lambda: "linalg.matrix_norm: Order {ord} not supported.",
226        )
227        # dtype
228        check_fp_or_complex(
229            A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2
230        )
231
232        max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim)
233
234        if abs_ord == 2.0:
235            if dtype is not None:
236                A = _maybe_convert_to_dtype(A, dtype)  # type: ignore[assignment]
237            perm = _backshift_permutation(dim[0], dim[1], A.ndim)
238            result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
239            if keepdim:
240                inv_perm = _inverse_permutation(perm)
241                result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
242            return result
243        else:  # 1, -1, inf, -inf
244            dim0, dim1 = dim
245            if abs_ord == float("inf"):
246                dim0, dim1 = dim1, dim0
247            if not keepdim and (dim0 < dim1):
248                dim1 -= 1
249            return max_min(
250                vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1
251            )
252
253
254# CompositeImplicitAutograd
255@out_wrapper(exact_dtype=True)
256def norm(
257    A: TensorLikeType,
258    ord: Optional[Union[float, str]] = None,
259    dim: Optional[DimsType] = None,
260    keepdim: bool = False,
261    *,
262    dtype: Optional[torch.dtype] = None,
263) -> TensorLikeType:
264    if dim is not None:
265        if isinstance(dim, Dim):
266            dim = (dim,)  # type: ignore[assignment]
267        torch._check(
268            len(dim) in (1, 2),
269            lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
270        )
271    elif ord is not None:
272        torch._check(
273            A.ndim in (1, 2),
274            lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
275        )
276
277    if ord is not None and (
278        (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2)
279    ):
280        if dim is None:
281            dim = (0, 1)
282        return matrix_norm(A, ord, dim, keepdim, dtype=dtype)
283    else:
284        if ord is None:
285            ord = 2.0
286        return vector_norm(A, ord, dim, keepdim, dtype=dtype)  # type: ignore[arg-type]
287
288
289# CompositeImplicitAutograd
290@out_wrapper("U", "S", "Vh", exact_dtype=True)
291def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
292    return prims.svd(A, full_matrices=full_matrices)
293
294
295# CompositeImplicitAutograd
296@out_wrapper(exact_dtype=True)
297def svdvals(A: TensorLikeType) -> Tensor:
298    return svd(A, full_matrices=False)[1]
299
300
301# CompositeImplicitAutograd
302@out_wrapper()
303@elementwise_type_promotion_wrapper(
304    type_promoting_args=("x", "y"),
305    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
306)
307def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
308    check_fp_or_complex(x.dtype, "linalg.vecdot")
309    return (x.conj() * y).sum(dim=dim)
310