# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # This is for PT2E quantization. import logging from dataclasses import dataclass from typing import List, Optional import torch from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @dataclass class EmbeddingQuantOptions: is_per_channel: bool = True group_size: int = -1 def __post_init__(self): if self.group_size != -1: raise RuntimeError( "PT2E embedding quantizer does not support groupwise at the moment." ) @dataclass class DynamicQuantLinearOptions: is_per_channel: bool = True is_qc4: bool = False @dataclass class PT2EQuantOptions: quantize_embedding: Optional[EmbeddingQuantOptions] = None quantize_linear: Optional[DynamicQuantLinearOptions] = None def get_pt2e_quantization_params( pt2e_quantize: Optional[str] = None, quantization_mode: Optional[str] = None, ) -> Optional[PT2EQuantOptions]: if pt2e_quantize is None: return None if quantization_mode: raise ValueError("Cannot specify both quantization_mode and pt2e_quantize") quantization_options = pt2e_quantize.split(",") quantization_options = [option.strip() for option in quantization_options] # This can really be improved significantly. # Hopefully we dont release this in its current form. # Just using this for quick experiments. quant_options = None if "embedding" in quantization_options: quant_options = quant_options or PT2EQuantOptions() quant_options.quantize_embedding = EmbeddingQuantOptions() if ( "xnnpack_dynamic" in quantization_options and "xnnpack_dynamic_qc4" in quantization_options ): raise RuntimeError( "For dynamic linear quantization via xnnpack quantizer you can chose only qc8 or qc4 option, not both." ) if ( "xnnpack_dynamic" in quantization_options or "xnnpack_dynamic_qc4" in quantization_options ): quant_options = quant_options or PT2EQuantOptions() quant_options.quantize_linear = DynamicQuantLinearOptions() if "xnnpack_dynamic_qc4" in quantization_options: quant_options.quantize_linear.is_qc4 = True return quant_options def get_pt2e_quantizers( quant_params: Optional[PT2EQuantOptions], so_library: Optional[str] = None, ) -> List[Quantizer]: """ Get a list of quantizers from quantization params Args: quant_params: PT2E quantization options. Returns: A list of quantizers to pass into LlamaBuilder. """ def check_embedding_byte_registered(): try: _ = torch.ops.quantized_decomposed.embedding_byte.out except AttributeError: if so_library: print(f"Loading library {so_library}") torch.ops.load_library(so_library) else: raise RuntimeError( "Need to specify shared library path to register quantized ops (and their out variants) into EXIR.\n" "Follow the following steps to build the needed lib via cmake.\n" 'Use `python -c "import torch as _; print(_.__path__)"` to find where torch package is installed.\n' "Set that as TORCH_PACKAGE_DIR.\n" "Then from root executorch dir do the following:\n" "rm -rf cmake-out && mkdir cmake-out && (cd cmake-out && cmake -DBUCK2= -DCMAKE_PREFIX_PATH=$TORCH_PACKAGE_DIR -DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON ..) && cmake --build . -j16\n" 'To find the location of the lib: find cmake-out -name "libquantized_ops_aot_lib*"\n' "Then specify the said library via -s