1from __future__ import annotations 2 3import tempfile 4import unittest 5from typing import Any 6from unittest.mock import ANY, Mock, patch 7 8import expecttest 9 10import torchgen 11from torchgen.executorch.api.custom_ops import ComputeNativeFunctionStub 12from torchgen.executorch.model import ETKernelIndex 13from torchgen.gen_executorch import gen_headers 14from torchgen.model import Location, NativeFunction 15from torchgen.selective_build.selector import SelectiveBuilder 16from torchgen.utils import FileManager 17 18 19SPACES = " " 20 21 22def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction: 23 native_function, _ = NativeFunction.from_yaml( 24 yaml_obj, 25 loc=Location(__file__, 1), 26 valid_tags=set(), 27 ) 28 return native_function 29 30 31class TestComputeNativeFunctionStub(expecttest.TestCase): 32 """ 33 Could use torch.testing._internal.common_utils to reduce boilerplate. 34 GH CI job doesn't build torch before running tools unit tests, hence 35 manually adding these parametrized tests. 36 """ 37 38 def _test_function_schema_generates_correct_kernel( 39 self, obj: dict[str, Any], expected: str 40 ) -> None: 41 func = _get_native_function_from_yaml(obj) 42 43 gen = ComputeNativeFunctionStub() 44 res = gen(func) 45 self.assertIsNotNone(res) 46 self.assertExpectedInline( 47 str(res), 48 expected, 49 ) 50 51 def test_function_schema_generates_correct_kernel_tensor_out(self) -> None: 52 obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"} 53 expected = """ 54at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) { 55 return out; 56} 57 """ 58 self._test_function_schema_generates_correct_kernel(obj, expected) 59 60 def test_function_schema_generates_correct_kernel_no_out(self) -> None: 61 obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"} 62 expected = """ 63at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) { 64 return self; 65} 66 """ 67 self._test_function_schema_generates_correct_kernel(obj, expected) 68 69 def test_function_schema_generates_correct_kernel_no_return(self) -> None: 70 obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!)[] out) -> ()"} 71 expected = f""" 72void wrapper_CPU_out_foo_out(const at::Tensor & self, at::TensorList out) {{ 73{SPACES} 74}} 75 """ 76 self._test_function_schema_generates_correct_kernel(obj, expected) 77 78 def test_function_schema_generates_correct_kernel_3_returns(self) -> None: 79 obj = { 80 "func": "custom::foo(Tensor self, Tensor[] other) -> (Tensor, Tensor, Tensor)" 81 } 82 expected = """ 83::std::tuple<at::Tensor,at::Tensor,at::Tensor> wrapper_CPU__foo(const at::Tensor & self, at::TensorList other) { 84 return ::std::tuple<at::Tensor, at::Tensor, at::Tensor>( 85 at::Tensor(), at::Tensor(), at::Tensor() 86 ); 87} 88 """ 89 self._test_function_schema_generates_correct_kernel(obj, expected) 90 91 def test_function_schema_generates_correct_kernel_1_return_no_out(self) -> None: 92 obj = {"func": "custom::foo(Tensor[] a) -> Tensor"} 93 expected = """ 94at::Tensor wrapper_CPU__foo(at::TensorList a) { 95 return at::Tensor(); 96} 97 """ 98 self._test_function_schema_generates_correct_kernel(obj, expected) 99 100 def test_schema_has_no_return_type_argument_throws(self) -> None: 101 func = _get_native_function_from_yaml( 102 {"func": "custom::foo.bool(Tensor self) -> bool"} 103 ) 104 105 gen = ComputeNativeFunctionStub() 106 with self.assertRaisesRegex(Exception, "Can't handle this return type"): 107 gen(func) 108 109 110class TestGenCustomOpsHeader(unittest.TestCase): 111 @patch.object(torchgen.utils.FileManager, "write_with_template") 112 @patch.object(torchgen.utils.FileManager, "write") 113 def test_fm_writes_custom_ops_header_when_boolean_is_true( 114 self, unused: Mock, mock_method: Mock 115 ) -> None: 116 with tempfile.TemporaryDirectory() as tempdir: 117 fm = FileManager(tempdir, tempdir, False) 118 gen_headers( 119 native_functions=[], 120 gen_custom_ops_header=True, 121 custom_ops_native_functions=[], 122 selector=SelectiveBuilder.get_nop_selector(), 123 kernel_index=ETKernelIndex(index={}), 124 cpu_fm=fm, 125 use_aten_lib=False, 126 ) 127 mock_method.assert_called_once_with( 128 "CustomOpsNativeFunctions.h", "NativeFunctions.h", ANY 129 ) 130 131 @patch.object(torchgen.utils.FileManager, "write_with_template") 132 @patch.object(torchgen.utils.FileManager, "write") 133 def test_fm_doesnot_writes_custom_ops_header_when_boolean_is_false( 134 self, unused: Mock, mock_method: Mock 135 ) -> None: 136 with tempfile.TemporaryDirectory() as tempdir: 137 fm = FileManager(tempdir, tempdir, False) 138 gen_headers( 139 native_functions=[], 140 gen_custom_ops_header=False, 141 custom_ops_native_functions=[], 142 selector=SelectiveBuilder.get_nop_selector(), 143 kernel_index=ETKernelIndex(index={}), 144 cpu_fm=fm, 145 use_aten_lib=False, 146 ) 147 mock_method.assert_not_called() 148