• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import List, Optional
8
9import torch
10from executorch.backends.xnnpack.partition.config.xnnpack_config import (
11    ConfigPrecisionType,
12    XNNPartitionerConfig,
13)
14from torch.export import ExportedProgram
15
16
17class QDQAffineConfigs(XNNPartitionerConfig):
18    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
19        return True
20
21    def get_node_and_deps(
22        self, node: torch.fx.Node, ep: ExportedProgram
23    ) -> List[torch.fx.Node]:
24        # Do not return anything from this because we only use this to
25        # preserve the decomposition
26        return []
27
28    def supported_precision_types(self) -> List[ConfigPrecisionType]:
29        return [ConfigPrecisionType.DYNAMIC_QUANT]
30
31
32class QuantizeAffineConfig(QDQAffineConfigs):
33    target_name = "quantize_affine.default"
34
35    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
36        try:
37            import torchao.quantization.quant_primitives  # noqa
38
39            return torch.ops.quant.quantize_affine.default
40        except:
41            return None
42
43
44class DeQuantizeAffineConfig(QDQAffineConfigs):
45    target_name = "dequantize_affine.default"
46
47    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
48        try:
49            import torchao.quantization.quant_primitives  # noqa
50
51            return torch.ops.quant.dequantize_affine.default
52        except:
53            return None
54
55
56class ChooseQParamsAffineConfig(QDQAffineConfigs):
57    target_name = "choose_qparams_affine.default"
58
59    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
60        try:
61            import torchao.quantization.quant_primitives  # noqa
62
63            return torch.ops.quant.choose_qparams_affine.default
64        except:
65            return None
66