1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict, List 8 9import torch 10from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ( 11 ChannelsLastTaggedReshapePass, 12) 13from executorch.backends.xnnpack.operators.node_visitor import ( 14 NodeVisitor, 15 register_node_visitor, 16) 17from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 18 XNNGraph, 19 XNNStaticTranspose, 20 XNode, 21) 22from executorch.backends.xnnpack.utils.utils import ( 23 check_or_raise, 24 get_input_node, 25 PERM_NCHW_TO_NHWC, 26 PERM_NHWC_TO_NCHW, 27) 28 29 30@register_node_visitor 31class PermuteVisitor(NodeVisitor): 32 target = "aten.permute_copy.default" 33 34 def __init__(self, *args) -> None: 35 super().__init__(*args) 36 37 def define_node( 38 self, 39 node: torch.fx.Node, 40 xnn_graph: XNNGraph, 41 vals_to_ids: Dict[torch.fx.Node, int], 42 debug_handle: int, 43 ) -> None: 44 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 45 46 # input 47 input_id = vals_to_ids[get_input_node(node, 0)] 48 49 # output 50 output_id = vals_to_ids[node] 51 52 # permutation 53 permute_order = cast(List[int], node.args[1]) 54 55 # change permute order if under channels last 56 is_channels_last = node.meta.get( 57 ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False 58 ) 59 if is_channels_last: 60 check_or_raise( 61 len(permute_order) == 4, 62 "Internal Error: Permute was tagged in channels last but is not 4D", 63 ) 64 permute_order_in_contiguous = [PERM_NHWC_TO_NCHW[i] for i in permute_order] 65 permute_order_in_channels_last = [ 66 permute_order_in_contiguous[i] for i in PERM_NCHW_TO_NHWC 67 ] 68 permute_order = permute_order_in_channels_last 69 70 ser_node = XNode( 71 xnode_union=XNNStaticTranspose( 72 input_id=input_id, 73 num_dims=len(permute_order), 74 perm=permute_order, 75 output_id=output_id, 76 flags=0, 77 ), 78 debug_handle=debug_handle, 79 ) 80 xnn_graph.xnodes.append(ser_node) 81