import torch import torch.nn.functional as F from torch.testing._internal.common_nn import wrap_functional """ `sample_functional` is used by `test_cpp_api_parity.py` to test that Python / C++ API parity test harness works for `torch.nn.functional` functions. When `has_parity=true` is passed to `sample_functional`, behavior of `sample_functional` is the same as the C++ equivalent. When `has_parity=false` is passed to `sample_functional`, behavior of `sample_functional` is different from the C++ equivalent. """ def sample_functional(x, has_parity): if has_parity: return x * 2 else: return x * 4 torch.nn.functional.sample_functional = sample_functional SAMPLE_FUNCTIONAL_CPP_SOURCE = """\n namespace torch { namespace nn { namespace functional { struct C10_EXPORT SampleFunctionalFuncOptions { SampleFunctionalFuncOptions(bool has_parity) : has_parity_(has_parity) {} TORCH_ARG(bool, has_parity); }; Tensor sample_functional(Tensor x, SampleFunctionalFuncOptions options) { return x * 2; } } // namespace functional } // namespace nn } // namespace torch """ functional_tests = [ dict( constructor=wrap_functional(F.sample_functional, has_parity=True), cpp_options_args="F::SampleFunctionalFuncOptions(true)", input_size=(1, 2, 3), fullname="sample_functional_has_parity", has_parity=True, ), dict( constructor=wrap_functional(F.sample_functional, has_parity=False), cpp_options_args="F::SampleFunctionalFuncOptions(false)", input_size=(1, 2, 3), fullname="sample_functional_no_parity", has_parity=False, ), # This is to test that setting the `test_cpp_api_parity=False` flag skips # the C++ API parity test accordingly (otherwise this test would run and # throw a parity error). dict( constructor=wrap_functional(F.sample_functional, has_parity=False), cpp_options_args="F::SampleFunctionalFuncOptions(false)", input_size=(1, 2, 3), fullname="sample_functional_THIS_TEST_SHOULD_BE_SKIPPED", test_cpp_api_parity=False, ), ]