# mypy: allow-untyped-defs """ This module implements nonuniform observers used to collect statistics about the values observed during calibration (PTQ) or training (QAT). """ import itertools import matplotlib.pyplot as plt import torch from torch.ao.quantization.experimental.apot_utils import apot_to_float, float_to_apot from torch.ao.quantization.observer import ObserverBase # TODO: Consider adding NonUniformQuantizationObserverBase class # when more than one non-uniform method is implemented class APoTObserver(ObserverBase): b: int k: int n: int min_val: torch.Tensor max_val: torch.Tensor def __init__(self, b, k, dtype=torch.quint8) -> None: super().__init__(dtype) self.b = b self.k = k self.min_val = torch.tensor([]) self.max_val = torch.tensor([]) # min_val and max_val are optional args to override # the min_val and max_val observed by forward def calculate_qparams(self, signed): return self._calculate_qparams(signed, self.min_val, self.max_val) r""" Calculates nonuniform quantization parameters according to APoT paper: https://arxiv.org/pdf/1909.13144.pdf. Arg: signed: specifies whether to include signed values in quantization level calculations min_val: optional arg that can override min_val internal attribute max_val: optional arg that can override max_val internal attribute Returns: alpha: alpha quantization parameter, max of abs value of observed values gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range quantization_levels: non-uniform quantization levels (fp representation) level_indices: int representation of quantization_levels indices """ def _calculate_qparams(self, signed: bool, min_val=None, max_val=None): if min_val is not None: self.min_val = min_val if max_val is not None: self.max_val = max_val # compute alpha alpha = torch.max(-self.min_val, self.max_val) # check for valid inputs of b, k assert self.k and self.k != 0 assert self.b % self.k == 0 # compute n and store as member variable self.n = self.b // self.k # store a tensor of subtensors (all levels) p_all = [] # create levels for i in range(0, self.n): p_curr = torch.tensor([0]) for j in range(0, (2**self.k - 2) + 1): curr_ele = 2 ** (-(i + j * self.n)) p_append = torch.tensor([curr_ele]) p_curr = torch.cat((p_curr, p_append)) # introduce signed numbers if signed: p_curr = torch.cat((p_curr, torch.tensor([-curr_ele]))) if signed: # sort tensor in reverse order before adding to list if signed sorted, indices = torch.sort(p_curr, descending=True) p_all.append(sorted) else: p_all.append(p_curr) # gamma calculation: # loop through all tensors # if signed, add element at index 0 for each tensor # else, add element at index 1 for each tensor # gamma defined to ensure alpha is at max of range p_sum = 0.0 for tens in p_all: if signed: p_sum += float(tens[0]) else: p_sum += float(tens[1]) # assign gamma gamma = alpha / p_sum # calculate cartesian product cartesian_product = list(itertools.product(*p_all)) quantization_levels_list = [] # calculate sum of each row for row in cartesian_product: sum = 0.0 for ele in row: sum += ele quantization_levels_list.append(sum) quantization_levels_gamma = [ float(gamma) * ele for ele in quantization_levels_list ] quantization_levels = torch.tensor(quantization_levels_gamma) level_indices = torch.tensor([]) quantization_levels, level_indices = quantization_levels.sort() return (alpha, gamma, quantization_levels, level_indices) r"""Records the running minimum and maximum of ``x``. Args: x_orig: Tensor to be observed for min and max val""" def forward(self, x_orig): if x_orig.numel() == 0: return x_orig x = x_orig.detach() min_val, max_val = torch.aminmax(x) if self.min_val.numel(): min_val = torch.min(min_val, self.min_val) if self.max_val.numel(): max_val = torch.max(max_val, self.max_val) self.min_val = min_val self.max_val = max_val return x_orig r"""Displays visualization of APoT quantization levels Args: observer: APoTObserver to calculate qparams signed: bool to indicate if qparams should be signed/unsigned """ def quant_levels_visualization(self, signed=False): alpha, gamma, quantization_levels, level_indices = self.calculate_qparams( signed ) xs = [float(x) / 1000.0 for x in range(1000)] ys = [ apot_to_float( float_to_apot(x, quantization_levels, level_indices, alpha), quantization_levels, level_indices, ).item() for x in xs ] f = plt.figure(figsize=(15, 10)) plt.plot(xs, ys) plt.title("APoT Quantization Plot") plt.xlabel("Full Precision") plt.ylabel("Quantized") plt.show()