1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Constant module for compression""" 16import enum 17import re 18from types import DynamicClassAttribute 19 20 21__all__ = ["QuantDtype"] 22 23 24@enum.unique 25class QuantDtype(enum.Enum): 26 """ 27 An enum for quant datatype, contains `INT2` ~ `INT8`, `UINT2` ~ `UINT8`. 28 """ 29 INT2 = "INT2" 30 INT3 = "INT3" 31 INT4 = "INT4" 32 INT5 = "INT5" 33 INT6 = "INT6" 34 INT7 = "INT7" 35 INT8 = "INT8" 36 37 UINT2 = "UINT2" 38 UINT3 = "UINT3" 39 UINT4 = "UINT4" 40 UINT5 = "UINT5" 41 UINT6 = "UINT6" 42 UINT7 = "UINT7" 43 UINT8 = "UINT8" 44 45 def __str__(self): 46 return f"{self.name}" 47 48 @staticmethod 49 def is_signed(dtype): 50 """ 51 Get whether the quant datatype is signed. 52 53 Args: 54 dtype (QuantDtype): quant datatype. 55 56 Returns: 57 bool, whether the input quant datatype is signed. 58 59 Examples: 60 >>> quant_dtype = QuantDtype.INT8 61 >>> is_signed = QuantDtype.is_signed(quant_dtype) 62 """ 63 return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5, 64 QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8] 65 66 @staticmethod 67 def switch_signed(dtype): 68 """ 69 Switch the signed state of the input quant datatype. 70 71 Args: 72 dtype (QuantDtype): quant datatype. 73 74 Returns: 75 QuantDtype, quant datatype with opposite signed state as the input. 76 77 Examples: 78 >>> quant_dtype = QuantDtype.INT8 79 >>> quant_dtype = QuantDtype.switch_signed(quant_dtype) 80 """ 81 type_map = { 82 QuantDtype.INT2: QuantDtype.UINT2, 83 QuantDtype.INT3: QuantDtype.UINT3, 84 QuantDtype.INT4: QuantDtype.UINT4, 85 QuantDtype.INT5: QuantDtype.UINT5, 86 QuantDtype.INT6: QuantDtype.UINT6, 87 QuantDtype.INT7: QuantDtype.UINT7, 88 QuantDtype.INT8: QuantDtype.UINT8, 89 QuantDtype.UINT2: QuantDtype.INT2, 90 QuantDtype.UINT3: QuantDtype.INT3, 91 QuantDtype.UINT4: QuantDtype.INT4, 92 QuantDtype.UINT5: QuantDtype.INT5, 93 QuantDtype.UINT6: QuantDtype.INT6, 94 QuantDtype.UINT7: QuantDtype.INT7, 95 QuantDtype.UINT8: QuantDtype.INT8 96 } 97 return type_map[dtype] 98 99 @DynamicClassAttribute 100 def _value(self): 101 """The value of the Enum member.""" 102 return int(re.search(r"(\d+)", self._value_).group(1)) 103 104 @DynamicClassAttribute 105 def num_bits(self): 106 """ 107 Get the num bits of the QuantDtype member. 108 109 Returns: 110 int, the num bits of the QuantDtype member. 111 112 Examples: 113 >>> from mindspore.compression.common import QuantDtype 114 >>> quant_dtype = QuantDtype.INT8 115 >>> num_bits = quant_dtype.num_bits 116 >>> print(num_bits) 117 8 118 """ 119 return self._value 120