• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict
8
9import torch
10from executorch.backends.xnnpack.operators.node_visitor import (
11    NodeVisitor,
12    register_node_visitor,
13)
14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15    XNNFullyConnected,
16    XNNGraph,
17    XNode,
18)
19from executorch.backends.xnnpack.utils.utils import get_input_node
20
21from executorch.backends.xnnpack.utils.xnnpack_constants import (
22    XNN_FLAG_TRANSPOSE_WEIGHTS,
23    XNN_INVALID_VALUE_ID,
24)
25
26
27@register_node_visitor
28class MatrixMultiplyVisitor(NodeVisitor):
29    target = "aten.mm.default"
30
31    def __init__(self, *args) -> None:
32        super().__init__(*args)
33
34    def define_node(
35        self,
36        node: torch.fx.Node,
37        xnn_graph: XNNGraph,
38        vals_to_ids: Dict[torch.fx.Node, int],
39        debug_handle: int,
40    ) -> None:
41        self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
42
43        # input
44        input_id = vals_to_ids[get_input_node(node, 0)]
45
46        # filter
47        filter_id = vals_to_ids[get_input_node(node, 1)]
48
49        # output
50        output = vals_to_ids[node]
51
52        # Matrix Multiply is handled by using linear with bias = 0. XNNPACK performs
53        # this by giving a dummy id as the bias in the fully-connected node.
54        ser_node = XNode(
55            xnode_union=XNNFullyConnected(
56                input1_id=input_id,
57                filter_id=filter_id,
58                bias_id=XNN_INVALID_VALUE_ID,  # Dummy Bias id for bias = 0
59                output_id=output,
60                # We are taking from Aten::mm which holds weights as (in, out)
61                # instead of (out, in) which is what torch.nn.linear uses
62                flags=XNN_FLAG_TRANSPOSE_WEIGHTS,
63            ),
64            debug_handle=debug_handle,
65        )
66        xnn_graph.xnodes.append(ser_node)
67