1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 15 XNNFullyConnected, 16 XNNGraph, 17 XNode, 18) 19from executorch.backends.xnnpack.utils.utils import get_input_node 20 21from executorch.backends.xnnpack.utils.xnnpack_constants import ( 22 XNN_FLAG_TRANSPOSE_WEIGHTS, 23 XNN_INVALID_VALUE_ID, 24) 25 26 27@register_node_visitor 28class MatrixMultiplyVisitor(NodeVisitor): 29 target = "aten.mm.default" 30 31 def __init__(self, *args) -> None: 32 super().__init__(*args) 33 34 def define_node( 35 self, 36 node: torch.fx.Node, 37 xnn_graph: XNNGraph, 38 vals_to_ids: Dict[torch.fx.Node, int], 39 debug_handle: int, 40 ) -> None: 41 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 42 43 # input 44 input_id = vals_to_ids[get_input_node(node, 0)] 45 46 # filter 47 filter_id = vals_to_ids[get_input_node(node, 1)] 48 49 # output 50 output = vals_to_ids[node] 51 52 # Matrix Multiply is handled by using linear with bias = 0. XNNPACK performs 53 # this by giving a dummy id as the bias in the fully-connected node. 54 ser_node = XNode( 55 xnode_union=XNNFullyConnected( 56 input1_id=input_id, 57 filter_id=filter_id, 58 bias_id=XNN_INVALID_VALUE_ID, # Dummy Bias id for bias = 0 59 output_id=output, 60 # We are taking from Aten::mm which holds weights as (in, out) 61 # instead of (out, in) which is what torch.nn.linear uses 62 flags=XNN_FLAG_TRANSPOSE_WEIGHTS, 63 ), 64 debug_handle=debug_handle, 65 ) 66 xnn_graph.xnodes.append(ser_node) 67