• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2020-2021 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17"""Data type for MindSpore."""
18from __future__ import absolute_import
19
20import enum
21from inspect import isfunction
22import numpy as np
23from mindspore._c_expression import typing
24from mindspore._c_expression.typing import Type
25from mindspore._c_expression.np_dtypes import np_version_valid
26if np_version_valid(False):
27    from mindspore._c_expression.np_dtypes import bfloat16 as np_bfloat16
28
29__dtype__ = [
30    "int8", "byte",
31    "int16", "short",
32    "int32", "intc",
33    "int64", "intp",
34    "uint8", "ubyte",
35    "uint16", "ushort",
36    "uint32", "uintc",
37    "uint64", "uintp",
38    "float16", "half",
39    "float32", "single",
40    "float64", "double",
41    "bool_", "float_",
42    "list_", "tuple_",
43    "int_", "uint",
44    "number", "tensor_type",
45    "string", "type_none",
46    "TensorType", "_null",
47    "Type", "Int",
48    "complex64", "complex128",
49    "bfloat16", "qint4x2"
50]
51
52__method__ = [
53    "dtype_to_nptype", "dtype_to_pytype",
54    "pytype_to_dtype", "get_py_obj_dtype"
55]
56
57__all__ = ["Type", "QuantDtype"]
58__all__.extend(__dtype__)
59__all__.extend(__method__)
60
61# type definition
62bool_ = typing.Bool()
63
64qint4x2 = typing.Int(4)
65int8 = typing.Int(8)
66byte = int8
67int16 = typing.Int(16)
68short = int16
69int32 = typing.Int(32)
70intc = int32
71int64 = typing.Int(64)
72intp = int64
73
74uint8 = typing.UInt(8)
75ubyte = uint8
76uint16 = typing.UInt(16)
77ushort = uint16
78uint32 = typing.UInt(32)
79uintc = uint32
80uint64 = typing.UInt(64)
81uintp = uint64
82
83float16 = typing.Float(16)
84half = float16
85float32 = typing.Float(32)
86single = float32
87float64 = typing.Float(64)
88double = float64
89bfloat16 = typing.BFloat(16)
90complex64 = typing.Complex(64)
91complex128 = typing.Complex(128)
92
93number = typing.Number()
94int_ = typing.Int()
95uint = typing.UInt()
96float_ = typing.Float()
97string = typing.String()
98list_ = typing.List()
99tuple_ = typing.Tuple()
100type_none = typing.TypeNone()
101_null = typing.TypeNull()
102
103tensor_type = typing.TensorType()
104index_slices = typing.RowTensorType()
105coo_tensor = typing.COOTensorType()
106csr_tensor = typing.CSRTensorType()
107undetermined = typing.UndeterminedType()
108
109function = typing.Function()
110symbolic_key = typing.SymbolicKeyType()
111env_type = typing.EnvType()
112type_type = typing.TypeType()
113type_refkey = typing.RefKeyType()
114
115Int = typing.Int
116Float = typing.Float
117Bool = typing.Bool
118String = typing.String
119List = typing.List
120Tuple = typing.Tuple
121Dict = typing.Dict
122Slice = typing.Slice
123FunctionType = typing.Function
124Ellipsis_ = typing.TypeEllipsis
125MsClassType = typing.TypeMsClassType
126NoneType = typing.TypeNone
127EnvType = typing.EnvType
128TensorType = typing.TensorType
129CSRTensorType = typing.CSRTensorType
130AnythingType = typing.TypeAny
131RefType = typing.RefType
132_NullType = typing.TypeNull
133
134number_type = (int8,
135               int16,
136               int32,
137               int64,
138               uint8,
139               uint16,
140               uint32,
141               uint64,
142               float16,
143               float32,
144               float64,
145               bfloat16,
146               complex64,
147               complex128,
148               qint4x2,)
149
150int_type = (int8, int16, int32, int64,)
151uint_type = (uint8, uint16, uint32, uint64,)
152float_type = (float16, float32, float64, bfloat16,)
153signed_type = (int8, byte, int16, short, int32, intc, int64,
154               intp, float16, half, float32, single, float64,
155               double, bfloat16, complex64, complex128)
156complex_type = (complex64, complex128,)
157all_types = (bool_, int8, uint8, int16, int32, int64, float16, float32, float64, bfloat16, complex64, complex128)
158implicit_conversion_seq = {t: idx for idx, t in enumerate(all_types)}
159
160_simple_types = {
161    list: list_,
162    tuple: tuple_,
163    type(None): type_none,
164    bool: bool_,
165    int: int64,
166    float: float64,
167    complex: complex128,
168    str: string,
169    np.bool_: bool_,
170    np.str_: string,
171    np.int8: int8,
172    np.int16: int16,
173    np.int32: int32,
174    np.int64: int64,
175    np.uint8: uint8,
176    np.uint16: uint16,
177    np.uint32: uint32,
178    np.uint64: uint64,
179    np.float16: float16,
180    np.float32: float32,
181    np.float64: float64,
182}
183
184
185def pytype_to_dtype(obj):
186    """
187    Convert python type to MindSpore type.
188
189    Args:
190        obj (type): A python type object.
191
192    Returns:
193        Type of MindSpore type.
194
195    Raises:
196        NotImplementedError: If the python type cannot be converted to MindSpore type.
197
198    Examples:
199        >>> import mindspore as ms
200        >>> out = ms.pytype_to_dtype(bool)
201        >>> print(out)
202        Bool
203    """
204
205    if isinstance(obj, np.dtype):
206        obj = obj.type
207    if isinstance(obj, typing.Type):
208        return obj
209    if not isinstance(obj, type):
210        raise TypeError("For 'pytype_to_dtype', the argument 'obj' must be a python type object,"
211                        "such as int, float, str, etc. But got type {}.".format(type(obj)))
212    if obj in _simple_types:
213        return _simple_types[obj]
214    raise NotImplementedError(f"The python type {obj} cannot be converted to MindSpore type.")
215
216
217def get_py_obj_dtype(obj):
218    """
219    Get the MindSpore data type, which corresponds to python type or variable.
220
221    Args:
222        obj (type): An object of python type, or a variable of python type.
223
224    Returns:
225        Type of MindSpore type.
226
227    Examples:
228        >>> import mindspore as ms
229        >>> ms.get_py_obj_dtype(1)
230        mindspore.int64
231    """
232    # Tensor
233    if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
234        return TensorType(obj.dtype)
235    # Primitive or Cell
236    if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
237        return function
238    # python function type
239    if isfunction(obj):
240        return function
241    # mindspore type
242    if isinstance(obj, typing.Type):
243        return type_type
244    # python type
245    if isinstance(obj, type):
246        return pytype_to_dtype(obj)
247    # others
248    return pytype_to_dtype(type(obj))
249
250
251def dtype_to_nptype(type_):
252    r"""
253    Convert MindSpore dtype to numpy data type.
254
255    Args:
256        type\_ (:class:`mindspore.dtype`): MindSpore's dtype.
257
258    Returns:
259        The data type of numpy.
260
261    Examples:
262        >>> import mindspore as ms
263        >>> ms.dtype_to_nptype(ms.int8)
264        <class 'numpy.int8'>
265    """
266    _dtype_nptype_dict = {
267        bool_: np.bool_,
268        int8: np.int8,
269        int16: np.int16,
270        int32: np.int32,
271        int64: np.int64,
272        uint8: np.uint8,
273        uint16: np.uint16,
274        uint32: np.uint32,
275        uint64: np.uint64,
276        float16: np.float16,
277        float32: np.float32,
278        float64: np.float64,
279        complex64: np.complex64,
280        complex128: np.complex128,
281    }
282    if np_version_valid(False):
283        _dtype_nptype_dict.update({bfloat16: np_bfloat16})
284    return _dtype_nptype_dict[type_]
285
286
287def dtype_to_pytype(type_):
288    r"""
289    Convert MindSpore dtype to python data type.
290
291    Args:
292        type\_ (:class:`mindspore.dtype`): MindSpore's dtype.
293
294    Returns:
295        Type of python.
296
297    Examples:
298        >>> import mindspore as ms
299        >>> out = ms.dtype_to_pytype(ms.bool_)
300        >>> print(out)
301        <class 'bool'>
302    """
303
304    return {
305        bool_: bool,
306        int_: int,
307        int8: int,
308        int16: int,
309        int32: int,
310        int64: int,
311        uint8: int,
312        uint16: int,
313        uint32: int,
314        uint64: int,
315        float_: float,
316        float16: float,
317        float32: float,
318        float64: float,
319        bfloat16: float,
320        list_: list,
321        tuple_: tuple,
322        string: str,
323        complex64: complex,
324        complex128: complex,
325        type_none: type(None)
326    }[type_]
327
328
329def _issubclass_(type_, dtype):
330    if not isinstance(type_, typing.Type):
331        return False
332    return typing.is_subclass(type_, dtype)
333
334
335
336def type_size_in_bytes(dtype):
337    """
338    Return type size in bytes.
339
340    Args:
341        dtype (:class:`mindspore.dtype`): MindSpore dtype.
342
343    Returns:
344        Type size in bytes.
345    """
346
347    if not isinstance(dtype, typing.Type):
348        raise TypeError("The argument `dtype` should be instance of ", typing.Type)
349    return typing.type_size_in_bytes(dtype)
350
351
352@enum.unique
353class QuantDtype(enum.Enum):
354    """
355    An enum for quant datatype, contains `INT1` ~ `INT16`, `UINT1` ~ `UINT16`.
356
357    `QuantDtype` is defined in
358    `dtype.py <https://gitee.com/mindspore/mindspore/blob/master/mindspore/python/mindspore/common/dtype.py>`_ ,
359    use command below to import:
360
361    .. code-block::
362
363        from mindspore import QuantDtype
364
365    Tutorial Examples:
366        - `Quantization algorithm in Golden Stick
367          <https://www.mindspore.cn/golden_stick/docs/en/master/quantization/slb.html
368          #applying-the-quantization-algorithm>`_
369    """
370    INT1 = 0
371    INT2 = 1
372    INT3 = 2
373    INT4 = 3
374    INT5 = 4
375    INT6 = 5
376    INT7 = 6
377    INT8 = 7
378    INT9 = 8
379    INT10 = 9
380    INT11 = 10
381    INT12 = 11
382    INT13 = 12
383    INT14 = 13
384    INT15 = 14
385    INT16 = 15
386
387    UINT1 = 100
388    UINT2 = 101
389    UINT3 = 102
390    UINT4 = 103
391    UINT5 = 104
392    UINT6 = 105
393    UINT7 = 106
394    UINT8 = 107
395    UINT9 = 108
396    UINT10 = 109
397    UINT11 = 110
398    UINT12 = 111
399    UINT13 = 112
400    UINT14 = 113
401    UINT15 = 114
402    UINT16 = 115
403
404    def __str__(self):
405        return f"{self.name}"
406
407    def value(self) -> int:
408        """
409        Return value of `QuantDtype`. This interface is currently used to serialize or deserialize `QuantDtype`
410        primarily.
411
412        Returns:
413            An int as value of `QuantDtype`.
414
415        Examples:
416            >>> from mindspore import QuantDtype
417            >>> print(QuantDtype.INT8.value())
418            7
419            >>> print(QuantDtype.UINT16.value())
420            115
421        """
422        return self._value_
423