# pyre-unsafe import logging from typing import Any, Callable, Dict, Optional, Type import executorch.backends.vulkan.custom_ops_lib # noqa import torch import torch.nn.functional as F from torchao.quantization.GPTQ import _check_linear_int4_k from torchao.quantization.unified import Quantizer from torchao.quantization.utils import groupwise_affine_quantize_tensor # This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with # changes at the annotated lines. class VkWeightOnlyInt4Linear(torch.nn.Module): __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: torch.Tensor def __init__( self, in_features: int, out_features: int, # TODO: remove dtype field, not used bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) if self.padding: from torchao.utils import find_multiple self.origin_in_features = in_features in_features = find_multiple(in_features, (1024,)) self.in_features = in_features self.out_features = out_features assert not bias, "require bias=False" self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.precision = precision self.scales_precision = scales_precision if dtype is not None: raise ValueError("Please specify 'precision' instead of 'dtype'") assert out_features % 8 == 0, "require out_features % 8 == 0" assert ( in_features % (inner_k_tiles * 16) == 0 ), "require in_features % (innerKTiles * 16) == 0" # In the original implementation, the weight buffer is registered with the packed # sizes, i.e. the result of calling the _convert_weight_to_int4pack operator. # However, the Vulkan implementation does not expect the weights to be packed # therefore the weight tensor is registered with the unpacked sizes instead. # Note that in_features is divided by 2 because each `uint8` tensor element # contains 2 4-bit packed values. self.register_buffer( "weight", torch.empty( (out_features, in_features // 2), dtype=torch.uint8, device=device, ), ) self.dtype = dtype self.register_buffer( "scales_and_zeros", torch.empty( (in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device, ), ) def forward(self, input: torch.Tensor) -> torch.Tensor: if self.padding: input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) # The forward method is replaced. In the original implementation, the forward # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom # operator is called instead. return torch.ops.et_vk.linear_weight_int4( input, self.weight, self.groupsize, self.scales_and_zeros, self.inner_k_tiles, ) # This function is coped from torchao.quantization.GPTQ._replace_linear_int4 # with small changes at the annotated locations. def _vk_replace_linear_int4( module: torch.nn.Module, groupsize: int, inner_k_tiles: Optional[int], padding_allowed: bool, skip_layer_func: Optional[Callable] = None, precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, # Use custom vulkan linear layer as default linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear, copy_weights: bool = False, # Serves the same purpose as `tensor_dim_limit` in # executorch.backends.vulkan.partitioner.VulkanSupportedOperators feature_limit: int = 16384, ): for name, child in module.named_children(): if isinstance(child, torch.nn.Linear) and ( skip_layer_func is None or not skip_layer_func(child.weight) ): # Add an additional condition that the out/in features must not exceed the # `feature_limit` argument. if ( _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed ) and ( child.out_features < feature_limit and child.in_features < feature_limit ): new_linear = linear_class( child.in_features, child.out_features, bias=False, device=child.weight.device, groupsize=groupsize, inner_k_tiles=inner_k_tiles, precision=precision, scales_precision=scales_precision, ) if copy_weights and child.weight.device != torch.device("meta"): # pyre-fixme[16]: `Module` has no attribute `weight`. new_linear.weight = child.weight setattr(module, name, new_linear) else: _vk_replace_linear_int4( child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func, precision, scales_precision, linear_class, copy_weights, ) # This module is copied from torchao.quantization.GPTQ.Int4WeightOnlyQuantizer # with some changes at the annotated lines. class VkInt4WeightOnlyQuantizer(Quantizer): def __init__( self, groupsize: int = 256, padding_allowed: bool = True, inner_k_tiles: Optional[int] = 8, device: torch.device = torch.device("cpu"), # noqa precision: torch.dtype = torch.float32, feature_limit: int = 16384, ) -> None: super().__init__() assert inner_k_tiles in [2, 4, 8] assert groupsize in [32, 64, 128, 256] self.inner_k_tiles = inner_k_tiles self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed self.device: torch.device = device self.precision: torch.dtype = precision # Serves the same purpose as `tensor_dim_limit` in # executorch.backends.vulkan.partitioner.VulkanSupportedOperators self.feature_limit = feature_limit @torch.no_grad() def _create_quantized_state_dict( self, model: torch.nn.Module ) -> Dict[str, torch.Tensor]: cur_state_dict = model.state_dict() for fqn, mod in model.named_modules(): # Add additional check to make sure features do not exceed feature limit if isinstance(mod, torch.nn.Linear) and ( mod.out_features < self.feature_limit and mod.in_features < self.feature_limit ): assert not mod.bias out_features = mod.out_features in_features = mod.in_features logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") assert ( in_features % self.groupsize == 0 ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" weight = mod.weight.data if not _check_linear_int4_k( in_features, self.groupsize, self.inner_k_tiles ): if self.padding_allowed: import torch.nn.functional as F from torchao.utils import find_multiple logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) padded_in_features = find_multiple(in_features, (1024,)) weight = F.pad( weight, pad=(0, padded_in_features - in_features) ) else: logging.warn( f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + "and that groupsize and inner_k_tiles*16 evenly divide into it" ) continue (w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor( weight, 4, # n_bit self.groupsize, self.precision, # dtype for scales_and_zeros ) # In the original implementation, w_int4x8 is packed via calling the # _convert_weight_to_int4pack operator before storing the weight. However # the Vulkan implementation does not expect the weights to be packed, so # the w_int4x8 tensor is stored as the weight instead. cur_state_dict[f"{fqn}.weight"] = w_int4x8.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( self.device ) return cur_state_dict def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: _vk_replace_linear_int4( model, self.groupsize, self.inner_k_tiles, self.padding_allowed, skip_layer_func=None, precision=self.precision, scales_precision=self.precision, ) return model def quantize( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: state_dict = self._create_quantized_state_dict(model) model = self._convert_for_runtime(model) model.load_state_dict(state_dict, strict=False) return model