# Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe import torch.fx from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm.tosa_mapping import extract_tensor_meta from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult class ConvertSplitToSlicePass(ExportPass): """ Replace a split operation with many slice operations. """ split_ops = ( exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.split_copy.Tensor, ) slice = exir_ops.edge.aten.slice_copy.Tensor def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph for node in graph.nodes: if node.target not in self.split_ops: continue # Get useful variables split_node = node input_node = split_node.all_input_nodes[0] output_nodes = split_node.users.copy() _, shape, _ = extract_tensor_meta(input_node.meta) rank = len(shape) split_lengths = split_node.args[1] dim = split_node.args[2] if len(split_node.args) > 2 else 0 dim = (dim + rank) % rank assert ( sum(split_lengths) == shape[dim] ), "Given split lengths don't sum up to the size of the dimension." # Convert split argument 'split_lengths' to slice arguments start and end. starts = [0] * len(split_lengths) ends = [0] * len(split_lengths) start = 0 end = 0 for i, split_length in enumerate(split_lengths): end = start + split_length starts[i] = start ends[i] = end start = end # Output nodes are of type getitem # Replace them with one slice node for each output node. with graph_module.graph.inserting_before(split_node): for output_node in output_nodes: index = output_node.args[1] slice_node = create_node( graph, self.slice, (input_node, dim, starts[index], ends[index]), ) slice_node.meta = split_node.meta.copy() slice_node.meta["val"] = slice_node.meta["val"][index] output_node.replace_all_uses_with(slice_node) graph.eliminate_dead_code() graph_module.recompile() graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True)