• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
17 #include "include/common/np_dtype/np_dtypes.h"
18 #include <algorithm>
19 #include <string>
20 #include "numpy/arrayobject.h"
21 #include "numpy/ufuncobject.h"
22 #include "base/float16.h"
23 #include "base/bfloat16.h"
24 #include "utils/log_adapter.h"
25 
26 #if NPY_API_VERSION < 0x0000000d
27 #error Current Numpy version is too low, the required version is not less than 1.19.3.
28 #endif
29 
30 #if NPY_ABI_VERSION < 0x02000000
31 #define PyArray_DescrProto PyArray_Descr
32 #endif
33 
34 namespace mindspore {
35 namespace np_dtypes {
36 // A safe PyObject pointer which can decrement the references automatically when destructing.
37 struct PyObjDeleter {
operator ()mindspore::np_dtypes::PyObjDeleter38   void operator()(PyObject *object) const { Py_DECREF(object); }
39 };
40 using PyObjectPtr = std::unique_ptr<PyObject, PyObjDeleter>;
SafePtr(PyObject * object)41 PyObjectPtr SafePtr(PyObject *object) { return PyObjectPtr(object); }
42 
43 // Representation of a custom Python type.
44 template <typename T>
45 struct PyType {
46   PyObject_HEAD;
47   T value;
48 };
49 
50 // Description of a numpy type.
51 template <typename T>
52 struct NpTypeBaseDescr {
Dtypemindspore::np_dtypes::NpTypeBaseDescr53   static int Dtype() { return np_type_num; }
TypePtrmindspore::np_dtypes::NpTypeBaseDescr54   static PyTypeObject *TypePtr() { return np_type_ptr; }
55   static int np_type_num;
56   static PyTypeObject *np_type_ptr;
57   static PyArray_Descr np_descr;
58   static PyArray_ArrFuncs arr_funcs;
59   static PyNumberMethods number_methods;
60 };
61 
62 template <typename T>
63 int NpTypeBaseDescr<T>::np_type_num = NPY_NOTYPE;
64 template <typename T>
65 PyTypeObject *NpTypeBaseDescr<T>::np_type_ptr = nullptr;
66 template <typename T>
67 PyArray_Descr NpTypeBaseDescr<T>::np_descr;
68 template <typename T>
69 PyArray_ArrFuncs NpTypeBaseDescr<T>::arr_funcs;
70 
71 template <typename T>
72 struct NpTypeDescr {
Dtypemindspore::np_dtypes::NpTypeDescr73   static int Dtype() { return np_type_num; }
74   static int np_type_num;
75 };
76 
77 template <>
78 struct NpTypeDescr<bfloat16> : NpTypeBaseDescr<bfloat16> {
79   static constexpr const char *type_name = "bfloat16";
80   static constexpr const char *type_doc = "BFloat16 type for numpy";
81   static constexpr char kind = 'T';
82   static constexpr char type = 'T';
83   static constexpr char byte_order = '=';
84 };
85 
86 template <>
87 int NpTypeDescr<unsigned char>::np_type_num = NPY_UBYTE;
88 template <>
89 int NpTypeDescr<unsigned short>::np_type_num = NPY_USHORT;
90 template <>
91 int NpTypeDescr<unsigned int>::np_type_num = NPY_UINT;
92 template <>
93 int NpTypeDescr<unsigned long>::np_type_num = NPY_ULONG;
94 template <>
95 int NpTypeDescr<unsigned long long>::np_type_num = NPY_ULONGLONG;
96 template <>
97 int NpTypeDescr<char>::np_type_num = NPY_BYTE;
98 template <>
99 int NpTypeDescr<short>::np_type_num = NPY_SHORT;
100 template <>
101 int NpTypeDescr<int>::np_type_num = NPY_INT;
102 template <>
103 int NpTypeDescr<long>::np_type_num = NPY_LONG;
104 template <>
105 int NpTypeDescr<long long>::np_type_num = NPY_LONGLONG;
106 template <>
107 int NpTypeDescr<bool>::np_type_num = NPY_BOOL;
108 template <>
109 int NpTypeDescr<float16>::np_type_num = NPY_HALF;
110 template <>
111 int NpTypeDescr<float>::np_type_num = NPY_FLOAT;
112 template <>
113 int NpTypeDescr<double>::np_type_num = NPY_ULONG;
114 template <>
115 int NpTypeDescr<long double>::np_type_num = NPY_LONGDOUBLE;
116 
117 // Check if object is specific numpy custom type.
118 template <typename T>
PyType_CheckType(PyObject * object)119 bool PyType_CheckType(PyObject *object) {
120   return PyObject_IsInstance(object, reinterpret_cast<PyObject *>(NpTypeDescr<T>::TypePtr()));
121 }
122 
123 // Get value in the Python type object.
124 template <typename T>
PyType_GetValue(PyObject * object)125 T PyType_GetValue(PyObject *object) {
126   return reinterpret_cast<PyType<T> *>(object)->value;
127 }
128 
129 // Create PyTypeObject<T> data from T value.
130 template <typename T>
PyTypeFromValue(T value)131 PyObjectPtr PyTypeFromValue(T value) {
132   PyTypeObject *np_type_p = NpTypeDescr<T>::TypePtr();
133   PyObjectPtr npy_data_p = SafePtr(np_type_p->tp_alloc(np_type_p, 0));
134   PyType<T> *data_p = reinterpret_cast<PyType<T> *>(npy_data_p.get());
135   if (data_p) {
136     data_p->value = value;
137   }
138   return npy_data_p;
139 }
140 
141 template <typename T>
PyType_Add(PyObject * a,PyObject * b)142 PyObject *PyType_Add(PyObject *a, PyObject *b) {
143   if (PyType_CheckType<T>(a) && PyType_CheckType<T>(b)) {
144     return PyTypeFromValue<T>(PyType_GetValue<T>(a) + PyType_GetValue<T>(b)).release();
145   }
146   return PyArray_Type.tp_as_number->nb_add(a, b);
147 }
148 
149 template <typename T>
PyType_Subtract(PyObject * a,PyObject * b)150 PyObject *PyType_Subtract(PyObject *a, PyObject *b) {
151   if (PyType_CheckType<T>(a) && PyType_CheckType<T>(b)) {
152     return PyTypeFromValue<T>(PyType_GetValue<T>(a) - PyType_GetValue<T>(b)).release();
153   }
154   return PyArray_Type.tp_as_number->nb_subtract(a, b);
155 }
156 
157 template <typename T>
PyType_Multiply(PyObject * a,PyObject * b)158 PyObject *PyType_Multiply(PyObject *a, PyObject *b) {
159   if (PyType_CheckType<T>(a) && PyType_CheckType<T>(b)) {
160     return PyTypeFromValue<T>(PyType_GetValue<T>(a) * PyType_GetValue<T>(b)).release();
161   }
162   return PyArray_Type.tp_as_number->nb_multiply(a, b);
163 }
164 
165 template <typename T>
PyType_Divide(PyObject * a,PyObject * b)166 PyObject *PyType_Divide(PyObject *a, PyObject *b) {
167   if (PyType_CheckType<T>(a) && PyType_CheckType<T>(b)) {
168     return PyTypeFromValue<T>(PyType_GetValue<T>(a) / PyType_GetValue<T>(b)).release();
169   }
170   return PyArray_Type.tp_as_number->nb_true_divide(a, b);
171 }
172 
173 template <typename T>
PyType_Negative(PyObject * self)174 PyObject *PyType_Negative(PyObject *self) {
175   return PyTypeFromValue<T>(-PyType_GetValue<T>(self)).release();
176 }
177 
178 template <typename T>
PyType_Int(PyObject * self)179 PyObject *PyType_Int(PyObject *self) {
180   T value = PyType_GetValue<T>(self);
181   return PyLong_FromLong(static_cast<long>(static_cast<float>(value)));
182 }
183 
184 template <typename T>
PyType_Float(PyObject * self)185 PyObject *PyType_Float(PyObject *self) {
186   T value = PyType_GetValue<T>(self);
187   return PyFloat_FromDouble(static_cast<double>(static_cast<float>(value)));
188 }
189 
190 template <typename T>
191 PyNumberMethods NpTypeBaseDescr<T>::number_methods = {
192   PyType_Add<T>,       // nb_add
193   PyType_Subtract<T>,  // nb_subtract
194   PyType_Multiply<T>,  // nb_multiply
195   nullptr,             // nb_remainder
196   nullptr,             // nb_divmod
197   nullptr,             // nb_power
198   PyType_Negative<T>,  // nb_negative
199   nullptr,             // nb_positive
200   nullptr,             // nb_absolute
201   nullptr,             // nb_nonzero
202   nullptr,             // nb_invert
203   nullptr,             // nb_lshift
204   nullptr,             // nb_rshift
205   nullptr,             // nb_and
206   nullptr,             // nb_xor
207   nullptr,             // nb_or
208   PyType_Int<T>,       // nb_int
209   nullptr,             // reserved
210   PyType_Float<T>,     // nb_float
211   nullptr,             // nb_inplace_add
212   nullptr,             // nb_inplace_subtract
213   nullptr,             // nb_inplace_multiply
214   nullptr,             // nb_inplace_remainder
215   nullptr,             // nb_inplace_power
216   nullptr,             // nb_inplace_lshift
217   nullptr,             // nb_inplace_rshift
218   nullptr,             // nb_inplace_and
219   nullptr,             // nb_inplace_xor
220   nullptr,             // nb_inplace_or
221   nullptr,             // nb_floor_divide
222   PyType_Divide<T>,    // nb_true_divide
223   nullptr,             // nb_inplace_floor_divide
224   nullptr,             // nb_inplace_true_divide
225   nullptr,             // nb_index
226 };
227 
228 template <typename TypeIn, typename TypeOut, typename Func>
229 struct UnaryUFunc {
Typesmindspore::np_dtypes::UnaryUFunc230   static std::vector<int> Types() { return {NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeOut>::Dtype()}; }
Fnmindspore::np_dtypes::UnaryUFunc231   static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
232     const char *arg_p = args[0];
233     char *out_p = args[1];
234     for (npy_intp d = 0; d < *dimensions; d++) {
235       auto arg = *reinterpret_cast<const TypeIn *>(arg_p);
236       *reinterpret_cast<TypeOut *>(out_p) = Func()(arg);
237       arg_p += steps[0];
238       out_p += steps[1];
239     }
240   }
241 };
242 
243 template <typename TypeIn, typename TypeOut, typename TypeOut2, typename Func>
244 struct UnaryUFunc2 {
Typesmindspore::np_dtypes::UnaryUFunc2245   static std::vector<int> Types() {
246     return {NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeOut>::Dtype(), NpTypeDescr<TypeOut2>::Dtype()};
247   }
Fnmindspore::np_dtypes::UnaryUFunc2248   static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
249     const char *arg_p = args[0];
250     char *out0_p = args[1];
251     char *out1_p = args[2];
252     for (npy_intp d = 0; d < *dimensions; d++) {
253       auto arg = *reinterpret_cast<const TypeIn *>(arg_p);
254       std::tie(*reinterpret_cast<TypeOut *>(out0_p), *reinterpret_cast<TypeOut2 *>(out1_p)) = Func()(arg);
255       arg_p += steps[0];
256       out0_p += steps[1];
257       out1_p += steps[2];
258     }
259   }
260 };
261 
262 template <typename TypeIn, typename TypeOut, typename Func>
263 struct BinaryUFunc {
Typesmindspore::np_dtypes::BinaryUFunc264   static std::vector<int> Types() {
265     return {NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeOut>::Dtype()};
266   }
Fnmindspore::np_dtypes::BinaryUFunc267   static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
268     const char *arg0_p = args[0];
269     const char *arg1_p = args[1];
270     char *out_p = args[2];
271     for (npy_intp d = 0; d < *dimensions; d++) {
272       auto arg0 = *reinterpret_cast<const TypeIn *>(arg0_p);
273       auto arg1 = *reinterpret_cast<const TypeIn *>(arg1_p);
274       *reinterpret_cast<TypeOut *>(out_p) = Func()(arg0, arg1);
275       arg0_p += steps[0];
276       arg1_p += steps[1];
277       out_p += steps[2];
278     }
279   }
280 };
281 
282 template <typename TypeIn, typename TypeIn2, typename TypeOut, typename Func>
283 struct BinaryUFunc2 {
Typesmindspore::np_dtypes::BinaryUFunc2284   static std::vector<int> Types() {
285     return {NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeIn2>::Dtype(), NpTypeDescr<TypeOut>::Dtype()};
286   }
Fnmindspore::np_dtypes::BinaryUFunc2287   static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
288     const char *arg0_p = args[0];
289     const char *arg1_p = args[1];
290     char *out_p = args[2];
291     for (npy_intp d = 0; d < *dimensions; d++) {
292       auto arg0 = *reinterpret_cast<const TypeIn *>(arg0_p);
293       auto arg1 = *reinterpret_cast<const TypeIn2 *>(arg1_p);
294       *reinterpret_cast<TypeOut *>(out_p) = Func()(arg0, arg1);
295       arg0_p += steps[0];
296       arg1_p += steps[1];
297       out_p += steps[2];
298     }
299   }
300 };
301 namespace ufuncs {
302 // Implementation of Numpy universal functions.
303 template <typename T>
304 struct Add {
operator ()mindspore::np_dtypes::ufuncs::Add305   T operator()(T a, T b) { return a + b; }
306 };
307 template <typename T>
308 struct Subtract {
operator ()mindspore::np_dtypes::ufuncs::Subtract309   T operator()(T a, T b) { return a - b; }
310 };
311 template <typename T>
312 struct Multiply {
operator ()mindspore::np_dtypes::ufuncs::Multiply313   T operator()(T a, T b) { return a * b; }
314 };
315 template <typename T>
316 struct Divide {
operator ()mindspore::np_dtypes::ufuncs::Divide317   T operator()(T a, T b) { return a / b; }
318 };
divmod(float a,float b)319 inline std::pair<float, float> divmod(float a, float b) {
320   if (b == 0.0f) {
321     float nan = std::numeric_limits<float>::quiet_NaN();
322     return {nan, nan};
323   }
324   float mod = std::fmod(a, b);
325   float div = (a - mod) / b;
326   if (mod == 0.0f) {
327     mod = std::copysign(0.0f, b);
328   } else if ((b < 0.0f) != (mod < 0.0f)) {
329     mod += b;
330     div -= 1.0f;
331   }
332   float floor_div;
333   if (div != 0.0f) {
334     floor_div = std::floor(div);
335     if (div - floor_div > 0.5f) {
336       floor_div += 1.0f;
337     }
338   } else {
339     floor_div = std::copysign(0.0f, a / b);
340   }
341   return {floor_div, mod};
342 }
343 template <typename T>
344 struct DivmodUFunc {
Typesmindspore::np_dtypes::ufuncs::DivmodUFunc345   static std::vector<int> Types() {
346     return {NpTypeDescr<T>::Dtype(), NpTypeDescr<T>::Dtype(), NpTypeDescr<T>::Dtype(), NpTypeDescr<T>::Dtype()};
347   }
Fnmindspore::np_dtypes::ufuncs::DivmodUFunc348   static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
349     const char *arg0_p = args[0];
350     const char *arg1_p = args[1];
351     char *out0_p = args[2];
352     char *out1_p = args[3];
353     for (npy_intp d = 0; d < *dimensions; d++) {
354       T arg0 = *reinterpret_cast<const T *>(arg0_p);
355       T arg1 = *reinterpret_cast<const T *>(arg1_p);
356       float floordiv, mod;
357       std::tie(floordiv, mod) = divmod(static_cast<float>(arg0), static_cast<float>(arg1));
358       *reinterpret_cast<T *>(out0_p) = T(floordiv);
359       *reinterpret_cast<T *>(out1_p) = T(mod);
360       arg0_p += steps[0];
361       arg1_p += steps[1];
362       out0_p += steps[2];
363       out1_p += steps[3];
364     }
365   }
366 };
367 template <typename T>
368 struct FloorDivide {
operator ()mindspore::np_dtypes::ufuncs::FloorDivide369   T operator()(T a, T b) { return T(divmod(static_cast<float>(a), static_cast<float>(b)).first); }
370 };
371 template <typename T>
372 struct Remainder {
operator ()mindspore::np_dtypes::ufuncs::Remainder373   T operator()(T a, T b) { return T(divmod(static_cast<float>(a), static_cast<float>(b)).second); }
374 };
375 template <typename T>
376 struct Fmod {
operator ()mindspore::np_dtypes::ufuncs::Fmod377   T operator()(T a, T b) { return T(std::fmod(static_cast<float>(a), static_cast<float>(b))); }
378 };
379 template <typename T>
380 struct Negative {
operator ()mindspore::np_dtypes::ufuncs::Negative381   T operator()(T a) { return -a; }
382 };
383 template <typename T>
384 struct Positive {
operator ()mindspore::np_dtypes::ufuncs::Positive385   T operator()(T a) { return a; }
386 };
387 template <typename T>
388 struct Power {
operator ()mindspore::np_dtypes::ufuncs::Power389   T operator()(T a, T b) { return pow(a, b); }
390 };
391 template <typename T>
392 struct Abs {
operator ()mindspore::np_dtypes::ufuncs::Abs393   T operator()(T a) { return abs(a); }
394 };
395 template <typename T>
396 struct Cbrt {
operator ()mindspore::np_dtypes::ufuncs::Cbrt397   T operator()(T a) { return T(std::cbrt(static_cast<float>(a))); }
398 };
399 template <typename T>
400 struct Ceil {
operator ()mindspore::np_dtypes::ufuncs::Ceil401   T operator()(T a) { return ceil(a); }
402 };
403 template <typename T>
404 struct CopySign {
operator ()mindspore::np_dtypes::ufuncs::CopySign405   T operator()(T a, T b) { return T(std::copysign(static_cast<float>(a), static_cast<float>(b))); }
406 };
407 template <typename T>
408 struct Exp {
operator ()mindspore::np_dtypes::ufuncs::Exp409   T operator()(T a) { return exp(a); }
410 };
411 template <typename T>
412 struct Exp2 {
operator ()mindspore::np_dtypes::ufuncs::Exp2413   T operator()(T a) { return T(std::exp2(static_cast<float>(a))); }
414 };
415 template <typename T>
416 struct Expm1 {
operator ()mindspore::np_dtypes::ufuncs::Expm1417   T operator()(T a) { return T(std::expm1(static_cast<float>(a))); }
418 };
419 template <typename T>
420 struct Floor {
operator ()mindspore::np_dtypes::ufuncs::Floor421   T operator()(T a) { return floor(a); }
422 };
423 template <typename T>
424 struct Frexp {
operator ()mindspore::np_dtypes::ufuncs::Frexp425   std::pair<T, int> operator()(T a) {
426     int exp;
427     float f = std::frexp(static_cast<float>(a), &exp);
428     return {T(f), exp};
429   }
430 };
431 template <typename T>
432 struct Heaviside {
operator ()mindspore::np_dtypes::ufuncs::Heaviside433   T operator()(T x, T h0) {
434     if (isnan(x)) {
435       return x;
436     }
437     if (x < T(0)) {
438       return T(0);
439     }
440     if (x > T(0)) {
441       return T(1);
442     }
443     return h0;
444   }
445 };
446 template <typename T>
447 struct Conjugate {
operator ()mindspore::np_dtypes::ufuncs::Conjugate448   T operator()(T a) { return a; }
449 };
450 template <typename T>
451 struct IsFinite {
operator ()mindspore::np_dtypes::ufuncs::IsFinite452   bool operator()(T a) { return isfinite(a); }
453 };
454 template <typename T>
455 struct IsInf {
operator ()mindspore::np_dtypes::ufuncs::IsInf456   bool operator()(T a) { return isinf(a); }
457 };
458 template <typename T>
459 struct IsNan {
operator ()mindspore::np_dtypes::ufuncs::IsNan460   bool operator()(T a) { return isnan(a); }
461 };
462 template <typename T>
463 struct Ldexp {
operator ()mindspore::np_dtypes::ufuncs::Ldexp464   T operator()(T a, int exp) { return T(std::ldexp(static_cast<float>(a), exp)); }
465 };
466 template <typename T>
467 struct Log {
operator ()mindspore::np_dtypes::ufuncs::Log468   T operator()(T a) { return log(a); }
469 };
470 template <typename T>
471 struct Log1p {
operator ()mindspore::np_dtypes::ufuncs::Log1p472   T operator()(T a) { return T(std::log1p(static_cast<float>(a))); }
473 };
474 template <typename T>
475 struct Log2 {
operator ()mindspore::np_dtypes::ufuncs::Log2476   T operator()(T a) { return T(std::log2(static_cast<float>(a))); }
477 };
478 template <typename T>
479 struct Log10 {
operator ()mindspore::np_dtypes::ufuncs::Log10480   T operator()(T a) { return T(std::log10(static_cast<float>(a))); }
481 };
482 template <typename T>
483 struct LogAddExp {
operator ()mindspore::np_dtypes::ufuncs::LogAddExp484   T operator()(T a, T b) {
485     float x = static_cast<float>(a);
486     float y = static_cast<float>(b);
487     if (x == y) {
488       return T(x + std::log(2.0f));
489     }
490     float out = std::numeric_limits<float>::quiet_NaN();
491     if (x > y) {
492       out = x + std::log1p(std::exp(y - x));
493     } else if (x < y) {
494       out = y + std::log1p(std::exp(x - y));
495     }
496     return T(out);
497   }
498 };
499 template <typename T>
500 struct LogAddExp2 {
operator ()mindspore::np_dtypes::ufuncs::LogAddExp2501   T operator()(T a, T b) {
502     float x = static_cast<float>(a);
503     float y = static_cast<float>(b);
504     if (x == y) {
505       return T(x + 1.0f);
506     }
507     float out = std::numeric_limits<float>::quiet_NaN();
508     if (x > y) {
509       out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
510     } else if (x < y) {
511       out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
512     }
513     return T(out);
514   }
515 };
516 template <typename T>
517 struct Modf {
operator ()mindspore::np_dtypes::ufuncs::Modf518   std::pair<T, T> operator()(T a) {
519     float integral;
520     float f = std::modf(static_cast<float>(a), &integral);
521     return {T(f), T(integral)};
522   }
523 };
524 template <typename T>
525 struct Reciprocal {
operator ()mindspore::np_dtypes::ufuncs::Reciprocal526   T operator()(T a) { return T(1.f / static_cast<float>(a)); }
527 };
528 template <typename T>
529 struct Rint {
operator ()mindspore::np_dtypes::ufuncs::Rint530   T operator()(T a) { return T(std::rint(static_cast<float>(a))); }
531 };
532 template <typename T>
533 struct Sign {
operator ()mindspore::np_dtypes::ufuncs::Sign534   T operator()(T a) {
535     if (isnan(a)) {
536       return a;
537     }
538     if (a < T(0)) {
539       return T(-1);
540     }
541     if (a > T(0)) {
542       return T(1);
543     }
544     return a;
545   }
546 };
547 template <typename T>
548 struct SignBit {
operator ()mindspore::np_dtypes::ufuncs::SignBit549   bool operator()(T a) { return std::signbit(static_cast<float>(a)); }
550 };
551 template <typename T>
552 struct Sqrt {
operator ()mindspore::np_dtypes::ufuncs::Sqrt553   T operator()(T a) { return T(std::sqrt(static_cast<float>(a))); }
554 };
555 template <typename T>
556 struct Square {
operator ()mindspore::np_dtypes::ufuncs::Square557   T operator()(T a) {
558     float f(a);
559     return T(f * f);
560   }
561 };
562 template <typename T>
563 struct Trunc {
operator ()mindspore::np_dtypes::ufuncs::Trunc564   T operator()(T a) { return T(std::trunc(static_cast<float>(a))); }
565 };
566 // Trigonometric functions
567 template <typename T>
568 struct Sin {
operator ()mindspore::np_dtypes::ufuncs::Sin569   T operator()(T a) { return sin(a); }
570 };
571 template <typename T>
572 struct Cos {
operator ()mindspore::np_dtypes::ufuncs::Cos573   T operator()(T a) { return cos(a); }
574 };
575 template <typename T>
576 struct Tan {
operator ()mindspore::np_dtypes::ufuncs::Tan577   T operator()(T a) { return tan(a); }
578 };
579 template <typename T>
580 struct Arcsin {
operator ()mindspore::np_dtypes::ufuncs::Arcsin581   T operator()(T a) { return T(std::asin(static_cast<float>(a))); }
582 };
583 template <typename T>
584 struct Arccos {
operator ()mindspore::np_dtypes::ufuncs::Arccos585   T operator()(T a) { return T(std::acos(static_cast<float>(a))); }
586 };
587 template <typename T>
588 struct Arctan {
operator ()mindspore::np_dtypes::ufuncs::Arctan589   T operator()(T a) { return T(std::atan(static_cast<float>(a))); }
590 };
591 template <typename T>
592 struct Arctan2 {
operator ()mindspore::np_dtypes::ufuncs::Arctan2593   T operator()(T a, T b) { return T(std::atan2(static_cast<float>(a), static_cast<float>(b))); }
594 };
595 template <typename T>
596 struct Hypot {
operator ()mindspore::np_dtypes::ufuncs::Hypot597   T operator()(T a, T b) { return T(std::hypot(static_cast<float>(a), static_cast<float>(b))); }
598 };
599 template <typename T>
600 struct Sinh {
operator ()mindspore::np_dtypes::ufuncs::Sinh601   T operator()(T a) { return T(std::sinh(static_cast<float>(a))); }
602 };
603 template <typename T>
604 struct Cosh {
operator ()mindspore::np_dtypes::ufuncs::Cosh605   T operator()(T a) { return T(std::cosh(static_cast<float>(a))); }
606 };
607 template <typename T>
608 struct Tanh {
operator ()mindspore::np_dtypes::ufuncs::Tanh609   T operator()(T a) { return tanh(a); }
610 };
611 template <typename T>
612 struct Arcsinh {
operator ()mindspore::np_dtypes::ufuncs::Arcsinh613   T operator()(T a) { return T(std::asinh(static_cast<float>(a))); }
614 };
615 template <typename T>
616 struct Arccosh {
operator ()mindspore::np_dtypes::ufuncs::Arccosh617   T operator()(T a) { return T(std::acosh(static_cast<float>(a))); }
618 };
619 template <typename T>
620 struct Arctanh {
operator ()mindspore::np_dtypes::ufuncs::Arctanh621   T operator()(T a) { return T(std::atanh(static_cast<float>(a))); }
622 };
623 template <typename T>
624 struct Deg2rad {
operator ()mindspore::np_dtypes::ufuncs::Deg2rad625   T operator()(T a) {
626     static constexpr float PI = 3.14159265358979323846f;
627     static constexpr float RADIANS_PER_DEGREE = PI / 180.0f;
628     return T(static_cast<float>(a) * RADIANS_PER_DEGREE);
629   }
630 };
631 template <typename T>
632 struct Rad2deg {
operator ()mindspore::np_dtypes::ufuncs::Rad2deg633   T operator()(T a) {
634     static constexpr float PI = 3.14159265358979323846f;
635     static constexpr float DEGREES_PER_RADIAN = 180.0f / PI;
636     return T(static_cast<float>(a) * DEGREES_PER_RADIAN);
637   }
638 };
639 template <typename T>
640 struct Eq {
operator ()mindspore::np_dtypes::ufuncs::Eq641   npy_bool operator()(T a, T b) { return a == b; }
642 };
643 template <typename T>
644 struct Ne {
operator ()mindspore::np_dtypes::ufuncs::Ne645   npy_bool operator()(T a, T b) { return a != b; }
646 };
647 template <typename T>
648 struct Lt {
operator ()mindspore::np_dtypes::ufuncs::Lt649   npy_bool operator()(T a, T b) { return a < b; }
650 };
651 template <typename T>
652 struct Le {
operator ()mindspore::np_dtypes::ufuncs::Le653   npy_bool operator()(T a, T b) { return a <= b; }
654 };
655 template <typename T>
656 struct Gt {
operator ()mindspore::np_dtypes::ufuncs::Gt657   npy_bool operator()(T a, T b) { return a > b; }
658 };
659 template <typename T>
660 struct Ge {
operator ()mindspore::np_dtypes::ufuncs::Ge661   npy_bool operator()(T a, T b) { return a >= b; }
662 };
663 template <typename T>
664 struct Maximum {
operator ()mindspore::np_dtypes::ufuncs::Maximum665   T operator()(T a, T b) { return isnan(a) || a > b ? a : b; }
666 };
667 template <typename T>
668 struct Minimum {
operator ()mindspore::np_dtypes::ufuncs::Minimum669   T operator()(T a, T b) { return isnan(a) || a < b ? a : b; }
670 };
671 template <typename T>
672 struct Fmax {
operator ()mindspore::np_dtypes::ufuncs::Fmax673   T operator()(T a, T b) { return isnan(b) || a > b ? a : b; }
674 };
675 template <typename T>
676 struct Fmin {
operator ()mindspore::np_dtypes::ufuncs::Fmin677   T operator()(T a, T b) { return isnan(b) || a < b ? a : b; }
678 };
679 template <typename T>
680 struct LogicalNot {
operator ()mindspore::np_dtypes::ufuncs::LogicalNot681   npy_bool operator()(T a) { return !static_cast<bool>(a); }
682 };
683 template <typename T>
684 struct LogicalAnd {
operator ()mindspore::np_dtypes::ufuncs::LogicalAnd685   npy_bool operator()(T a, T b) { return static_cast<bool>(a) && static_cast<bool>(b); }
686 };
687 template <typename T>
688 struct LogicalOr {
operator ()mindspore::np_dtypes::ufuncs::LogicalOr689   npy_bool operator()(T a, T b) { return static_cast<bool>(a) || static_cast<bool>(b); }
690 };
691 template <typename T>
692 struct LogicalXor {
operator ()mindspore::np_dtypes::ufuncs::LogicalXor693   npy_bool operator()(T a, T b) { return static_cast<bool>(a) ^ static_cast<bool>(b); }
694 };
695 // Get unsigned integer type with same size of T.
696 template <int kNumBytes>
697 struct GetUnsignedInteger;
698 template <>
699 struct GetUnsignedInteger<1> {
700   using uint_type = uint8_t;
701 };
702 template <>
703 struct GetUnsignedInteger<2> {
704   using uint_type = uint16_t;
705 };
706 template <>
707 struct GetUnsignedInteger<4> {
708   using uint_type = uint32_t;
709 };
710 template <typename T>
711 using UIntType = typename GetUnsignedInteger<sizeof(T)>::uint_type;
712 template <typename TypeIn, typename TypeOut>
bit_cast(TypeIn value)713 TypeOut bit_cast(TypeIn value) {
714   static_assert(sizeof(TypeIn) == sizeof(TypeOut), "For bit_cast, types must match size.");
715   TypeOut out = TypeOut(0);
716   errno_t ret = memcpy_s(&out, sizeof(TypeOut), &value, sizeof(TypeIn));
717   if (ret != EOK) {
718     PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d", ret);
719     return out;
720   }
721   return out;
722 }
723 template <typename T>
724 struct NextAfter {
operator ()mindspore::np_dtypes::ufuncs::NextAfter725   T operator()(T from, T to) {
726     if (isnan(from) || isnan(to)) {
727       return std::numeric_limits<T>::quiet_NaN();
728     }
729     UIntType<T> from_uint = bit_cast<T, UIntType<T>>(from);
730     UIntType<T> to_uint = bit_cast<T, UIntType<T>>(to);
731     if (from_uint == to_uint) {
732       return to;
733     }
734     UIntType<T> sign_mask = UIntType<T>(1) << (sizeof(T) * CHAR_BIT - 1);
735     UIntType<T> from_uint_abs = bit_cast<T, UIntType<T>>(abs(from));
736     UIntType<T> from_uint_sign = from_uint & sign_mask;
737     UIntType<T> to_uint_abs = bit_cast<T, UIntType<T>>(abs(to));
738     UIntType<T> to_uint_sign = to_uint & sign_mask;
739     if (from_uint_abs == 0) {
740       if (to_uint_abs == 0) {
741         return to;
742       } else {
743         // Minimum non-zero value with sign bit of `to`.
744         return bit_cast<UIntType<T>, T>(static_cast<UIntType<T>>(0x01 | to_uint_sign));
745       }
746     }
747     UIntType<T> next_step = (from_uint_abs > to_uint_abs || from_uint_sign != to_uint_sign)
748                               ? static_cast<UIntType<T>>(-1)
749                               : static_cast<UIntType<T>>(1);
750     UIntType<T> out_uint = from_uint + next_step;
751     return bit_cast<UIntType<T>, T>(out_uint);
752   }
753 };
754 }  // namespace ufuncs
755 
756 // Cast input object to Python type T.
757 template <typename T>
CastToPyType(PyObject * obj,T * output)758 bool CastToPyType(PyObject *obj, T *output) {
759   // object is an instance of NpTypeDescr
760   if (PyType_CheckType<T>(obj)) {
761     *output = PyType_GetValue<T>(obj);
762     return true;
763   }
764   // object is an instance of int
765   if (PyLong_Check(obj)) {
766     long value = PyLong_AsLong(obj);
767     if (PyErr_Occurred()) {
768       return false;
769     }
770     *output = T(value);
771     return true;
772   }
773   // object is an instance of float
774   if (PyFloat_Check(obj)) {
775     double value = PyFloat_AsDouble(obj);
776     if (PyErr_Occurred()) {
777       return false;
778     }
779     *output = T(value);
780     return true;
781   }
782   // object is an instance of scalar float16
783   if (PyArray_IsScalar(obj, Half)) {
784     float16 value;
785     PyArray_ScalarAsCtype(obj, &value);
786     *output = T(value);
787     return true;
788   }
789   // object is an instance of scalar float
790   if (PyArray_IsScalar(obj, Float)) {
791     float value;
792     PyArray_ScalarAsCtype(obj, &value);
793     *output = T(value);
794     return true;
795   }
796   // object is an instance of scalar double
797   if (PyArray_IsScalar(obj, Double)) {
798     double value;
799     PyArray_ScalarAsCtype(obj, &value);
800     *output = T(value);
801     return true;
802   }
803   // object is an instance of scalar long double
804   if (PyArray_IsScalar(obj, LongDouble)) {
805     long double value;
806     PyArray_ScalarAsCtype(obj, &value);
807     *output = T(value);
808     return true;
809   }
810   // object is an instance of 0-dim array
811   if (PyArray_IsZeroDim(obj)) {
812     PyArrayObject *arr = reinterpret_cast<PyArrayObject *>(obj);
813     // cast value in array to type T
814     if (PyArray_TYPE(arr) != NpTypeDescr<T>::Dtype()) {
815       PyObjectPtr new_arr = SafePtr(PyArray_Cast(arr, NpTypeDescr<T>::Dtype()));
816       if (PyErr_Occurred()) {
817         return false;
818       }
819       arr = reinterpret_cast<PyArrayObject *>(new_arr.get());
820     }
821     *output = *reinterpret_cast<T *>(PyArray_DATA(arr));
822     return true;
823   }
824   return false;
825 }
826 
827 // Constructs a new Python type.
828 template <typename T>
PyType_New(PyTypeObject * type,PyObject * args,PyObject * kwds)829 PyObject *PyType_New(PyTypeObject *type, PyObject *args, PyObject *kwds) {
830   if (kwds && PyDict_Size(kwds)) {
831     PyErr_Format(PyExc_TypeError, "No keyword arguments should be provided when constructing %s",
832                  NpTypeDescr<T>::type_name);
833     return nullptr;
834   }
835   Py_ssize_t arg_num = PyTuple_Size(args);
836   if (arg_num != 1) {
837     PyErr_Format(PyExc_TypeError, "One argument is expected when constructing %s, but got %d.",
838                  NpTypeDescr<T>::type_name, arg_num);
839     return nullptr;
840   }
841   PyObject *arg = PyTuple_GetItem(args, 0);
842   T value;
843   // If arg is already NpTypeDescr<T>, just return it.
844   if (PyType_CheckType<T>(arg)) {
845     Py_INCREF(arg);
846     return arg;
847   }
848   // If arg can be casted to T value, create NpTypeDescr<T> from the value.
849   if (CastToPyType<T>(arg, &value)) {
850     return PyTypeFromValue<T>(value).release();
851   }
852   // If arg is an array, cast it to NpTypeDescr<T>
853   if (PyArray_Check(arg)) {
854     PyArrayObject *arr = reinterpret_cast<PyArrayObject *>(arg);
855     if (PyArray_TYPE(arr) != NpTypeDescr<T>::Dtype()) {
856       return PyArray_Cast(arr, NpTypeDescr<T>::Dtype());
857     } else {
858       Py_INCREF(arg);
859       return arg;
860     }
861   }
862   // If arg is unicodes or bytes, convert it from string to float, then cast the float to T value,
863   // and then create NpTypeDescr<T> from the value.
864   if (PyUnicode_Check(arg) || PyBytes_Check(arg)) {
865     PyObject *value_f = PyFloat_FromString(arg);
866     if (CastToPyType<T>(value_f, &value)) {
867       return PyTypeFromValue<T>(value).release();
868     }
869   }
870   PyErr_Format(PyExc_TypeError, "Only number argument is expected when constructing %s, but got %s.",
871                NpTypeDescr<T>::type_name, Py_TYPE(arg)->tp_name);
872   return nullptr;
873 }
874 
875 // Implementation of repr() for PyType.
876 template <typename T>
PyType_Repr(PyObject * self)877 PyObject *PyType_Repr(PyObject *self) {
878   T value = reinterpret_cast<PyType<T> *>(self)->value;
879   std::string value_str = std::to_string(static_cast<float>(value));
880   return PyUnicode_FromString(value_str.c_str());
881 }
882 
883 // Overload function _Py_HashDouble to support Python version over 3.10.
HashDouble_(Py_hash_t (* hash_double)(PyObject *,double),PyObject * self,double value)884 inline Py_hash_t HashDouble_(Py_hash_t (*hash_double)(PyObject *, double), PyObject *self, double value) {
885   return hash_double(self, value);
886 }
887 
HashDouble_(Py_hash_t (* hash_double)(double),PyObject * self,double value)888 inline Py_hash_t HashDouble_(Py_hash_t (*hash_double)(double), PyObject *self, double value) {
889   return hash_double(value);
890 }
891 
892 // Implementation of hash() for PyType.
893 template <typename T>
PyType_Hash(PyObject * self)894 Py_hash_t PyType_Hash(PyObject *self) {
895   T value = reinterpret_cast<PyType<T> *>(self)->value;
896   return HashDouble_(&_Py_HashDouble, self, static_cast<double>(value));
897 }
898 
899 // Implementation of str() for PyType.
900 template <typename T>
PyType_Str(PyObject * self)901 PyObject *PyType_Str(PyObject *self) {
902   T value = reinterpret_cast<PyType<T> *>(self)->value;
903   std::string value_str = std::to_string(static_cast<float>(value));
904   return PyUnicode_FromString(value_str.c_str());
905 }
906 
907 // Implementation of Comparisons for PyType.
908 template <typename T>
PyType_RichCompare(PyObject * a,PyObject * b,int op)909 PyObject *PyType_RichCompare(PyObject *a, PyObject *b, int op) {
910   if (!PyType_CheckType<T>(a) || !PyType_CheckType<T>(b)) {
911     return PyGenericArrType_Type.tp_richcompare(a, b, op);
912   }
913   T x = PyType_GetValue<T>(a);
914   T y = PyType_GetValue<T>(b);
915   bool result;
916   switch (op) {
917     case Py_EQ:
918       result = (x == y);
919       break;
920     case Py_NE:
921       result = (x != y);
922       break;
923     case Py_LT:
924       result = (x < y);
925       break;
926     case Py_LE:
927       result = (x <= y);
928       break;
929     case Py_GT:
930       result = (x > y);
931       break;
932     case Py_GE:
933       result = (x >= y);
934       break;
935     default:
936       PyErr_Format(PyExc_ValueError, "Got invalid op type %d when comparing %s", op, NpTypeDescr<T>::type_name);
937       return nullptr;
938   }
939   PyObject *ret = PyBool_FromLong(result);
940   Py_INCREF(ret);
941   return ret;
942 }
943 
944 // Implementations of NumPy array methods for PyType.
945 template <typename T>
NpType_GetItem(void * data,void * arr)946 PyObject *NpType_GetItem(void *data, void *arr) {
947   T value;
948   errno_t ret = memcpy_s(&value, sizeof(T), data, sizeof(T));
949   if (ret != EOK) {
950     PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
951     return nullptr;
952   }
953   return PyTypeFromValue(value).release();
954 }
955 
956 template <typename T>
NpType_SetItem(PyObject * item,void * data,void * arr)957 int NpType_SetItem(PyObject *item, void *data, void *arr) {
958   T value;
959   if (!CastToPyType<T>(item, &value)) {
960     PyErr_Format(PyExc_TypeError, "Only number argument is expected for SetItem %s, but got %s.",
961                  NpTypeDescr<T>::type_name, Py_TYPE(item)->tp_name);
962     return -1;
963   }
964   errno_t ret = memcpy_s(data, sizeof(T), &value, sizeof(T));
965   if (ret != EOK) {
966     PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
967     return -1;
968   }
969   return 0;
970 }
971 
972 template <typename T>
NpType_Compare(const void * d1,const void * d2,void * arr)973 int NpType_Compare(const void *d1, const void *d2, void *arr) {
974   T x = *reinterpret_cast<const T *>(d1);
975   T y = *reinterpret_cast<const T *>(d2);
976   if (x < y) {
977     return -1;
978   }
979   if (y < x) {
980     return 1;
981   }
982   if (!isnan(x) && isnan(y)) {
983     return -1;
984   }
985   if (isnan(x) && !isnan(y)) {
986     return 1;
987   }
988   return 0;
989 }
990 
991 template <typename T>
NpType_CopySwapN(void * dest,npy_intp dstride,void * src,npy_intp sstride,npy_intp n,int swap,void * arr)992 void NpType_CopySwapN(void *dest, npy_intp dstride, void *src, npy_intp sstride, npy_intp n, int swap, void *arr) {
993   static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t), "Swap is not supported");
994   char *dst_p = reinterpret_cast<char *>(dest);
995   char *src_p = reinterpret_cast<char *>(src);
996   if (!src_p) {
997     return;
998   }
999   if (swap && sizeof(T) == sizeof(int16_t)) {
1000     for (npy_intp i = 0; i < n; i++) {
1001       char *r = dst_p + dstride * i;
1002       errno_t ret = memcpy_s(r, sizeof(T), src_p + sstride * i, sizeof(T));
1003       if (ret != EOK) {
1004         PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1005         return;
1006       }
1007       std::swap(r[0], r[1]);
1008     }
1009   } else if (dstride == sizeof(T) && sstride == sizeof(T)) {
1010     errno_t ret = memcpy_s(dst_p, n * sizeof(T), src_p, n * sizeof(T));
1011     if (ret != EOK) {
1012       PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1013       return;
1014     }
1015   } else {
1016     for (npy_intp i = 0; i < n; i++) {
1017       errno_t ret = memcpy_s(dst_p + dstride * i, sizeof(T), src_p + sstride * i, sizeof(T));
1018       if (ret != EOK) {
1019         PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1020         return;
1021       }
1022     }
1023   }
1024 }
1025 
1026 template <typename T>
NpType_CopySwap(void * dest,void * src,int swap,void * arr)1027 void NpType_CopySwap(void *dest, void *src, int swap, void *arr) {
1028   static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t), "Swap is not supported");
1029   if (!src) {
1030     return;
1031   }
1032   errno_t ret = memcpy_s(dest, sizeof(T), src, sizeof(T));
1033   if (ret != EOK) {
1034     PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1035     return;
1036   }
1037   if (swap && (sizeof(T) == sizeof(int16_t))) {
1038     char *p = reinterpret_cast<char *>(dest);
1039     std::swap(p[0], p[1]);
1040   }
1041 }
1042 
1043 template <typename T>
NpType_NonZero(void * data,void * arr)1044 npy_bool NpType_NonZero(void *data, void *arr) {
1045   T value;
1046   errno_t ret = memcpy_s(&value, sizeof(T), data, sizeof(T));
1047   if (ret != EOK) {
1048     PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1049     return false;
1050   }
1051   return value != static_cast<T>(0);
1052 }
1053 
1054 template <typename T>
NpType_Fill(void * data,npy_intp length,void * arr)1055 int NpType_Fill(void *data, npy_intp length, void *arr) {
1056   T *const buffer = reinterpret_cast<T *>(data);
1057   const T start(buffer[0]);
1058   const T delta = static_cast<T>(buffer[1]) - start;
1059   for (npy_intp i = 2; i < length; i++) {
1060     buffer[i] = static_cast<T>(start + T(i) * delta);
1061   }
1062   return 0;
1063 }
1064 
1065 template <typename T>
NpType_Dot(void * ip1,npy_intp is1,void * ip2,npy_intp is2,void * op,npy_intp n,void * arr)1066 void NpType_Dot(void *ip1, npy_intp is1, void *ip2, npy_intp is2, void *op, npy_intp n, void *arr) {
1067   char *p1 = reinterpret_cast<char *>(ip1);
1068   char *p2 = reinterpret_cast<char *>(ip2);
1069   T acc = T(0);
1070   for (npy_intp i = 0; i < n; i++) {
1071     T *const a = reinterpret_cast<T *>(p1);
1072     T *const b = reinterpret_cast<T *>(p2);
1073     acc += static_cast<T>(*a) * static_cast<T>(*b);
1074     p1 += is1;
1075     p2 += is2;
1076   }
1077   T *out = reinterpret_cast<T *>(op);
1078   *out = static_cast<T>(acc);
1079 }
1080 
1081 template <typename T>
NpType_ArgMax(void * data,npy_intp n,npy_intp * max_ind,void * arr)1082 int NpType_ArgMax(void *data, npy_intp n, npy_intp *max_ind, void *arr) {
1083   const T *data_p = reinterpret_cast<const T *>(data);
1084   T max_val = static_cast<T>(data_p[0]);
1085   *max_ind = 0;
1086   for (npy_intp i = 0; i < n; i++) {
1087     T val = static_cast<T>(data_p[i]);
1088     if (isnan(val) || val > max_val) {
1089       max_val = val;
1090       *max_ind = i;
1091       // NumPy stops at the first NaN.
1092       if (isnan(val)) {
1093         break;
1094       }
1095     }
1096   }
1097   return 0;
1098 }
1099 
1100 template <typename T>
NpType_ArgMin(void * data,npy_intp n,npy_intp * min_ind,void * arr)1101 int NpType_ArgMin(void *data, npy_intp n, npy_intp *min_ind, void *arr) {
1102   const T *data_p = reinterpret_cast<const T *>(data);
1103   T min_val = static_cast<T>(data_p[0]);
1104   *min_ind = 0;
1105   for (npy_intp i = 1; i < n; i++) {
1106     T val = static_cast<T>(data_p[i]);
1107     if (isnan(val) || val < min_val) {
1108       min_val = val;
1109       *min_ind = i;
1110       // NumPy stops at the first NaN.
1111       if (isnan(val)) {
1112         break;
1113       }
1114     }
1115   }
1116   return 0;
1117 }
1118 
1119 template <typename T>
GetNpDescrProto()1120 PyArray_DescrProto GetNpDescrProto() {
1121   return {
1122     PyObject_HEAD_INIT(nullptr)
1123     /*typeobj=*/nullptr,
1124     /*kind=*/NpTypeDescr<T>::kind,
1125     /*type=*/NpTypeDescr<T>::type,
1126     /*byteorder=*/NpTypeDescr<T>::byte_order,
1127     /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_SETITEM,
1128     /*type_num=*/0,
1129     /*elsize=*/sizeof(T),
1130     /*alignment=*/alignof(T),
1131     /*subarray=*/nullptr,
1132     /*fields=*/nullptr,
1133     /*names=*/nullptr,
1134     /*f=*/&NpTypeDescr<T>::arr_funcs,
1135     /*metadata=*/nullptr,
1136     /*c_metadata=*/nullptr,
1137     /*hash=*/-1,
1138   };
1139 }
1140 
1141 // Cast a numpy array from type 'From' to 'To'.
1142 template <typename From, typename To>
NpyCast(void * from,void * to,npy_intp n,void * from_arr,void * to_arr)1143 void NpyCast(void *from, void *to, npy_intp n, void *from_arr, void *to_arr) {
1144   const From *from_ptr = static_cast<From *>(from);
1145   To *to_ptr = static_cast<To *>(to);
1146   for (npy_intp i = 0; i < n; i++) {
1147     to_ptr[i] = static_cast<To>(from_ptr[i]);
1148   }
1149 }
1150 
1151 // Register a cast between T and other numpy type Y.
1152 template <typename T, typename Y>
RegisterNpTypeCast(int np_type,bool scalar_castable)1153 bool RegisterNpTypeCast(int np_type, bool scalar_castable) {
1154   PyArray_Descr *descr = PyArray_DescrFromType(np_type);
1155   if (PyArray_RegisterCastFunc(descr, NpTypeDescr<T>::Dtype(), NpyCast<Y, T>) < 0) {
1156     return false;
1157   }
1158   if (PyArray_RegisterCastFunc(&NpTypeDescr<T>::np_descr, np_type, NpyCast<T, Y>) < 0) {
1159     return false;
1160   }
1161   if (scalar_castable && PyArray_RegisterCanCast(&NpTypeDescr<T>::np_descr, np_type, NPY_NOSCALAR) < 0) {
1162     return false;
1163   }
1164   return true;
1165 }
1166 
1167 // Register casts between T and other numpy types.
1168 template <typename T>
RegisterNpTypeCasts()1169 bool RegisterNpTypeCasts() {
1170   if (!RegisterNpTypeCast<T, bool>(NPY_BOOL, false)) {
1171     return false;
1172   }
1173   if (!RegisterNpTypeCast<T, float16>(NPY_HALF, false)) {
1174     return false;
1175   }
1176   if (!RegisterNpTypeCast<T, float>(NPY_FLOAT, true)) {
1177     return false;
1178   }
1179   if (!RegisterNpTypeCast<T, double>(NPY_DOUBLE, false)) {
1180     return false;
1181   }
1182   if (!RegisterNpTypeCast<T, long double>(NPY_LONGDOUBLE, false)) {
1183     return false;
1184   }
1185   if (!RegisterNpTypeCast<T, unsigned char>(NPY_UBYTE, false)) {
1186     return false;
1187   }
1188   if (!RegisterNpTypeCast<T, unsigned short>(NPY_USHORT, false)) {
1189     return false;
1190   }
1191   if (!RegisterNpTypeCast<T, unsigned int>(NPY_UINT, false)) {
1192     return false;
1193   }
1194   if (!RegisterNpTypeCast<T, unsigned long>(NPY_ULONG, false)) {
1195     return false;
1196   }
1197   if (!RegisterNpTypeCast<T, unsigned long long>(NPY_ULONGLONG, false)) {
1198     return false;
1199   }
1200   if (!RegisterNpTypeCast<T, char>(NPY_BYTE, false)) {
1201     return false;
1202   }
1203   if (!RegisterNpTypeCast<T, short>(NPY_SHORT, false)) {
1204     return false;
1205   }
1206   if (!RegisterNpTypeCast<T, int>(NPY_INT, false)) {
1207     return false;
1208   }
1209   if (!RegisterNpTypeCast<T, long>(NPY_LONG, false)) {
1210     return false;
1211   }
1212   if (!RegisterNpTypeCast<T, long long>(NPY_LONGLONG, false)) {
1213     return false;
1214   }
1215   // Complexs are not support yet.
1216   return true;
1217 }
1218 
1219 // Register a Numpy universal function.
1220 template <typename UFunc, typename T>
RegisterNpTypeUFunc(PyObject * numpy,const char * fn_name)1221 bool RegisterNpTypeUFunc(PyObject *numpy, const char *fn_name) {
1222   std::vector<int> types = UFunc::Types();
1223   PyUFuncGenericFunction fn = reinterpret_cast<PyUFuncGenericFunction>(UFunc::Fn);
1224   PyObjectPtr ufunc_p = SafePtr(PyObject_GetAttrString(numpy, fn_name));
1225   if (!ufunc_p) {
1226     return false;
1227   }
1228   PyUFuncObject *ufunc = reinterpret_cast<PyUFuncObject *>(ufunc_p.get());
1229   if (static_cast<int>(types.size()) != ufunc->nargs) {
1230     PyErr_Format(PyExc_AssertionError, "The ufunc %s need %d arguments, but got %lu.", fn_name, ufunc->nargs,
1231                  types.size());
1232     return false;
1233   }
1234   if (PyUFunc_RegisterLoopForType(ufunc, NpTypeDescr<T>::Dtype(), fn, const_cast<int *>(types.data()), nullptr) < 0) {
1235     return false;
1236   }
1237   return true;
1238 }
1239 
1240 // Register Numpy universal functions of type T.
1241 template <typename T>
RegisterNpTypeUFuncs(PyObject * numpy)1242 bool RegisterNpTypeUFuncs(PyObject *numpy) {
1243   // Math operations
1244   bool ok = RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Add<T>>, T>(numpy, "add") &&
1245             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Subtract<T>>, T>(numpy, "subtract") &&
1246             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Multiply<T>>, T>(numpy, "multiply") &&
1247             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Divide<T>>, T>(numpy, "divide") &&
1248             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp<T>>, T>(numpy, "logaddexp") &&
1249             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp2<T>>, T>(numpy, "logaddexp2") &&
1250             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Negative<T>>, T>(numpy, "negative") &&
1251             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Positive<T>>, T>(numpy, "positive") &&
1252             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Divide<T>>, T>(numpy, "true_divide") &&
1253             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::FloorDivide<T>>, T>(numpy, "floor_divide") &&
1254             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Power<T>>, T>(numpy, "power") &&
1255             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy, "remainder") &&
1256             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy, "mod") &&
1257             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Fmod<T>>, T>(numpy, "fmod") &&
1258             RegisterNpTypeUFunc<ufuncs::DivmodUFunc<T>, T>(numpy, "divmod") &&
1259             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "absolute") &&
1260             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "fabs") &&
1261             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Rint<T>>, T>(numpy, "rint") &&
1262             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Sign<T>>, T>(numpy, "sign") &&
1263             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Heaviside<T>>, T>(numpy, "heaviside") &&
1264             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Conjugate<T>>, T>(numpy, "conjugate") &&
1265             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Exp<T>>, T>(numpy, "exp") &&
1266             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Exp2<T>>, T>(numpy, "exp2") &&
1267             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Expm1<T>>, T>(numpy, "expm1") &&
1268             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Log<T>>, T>(numpy, "log") &&
1269             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Log1p<T>>, T>(numpy, "log1p") &&
1270             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Log2<T>>, T>(numpy, "log2") &&
1271             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Log10<T>>, T>(numpy, "log10") &&
1272             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Sqrt<T>>, T>(numpy, "sqrt") &&
1273             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Square<T>>, T>(numpy, "square") &&
1274             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Cbrt<T>>, T>(numpy, "cbrt") &&
1275             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Reciprocal<T>>, T>(numpy, "reciprocal") &&
1276             // Trigonometric functions
1277             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Sin<T>>, T>(numpy, "sin") &&
1278             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Cos<T>>, T>(numpy, "cos") &&
1279             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Tan<T>>, T>(numpy, "tan") &&
1280             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arcsin<T>>, T>(numpy, "arcsin") &&
1281             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arccos<T>>, T>(numpy, "arccos") &&
1282             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arctan<T>>, T>(numpy, "arctan") &&
1283             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Arctan2<T>>, T>(numpy, "arctan2") &&
1284             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Hypot<T>>, T>(numpy, "hypot") &&
1285             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Sinh<T>>, T>(numpy, "sinh") &&
1286             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Cosh<T>>, T>(numpy, "cosh") &&
1287             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Tanh<T>>, T>(numpy, "tanh") &&
1288             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arcsinh<T>>, T>(numpy, "arcsinh") &&
1289             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arccosh<T>>, T>(numpy, "arccosh") &&
1290             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arctanh<T>>, T>(numpy, "arctanh") &&
1291             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Deg2rad<T>>, T>(numpy, "deg2rad") &&
1292             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Rad2deg<T>>, T>(numpy, "rad2deg") &&
1293             // Comparison functions
1294             RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Eq<T>>, T>(numpy, "equal") &&
1295             RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Ne<T>>, T>(numpy, "not_equal") &&
1296             RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Lt<T>>, T>(numpy, "less") &&
1297             RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Le<T>>, T>(numpy, "less_equal") &&
1298             RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Gt<T>>, T>(numpy, "greater") &&
1299             RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Ge<T>>, T>(numpy, "greater_equal") &&
1300             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Maximum<T>>, T>(numpy, "maximum") &&
1301             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Minimum<T>>, T>(numpy, "minimum") &&
1302             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Fmax<T>>, T>(numpy, "fmax") &&
1303             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Fmin<T>>, T>(numpy, "fmin") &&
1304             RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::LogicalAnd<T>>, T>(numpy, "logical_and") &&
1305             RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::LogicalOr<T>>, T>(numpy, "logical_or") &&
1306             RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::LogicalXor<T>>, T>(numpy, "logical_xor") &&
1307             RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::LogicalNot<T>>, T>(numpy, "logical_not") &&
1308             // Floating point functions
1309             RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::IsFinite<T>>, T>(numpy, "isfinite") &&
1310             RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::IsInf<T>>, T>(numpy, "isinf") &&
1311             RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::IsNan<T>>, T>(numpy, "isnan") &&
1312             RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::SignBit<T>>, T>(numpy, "signbit") &&
1313             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::CopySign<T>>, T>(numpy, "copysign") &&
1314             RegisterNpTypeUFunc<UnaryUFunc2<T, T, T, ufuncs::Modf<T>>, T>(numpy, "modf") &&
1315             RegisterNpTypeUFunc<BinaryUFunc2<T, int, T, ufuncs::Ldexp<T>>, T>(numpy, "ldexp") &&
1316             RegisterNpTypeUFunc<UnaryUFunc2<T, T, int, ufuncs::Frexp<T>>, T>(numpy, "frexp") &&
1317             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Floor<T>>, T>(numpy, "floor") &&
1318             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Ceil<T>>, T>(numpy, "ceil") &&
1319             RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Trunc<T>>, T>(numpy, "trunc") &&
1320             RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::NextAfter<T>>, T>(numpy, "nextafter");
1321   return ok;
1322 }
1323 
1324 template <typename T>
RegisterNumpyType()1325 bool RegisterNumpyType() {
1326   // Check if current type is already initialized.
1327   if (NpTypeDescr<T>::Dtype() != NPY_NOTYPE) {
1328     return true;
1329   }
1330 
1331   // Import Python modules
1332   import_array1(false);
1333   import_umath1(false);
1334   PyObjectPtr numpy_str = SafePtr(PyUnicode_FromString("numpy"));
1335   if (!numpy_str) {
1336     return false;
1337   }
1338   PyObjectPtr numpy_obj = SafePtr(PyImport_Import(numpy_str.get()));
1339   if (!numpy_obj) {
1340     return false;
1341   }
1342   // Initializes the NumPy type.
1343   PyHeapTypeObject *heap_type = reinterpret_cast<PyHeapTypeObject *>(PyType_Type.tp_alloc(&PyType_Type, 0));
1344   if (!heap_type) {
1345     return false;
1346   }
1347   PyObjectPtr name = SafePtr(PyUnicode_FromString(NpTypeDescr<T>::type_name));
1348   PyObjectPtr qualname = SafePtr(PyUnicode_FromString(NpTypeDescr<T>::type_name));
1349   heap_type->ht_name = name.release();
1350   heap_type->ht_qualname = qualname.release();
1351   PyTypeObject *py_type = &heap_type->ht_type;
1352   py_type->tp_name = NpTypeDescr<T>::type_name;
1353   py_type->tp_basicsize = sizeof(PyType<T>);
1354   py_type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
1355   py_type->tp_base = &PyGenericArrType_Type;
1356   py_type->tp_new = PyType_New<T>;
1357   py_type->tp_repr = PyType_Repr<T>;
1358   py_type->tp_hash = PyType_Hash<T>;
1359   py_type->tp_str = PyType_Str<T>;
1360   py_type->tp_doc = const_cast<char *>(NpTypeDescr<T>::type_doc);
1361   py_type->tp_richcompare = PyType_RichCompare<T>;
1362   py_type->tp_as_number = &NpTypeDescr<T>::number_methods;
1363   if (PyType_Ready(py_type) < 0) {
1364     return false;
1365   }
1366   NpTypeDescr<T>::np_type_ptr = py_type;
1367 
1368   // Initializes the NumPy descriptor.
1369   PyArray_ArrFuncs &arr_funcs = NpTypeDescr<T>::arr_funcs;
1370   PyArray_InitArrFuncs(&arr_funcs);
1371   arr_funcs.getitem = NpType_GetItem<T>;
1372   arr_funcs.setitem = NpType_SetItem<T>;
1373   arr_funcs.compare = NpType_Compare<T>;
1374   arr_funcs.copyswapn = NpType_CopySwapN<T>;
1375   arr_funcs.copyswap = NpType_CopySwap<T>;
1376   arr_funcs.nonzero = NpType_NonZero<T>;
1377   arr_funcs.fill = NpType_Fill<T>;
1378   arr_funcs.dotfunc = NpType_Dot<T>;
1379   arr_funcs.argmax = NpType_ArgMax<T>;
1380   arr_funcs.argmin = NpType_ArgMin<T>;
1381 
1382   // Before NumPy 2.0, we allocate and manage the lifetime of descriptor, and Numpy only stores the pointer.
1383   // After NumPy 2.0, NumPy allocates and manages the lifetime of the descriptor.
1384 #if NPY_ABI_VERSION < 0x02000000
1385   PyArray_DescrProto *descr_proto = &NpTypeDescr<T>::np_descr;
1386 #else
1387   PyArray_DescrProto descr_proto_storage;
1388   PyArray_DescrProto *descr_proto = &descr_proto_storage;
1389 #endif
1390   *descr_proto = GetNpDescrProto<T>();
1391 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE)
1392   Py_TYPE(descr_proto) = &PyArrayDescr_Type;
1393 #else
1394   Py_SET_TYPE(descr_proto, &PyArrayDescr_Type);
1395 #endif
1396   descr_proto->typeobj = py_type;
1397 
1398   NpTypeDescr<T>::np_type_num = PyArray_RegisterDataType(descr_proto);
1399   if (NpTypeDescr<T>::Dtype() < 0) {
1400     return false;
1401   }
1402 #if NPY_ABI_VERSION >= 0x02000000
1403   NpTypeDescr<T>::np_descr = *PyArray_DescrFromType(NpTypeDescr<T>::Dtype());
1404 #endif
1405   if (NpTypeDescr<T>::Dtype() < 0) {
1406     return false;
1407   }
1408 
1409   // Support numpy.dtype(type_name)
1410   PyObjectPtr np_type_dict = SafePtr(PyObject_GetAttrString(numpy_obj.get(), "sctypeDict"));
1411   if (!np_type_dict) {
1412     return false;
1413   }
1414   if (PyDict_SetItemString(np_type_dict.get(), NpTypeDescr<T>::type_name,
1415                            reinterpret_cast<PyObject *>(NpTypeDescr<T>::TypePtr())) < 0) {
1416     return false;
1417   }
1418 
1419   // Support dtype(type_name)
1420   if (PyObject_SetAttrString(reinterpret_cast<PyObject *>(NpTypeDescr<T>::TypePtr()), "dtype",
1421                              reinterpret_cast<PyObject *>(&NpTypeDescr<T>::np_descr)) < 0) {
1422     return false;
1423   }
1424 
1425   // Register casts
1426   if (!RegisterNpTypeCasts<T>()) {
1427     return false;
1428   }
1429 
1430   // Register UFuncs
1431   if (!RegisterNpTypeUFuncs<T>(numpy_obj.get())) {
1432     return false;
1433   }
1434 
1435   return true;
1436 }
1437 
GetNumpyVersion()1438 std::string GetNumpyVersion() {
1439   static std::string version_str = "";
1440   if (!version_str.empty()) {
1441     return version_str;
1442   }
1443   PyObjectPtr numpy_str = SafePtr(PyUnicode_FromString("numpy"));
1444   if (!numpy_str) {
1445     return version_str;
1446   }
1447   PyObjectPtr numpy_obj = SafePtr(PyImport_Import(numpy_str.get()));
1448   if (!numpy_obj) {
1449     return version_str;
1450   }
1451   PyObject *numpy_dict = PyModule_GetDict(numpy_obj.get());
1452   if (!numpy_dict) {
1453     return version_str;
1454   }
1455   PyObject *numpy_version = PyDict_GetItemString(numpy_dict, "__version__");
1456   if (!numpy_version || !PyUnicode_Check(numpy_version)) {
1457     return version_str;
1458   }
1459   const char *version_c = PyUnicode_AsUTF8(numpy_version);
1460   if (!version_c) {
1461     return version_str;
1462   }
1463   version_str = version_c;
1464   MS_LOG(DEBUG) << "Current numpy version:" << version_str;
1465   return version_str;
1466 }
1467 
GetMinimumSupportedNumpyVersion()1468 std::string GetMinimumSupportedNumpyVersion() {
1469   switch (NPY_API_VERSION) {
1470     case 0x0000000d:  // 1.19.3+
1471       return "1.19.3";
1472     case 0x0000000e:  // 1.20 & 1.21
1473       return "1.20.0";
1474     case 0x0000000f:  // 1.22
1475       return "1.22.0";
1476     case 0x00000010:  // 1.23 & 1.24
1477       return "1.23.0";
1478     case 0x00000011:  // 1.25 & 1.26
1479       return "1.20.0";
1480     case 0x00000012:  // 2.0
1481       return "2.0.0";
1482     default:  // Values that exceed the macro definition limit.
1483       return (NPY_API_VERSION < 0x0000000d) ? "1.19.3" : "2.0.0";
1484   }
1485 }
1486 
NumpyVersionValid(std::string version)1487 bool NumpyVersionValid(std::string version) {
1488   // Get current numpy versions
1489   if (version.empty()) {
1490     return false;
1491   }
1492   std::replace(version.begin(), version.end(), '.', ' ');
1493   std::istringstream iss(version);
1494   std::vector<int> version_parts(3);
1495   // version_parts[i] will be 0 if string is invalid.
1496   iss >> version_parts[0] >> version_parts[1] >> version_parts[2];
1497   // Get minimum supported numpy version
1498   std::string minimum_version = GetMinimumSupportedNumpyVersion();
1499   if (minimum_version.empty()) {
1500     return false;
1501   }
1502   std::replace(minimum_version.begin(), minimum_version.end(), '.', ' ');
1503   std::istringstream minimum_iss(minimum_version);
1504   std::vector<int> minimum_version_parts(3);
1505   minimum_iss >> minimum_version_parts[0] >> minimum_version_parts[1] >> minimum_version_parts[2];
1506   return (version_parts[0] == minimum_version_parts[0]) && (version_parts[1] >= minimum_version_parts[1]);
1507 }
1508 
RegisterNumpyTypes()1509 void RegisterNumpyTypes() {
1510   std::string numpy_version = GetNumpyVersion();
1511   std::string minimum_numpy_version = GetMinimumSupportedNumpyVersion();
1512   if (!NumpyVersionValid(numpy_version)) {
1513     MS_LOG(INFO) << "For asnumpy, the numpy bfloat16 data type is supported in Numpy versions " << minimum_numpy_version
1514                  << " to " << minimum_numpy_version[0] << ".x.x, but got " << numpy_version
1515                  << ", please upgrade numpy version.";
1516     return;
1517   }
1518   if (!RegisterNumpyType<bfloat16>()) {
1519     if (PyErr_Occurred()) {
1520       PyErr_Print();
1521     }
1522     MS_LOG(EXCEPTION) << "Failed to register BFloat16 type!";
1523   }
1524 }
1525 }  // namespace np_dtypes
1526 
GetBFloat16NpDType()1527 int GetBFloat16NpDType() { return np_dtypes::NpTypeDescr<bfloat16>::Dtype(); }
1528 
IsNumpyVersionValid(bool show_warning=false)1529 bool IsNumpyVersionValid(bool show_warning = false) {
1530   std::string numpy_version = np_dtypes::GetNumpyVersion();
1531   std::string minimum_numpy_version = np_dtypes::GetMinimumSupportedNumpyVersion();
1532   if (!np_dtypes::NumpyVersionValid(numpy_version)) {
1533     if (show_warning) {
1534       MS_LOG(WARNING) << "For asnumpy, the numpy bfloat16 data type is supported in Numpy versions "
1535                       << minimum_numpy_version << " to " << minimum_numpy_version[0] << ".x.x, but got "
1536                       << numpy_version << ", please upgrade numpy version.";
1537     }
1538     return false;
1539   }
1540   return true;
1541 }
1542 
RegNumpyTypes(py::module * m)1543 void RegNumpyTypes(py::module *m) {
1544   np_dtypes::RegisterNumpyTypes();
1545   auto m_sub = m->def_submodule("np_dtypes", "types of numpy");
1546   m_sub.add_object("bfloat16", reinterpret_cast<PyObject *>(np_dtypes::NpTypeDescr<bfloat16>::TypePtr()));
1547   (void)m_sub.def("np_version_valid", &IsNumpyVersionValid, "Check whether numpy version is valid");
1548 }
1549 }  // namespace mindspore
1550 
1551 #if NPY_ABI_VERSION < 0x02000000
1552 #undef PyArray_DescrProto
1553 #endif
1554