• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import cast, Dict, List
8
9import torch
10from executorch.backends.xnnpack.operators.node_visitor import (
11    get_tensor_value,
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
16    XNNGraph,
17    XNNStaticConstantPad,
18    XNode,
19)
20from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
21
22
23@register_node_visitor
24class StaticConstantPadVisitor(NodeVisitor):
25    target = "aten.constant_pad_nd.default"
26
27    def __init__(self, *args) -> None:
28        super().__init__(*args)
29
30    def define_node(
31        self,
32        node: torch.fx.Node,
33        xnn_graph: XNNGraph,
34        vals_to_ids: Dict[torch.fx.Node, int],
35        debug_handle: int,
36    ) -> None:
37        self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
38
39        # input
40        input_id = vals_to_ids[get_input_node(node, 0)]
41
42        # output
43        output_id = vals_to_ids[node]
44
45        all_paddings = cast(List[int], node.args[1])
46
47        check_or_raise(
48            len(all_paddings) % 2 == 0,
49            f"Expected even number of padding values, got {len(all_paddings)}",
50        )
51
52        # Explanation of padding as given by PyTorch vs as expected by XNNPACK:
53        #
54        # Let n be the number of dimensions in the input, and let k be the
55        # number of dimensions which we want to pad (k <= n).
56        # The list all_paddings, given by PyTorch, has length 2k and contains
57        # the padding amounts for before and after each of the LAST k input
58        # dimensions, but in descending order by dimension. i.e.
59        # [
60        #     padding before dim n - 1,
61        #     padding after dim n - 1,
62        #     padding before dim n - 2,
63        #     padding after dim n - 2,
64        #     ...
65        #     padding before dim n - k,
66        #     padding after dim n - k,
67        # ]
68        #
69        # Ex. if n = 4 and k = 2, all_paddings will look like:
70        # [
71        #     padding before dim 3,
72        #     padding after dim 3,
73        #     padding before dim 2,
74        #     padding after dim 2,
75        # ]
76        #
77        # The way that XNNPACK expects padding amounts to be passed in is in
78        # two lists, pre_paddings and post_paddings. pre_paddings should contain
79        # n elements, which are the padding amounts for before each of the n
80        # dimensions of the input in ascending order by dimensions.
81        # post_paddings is the same but for the padding amounts after each
82        # dimension. i.e. we want pre and post paddings to look like:
83        # pre_paddings = [
84        #     padding before dim 0,
85        #     padding before dim 1,
86        #     ...
87        #     padding before dim n - 1,
88        # ]
89        # post_paddings = [
90        #     padding after dim 0,
91        #     padding after dim 1,
92        #     ...
93        #     padding after dim n - 1,
94        # ]
95        #
96        # To get pre and post paddings in this form, we need to
97        # a) Append 2(n - k) zeros to the end of all_paddings as the padding
98        #    amounts for before and after each of the leading n - k input
99        #    input dimensions
100        # b) Extract the even index elements of all_paddings in reverse order
101        #    as pre_paddings, and same for the odd index elements as
102        #    post_paddings
103
104        # a)
105        num_padding_dims = 2 * len(
106            get_tensor_value(xnn_graph.xvalues[input_id]).dims
107        )  # 2n
108        num_zero_padding_dims = num_padding_dims - len(all_paddings)  # 2(n - k)
109        all_paddings = all_paddings + (
110            [0] * num_zero_padding_dims
111        )  # zeros have been appended
112
113        # b)
114        # tuple[0] = prepadding dim[-1]
115        # tuple[1] = postpadding dim[-1]
116        pre_paddings = all_paddings[-2::-2]  # even index elements in reverse order
117        post_paddings = all_paddings[::-2]  # odd index elements in reverse order
118
119        ser_node = XNode(
120            xnode_union=XNNStaticConstantPad(
121                pre_paddings=pre_paddings,
122                post_paddings=post_paddings,
123                padding_value=cast(float, node.args[2]),
124                input_id=input_id,
125                output_id=output_id,
126                flags=0,
127            ),
128            debug_handle=debug_handle,
129        )
130        xnn_graph.xnodes.append(ser_node)
131