• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch import exir
11from executorch.exir import to_edge
12from executorch.exir.backend.backend_api import to_backend
13from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
14    ExecutorBackendPartitioner,
15)
16from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
17    ExecutorBackend,
18)
19from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
20
21from executorch.extension.pybindings.portable_lib import (  # @manual
22    _load_for_executorch_from_buffer,
23)
24from torch.export import export
25from torch.utils._pytree import tree_flatten
26
27"""
28Server can be an App Call, and will send delegate to client backend like DSP,
29DSP will reeive the rpc call, calls the ExecuTorch instance (like on DSP),
30and return the result.
31
32
33       +------------+                                         +--------------+
34       | Host (CPU) |                                         | Client (DSP) |
35       +------------+                                         +--------------+
36              |                                                     |
37              |                                                     |
38              v                                                     |
39+--------------------------------+                                  |
40|Instatiate an Executor instance |                                  |
41+--------------------------------+                                  |
42              |                                                     |
43              |                                                     v
44              v               send init call with     +---------------------------------+
45    +--------------------+    delegate                | Unwarp the delegate,            |
46    |                    | -------------------------->| instatiate an Executor instance |
47    |init execution plan |                            +---------------------------------+
48    |                    |                                          |
49    |                    |    finish init call           +----------v----------+
50    +--------------------+<------------------------------| init execution plan |
51              |                                          +---------------------+
52              |                                                     |
53              v                                                     |
54        +------------+        send the execute call                 v
55        |            |---------------------------------------> +---------+
56        |  execute   |        receive the execute result       | execute |
57        +------------+<--------------------------------------  +---------+
58              |                                                     |
59              |                                                     |
60              |                                                     |
61              |                                                     v
62              v
63
64
65
66For example, in some usecases, there are can be three layers MCU -> DSP -> AC
67
68MCU
69——
701. MCU instantiate ExecuTorch instance with DSPBackend
712. In DSPBackend init/execute, it'll invoke the implemented RPC calls on DSP
72
73DSP
74——
753. DSP receives the RPC call and construct the ExecuTorch instance on the DSP
764. When dsp executor runs, it can call any delegate (e.g. Accelerator) as needed.
77
78There’ll negligible overhead in binary size on the MCU, as the executor size is small.
79"""
80
81
82class TestRPCDemos(unittest.TestCase):
83    def get_a_simple_net(self) -> torch.nn.Module:
84        class Net(torch.nn.Module):
85            def __init__(self):
86                super().__init__()
87                self.linear1 = torch.nn.Linear(4, 25)
88                self.linear2 = torch.nn.Linear(25, 3)
89
90            def forward(self, x):
91                x = torch.sigmoid(self.linear1(x))
92                x = self.linear2(x)
93                return x
94
95            def get_example_inputs(self):
96                return (torch.randn(25, 4),)
97
98        return Net()
99
100    def test_delegate_whole_program(self):
101        # This example shows how to delegate the whole simple net to the client executor, like on DSP
102        # CPU -> delegate (simple net) -> DSP executor run simple net
103
104        simple_net = self.get_a_simple_net()
105        simple_net_input = simple_net.get_example_inputs()
106        exported_program = to_edge(
107            export(simple_net, simple_net_input),
108            compile_config=exir.EdgeCompileConfig(
109                _check_ir_validity=False,
110            ),
111        )
112        # delegate the whole graph to the client executor
113        lowered_module = to_backend(
114            ExecutorBackend.__name__, exported_program.exported_program(), []
115        )
116
117        class CompositeModule(torch.nn.Module):
118            def __init__(self):
119                super().__init__()
120                self.lowered = lowered_module
121
122            def forward(self, *args):
123                return self.lowered(*args)
124
125        composite_model = CompositeModule()
126
127        exec_prog = to_edge(export(composite_model, simple_net_input)).to_executorch()
128
129        executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
130
131        # Now client executor is instantiate
132        inputs_flattened, _ = tree_flatten(simple_net_input)
133
134        # Send the input from server executor to client executor, and receive the result from client executor
135        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
136        ref_output = composite_model(*simple_net_input)
137
138        # Compare the server executor final result with eager model
139        self.assertTrue(
140            torch.allclose(model_output[0], ref_output, rtol=1e-03, atol=1e-03)
141        )
142
143    def test_delegate_partial_program(self):
144        # CPU -> delegate (simple net) -> DSP executor run simple net
145        # (TODO): input -> linear (delegated to dsp executor) -> sigmoid -> linear (delegated to dsp executor) -> output
146        pass
147
148    def test_delegate_program_with_nested_delegate(self):
149        class Model(torch.nn.Module):
150            def __init__(self):
151                super().__init__()
152
153            def forward(self, a, x, b):
154                a = x - b  # compute in dsp
155                y = torch.mm(a, x)  # compute in hta
156                z = y + b  # compute in hta
157                return z
158
159        model = Model()
160        inputs = (torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2))
161
162        exported_program = to_edge(export(model, inputs))
163
164        # First lower to demo backend
165        demo_backend_lowered = exported_program.to_backend(AddMulPartitionerDemo())
166
167        # Then lower to executor backend
168        executor_backend_lowered = demo_backend_lowered.to_backend(
169            ExecutorBackendPartitioner()
170        )
171
172        prog_buffer = executor_backend_lowered.to_executorch()
173        buffer = prog_buffer.buffer
174
175        executorch_module = _load_for_executorch_from_buffer(buffer)
176
177        # Now client executor is instantiate
178        inputs_flattened, _ = tree_flatten(inputs)
179
180        # Send the input from server executor to client executor, and receive the result from client executor
181        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
182        ref_output = model(*inputs)
183
184        # Compare the server executor final result with eager model
185        self.assertTrue(
186            torch.allclose(model_output[0], ref_output, rtol=1e-03, atol=1e-03)
187        )
188