• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.exir import to_edge
11from executorch.exir.backend.backend_api import to_backend
12from executorch.exir.backend.compile_spec_schema import CompileSpec
13from executorch.exir.backend.test.backend_with_compiler_demo import (
14    BackendWithCompilerDemo,
15)
16
17from executorch.extension.pybindings.aten_lib import (  # @manual
18    _load_for_executorch_from_buffer,
19)
20
21from executorch.extension.pytree import tree_flatten
22from torch.export import export
23
24
25class TestDelegateAtenMode(unittest.TestCase):
26    def test_add_xnnpack_and_dqlinear_qnn(self):
27        class AddMulModule(torch.nn.Module):
28            def __init__(self):
29                super().__init__()
30
31            def forward(self, a, x, b):
32                y = torch.mm(a, x)
33                z = torch.add(y, b)
34                return z
35
36        add_mul_module = AddMulModule()
37        model_inputs = (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
38        edge_graph_module = to_edge(export(add_mul_module, model_inputs))
39        max_value = model_inputs[0].shape[0]
40        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
41        lowered_add_mul = to_backend(
42            BackendWithCompilerDemo.__name__,
43            edge_graph_module.exported_program(),
44            compile_specs,
45        )
46
47        class CompositeModule(torch.nn.Module):
48            def __init__(self):
49                super().__init__()
50                self.lowered_add_mul = lowered_add_mul
51
52            def forward(self, a, x, b):
53                return self.lowered_add_mul(a, x, b)
54
55        composite_model = CompositeModule()
56
57        composite_model(*model_inputs)
58
59        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch()
60
61        buff = exec_prog.buffer
62
63        executorch_module = _load_for_executorch_from_buffer(buff)
64        inputs_flattened, _ = tree_flatten(model_inputs)
65        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
66        ref_output = add_mul_module(*model_inputs)
67
68        self.assertTrue(
69            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
70        )
71