1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.exir import to_edge 11from executorch.exir.backend.backend_api import to_backend 12from executorch.exir.backend.compile_spec_schema import CompileSpec 13from executorch.exir.backend.test.backend_with_compiler_demo import ( 14 BackendWithCompilerDemo, 15) 16 17from executorch.extension.pybindings.aten_lib import ( # @manual 18 _load_for_executorch_from_buffer, 19) 20 21from executorch.extension.pytree import tree_flatten 22from torch.export import export 23 24 25class TestDelegateAtenMode(unittest.TestCase): 26 def test_add_xnnpack_and_dqlinear_qnn(self): 27 class AddMulModule(torch.nn.Module): 28 def __init__(self): 29 super().__init__() 30 31 def forward(self, a, x, b): 32 y = torch.mm(a, x) 33 z = torch.add(y, b) 34 return z 35 36 add_mul_module = AddMulModule() 37 model_inputs = (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) 38 edge_graph_module = to_edge(export(add_mul_module, model_inputs)) 39 max_value = model_inputs[0].shape[0] 40 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 41 lowered_add_mul = to_backend( 42 BackendWithCompilerDemo.__name__, 43 edge_graph_module.exported_program(), 44 compile_specs, 45 ) 46 47 class CompositeModule(torch.nn.Module): 48 def __init__(self): 49 super().__init__() 50 self.lowered_add_mul = lowered_add_mul 51 52 def forward(self, a, x, b): 53 return self.lowered_add_mul(a, x, b) 54 55 composite_model = CompositeModule() 56 57 composite_model(*model_inputs) 58 59 exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch() 60 61 buff = exec_prog.buffer 62 63 executorch_module = _load_for_executorch_from_buffer(buff) 64 inputs_flattened, _ = tree_flatten(model_inputs) 65 model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) 66 ref_output = add_mul_module(*model_inputs) 67 68 self.assertTrue( 69 torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03) 70 ) 71