1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict 8 9import torch 10from executorch.backends.transforms import get_shape 11from executorch.backends.xnnpack.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 16 XNNGraph, 17 XNNStaticSlice, 18 XNode, 19) 20from executorch.backends.xnnpack.utils.utils import ( 21 check_or_raise, 22 get_input_node, 23 PERM_NCHW_TO_NHWC, 24 PERM_NHWC_TO_NCHW, 25) 26 27 28@register_node_visitor 29class SliceCopyVisitor(NodeVisitor): 30 target = "aten.slice_copy.Tensor" 31 32 def __init__(self, *args) -> None: 33 super().__init__(*args) 34 35 def define_node( 36 self, 37 node: torch.fx.Node, 38 xnn_graph: XNNGraph, 39 vals_to_ids: Dict[torch.fx.Node, int], 40 debug_handle: int, 41 ) -> None: 42 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 43 44 input_node = get_input_node(node, 0) 45 46 # input 47 input_id = vals_to_ids[input_node] 48 49 # output 50 output_id = vals_to_ids[node] 51 52 # input shape 53 check_or_raise( 54 "val" in input_node.meta, 55 "Missing val in tensor metadata for input when serializing XNNStaticSlice", 56 ) 57 input_shape = get_shape(input_node) 58 59 # output shape 60 check_or_raise( 61 "val" in node.meta, 62 "Missing val in tensor metadata for input when serializing XNNStaticSlice", 63 ) 64 output_shape = get_shape(node) 65 dim_of_slice = cast(int, node.args[1]) 66 67 if "XNN_NHWC_NODE" in node.meta: 68 input_shape = [input_shape[i] for i in PERM_NCHW_TO_NHWC] 69 output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC] 70 dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice] 71 72 slice_begin_index = cast(int, node.args[2]) 73 if slice_begin_index < 0: 74 slice_begin_index = input_shape[dim_of_slice] + slice_begin_index 75 76 if len(node.args) > 4: 77 stride = cast(int, node.args[4]) 78 check_or_raise( 79 stride == 1, "XNNPACK Static Slice only supports slices with stride 1" 80 ) 81 82 num_dims = len(input_shape) 83 offsets = [0 for i in range(num_dims)] 84 offsets[dim_of_slice] = slice_begin_index 85 sizes = list(output_shape) 86 87 ser_node = XNode( 88 xnode_union=XNNStaticSlice( 89 num_dims=num_dims, 90 offsets=offsets, 91 sizes=sizes, 92 input_id=input_id, 93 output_id=output_id, 94 flags=0, 95 ), 96 debug_handle=debug_handle, 97 ) 98 xnn_graph.xnodes.append(ser_node) 99