• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import List, Optional, Type
8
9from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
10    ChannelsLastTaggedReshapePass,
11)
12from executorch.backends.xnnpack._passes.conv1d_unsqueeze_pass import (
13    Conv1dUnsqueezePass,
14)
15from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
16from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
17from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
18    ConvertToUpsampleBilinear2d,
19)
20from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
21from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
22    FuseBatchNormWithConvPass,
23)
24from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
25from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
26from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
27    TagImplicitQDqPass,
28)
29from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
30
31from executorch.exir.pass_base import ExportPass
32
33from executorch.exir.passes.const_prop_pass import ConstPropPass
34from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
35
36from executorch.exir.program._program import _transform
37from torch._export.pass_base import PassType
38
39from torch.export import ExportedProgram
40
41
42class XNNPACKPassManager:
43    def __init__(
44        self,
45        exported_program: ExportedProgram,
46        passes: Optional[List[Type[PassType]]] = None,
47    ) -> None:
48        """
49        A helper class to run multiple XNNPACK passes on a program
50        If passes list is empty, all passes in XNNPACK will be run.
51        Else only run passes in the list will be run.
52        """
53        self._exported_program = exported_program
54
55        if not passes:
56            # All the XNNPACK passes
57            self.passes = [
58                # TODO - remove this pass once we have a better support for dim_order ops lowering
59                DimOrderOpsRevertPass,
60                ConvertToUpsampleBilinear2d,
61                ConvertToLinearPass,
62                ConvertToSDPAPass,
63                ConstPropPass,
64                FuseBatchNormWithConvPass,
65                FuseActivationPass,
66                RemoveGetItemPass,
67                Conv1dUnsqueezePass,
68                PReLUReshapePass,
69                ChannelsLastTaggedReshapePass,
70                TagImplicitQDqPass,
71            ]
72        else:
73            self.passes = passes
74
75    @property
76    def exported_program(self) -> ExportedProgram:
77        return self._exported_program
78
79    def transform(self) -> ExportedProgram:
80        """
81        Returns a transformed ExportedProgram
82        """
83        ep = self.exported_program
84        for pass_ in self.passes:
85            if issubclass(pass_, XNNPACKPass):
86                transform_pass = pass_(ep)
87            elif issubclass(pass_, ExportPass):
88                transform_pass = pass_()
89            else:
90                raise RuntimeError(
91                    f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}"
92                )
93            ep = _transform(ep, transform_pass)
94        return ep
95