• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import cast, Dict
8
9import torch
10from executorch.backends.transforms import get_shape
11from executorch.backends.xnnpack.operators.node_visitor import (
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
16    XNNGraph,
17    XNNStaticSlice,
18    XNode,
19)
20from executorch.backends.xnnpack.utils.utils import (
21    check_or_raise,
22    get_input_node,
23    PERM_NCHW_TO_NHWC,
24    PERM_NHWC_TO_NCHW,
25)
26
27
28@register_node_visitor
29class SliceCopyVisitor(NodeVisitor):
30    target = "aten.slice_copy.Tensor"
31
32    def __init__(self, *args) -> None:
33        super().__init__(*args)
34
35    def define_node(
36        self,
37        node: torch.fx.Node,
38        xnn_graph: XNNGraph,
39        vals_to_ids: Dict[torch.fx.Node, int],
40        debug_handle: int,
41    ) -> None:
42        self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
43
44        input_node = get_input_node(node, 0)
45
46        # input
47        input_id = vals_to_ids[input_node]
48
49        # output
50        output_id = vals_to_ids[node]
51
52        # input shape
53        check_or_raise(
54            "val" in input_node.meta,
55            "Missing val in tensor metadata for input when serializing XNNStaticSlice",
56        )
57        input_shape = get_shape(input_node)
58
59        # output shape
60        check_or_raise(
61            "val" in node.meta,
62            "Missing val in tensor metadata for input when serializing XNNStaticSlice",
63        )
64        output_shape = get_shape(node)
65        dim_of_slice = cast(int, node.args[1])
66
67        if "XNN_NHWC_NODE" in node.meta:
68            input_shape = [input_shape[i] for i in PERM_NCHW_TO_NHWC]
69            output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC]
70            dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice]
71
72        slice_begin_index = cast(int, node.args[2])
73        if slice_begin_index < 0:
74            slice_begin_index = input_shape[dim_of_slice] + slice_begin_index
75
76        if len(node.args) > 4:
77            stride = cast(int, node.args[4])
78            check_or_raise(
79                stride == 1, "XNNPACK Static Slice only supports slices with stride 1"
80            )
81
82        num_dims = len(input_shape)
83        offsets = [0 for i in range(num_dims)]
84        offsets[dim_of_slice] = slice_begin_index
85        sizes = list(output_shape)
86
87        ser_node = XNode(
88            xnode_union=XNNStaticSlice(
89                num_dims=num_dims,
90                offsets=offsets,
91                sizes=sizes,
92                input_id=input_id,
93                output_id=output_id,
94                flags=0,
95            ),
96            debug_handle=debug_handle,
97        )
98        xnn_graph.xnodes.append(ser_node)
99