• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import sys
8
9import executorch.backends.vulkan.custom_ops_lib  # noqa
10
11import torch
12
13from executorch.exir.dialects._ops import ops as exir_ops
14from executorch.exir.pass_base import ExportPass, PassResult
15
16
17class FuseClampPass(ExportPass):
18    """
19    Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it.
20    """
21
22    FUSEABLE_OPS = [
23        exir_ops.edge.aten.convolution.default,
24    ]
25    FUSEABLE_ACTIVATIONS = [
26        exir_ops.edge.aten.relu.default,
27        exir_ops.edge.aten.hardtanh.default,
28    ]
29
30    def get_output_min_max_from_activation(self, activation_node):
31        if activation_node.target == exir_ops.edge.aten.relu.default:
32            output_min = 0.0
33            output_max = sys.float_info.max
34        elif activation_node.target == exir_ops.edge.aten.hardtanh.default:
35            output_min = -1.0
36            output_max = 1.0
37            if len(activation_node.args) > 1:
38                output_min = activation_node.args[1]
39                output_max = activation_node.args[2]
40
41        return output_min, output_max
42
43    def call(self, graph_module: torch.fx.GraphModule):
44        for activation_node in graph_module.graph.nodes:
45            if activation_node.op == "call_function":
46                if activation_node.target in self.FUSEABLE_ACTIVATIONS:
47                    preceding_op = activation_node.args[0]
48                    if (
49                        preceding_op.op == "call_function"
50                        and preceding_op.target in self.FUSEABLE_OPS
51                    ):
52                        # Delete activation
53                        output_min_max = self.get_output_min_max_from_activation(
54                            activation_node
55                        )
56                        new_args = list(preceding_op.args)
57                        new_args.append(output_min_max[0])
58                        new_args.append(output_min_max[1])
59                        new_args = tuple(new_args)
60                        activation_node.replace_all_uses_with(preceding_op)
61                        graph_module.graph.erase_node(activation_node)
62
63                        # Create and insert node of custom op `conv_with_clamp`
64                        with graph_module.graph.inserting_before(preceding_op):
65                            conv_activation_node = graph_module.graph.create_node(
66                                "call_function",
67                                exir_ops.edge.et_vk.conv_with_clamp.default,
68                                new_args,
69                            )
70
71                            preceding_op.replace_all_uses_with(conv_activation_node)
72                            graph_module.graph.erase_node(preceding_op)
73
74        graph_module.recompile()
75        graph_module = super().call(graph_module).graph_module
76
77        return PassResult(graph_module, True)
78