• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lowering a Model as a Delegate
2
3Audience: ML Engineers, who are interested in applying delegates to accelerate their program in runtime.
4
5Backend delegation is an entry point for backends to process and execute PyTorch
6programs to leverage the performance and efficiency benefits of specialized
7backends and hardware, while still providing PyTorch users with an experience
8close to that of the PyTorch runtime. The backend delegate is usually either provided by
9ExecuTorch or vendors. The way to leverage delegation in your program is via a standard entry point `to_backend`.
10
11
12## Frontend Interfaces
13
14There are three flows for delegating a program to a backend:
15
161. Lower the whole module to a backend. This is good for testing backends and
17    the preprocessing stage.
181. Lower the whole module to a backend and compose it with another module. This
19    is good for reusing lowered modules exported from other flows.
201. Lower parts of a module according to a partitioner. This is good for
21    lowering models that include both lowerable and non-lowerable nodes, and is
22    the most streamlined procecss.
23
24### Flow 1: Lowering the whole module
25
26This flow starts from a traced graph module with Edge Dialect representation. To
27lower it, we call the following function which returns a `LoweredBackendModule` (more documentation on this function can be found in the [Export API reference](export-to-executorch-api-reference.rst))
28
29```python
30# defined in backend_api.py
31def to_backend(
32    backend_id: str,
33    edge_program: ExportedProgram,
34    compile_spec: List[CompileSpec],
35) -> LoweredBackendModule:
36```
37
38Within this function, the backend's `preprocess()` function is called which
39produces a compiled blob which will be emitted to the flatbuffer binary. The
40lowered module can be directly captured, or be put back in a parent module to be
41captured. Eventually the captured module is serialized in the flatbuffer's model
42that can be loaded by the runtime.
43
44The following is an example of this flow:
45
46```python
47from executorch.exir.backend.backend_api import to_backend
48import executorch.exir as exir
49import torch
50from torch.export import export
51from executorch.exir import to_edge
52
53# The submodule runs in a specific backend. In this example,  `BackendWithCompilerDemo` backend
54class LowerableSubModel(torch.nn.Module):
55    def __init__(self):
56        super().__init__()
57
58    def forward(self, x):
59        return torch.sin(x)
60
61# Convert the lowerable module to Edge IR Representation
62to_be_lowered = LowerableSubModel()
63example_input = (torch.ones(1), )
64to_be_lowered_exir_submodule = to_edge(export(to_be_lowered, example_input))
65
66# Import the backend implementation
67from executorch.exir.backend.test.backend_with_compiler_demo import (
68    BackendWithCompilerDemo,
69)
70lowered_module = to_backend('BackendWithCompilerDemo', to_be_lowered_exir_submodule.exported_program(), [])
71```
72
73We can serialize the program to a flatbuffer format by directly running:
74
75```python
76# Save the flatbuffer to a local file
77save_path = "delegate.pte"
78with open(save_path, "wb") as f:
79    f.write(lowered_module.buffer())
80```
81
82### Flow 2: Lowering the whole module and composite
83
84Alternatively, after flow 1, we can compose this lowered module with another
85module:
86
87```python
88# This submodule runs in executor runtime
89class NonLowerableSubModel(torch.nn.Module):
90    def __init__(self, bias):
91        super().__init__()
92        self.bias = bias
93
94    def forward(self, a, b):
95        return torch.add(torch.add(a, b), self.bias)
96
97
98# The composite module, including lower part and non-lowerpart
99class CompositeModel(torch.nn.Module):
100    def __init__(self):
101        super().__init__()
102        self.non_lowerable = NonLowerableSubModel(torch.ones(1) * 0.3)
103        self.lowerable = lowered_module
104
105    def forward(self, x):
106        a = self.lowerable(x)
107        b = self.lowerable(a)
108        ret = self.non_lowerable(a, b)
109        return a, b, ret
110
111composite_model = CompositeModel()
112model_inputs = (torch.ones(1), )
113exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch()
114
115# Save the flatbuffer to a local file
116save_path = "delegate.pte"
117with open(save_path, "wb") as f:
118    f.write(exec_prog.buffer)
119```
120
121### Flow 3: Partitioning
122
123The third flow also starts from a traced graph module with Edge Dialect
124representation. To lower certain nodes in this graph module, we can use the
125overloaded [`to_backend`
126function](https://github.com/pytorch/executorch/blob/d9eef24bb720804aa7b400b05241487510ae0dc2/exir/backend/backend_api.py#L39).
127
128```python
129def to_backend(
130    edge_program: ExportedProgram,
131    partitioner: Partitioner,
132) -> ExportedProgram:
133```
134
135This function takes in a `Partitioner` which adds a tag to all the nodes that
136are meant to be lowered. It will return a `partition_tags` dictionary mapping tags to
137backend names and module compile specs. The tagged nodes will then be
138partitioned and lowered to their mapped backends using Flow 1's process.
139Available helper partitioners are documented
140[here](./compiler-custom-compiler-passes.md). These lowered modules
141will be inserted into the top-level module and serialized.
142
143The following is an example of the flow:
144```python
145import executorch.exir as exir
146from executorch.exir.backend.backend_api import to_backend
147from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
148from executorch.exir.program import (
149    EdgeProgramManager,
150    to_edge,
151)
152from torch.export import export
153import torch
154
155class Model(torch.nn.Module):
156    def __init__(self):
157        super().__init__()
158
159    def forward(self, x, y):
160        x = x + y
161        x = x * y
162        x = x - y
163        x = x / y
164        x = x * y
165        x = x + y
166        return x
167
168model = Model()
169model_inputs = (torch.randn(1, 3), torch.randn(1, 3))
170
171core_aten_ep = export(model, model_inputs)
172edge: EdgeProgramManager = to_edge(core_aten_ep)
173edge = edge.to_backend(AddMulPartitionerDemo())
174exec_prog = edge.to_executorch()
175
176# Save the flatbuffer to a local file
177save_path = "delegate.pte"
178with open(save_path, "wb") as f:
179    f.write(exec_prog.buffer)
180```
181
182## Runtime
183
184After having the program with delegates, to run the model with the backend, we'd need to register the backend.
185Depending on the delegate implementation, the backend can be registered either as part of global variables or
186explicitly registered inside the main function.
187
188- If it's registered during global variables initialization, the backend will be registered as long as it's statically linked. Users only need to include the library as part of the dependency.
189
190- If the vendor provides an API to register the backend, users need to include the library as part of the dependency, and call the API provided by vendors to explicitly register the backend as part of the main function.
191