• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast
7
8import torch
9from executorch.backends.apple.mps.operators.node_visitor import (
10    NodeVisitor,
11    register_node_visitor,
12)
13from executorch.backends.apple.mps.serialization.mps_graph_schema import (
14    MPSDataType,
15    MPSFull,
16    MPSFullLike,
17    MPSGraph,
18    MPSNode,
19)
20from executorch.backends.apple.mps.utils.mps_utils import (
21    edge_dtype_to_mps_dtype,
22    get_input_node,
23)
24
25from executorch.exir.dialects._ops import ops as exir_ops
26from executorch.exir.sym_util import eval_shape
27
28
29@register_node_visitor
30class ConstantOpVisitor(NodeVisitor):
31    target = [
32        "aten.full.default",
33        "aten.empty.memory_format",
34        "aten.scalar_tensor.default",
35    ]
36
37    def __init__(self, *args) -> None:
38        super().__init__(*args)
39
40    def define_node(
41        self,
42        node: torch.fx.Node,
43        mps_graph: MPSGraph,
44    ) -> None:
45        if len(node.args) >= 3:
46            raise AssertionError("Unexpected number of input parameters")
47
48        if node.target == exir_ops.edge.aten.scalar_tensor.default:
49            shape = [1]
50        else:
51            shape = eval_shape(node.args[0])
52
53        if node.target == exir_ops.edge.aten.full.default:
54            fill_value = cast(float, node.args[1])
55        elif node.target == exir_ops.edge.aten.empty.memory_format:
56            fill_value = 0
57        elif node.target == exir_ops.edge.aten.scalar_tensor.default:
58            fill_value = cast(float, node.args[0])
59
60        if fill_value == float("-inf"):
61            fill_value = "-inf"
62        elif fill_value == float("inf"):
63            fill_value = "inf"
64
65        dtype = MPSDataType.mps_data_type_float32
66        if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None:
67            dtype = edge_dtype_to_mps_dtype(node.kwargs["dtype"])
68
69        output_id = self.define_tensor(node, mps_graph)
70        mps_graph.mps_nodes.append(
71            MPSNode(
72                mpsnode_union=MPSFull(
73                    output_id=output_id,
74                    shape=shape,
75                    fill_value=fill_value,
76                    dtype=dtype,
77                )
78            )
79        )
80
81
82@register_node_visitor
83class FullLikeVisitor(NodeVisitor):
84    target = "aten.full_like.default"
85
86    def __init__(self, *args) -> None:
87        super().__init__(*args)
88
89    def define_node(
90        self,
91        node: torch.fx.Node,
92        mps_graph: MPSGraph,
93    ) -> None:
94
95        if len(node.args) < 2:
96            raise AssertionError("Full op requires at least size & fill_value args")
97
98        mps_node = self.create_unary_node(node, mps_graph, MPSFullLike)
99
100        mps_node.mpsnode_union.fill_value = cast(float, node.args[1])
101        mps_node.mpsnode_union.dtype = self.get_serialized_dtype(
102            get_input_node(node, 0)
103        )
104        if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None:
105            mps_node.mpsnode_union.dtype = edge_dtype_to_mps_dtype(node.kwargs["dtype"])
106        if len(node.args) >= 3:
107            raise AssertionError("Unexpected number of input parameters")
108
109        mps_graph.mps_nodes.append(mps_node)
110