from typing import Any, Callable import torch def setup_baseline(): from torchao.quantization.utils import recommended_inductor_config_setter recommended_inductor_config_setter() torch._dynamo.config.automatic_dynamic_shapes = False torch._dynamo.config.cache_size_limit = 10000 def torchao_optimize_ctx(quantization: str): from torchao.quantization.quant_api import ( autoquant, int4_weight_only, int8_dynamic_activation_int8_weight, int8_weight_only, quantize_, ) from torchao.utils import unwrap_tensor_subclass def inner(model_iter_fn: Callable): def _torchao_apply(module: torch.nn.Module, example_inputs: Any): if getattr(module, "_quantized", None) is None: if quantization == "int8dynamic": quantize_( module, int8_dynamic_activation_int8_weight(), set_inductor_config=False, ) elif quantization == "int8weightonly": quantize_(module, int8_weight_only(), set_inductor_config=False) elif quantization == "int4weightonly": quantize_(module, int4_weight_only(), set_inductor_config=False) if quantization == "autoquant": autoquant(module, error_on_unseen=False, set_inductor_config=False) if isinstance(example_inputs, dict): module(**example_inputs) else: module(*example_inputs) from torchao.quantization.autoquant import AUTOQUANT_CACHE if len(AUTOQUANT_CACHE) == 0: raise Exception( # noqa: TRY002` "NotAutoquantizable" f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run" ) else: unwrap_tensor_subclass(module) setattr(module, "_quantized", True) # noqa: B010 model_iter_fn(module, example_inputs) return _torchao_apply return inner