1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import List, Optional 8 9import torch 10from executorch.backends.xnnpack.partition.config.xnnpack_config import ( 11 ConfigPrecisionType, 12 XNNPartitionerConfig, 13) 14from torch.export import ExportedProgram 15 16 17class QDQAffineConfigs(XNNPartitionerConfig): 18 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 19 return True 20 21 def get_node_and_deps( 22 self, node: torch.fx.Node, ep: ExportedProgram 23 ) -> List[torch.fx.Node]: 24 # Do not return anything from this because we only use this to 25 # preserve the decomposition 26 return [] 27 28 def supported_precision_types(self) -> List[ConfigPrecisionType]: 29 return [ConfigPrecisionType.DYNAMIC_QUANT] 30 31 32class QuantizeAffineConfig(QDQAffineConfigs): 33 target_name = "quantize_affine.default" 34 35 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 36 try: 37 import torchao.quantization.quant_primitives # noqa 38 39 return torch.ops.quant.quantize_affine.default 40 except: 41 return None 42 43 44class DeQuantizeAffineConfig(QDQAffineConfigs): 45 target_name = "dequantize_affine.default" 46 47 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 48 try: 49 import torchao.quantization.quant_primitives # noqa 50 51 return torch.ops.quant.dequantize_affine.default 52 except: 53 return None 54 55 56class ChooseQParamsAffineConfig(QDQAffineConfigs): 57 target_name = "choose_qparams_affine.default" 58 59 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 60 try: 61 import torchao.quantization.quant_primitives # noqa 62 63 return torch.ops.quant.choose_qparams_affine.default 64 except: 65 return None 66