1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass 11from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass 12from executorch.backends.xnnpack.test.tester import RunPasses, Tester 13from executorch.exir.dialects._ops import ops as exir_ops 14 15 16class TestActivationFusion(unittest.TestCase): 17 PassStage = RunPasses([ConvertToLinearPass, FuseActivationPass]) 18 19 def check_node_has_tag(self, graph_module, node_target, tag): 20 for n in graph_module.graph.nodes: 21 if n.op == "call_function" and n.target == node_target: 22 return FuseActivationPass.FUSED_ACTIVATION_TAG in n.meta 23 24 class OpActivation(torch.nn.Module): 25 def __init__(self, module: torch.nn.Module, activation): 26 super().__init__() 27 self.seq = torch.nn.Sequential(module, activation) 28 29 def forward(self, x): 30 return self.seq(x) 31 32 class UnaryOps(torch.nn.Module): 33 def __init__(self, unary_op): 34 super().__init__() 35 self.unary_op = unary_op 36 37 def forward(self, a): 38 return self.unary_op(a, a) 39 40 def _test_op_activation_case( 41 self, 42 module, 43 edge_op, 44 inputs, 45 quantize=False, 46 activation=None, 47 activation_name="executorch_exir_dialects_edge__ops_aten_relu_default", 48 ): 49 activation = activation or torch.nn.ReLU() 50 tester = Tester(self.OpActivation(module, activation).eval(), inputs) 51 if quantize: 52 tester.quantize() 53 54 artifact = ( 55 tester.export() 56 .to_edge() 57 .run_passes(self.PassStage) 58 .check_not([activation_name]) 59 .get_artifact(Tester.stage_name(self.PassStage)) 60 ) 61 62 for node in artifact.exported_program().module().graph.nodes: 63 if node.op == "call_function" and node.target == edge_op: 64 self.assertTrue(FuseActivationPass.FUSED_ACTIVATION_TAG in node.meta) 65 66 def test_activation_fusion_conv_relu(self): 67 inputs = (torch.randn(1, 1, 8, 8),) 68 self._test_op_activation_case( 69 torch.nn.Conv2d(1, 1, (4, 4)), 70 exir_ops.edge.aten.convolution.default, 71 inputs, 72 ) 73 self._test_op_activation_case( 74 torch.nn.Conv2d(1, 1, (4, 4)), 75 exir_ops.edge.aten.convolution.default, 76 inputs, 77 quantize=True, 78 ) 79 80 def test_activation_fusion_linear_relu(self): 81 inputs = (torch.randn(1, 1, 8, 8),) 82 self._test_op_activation_case( 83 torch.nn.Linear(8, 8), 84 exir_ops.edge.aten.linear.default, 85 inputs, 86 ) 87 self._test_op_activation_case( 88 torch.nn.Linear(8, 8), 89 exir_ops.edge.aten.linear.default, 90 inputs, 91 quantize=True, 92 ) 93 94 def test_activation_fusion_add_relu(self): 95 inputs = (torch.randn(1, 1, 8, 8),) 96 97 self._test_op_activation_case( 98 self.UnaryOps(torch.add), 99 exir_ops.edge.aten.add.Tensor, 100 inputs, 101 ) 102 self._test_op_activation_case( 103 self.UnaryOps(torch.add), 104 exir_ops.edge.aten.add.Tensor, 105 inputs, 106 quantize=True, 107 ) 108 109 def test_activation_fusion_mul_relu(self): 110 inputs = (torch.randn(1, 1, 8, 8),) 111 112 self._test_op_activation_case( 113 self.UnaryOps(torch.mul), 114 exir_ops.edge.aten.mul.Tensor, 115 inputs, 116 ) 117 self._test_op_activation_case( 118 self.UnaryOps(torch.mul), 119 exir_ops.edge.aten.mul.Tensor, 120 inputs, 121 quantize=True, 122 ) 123 124 def test_activation_fusion_sub_relu(self): 125 inputs = (torch.randn(1, 1, 8, 8),) 126 127 self._test_op_activation_case( 128 self.UnaryOps(torch.sub), 129 exir_ops.edge.aten.sub.Tensor, 130 inputs, 131 ) 132 self._test_op_activation_case( 133 self.UnaryOps(torch.sub), 134 exir_ops.edge.aten.sub.Tensor, 135 inputs, 136 quantize=True, 137 ) 138 139 def test_activation_fusion_conv_hardtanh(self): 140 inputs = (torch.randn(1, 1, 8, 8),) 141 self._test_op_activation_case( 142 torch.nn.Conv2d(1, 1, (4, 4)), 143 exir_ops.edge.aten.convolution.default, 144 inputs, 145 activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0), 146 activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default", 147 ) 148 self._test_op_activation_case( 149 torch.nn.Conv2d(1, 1, (4, 4)), 150 exir_ops.edge.aten.convolution.default, 151 inputs, 152 activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0), 153 activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default", 154 ) 155 156 def test_activation_fusion_linear_hardtanh(self): 157 inputs = (torch.randn(1, 1, 8, 8),) 158 self._test_op_activation_case( 159 torch.nn.Linear(8, 8), 160 exir_ops.edge.aten.linear.default, 161 inputs, 162 activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0), 163 activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default", 164 ) 165 166 def test_activation_fusion_add_hardtanh(self): 167 inputs = (torch.randn(1, 1, 8, 8),) 168 169 self._test_op_activation_case( 170 self.UnaryOps(torch.add), 171 exir_ops.edge.aten.add.Tensor, 172 inputs, 173 activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0), 174 activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default", 175 ) 176 177 def test_activation_fusion_mul_hardtanh(self): 178 inputs = (torch.randn(1, 1, 8, 8),) 179 180 self._test_op_activation_case( 181 self.UnaryOps(torch.mul), 182 exir_ops.edge.aten.mul.Tensor, 183 inputs, 184 activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0), 185 activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default", 186 ) 187 188 def test_activation_fusion_sub_hardtanh(self): 189 inputs = (torch.randn(1, 1, 8, 8),) 190 191 self._test_op_activation_case( 192 self.UnaryOps(torch.sub), 193 exir_ops.edge.aten.sub.Tensor, 194 inputs, 195 activation=torch.nn.Hardtanh(min_val=-1.0, max_val=1.0), 196 activation_name="executorch_exir_dialects_edge__ops_aten_hardtanh_default", 197 ) 198