• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2020-2021 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17"""Data type for MindSpore."""
18
19import numpy as np
20from .._c_expression import typing, EnvInstance_
21from .._c_expression.typing import Type
22
23__dtype__ = [
24    "int8", "byte",
25    "int16", "short",
26    "int32", "intc",
27    "int64", "intp",
28    "uint8", "ubyte",
29    "uint16", "ushort",
30    "uint32", "uintc",
31    "uint64", "uintp",
32    "float16", "half",
33    "float32", "single",
34    "float64", "double",
35    "bool_", "float_",
36    "list_", "tuple_",
37    "int_", "uint",
38    "number", "tensor",
39    "string", "type_none",
40    "tensor_type",
41    "Type", "Int",
42    "complex64", "complex128"
43]
44
45__method__ = [
46    "dtype_to_nptype", "issubclass_", "dtype_to_pytype",
47    "pytype_to_dtype", "get_py_obj_dtype"
48]
49
50__all__ = ["Type"]
51__all__.extend(__dtype__)
52__all__.extend(__method__)
53
54# type definition
55bool_ = typing.Bool()
56
57int8 = typing.Int(8)
58byte = int8
59int16 = typing.Int(16)
60short = int16
61int32 = typing.Int(32)
62intc = int32
63int64 = typing.Int(64)
64intp = int64
65
66uint8 = typing.UInt(8)
67ubyte = uint8
68uint16 = typing.UInt(16)
69ushort = uint16
70uint32 = typing.UInt(32)
71uintc = uint32
72uint64 = typing.UInt(64)
73uintp = uint64
74
75float16 = typing.Float(16)
76half = float16
77float32 = typing.Float(32)
78single = float32
79float64 = typing.Float(64)
80double = float64
81complex64 = typing.Complex(64)
82complex128 = typing.Complex(128)
83
84number = typing.Number()
85int_ = typing.Int()
86uint = typing.UInt()
87float_ = typing.Float()
88string = typing.String()
89list_ = typing.List()
90tuple_ = typing.Tuple()
91type_none = typing.TypeNone()
92
93tensor = typing.TensorType()
94index_slices = typing.RowTensorType()
95sparse_tensor = typing.SparseTensorType()
96undetermined = typing.UndeterminedType()
97
98function = typing.Function()
99symbolic_key = typing.SymbolicKeyType()
100env_type = typing.EnvType()
101type_type = typing.TypeType()
102type_refkey = typing.RefKeyType()
103
104Int = typing.Int
105Float = typing.Float
106Bool = typing.Bool
107String = typing.String
108List = typing.List
109Tuple = typing.Tuple
110Dict = typing.Dict
111Slice = typing.Slice
112function_type = typing.Function
113Ellipsis_ = typing.TypeEllipsis
114none_type = typing.TypeNone
115env_type_type = typing.EnvType
116tensor_type = typing.TensorType
117anything_type = typing.TypeAnything
118ref_type = typing.RefType
119
120number_type = (int8,
121               int16,
122               int32,
123               int64,
124               uint8,
125               uint16,
126               uint32,
127               uint64,
128               float16,
129               float32,
130               float64,
131               complex64,
132               complex128,)
133
134int_type = (int8, int16, int32, int64,)
135uint_type = (uint8, uint16, uint32, uint64,)
136float_type = (float16, float32, float64,)
137
138implicit_conversion_seq = {t: idx for idx, t in enumerate((
139    bool_, int8, uint8, int16, int32, int64, float16, float32, float64, complex64, complex128))}
140
141_simple_types = {
142    list: list_,
143    tuple: tuple_,
144    type(None): type_none,
145    bool: bool_,
146    int: int64,
147    float: float64,
148    complex: complex128,
149    str: string,
150    np.bool_: bool_,
151    np.str: string,
152    np.int8: int8,
153    np.int16: int16,
154    np.int32: int32,
155    np.int64: int64,
156    np.uint8: uint8,
157    np.uint16: uint16,
158    np.uint32: uint32,
159    np.uint64: uint64,
160    np.float16: float16,
161    np.float32: float32,
162    np.float64: float64,
163    EnvInstance_: env_type,
164}
165
166
167def pytype_to_dtype(obj):
168    """
169    Convert python type to MindSpore type.
170
171    Args:
172        obj (type): A python type object.
173
174    Returns:
175        Type of MindSpore type.
176    """
177
178    if isinstance(obj, np.dtype):
179        obj = obj.type
180    if isinstance(obj, typing.Type):
181        return obj
182    if isinstance(obj, type) and obj in _simple_types:
183        return _simple_types[obj]
184    raise NotImplementedError(f"The python type {obj} cannot be converted to MindSpore type.")
185
186
187def get_py_obj_dtype(obj):
188    """
189    Get the MindSpore data type, which corresponds to python type or variable.
190
191    Args:
192        obj (type): An object of python type, or a variable of python type.
193
194    Returns:
195        Type of MindSpore type.
196    """
197    # Tensor
198    if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
199        return tensor_type(obj.dtype)
200    # Primitive or Cell
201    if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
202        return function
203    # mindspore type
204    if isinstance(obj, typing.Type):
205        return type_type
206    # python type
207    if isinstance(obj, type):
208        return pytype_to_dtype(obj)
209    # others
210    return pytype_to_dtype(type(obj))
211
212
213def dtype_to_nptype(type_):
214    """
215    Convert MindSpore dtype to numpy data type.
216
217    Args:
218        type_ (:class:`mindspore.dtype`): MindSpore's dtype.
219
220    Returns:
221        The data type of numpy.
222    """
223
224    return {
225        bool_: np.bool_,
226        int8: np.int8,
227        int16: np.int16,
228        int32: np.int32,
229        int64: np.int64,
230        uint8: np.uint8,
231        uint16: np.uint16,
232        uint32: np.uint32,
233        uint64: np.uint64,
234        float16: np.float16,
235        float32: np.float32,
236        float64: np.float64,
237        complex64: np.complex64,
238        complex128: np.complex128,
239    }[type_]
240
241
242def dtype_to_pytype(type_):
243    """
244    Convert MindSpore dtype to python data type.
245
246    Args:
247        type_ (:class:`mindspore.dtype`): MindSpore's dtype.
248
249    Returns:
250        Type of python.
251    """
252
253    return {
254        bool_: bool,
255        int_: int,
256        int8: int,
257        int16: int,
258        int32: int,
259        int64: int,
260        uint8: int,
261        uint16: int,
262        uint32: int,
263        uint64: int,
264        float_: float,
265        float16: float,
266        float32: float,
267        float64: float,
268        list_: list,
269        tuple_: tuple,
270        string: str,
271        complex64: complex,
272        complex128: complex,
273        type_none: type(None)
274    }[type_]
275
276
277def issubclass_(type_, dtype):
278    """
279    Determine whether `type_` is a subclass of `dtype`.
280
281    Args:
282        type_ (:class:`mindspore.dtype`): Target MindSpore dtype.
283        dtype (:class:`mindspore.dtype`): Compare MindSpore dtype.
284
285    Returns:
286        bool, True or False.
287    """
288    if not isinstance(type_, typing.Type):
289        return False
290    return typing.is_subclass(type_, dtype)
291