# mypy: ignore-errors import unittest from functools import partial from itertools import product from typing import List import numpy as np import torch from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( precisionOverride, tol, toleranceOverride, ) from torch.testing._internal.common_dtype import all_types_and, floating_types from torch.testing._internal.common_utils import ( TEST_SCIPY, TEST_WITH_ROCM, torch_to_numpy_dtype_dict, ) from torch.testing._internal.opinfo.core import ( BinaryUfuncInfo, DecorateInfo, L, NumericsFilter, OpInfo, S, SampleInput, UnaryUfuncInfo, ) from torch.testing._internal.opinfo.refs import ( ElementwiseBinaryPythonRefInfo, ElementwiseUnaryPythonRefInfo, ) from torch.testing._internal.opinfo.utils import ( np_unary_ufunc_integer_promotion_wrapper, ) if TEST_SCIPY: import scipy.special # TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`, # supports `exclude` argument. # For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617 def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs): exclude_zero = requires_grad and op_info.op == torch.special.i0e make_arg = partial( make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=exclude_zero, ) yield SampleInput(make_arg((S,))) yield SampleInput(make_arg(())) if requires_grad and not exclude_zero: # Special Case for gradient # Sample with `0` in the input t = make_arg((S,)) t[0] = 0 yield SampleInput(t) def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial( make_tensor, device=device, # TODO: eliminate low after gh-106692 is fixed: low=(1 if dtype in {torch.int32, torch.int64} else None), dtype=dtype, requires_grad=requires_grad, ) tensor_shapes = ((S, S), ()) ns = (1, 2, 3, 4, 5) for shape, n in product(tensor_shapes, ns): yield SampleInput(make_arg(shape), args=(n,)) def reference_polygamma(x, n): # WEIRD `scipy.special.polygamma` behavior # >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype # dtype('float64') # >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype # dtype('float32') # # Thus we cast output to the default torch dtype or preserve double result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] if x.dtype == np.double: result_dtype = np.double return scipy.special.polygamma(n, x).astype(result_dtype) def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs): low, _ = op_info.domain if requires_grad: low = 0 + op_info._domain_eps make_arg = partial( make_tensor, dtype=dtype, device=device, low=low, requires_grad=requires_grad ) yield SampleInput(make_arg((L,))) yield SampleInput(make_arg(())) def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): for shape in ((L,), (1, 0, 3), ()): yield SampleInput( make_tensor( shape, device=device, dtype=dtype, low=-5, requires_grad=requires_grad, ), ) op_db: List[OpInfo] = [ UnaryUfuncInfo( "special.i0e", aten_name="special_i0e", ref=scipy.special.i0e if TEST_SCIPY else None, decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),), dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), backward_dtypes=floating_types(), sample_inputs_func=sample_inputs_i0_i1, supports_forward_ad=True, supports_fwgrad_bwgrad=True, ), UnaryUfuncInfo( "special.i1", aten_name="special_i1", ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1) if TEST_SCIPY else None, dtypes=all_types_and(torch.bool), dtypesIfCUDA=all_types_and(torch.bool), sample_inputs_func=sample_inputs_i0_i1, decorators=( DecorateInfo( toleranceOverride( { torch.float32: tol(atol=1e-4, rtol=0), torch.bool: tol(atol=1e-4, rtol=0), } ) ), ), skips=( DecorateInfo( unittest.skip("Incorrect result!"), "TestUnaryUfuncs", "test_reference_numerics_large", dtypes=(torch.int8,), ), ), supports_fwgrad_bwgrad=True, supports_forward_ad=True, ), UnaryUfuncInfo( "special.i1e", aten_name="special_i1e", ref=scipy.special.i1e if TEST_SCIPY else None, dtypes=all_types_and(torch.bool), dtypesIfCUDA=all_types_and(torch.bool), sample_inputs_func=sample_inputs_i0_i1, supports_forward_ad=True, supports_fwgrad_bwgrad=True, ), UnaryUfuncInfo( "special.ndtr", aten_name="special_ndtr", decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),), ref=scipy.special.ndtr if TEST_SCIPY else None, dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( # Dispatch stub: unsupported device typemeta DecorateInfo( unittest.expectedFailure, "TestFwdGradients", "test_fn_fwgrad_bwgrad", device_type="meta", ), ), ), # A separate OpInfo entry for special.polygamma is needed to reorder the arguments # for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939 UnaryUfuncInfo( "special.polygamma", op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs), variant_test_name="special_polygamma_n_0", ref=reference_polygamma if TEST_SCIPY else None, dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_polygamma, skips=( # lambda impl DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), DecorateInfo( unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), ), sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}), # polygamma functions have multiple singularities at x having non-positive integer value reference_numerics_filter=NumericsFilter( condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1 ), ), BinaryUfuncInfo( "special.xlog1py", aten_name="special_xlog1py", dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), promotes_int_to_float=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_one_python_scalar=True, # We don't test -1 as the gradient will be NaN and it'll break rhs_make_tensor_kwargs=dict(low=-0.99), ), BinaryUfuncInfo( "special.zeta", aten_name="special_zeta", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, supports_autograd=False, supports_one_python_scalar=True, skips=( # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), ), # TODO: FIXME # OpInfo entry to verify the gradient formula of `other`/`q` # BinaryUfuncInfo('special.zeta', # op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs), # aten_name='special_zeta', # variant_test_name='grad', # dtypes=all_types_and(torch.bool), # promotes_int_to_float=True, # supports_autograd=True, # supports_rhs_python_scalar=False, # decorators=[ # # Derivative wrt first tensor not implemented # DecorateInfo(unittest.expectedFailure, "TestCommon", # "test_floating_inputs_are_differentiable") # ], # skips=( # # Lambda doesn't work in JIT test # # AssertionError: JIT Test does not execute any logic # DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"), # )), UnaryUfuncInfo( "special.entr", ref=scipy.special.entr if TEST_SCIPY else None, aten_name="special_entr", supports_forward_ad=True, supports_fwgrad_bwgrad=True, decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),), dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), skips=( DecorateInfo( unittest.skip("Skipped!"), "TestUnaryUfuncs", "test_reference_numerics_large", dtypes=[torch.bfloat16, torch.float16], ), ), supports_inplace_autograd=False, sample_inputs_func=sample_inputs_entr, ), UnaryUfuncInfo( "special.ndtri", ref=scipy.special.ndtri if TEST_SCIPY else None, domain=(0, 1), aten_name="special_ndtri", dtypes=all_types_and(torch.bool), supports_forward_ad=True, supports_fwgrad_bwgrad=True, ), UnaryUfuncInfo( "special.log_ndtr", aten_name="special_log_ndtr", ref=scipy.special.log_ndtr if TEST_SCIPY else None, dtypes=all_types_and(torch.bool), supports_forward_ad=True, supports_fwgrad_bwgrad=True, ), UnaryUfuncInfo( "special.erfcx", ref=scipy.special.erfcx if TEST_SCIPY else None, aten_name="special_erfcx", decorators=( toleranceOverride( { torch.float32: tol(atol=0, rtol=4e-6), } ), ), dtypes=all_types_and(torch.bool), supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_erfcx, ), UnaryUfuncInfo( "special.airy_ai", decorators=( precisionOverride( { torch.float32: 1e-03, torch.float64: 1e-05, }, ), ), dtypes=all_types_and(torch.bool), ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None, skips=( DecorateInfo( unittest.skip("Skipped!"), "TestUnaryUfuncs", "test_reference_numerics_large", ), ), supports_autograd=False, ), UnaryUfuncInfo( "special.bessel_j0", decorators=( precisionOverride( { torch.float32: 1e-04, torch.float64: 1e-05, }, ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.j0 if TEST_SCIPY else None, supports_autograd=False, ), UnaryUfuncInfo( "special.bessel_j1", decorators=( precisionOverride( { torch.float32: 1e-04, torch.float64: 1e-05, }, ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.j1 if TEST_SCIPY else None, supports_autograd=False, ), UnaryUfuncInfo( "special.bessel_y0", decorators=( precisionOverride( { torch.float32: 1e-04, torch.float64: 1e-05, }, ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.y0 if TEST_SCIPY else None, supports_autograd=False, ), UnaryUfuncInfo( "special.bessel_y1", decorators=( precisionOverride( { torch.float32: 1e-04, torch.float64: 1e-05, }, ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.y1 if TEST_SCIPY else None, supports_autograd=False, ), BinaryUfuncInfo( "special.chebyshev_polynomial_t", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo( unittest.skip("testing takes an unreasonably long time, #79528"), "TestCommon", "test_compare_cpu", ), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.chebyshev_polynomial_u", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo( unittest.skip("testing takes an unreasonably long time, #79528"), "TestCommon", "test_compare_cpu", ), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.chebyshev_polynomial_v", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo( unittest.skip( "Skipping - testing takes an unreasonably long time, #79528" ) ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.chebyshev_polynomial_w", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo( unittest.skip( "Skipping - testing takes an unreasonably long time, #79528" ) ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.hermite_polynomial_h", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: inf DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), DecorateInfo(unittest.skip("Hangs on ROCm 6.1"), active_if=TEST_WITH_ROCM), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.hermite_polynomial_he", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo( unittest.skip("testing takes an unreasonably long time, #79528"), "TestCommon", "test_compare_cpu", ), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.laguerre_polynomial_l", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo( unittest.skip("testing takes an unreasonably long time, #79528"), "TestCommon", "test_compare_cpu", ), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.legendre_polynomial_p", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo( unittest.skip( "Skipping - testing takes an unreasonably long time, #79528" ) ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo( unittest.skip("testing takes an unreasonably long time, #79528"), "TestCommon", "test_compare_cpu", ), ), supports_one_python_scalar=True, supports_autograd=False, ), UnaryUfuncInfo( "special.modified_bessel_i0", decorators=( precisionOverride( { torch.float32: 1e-03, torch.float64: 1e-05, }, ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.i0 if TEST_SCIPY else None, supports_autograd=False, ), UnaryUfuncInfo( "special.modified_bessel_i1", decorators=( precisionOverride( { torch.float32: 1e-03, torch.float64: 1e-05, }, ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.i1 if TEST_SCIPY else None, supports_autograd=False, ), UnaryUfuncInfo( "special.modified_bessel_k0", decorators=( precisionOverride( { torch.float32: 1e-03, torch.float64: 1e-05, }, ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.k0 if TEST_SCIPY else None, supports_autograd=False, ), UnaryUfuncInfo( "special.modified_bessel_k1", decorators=( precisionOverride( { torch.float32: 1e-03, torch.float64: 1e-05, }, ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.k1 if TEST_SCIPY else None, supports_autograd=False, ), UnaryUfuncInfo( "special.scaled_modified_bessel_k0", decorators=( toleranceOverride( { torch.float32: tol(atol=1e-03, rtol=1e-03), torch.float64: tol(atol=1e-05, rtol=1e-03), } ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.k0e if TEST_SCIPY else None, supports_autograd=False, ), UnaryUfuncInfo( "special.scaled_modified_bessel_k1", decorators=( toleranceOverride( { torch.float32: tol(atol=1e-03, rtol=1e-03), torch.float64: tol(atol=1e-05, rtol=1e-03), } ), ), dtypes=all_types_and(torch.bool), ref=scipy.special.k1e if TEST_SCIPY else None, supports_autograd=False, ), BinaryUfuncInfo( "special.shifted_chebyshev_polynomial_t", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo( unittest.skip( "Skipping - testing takes an unreasonably long time, #79528" ) ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo( unittest.skip("testing takes an unreasonably long time, #79528"), "TestCommon", "test_compare_cpu", ), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.shifted_chebyshev_polynomial_u", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo( unittest.skip( "Skipping - testing takes an unreasonably long time, #79528" ) ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo( unittest.skip("testing takes an unreasonably long time, #79528"), "TestCommon", "test_compare_cpu", ), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.shifted_chebyshev_polynomial_v", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo( unittest.skip( "Skipping - testing takes an unreasonably long time, #79528" ) ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo( unittest.skip("testing takes an unreasonably long time, #79528"), "TestCommon", "test_compare_cpu", ), ), supports_one_python_scalar=True, supports_autograd=False, ), BinaryUfuncInfo( "special.shifted_chebyshev_polynomial_w", dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( DecorateInfo( unittest.skip( "Skipping - testing takes an unreasonably long time, #79528" ) ), DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo( unittest.skip("testing takes an unreasonably long time, #79528"), "TestCommon", "test_compare_cpu", ), ), supports_one_python_scalar=True, supports_autograd=False, ), UnaryUfuncInfo( "special.spherical_bessel_j0", decorators=( toleranceOverride( { torch.float32: tol(atol=1e-03, rtol=1e-03), torch.float64: tol(atol=1e-05, rtol=1e-03), } ), ), dtypes=all_types_and(torch.bool), ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None, supports_autograd=False, ), ] python_ref_db: List[OpInfo] = [ # # Elementwise Unary Special OpInfos # ElementwiseUnaryPythonRefInfo( "_refs.special.bessel_j0", torch_opinfo_name="special.bessel_j0", op_db=op_db, decorators=( precisionOverride( { torch.float32: 1e-04, torch.float64: 1e-05, }, ), ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.bessel_j1", torch_opinfo_name="special.bessel_j1", op_db=op_db, decorators=( precisionOverride( { torch.float32: 1e-04, torch.float64: 1e-05, }, ), ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.entr", torch_opinfo_name="special.entr", op_db=op_db, decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),), skips=( DecorateInfo( unittest.skip("Skipped!"), "TestUnaryUfuncs", "test_reference_numerics_large", dtypes=[torch.bfloat16, torch.float16], ), ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.erfcx", torch_opinfo_name="special.erfcx", op_db=op_db, decorators=( toleranceOverride( { torch.float32: tol(atol=0, rtol=4e-6), } ), ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.i0e", torch_opinfo_name="special.i0e", op_db=op_db, decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),), ), ElementwiseUnaryPythonRefInfo( "_refs.special.i1", torch_opinfo_name="special.i1", op_db=op_db, decorators=( DecorateInfo( toleranceOverride( { torch.float32: tol(atol=1e-4, rtol=0), torch.bool: tol(atol=1e-4, rtol=0), } ) ), ), skips=( DecorateInfo( unittest.skip("Incorrect result!"), "TestUnaryUfuncs", "test_reference_numerics_large", dtypes=(torch.int8,), ), ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.i1e", torch_opinfo_name="special.i1e", op_db=op_db, ), ElementwiseUnaryPythonRefInfo( "_refs.special.log_ndtr", torch_opinfo_name="special.log_ndtr", op_db=op_db, ), ElementwiseUnaryPythonRefInfo( "_refs.special.ndtr", torch_opinfo_name="special.ndtr", op_db=op_db, ), ElementwiseUnaryPythonRefInfo( "_refs.special.ndtri", torch_opinfo_name="special.ndtri", op_db=op_db, ), ElementwiseUnaryPythonRefInfo( "_refs.special.spherical_bessel_j0", torch_opinfo_name="special.spherical_bessel_j0", op_db=op_db, decorators=( toleranceOverride( { torch.float32: tol(atol=1e-03, rtol=1e-03), torch.float64: tol(atol=1e-05, rtol=1e-03), } ), ), ), # # Elementwise Binary Special OpInfos # ElementwiseBinaryPythonRefInfo( "_refs.special.zeta", torch_opinfo_name="special.zeta", supports_one_python_scalar=True, op_db=op_db, skips=( # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), ), ), ]