# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import sys import executorch.backends.vulkan.custom_ops_lib # noqa import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult class FuseClampPass(ExportPass): """ Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it. """ FUSEABLE_OPS = [ exir_ops.edge.aten.convolution.default, ] FUSEABLE_ACTIVATIONS = [ exir_ops.edge.aten.relu.default, exir_ops.edge.aten.hardtanh.default, ] def get_output_min_max_from_activation(self, activation_node): if activation_node.target == exir_ops.edge.aten.relu.default: output_min = 0.0 output_max = sys.float_info.max elif activation_node.target == exir_ops.edge.aten.hardtanh.default: output_min = -1.0 output_max = 1.0 if len(activation_node.args) > 1: output_min = activation_node.args[1] output_max = activation_node.args[2] return output_min, output_max def call(self, graph_module: torch.fx.GraphModule): for activation_node in graph_module.graph.nodes: if activation_node.op == "call_function": if activation_node.target in self.FUSEABLE_ACTIVATIONS: preceding_op = activation_node.args[0] if ( preceding_op.op == "call_function" and preceding_op.target in self.FUSEABLE_OPS ): # Delete activation output_min_max = self.get_output_min_max_from_activation( activation_node ) new_args = list(preceding_op.args) new_args.append(output_min_max[0]) new_args.append(output_min_max[1]) new_args = tuple(new_args) activation_node.replace_all_uses_with(preceding_op) graph_module.graph.erase_node(activation_node) # Create and insert node of custom op `conv_with_clamp` with graph_module.graph.inserting_before(preceding_op): conv_activation_node = graph_module.graph.create_node( "call_function", exir_ops.edge.et_vk.conv_with_clamp.default, new_args, ) preceding_op.replace_all_uses_with(conv_activation_node) graph_module.graph.erase_node(preceding_op) graph_module.recompile() graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True)