1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import torch 7from executorch.backends.apple.mps.operators.node_visitor import ( 8 NodeVisitor, 9 register_node_visitor, 10) 11from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 12 MPSArange, 13 MPSGraph, 14 MPSNode, 15) 16from executorch.backends.apple.mps.utils.mps_utils import edge_dtype_to_mps_dtype 17 18 19@register_node_visitor 20class ArangeVisitor(NodeVisitor): 21 target = "aten.arange.start_step" 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 mps_graph: MPSGraph, 30 ) -> None: 31 step = 1.0 32 if len(node.args) > 2 and node.args[2] is not None: 33 step = float(node.args[2]) 34 35 start = float(node.args[0]) 36 end = float(node.args[1]) 37 38 dtype = edge_dtype_to_mps_dtype(node.meta["val"].dtype) 39 if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None: 40 dtype = edge_dtype_to_mps_dtype(node.kwargs["dtype"]) 41 42 output_id = self.define_tensor(node, mps_graph) 43 44 mps_node = MPSNode( 45 mpsnode_union=MPSArange( 46 output_id=output_id, 47 start=start, 48 end=end, 49 step=step, 50 dtype=dtype, 51 ) 52 ) 53 mps_graph.mps_nodes.append(mps_node) 54