• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict
8
9import torch
10from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
11from executorch.backends.xnnpack.operators.node_visitor import (
12    get_input_node,
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.xnnpack.operators.quant_params import QuantParams
17from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
18    XNNFullyConnected,
19    XNNGraph,
20    XNode,
21)
22
23from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
24
25
26@register_node_visitor
27class LinearVisitor(NodeVisitor):
28    target = "aten.linear.default"
29
30    def __init__(self, *args) -> None:
31        super().__init__(*args)
32
33    def define_node(
34        self,
35        node: torch.fx.Node,
36        xnn_graph: XNNGraph,
37        vals_to_ids: Dict[torch.fx.Node, int],
38        debug_handle: int,
39    ) -> None:
40
41        # input
42        input_node = get_input_node(node, 0)
43        input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
44        self.define_tensor(
45            input_node,
46            xnn_graph,
47            vals_to_ids,
48            quant_params=input_quant_params,
49        )
50        input_id = vals_to_ids[input_node]
51
52        # filter
53        weight_node = get_input_node(node, 1)
54        weight_quant_params = QuantParams.from_weights(
55            weight_node, self._exported_program
56        )
57        self.define_tensor(
58            weight_node,
59            xnn_graph,
60            vals_to_ids,
61            quant_params=weight_quant_params,
62            fp32_static_weights=True,
63        )
64        filter_id = vals_to_ids[weight_node]
65
66        # bias
67        if len(node.args) > 2:
68            bias_node = get_input_node(node, 2)
69            bias_quant_params = QuantParams.from_bias(
70                bias_node, weight_quant_params, input_quant_params
71            )
72            self.define_tensor(
73                get_input_node(node, 2),
74                xnn_graph,
75                vals_to_ids,
76                quant_params=bias_quant_params,
77                fp32_static_weights=True,
78            )
79            bias_id = vals_to_ids[bias_node]
80        else:
81            bias_id = XNN_INVALID_VALUE_ID
82
83        # output
84        output_min_max = FuseActivationPass.get_fused_activation(node)
85        output_quant_params = QuantParams.from_outputs(node)
86        self.define_tensor(
87            node,
88            xnn_graph,
89            vals_to_ids,
90            quant_params=output_quant_params,
91        )
92        output_id = vals_to_ids[node]
93
94        ser_node = XNode(
95            xnode_union=XNNFullyConnected(
96                input1_id=input_id,
97                filter_id=filter_id,
98                bias_id=bias_id,
99                output_id=output_id,
100                flags=0,
101            ),
102            debug_handle=debug_handle,
103            output_min_max=output_min_max,
104        )
105        xnn_graph.xnodes.append(ser_node)
106