• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast, List
7
8import torch
9from executorch.backends.apple.mps.operators.node_visitor import (
10    NodeVisitor,
11    register_node_visitor,
12)
13from executorch.backends.apple.mps.serialization.mps_graph_schema import (
14    MPSGraph,
15    MPSMean,
16)
17
18
19@register_node_visitor
20class MeanVisitor(NodeVisitor):
21    target = "aten.mean.dim"
22
23    def __init__(self, *args) -> None:
24        super().__init__(*args)
25
26    def define_node(
27        self,
28        node: torch.fx.Node,
29        mps_graph: MPSGraph,
30    ) -> None:
31        mps_node = self.create_unary_node(node, mps_graph, MPSMean)
32
33        dims = cast(List[int], node.args[1])
34        mps_node.mpsnode_union.num_dims = len(dims)
35        mps_node.mpsnode_union.dims = dims
36        if len(node.args) == 3:
37            mps_node.mpsnode_union.keep_dims = node.args[2]
38
39        mps_graph.mps_nodes.append(mps_node)
40