1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import torch 10from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass 11from executorch.backends.xnnpack.operators.node_visitor import ( 12 get_input_node, 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.xnnpack.operators.quant_params import QuantParams 17from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 18 XNNFullyConnected, 19 XNNGraph, 20 XNode, 21) 22 23from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID 24 25 26@register_node_visitor 27class LinearVisitor(NodeVisitor): 28 target = "aten.linear.default" 29 30 def __init__(self, *args) -> None: 31 super().__init__(*args) 32 33 def define_node( 34 self, 35 node: torch.fx.Node, 36 xnn_graph: XNNGraph, 37 vals_to_ids: Dict[torch.fx.Node, int], 38 debug_handle: int, 39 ) -> None: 40 41 # input 42 input_node = get_input_node(node, 0) 43 input_quant_params = QuantParams.from_inputs(input_node, self._exported_program) 44 self.define_tensor( 45 input_node, 46 xnn_graph, 47 vals_to_ids, 48 quant_params=input_quant_params, 49 ) 50 input_id = vals_to_ids[input_node] 51 52 # filter 53 weight_node = get_input_node(node, 1) 54 weight_quant_params = QuantParams.from_weights( 55 weight_node, self._exported_program 56 ) 57 self.define_tensor( 58 weight_node, 59 xnn_graph, 60 vals_to_ids, 61 quant_params=weight_quant_params, 62 fp32_static_weights=True, 63 ) 64 filter_id = vals_to_ids[weight_node] 65 66 # bias 67 if len(node.args) > 2: 68 bias_node = get_input_node(node, 2) 69 bias_quant_params = QuantParams.from_bias( 70 bias_node, weight_quant_params, input_quant_params 71 ) 72 self.define_tensor( 73 get_input_node(node, 2), 74 xnn_graph, 75 vals_to_ids, 76 quant_params=bias_quant_params, 77 fp32_static_weights=True, 78 ) 79 bias_id = vals_to_ids[bias_node] 80 else: 81 bias_id = XNN_INVALID_VALUE_ID 82 83 # output 84 output_min_max = FuseActivationPass.get_fused_activation(node) 85 output_quant_params = QuantParams.from_outputs(node) 86 self.define_tensor( 87 node, 88 xnn_graph, 89 vals_to_ids, 90 quant_params=output_quant_params, 91 ) 92 output_id = vals_to_ids[node] 93 94 ser_node = XNode( 95 xnode_union=XNNFullyConnected( 96 input1_id=input_id, 97 filter_id=filter_id, 98 bias_id=bias_id, 99 output_id=output_id, 100 flags=0, 101 ), 102 debug_handle=debug_handle, 103 output_min_max=output_min_max, 104 ) 105 xnn_graph.xnodes.append(ser_node) 106