• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import annotations
2
3import tempfile
4import unittest
5from typing import Any
6from unittest.mock import ANY, Mock, patch
7
8import expecttest
9
10import torchgen
11from torchgen.executorch.api.custom_ops import ComputeNativeFunctionStub
12from torchgen.executorch.model import ETKernelIndex
13from torchgen.gen_executorch import gen_headers
14from torchgen.model import Location, NativeFunction
15from torchgen.selective_build.selector import SelectiveBuilder
16from torchgen.utils import FileManager
17
18
19SPACES = "    "
20
21
22def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction:
23    native_function, _ = NativeFunction.from_yaml(
24        yaml_obj,
25        loc=Location(__file__, 1),
26        valid_tags=set(),
27    )
28    return native_function
29
30
31class TestComputeNativeFunctionStub(expecttest.TestCase):
32    """
33    Could use torch.testing._internal.common_utils to reduce boilerplate.
34    GH CI job doesn't build torch before running tools unit tests, hence
35    manually adding these parametrized tests.
36    """
37
38    def _test_function_schema_generates_correct_kernel(
39        self, obj: dict[str, Any], expected: str
40    ) -> None:
41        func = _get_native_function_from_yaml(obj)
42
43        gen = ComputeNativeFunctionStub()
44        res = gen(func)
45        self.assertIsNotNone(res)
46        self.assertExpectedInline(
47            str(res),
48            expected,
49        )
50
51    def test_function_schema_generates_correct_kernel_tensor_out(self) -> None:
52        obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"}
53        expected = """
54at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) {
55    return out;
56}
57    """
58        self._test_function_schema_generates_correct_kernel(obj, expected)
59
60    def test_function_schema_generates_correct_kernel_no_out(self) -> None:
61        obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"}
62        expected = """
63at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) {
64    return self;
65}
66    """
67        self._test_function_schema_generates_correct_kernel(obj, expected)
68
69    def test_function_schema_generates_correct_kernel_no_return(self) -> None:
70        obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!)[] out) -> ()"}
71        expected = f"""
72void wrapper_CPU_out_foo_out(const at::Tensor & self, at::TensorList out) {{
73{SPACES}
74}}
75    """
76        self._test_function_schema_generates_correct_kernel(obj, expected)
77
78    def test_function_schema_generates_correct_kernel_3_returns(self) -> None:
79        obj = {
80            "func": "custom::foo(Tensor self, Tensor[] other) -> (Tensor, Tensor, Tensor)"
81        }
82        expected = """
83::std::tuple<at::Tensor,at::Tensor,at::Tensor> wrapper_CPU__foo(const at::Tensor & self, at::TensorList other) {
84    return ::std::tuple<at::Tensor, at::Tensor, at::Tensor>(
85                at::Tensor(), at::Tensor(), at::Tensor()
86            );
87}
88    """
89        self._test_function_schema_generates_correct_kernel(obj, expected)
90
91    def test_function_schema_generates_correct_kernel_1_return_no_out(self) -> None:
92        obj = {"func": "custom::foo(Tensor[] a) -> Tensor"}
93        expected = """
94at::Tensor wrapper_CPU__foo(at::TensorList a) {
95    return at::Tensor();
96}
97    """
98        self._test_function_schema_generates_correct_kernel(obj, expected)
99
100    def test_schema_has_no_return_type_argument_throws(self) -> None:
101        func = _get_native_function_from_yaml(
102            {"func": "custom::foo.bool(Tensor self) -> bool"}
103        )
104
105        gen = ComputeNativeFunctionStub()
106        with self.assertRaisesRegex(Exception, "Can't handle this return type"):
107            gen(func)
108
109
110class TestGenCustomOpsHeader(unittest.TestCase):
111    @patch.object(torchgen.utils.FileManager, "write_with_template")
112    @patch.object(torchgen.utils.FileManager, "write")
113    def test_fm_writes_custom_ops_header_when_boolean_is_true(
114        self, unused: Mock, mock_method: Mock
115    ) -> None:
116        with tempfile.TemporaryDirectory() as tempdir:
117            fm = FileManager(tempdir, tempdir, False)
118            gen_headers(
119                native_functions=[],
120                gen_custom_ops_header=True,
121                custom_ops_native_functions=[],
122                selector=SelectiveBuilder.get_nop_selector(),
123                kernel_index=ETKernelIndex(index={}),
124                cpu_fm=fm,
125                use_aten_lib=False,
126            )
127            mock_method.assert_called_once_with(
128                "CustomOpsNativeFunctions.h", "NativeFunctions.h", ANY
129            )
130
131    @patch.object(torchgen.utils.FileManager, "write_with_template")
132    @patch.object(torchgen.utils.FileManager, "write")
133    def test_fm_doesnot_writes_custom_ops_header_when_boolean_is_false(
134        self, unused: Mock, mock_method: Mock
135    ) -> None:
136        with tempfile.TemporaryDirectory() as tempdir:
137            fm = FileManager(tempdir, tempdir, False)
138            gen_headers(
139                native_functions=[],
140                gen_custom_ops_header=False,
141                custom_ops_native_functions=[],
142                selector=SelectiveBuilder.get_nop_selector(),
143                kernel_index=ETKernelIndex(index={}),
144                cpu_fm=fm,
145                use_aten_lib=False,
146            )
147            mock_method.assert_not_called()
148