• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import torch
7from executorch.backends.apple.mps.operators.node_visitor import (
8    NodeVisitor,
9    register_node_visitor,
10)
11from executorch.backends.apple.mps.serialization.mps_graph_schema import (
12    MPSArange,
13    MPSGraph,
14    MPSNode,
15)
16from executorch.backends.apple.mps.utils.mps_utils import edge_dtype_to_mps_dtype
17
18
19@register_node_visitor
20class ArangeVisitor(NodeVisitor):
21    target = "aten.arange.start_step"
22
23    def __init__(self, *args) -> None:
24        super().__init__(*args)
25
26    def define_node(
27        self,
28        node: torch.fx.Node,
29        mps_graph: MPSGraph,
30    ) -> None:
31        step = 1.0
32        if len(node.args) > 2 and node.args[2] is not None:
33            step = float(node.args[2])
34
35        start = float(node.args[0])
36        end = float(node.args[1])
37
38        dtype = edge_dtype_to_mps_dtype(node.meta["val"].dtype)
39        if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None:
40            dtype = edge_dtype_to_mps_dtype(node.kwargs["dtype"])
41
42        output_id = self.define_tensor(node, mps_graph)
43
44        mps_node = MPSNode(
45            mpsnode_union=MPSArange(
46                output_id=output_id,
47                start=start,
48                end=end,
49                step=step,
50                dtype=dtype,
51            )
52        )
53        mps_graph.mps_nodes.append(mps_node)
54