1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast, List 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 MPSGraph, 15 MPSMean, 16) 17 18 19@register_node_visitor 20class MeanVisitor(NodeVisitor): 21 target = "aten.mean.dim" 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 mps_graph: MPSGraph, 30 ) -> None: 31 mps_node = self.create_unary_node(node, mps_graph, MPSMean) 32 33 dims = cast(List[int], node.args[1]) 34 mps_node.mpsnode_union.num_dims = len(dims) 35 mps_node.mpsnode_union.dims = dims 36 if len(node.args) == 3: 37 mps_node.mpsnode_union.keep_dims = node.args[2] 38 39 mps_graph.mps_nodes.append(mps_node) 40