• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Quantization utils."""
16
17import numpy as np
18from mindspore._checkparam import Validator
19from ... import nn
20
21__all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers"]
22
23
24def cal_quantization_params(input_min,
25                            input_max,
26                            quant_min,
27                            quant_max,
28                            data_type,
29                            symmetric=False):
30    r"""
31    Calculate quantization params for scale and zero point.
32
33    Args:
34        input_min (numpy.ndarray): The dimension of channel or 1.
35        input_max (numpy.ndarray): The dimension of channel or 1.
36        quant_min (int): The minimum quantization integer.
37        quant_max (int): The maximum quantization integer.
38        data_type (numpy type) : Can be numpy int8, numpy uint8.
39        symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
40
41    Returns:
42        scale (numpy.ndarray): quantization param.
43        zero point (numpy.ndarray): quantization param.
44    """
45    input_max = np.maximum(0.0, input_max)
46    input_min = np.minimum(0.0, input_min)
47
48    if input_min.shape != input_max.shape:
49        raise ValueError("input min shape should equal to input max.")
50    if len(input_min.shape) > 1:
51        raise ValueError("input min and max shape should be one dim.")
52    if (input_min > input_max).all():
53        raise ValueError("input_min min should less than input max.")
54    if (input_max == input_min).all():
55        return np.ones(input_min.shape), np.zeros(input_min.shape)
56
57    # calculate scale
58    if symmetric:
59        input_max = np.maximum(-input_min, input_max)
60        input_min = -input_max
61    scale = (input_max - input_min) / (quant_max - quant_min)
62
63    # calculate zero point
64    if data_type == np.int8 and symmetric:
65        zp = np.zeros(input_min.shape)
66    else:
67        zp_double = quant_min - input_min / scale
68        zp = np.floor(zp_double + 0.5)
69
70    return scale, zp
71
72
73def get_quant_min_max(data_type, num_bits=8, narrow_range=False):
74    """Calculate quantization params for minimum/maximum quantization integer"""
75    if data_type == np.int8:
76        quant_min = 0 - 2 ** (num_bits - 1)
77        quant_max = 2 ** (num_bits - 1) - 1
78    elif data_type == np.uint8:
79        quant_min = 0
80        quant_max = 2 ** num_bits - 1
81    else:
82        raise ValueError("Unsupported datatype({})".format(data_type))
83    if narrow_range:
84        quant_min = quant_min + 1
85    return quant_min, quant_max
86
87
88def weight2int(data, scale, zero_point, quant_min, quant_max):
89    r"""
90    Calculate int8/uint8 weight from fp32. the formula is defined as:
91
92    .. math::
93        int8/uint8 = round(float/scale) + offset
94
95    Args:
96        data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
97        scale (numpy.ndarray): The dimension of channel or 1.
98        zero_point (numpy.ndarray): The dimension of channel or 1.
99        quant_min (int): The minimum quantization integer.
100        quant_max (int): The maximum quantization integer.
101
102    Returns:
103        weight (numpy.ndarray): The dimension of channel or 1.
104    """
105    if scale.shape != zero_point.shape:
106        raise ValueError("`scale` and `zero_point` should have the same shape.")
107    if scale.shape[0] < 0:
108        raise ValueError("`scale` and `zero_point` shape should greater than zero.")
109    if len(scale.shape) >= 1 and scale.shape[0] > 1:
110        # for perchannel
111        if scale.shape[0] == data.shape[0]:
112            # `Conv2d` or `Dense` op weight
113            shape_list = [-1] + [1] * len(data.shape[1:])
114            scale = scale.reshape(shape_list)
115            zero_point = zero_point.reshape(shape_list)
116        elif scale.shape[0] == data.shape[1]:
117            # `DepthwiseConv2d` op weight
118            shape_list = [1, -1] + [1] * len(data.shape[2:])
119            scale = scale.reshape(shape_list)
120            zero_point = zero_point.reshape(shape_list)
121        else:
122            raise ValueError("Unsupported weight shape({})".format(data.shape))
123
124    weight_int = np.round((data / scale) + zero_point)
125    weight_int[weight_int > quant_max] = quant_max
126    weight_int[weight_int < quant_min] = quant_min
127    return weight_int
128
129
130def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
131    """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMaxObserver`."""
132    minq = cell.minq.data.asnumpy()
133    maxq = cell.maxq.data.asnumpy()
134    # make sure maxq > 0 and minq <= 0
135    if cell.mode == 'LEARNED_SCALE':
136        maxq = np.abs(maxq)
137        minq = -np.abs(minq)
138    quant_min, quant_max = get_quant_min_max(data_type, num_bits=cell.num_bits, narrow_range=cell.narrow_range)
139    symmetric = cell.symmetric and not cell.neg_trunc
140    scale, zp = cal_quantization_params(
141        minq, maxq,
142        quant_min, quant_max, data_type,
143        symmetric=symmetric)
144    return scale, zp, maxq, minq
145
146
147def fold_batchnorm(weight, cell_quant):
148    r"""
149    Fold the batchnorm in `Conv2dBnFoldQuant` to weight.
150
151    Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
152
153    Args:
154        weight (numpy.ndarray): Weight of `cell_quant`.
155        cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`.
156
157    Returns:
158        weight (numpy.ndarray): Folded weight.
159        bias (numpy.ndarray): Folded bias.
160    """
161    variance = cell_quant.moving_variance.data.asnumpy()
162    mean = cell_quant.moving_mean.data.asnumpy()
163    gamma = cell_quant.gamma.data.asnumpy()
164    beta = cell_quant.beta.data.asnumpy()
165    epsilon = cell_quant.eps
166    sigma = np.sqrt(variance + epsilon)
167
168    if gamma.shape[0] == weight.shape[0]:
169        # `Conv2d` or `Dense` op weight
170        shape_list = [-1] + [1] * len(weight.shape[1:])
171        _gamma = gamma.reshape(shape_list)
172        _sigma = sigma.reshape(shape_list)
173    elif gamma.shape[0] == weight.shape[1]:
174        # `DepthwiseConv2d` op weight
175        shape_list = [1, -1] + [1] * len(weight.shape[2:])
176        _gamma = gamma.reshape(shape_list)
177        _sigma = sigma.reshape(shape_list)
178    else:
179        raise ValueError("Unsupported weight shape({})".format(weight.shape))
180
181    weight = weight * _gamma / _sigma
182    bias = beta - gamma * mean / sigma
183    return weight, bias
184
185
186def without_fold_batchnorm(weight, cell_quant):
187    r"""
188    Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight.
189
190    Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
191
192    Args:
193        weight (numpy.ndarray): Weight of `cell_quant`.
194        cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`.
195
196    Returns:
197        weight (numpy.ndarray): whihout folded weight.
198        bias (numpy.ndarray): without folded bias.
199    """
200    variance = cell_quant.batchnorm.moving_variance.data.asnumpy()
201    mean = cell_quant.batchnorm.moving_mean.data.asnumpy()
202    gamma = cell_quant.batchnorm.gamma.data.asnumpy()
203    beta = cell_quant.batchnorm.beta.data.asnumpy()
204    epsilon = cell_quant.batchnorm.eps
205    sigma = np.sqrt(variance + epsilon)
206
207    if gamma.shape[0] == weight.shape[0]:
208        # `Conv2d` or `Dense` op weight
209        shape_list = [-1] + [1] * len(weight.shape[1:])
210        _gamma = gamma.reshape(shape_list)
211        _sigma = sigma.reshape(shape_list)
212    elif gamma.shape[0] == weight.shape[1]:
213        # `DepthwiseConv2d` op weight
214        shape_list = [1, -1] + [1] * len(weight.shape[2:])
215        _gamma = gamma.reshape(shape_list)
216        _sigma = sigma.reshape(shape_list)
217    else:
218        raise ValueError("Unsupported weight shape({})".format(weight.shape))
219
220    weight = weight * _gamma / _sigma
221    bias = beta - gamma * mean / sigma
222    return weight, bias
223
224
225def compute_kl_threshold(data, bitwidth):
226    r"""
227    Using KL-J Distance to calculate the clip threshold.
228
229    Args:
230        - **data** (NumpyArray) - Data observed to calculate the threshold for quantization,
231        - **bitwidth** (QuantDtype) - The datatype of quantization.
232    Outputs:
233        Tensor with Shape 1. Threshold to calculate the data.
234    """
235    data_max = np.abs(data).max()
236    if data_max < 1e-5:
237        return 1e-5
238    hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(0, data_max), density=True)
239    # For the sake of high efficiency, we limit the maximum number of bins to 1024 in `sqrt` mode, If it exceeds the
240    # largest size, turn to use the default bins config.
241    largest_bin_size = 1024
242    if hist.shape[0] > largest_bin_size:
243        hist, bin_edges = np.histogram(np.abs(data), range=(0, data_max), density=True)
244    hist = hist / np.sum(hist)
245    cumsum = np.cumsum(hist)
246    bit_pow_range = pow(2, int(bitwidth.num_bits) - 1)
247    threshold = []
248    scaling_factor = []
249    kl = []
250    if bit_pow_range + 1 > len(bin_edges) - 1:
251        th_layer_out = bin_edges[-1]
252        return float(th_layer_out)
253    for i in range(bit_pow_range + 1, len(bin_edges), 1):
254        threshold_tmp = (i + 0.5) * (bin_edges[1] - bin_edges[0])
255        threshold = np.concatenate((threshold, [threshold_tmp]))
256        scaling_factor_tmp = threshold_tmp / (bit_pow_range - 1)
257        scaling_factor = np.concatenate((scaling_factor, [scaling_factor_tmp]))
258        # forward interpolation
259        cumsum_tmp = np.copy(cumsum)
260        cumsum_tmp[(i - 1):] = 1
261        fwd_x = np.linspace(0.0, 1.0, bit_pow_range)
262        fwd_xp = np.linspace(0.0, 1.0, i)
263        fwd_fp = cumsum_tmp[:i]
264        forward_interp = np.interp(fwd_x, fwd_xp, fwd_fp)
265        # backward interpolation
266        bwd_x = np.linspace(0.0, 1.0, i)
267        bwd_xp = np.linspace(0.0, 1.0, bit_pow_range)
268        bwd_fp = forward_interp
269        backward_interp = np.interp(bwd_x, bwd_xp, bwd_fp)
270        cumsum_tmp[:i] = backward_interp
271        kl_tmp = np.sum((cumsum - cumsum_tmp) * np.log2(cumsum / cumsum_tmp))  # Kullback-Leibler-J
272        kl = np.concatenate((kl, [kl_tmp]))
273    th_layer_out = threshold[np.argmin(kl)]
274    threshold = float(th_layer_out)
275    if threshold < 1e-5:
276        threshold = 1e-5
277    return threshold
278
279
280def query_quant_layers(network):
281    r"""
282    Query the network's quantization strategy of each quantized layer and print it to the screen, note that all the
283    quantization layers are queried before graph compile optimization in the graph mode, thus, some redundant quantized
284    layers, which not exist in practical execution, may appear.
285
286    Args:
287        network (Cell): input network
288    """
289    network = Validator.check_isinstance("network", network, nn.Cell)
290    tplt = "{0:60}\t{1:10}"
291    for cell_and_name in network.cells_and_names():
292        cell_name = cell_and_name[0]
293        cell = cell_and_name[1]
294        if isinstance(cell, nn.FakeQuantWithMinMaxObserver):
295            print(tplt.format(cell_name, cell.quant_dtype))
296
297
298def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None):
299    r"""
300    Load fp32 model parameters into quantization model.
301
302    Args:
303        quant_model(Cell): Quantization model.
304        params_dict(dict): Parameter dict that stores fp32 parameters.
305        quant_new_params(list): Parameters that exist in quantization network but not in non-quantization
306            network. Default: None.
307
308    Raises:
309        TypeError: If `quant_new_params` is not None and is not list.
310        ValueError: If there are parameters in the `quant_model` that are neither in `params_dict`
311            nor in `quant_new_params`.
312    """
313    if quant_new_params is not None and not isinstance(quant_new_params, list):
314        raise TypeError("quant_new_params must be list or None.")
315    iterable_dict = {
316        'minq': iter(list(filter(lambda item: item[0].endswith('minq'), params_dict.items()))),
317        'maxq': iter(list(filter(lambda item: item[0].endswith('maxq'), params_dict.items()))),
318        'quant_max': iter(list(filter(lambda item: item[0].endswith('quant_max'), params_dict.items())))
319    }
320    for param in params_dict.items():
321        key_name = param[0].split(".")[-1]
322        if key_name not in iterable_dict:
323            iterable_dict[key_name] = iter(list(filter(lambda item, value=key_name: item[0].endswith(value),
324                                                       params_dict.items())))
325
326    for name, param in quant_model.parameters_and_names():
327        key_name = name.split(".")[-1]
328        if key_name not in iterable_dict.keys():
329            if key_name not in quant_new_params:
330                raise ValueError(f"Can't find match parameter in ckpt, param name = {name}")
331            continue
332        value_param = next(iterable_dict[key_name], None)
333        if value_param:
334            param.set_data(value_param[1].data)
335            print(f'init model param {name} with checkpoint param {value_param[0]}')
336
337
338    # Perform KL_init when learned scale quantization is executed.
339    for cell_and_name in quant_model.cells_and_names():
340        cell = cell_and_name[1]
341        if isinstance(cell, (nn.Conv2dBnFoldQuantOneConv, nn.Conv2dBnFoldQuant, nn.Conv2dBnWithoutFoldQuant,
342                             nn.Conv2dQuant, nn.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE":
343            subcell_weight_para = cell.weight.data.asnumpy()
344            if hasattr(cell, 'gamma'):
345                scale_factor = (cell.gamma.data.asnumpy() /
346                                np.sqrt(cell.moving_variance.data.asnumpy() + 1e-5))
347                subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
348
349            if cell.fake_quant_weight.per_channel:
350                max_init = [compute_kl_threshold(weight_para_each, cell.fake_quant_weight.quant_dtype)
351                            for weight_para_each in subcell_weight_para]
352                min_init = [-x for x in max_init]
353            else:
354                max_init = [compute_kl_threshold(subcell_weight_para, cell.fake_quant_weight.quant_dtype)]
355                min_init = [-x for x in max_init]
356
357            cell.fake_quant_weight.reset(quant_dtype=cell.fake_quant_weight.quant_dtype,
358                                         min_init=min_init, max_init=max_init)
359