1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import torch 7from executorch.backends.apple.mps.operators.node_visitor import ( 8 NodeVisitor, 9 register_node_visitor, 10) 11from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph 12from executorch.backends.apple.mps.utils.mps_utils import get_input_node, get_scalar_val 13 14 15@register_node_visitor 16class GetItemVisitor(NodeVisitor): 17 target = "getitem" 18 19 def __init__(self, *args) -> None: 20 super().__init__(*args) 21 22 def define_node( 23 self, 24 node: torch.fx.Node, 25 mps_graph: MPSGraph, 26 ) -> None: 27 self.tensor_to_id[node] = self.tensor_to_id[get_input_node(node, 0)][ 28 get_scalar_val(node, 1) 29 ] 30