# mypy: allow-untyped-defs from typing import List import torch from torch.nn.parameter import Parameter __all__: List[str] = [] class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): r"""Generalized extension of the FakeQuantize module in fake_quantize.py. This is an extension of the FakeQuantize module in fake_quantize.py, which supports more generalized lower-bit quantization and supports learning of the scale and zero point parameters through backpropagation. In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize module also includes the following attributes to support quantization parameter learning. * :attr:`channel_len` defines the length of the channel when initializing scale and zero point for the per channel case. * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are normalized by the constant, which is proportional to the square root of the number of elements in the tensor. The related literature justifying the use of this particular constant can be found here: https://openreview.net/pdf?id=rkgO66VKDS. * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output. * :attr:`static_enabled` defines the flag for using observer's static estimation for scale and zero point. * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point. """ def __init__( self, observer, quant_min=0, quant_max=255, scale=1.0, zero_point=0.0, channel_len=-1, use_grad_scaling=False, **observer_kwargs, ): super().__init__() assert quant_min < quant_max, "quant_min must be strictly less than quant_max." self.quant_min = quant_min self.quant_max = quant_max # also pass quant_min and quant_max to observer observer_kwargs["quant_min"] = quant_min observer_kwargs["quant_max"] = quant_max self.use_grad_scaling = use_grad_scaling if channel_len == -1: self.scale = Parameter(torch.tensor([scale])) self.zero_point = Parameter(torch.tensor([zero_point])) else: assert ( isinstance(channel_len, int) and channel_len > 0 ), "Channel size must be a positive integer." self.scale = Parameter(torch.tensor([scale] * channel_len)) self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) self.activation_post_process = observer(**observer_kwargs) assert ( torch.iinfo(self.activation_post_process.dtype).min <= quant_min ), "quant_min out of bound" assert ( quant_max <= torch.iinfo(self.activation_post_process.dtype).max ), "quant_max out of bound" self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = ( self.activation_post_process.ch_axis if hasattr(self.activation_post_process, "ch_axis") else -1 ) self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) self.register_buffer("static_enabled", torch.tensor([1], dtype=torch.uint8)) self.register_buffer("learning_enabled", torch.tensor([0], dtype=torch.uint8)) bitrange = torch.tensor(quant_max - quant_min + 1).double() self.bitwidth = int(torch.log2(bitrange).item()) self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps])) @torch.jit.export def enable_param_learning(self): r"""Enable parameter learning over static observer estimates. Enables learning of quantization parameters and disables static observer estimates. Forward path returns fake quantized X. """ self.toggle_qparam_learning(enabled=True).toggle_fake_quant( enabled=True ).toggle_observer_update(enabled=False) return self @torch.jit.export def enable_static_estimate(self): """Enable static estimates of quantization parameters. Enables static observer estimates and disables learning of quantization parameters. Forward path returns fake quantized X. """ self.toggle_qparam_learning(enabled=False).toggle_fake_quant( enabled=True ).toggle_observer_update(enabled=True) @torch.jit.export def enable_static_observation(self): """Enable accumulation of data without updating quantization parameters. Enables static observer accumulating data from input but doesn't update the quantization parameters. Forward path returns the original X. """ self.toggle_qparam_learning(enabled=False).toggle_fake_quant( enabled=False ).toggle_observer_update(enabled=True) @torch.jit.export def toggle_observer_update(self, enabled=True): self.static_enabled[0] = int(enabled) # type: ignore[operator] return self @torch.jit.export def enable_observer(self, enabled=True): self.toggle_observer_update(enabled) @torch.jit.export def toggle_qparam_learning(self, enabled=True): self.learning_enabled[0] = int(enabled) # type: ignore[operator] self.scale.requires_grad = enabled self.zero_point.requires_grad = enabled return self @torch.jit.export def toggle_fake_quant(self, enabled=True): self.fake_quant_enabled[0] = int(enabled) return self @torch.jit.export def observe_quant_params(self): print(f"_LearnableFakeQuantize Scale: {self.scale.detach()}") print(f"_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}") @torch.jit.export def calculate_qparams(self): self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] scale = self.scale.detach() zero_point = ( self.zero_point.detach() .round() .clamp(self.quant_min, self.quant_max) .long() ) return scale, zero_point def forward(self, X): if self.static_enabled[0] == 1: # type: ignore[index] self.activation_post_process(X.detach()) _scale, _zero_point = self.activation_post_process.calculate_qparams() _scale = _scale.to(self.scale.device) _zero_point = _zero_point.to(self.zero_point.device) self.scale.data.copy_(_scale) self.zero_point.data.copy_(_zero_point) else: self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] if self.fake_quant_enabled[0] == 1: if self.qscheme in ( torch.per_channel_symmetric, torch.per_tensor_symmetric, ): self.zero_point.data.zero_() if self.use_grad_scaling: grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 else: grad_factor = 1.0 if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): X = torch._fake_quantize_learnable_per_channel_affine( X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max, grad_factor, ) else: X = torch._fake_quantize_learnable_per_tensor_affine( X, self.scale, self.zero_point, self.quant_min, self.quant_max, grad_factor, ) return X