• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
11from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
12from executorch.backends.xnnpack.test.tester import RunPasses, Tester
13from executorch.exir.dialects._ops import ops as exir_ops
14
15
16class TestActivationFusion(unittest.TestCase):
17    PassStage = RunPasses([ConvertToLinearPass, FuseActivationPass])
18
19    def check_node_has_tag(self, graph_module, node_target, tag):
20        for n in graph_module.graph.nodes:
21            if n.op == "call_function" and n.target == node_target:
22                return FuseActivationPass.FUSED_ACTIVATION_TAG in n.meta
23
24    class OpActivation(torch.nn.Module):
25        def __init__(self, module: torch.nn.Module, activation):
26            super().__init__()
27            self.seq = torch.nn.Sequential(module, activation)
28
29        def forward(self, x):
30            return self.seq(x)
31
32    class UnaryOps(torch.nn.Module):
33        def __init__(self, unary_op):
34            super().__init__()
35            self.unary_op = unary_op
36
37        def forward(self, a):
38            return self.unary_op(a, a)
39
40    def _test_op_activation_case(
41        self,
42        module,
43        edge_op,
44        inputs,
45        quantize=False,
46        activation=None,
47        activation_name="executorch_exir_dialects_edge__ops_aten_relu_default",
48    ):
49        activation = activation or torch.nn.ReLU()
50        tester = Tester(self.OpActivation(module, activation).eval(), inputs)
51        if quantize:
52            tester.quantize()
53
54        artifact = (
55            tester.export()
56            .to_edge()
57            .run_passes(self.PassStage)
58            .check_not([activation_name])
59            .get_artifact(Tester.stage_name(self.PassStage))
60        )
61
62        for node in artifact.exported_program().module().graph.nodes:
63            if node.op == "call_function" and node.target == edge_op:
64                self.assertTrue(FuseActivationPass.FUSED_ACTIVATION_TAG in node.meta)
65
66    def test_activation_fusion_conv_relu(self):
67        inputs = (torch.randn(1, 1, 8, 8),)
68        self._test_op_activation_case(
69            torch.nn.Conv2d(1, 1, (4, 4)),
70            exir_ops.edge.aten.convolution.default,
71            inputs,
72        )
73        self._test_op_activation_case(
74            torch.nn.Conv2d(1, 1, (4, 4)),
75            exir_ops.edge.aten.convolution.default,
76            inputs,
77            quantize=True,
78        )
79
80    def test_activation_fusion_linear_relu(self):
81        inputs = (torch.randn(1, 1, 8, 8),)
82        self._test_op_activation_case(
83            torch.nn.Linear(8, 8),
84            exir_ops.edge.aten.linear.default,
85            inputs,
86        )
87        self._test_op_activation_case(
88            torch.nn.Linear(8, 8),
89            exir_ops.edge.aten.linear.default,
90            inputs,
91            quantize=True,
92        )
93
94    def test_activation_fusion_add_relu(self):
95        inputs = (torch.randn(1, 1, 8, 8),)
96
97        self._test_op_activation_case(
98            self.UnaryOps(torch.add),
99            exir_ops.edge.aten.add.Tensor,
100            inputs,
101        )
102        self._test_op_activation_case(
103            self.UnaryOps(torch.add),
104            exir_ops.edge.aten.add.Tensor,
105            inputs,
106            quantize=True,
107        )
108
109    def test_activation_fusion_mul_relu(self):
110        inputs = (torch.randn(1, 1, 8, 8),)
111
112        self._test_op_activation_case(
113            self.UnaryOps(torch.mul),
114            exir_ops.edge.aten.mul.Tensor,
115            inputs,
116        )
117        self._test_op_activation_case(
118            self.UnaryOps(torch.mul),
119            exir_ops.edge.aten.mul.Tensor,
120            inputs,
121            quantize=True,
122        )
123
124    def test_activation_fusion_sub_relu(self):
125        inputs = (torch.randn(1, 1, 8, 8),)
126
127        self._test_op_activation_case(
128            self.UnaryOps(torch.sub),
129            exir_ops.edge.aten.sub.Tensor,
130            inputs,
131        )
132        self._test_op_activation_case(
133            self.UnaryOps(torch.sub),
134            exir_ops.edge.aten.sub.Tensor,
135            inputs,
136            quantize=True,
137        )
138
139    def test_activation_fusion_conv_hardtanh(self):
140        inputs = (torch.randn(1, 1, 8, 8),)
141        self._test_op_activation_case(
142            torch.nn.Conv2d(1, 1, (4, 4)),
143            exir_ops.edge.aten.convolution.default,
144            inputs,
145            activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0),
146            activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default",
147        )
148        self._test_op_activation_case(
149            torch.nn.Conv2d(1, 1, (4, 4)),
150            exir_ops.edge.aten.convolution.default,
151            inputs,
152            activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0),
153            activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default",
154        )
155
156    def test_activation_fusion_linear_hardtanh(self):
157        inputs = (torch.randn(1, 1, 8, 8),)
158        self._test_op_activation_case(
159            torch.nn.Linear(8, 8),
160            exir_ops.edge.aten.linear.default,
161            inputs,
162            activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0),
163            activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default",
164        )
165
166    def test_activation_fusion_add_hardtanh(self):
167        inputs = (torch.randn(1, 1, 8, 8),)
168
169        self._test_op_activation_case(
170            self.UnaryOps(torch.add),
171            exir_ops.edge.aten.add.Tensor,
172            inputs,
173            activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0),
174            activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default",
175        )
176
177    def test_activation_fusion_mul_hardtanh(self):
178        inputs = (torch.randn(1, 1, 8, 8),)
179
180        self._test_op_activation_case(
181            self.UnaryOps(torch.mul),
182            exir_ops.edge.aten.mul.Tensor,
183            inputs,
184            activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0),
185            activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default",
186        )
187
188    def test_activation_fusion_sub_hardtanh(self):
189        inputs = (torch.randn(1, 1, 8, 8),)
190
191        self._test_op_activation_case(
192            self.UnaryOps(torch.sub),
193            exir_ops.edge.aten.sub.Tensor,
194            inputs,
195            activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0),
196            activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default",
197        )
198