1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2020-2021 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17"""Data type for MindSpore.""" 18from __future__ import absolute_import 19 20import enum 21from inspect import isfunction 22import numpy as np 23from mindspore._c_expression import typing 24from mindspore._c_expression.typing import Type 25from mindspore._c_expression.np_dtypes import np_version_valid 26if np_version_valid(False): 27 from mindspore._c_expression.np_dtypes import bfloat16 as np_bfloat16 28 29__dtype__ = [ 30 "int8", "byte", 31 "int16", "short", 32 "int32", "intc", 33 "int64", "intp", 34 "uint8", "ubyte", 35 "uint16", "ushort", 36 "uint32", "uintc", 37 "uint64", "uintp", 38 "float16", "half", 39 "float32", "single", 40 "float64", "double", 41 "bool_", "float_", 42 "list_", "tuple_", 43 "int_", "uint", 44 "number", "tensor_type", 45 "string", "type_none", 46 "TensorType", "_null", 47 "Type", "Int", 48 "complex64", "complex128", 49 "bfloat16", "qint4x2" 50] 51 52__method__ = [ 53 "dtype_to_nptype", "dtype_to_pytype", 54 "pytype_to_dtype", "get_py_obj_dtype" 55] 56 57__all__ = ["Type", "QuantDtype"] 58__all__.extend(__dtype__) 59__all__.extend(__method__) 60 61# type definition 62bool_ = typing.Bool() 63 64qint4x2 = typing.Int(4) 65int8 = typing.Int(8) 66byte = int8 67int16 = typing.Int(16) 68short = int16 69int32 = typing.Int(32) 70intc = int32 71int64 = typing.Int(64) 72intp = int64 73 74uint8 = typing.UInt(8) 75ubyte = uint8 76uint16 = typing.UInt(16) 77ushort = uint16 78uint32 = typing.UInt(32) 79uintc = uint32 80uint64 = typing.UInt(64) 81uintp = uint64 82 83float16 = typing.Float(16) 84half = float16 85float32 = typing.Float(32) 86single = float32 87float64 = typing.Float(64) 88double = float64 89bfloat16 = typing.BFloat(16) 90complex64 = typing.Complex(64) 91complex128 = typing.Complex(128) 92 93number = typing.Number() 94int_ = typing.Int() 95uint = typing.UInt() 96float_ = typing.Float() 97string = typing.String() 98list_ = typing.List() 99tuple_ = typing.Tuple() 100type_none = typing.TypeNone() 101_null = typing.TypeNull() 102 103tensor_type = typing.TensorType() 104index_slices = typing.RowTensorType() 105coo_tensor = typing.COOTensorType() 106csr_tensor = typing.CSRTensorType() 107undetermined = typing.UndeterminedType() 108 109function = typing.Function() 110symbolic_key = typing.SymbolicKeyType() 111env_type = typing.EnvType() 112type_type = typing.TypeType() 113type_refkey = typing.RefKeyType() 114 115Int = typing.Int 116Float = typing.Float 117Bool = typing.Bool 118String = typing.String 119List = typing.List 120Tuple = typing.Tuple 121Dict = typing.Dict 122Slice = typing.Slice 123FunctionType = typing.Function 124Ellipsis_ = typing.TypeEllipsis 125MsClassType = typing.TypeMsClassType 126NoneType = typing.TypeNone 127EnvType = typing.EnvType 128TensorType = typing.TensorType 129CSRTensorType = typing.CSRTensorType 130AnythingType = typing.TypeAny 131RefType = typing.RefType 132_NullType = typing.TypeNull 133 134number_type = (int8, 135 int16, 136 int32, 137 int64, 138 uint8, 139 uint16, 140 uint32, 141 uint64, 142 float16, 143 float32, 144 float64, 145 bfloat16, 146 complex64, 147 complex128, 148 qint4x2,) 149 150int_type = (int8, int16, int32, int64,) 151uint_type = (uint8, uint16, uint32, uint64,) 152float_type = (float16, float32, float64, bfloat16,) 153signed_type = (int8, byte, int16, short, int32, intc, int64, 154 intp, float16, half, float32, single, float64, 155 double, bfloat16, complex64, complex128) 156complex_type = (complex64, complex128,) 157all_types = (bool_, int8, uint8, int16, int32, int64, float16, float32, float64, bfloat16, complex64, complex128) 158implicit_conversion_seq = {t: idx for idx, t in enumerate(all_types)} 159 160_simple_types = { 161 list: list_, 162 tuple: tuple_, 163 type(None): type_none, 164 bool: bool_, 165 int: int64, 166 float: float64, 167 complex: complex128, 168 str: string, 169 np.bool_: bool_, 170 np.str_: string, 171 np.int8: int8, 172 np.int16: int16, 173 np.int32: int32, 174 np.int64: int64, 175 np.uint8: uint8, 176 np.uint16: uint16, 177 np.uint32: uint32, 178 np.uint64: uint64, 179 np.float16: float16, 180 np.float32: float32, 181 np.float64: float64, 182} 183 184 185def pytype_to_dtype(obj): 186 """ 187 Convert python type to MindSpore type. 188 189 Args: 190 obj (type): A python type object. 191 192 Returns: 193 Type of MindSpore type. 194 195 Raises: 196 NotImplementedError: If the python type cannot be converted to MindSpore type. 197 198 Examples: 199 >>> import mindspore as ms 200 >>> out = ms.pytype_to_dtype(bool) 201 >>> print(out) 202 Bool 203 """ 204 205 if isinstance(obj, np.dtype): 206 obj = obj.type 207 if isinstance(obj, typing.Type): 208 return obj 209 if not isinstance(obj, type): 210 raise TypeError("For 'pytype_to_dtype', the argument 'obj' must be a python type object," 211 "such as int, float, str, etc. But got type {}.".format(type(obj))) 212 if obj in _simple_types: 213 return _simple_types[obj] 214 raise NotImplementedError(f"The python type {obj} cannot be converted to MindSpore type.") 215 216 217def get_py_obj_dtype(obj): 218 """ 219 Get the MindSpore data type, which corresponds to python type or variable. 220 221 Args: 222 obj (type): An object of python type, or a variable of python type. 223 224 Returns: 225 Type of MindSpore type. 226 227 Examples: 228 >>> import mindspore as ms 229 >>> ms.get_py_obj_dtype(1) 230 mindspore.int64 231 """ 232 # Tensor 233 if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): 234 return TensorType(obj.dtype) 235 # Primitive or Cell 236 if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): 237 return function 238 # python function type 239 if isfunction(obj): 240 return function 241 # mindspore type 242 if isinstance(obj, typing.Type): 243 return type_type 244 # python type 245 if isinstance(obj, type): 246 return pytype_to_dtype(obj) 247 # others 248 return pytype_to_dtype(type(obj)) 249 250 251def dtype_to_nptype(type_): 252 r""" 253 Convert MindSpore dtype to numpy data type. 254 255 Args: 256 type\_ (:class:`mindspore.dtype`): MindSpore's dtype. 257 258 Returns: 259 The data type of numpy. 260 261 Examples: 262 >>> import mindspore as ms 263 >>> ms.dtype_to_nptype(ms.int8) 264 <class 'numpy.int8'> 265 """ 266 _dtype_nptype_dict = { 267 bool_: np.bool_, 268 int8: np.int8, 269 int16: np.int16, 270 int32: np.int32, 271 int64: np.int64, 272 uint8: np.uint8, 273 uint16: np.uint16, 274 uint32: np.uint32, 275 uint64: np.uint64, 276 float16: np.float16, 277 float32: np.float32, 278 float64: np.float64, 279 complex64: np.complex64, 280 complex128: np.complex128, 281 } 282 if np_version_valid(False): 283 _dtype_nptype_dict.update({bfloat16: np_bfloat16}) 284 return _dtype_nptype_dict[type_] 285 286 287def dtype_to_pytype(type_): 288 r""" 289 Convert MindSpore dtype to python data type. 290 291 Args: 292 type\_ (:class:`mindspore.dtype`): MindSpore's dtype. 293 294 Returns: 295 Type of python. 296 297 Examples: 298 >>> import mindspore as ms 299 >>> out = ms.dtype_to_pytype(ms.bool_) 300 >>> print(out) 301 <class 'bool'> 302 """ 303 304 return { 305 bool_: bool, 306 int_: int, 307 int8: int, 308 int16: int, 309 int32: int, 310 int64: int, 311 uint8: int, 312 uint16: int, 313 uint32: int, 314 uint64: int, 315 float_: float, 316 float16: float, 317 float32: float, 318 float64: float, 319 bfloat16: float, 320 list_: list, 321 tuple_: tuple, 322 string: str, 323 complex64: complex, 324 complex128: complex, 325 type_none: type(None) 326 }[type_] 327 328 329def _issubclass_(type_, dtype): 330 if not isinstance(type_, typing.Type): 331 return False 332 return typing.is_subclass(type_, dtype) 333 334 335 336def type_size_in_bytes(dtype): 337 """ 338 Return type size in bytes. 339 340 Args: 341 dtype (:class:`mindspore.dtype`): MindSpore dtype. 342 343 Returns: 344 Type size in bytes. 345 """ 346 347 if not isinstance(dtype, typing.Type): 348 raise TypeError("The argument `dtype` should be instance of ", typing.Type) 349 return typing.type_size_in_bytes(dtype) 350 351 352@enum.unique 353class QuantDtype(enum.Enum): 354 """ 355 An enum for quant datatype, contains `INT1` ~ `INT16`, `UINT1` ~ `UINT16`. 356 357 `QuantDtype` is defined in 358 `dtype.py <https://gitee.com/mindspore/mindspore/blob/master/mindspore/python/mindspore/common/dtype.py>`_ , 359 use command below to import: 360 361 .. code-block:: 362 363 from mindspore import QuantDtype 364 365 Tutorial Examples: 366 - `Quantization algorithm in Golden Stick 367 <https://www.mindspore.cn/golden_stick/docs/en/master/quantization/slb.html 368 #applying-the-quantization-algorithm>`_ 369 """ 370 INT1 = 0 371 INT2 = 1 372 INT3 = 2 373 INT4 = 3 374 INT5 = 4 375 INT6 = 5 376 INT7 = 6 377 INT8 = 7 378 INT9 = 8 379 INT10 = 9 380 INT11 = 10 381 INT12 = 11 382 INT13 = 12 383 INT14 = 13 384 INT15 = 14 385 INT16 = 15 386 387 UINT1 = 100 388 UINT2 = 101 389 UINT3 = 102 390 UINT4 = 103 391 UINT5 = 104 392 UINT6 = 105 393 UINT7 = 106 394 UINT8 = 107 395 UINT9 = 108 396 UINT10 = 109 397 UINT11 = 110 398 UINT12 = 111 399 UINT13 = 112 400 UINT14 = 113 401 UINT15 = 114 402 UINT16 = 115 403 404 def __str__(self): 405 return f"{self.name}" 406 407 def value(self) -> int: 408 """ 409 Return value of `QuantDtype`. This interface is currently used to serialize or deserialize `QuantDtype` 410 primarily. 411 412 Returns: 413 An int as value of `QuantDtype`. 414 415 Examples: 416 >>> from mindspore import QuantDtype 417 >>> print(QuantDtype.INT8.value()) 418 7 419 >>> print(QuantDtype.UINT16.value()) 420 115 421 """ 422 return self._value_ 423