1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2020-2021 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17"""Data type for MindSpore.""" 18 19import numpy as np 20from .._c_expression import typing, EnvInstance_ 21from .._c_expression.typing import Type 22 23__dtype__ = [ 24 "int8", "byte", 25 "int16", "short", 26 "int32", "intc", 27 "int64", "intp", 28 "uint8", "ubyte", 29 "uint16", "ushort", 30 "uint32", "uintc", 31 "uint64", "uintp", 32 "float16", "half", 33 "float32", "single", 34 "float64", "double", 35 "bool_", "float_", 36 "list_", "tuple_", 37 "int_", "uint", 38 "number", "tensor", 39 "string", "type_none", 40 "tensor_type", 41 "Type", "Int", 42 "complex64", "complex128" 43] 44 45__method__ = [ 46 "dtype_to_nptype", "issubclass_", "dtype_to_pytype", 47 "pytype_to_dtype", "get_py_obj_dtype" 48] 49 50__all__ = ["Type"] 51__all__.extend(__dtype__) 52__all__.extend(__method__) 53 54# type definition 55bool_ = typing.Bool() 56 57int8 = typing.Int(8) 58byte = int8 59int16 = typing.Int(16) 60short = int16 61int32 = typing.Int(32) 62intc = int32 63int64 = typing.Int(64) 64intp = int64 65 66uint8 = typing.UInt(8) 67ubyte = uint8 68uint16 = typing.UInt(16) 69ushort = uint16 70uint32 = typing.UInt(32) 71uintc = uint32 72uint64 = typing.UInt(64) 73uintp = uint64 74 75float16 = typing.Float(16) 76half = float16 77float32 = typing.Float(32) 78single = float32 79float64 = typing.Float(64) 80double = float64 81complex64 = typing.Complex(64) 82complex128 = typing.Complex(128) 83 84number = typing.Number() 85int_ = typing.Int() 86uint = typing.UInt() 87float_ = typing.Float() 88string = typing.String() 89list_ = typing.List() 90tuple_ = typing.Tuple() 91type_none = typing.TypeNone() 92 93tensor = typing.TensorType() 94index_slices = typing.RowTensorType() 95sparse_tensor = typing.SparseTensorType() 96undetermined = typing.UndeterminedType() 97 98function = typing.Function() 99symbolic_key = typing.SymbolicKeyType() 100env_type = typing.EnvType() 101type_type = typing.TypeType() 102type_refkey = typing.RefKeyType() 103 104Int = typing.Int 105Float = typing.Float 106Bool = typing.Bool 107String = typing.String 108List = typing.List 109Tuple = typing.Tuple 110Dict = typing.Dict 111Slice = typing.Slice 112function_type = typing.Function 113Ellipsis_ = typing.TypeEllipsis 114none_type = typing.TypeNone 115env_type_type = typing.EnvType 116tensor_type = typing.TensorType 117anything_type = typing.TypeAnything 118ref_type = typing.RefType 119 120number_type = (int8, 121 int16, 122 int32, 123 int64, 124 uint8, 125 uint16, 126 uint32, 127 uint64, 128 float16, 129 float32, 130 float64, 131 complex64, 132 complex128,) 133 134int_type = (int8, int16, int32, int64,) 135uint_type = (uint8, uint16, uint32, uint64,) 136float_type = (float16, float32, float64,) 137 138implicit_conversion_seq = {t: idx for idx, t in enumerate(( 139 bool_, int8, uint8, int16, int32, int64, float16, float32, float64, complex64, complex128))} 140 141_simple_types = { 142 list: list_, 143 tuple: tuple_, 144 type(None): type_none, 145 bool: bool_, 146 int: int64, 147 float: float64, 148 complex: complex128, 149 str: string, 150 np.bool_: bool_, 151 np.str: string, 152 np.int8: int8, 153 np.int16: int16, 154 np.int32: int32, 155 np.int64: int64, 156 np.uint8: uint8, 157 np.uint16: uint16, 158 np.uint32: uint32, 159 np.uint64: uint64, 160 np.float16: float16, 161 np.float32: float32, 162 np.float64: float64, 163 EnvInstance_: env_type, 164} 165 166 167def pytype_to_dtype(obj): 168 """ 169 Convert python type to MindSpore type. 170 171 Args: 172 obj (type): A python type object. 173 174 Returns: 175 Type of MindSpore type. 176 """ 177 178 if isinstance(obj, np.dtype): 179 obj = obj.type 180 if isinstance(obj, typing.Type): 181 return obj 182 if isinstance(obj, type) and obj in _simple_types: 183 return _simple_types[obj] 184 raise NotImplementedError(f"The python type {obj} cannot be converted to MindSpore type.") 185 186 187def get_py_obj_dtype(obj): 188 """ 189 Get the MindSpore data type, which corresponds to python type or variable. 190 191 Args: 192 obj (type): An object of python type, or a variable of python type. 193 194 Returns: 195 Type of MindSpore type. 196 """ 197 # Tensor 198 if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): 199 return tensor_type(obj.dtype) 200 # Primitive or Cell 201 if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): 202 return function 203 # mindspore type 204 if isinstance(obj, typing.Type): 205 return type_type 206 # python type 207 if isinstance(obj, type): 208 return pytype_to_dtype(obj) 209 # others 210 return pytype_to_dtype(type(obj)) 211 212 213def dtype_to_nptype(type_): 214 """ 215 Convert MindSpore dtype to numpy data type. 216 217 Args: 218 type_ (:class:`mindspore.dtype`): MindSpore's dtype. 219 220 Returns: 221 The data type of numpy. 222 """ 223 224 return { 225 bool_: np.bool_, 226 int8: np.int8, 227 int16: np.int16, 228 int32: np.int32, 229 int64: np.int64, 230 uint8: np.uint8, 231 uint16: np.uint16, 232 uint32: np.uint32, 233 uint64: np.uint64, 234 float16: np.float16, 235 float32: np.float32, 236 float64: np.float64, 237 complex64: np.complex64, 238 complex128: np.complex128, 239 }[type_] 240 241 242def dtype_to_pytype(type_): 243 """ 244 Convert MindSpore dtype to python data type. 245 246 Args: 247 type_ (:class:`mindspore.dtype`): MindSpore's dtype. 248 249 Returns: 250 Type of python. 251 """ 252 253 return { 254 bool_: bool, 255 int_: int, 256 int8: int, 257 int16: int, 258 int32: int, 259 int64: int, 260 uint8: int, 261 uint16: int, 262 uint32: int, 263 uint64: int, 264 float_: float, 265 float16: float, 266 float32: float, 267 float64: float, 268 list_: list, 269 tuple_: tuple, 270 string: str, 271 complex64: complex, 272 complex128: complex, 273 type_none: type(None) 274 }[type_] 275 276 277def issubclass_(type_, dtype): 278 """ 279 Determine whether `type_` is a subclass of `dtype`. 280 281 Args: 282 type_ (:class:`mindspore.dtype`): Target MindSpore dtype. 283 dtype (:class:`mindspore.dtype`): Compare MindSpore dtype. 284 285 Returns: 286 bool, True or False. 287 """ 288 if not isinstance(type_, typing.Type): 289 return False 290 return typing.is_subclass(type_, dtype) 291