1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict, List 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 get_tensor_value, 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 16 XNNGraph, 17 XNNStaticConstantPad, 18 XNode, 19) 20from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node 21 22 23@register_node_visitor 24class StaticConstantPadVisitor(NodeVisitor): 25 target = "aten.constant_pad_nd.default" 26 27 def __init__(self, *args) -> None: 28 super().__init__(*args) 29 30 def define_node( 31 self, 32 node: torch.fx.Node, 33 xnn_graph: XNNGraph, 34 vals_to_ids: Dict[torch.fx.Node, int], 35 debug_handle: int, 36 ) -> None: 37 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 38 39 # input 40 input_id = vals_to_ids[get_input_node(node, 0)] 41 42 # output 43 output_id = vals_to_ids[node] 44 45 all_paddings = cast(List[int], node.args[1]) 46 47 check_or_raise( 48 len(all_paddings) % 2 == 0, 49 f"Expected even number of padding values, got {len(all_paddings)}", 50 ) 51 52 # Explanation of padding as given by PyTorch vs as expected by XNNPACK: 53 # 54 # Let n be the number of dimensions in the input, and let k be the 55 # number of dimensions which we want to pad (k <= n). 56 # The list all_paddings, given by PyTorch, has length 2k and contains 57 # the padding amounts for before and after each of the LAST k input 58 # dimensions, but in descending order by dimension. i.e. 59 # [ 60 # padding before dim n - 1, 61 # padding after dim n - 1, 62 # padding before dim n - 2, 63 # padding after dim n - 2, 64 # ... 65 # padding before dim n - k, 66 # padding after dim n - k, 67 # ] 68 # 69 # Ex. if n = 4 and k = 2, all_paddings will look like: 70 # [ 71 # padding before dim 3, 72 # padding after dim 3, 73 # padding before dim 2, 74 # padding after dim 2, 75 # ] 76 # 77 # The way that XNNPACK expects padding amounts to be passed in is in 78 # two lists, pre_paddings and post_paddings. pre_paddings should contain 79 # n elements, which are the padding amounts for before each of the n 80 # dimensions of the input in ascending order by dimensions. 81 # post_paddings is the same but for the padding amounts after each 82 # dimension. i.e. we want pre and post paddings to look like: 83 # pre_paddings = [ 84 # padding before dim 0, 85 # padding before dim 1, 86 # ... 87 # padding before dim n - 1, 88 # ] 89 # post_paddings = [ 90 # padding after dim 0, 91 # padding after dim 1, 92 # ... 93 # padding after dim n - 1, 94 # ] 95 # 96 # To get pre and post paddings in this form, we need to 97 # a) Append 2(n - k) zeros to the end of all_paddings as the padding 98 # amounts for before and after each of the leading n - k input 99 # input dimensions 100 # b) Extract the even index elements of all_paddings in reverse order 101 # as pre_paddings, and same for the odd index elements as 102 # post_paddings 103 104 # a) 105 num_padding_dims = 2 * len( 106 get_tensor_value(xnn_graph.xvalues[input_id]).dims 107 ) # 2n 108 num_zero_padding_dims = num_padding_dims - len(all_paddings) # 2(n - k) 109 all_paddings = all_paddings + ( 110 [0] * num_zero_padding_dims 111 ) # zeros have been appended 112 113 # b) 114 # tuple[0] = prepadding dim[-1] 115 # tuple[1] = postpadding dim[-1] 116 pre_paddings = all_paddings[-2::-2] # even index elements in reverse order 117 post_paddings = all_paddings[::-2] # odd index elements in reverse order 118 119 ser_node = XNode( 120 xnode_union=XNNStaticConstantPad( 121 pre_paddings=pre_paddings, 122 post_paddings=post_paddings, 123 padding_value=cast(float, node.args[2]), 124 input_id=input_id, 125 output_id=output_id, 126 flags=0, 127 ), 128 debug_handle=debug_handle, 129 ) 130 xnn_graph.xnodes.append(ser_node) 131