• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Constant module for compression"""
16import enum
17import re
18from types import DynamicClassAttribute
19
20
21__all__ = ["QuantDtype"]
22
23
24@enum.unique
25class QuantDtype(enum.Enum):
26    """
27    An enum for quant datatype, contains `INT2` ~ `INT8`, `UINT2` ~ `UINT8`.
28    """
29    INT2 = "INT2"
30    INT3 = "INT3"
31    INT4 = "INT4"
32    INT5 = "INT5"
33    INT6 = "INT6"
34    INT7 = "INT7"
35    INT8 = "INT8"
36
37    UINT2 = "UINT2"
38    UINT3 = "UINT3"
39    UINT4 = "UINT4"
40    UINT5 = "UINT5"
41    UINT6 = "UINT6"
42    UINT7 = "UINT7"
43    UINT8 = "UINT8"
44
45    def __str__(self):
46        return f"{self.name}"
47
48    @staticmethod
49    def is_signed(dtype):
50        """
51        Get whether the quant datatype is signed.
52
53        Args:
54            dtype (QuantDtype): quant datatype.
55
56        Returns:
57            bool, whether the input quant datatype is signed.
58
59        Examples:
60            >>> quant_dtype = QuantDtype.INT8
61            >>> is_signed = QuantDtype.is_signed(quant_dtype)
62        """
63        return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5,
64                         QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8]
65
66    @staticmethod
67    def switch_signed(dtype):
68        """
69        Switch the signed state of the input quant datatype.
70
71        Args:
72            dtype (QuantDtype): quant datatype.
73
74        Returns:
75            QuantDtype, quant datatype with opposite signed state as the input.
76
77        Examples:
78            >>> quant_dtype = QuantDtype.INT8
79            >>> quant_dtype = QuantDtype.switch_signed(quant_dtype)
80        """
81        type_map = {
82            QuantDtype.INT2: QuantDtype.UINT2,
83            QuantDtype.INT3: QuantDtype.UINT3,
84            QuantDtype.INT4: QuantDtype.UINT4,
85            QuantDtype.INT5: QuantDtype.UINT5,
86            QuantDtype.INT6: QuantDtype.UINT6,
87            QuantDtype.INT7: QuantDtype.UINT7,
88            QuantDtype.INT8: QuantDtype.UINT8,
89            QuantDtype.UINT2: QuantDtype.INT2,
90            QuantDtype.UINT3: QuantDtype.INT3,
91            QuantDtype.UINT4: QuantDtype.INT4,
92            QuantDtype.UINT5: QuantDtype.INT5,
93            QuantDtype.UINT6: QuantDtype.INT6,
94            QuantDtype.UINT7: QuantDtype.INT7,
95            QuantDtype.UINT8: QuantDtype.INT8
96        }
97        return type_map[dtype]
98
99    @DynamicClassAttribute
100    def _value(self):
101        """The value of the Enum member."""
102        return int(re.search(r"(\d+)", self._value_).group(1))
103
104    @DynamicClassAttribute
105    def num_bits(self):
106        """
107        Get the num bits of the QuantDtype member.
108
109        Returns:
110            int, the num bits of the QuantDtype member.
111
112        Examples:
113            >>> from mindspore.compression.common import QuantDtype
114            >>> quant_dtype = QuantDtype.INT8
115            >>> num_bits = quant_dtype.num_bits
116            >>> print(num_bits)
117            8
118        """
119        return self._value
120