# The purpose of this test is to check that we have implementation parity between # a Python `torch.nn.functional` function and its corresponding C++ `torch::nn::functional` # function. Concretely, this test does the following: # # 1. Get a test params dict from common_nn.py, run forward pass on the Python functional # created using the test params. # # 2. Serialize the Python functional's forward input arguments, deserialize them # in C++ and use them as input for the C++ functional's forward pass. # # 3. Run the forward pass on the C++ functional, and serialize the C++ functional's # forward output. # # 4. Compare Python/C++ functional's forward output. If they are the same, then we # have implementation parity between Python/C++ module. import os import pprint import re import tempfile from string import Template import torch from cpp_api_parity.sample_functional import SAMPLE_FUNCTIONAL_CPP_SOURCE from cpp_api_parity.utils import ( add_test, compile_cpp_code_inline, compute_arg_dict, compute_cpp_args_construction_stmts_and_forward_arg_symbols, compute_temp_file_path, decorate_test_fn, generate_error_msg, is_torch_nn_functional_test, move_python_tensors_to_device, serialize_arg_dict_as_script_module, set_python_tensors_requires_grad, TORCH_NN_COMMON_TEST_HARNESS, TorchNNFunctionalTestParams, try_remove_folder, ) # Expected substitutions: # # ${functional_variant_name} (e.g. `BCELoss_no_reduce`) # ${cpp_args_construction_stmts} # ${cpp_function_call} TORCH_NN_FUNCTIONAL_TEST_FORWARD = Template( """ void ${functional_variant_name}_test_forward( const std::string& arg_dict_file_path, const std::string& forward_output_file_path) { pybind11::gil_scoped_release no_gil; namespace F = torch::nn::functional; // Declare arguments auto arg_dict = load_dict_from_file(arg_dict_file_path); ${cpp_args_construction_stmts}; // Some functionals (such as `F::rrelu`) create random tensors in their call path. // To make sure the random tensors created are the same in Python/C++, we need // to set the RNG seed manually. torch::manual_seed(0); // Run function with arguments auto cpp_output = ${cpp_function_call}; // Save the output into a file to be compared in Python later write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path); } """ ) def run_forward(unit_test_class, test_params): device = test_params.device inputs = set_python_tensors_requires_grad( move_python_tensors_to_device( [arg_value for _, arg_value in test_params.arg_dict["input"]], device ) ) inputs += move_python_tensors_to_device( [arg_value for _, arg_value in test_params.arg_dict["target"]], device ) inputs += move_python_tensors_to_device( [arg_value for _, arg_value in test_params.arg_dict["extra_args"]], device ) # Some functionals (such as `F.rrelu`) create random tensors in their call path. # To make sure the random tensors created are the same in Python/C++, we need # to set the RNG seed manually. torch.manual_seed(0) python_output = test_params.test_instance.constructor()(*inputs) return python_output def test_forward(unit_test_class, test_params): functional_variant_name = test_params.functional_variant_name cpp_tmp_folder = test_params.cpp_tmp_folder # Remove the temporary folder if it exists already try_remove_folder(cpp_tmp_folder) os.mkdir(cpp_tmp_folder) # Run forward on Python functional python_output = run_forward(unit_test_class, test_params) # Save Python arguments to be used from C++ function arg_dict_file_path = compute_temp_file_path( cpp_tmp_folder, functional_variant_name, "arg_dict" ) serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path) cpp_test_name = f"{test_params.functional_variant_name}_test_forward" cpp_test_fn = getattr( unit_test_class.functional_impl_check_cpp_module, cpp_test_name ) def run_cpp_test_fn_and_check_output(): forward_output_file_path = compute_temp_file_path( cpp_tmp_folder, functional_variant_name, "forward_output" ) cpp_test_fn(arg_dict_file_path, forward_output_file_path) cpp_output = torch.load(forward_output_file_path) # Check that forward outputs are equal unit_test_class.assertEqual( python_output, cpp_output, msg=generate_error_msg("forward output", cpp_output, python_output), ) run_cpp_test_fn_and_check_output() # Remove temporary folder that stores C++ outputs try_remove_folder(cpp_tmp_folder) def compute_functional_name(test_params_dict): def camel_case_to_snake_case(camel_case_str): return re.sub(r"(? 0 cpp_sources = TORCH_NN_COMMON_TEST_HARNESS + SAMPLE_FUNCTIONAL_CPP_SOURCE functions = [] for test_params in unit_test_class.functional_test_params_map.values(): cpp_sources += generate_test_cpp_sources( test_params=test_params, template=TORCH_NN_FUNCTIONAL_TEST_FORWARD ) functions.append(f"{test_params.functional_variant_name}_test_forward") if print_cpp_source: print(cpp_sources) cpp_module = compile_cpp_code_inline( name="functional_impl_check", cpp_sources=cpp_sources, functions=functions ) unit_test_class.functional_impl_check_cpp_module = cpp_module