1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import torch 7from executorch.backends.apple.mps.operators.node_visitor import ( 8 NodeVisitor, 9 register_node_visitor, 10) 11from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 12 MPSAddmm, 13 MPSGraph, 14 MPSMatMul, 15) 16 17 18@register_node_visitor 19class MatMulVisitor(NodeVisitor): 20 target = ["aten.mm.default", "aten.bmm.default"] 21 22 def __init__(self, *args) -> None: 23 super().__init__(*args) 24 25 def define_node( 26 self, 27 node: torch.fx.Node, 28 mps_graph: MPSGraph, 29 ) -> None: 30 mps_graph.mps_nodes.append(self.create_binary_node(node, mps_graph, MPSMatMul)) 31 32 33@register_node_visitor 34class AddmmVisitor(NodeVisitor): 35 target = "aten.addmm.default" 36 37 def __init__(self, *args) -> None: 38 super().__init__(*args) 39 40 def define_node( 41 self, 42 node: torch.fx.Node, 43 mps_graph: MPSGraph, 44 ) -> None: 45 mps_node = self.create_tertiary_node(node, mps_graph, MPSAddmm) 46 47 if len(node.args) == 4: 48 mps_node.mpsnode_union.beta = node.args[3] 49 if len(node.args) == 5: 50 mps_node.mpsnode_union.alpha = node.args[4] 51 52 mps_graph.mps_nodes.append(mps_node) 53