1# Owner(s): ["oncall: jit"] 2 3import unittest 4 5import torch 6import torch._C 7 8 9torch.ops.load_library("//caffe2:xnnpack_backend") 10 11 12class TestXNNPackBackend(unittest.TestCase): 13 def test_xnnpack_constant_data(self): 14 class Module(torch.nn.Module): 15 def __init__(self) -> None: 16 super().__init__() 17 self._constant = torch.ones(4, 4, 4) 18 19 def forward(self, x): 20 return x + self._constant 21 22 scripted_module = torch.jit.script(Module()) 23 24 lowered_module = torch._C._jit_to_backend( 25 "xnnpack", 26 scripted_module, 27 { 28 "forward": { 29 "inputs": [torch.randn(4, 4, 4)], 30 "outputs": [torch.randn(4, 4, 4)], 31 } 32 }, 33 ) 34 35 for i in range(0, 20): 36 sample_input = torch.randn(4, 4, 4) 37 actual_output = scripted_module(sample_input) 38 expected_output = lowered_module(sample_input) 39 self.assertTrue( 40 torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03) 41 ) 42 43 def test_xnnpack_lowering(self): 44 class Module(torch.nn.Module): 45 def forward(self, x): 46 return x + x 47 48 scripted_module = torch.jit.script(Module()) 49 50 faulty_compile_spec = { 51 "backward": { 52 "inputs": [torch.zeros(1)], 53 "outputs": [torch.zeros(1)], 54 } 55 } 56 error_msg = 'method_compile_spec does not contain the "forward" key.' 57 58 with self.assertRaisesRegex( 59 RuntimeError, 60 error_msg, 61 ): 62 _ = torch._C._jit_to_backend( 63 "xnnpack", 64 scripted_module, 65 faulty_compile_spec, 66 ) 67 68 mismatch_compile_spec = { 69 "forward": { 70 "inputs": [torch.zeros(1), torch.zeros(1)], 71 "outputs": [torch.zeros(1)], 72 } 73 } 74 error_msg = ( 75 "method_compile_spec inputs do not match expected number of forward inputs" 76 ) 77 78 with self.assertRaisesRegex( 79 RuntimeError, 80 error_msg, 81 ): 82 _ = torch._C._jit_to_backend( 83 "xnnpack", scripted_module, mismatch_compile_spec 84 ) 85 86 lowered = torch._C._jit_to_backend( 87 "xnnpack", 88 scripted_module, 89 { 90 "forward": { 91 "inputs": [torch.zeros(1)], 92 "outputs": [torch.zeros(1)], 93 } 94 }, 95 ) 96 lowered(torch.zeros(1)) 97 98 def test_xnnpack_backend_add(self): 99 class AddModule(torch.nn.Module): 100 def forward(self, x, y): 101 z = x + y 102 z = z + x 103 z = z + x 104 return z 105 106 add_module = AddModule() 107 sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3)) 108 sample_output = torch.zeros(1, 512, 512, 3) 109 110 add_module = torch.jit.script(add_module) 111 expected_output = add_module(sample_inputs[0], sample_inputs[1]) 112 113 lowered_add_module = torch._C._jit_to_backend( 114 "xnnpack", 115 add_module, 116 { 117 "forward": { 118 "inputs": [sample_inputs[0].clone(), sample_inputs[1].clone()], 119 "outputs": [sample_output], 120 } 121 }, 122 ) 123 124 actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1]) 125 self.assertTrue( 126 torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03) 127 ) 128 129 def test_xnnpack_broadcasting(self): 130 class AddModule(torch.nn.Module): 131 def forward(self, x, y): 132 return x + y 133 134 add_module = AddModule() 135 sample_inputs = (torch.rand(5, 1, 4, 1), torch.rand(3, 1, 1)) 136 sample_output = torch.zeros(5, 3, 4, 1) 137 138 add_module = torch.jit.script(add_module) 139 expected_output = add_module(sample_inputs[0], sample_inputs[1]) 140 141 lowered_add_module = torch._C._jit_to_backend( 142 "xnnpack", 143 add_module, 144 { 145 "forward": { 146 "inputs": [sample_inputs[0], sample_inputs[1]], 147 "outputs": [sample_output], 148 } 149 }, 150 ) 151 152 actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1]) 153 self.assertTrue( 154 torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03) 155 ) 156 157 def test_xnnpack_unsupported(self): 158 class AddSpliceModule(torch.nn.Module): 159 def forward(self, x, y): 160 z = x + y[:, :, 1, :] 161 return z 162 163 sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3)) 164 sample_output = torch.zeros(1, 512, 512, 3) 165 166 error_msg = ( 167 "the module contains the following unsupported ops:\n" 168 "aten::select\n" 169 "aten::slice\n" 170 ) 171 172 add_module = torch.jit.script(AddSpliceModule()) 173 with self.assertRaisesRegex( 174 RuntimeError, 175 error_msg, 176 ): 177 _ = torch._C._jit_to_backend( 178 "xnnpack", 179 add_module, 180 { 181 "forward": { 182 "inputs": [sample_inputs[0], sample_inputs[1]], 183 "outputs": [sample_output], 184 } 185 }, 186 ) 187