• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import torch
7from executorch.backends.apple.mps.operators.node_visitor import (
8    NodeVisitor,
9    register_node_visitor,
10)
11from executorch.backends.apple.mps.serialization.mps_graph_schema import (
12    MPSAddmm,
13    MPSGraph,
14    MPSMatMul,
15)
16
17
18@register_node_visitor
19class MatMulVisitor(NodeVisitor):
20    target = ["aten.mm.default", "aten.bmm.default"]
21
22    def __init__(self, *args) -> None:
23        super().__init__(*args)
24
25    def define_node(
26        self,
27        node: torch.fx.Node,
28        mps_graph: MPSGraph,
29    ) -> None:
30        mps_graph.mps_nodes.append(self.create_binary_node(node, mps_graph, MPSMatMul))
31
32
33@register_node_visitor
34class AddmmVisitor(NodeVisitor):
35    target = "aten.addmm.default"
36
37    def __init__(self, *args) -> None:
38        super().__init__(*args)
39
40    def define_node(
41        self,
42        node: torch.fx.Node,
43        mps_graph: MPSGraph,
44    ) -> None:
45        mps_node = self.create_tertiary_node(node, mps_graph, MPSAddmm)
46
47        if len(node.args) == 4:
48            mps_node.mpsnode_union.beta = node.args[3]
49        if len(node.args) == 5:
50            mps_node.mpsnode_union.alpha = node.args[4]
51
52        mps_graph.mps_nodes.append(mps_node)
53