1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import math 4from functools import wraps 5from typing import Callable, Optional, Union 6 7import torch 8import torch._prims as prims 9import torch._prims_common as utils 10import torch._refs as refs 11from torch._decomp import register_decomposition 12from torch._prims_common import ( 13 ELEMENTWISE_TYPE_PROMOTION_KIND, 14 NumberType, 15 ShapeType, 16 TensorLike, 17 TensorLikeType, 18) 19from torch._prims_common.wrappers import ( 20 elementwise_type_promotion_wrapper, 21 elementwise_unary_scalar_wrapper, 22 out_wrapper, 23) 24from torch._refs import _make_inplace 25 26 27__all__ = [ 28 "alpha_dropout", 29 "celu", 30 "celu_", 31 "channel_shuffle", 32 "dropout", 33 "elu", 34 "elu_", 35 "gelu", 36 "glu", 37 "group_norm", 38 "hardshrink", 39 "hardtanh", 40 "hinge_embedding_loss", 41 "huber_loss", 42 "l1_loss", 43 "layer_norm", 44 "leaky_relu", 45 "log_softmax", 46 "margin_ranking_loss", 47 "mish", 48 "mish_", 49 "mse_loss", 50 "nll_loss", 51 "pairwise_distance", 52 "pdist", 53 "poisson_nll_loss", 54 "prelu", 55 "relu", 56 "relu6", 57 "selu", 58 "selu_", 59 "smooth_l1_loss", 60 "softmax", 61 "softmin", 62 "softplus", 63 "softshrink", 64 "tanhshrink", 65 "threshold", 66 "threshold_", 67 "triplet_margin_loss", 68] 69 70Tensor = torch.Tensor 71aten = torch._ops.ops.aten 72DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] 73 74 75def _dropout_helper( 76 self: TensorLikeType, 77 val: float, 78) -> TensorLikeType: 79 """ 80 Helper function for all dropout-type operators. During training, 81 some of the elements of the input tensor are randomly masked. 82 83 Returns the masked tensor of the boolean values. 84 85 """ 86 87 return ( 88 refs._uniform_helper( 89 self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device 90 ) 91 < val 92 ) 93 94 95@register_decomposition(aten.alpha_dropout) 96def alpha_dropout( 97 self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False 98) -> TensorLikeType: 99 if inplace: 100 raise NotImplementedError 101 102 if not training: 103 return self 104 105 torch._check( 106 p <= 1 and p >= 0, 107 lambda: f"dropout probability has to be between 0 and 1, but got, {p}", 108 ) 109 110 if p == 1: 111 return torch.zeros_like(self) 112 113 if p == 0: 114 return self 115 116 dropout_mask = _dropout_helper(self, 1 - p) 117 118 # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf) 119 # alpha = - SELU.alpha * SELU.scale, here 120 # SELU.alpha = 1.6732632423543772848170429916717 and 121 # SELU.scale = 1.0507009873554804934193349852946 122 alpha = -1.7580993408473766 123 124 a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p)) 125 b = torch.logical_not(dropout_mask) 126 b = b * (alpha * a) + alpha * a * p 127 dropout_mask = a * dropout_mask 128 129 return self * dropout_mask + b 130 131 132def _inplace_wrapper(fn): 133 """ 134 Given a nn.functional non-linearity, implements its `inplace: bool` argument 135 """ 136 137 # nb. We use the name of the first argument used in the unary references 138 @wraps(fn) 139 def _fn(a, *args, inplace=False, **kwargs): 140 if inplace: 141 torch._check( 142 "out" not in kwargs, 143 lambda: "Cannot set inplace=True and pass out= at the same time", 144 ) 145 return fn(a, *args, inplace=False, out=a, **kwargs) 146 else: 147 return fn(a, *args, inplace=False, **kwargs) 148 149 return _fn 150 151 152# celu is implemented specially because it has an alpha argument 153# celu is very similar to elu 154@register_decomposition(aten.celu) 155@_inplace_wrapper 156@out_wrapper() 157@elementwise_type_promotion_wrapper( 158 type_promoting_args=("a",), 159 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 160) 161def celu( 162 a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False 163) -> TensorLikeType: 164 """ 165 Reference implementation of torch.nn.functional.celu 166 """ 167 168 if inplace: 169 raise NotImplementedError 170 171 rhs: TensorLikeType 172 if alpha is not None: 173 python_type = utils.dtype_to_type(a.dtype) 174 if not utils.is_weakly_lesser_type(type(alpha), python_type): 175 msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" 176 raise ValueError(msg) 177 rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type] 178 else: 179 rhs = torch.expm1(a) 180 181 return torch.where(a > 0, a, rhs) 182 183 184@_inplace_wrapper 185@out_wrapper() 186def dropout( 187 a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False 188) -> TensorLikeType: 189 if inplace: 190 raise NotImplementedError 191 192 if not training: 193 return a 194 195 torch._check( 196 p <= 1 and p >= 0, 197 lambda: f"dropout probability has to be between 0 and 1, but got, {p}", 198 ) 199 200 if p == 1: 201 return torch.zeros_like(a) 202 203 if p == 0: 204 return a 205 206 scale = 1 / (1 - p) 207 dropout_mask = _dropout_helper(a, 1 - p) 208 209 return a * dropout_mask * scale 210 211 212@register_decomposition(aten.elu) 213@_inplace_wrapper 214@out_wrapper() 215@elementwise_type_promotion_wrapper( 216 type_promoting_args=("a",), 217 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 218) 219def elu( 220 a: TensorLikeType, 221 alpha: NumberType = 1.0, 222 scale: NumberType = 1.0, 223 input_scale: NumberType = 1.0, 224 inplace: bool = False, 225) -> TensorLikeType: 226 """ 227 Reference implementation of torch.nn.functional.elu 228 """ 229 if inplace: 230 raise NotImplementedError 231 232 # nb. This should be factored out into a can_cast aux function 233 python_type = utils.dtype_to_type(a.dtype) 234 torch._check( 235 utils.is_weakly_lesser_type(type(input_scale), python_type), 236 lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", 237 ) 238 torch._check( 239 utils.is_weakly_lesser_type(type(scale), python_type), 240 lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", 241 ) 242 torch._check( 243 utils.is_weakly_lesser_type(type(alpha), python_type), 244 lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", 245 ) 246 247 return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale)) 248 249 250@register_decomposition(aten.relu) 251@_inplace_wrapper 252@out_wrapper() 253@elementwise_type_promotion_wrapper( 254 type_promoting_args=("a",), 255 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 256) 257def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: 258 """ 259 Reference implementation of torch.nn.functional.relu 260 """ 261 262 if inplace: 263 raise NotImplementedError 264 265 return torch.where(torch.le(a, 0), 0, a) 266 267 268@register_decomposition(aten.channel_shuffle) 269@out_wrapper() 270def channel_shuffle(input: TensorLikeType, groups: int) -> TensorLikeType: 271 """ 272 Reference implementation of :func:`torch.nn.functional.channel_shuffle`. 273 """ 274 from torch._meta_registrations import device_hint 275 276 torch._check( 277 input.dim() > 2, 278 lambda: f"channel_shuffle expects input with > 2 dims, but got input with sizes {list(input.size())}", 279 ) 280 c = input.shape[1] 281 torch._check( 282 groups > 0, 283 lambda: f"Number of groups to divide channels in must be positive. Value of groups:{groups}", 284 ) 285 torch._check( 286 (c % groups) == 0, 287 lambda: f"Number of channels must be divisible by groups. Got {c} channels and {groups} groups.", 288 ) 289 n = input.shape[0] 290 cg = c // groups 291 dhw = input.shape[2:] 292 293 if input.numel() == 0 or ( 294 device_hint(input) == "cuda" and (groups == 1 or groups == c) 295 ): 296 return input.view(input.shape) 297 298 return ( 299 input.reshape(n, groups, cg, *dhw) 300 .transpose(1, 2) 301 .reshape(input.shape) 302 .contiguous() 303 ) 304 305 306def group_norm( 307 input: Tensor, 308 num_groups: int, 309 weight: Optional[Tensor] = None, 310 bias: Optional[Tensor] = None, 311 eps: float = 1e-5, 312) -> Tensor: 313 """ 314 Reference implementation of :func:`torch.nn.functional.group_norm`. 315 """ 316 torch._check( 317 input.ndim >= 2, 318 lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", 319 ) 320 321 batch_size = input.shape[0] 322 num_channels = input.shape[1] 323 torch._check( 324 num_channels % num_groups == 0, 325 lambda: "Expected number of channels in input to be divisible by num_groups, " 326 + f"but got input of shape {input.shape} and num_groups = {num_groups}", 327 ) 328 329 # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) 330 flattened_inner_size = 1 331 for dim_length in input.shape[2:]: 332 flattened_inner_size *= dim_length 333 334 return torch.native_group_norm( 335 input, 336 weight, 337 bias, 338 batch_size, 339 num_channels, 340 flattened_inner_size, 341 num_groups, 342 eps, 343 )[0] 344 345 346def layer_norm( 347 input: Tensor, 348 normalized_shape: ShapeType, 349 weight: Optional[Tensor] = None, 350 bias: Optional[Tensor] = None, 351 eps: float = 1e-5, 352) -> Tensor: 353 """ 354 Reference implementation of :func:`torch.nn.functional.layer_norm`. 355 """ 356 return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0] 357 358 359@register_decomposition(aten.leaky_relu) 360@_inplace_wrapper 361@out_wrapper() 362@elementwise_type_promotion_wrapper( 363 type_promoting_args=("a",), 364 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 365) 366def leaky_relu( 367 a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False 368) -> TensorLikeType: 369 """ 370 Reference implementation of torch.nn.functional.leaky_relu 371 """ 372 373 if inplace: 374 raise NotImplementedError 375 376 python_type = utils.dtype_to_type(a.dtype) 377 if not utils.is_weakly_lesser_type(type(negative_slope), python_type): 378 msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" 379 raise ValueError(msg) 380 return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope)) 381 382 383@register_decomposition(aten.mish) 384@_inplace_wrapper 385@out_wrapper() 386@elementwise_type_promotion_wrapper( 387 type_promoting_args=("a",), 388 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 389) 390def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: 391 """ 392 Reference implementation of torch.nn.functional.mish 393 """ 394 395 if inplace: 396 raise NotImplementedError 397 return a * torch.tanh(torch.nn.functional.softplus(a)) 398 399 400@register_decomposition(aten.selu) 401@_inplace_wrapper 402@out_wrapper() 403@elementwise_type_promotion_wrapper( 404 type_promoting_args=("a",), 405 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 406) 407def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: 408 """ 409 Reference implementation of torch.nn.functional.selu 410 """ 411 if inplace: 412 raise NotImplementedError 413 414 alpha = 1.6732632423543772848170429916717 415 scale = 1.0507009873554804934193349852946 416 417 rhs = alpha * torch.expm1(a) 418 419 return scale * torch.where(a > 0, a, rhs) 420 421 422# Forwarding alias: the functional variant doesn't support the out kwarg 423# CompositeImplicitAutograd - don't register decomp 424def softmax( 425 a: TensorLikeType, 426 dim: Optional[int] = None, 427 _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) 428 dtype: Optional[torch.dtype] = None, 429) -> TensorLikeType: 430 # The error is for compat with regular PyTorch, which has this behavior 431 # deprecated. For PrimTorch, it's fine to drop support for deprecated 432 # behavior because it requires explicit opt in. This error is to inform 433 # users how to update their calls. 434 torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") 435 return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] 436 437 438# CompositeImplicitAutograd - don't register decomp 439def softmin( 440 a: TensorLikeType, 441 dim: Optional[int] = None, 442 _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) 443 dtype: Optional[torch.dtype] = None, 444) -> TensorLikeType: 445 # The error is for compat with regular PyTorch, which has this behavior 446 # deprecated. For PrimTorch, it's fine to drop support for deprecated 447 # behavior because it requires explicit opt in. This error is to inform 448 # users how to update their calls. 449 torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") 450 return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload] 451 452 453# softplus is implemented specially because it has beta and threshold arguments 454@register_decomposition(aten.softplus) 455@_inplace_wrapper 456@out_wrapper() 457@elementwise_type_promotion_wrapper( 458 type_promoting_args=("a",), 459 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 460) 461def softplus( 462 a: TensorLikeType, 463 beta: Optional[NumberType] = None, 464 threshold: NumberType = 20, 465 inplace: bool = False, 466) -> TensorLikeType: 467 """ 468 Reference implementation of torch.nn.functional.softplus 469 """ 470 471 if inplace: 472 raise NotImplementedError 473 474 rhs: TensorLikeType 475 if beta is not None: 476 python_type = utils.dtype_to_type(a.dtype) 477 if not utils.is_weakly_lesser_type(type(beta), python_type): 478 msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!" 479 raise ValueError(msg) 480 scaled_input = a * beta 481 rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] 482 483 else: 484 scaled_input = a 485 rhs = torch.log1p(torch.exp(scaled_input)) 486 487 return torch.where(scaled_input > threshold, a, rhs) 488 489 490@aten.hardshrink.default.py_impl(DispatchKey.Autograd) 491@register_decomposition(aten.hardshrink) 492@out_wrapper() 493def hardshrink(a: TensorLikeType, lambd: float = 0.5): 494 # Formula for reference, 495 # hardshrink(x) = x if x > lambd 496 # = x if x < -lambd 497 # = 0 otherwise 498 return torch.where(torch.abs(a) <= lambd, 0, a) 499 500 501@aten.softshrink.default.py_impl(DispatchKey.Autograd) 502@register_decomposition(aten.softshrink) 503@out_wrapper() 504def softshrink(a: TensorLikeType, lambd: float = 0.5): 505 # Formula for reference, 506 # softshrink(x) = x - lambd if x > lambd 507 # = x + lambd if x < -lambd 508 # = 0 otherwise 509 torch._check( 510 lambd >= 0, 511 lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", 512 ) 513 # We implement this in one torch.where to generate better code in the backward 514 # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211 515 return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, 0) 516 517 518# Losses 519def _reduction_int_to_str(reduction: int) -> str: 520 from torch._decomp.decompositions import Reduction 521 522 if reduction == Reduction.NONE.value: 523 return "none" 524 elif reduction == Reduction.MEAN.value: 525 return "mean" 526 elif reduction == Reduction.SUM.value: 527 return "sum" 528 else: 529 raise ValueError(f"{reduction} is not a valid value for reduction") 530 531 532def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType: 533 if reduction == "sum": 534 return torch.sum(loss) 535 elif reduction == "mean": 536 return torch.mean(loss) 537 else: # reduction == "none" 538 return loss 539 540 541def _check_reduction_value(reduction: str): 542 if reduction not in ("mean", "sum", "none"): 543 raise ValueError(f"{reduction} is not a valid value for reduction") 544 545 546# This helper function maps depreciated arguments, "size_average" and "reduce" 547# to their corresponding "reduction" string argument 548def _get_string_reduction_arg( 549 *, size_average: Optional[bool], reduce: Optional[bool] 550) -> str: 551 if size_average is None: 552 size_average = True 553 if reduce is None: 554 reduce = True 555 if size_average and reduce: 556 ret = "mean" 557 elif reduce: 558 ret = "sum" 559 else: 560 ret = "none" 561 return ret 562 563 564# CompositeImplicitAutograd - don't register decomp 565@elementwise_type_promotion_wrapper( 566 type_promoting_args=("input", "target"), 567 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, 568) 569def l1_loss( 570 input: TensorLikeType, 571 target: TensorLikeType, 572 size_average: Optional[bool] = None, 573 reduce: Optional[bool] = None, 574 reduction: str = "mean", 575) -> TensorLikeType: 576 """ 577 Reference implementation of torch.nn.functional.l1_loss 578 """ 579 if size_average is not None or reduce is not None: 580 # TODO: Raise exception instead of converting value. This is only for 581 # primTorch since it can drop support for deprecated arguments. 582 # msg = "size_average and reduce args are deprecated, please use reduction argument." 583 reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) 584 _check_reduction_value(reduction) 585 loss = torch.abs(input - target) 586 return _apply_loss_reduction(loss, reduction) 587 588 589@elementwise_type_promotion_wrapper( 590 type_promoting_args=("input", "target"), 591 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, 592) 593def smooth_l1_loss( 594 input: TensorLikeType, 595 target: TensorLikeType, 596 size_average: Optional[bool] = None, 597 reduce: Optional[bool] = None, 598 reduction: str = "mean", 599 beta: float = 1.0, 600) -> TensorLikeType: 601 """ 602 Reference implementation of torch.nn.functional.smooth_l1_loss 603 """ 604 if size_average is not None or reduce is not None: 605 # TODO: Raise exception instead of converting value. This is only for 606 # primTorch since it can drop support for deprecated arguments. 607 # msg = "size_average and reduce args are deprecated, please use reduction argument." 608 reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) 609 _check_reduction_value(reduction) 610 611 if beta == 0.0: 612 return torch.nn.functional.l1_loss( 613 input, target, size_average=size_average, reduce=reduce, reduction=reduction 614 ) 615 else: 616 loss = torch.abs(input - target) 617 loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) 618 return _apply_loss_reduction(loss, reduction) 619 620 621# Forwarding alias: the functional variant doesn't support the out kwarg 622# CompositeImplicitAutograd - don't register decomp 623def log_softmax( 624 a: TensorLikeType, 625 dim: Optional[int] = None, 626 _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) 627 dtype: Optional[torch.dtype] = None, 628) -> TensorLikeType: 629 # The error is for compat with regular PyTorch, which has this behavior 630 # deprecated. For PrimTorch, it's fine to drop support for deprecated 631 # behavior because it requires explicit opt in. This error is to inform 632 # users how to update their calls. 633 torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") 634 return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] 635 636 637@register_decomposition(aten.margin_ranking_loss) 638def margin_ranking_loss( 639 input1: TensorLikeType, 640 input2: TensorLikeType, 641 target: TensorLikeType, 642 margin: float = 0.0, 643 reduction: str = "mean", 644) -> TensorLikeType: 645 # loss_without_reduction = max(0, -target * (input1 - input2) + margin) 646 if input1.ndim != input2.ndim or input1.ndim != target.ndim: 647 raise RuntimeError( 648 "margin_ranking_loss : All input tensors should have same dimension but got sizes: " 649 f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} " 650 ) 651 _check_reduction_value(reduction) 652 loss = torch.clamp_min(-target * (input1 - input2) + margin, 0) 653 return _apply_loss_reduction(loss, reduction) 654 655 656@elementwise_type_promotion_wrapper( 657 type_promoting_args=("input", "target"), 658 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, 659) 660def mse_loss( 661 input: TensorLikeType, 662 target: TensorLikeType, 663 size_average: Optional[bool] = None, 664 reduce: Optional[bool] = None, 665 reduction: str = "mean", 666) -> TensorLikeType: 667 if size_average is not None or reduce is not None: 668 # TODO: Raise exception instead of converting value. This is only for 669 # primTorch since it can drop support for deprecated arguments. 670 # msg = "size_average and reduce args are deprecated, please use reduction argument." 671 reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) 672 _check_reduction_value(reduction) 673 loss = torch.pow(input - target, 2) 674 return _apply_loss_reduction(loss, reduction) 675 676 677@register_decomposition(aten.hinge_embedding_loss) 678def hinge_embedding_loss( 679 input: TensorLikeType, 680 target: TensorLikeType, 681 margin: float = 1.0, 682 reduction: str = "mean", 683) -> TensorLikeType: 684 # loss_without_reduction = input if y == 1 685 # = max(0, margin - input) if y == -1 686 _check_reduction_value(reduction) 687 margin_clamp = torch.clamp_min(margin - input, 0) 688 output_margin = torch.where(target != 1, margin_clamp, 0) 689 output_self = torch.where(target != -1, input, 0) 690 loss = output_margin + output_self 691 return _apply_loss_reduction(loss, reduction) 692 693 694def _nll_loss_nd( 695 input: TensorLikeType, 696 target: TensorLikeType, 697 weight: Optional[TensorLikeType], 698 reduction: str, 699 ignore_index: int, 700) -> TensorLikeType: 701 torch._check( 702 input.ndim > 0 and input.ndim <= 3, 703 lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", 704 ) 705 706 torch._check( 707 (input.ndim == 1) or (input.shape[0] == target.shape[0]), 708 lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", 709 ) 710 711 _check_reduction_value(reduction) 712 713 flat_target = torch.flatten(target) 714 ignore_classes_mask = torch.eq(flat_target, ignore_index) 715 716 # TODO: Enable data-dependent checks with debug mode 717 # TODO: This check does not work with FakeTensor inputs; See Issue #85834 718 # Explicit cast for class_check to bool; See Issue #78071 719 """ 720 from torch._subclasses.fake_tensor import FakeTensor 721 num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] 722 valid_classes_mask = torch.logical_and( 723 (flat_target >= 0), (flat_target < num_classes) 724 ) 725 class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) 726 torch._check( 727 isinstance(target, FakeTensor) or bool(class_check.item()), 728 lambda: "A target class is out-of-bounds and not the ignore index.", 729 ) 730 """ 731 732 ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) 733 class_weight = ( 734 torch.scalar_tensor(1, dtype=input.dtype, device=input.device) 735 if weight is None 736 else weight[flat_target] 737 ) 738 current_weight = torch.where( 739 ignore_classes_mask, 740 ignore_class_weight, 741 class_weight, 742 ) 743 744 if input.ndim == 1: 745 # implicit batch size = 1 746 # input (1 batch size, C classes) 747 loss = -input[target] * current_weight 748 elif input.ndim == 2: 749 # input (N batch size, C classes) 750 batch_size = input.shape[0] 751 loss = -input[torch.arange(batch_size), target] * current_weight 752 else: 753 # 3D case (N batch size, C classe, K dimensions) 754 # input (N batch size, C classes, K) 755 batch_size = input.shape[0] 756 extent = input.shape[2] 757 numel = batch_size * extent 758 indices = torch.arange(numel) 759 bdx = indices // extent 760 kdx = indices % extent 761 loss = -input[bdx, flat_target, kdx] * current_weight 762 loss = torch.reshape(loss, target.shape) 763 764 if reduction == "none": 765 return loss 766 elif reduction == "sum": 767 return torch.sum(loss) 768 else: 769 # calculate weighted mean of the loss function 770 return torch.sum(loss) / torch.sum(current_weight) 771 772 773@register_decomposition(aten.nll_loss) 774@out_wrapper() 775@elementwise_type_promotion_wrapper( 776 type_promoting_args=("input",), 777 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 778) 779def nll_loss( 780 input: TensorLikeType, 781 target: TensorLikeType, 782 weight: Optional[TensorLikeType] = None, 783 size_average: Optional[bool] = None, 784 ignore_index: int = -100, 785 reduce: Optional[bool] = None, 786 reduction: str = "mean", 787) -> TensorLikeType: 788 """ 789 Reference implementation of torch.nn.functional.nll_loss 790 """ 791 torch._check( 792 input.ndim > 0, 793 lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", 794 ) 795 796 # TODO: raise exception instead of converting value 797 # msg = "size_average and reduce args are deprecated, please use reduction argument." 798 # Convert these options for consistency with the eager mode 799 if size_average is not None or reduce is not None: 800 reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) 801 802 # The expected behavior when the target and input have zero elements: 803 # reduction = 'none' --- tensor([]) 804 # reduction = 'sum' --- tensor(0.) 805 # reduction = 'mean' --- tensor(nan) 806 # Mean reduction on empty tensors produces NaN. See the discussion in 807 # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 808 if input.numel() == 0 and target.numel() == 0: 809 if reduction == "none": 810 return torch.zeros_like(target) 811 elif reduction == "sum": 812 return torch.empty_like(target) 813 else: 814 return torch.full_like(target, float("nan")) 815 816 # The _nll_loss_nd helper function handles the most common cases. 817 # ndim == 1 (Single Example) 818 # => Batch Size: 1, Input: (C), Target: () 819 # ndim == 2 (k = 1) 820 # => Batch Size: N, Input: (N, C), Target: (N) 821 # ndim == 3 (k > 1) 822 # => Batch Size: N, Input: (N, C, K), Target: (N, K) 823 if input.ndim <= 3: 824 return _nll_loss_nd(input, target, weight, reduction, ignore_index) 825 826 # For ndim > 3, we reshape the input and target to 3-D case. 827 # Input (N batch-size, C classes, k-dimensions) 828 # Target (N batch-size, k-dimensions) 829 torch._check( 830 input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], 831 lambda: ( 832 "Expected input and target to both have ndim > 0 and " 833 "target.shape[1:] == input.shape[2:], but got " 834 f"target.shape {target.shape} and input.shape {input.shape}" 835 ), 836 ) 837 838 batch_size = input.shape[0] 839 num_classes = input.shape[1] 840 out_size = [batch_size] + list(target.shape[1:]) 841 842 input = torch.reshape(input, [batch_size, num_classes, -1]) 843 target = torch.reshape(target, [batch_size, -1]) 844 if reduction != "none": 845 return _nll_loss_nd(input, target, weight, reduction, ignore_index) 846 else: 847 result = _nll_loss_nd(input, target, weight, reduction, ignore_index) 848 # reshape flattened inner-dim to original k-dimensions 849 return torch.reshape(result, out_size) 850 851 852# TODO: This ref supports int reduction and out kwarg to be compatible with ATen: 853# https://github.com/pytorch/pytorch/issues/83931 854# TODO: Could be rewritten to support complex: 855# https://github.com/pytorch/pytorch/pull/85041 856@register_decomposition(aten.huber_loss) 857@out_wrapper() 858@elementwise_type_promotion_wrapper( 859 type_promoting_args=("input", "target"), 860 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 861) 862def huber_loss( 863 input: TensorLikeType, 864 target: TensorLikeType, 865 reduction: Union[str, int] = "mean", 866 delta: float = 1.0, 867) -> TensorLikeType: 868 """ 869 Reference implementation of torch.nn.functional.huber_loss 870 """ 871 if type(reduction) is int: 872 reduction = _reduction_int_to_str(reduction) 873 _check_reduction_value(reduction) # type: ignore[arg-type] 874 torch._check( 875 delta > 0, 876 lambda: "huber_loss does not support non-positive values for delta.", 877 ) 878 z = (input - target).abs() 879 loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) 880 return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type] 881 882 883# tanhshrink does not use _make_elementwise_unary_reference because it does not support out 884@elementwise_unary_scalar_wrapper 885@elementwise_type_promotion_wrapper( 886 type_promoting_args=("a",), 887 type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 888) 889def tanhshrink(a: TensorLikeType) -> TensorLikeType: 890 """ 891 Reference implementation of torch.nn.functional.tanhshrink 892 """ 893 if not isinstance(a, TensorLike): 894 raise RuntimeError( 895 "Expected a tensor input for an elementwise unary operation!" 896 ) 897 return a - torch.tanh(a) 898 899 900@register_decomposition(aten.threshold) 901@_inplace_wrapper 902@out_wrapper() 903@elementwise_type_promotion_wrapper( 904 type_promoting_args=("a",), 905 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 906) 907def threshold( 908 a: TensorLikeType, 909 threshold: NumberType, 910 value: Union[bool, int, float], 911 inplace: bool = False, 912) -> TensorLikeType: 913 """ 914 Reference implementation of torch.nn.functional.threshold 915 """ 916 917 if inplace: 918 raise NotImplementedError 919 920 return torch.where(a <= threshold, value, a) 921 922 923# CompositeImplicitAutograd - don't register decomp 924# No elementwise type promotion - core op doesn't explicitly type promote 925def triplet_margin_loss( 926 anchor: TensorLikeType, 927 positive: TensorLikeType, 928 negative: TensorLikeType, 929 margin: float = 1.0, 930 p: float = 2, 931 eps: float = 1e-6, 932 swap: bool = False, 933 size_average: Optional[bool] = None, 934 reduce: Optional[bool] = None, 935 reduction: str = "mean", 936) -> TensorLikeType: 937 if size_average is not None or reduce is not None: 938 # TODO: Raise exception instead of converting value. This is only for 939 # primTorch since it can drop support for deprecated arguments. 940 # msg = "size_average and reduce args are deprecated, please use reduction argument." 941 reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) 942 943 if margin <= 0: 944 raise ValueError(f"margin must be greater than 0, got {margin}") 945 946 # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined 947 # since it's a pure Python implementation. Use this helper instead. 948 return _triplet_margin_with_distance_loss( 949 anchor=anchor, 950 positive=positive, 951 negative=negative, 952 distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps), 953 margin=margin, 954 swap=swap, 955 reduction=reduction, 956 ) 957 958 959# Pure Python impl - don't register decomp and don't add a ref. Defined as a 960# helper here since triplet_margin_loss can be nicely implemented with it. 961def _triplet_margin_with_distance_loss( 962 anchor: TensorLikeType, 963 positive: TensorLikeType, 964 negative: TensorLikeType, 965 *, 966 distance_function: Optional[ 967 Callable[[TensorLikeType, TensorLikeType], TensorLikeType] 968 ] = None, 969 margin: float = 1.0, 970 swap: bool = False, 971 reduction: str = "mean", 972) -> TensorLikeType: 973 _check_reduction_value(reduction) 974 975 a_dim = anchor.ndim 976 p_dim = positive.ndim 977 n_dim = negative.ndim 978 torch._check( 979 a_dim == p_dim and p_dim == n_dim, 980 lambda: ( 981 f"The anchor, positive, and negative tensors are expected to have " 982 f"the same number of dimensions, but got: anchor {a_dim}D, " 983 f"positive {p_dim}D, and negative {n_dim}D inputs" 984 ), 985 ) 986 987 if distance_function is None: 988 distance_function = torch.pairwise_distance 989 990 dist_pos = distance_function(anchor, positive) 991 dist_neg = distance_function(anchor, negative) 992 # The distance swap is described in the paper "Learning shallow 993 # convolutional feature descriptors with triplet losses" by V. Balntas, E. 994 # Riba et al. If True, and if the positive example is closer to the 995 # negative example than the anchor is, swaps the positive example and the 996 # anchor in the loss computation. 997 if swap: 998 dist_swap = distance_function(positive, negative) 999 dist_neg = torch.minimum(dist_neg, dist_swap) 1000 loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) 1001 return _apply_loss_reduction(loss, reduction) 1002 1003 1004@register_decomposition(aten.hardtanh) 1005@_inplace_wrapper 1006@out_wrapper() 1007@elementwise_unary_scalar_wrapper 1008@elementwise_type_promotion_wrapper( 1009 type_promoting_args=("a"), 1010 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1011) 1012def hardtanh( 1013 a: TensorLikeType, 1014 min_val: NumberType = -1, 1015 max_val: NumberType = 1, 1016 inplace: bool = False, 1017) -> TensorLikeType: 1018 """ 1019 Reference implementation of torch.nn.functional.hardtanh 1020 """ 1021 if inplace: 1022 raise NotImplementedError 1023 if utils.is_boolean_dtype(a.dtype): 1024 raise RuntimeError("Bool inputs not supported for hardtanh") 1025 1026 # preserve legacy behavior of boundaries not causing type promotion 1027 if utils.is_integer_dtype(a.dtype): 1028 min_val = int(min_val) # type: ignore[arg-type] 1029 max_val = int(max_val) # type: ignore[arg-type] 1030 if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)): 1031 raise RuntimeError( 1032 "Cannot do hardtanh on an unsigned type with negative limits" 1033 ) 1034 1035 if min_val > max_val: # type: ignore[operator] 1036 raise ValueError("min_val cannot be greater than max_val") 1037 1038 return torch.clamp(a, min_val, max_val) # type: ignore[arg-type] 1039 1040 1041@register_decomposition(aten.gelu) 1042@out_wrapper() 1043@elementwise_unary_scalar_wrapper 1044@elementwise_type_promotion_wrapper( 1045 type_promoting_args=("a",), 1046 type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1047) 1048def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType: 1049 """ 1050 Reference implementation of torch.nn.functional.gelu 1051 """ 1052 if not isinstance(a, TensorLike): 1053 raise RuntimeError( 1054 "Expected a tensor input for an elementwise unary operation!" 1055 ) 1056 M_SQRT2 = 1.41421356237309504880 1057 M_SQRT1_2 = 0.70710678118654752440 1058 M_2_SQRTPI = 1.12837916709551257390 1059 if approximate == "tanh": 1060 kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 1061 kKappa = 0.044715 1062 a_cube = a * a * a 1063 inner = kBeta * (a + kKappa * a_cube) 1064 return 0.5 * a * (1 + torch.tanh(inner)) 1065 elif approximate == "none": 1066 kAlpha = M_SQRT1_2 1067 return a * 0.5 * (1 + torch.erf(a * kAlpha)) 1068 else: 1069 raise RuntimeError("approximate argument must be either none or tanh.") 1070 1071 1072# CompositeImplicitAutograd - don't register decomp 1073@elementwise_type_promotion_wrapper( 1074 type_promoting_args=("input", "target"), 1075 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1076) 1077def poisson_nll_loss( 1078 input: TensorLikeType, 1079 target: TensorLikeType, 1080 log_input: bool = True, 1081 full: bool = False, 1082 size_average: Optional[bool] = None, 1083 eps: float = 1e-8, 1084 reduce: Optional[bool] = None, 1085 reduction: str = "mean", 1086) -> TensorLikeType: 1087 """ 1088 Reference implementation of torch.nn.functional.poisson_nll_loss 1089 """ 1090 if size_average is not None or reduce is not None: 1091 # TODO: Raise exception instead of converting value. This is only for 1092 # primTorch since it can drop support for deprecated arguments. 1093 # msg = "size_average and reduce args are deprecated, please use reduction argument." 1094 reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) 1095 _check_reduction_value(reduction) 1096 if log_input: 1097 loss = torch.exp(input) - target * input 1098 else: 1099 loss = input - target * torch.log(input + eps) 1100 1101 if full: 1102 stirling_term = ( 1103 target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target) 1104 ) 1105 # avoid inplace add 1106 loss = loss + stirling_term.masked_fill(target <= 1, 0) 1107 return _apply_loss_reduction(loss, reduction) 1108 1109 1110@register_decomposition(aten.prelu) 1111@elementwise_type_promotion_wrapper( 1112 type_promoting_args=("a", "weight"), 1113 type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1114) 1115def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: 1116 """ 1117 Reference implementation of torch.nn.functional.prelu 1118 """ 1119 torch._check( 1120 isinstance(a, TensorLike), 1121 lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", 1122 ) 1123 torch._check( 1124 isinstance(weight, TensorLike), 1125 lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", 1126 ) 1127 1128 if weight.numel() != 1: 1129 torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") 1130 channel_size = a.shape[1] if a.ndim >= 2 else 1 1131 torch._check( 1132 weight.numel() == channel_size, 1133 lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" 1134 f" {weight.numel()} and channel size = {channel_size}.", 1135 ) 1136 1137 torch._check( 1138 weight.ndim == 0 or weight.ndim == 1, 1139 lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " 1140 f"ndim = {weight.ndim}", 1141 ) 1142 if a.ndim == 0: 1143 weight = weight[0] if weight.ndim == 1 else weight 1144 else: 1145 weight = prims.broadcast_in_dim( 1146 weight, a.shape, () if weight.ndim == 0 else (0 if a.ndim == 1 else 1,) 1147 ) 1148 1149 return torch.where(a > 0, a, a * weight) 1150 1151 1152@register_decomposition(aten.relu6) 1153@_inplace_wrapper 1154@out_wrapper() 1155def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: 1156 """ 1157 Reference implementation of torch.nn.functional.relu6 1158 """ 1159 if inplace: 1160 raise NotImplementedError 1161 1162 # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126 1163 # It may be better to use clamp here, but we use hardtanh to replicate 1164 # the behavior of the existing implementation 1165 return torch.nn.functional.hardtanh(a, 0, 6) 1166 1167 1168@register_decomposition(aten.glu) 1169@out_wrapper() 1170@elementwise_type_promotion_wrapper( 1171 type_promoting_args=("a",), 1172 type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1173) 1174def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: 1175 dim = utils.canonicalize_dims(a.ndim, dim) 1176 torch._check( 1177 a.shape[dim] % 2 == 0, 1178 lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", 1179 ) 1180 b, c = torch.tensor_split(a, 2, dim) 1181 1182 return b * torch.sigmoid(c) 1183 1184 1185@register_decomposition(aten.pairwise_distance) 1186@out_wrapper() 1187def pairwise_distance( 1188 x1: TensorLikeType, 1189 x2: TensorLikeType, 1190 p: NumberType = 2.0, 1191 eps: NumberType = 1e-6, 1192 keepdim=False, 1193) -> TensorLikeType: 1194 return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim) 1195 1196 1197@register_decomposition(aten.pdist) 1198@out_wrapper() 1199@elementwise_type_promotion_wrapper( 1200 type_promoting_args=("a",), 1201 type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1202) 1203def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: 1204 torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") 1205 torch._check(p >= 0, lambda: "pdist only supports non-negative p values") 1206 # For p == 2 we can use an efficient implementation, but other values of p 1207 # require creating a much bigger tensor for an intermediate step 1208 if p == 2: 1209 aTa = torch.mm(a, a.T) 1210 aTa_diag = torch.diag(aTa) 1211 t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)) 1212 else: 1213 t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) 1214 i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) 1215 return t.flatten().index_select(0, i[0] * t.shape[0] + i[1]) 1216 1217 1218@register_decomposition(aten.pixel_shuffle) 1219@out_wrapper() 1220def pixel_shuffle(self: Tensor, upscale_factor: int): 1221 torch._check( 1222 self.dim() >= 3, 1223 lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", 1224 ) 1225 batch = self.shape[:-3] 1226 C_out = self.shape[-3] // upscale_factor**2 1227 HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor) 1228 n = len(batch) 1229 B_dims = range(n) 1230 C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5) 1231 return ( 1232 self.view( 1233 *batch, 1234 C_out, 1235 upscale_factor, 1236 upscale_factor, 1237 self.shape[-2], 1238 self.shape[-1], 1239 ) 1240 .permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim) 1241 .reshape(*batch, C_out, *HW_out) 1242 .clone(memory_format=utils.suggest_memory_format(self)) 1243 ) 1244 1245 1246@register_decomposition(aten.pixel_unshuffle) 1247@out_wrapper() 1248def pixel_unshuffle(self: Tensor, downscale_factor: int): 1249 torch._check( 1250 self.dim() >= 3, 1251 lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", 1252 ) 1253 batch = self.shape[:-3] 1254 C_out = self.shape[-3] * downscale_factor**2 1255 HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor) 1256 n = len(batch) 1257 B_dims = range(n) 1258 C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5) 1259 return ( 1260 self.view( 1261 *batch, 1262 self.shape[-3], 1263 HW_out[0], 1264 downscale_factor, 1265 HW_out[1], 1266 downscale_factor, 1267 ) 1268 .permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim) 1269 .reshape(*batch, C_out, *HW_out) 1270 .clone(memory_format=utils.suggest_memory_format(self)) 1271 ) 1272 1273 1274# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg) 1275celu_ = _make_inplace(celu) 1276elu_ = _make_inplace(elu) 1277mish_ = _make_inplace(mish) 1278selu_ = _make_inplace(selu) 1279threshold_ = _make_inplace(threshold) 1280