• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: jit"]
2
3import unittest
4
5import torch
6import torch._C
7
8
9torch.ops.load_library("//caffe2:xnnpack_backend")
10
11
12class TestXNNPackBackend(unittest.TestCase):
13    def test_xnnpack_constant_data(self):
14        class Module(torch.nn.Module):
15            def __init__(self) -> None:
16                super().__init__()
17                self._constant = torch.ones(4, 4, 4)
18
19            def forward(self, x):
20                return x + self._constant
21
22        scripted_module = torch.jit.script(Module())
23
24        lowered_module = torch._C._jit_to_backend(
25            "xnnpack",
26            scripted_module,
27            {
28                "forward": {
29                    "inputs": [torch.randn(4, 4, 4)],
30                    "outputs": [torch.randn(4, 4, 4)],
31                }
32            },
33        )
34
35        for i in range(0, 20):
36            sample_input = torch.randn(4, 4, 4)
37            actual_output = scripted_module(sample_input)
38            expected_output = lowered_module(sample_input)
39            self.assertTrue(
40                torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
41            )
42
43    def test_xnnpack_lowering(self):
44        class Module(torch.nn.Module):
45            def forward(self, x):
46                return x + x
47
48        scripted_module = torch.jit.script(Module())
49
50        faulty_compile_spec = {
51            "backward": {
52                "inputs": [torch.zeros(1)],
53                "outputs": [torch.zeros(1)],
54            }
55        }
56        error_msg = 'method_compile_spec does not contain the "forward" key.'
57
58        with self.assertRaisesRegex(
59            RuntimeError,
60            error_msg,
61        ):
62            _ = torch._C._jit_to_backend(
63                "xnnpack",
64                scripted_module,
65                faulty_compile_spec,
66            )
67
68        mismatch_compile_spec = {
69            "forward": {
70                "inputs": [torch.zeros(1), torch.zeros(1)],
71                "outputs": [torch.zeros(1)],
72            }
73        }
74        error_msg = (
75            "method_compile_spec inputs do not match expected number of forward inputs"
76        )
77
78        with self.assertRaisesRegex(
79            RuntimeError,
80            error_msg,
81        ):
82            _ = torch._C._jit_to_backend(
83                "xnnpack", scripted_module, mismatch_compile_spec
84            )
85
86        lowered = torch._C._jit_to_backend(
87            "xnnpack",
88            scripted_module,
89            {
90                "forward": {
91                    "inputs": [torch.zeros(1)],
92                    "outputs": [torch.zeros(1)],
93                }
94            },
95        )
96        lowered(torch.zeros(1))
97
98    def test_xnnpack_backend_add(self):
99        class AddModule(torch.nn.Module):
100            def forward(self, x, y):
101                z = x + y
102                z = z + x
103                z = z + x
104                return z
105
106        add_module = AddModule()
107        sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3))
108        sample_output = torch.zeros(1, 512, 512, 3)
109
110        add_module = torch.jit.script(add_module)
111        expected_output = add_module(sample_inputs[0], sample_inputs[1])
112
113        lowered_add_module = torch._C._jit_to_backend(
114            "xnnpack",
115            add_module,
116            {
117                "forward": {
118                    "inputs": [sample_inputs[0].clone(), sample_inputs[1].clone()],
119                    "outputs": [sample_output],
120                }
121            },
122        )
123
124        actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
125        self.assertTrue(
126            torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
127        )
128
129    def test_xnnpack_broadcasting(self):
130        class AddModule(torch.nn.Module):
131            def forward(self, x, y):
132                return x + y
133
134        add_module = AddModule()
135        sample_inputs = (torch.rand(5, 1, 4, 1), torch.rand(3, 1, 1))
136        sample_output = torch.zeros(5, 3, 4, 1)
137
138        add_module = torch.jit.script(add_module)
139        expected_output = add_module(sample_inputs[0], sample_inputs[1])
140
141        lowered_add_module = torch._C._jit_to_backend(
142            "xnnpack",
143            add_module,
144            {
145                "forward": {
146                    "inputs": [sample_inputs[0], sample_inputs[1]],
147                    "outputs": [sample_output],
148                }
149            },
150        )
151
152        actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
153        self.assertTrue(
154            torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
155        )
156
157    def test_xnnpack_unsupported(self):
158        class AddSpliceModule(torch.nn.Module):
159            def forward(self, x, y):
160                z = x + y[:, :, 1, :]
161                return z
162
163        sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3))
164        sample_output = torch.zeros(1, 512, 512, 3)
165
166        error_msg = (
167            "the module contains the following unsupported ops:\n"
168            "aten::select\n"
169            "aten::slice\n"
170        )
171
172        add_module = torch.jit.script(AddSpliceModule())
173        with self.assertRaisesRegex(
174            RuntimeError,
175            error_msg,
176        ):
177            _ = torch._C._jit_to_backend(
178                "xnnpack",
179                add_module,
180                {
181                    "forward": {
182                        "inputs": [sample_inputs[0], sample_inputs[1]],
183                        "outputs": [sample_output],
184                    }
185                },
186            )
187