1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import List, Optional, Type 8 9from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ( 10 ChannelsLastTaggedReshapePass, 11) 12from executorch.backends.xnnpack._passes.conv1d_unsqueeze_pass import ( 13 Conv1dUnsqueezePass, 14) 15from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass 16from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass 17from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import ( 18 ConvertToUpsampleBilinear2d, 19) 20from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass 21from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( 22 FuseBatchNormWithConvPass, 23) 24from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass 25from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass 26from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( 27 TagImplicitQDqPass, 28) 29from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass 30 31from executorch.exir.pass_base import ExportPass 32 33from executorch.exir.passes.const_prop_pass import ConstPropPass 34from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass 35 36from executorch.exir.program._program import _transform 37from torch._export.pass_base import PassType 38 39from torch.export import ExportedProgram 40 41 42class XNNPACKPassManager: 43 def __init__( 44 self, 45 exported_program: ExportedProgram, 46 passes: Optional[List[Type[PassType]]] = None, 47 ) -> None: 48 """ 49 A helper class to run multiple XNNPACK passes on a program 50 If passes list is empty, all passes in XNNPACK will be run. 51 Else only run passes in the list will be run. 52 """ 53 self._exported_program = exported_program 54 55 if not passes: 56 # All the XNNPACK passes 57 self.passes = [ 58 # TODO - remove this pass once we have a better support for dim_order ops lowering 59 DimOrderOpsRevertPass, 60 ConvertToUpsampleBilinear2d, 61 ConvertToLinearPass, 62 ConvertToSDPAPass, 63 ConstPropPass, 64 FuseBatchNormWithConvPass, 65 FuseActivationPass, 66 RemoveGetItemPass, 67 Conv1dUnsqueezePass, 68 PReLUReshapePass, 69 ChannelsLastTaggedReshapePass, 70 TagImplicitQDqPass, 71 ] 72 else: 73 self.passes = passes 74 75 @property 76 def exported_program(self) -> ExportedProgram: 77 return self._exported_program 78 79 def transform(self) -> ExportedProgram: 80 """ 81 Returns a transformed ExportedProgram 82 """ 83 ep = self.exported_program 84 for pass_ in self.passes: 85 if issubclass(pass_, XNNPACKPass): 86 transform_pass = pass_(ep) 87 elif issubclass(pass_, ExportPass): 88 transform_pass = pass_() 89 else: 90 raise RuntimeError( 91 f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}" 92 ) 93 ep = _transform(ep, transform_pass) 94 return ep 95