# Copyright 2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe from typing import List import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg from serializer.tosa_serializer import TosaOp from torch.fx import Node @register_node_visitor class SliceVisitor(NodeVisitor): target = "aten.slice_copy.Tensor" def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, is_quant_node: bool, ) -> None: # aten.slice_copy supports slicing in 1d at a time. # The arguments are dimension of slicing, start index and end index. assert len(inputs) == 4 input_node, dim, start, end = inputs # Translate and check parameters in Pytorch dim order. shape = input_node.shape dim = dim.number if end.number < 0: end = end.number % shape[dim] else: end = min(end.number, shape[dim]) size = end - start.number assert size > 0 assert size <= shape[dim] # Convert aten args to Tosa's start and size attributes and in TOSA dim order. attr = ts.TosaSerializerAttribute() start_attr = [start.number if i == dim else 0 for i in input_node.dim_order] size_attr = [size if i == dim else shape[i] for i in input_node.dim_order] attr.SliceAttribute(start_attr, size_attr) tosa_graph.addOperator( TosaOp.Op().SLICE, [input_node.name], [output.name], attr )