# Owner(s): ["oncall: jit"] import unittest import torch import torch._C torch.ops.load_library("//caffe2:xnnpack_backend") class TestXNNPackBackend(unittest.TestCase): def test_xnnpack_constant_data(self): class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() self._constant = torch.ones(4, 4, 4) def forward(self, x): return x + self._constant scripted_module = torch.jit.script(Module()) lowered_module = torch._C._jit_to_backend( "xnnpack", scripted_module, { "forward": { "inputs": [torch.randn(4, 4, 4)], "outputs": [torch.randn(4, 4, 4)], } }, ) for i in range(0, 20): sample_input = torch.randn(4, 4, 4) actual_output = scripted_module(sample_input) expected_output = lowered_module(sample_input) self.assertTrue( torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03) ) def test_xnnpack_lowering(self): class Module(torch.nn.Module): def forward(self, x): return x + x scripted_module = torch.jit.script(Module()) faulty_compile_spec = { "backward": { "inputs": [torch.zeros(1)], "outputs": [torch.zeros(1)], } } error_msg = 'method_compile_spec does not contain the "forward" key.' with self.assertRaisesRegex( RuntimeError, error_msg, ): _ = torch._C._jit_to_backend( "xnnpack", scripted_module, faulty_compile_spec, ) mismatch_compile_spec = { "forward": { "inputs": [torch.zeros(1), torch.zeros(1)], "outputs": [torch.zeros(1)], } } error_msg = ( "method_compile_spec inputs do not match expected number of forward inputs" ) with self.assertRaisesRegex( RuntimeError, error_msg, ): _ = torch._C._jit_to_backend( "xnnpack", scripted_module, mismatch_compile_spec ) lowered = torch._C._jit_to_backend( "xnnpack", scripted_module, { "forward": { "inputs": [torch.zeros(1)], "outputs": [torch.zeros(1)], } }, ) lowered(torch.zeros(1)) def test_xnnpack_backend_add(self): class AddModule(torch.nn.Module): def forward(self, x, y): z = x + y z = z + x z = z + x return z add_module = AddModule() sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3)) sample_output = torch.zeros(1, 512, 512, 3) add_module = torch.jit.script(add_module) expected_output = add_module(sample_inputs[0], sample_inputs[1]) lowered_add_module = torch._C._jit_to_backend( "xnnpack", add_module, { "forward": { "inputs": [sample_inputs[0].clone(), sample_inputs[1].clone()], "outputs": [sample_output], } }, ) actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1]) self.assertTrue( torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03) ) def test_xnnpack_broadcasting(self): class AddModule(torch.nn.Module): def forward(self, x, y): return x + y add_module = AddModule() sample_inputs = (torch.rand(5, 1, 4, 1), torch.rand(3, 1, 1)) sample_output = torch.zeros(5, 3, 4, 1) add_module = torch.jit.script(add_module) expected_output = add_module(sample_inputs[0], sample_inputs[1]) lowered_add_module = torch._C._jit_to_backend( "xnnpack", add_module, { "forward": { "inputs": [sample_inputs[0], sample_inputs[1]], "outputs": [sample_output], } }, ) actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1]) self.assertTrue( torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03) ) def test_xnnpack_unsupported(self): class AddSpliceModule(torch.nn.Module): def forward(self, x, y): z = x + y[:, :, 1, :] return z sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3)) sample_output = torch.zeros(1, 512, 512, 3) error_msg = ( "the module contains the following unsupported ops:\n" "aten::select\n" "aten::slice\n" ) add_module = torch.jit.script(AddSpliceModule()) with self.assertRaisesRegex( RuntimeError, error_msg, ): _ = torch._C._jit_to_backend( "xnnpack", add_module, { "forward": { "inputs": [sample_inputs[0], sample_inputs[1]], "outputs": [sample_output], } }, )