• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import math
4from functools import wraps
5from typing import Callable, Optional, Union
6
7import torch
8import torch._prims as prims
9import torch._prims_common as utils
10import torch._refs as refs
11from torch._decomp import register_decomposition
12from torch._prims_common import (
13    ELEMENTWISE_TYPE_PROMOTION_KIND,
14    NumberType,
15    ShapeType,
16    TensorLike,
17    TensorLikeType,
18)
19from torch._prims_common.wrappers import (
20    elementwise_type_promotion_wrapper,
21    elementwise_unary_scalar_wrapper,
22    out_wrapper,
23)
24from torch._refs import _make_inplace
25
26
27__all__ = [
28    "alpha_dropout",
29    "celu",
30    "celu_",
31    "channel_shuffle",
32    "dropout",
33    "elu",
34    "elu_",
35    "gelu",
36    "glu",
37    "group_norm",
38    "hardshrink",
39    "hardtanh",
40    "hinge_embedding_loss",
41    "huber_loss",
42    "l1_loss",
43    "layer_norm",
44    "leaky_relu",
45    "log_softmax",
46    "margin_ranking_loss",
47    "mish",
48    "mish_",
49    "mse_loss",
50    "nll_loss",
51    "pairwise_distance",
52    "pdist",
53    "poisson_nll_loss",
54    "prelu",
55    "relu",
56    "relu6",
57    "selu",
58    "selu_",
59    "smooth_l1_loss",
60    "softmax",
61    "softmin",
62    "softplus",
63    "softshrink",
64    "tanhshrink",
65    "threshold",
66    "threshold_",
67    "triplet_margin_loss",
68]
69
70Tensor = torch.Tensor
71aten = torch._ops.ops.aten
72DispatchKey = torch._C.DispatchKey  # type: ignore[attr-defined]
73
74
75def _dropout_helper(
76    self: TensorLikeType,
77    val: float,
78) -> TensorLikeType:
79    """
80    Helper function for all dropout-type operators. During training,
81    some of the elements of the input tensor are randomly masked.
82
83    Returns the masked tensor of the boolean values.
84
85    """
86
87    return (
88        refs._uniform_helper(
89            self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device
90        )
91        < val
92    )
93
94
95@register_decomposition(aten.alpha_dropout)
96def alpha_dropout(
97    self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False
98) -> TensorLikeType:
99    if inplace:
100        raise NotImplementedError
101
102    if not training:
103        return self
104
105    torch._check(
106        p <= 1 and p >= 0,
107        lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
108    )
109
110    if p == 1:
111        return torch.zeros_like(self)
112
113    if p == 0:
114        return self
115
116    dropout_mask = _dropout_helper(self, 1 - p)
117
118    # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf)
119    # alpha = - SELU.alpha * SELU.scale, here
120    # SELU.alpha = 1.6732632423543772848170429916717 and
121    # SELU.scale = 1.0507009873554804934193349852946
122    alpha = -1.7580993408473766
123
124    a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p))
125    b = torch.logical_not(dropout_mask)
126    b = b * (alpha * a) + alpha * a * p
127    dropout_mask = a * dropout_mask
128
129    return self * dropout_mask + b
130
131
132def _inplace_wrapper(fn):
133    """
134    Given a nn.functional non-linearity, implements its `inplace: bool` argument
135    """
136
137    # nb. We use the name of the first argument used in the unary references
138    @wraps(fn)
139    def _fn(a, *args, inplace=False, **kwargs):
140        if inplace:
141            torch._check(
142                "out" not in kwargs,
143                lambda: "Cannot set inplace=True and pass out= at the same time",
144            )
145            return fn(a, *args, inplace=False, out=a, **kwargs)
146        else:
147            return fn(a, *args, inplace=False, **kwargs)
148
149    return _fn
150
151
152# celu is implemented specially because it has an alpha argument
153# celu is very similar to elu
154@register_decomposition(aten.celu)
155@_inplace_wrapper
156@out_wrapper()
157@elementwise_type_promotion_wrapper(
158    type_promoting_args=("a",),
159    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
160)
161def celu(
162    a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False
163) -> TensorLikeType:
164    """
165    Reference implementation of torch.nn.functional.celu
166    """
167
168    if inplace:
169        raise NotImplementedError
170
171    rhs: TensorLikeType
172    if alpha is not None:
173        python_type = utils.dtype_to_type(a.dtype)
174        if not utils.is_weakly_lesser_type(type(alpha), python_type):
175            msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
176            raise ValueError(msg)
177        rhs = alpha * torch.expm1(torch.true_divide(a, alpha))  # type: ignore[arg-type]
178    else:
179        rhs = torch.expm1(a)
180
181    return torch.where(a > 0, a, rhs)
182
183
184@_inplace_wrapper
185@out_wrapper()
186def dropout(
187    a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False
188) -> TensorLikeType:
189    if inplace:
190        raise NotImplementedError
191
192    if not training:
193        return a
194
195    torch._check(
196        p <= 1 and p >= 0,
197        lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
198    )
199
200    if p == 1:
201        return torch.zeros_like(a)
202
203    if p == 0:
204        return a
205
206    scale = 1 / (1 - p)
207    dropout_mask = _dropout_helper(a, 1 - p)
208
209    return a * dropout_mask * scale
210
211
212@register_decomposition(aten.elu)
213@_inplace_wrapper
214@out_wrapper()
215@elementwise_type_promotion_wrapper(
216    type_promoting_args=("a",),
217    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
218)
219def elu(
220    a: TensorLikeType,
221    alpha: NumberType = 1.0,
222    scale: NumberType = 1.0,
223    input_scale: NumberType = 1.0,
224    inplace: bool = False,
225) -> TensorLikeType:
226    """
227    Reference implementation of torch.nn.functional.elu
228    """
229    if inplace:
230        raise NotImplementedError
231
232    # nb. This should be factored out into a can_cast aux function
233    python_type = utils.dtype_to_type(a.dtype)
234    torch._check(
235        utils.is_weakly_lesser_type(type(input_scale), python_type),
236        lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!",
237    )
238    torch._check(
239        utils.is_weakly_lesser_type(type(scale), python_type),
240        lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!",
241    )
242    torch._check(
243        utils.is_weakly_lesser_type(type(alpha), python_type),
244        lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
245    )
246
247    return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale))
248
249
250@register_decomposition(aten.relu)
251@_inplace_wrapper
252@out_wrapper()
253@elementwise_type_promotion_wrapper(
254    type_promoting_args=("a",),
255    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
256)
257def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
258    """
259    Reference implementation of torch.nn.functional.relu
260    """
261
262    if inplace:
263        raise NotImplementedError
264
265    return torch.where(torch.le(a, 0), 0, a)
266
267
268@register_decomposition(aten.channel_shuffle)
269@out_wrapper()
270def channel_shuffle(input: TensorLikeType, groups: int) -> TensorLikeType:
271    """
272    Reference implementation of :func:`torch.nn.functional.channel_shuffle`.
273    """
274    from torch._meta_registrations import device_hint
275
276    torch._check(
277        input.dim() > 2,
278        lambda: f"channel_shuffle expects input with > 2 dims, but got input with sizes {list(input.size())}",
279    )
280    c = input.shape[1]
281    torch._check(
282        groups > 0,
283        lambda: f"Number of groups to divide channels in must be positive. Value of groups:{groups}",
284    )
285    torch._check(
286        (c % groups) == 0,
287        lambda: f"Number of channels must be divisible by groups. Got {c} channels and {groups} groups.",
288    )
289    n = input.shape[0]
290    cg = c // groups
291    dhw = input.shape[2:]
292
293    if input.numel() == 0 or (
294        device_hint(input) == "cuda" and (groups == 1 or groups == c)
295    ):
296        return input.view(input.shape)
297
298    return (
299        input.reshape(n, groups, cg, *dhw)
300        .transpose(1, 2)
301        .reshape(input.shape)
302        .contiguous()
303    )
304
305
306def group_norm(
307    input: Tensor,
308    num_groups: int,
309    weight: Optional[Tensor] = None,
310    bias: Optional[Tensor] = None,
311    eps: float = 1e-5,
312) -> Tensor:
313    """
314    Reference implementation of :func:`torch.nn.functional.group_norm`.
315    """
316    torch._check(
317        input.ndim >= 2,
318        lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
319    )
320
321    batch_size = input.shape[0]
322    num_channels = input.shape[1]
323    torch._check(
324        num_channels % num_groups == 0,
325        lambda: "Expected number of channels in input to be divisible by num_groups, "
326        + f"but got input of shape {input.shape} and num_groups = {num_groups}",
327    )
328
329    # input shape is (N, C, *), so we flatten all inner dimensions except (N, C)
330    flattened_inner_size = 1
331    for dim_length in input.shape[2:]:
332        flattened_inner_size *= dim_length
333
334    return torch.native_group_norm(
335        input,
336        weight,
337        bias,
338        batch_size,
339        num_channels,
340        flattened_inner_size,
341        num_groups,
342        eps,
343    )[0]
344
345
346def layer_norm(
347    input: Tensor,
348    normalized_shape: ShapeType,
349    weight: Optional[Tensor] = None,
350    bias: Optional[Tensor] = None,
351    eps: float = 1e-5,
352) -> Tensor:
353    """
354    Reference implementation of :func:`torch.nn.functional.layer_norm`.
355    """
356    return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0]
357
358
359@register_decomposition(aten.leaky_relu)
360@_inplace_wrapper
361@out_wrapper()
362@elementwise_type_promotion_wrapper(
363    type_promoting_args=("a",),
364    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
365)
366def leaky_relu(
367    a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False
368) -> TensorLikeType:
369    """
370    Reference implementation of torch.nn.functional.leaky_relu
371    """
372
373    if inplace:
374        raise NotImplementedError
375
376    python_type = utils.dtype_to_type(a.dtype)
377    if not utils.is_weakly_lesser_type(type(negative_slope), python_type):
378        msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!"
379        raise ValueError(msg)
380    return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope))
381
382
383@register_decomposition(aten.mish)
384@_inplace_wrapper
385@out_wrapper()
386@elementwise_type_promotion_wrapper(
387    type_promoting_args=("a",),
388    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
389)
390def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
391    """
392    Reference implementation of torch.nn.functional.mish
393    """
394
395    if inplace:
396        raise NotImplementedError
397    return a * torch.tanh(torch.nn.functional.softplus(a))
398
399
400@register_decomposition(aten.selu)
401@_inplace_wrapper
402@out_wrapper()
403@elementwise_type_promotion_wrapper(
404    type_promoting_args=("a",),
405    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
406)
407def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
408    """
409    Reference implementation of torch.nn.functional.selu
410    """
411    if inplace:
412        raise NotImplementedError
413
414    alpha = 1.6732632423543772848170429916717
415    scale = 1.0507009873554804934193349852946
416
417    rhs = alpha * torch.expm1(a)
418
419    return scale * torch.where(a > 0, a, rhs)
420
421
422# Forwarding alias: the functional variant doesn't support the out kwarg
423# CompositeImplicitAutograd - don't register decomp
424def softmax(
425    a: TensorLikeType,
426    dim: Optional[int] = None,
427    _stacklevel: int = 3,  # for compat when using TorchRefsMode(strict=True)
428    dtype: Optional[torch.dtype] = None,
429) -> TensorLikeType:
430    # The error is for compat with regular PyTorch, which has this behavior
431    # deprecated.  For PrimTorch, it's fine to drop support for deprecated
432    # behavior because it requires explicit opt in.  This error is to inform
433    # users how to update their calls.
434    torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
435    return torch.softmax(a=a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
436
437
438# CompositeImplicitAutograd - don't register decomp
439def softmin(
440    a: TensorLikeType,
441    dim: Optional[int] = None,
442    _stacklevel: int = 3,  # for compat when using TorchRefsMode(strict=True)
443    dtype: Optional[torch.dtype] = None,
444) -> TensorLikeType:
445    # The error is for compat with regular PyTorch, which has this behavior
446    # deprecated.  For PrimTorch, it's fine to drop support for deprecated
447    # behavior because it requires explicit opt in.  This error is to inform
448    # users how to update their calls.
449    torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
450    return torch.softmax(a=-a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
451
452
453# softplus is implemented specially because it has beta and threshold arguments
454@register_decomposition(aten.softplus)
455@_inplace_wrapper
456@out_wrapper()
457@elementwise_type_promotion_wrapper(
458    type_promoting_args=("a",),
459    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
460)
461def softplus(
462    a: TensorLikeType,
463    beta: Optional[NumberType] = None,
464    threshold: NumberType = 20,
465    inplace: bool = False,
466) -> TensorLikeType:
467    """
468    Reference implementation of torch.nn.functional.softplus
469    """
470
471    if inplace:
472        raise NotImplementedError
473
474    rhs: TensorLikeType
475    if beta is not None:
476        python_type = utils.dtype_to_type(a.dtype)
477        if not utils.is_weakly_lesser_type(type(beta), python_type):
478            msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!"
479            raise ValueError(msg)
480        scaled_input = a * beta
481        rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta)  # type: ignore[arg-type]
482
483    else:
484        scaled_input = a
485        rhs = torch.log1p(torch.exp(scaled_input))
486
487    return torch.where(scaled_input > threshold, a, rhs)
488
489
490@aten.hardshrink.default.py_impl(DispatchKey.Autograd)
491@register_decomposition(aten.hardshrink)
492@out_wrapper()
493def hardshrink(a: TensorLikeType, lambd: float = 0.5):
494    # Formula for reference,
495    # hardshrink(x) = x if x > lambd
496    #               = x if x < -lambd
497    #               = 0 otherwise
498    return torch.where(torch.abs(a) <= lambd, 0, a)
499
500
501@aten.softshrink.default.py_impl(DispatchKey.Autograd)
502@register_decomposition(aten.softshrink)
503@out_wrapper()
504def softshrink(a: TensorLikeType, lambd: float = 0.5):
505    # Formula for reference,
506    # softshrink(x) = x - lambd if x > lambd
507    #               = x + lambd if x < -lambd
508    #               = 0 otherwise
509    torch._check(
510        lambd >= 0,
511        lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
512    )
513    # We implement this in one torch.where to generate better code in the backward
514    # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211
515    return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, 0)
516
517
518# Losses
519def _reduction_int_to_str(reduction: int) -> str:
520    from torch._decomp.decompositions import Reduction
521
522    if reduction == Reduction.NONE.value:
523        return "none"
524    elif reduction == Reduction.MEAN.value:
525        return "mean"
526    elif reduction == Reduction.SUM.value:
527        return "sum"
528    else:
529        raise ValueError(f"{reduction} is not a valid value for reduction")
530
531
532def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType:
533    if reduction == "sum":
534        return torch.sum(loss)
535    elif reduction == "mean":
536        return torch.mean(loss)
537    else:  # reduction == "none"
538        return loss
539
540
541def _check_reduction_value(reduction: str):
542    if reduction not in ("mean", "sum", "none"):
543        raise ValueError(f"{reduction} is not a valid value for reduction")
544
545
546# This helper function maps depreciated arguments, "size_average" and "reduce"
547# to their corresponding "reduction" string argument
548def _get_string_reduction_arg(
549    *, size_average: Optional[bool], reduce: Optional[bool]
550) -> str:
551    if size_average is None:
552        size_average = True
553    if reduce is None:
554        reduce = True
555    if size_average and reduce:
556        ret = "mean"
557    elif reduce:
558        ret = "sum"
559    else:
560        ret = "none"
561    return ret
562
563
564# CompositeImplicitAutograd - don't register decomp
565@elementwise_type_promotion_wrapper(
566    type_promoting_args=("input", "target"),
567    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
568)
569def l1_loss(
570    input: TensorLikeType,
571    target: TensorLikeType,
572    size_average: Optional[bool] = None,
573    reduce: Optional[bool] = None,
574    reduction: str = "mean",
575) -> TensorLikeType:
576    """
577    Reference implementation of torch.nn.functional.l1_loss
578    """
579    if size_average is not None or reduce is not None:
580        # TODO: Raise exception instead of converting value.  This is only for
581        # primTorch since it can drop support for deprecated arguments.
582        # msg = "size_average and reduce args are deprecated, please use reduction argument."
583        reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
584    _check_reduction_value(reduction)
585    loss = torch.abs(input - target)
586    return _apply_loss_reduction(loss, reduction)
587
588
589@elementwise_type_promotion_wrapper(
590    type_promoting_args=("input", "target"),
591    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
592)
593def smooth_l1_loss(
594    input: TensorLikeType,
595    target: TensorLikeType,
596    size_average: Optional[bool] = None,
597    reduce: Optional[bool] = None,
598    reduction: str = "mean",
599    beta: float = 1.0,
600) -> TensorLikeType:
601    """
602    Reference implementation of torch.nn.functional.smooth_l1_loss
603    """
604    if size_average is not None or reduce is not None:
605        # TODO: Raise exception instead of converting value.  This is only for
606        # primTorch since it can drop support for deprecated arguments.
607        # msg = "size_average and reduce args are deprecated, please use reduction argument."
608        reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
609    _check_reduction_value(reduction)
610
611    if beta == 0.0:
612        return torch.nn.functional.l1_loss(
613            input, target, size_average=size_average, reduce=reduce, reduction=reduction
614        )
615    else:
616        loss = torch.abs(input - target)
617        loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
618        return _apply_loss_reduction(loss, reduction)
619
620
621# Forwarding alias: the functional variant doesn't support the out kwarg
622# CompositeImplicitAutograd - don't register decomp
623def log_softmax(
624    a: TensorLikeType,
625    dim: Optional[int] = None,
626    _stacklevel: int = 3,  # for compat when using TorchRefsMode(strict=True)
627    dtype: Optional[torch.dtype] = None,
628) -> TensorLikeType:
629    # The error is for compat with regular PyTorch, which has this behavior
630    # deprecated.  For PrimTorch, it's fine to drop support for deprecated
631    # behavior because it requires explicit opt in.  This error is to inform
632    # users how to update their calls.
633    torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
634    return torch.log_softmax(a=a, dim=dim, dtype=dtype)  # type: ignore[call-overload]
635
636
637@register_decomposition(aten.margin_ranking_loss)
638def margin_ranking_loss(
639    input1: TensorLikeType,
640    input2: TensorLikeType,
641    target: TensorLikeType,
642    margin: float = 0.0,
643    reduction: str = "mean",
644) -> TensorLikeType:
645    # loss_without_reduction = max(0, -target * (input1 - input2) + margin)
646    if input1.ndim != input2.ndim or input1.ndim != target.ndim:
647        raise RuntimeError(
648            "margin_ranking_loss : All input tensors should have same dimension but got sizes: "
649            f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} "
650        )
651    _check_reduction_value(reduction)
652    loss = torch.clamp_min(-target * (input1 - input2) + margin, 0)
653    return _apply_loss_reduction(loss, reduction)
654
655
656@elementwise_type_promotion_wrapper(
657    type_promoting_args=("input", "target"),
658    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
659)
660def mse_loss(
661    input: TensorLikeType,
662    target: TensorLikeType,
663    size_average: Optional[bool] = None,
664    reduce: Optional[bool] = None,
665    reduction: str = "mean",
666) -> TensorLikeType:
667    if size_average is not None or reduce is not None:
668        # TODO: Raise exception instead of converting value.  This is only for
669        # primTorch since it can drop support for deprecated arguments.
670        # msg = "size_average and reduce args are deprecated, please use reduction argument."
671        reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
672    _check_reduction_value(reduction)
673    loss = torch.pow(input - target, 2)
674    return _apply_loss_reduction(loss, reduction)
675
676
677@register_decomposition(aten.hinge_embedding_loss)
678def hinge_embedding_loss(
679    input: TensorLikeType,
680    target: TensorLikeType,
681    margin: float = 1.0,
682    reduction: str = "mean",
683) -> TensorLikeType:
684    # loss_without_reduction = input if y == 1
685    #                        = max(0, margin - input) if y == -1
686    _check_reduction_value(reduction)
687    margin_clamp = torch.clamp_min(margin - input, 0)
688    output_margin = torch.where(target != 1, margin_clamp, 0)
689    output_self = torch.where(target != -1, input, 0)
690    loss = output_margin + output_self
691    return _apply_loss_reduction(loss, reduction)
692
693
694def _nll_loss_nd(
695    input: TensorLikeType,
696    target: TensorLikeType,
697    weight: Optional[TensorLikeType],
698    reduction: str,
699    ignore_index: int,
700) -> TensorLikeType:
701    torch._check(
702        input.ndim > 0 and input.ndim <= 3,
703        lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.",
704    )
705
706    torch._check(
707        (input.ndim == 1) or (input.shape[0] == target.shape[0]),
708        lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
709    )
710
711    _check_reduction_value(reduction)
712
713    flat_target = torch.flatten(target)
714    ignore_classes_mask = torch.eq(flat_target, ignore_index)
715
716    # TODO: Enable data-dependent checks with debug mode
717    # TODO: This check does not work with FakeTensor inputs; See Issue #85834
718    # Explicit cast for class_check to bool; See Issue #78071
719    """
720    from torch._subclasses.fake_tensor import FakeTensor
721    num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
722    valid_classes_mask = torch.logical_and(
723        (flat_target >= 0), (flat_target < num_classes)
724    )
725    class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
726    torch._check(
727        isinstance(target, FakeTensor) or bool(class_check.item()),
728        lambda: "A target class is out-of-bounds and not the ignore index.",
729    )
730    """
731
732    ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
733    class_weight = (
734        torch.scalar_tensor(1, dtype=input.dtype, device=input.device)
735        if weight is None
736        else weight[flat_target]
737    )
738    current_weight = torch.where(
739        ignore_classes_mask,
740        ignore_class_weight,
741        class_weight,
742    )
743
744    if input.ndim == 1:
745        # implicit batch size = 1
746        # input (1 batch size, C classes)
747        loss = -input[target] * current_weight
748    elif input.ndim == 2:
749        # input (N batch size, C classes)
750        batch_size = input.shape[0]
751        loss = -input[torch.arange(batch_size), target] * current_weight
752    else:
753        # 3D case (N batch size, C classe, K dimensions)
754        # input (N batch size, C classes, K)
755        batch_size = input.shape[0]
756        extent = input.shape[2]
757        numel = batch_size * extent
758        indices = torch.arange(numel)
759        bdx = indices // extent
760        kdx = indices % extent
761        loss = -input[bdx, flat_target, kdx] * current_weight
762    loss = torch.reshape(loss, target.shape)
763
764    if reduction == "none":
765        return loss
766    elif reduction == "sum":
767        return torch.sum(loss)
768    else:
769        # calculate weighted mean of the loss function
770        return torch.sum(loss) / torch.sum(current_weight)
771
772
773@register_decomposition(aten.nll_loss)
774@out_wrapper()
775@elementwise_type_promotion_wrapper(
776    type_promoting_args=("input",),
777    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
778)
779def nll_loss(
780    input: TensorLikeType,
781    target: TensorLikeType,
782    weight: Optional[TensorLikeType] = None,
783    size_average: Optional[bool] = None,
784    ignore_index: int = -100,
785    reduce: Optional[bool] = None,
786    reduction: str = "mean",
787) -> TensorLikeType:
788    """
789    Reference implementation of torch.nn.functional.nll_loss
790    """
791    torch._check(
792        input.ndim > 0,
793        lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
794    )
795
796    # TODO: raise exception instead of converting value
797    # msg = "size_average and reduce args are deprecated, please use reduction argument."
798    # Convert these options for consistency with the eager mode
799    if size_average is not None or reduce is not None:
800        reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
801
802    # The expected behavior when the target and input have zero elements:
803    #   reduction = 'none' --- tensor([])
804    #   reduction = 'sum'  --- tensor(0.)
805    #   reduction = 'mean' --- tensor(nan)
806    # Mean reduction on empty tensors produces NaN. See the discussion in
807    # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
808    if input.numel() == 0 and target.numel() == 0:
809        if reduction == "none":
810            return torch.zeros_like(target)
811        elif reduction == "sum":
812            return torch.empty_like(target)
813        else:
814            return torch.full_like(target, float("nan"))
815
816    # The _nll_loss_nd helper function handles the most common cases.
817    # ndim == 1 (Single Example)
818    #   => Batch Size: 1, Input: (C), Target: ()
819    # ndim == 2 (k = 1)
820    #   => Batch Size: N, Input: (N, C), Target: (N)
821    # ndim == 3 (k > 1)
822    #   => Batch Size: N, Input: (N, C, K), Target: (N, K)
823    if input.ndim <= 3:
824        return _nll_loss_nd(input, target, weight, reduction, ignore_index)
825
826    # For ndim > 3, we reshape the input and target to 3-D case.
827    # Input (N batch-size, C classes, k-dimensions)
828    # Target (N batch-size, k-dimensions)
829    torch._check(
830        input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
831        lambda: (
832            "Expected input and target to both have ndim > 0 and "
833            "target.shape[1:] == input.shape[2:], but got "
834            f"target.shape {target.shape} and input.shape {input.shape}"
835        ),
836    )
837
838    batch_size = input.shape[0]
839    num_classes = input.shape[1]
840    out_size = [batch_size] + list(target.shape[1:])
841
842    input = torch.reshape(input, [batch_size, num_classes, -1])
843    target = torch.reshape(target, [batch_size, -1])
844    if reduction != "none":
845        return _nll_loss_nd(input, target, weight, reduction, ignore_index)
846    else:
847        result = _nll_loss_nd(input, target, weight, reduction, ignore_index)
848        # reshape flattened inner-dim to original k-dimensions
849        return torch.reshape(result, out_size)
850
851
852# TODO: This ref supports int reduction and out kwarg to be compatible with ATen:
853# https://github.com/pytorch/pytorch/issues/83931
854# TODO: Could be rewritten to support complex:
855# https://github.com/pytorch/pytorch/pull/85041
856@register_decomposition(aten.huber_loss)
857@out_wrapper()
858@elementwise_type_promotion_wrapper(
859    type_promoting_args=("input", "target"),
860    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
861)
862def huber_loss(
863    input: TensorLikeType,
864    target: TensorLikeType,
865    reduction: Union[str, int] = "mean",
866    delta: float = 1.0,
867) -> TensorLikeType:
868    """
869    Reference implementation of torch.nn.functional.huber_loss
870    """
871    if type(reduction) is int:
872        reduction = _reduction_int_to_str(reduction)
873    _check_reduction_value(reduction)  # type: ignore[arg-type]
874    torch._check(
875        delta > 0,
876        lambda: "huber_loss does not support non-positive values for delta.",
877    )
878    z = (input - target).abs()
879    loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
880    return _apply_loss_reduction(loss, reduction)  # type: ignore[arg-type]
881
882
883# tanhshrink does not use _make_elementwise_unary_reference because it does not support out
884@elementwise_unary_scalar_wrapper
885@elementwise_type_promotion_wrapper(
886    type_promoting_args=("a",),
887    type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
888)
889def tanhshrink(a: TensorLikeType) -> TensorLikeType:
890    """
891    Reference implementation of torch.nn.functional.tanhshrink
892    """
893    if not isinstance(a, TensorLike):
894        raise RuntimeError(
895            "Expected a tensor input for an elementwise unary operation!"
896        )
897    return a - torch.tanh(a)
898
899
900@register_decomposition(aten.threshold)
901@_inplace_wrapper
902@out_wrapper()
903@elementwise_type_promotion_wrapper(
904    type_promoting_args=("a",),
905    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
906)
907def threshold(
908    a: TensorLikeType,
909    threshold: NumberType,
910    value: Union[bool, int, float],
911    inplace: bool = False,
912) -> TensorLikeType:
913    """
914    Reference implementation of torch.nn.functional.threshold
915    """
916
917    if inplace:
918        raise NotImplementedError
919
920    return torch.where(a <= threshold, value, a)
921
922
923# CompositeImplicitAutograd - don't register decomp
924# No elementwise type promotion - core op doesn't explicitly type promote
925def triplet_margin_loss(
926    anchor: TensorLikeType,
927    positive: TensorLikeType,
928    negative: TensorLikeType,
929    margin: float = 1.0,
930    p: float = 2,
931    eps: float = 1e-6,
932    swap: bool = False,
933    size_average: Optional[bool] = None,
934    reduce: Optional[bool] = None,
935    reduction: str = "mean",
936) -> TensorLikeType:
937    if size_average is not None or reduce is not None:
938        # TODO: Raise exception instead of converting value.  This is only for
939        # primTorch since it can drop support for deprecated arguments.
940        # msg = "size_average and reduce args are deprecated, please use reduction argument."
941        reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
942
943    if margin <= 0:
944        raise ValueError(f"margin must be greater than 0, got {margin}")
945
946    # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined
947    # since it's a pure Python implementation.  Use this helper instead.
948    return _triplet_margin_with_distance_loss(
949        anchor=anchor,
950        positive=positive,
951        negative=negative,
952        distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps),
953        margin=margin,
954        swap=swap,
955        reduction=reduction,
956    )
957
958
959# Pure Python impl - don't register decomp and don't add a ref.  Defined as a
960# helper here since triplet_margin_loss can be nicely implemented with it.
961def _triplet_margin_with_distance_loss(
962    anchor: TensorLikeType,
963    positive: TensorLikeType,
964    negative: TensorLikeType,
965    *,
966    distance_function: Optional[
967        Callable[[TensorLikeType, TensorLikeType], TensorLikeType]
968    ] = None,
969    margin: float = 1.0,
970    swap: bool = False,
971    reduction: str = "mean",
972) -> TensorLikeType:
973    _check_reduction_value(reduction)
974
975    a_dim = anchor.ndim
976    p_dim = positive.ndim
977    n_dim = negative.ndim
978    torch._check(
979        a_dim == p_dim and p_dim == n_dim,
980        lambda: (
981            f"The anchor, positive, and negative tensors are expected to have "
982            f"the same number of dimensions, but got: anchor {a_dim}D, "
983            f"positive {p_dim}D, and negative {n_dim}D inputs"
984        ),
985    )
986
987    if distance_function is None:
988        distance_function = torch.pairwise_distance
989
990    dist_pos = distance_function(anchor, positive)
991    dist_neg = distance_function(anchor, negative)
992    # The distance swap is described in the paper "Learning shallow
993    # convolutional feature descriptors with triplet losses" by V. Balntas, E.
994    # Riba et al.  If True, and if the positive example is closer to the
995    # negative example than the anchor is, swaps the positive example and the
996    # anchor in the loss computation.
997    if swap:
998        dist_swap = distance_function(positive, negative)
999        dist_neg = torch.minimum(dist_neg, dist_swap)
1000    loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
1001    return _apply_loss_reduction(loss, reduction)
1002
1003
1004@register_decomposition(aten.hardtanh)
1005@_inplace_wrapper
1006@out_wrapper()
1007@elementwise_unary_scalar_wrapper
1008@elementwise_type_promotion_wrapper(
1009    type_promoting_args=("a"),
1010    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1011)
1012def hardtanh(
1013    a: TensorLikeType,
1014    min_val: NumberType = -1,
1015    max_val: NumberType = 1,
1016    inplace: bool = False,
1017) -> TensorLikeType:
1018    """
1019    Reference implementation of torch.nn.functional.hardtanh
1020    """
1021    if inplace:
1022        raise NotImplementedError
1023    if utils.is_boolean_dtype(a.dtype):
1024        raise RuntimeError("Bool inputs not supported for hardtanh")
1025
1026    # preserve legacy behavior of boundaries not causing type promotion
1027    if utils.is_integer_dtype(a.dtype):
1028        min_val = int(min_val)  # type: ignore[arg-type]
1029        max_val = int(max_val)  # type: ignore[arg-type]
1030        if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)):
1031            raise RuntimeError(
1032                "Cannot do hardtanh on an unsigned type with negative limits"
1033            )
1034
1035    if min_val > max_val:  # type: ignore[operator]
1036        raise ValueError("min_val cannot be greater than max_val")
1037
1038    return torch.clamp(a, min_val, max_val)  # type: ignore[arg-type]
1039
1040
1041@register_decomposition(aten.gelu)
1042@out_wrapper()
1043@elementwise_unary_scalar_wrapper
1044@elementwise_type_promotion_wrapper(
1045    type_promoting_args=("a",),
1046    type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1047)
1048def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType:
1049    """
1050    Reference implementation of torch.nn.functional.gelu
1051    """
1052    if not isinstance(a, TensorLike):
1053        raise RuntimeError(
1054            "Expected a tensor input for an elementwise unary operation!"
1055        )
1056    M_SQRT2 = 1.41421356237309504880
1057    M_SQRT1_2 = 0.70710678118654752440
1058    M_2_SQRTPI = 1.12837916709551257390
1059    if approximate == "tanh":
1060        kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
1061        kKappa = 0.044715
1062        a_cube = a * a * a
1063        inner = kBeta * (a + kKappa * a_cube)
1064        return 0.5 * a * (1 + torch.tanh(inner))
1065    elif approximate == "none":
1066        kAlpha = M_SQRT1_2
1067        return a * 0.5 * (1 + torch.erf(a * kAlpha))
1068    else:
1069        raise RuntimeError("approximate argument must be either none or tanh.")
1070
1071
1072# CompositeImplicitAutograd - don't register decomp
1073@elementwise_type_promotion_wrapper(
1074    type_promoting_args=("input", "target"),
1075    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1076)
1077def poisson_nll_loss(
1078    input: TensorLikeType,
1079    target: TensorLikeType,
1080    log_input: bool = True,
1081    full: bool = False,
1082    size_average: Optional[bool] = None,
1083    eps: float = 1e-8,
1084    reduce: Optional[bool] = None,
1085    reduction: str = "mean",
1086) -> TensorLikeType:
1087    """
1088    Reference implementation of torch.nn.functional.poisson_nll_loss
1089    """
1090    if size_average is not None or reduce is not None:
1091        # TODO: Raise exception instead of converting value.  This is only for
1092        # primTorch since it can drop support for deprecated arguments.
1093        # msg = "size_average and reduce args are deprecated, please use reduction argument."
1094        reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
1095    _check_reduction_value(reduction)
1096    if log_input:
1097        loss = torch.exp(input) - target * input
1098    else:
1099        loss = input - target * torch.log(input + eps)
1100
1101    if full:
1102        stirling_term = (
1103            target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target)
1104        )
1105        # avoid inplace add
1106        loss = loss + stirling_term.masked_fill(target <= 1, 0)
1107    return _apply_loss_reduction(loss, reduction)
1108
1109
1110@register_decomposition(aten.prelu)
1111@elementwise_type_promotion_wrapper(
1112    type_promoting_args=("a", "weight"),
1113    type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1114)
1115def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
1116    """
1117    Reference implementation of torch.nn.functional.prelu
1118    """
1119    torch._check(
1120        isinstance(a, TensorLike),
1121        lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
1122    )
1123    torch._check(
1124        isinstance(weight, TensorLike),
1125        lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
1126    )
1127
1128    if weight.numel() != 1:
1129        torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
1130        channel_size = a.shape[1] if a.ndim >= 2 else 1
1131        torch._check(
1132            weight.numel() == channel_size,
1133            lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
1134            f" {weight.numel()} and channel size = {channel_size}.",
1135        )
1136
1137    torch._check(
1138        weight.ndim == 0 or weight.ndim == 1,
1139        lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
1140        f"ndim = {weight.ndim}",
1141    )
1142    if a.ndim == 0:
1143        weight = weight[0] if weight.ndim == 1 else weight
1144    else:
1145        weight = prims.broadcast_in_dim(
1146            weight, a.shape, () if weight.ndim == 0 else (0 if a.ndim == 1 else 1,)
1147        )
1148
1149    return torch.where(a > 0, a, a * weight)
1150
1151
1152@register_decomposition(aten.relu6)
1153@_inplace_wrapper
1154@out_wrapper()
1155def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
1156    """
1157    Reference implementation of torch.nn.functional.relu6
1158    """
1159    if inplace:
1160        raise NotImplementedError
1161
1162    # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126
1163    # It may be better to use clamp here, but we use hardtanh to replicate
1164    # the behavior of the existing implementation
1165    return torch.nn.functional.hardtanh(a, 0, 6)
1166
1167
1168@register_decomposition(aten.glu)
1169@out_wrapper()
1170@elementwise_type_promotion_wrapper(
1171    type_promoting_args=("a",),
1172    type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1173)
1174def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
1175    dim = utils.canonicalize_dims(a.ndim, dim)
1176    torch._check(
1177        a.shape[dim] % 2 == 0,
1178        lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
1179    )
1180    b, c = torch.tensor_split(a, 2, dim)
1181
1182    return b * torch.sigmoid(c)
1183
1184
1185@register_decomposition(aten.pairwise_distance)
1186@out_wrapper()
1187def pairwise_distance(
1188    x1: TensorLikeType,
1189    x2: TensorLikeType,
1190    p: NumberType = 2.0,
1191    eps: NumberType = 1e-6,
1192    keepdim=False,
1193) -> TensorLikeType:
1194    return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim)
1195
1196
1197@register_decomposition(aten.pdist)
1198@out_wrapper()
1199@elementwise_type_promotion_wrapper(
1200    type_promoting_args=("a",),
1201    type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1202)
1203def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
1204    torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
1205    torch._check(p >= 0, lambda: "pdist only supports non-negative p values")
1206    # For p == 2 we can use an efficient implementation, but other values of p
1207    # require creating a much bigger tensor for an intermediate step
1208    if p == 2:
1209        aTa = torch.mm(a, a.T)
1210        aTa_diag = torch.diag(aTa)
1211        t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0))
1212    else:
1213        t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2)
1214    i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device)
1215    return t.flatten().index_select(0, i[0] * t.shape[0] + i[1])
1216
1217
1218@register_decomposition(aten.pixel_shuffle)
1219@out_wrapper()
1220def pixel_shuffle(self: Tensor, upscale_factor: int):
1221    torch._check(
1222        self.dim() >= 3,
1223        lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)",
1224    )
1225    batch = self.shape[:-3]
1226    C_out = self.shape[-3] // upscale_factor**2
1227    HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor)
1228    n = len(batch)
1229    B_dims = range(n)
1230    C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5)
1231    return (
1232        self.view(
1233            *batch,
1234            C_out,
1235            upscale_factor,
1236            upscale_factor,
1237            self.shape[-2],
1238            self.shape[-1],
1239        )
1240        .permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim)
1241        .reshape(*batch, C_out, *HW_out)
1242        .clone(memory_format=utils.suggest_memory_format(self))
1243    )
1244
1245
1246@register_decomposition(aten.pixel_unshuffle)
1247@out_wrapper()
1248def pixel_unshuffle(self: Tensor, downscale_factor: int):
1249    torch._check(
1250        self.dim() >= 3,
1251        lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)",
1252    )
1253    batch = self.shape[:-3]
1254    C_out = self.shape[-3] * downscale_factor**2
1255    HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor)
1256    n = len(batch)
1257    B_dims = range(n)
1258    C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5)
1259    return (
1260        self.view(
1261            *batch,
1262            self.shape[-3],
1263            HW_out[0],
1264            downscale_factor,
1265            HW_out[1],
1266            downscale_factor,
1267        )
1268        .permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim)
1269        .reshape(*batch, C_out, *HW_out)
1270        .clone(memory_format=utils.suggest_memory_format(self))
1271    )
1272
1273
1274# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg)
1275celu_ = _make_inplace(celu)
1276elu_ = _make_inplace(elu)
1277mish_ = _make_inplace(mish)
1278selu_ = _make_inplace(selu)
1279threshold_ = _make_inplace(threshold)
1280