1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import sys 8 9import executorch.backends.vulkan.custom_ops_lib # noqa 10 11import torch 12 13from executorch.exir.dialects._ops import ops as exir_ops 14from executorch.exir.pass_base import ExportPass, PassResult 15 16 17class FuseClampPass(ExportPass): 18 """ 19 Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it. 20 """ 21 22 FUSEABLE_OPS = [ 23 exir_ops.edge.aten.convolution.default, 24 ] 25 FUSEABLE_ACTIVATIONS = [ 26 exir_ops.edge.aten.relu.default, 27 exir_ops.edge.aten.hardtanh.default, 28 ] 29 30 def get_output_min_max_from_activation(self, activation_node): 31 if activation_node.target == exir_ops.edge.aten.relu.default: 32 output_min = 0.0 33 output_max = sys.float_info.max 34 elif activation_node.target == exir_ops.edge.aten.hardtanh.default: 35 output_min = -1.0 36 output_max = 1.0 37 if len(activation_node.args) > 1: 38 output_min = activation_node.args[1] 39 output_max = activation_node.args[2] 40 41 return output_min, output_max 42 43 def call(self, graph_module: torch.fx.GraphModule): 44 for activation_node in graph_module.graph.nodes: 45 if activation_node.op == "call_function": 46 if activation_node.target in self.FUSEABLE_ACTIVATIONS: 47 preceding_op = activation_node.args[0] 48 if ( 49 preceding_op.op == "call_function" 50 and preceding_op.target in self.FUSEABLE_OPS 51 ): 52 # Delete activation 53 output_min_max = self.get_output_min_max_from_activation( 54 activation_node 55 ) 56 new_args = list(preceding_op.args) 57 new_args.append(output_min_max[0]) 58 new_args.append(output_min_max[1]) 59 new_args = tuple(new_args) 60 activation_node.replace_all_uses_with(preceding_op) 61 graph_module.graph.erase_node(activation_node) 62 63 # Create and insert node of custom op `conv_with_clamp` 64 with graph_module.graph.inserting_before(preceding_op): 65 conv_activation_node = graph_module.graph.create_node( 66 "call_function", 67 exir_ops.edge.et_vk.conv_with_clamp.default, 68 new_args, 69 ) 70 71 preceding_op.replace_all_uses_with(conv_activation_node) 72 graph_module.graph.erase_node(preceding_op) 73 74 graph_module.recompile() 75 graph_module = super().call(graph_module).graph_module 76 77 return PassResult(graph_module, True) 78