1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
17 #include "include/common/np_dtype/np_dtypes.h"
18 #include <algorithm>
19 #include <string>
20 #include "numpy/arrayobject.h"
21 #include "numpy/ufuncobject.h"
22 #include "base/float16.h"
23 #include "base/bfloat16.h"
24 #include "utils/log_adapter.h"
25
26 #if NPY_API_VERSION < 0x0000000d
27 #error Current Numpy version is too low, the required version is not less than 1.19.3.
28 #endif
29
30 #if NPY_ABI_VERSION < 0x02000000
31 #define PyArray_DescrProto PyArray_Descr
32 #endif
33
34 namespace mindspore {
35 namespace np_dtypes {
36 // A safe PyObject pointer which can decrement the references automatically when destructing.
37 struct PyObjDeleter {
operator ()mindspore::np_dtypes::PyObjDeleter38 void operator()(PyObject *object) const { Py_DECREF(object); }
39 };
40 using PyObjectPtr = std::unique_ptr<PyObject, PyObjDeleter>;
SafePtr(PyObject * object)41 PyObjectPtr SafePtr(PyObject *object) { return PyObjectPtr(object); }
42
43 // Representation of a custom Python type.
44 template <typename T>
45 struct PyType {
46 PyObject_HEAD;
47 T value;
48 };
49
50 // Description of a numpy type.
51 template <typename T>
52 struct NpTypeBaseDescr {
Dtypemindspore::np_dtypes::NpTypeBaseDescr53 static int Dtype() { return np_type_num; }
TypePtrmindspore::np_dtypes::NpTypeBaseDescr54 static PyTypeObject *TypePtr() { return np_type_ptr; }
55 static int np_type_num;
56 static PyTypeObject *np_type_ptr;
57 static PyArray_Descr np_descr;
58 static PyArray_ArrFuncs arr_funcs;
59 static PyNumberMethods number_methods;
60 };
61
62 template <typename T>
63 int NpTypeBaseDescr<T>::np_type_num = NPY_NOTYPE;
64 template <typename T>
65 PyTypeObject *NpTypeBaseDescr<T>::np_type_ptr = nullptr;
66 template <typename T>
67 PyArray_Descr NpTypeBaseDescr<T>::np_descr;
68 template <typename T>
69 PyArray_ArrFuncs NpTypeBaseDescr<T>::arr_funcs;
70
71 template <typename T>
72 struct NpTypeDescr {
Dtypemindspore::np_dtypes::NpTypeDescr73 static int Dtype() { return np_type_num; }
74 static int np_type_num;
75 };
76
77 template <>
78 struct NpTypeDescr<bfloat16> : NpTypeBaseDescr<bfloat16> {
79 static constexpr const char *type_name = "bfloat16";
80 static constexpr const char *type_doc = "BFloat16 type for numpy";
81 static constexpr char kind = 'T';
82 static constexpr char type = 'T';
83 static constexpr char byte_order = '=';
84 };
85
86 template <>
87 int NpTypeDescr<unsigned char>::np_type_num = NPY_UBYTE;
88 template <>
89 int NpTypeDescr<unsigned short>::np_type_num = NPY_USHORT;
90 template <>
91 int NpTypeDescr<unsigned int>::np_type_num = NPY_UINT;
92 template <>
93 int NpTypeDescr<unsigned long>::np_type_num = NPY_ULONG;
94 template <>
95 int NpTypeDescr<unsigned long long>::np_type_num = NPY_ULONGLONG;
96 template <>
97 int NpTypeDescr<char>::np_type_num = NPY_BYTE;
98 template <>
99 int NpTypeDescr<short>::np_type_num = NPY_SHORT;
100 template <>
101 int NpTypeDescr<int>::np_type_num = NPY_INT;
102 template <>
103 int NpTypeDescr<long>::np_type_num = NPY_LONG;
104 template <>
105 int NpTypeDescr<long long>::np_type_num = NPY_LONGLONG;
106 template <>
107 int NpTypeDescr<bool>::np_type_num = NPY_BOOL;
108 template <>
109 int NpTypeDescr<float16>::np_type_num = NPY_HALF;
110 template <>
111 int NpTypeDescr<float>::np_type_num = NPY_FLOAT;
112 template <>
113 int NpTypeDescr<double>::np_type_num = NPY_ULONG;
114 template <>
115 int NpTypeDescr<long double>::np_type_num = NPY_LONGDOUBLE;
116
117 // Check if object is specific numpy custom type.
118 template <typename T>
PyType_CheckType(PyObject * object)119 bool PyType_CheckType(PyObject *object) {
120 return PyObject_IsInstance(object, reinterpret_cast<PyObject *>(NpTypeDescr<T>::TypePtr()));
121 }
122
123 // Get value in the Python type object.
124 template <typename T>
PyType_GetValue(PyObject * object)125 T PyType_GetValue(PyObject *object) {
126 return reinterpret_cast<PyType<T> *>(object)->value;
127 }
128
129 // Create PyTypeObject<T> data from T value.
130 template <typename T>
PyTypeFromValue(T value)131 PyObjectPtr PyTypeFromValue(T value) {
132 PyTypeObject *np_type_p = NpTypeDescr<T>::TypePtr();
133 PyObjectPtr npy_data_p = SafePtr(np_type_p->tp_alloc(np_type_p, 0));
134 PyType<T> *data_p = reinterpret_cast<PyType<T> *>(npy_data_p.get());
135 if (data_p) {
136 data_p->value = value;
137 }
138 return npy_data_p;
139 }
140
141 template <typename T>
PyType_Add(PyObject * a,PyObject * b)142 PyObject *PyType_Add(PyObject *a, PyObject *b) {
143 if (PyType_CheckType<T>(a) && PyType_CheckType<T>(b)) {
144 return PyTypeFromValue<T>(PyType_GetValue<T>(a) + PyType_GetValue<T>(b)).release();
145 }
146 return PyArray_Type.tp_as_number->nb_add(a, b);
147 }
148
149 template <typename T>
PyType_Subtract(PyObject * a,PyObject * b)150 PyObject *PyType_Subtract(PyObject *a, PyObject *b) {
151 if (PyType_CheckType<T>(a) && PyType_CheckType<T>(b)) {
152 return PyTypeFromValue<T>(PyType_GetValue<T>(a) - PyType_GetValue<T>(b)).release();
153 }
154 return PyArray_Type.tp_as_number->nb_subtract(a, b);
155 }
156
157 template <typename T>
PyType_Multiply(PyObject * a,PyObject * b)158 PyObject *PyType_Multiply(PyObject *a, PyObject *b) {
159 if (PyType_CheckType<T>(a) && PyType_CheckType<T>(b)) {
160 return PyTypeFromValue<T>(PyType_GetValue<T>(a) * PyType_GetValue<T>(b)).release();
161 }
162 return PyArray_Type.tp_as_number->nb_multiply(a, b);
163 }
164
165 template <typename T>
PyType_Divide(PyObject * a,PyObject * b)166 PyObject *PyType_Divide(PyObject *a, PyObject *b) {
167 if (PyType_CheckType<T>(a) && PyType_CheckType<T>(b)) {
168 return PyTypeFromValue<T>(PyType_GetValue<T>(a) / PyType_GetValue<T>(b)).release();
169 }
170 return PyArray_Type.tp_as_number->nb_true_divide(a, b);
171 }
172
173 template <typename T>
PyType_Negative(PyObject * self)174 PyObject *PyType_Negative(PyObject *self) {
175 return PyTypeFromValue<T>(-PyType_GetValue<T>(self)).release();
176 }
177
178 template <typename T>
PyType_Int(PyObject * self)179 PyObject *PyType_Int(PyObject *self) {
180 T value = PyType_GetValue<T>(self);
181 return PyLong_FromLong(static_cast<long>(static_cast<float>(value)));
182 }
183
184 template <typename T>
PyType_Float(PyObject * self)185 PyObject *PyType_Float(PyObject *self) {
186 T value = PyType_GetValue<T>(self);
187 return PyFloat_FromDouble(static_cast<double>(static_cast<float>(value)));
188 }
189
190 template <typename T>
191 PyNumberMethods NpTypeBaseDescr<T>::number_methods = {
192 PyType_Add<T>, // nb_add
193 PyType_Subtract<T>, // nb_subtract
194 PyType_Multiply<T>, // nb_multiply
195 nullptr, // nb_remainder
196 nullptr, // nb_divmod
197 nullptr, // nb_power
198 PyType_Negative<T>, // nb_negative
199 nullptr, // nb_positive
200 nullptr, // nb_absolute
201 nullptr, // nb_nonzero
202 nullptr, // nb_invert
203 nullptr, // nb_lshift
204 nullptr, // nb_rshift
205 nullptr, // nb_and
206 nullptr, // nb_xor
207 nullptr, // nb_or
208 PyType_Int<T>, // nb_int
209 nullptr, // reserved
210 PyType_Float<T>, // nb_float
211 nullptr, // nb_inplace_add
212 nullptr, // nb_inplace_subtract
213 nullptr, // nb_inplace_multiply
214 nullptr, // nb_inplace_remainder
215 nullptr, // nb_inplace_power
216 nullptr, // nb_inplace_lshift
217 nullptr, // nb_inplace_rshift
218 nullptr, // nb_inplace_and
219 nullptr, // nb_inplace_xor
220 nullptr, // nb_inplace_or
221 nullptr, // nb_floor_divide
222 PyType_Divide<T>, // nb_true_divide
223 nullptr, // nb_inplace_floor_divide
224 nullptr, // nb_inplace_true_divide
225 nullptr, // nb_index
226 };
227
228 template <typename TypeIn, typename TypeOut, typename Func>
229 struct UnaryUFunc {
Typesmindspore::np_dtypes::UnaryUFunc230 static std::vector<int> Types() { return {NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeOut>::Dtype()}; }
Fnmindspore::np_dtypes::UnaryUFunc231 static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
232 const char *arg_p = args[0];
233 char *out_p = args[1];
234 for (npy_intp d = 0; d < *dimensions; d++) {
235 auto arg = *reinterpret_cast<const TypeIn *>(arg_p);
236 *reinterpret_cast<TypeOut *>(out_p) = Func()(arg);
237 arg_p += steps[0];
238 out_p += steps[1];
239 }
240 }
241 };
242
243 template <typename TypeIn, typename TypeOut, typename TypeOut2, typename Func>
244 struct UnaryUFunc2 {
Typesmindspore::np_dtypes::UnaryUFunc2245 static std::vector<int> Types() {
246 return {NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeOut>::Dtype(), NpTypeDescr<TypeOut2>::Dtype()};
247 }
Fnmindspore::np_dtypes::UnaryUFunc2248 static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
249 const char *arg_p = args[0];
250 char *out0_p = args[1];
251 char *out1_p = args[2];
252 for (npy_intp d = 0; d < *dimensions; d++) {
253 auto arg = *reinterpret_cast<const TypeIn *>(arg_p);
254 std::tie(*reinterpret_cast<TypeOut *>(out0_p), *reinterpret_cast<TypeOut2 *>(out1_p)) = Func()(arg);
255 arg_p += steps[0];
256 out0_p += steps[1];
257 out1_p += steps[2];
258 }
259 }
260 };
261
262 template <typename TypeIn, typename TypeOut, typename Func>
263 struct BinaryUFunc {
Typesmindspore::np_dtypes::BinaryUFunc264 static std::vector<int> Types() {
265 return {NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeOut>::Dtype()};
266 }
Fnmindspore::np_dtypes::BinaryUFunc267 static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
268 const char *arg0_p = args[0];
269 const char *arg1_p = args[1];
270 char *out_p = args[2];
271 for (npy_intp d = 0; d < *dimensions; d++) {
272 auto arg0 = *reinterpret_cast<const TypeIn *>(arg0_p);
273 auto arg1 = *reinterpret_cast<const TypeIn *>(arg1_p);
274 *reinterpret_cast<TypeOut *>(out_p) = Func()(arg0, arg1);
275 arg0_p += steps[0];
276 arg1_p += steps[1];
277 out_p += steps[2];
278 }
279 }
280 };
281
282 template <typename TypeIn, typename TypeIn2, typename TypeOut, typename Func>
283 struct BinaryUFunc2 {
Typesmindspore::np_dtypes::BinaryUFunc2284 static std::vector<int> Types() {
285 return {NpTypeDescr<TypeIn>::Dtype(), NpTypeDescr<TypeIn2>::Dtype(), NpTypeDescr<TypeOut>::Dtype()};
286 }
Fnmindspore::np_dtypes::BinaryUFunc2287 static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
288 const char *arg0_p = args[0];
289 const char *arg1_p = args[1];
290 char *out_p = args[2];
291 for (npy_intp d = 0; d < *dimensions; d++) {
292 auto arg0 = *reinterpret_cast<const TypeIn *>(arg0_p);
293 auto arg1 = *reinterpret_cast<const TypeIn2 *>(arg1_p);
294 *reinterpret_cast<TypeOut *>(out_p) = Func()(arg0, arg1);
295 arg0_p += steps[0];
296 arg1_p += steps[1];
297 out_p += steps[2];
298 }
299 }
300 };
301 namespace ufuncs {
302 // Implementation of Numpy universal functions.
303 template <typename T>
304 struct Add {
operator ()mindspore::np_dtypes::ufuncs::Add305 T operator()(T a, T b) { return a + b; }
306 };
307 template <typename T>
308 struct Subtract {
operator ()mindspore::np_dtypes::ufuncs::Subtract309 T operator()(T a, T b) { return a - b; }
310 };
311 template <typename T>
312 struct Multiply {
operator ()mindspore::np_dtypes::ufuncs::Multiply313 T operator()(T a, T b) { return a * b; }
314 };
315 template <typename T>
316 struct Divide {
operator ()mindspore::np_dtypes::ufuncs::Divide317 T operator()(T a, T b) { return a / b; }
318 };
divmod(float a,float b)319 inline std::pair<float, float> divmod(float a, float b) {
320 if (b == 0.0f) {
321 float nan = std::numeric_limits<float>::quiet_NaN();
322 return {nan, nan};
323 }
324 float mod = std::fmod(a, b);
325 float div = (a - mod) / b;
326 if (mod == 0.0f) {
327 mod = std::copysign(0.0f, b);
328 } else if ((b < 0.0f) != (mod < 0.0f)) {
329 mod += b;
330 div -= 1.0f;
331 }
332 float floor_div;
333 if (div != 0.0f) {
334 floor_div = std::floor(div);
335 if (div - floor_div > 0.5f) {
336 floor_div += 1.0f;
337 }
338 } else {
339 floor_div = std::copysign(0.0f, a / b);
340 }
341 return {floor_div, mod};
342 }
343 template <typename T>
344 struct DivmodUFunc {
Typesmindspore::np_dtypes::ufuncs::DivmodUFunc345 static std::vector<int> Types() {
346 return {NpTypeDescr<T>::Dtype(), NpTypeDescr<T>::Dtype(), NpTypeDescr<T>::Dtype(), NpTypeDescr<T>::Dtype()};
347 }
Fnmindspore::np_dtypes::ufuncs::DivmodUFunc348 static void Fn(char **args, npy_intp const *dimensions, npy_intp const *steps, void *data) {
349 const char *arg0_p = args[0];
350 const char *arg1_p = args[1];
351 char *out0_p = args[2];
352 char *out1_p = args[3];
353 for (npy_intp d = 0; d < *dimensions; d++) {
354 T arg0 = *reinterpret_cast<const T *>(arg0_p);
355 T arg1 = *reinterpret_cast<const T *>(arg1_p);
356 float floordiv, mod;
357 std::tie(floordiv, mod) = divmod(static_cast<float>(arg0), static_cast<float>(arg1));
358 *reinterpret_cast<T *>(out0_p) = T(floordiv);
359 *reinterpret_cast<T *>(out1_p) = T(mod);
360 arg0_p += steps[0];
361 arg1_p += steps[1];
362 out0_p += steps[2];
363 out1_p += steps[3];
364 }
365 }
366 };
367 template <typename T>
368 struct FloorDivide {
operator ()mindspore::np_dtypes::ufuncs::FloorDivide369 T operator()(T a, T b) { return T(divmod(static_cast<float>(a), static_cast<float>(b)).first); }
370 };
371 template <typename T>
372 struct Remainder {
operator ()mindspore::np_dtypes::ufuncs::Remainder373 T operator()(T a, T b) { return T(divmod(static_cast<float>(a), static_cast<float>(b)).second); }
374 };
375 template <typename T>
376 struct Fmod {
operator ()mindspore::np_dtypes::ufuncs::Fmod377 T operator()(T a, T b) { return T(std::fmod(static_cast<float>(a), static_cast<float>(b))); }
378 };
379 template <typename T>
380 struct Negative {
operator ()mindspore::np_dtypes::ufuncs::Negative381 T operator()(T a) { return -a; }
382 };
383 template <typename T>
384 struct Positive {
operator ()mindspore::np_dtypes::ufuncs::Positive385 T operator()(T a) { return a; }
386 };
387 template <typename T>
388 struct Power {
operator ()mindspore::np_dtypes::ufuncs::Power389 T operator()(T a, T b) { return pow(a, b); }
390 };
391 template <typename T>
392 struct Abs {
operator ()mindspore::np_dtypes::ufuncs::Abs393 T operator()(T a) { return abs(a); }
394 };
395 template <typename T>
396 struct Cbrt {
operator ()mindspore::np_dtypes::ufuncs::Cbrt397 T operator()(T a) { return T(std::cbrt(static_cast<float>(a))); }
398 };
399 template <typename T>
400 struct Ceil {
operator ()mindspore::np_dtypes::ufuncs::Ceil401 T operator()(T a) { return ceil(a); }
402 };
403 template <typename T>
404 struct CopySign {
operator ()mindspore::np_dtypes::ufuncs::CopySign405 T operator()(T a, T b) { return T(std::copysign(static_cast<float>(a), static_cast<float>(b))); }
406 };
407 template <typename T>
408 struct Exp {
operator ()mindspore::np_dtypes::ufuncs::Exp409 T operator()(T a) { return exp(a); }
410 };
411 template <typename T>
412 struct Exp2 {
operator ()mindspore::np_dtypes::ufuncs::Exp2413 T operator()(T a) { return T(std::exp2(static_cast<float>(a))); }
414 };
415 template <typename T>
416 struct Expm1 {
operator ()mindspore::np_dtypes::ufuncs::Expm1417 T operator()(T a) { return T(std::expm1(static_cast<float>(a))); }
418 };
419 template <typename T>
420 struct Floor {
operator ()mindspore::np_dtypes::ufuncs::Floor421 T operator()(T a) { return floor(a); }
422 };
423 template <typename T>
424 struct Frexp {
operator ()mindspore::np_dtypes::ufuncs::Frexp425 std::pair<T, int> operator()(T a) {
426 int exp;
427 float f = std::frexp(static_cast<float>(a), &exp);
428 return {T(f), exp};
429 }
430 };
431 template <typename T>
432 struct Heaviside {
operator ()mindspore::np_dtypes::ufuncs::Heaviside433 T operator()(T x, T h0) {
434 if (isnan(x)) {
435 return x;
436 }
437 if (x < T(0)) {
438 return T(0);
439 }
440 if (x > T(0)) {
441 return T(1);
442 }
443 return h0;
444 }
445 };
446 template <typename T>
447 struct Conjugate {
operator ()mindspore::np_dtypes::ufuncs::Conjugate448 T operator()(T a) { return a; }
449 };
450 template <typename T>
451 struct IsFinite {
operator ()mindspore::np_dtypes::ufuncs::IsFinite452 bool operator()(T a) { return isfinite(a); }
453 };
454 template <typename T>
455 struct IsInf {
operator ()mindspore::np_dtypes::ufuncs::IsInf456 bool operator()(T a) { return isinf(a); }
457 };
458 template <typename T>
459 struct IsNan {
operator ()mindspore::np_dtypes::ufuncs::IsNan460 bool operator()(T a) { return isnan(a); }
461 };
462 template <typename T>
463 struct Ldexp {
operator ()mindspore::np_dtypes::ufuncs::Ldexp464 T operator()(T a, int exp) { return T(std::ldexp(static_cast<float>(a), exp)); }
465 };
466 template <typename T>
467 struct Log {
operator ()mindspore::np_dtypes::ufuncs::Log468 T operator()(T a) { return log(a); }
469 };
470 template <typename T>
471 struct Log1p {
operator ()mindspore::np_dtypes::ufuncs::Log1p472 T operator()(T a) { return T(std::log1p(static_cast<float>(a))); }
473 };
474 template <typename T>
475 struct Log2 {
operator ()mindspore::np_dtypes::ufuncs::Log2476 T operator()(T a) { return T(std::log2(static_cast<float>(a))); }
477 };
478 template <typename T>
479 struct Log10 {
operator ()mindspore::np_dtypes::ufuncs::Log10480 T operator()(T a) { return T(std::log10(static_cast<float>(a))); }
481 };
482 template <typename T>
483 struct LogAddExp {
operator ()mindspore::np_dtypes::ufuncs::LogAddExp484 T operator()(T a, T b) {
485 float x = static_cast<float>(a);
486 float y = static_cast<float>(b);
487 if (x == y) {
488 return T(x + std::log(2.0f));
489 }
490 float out = std::numeric_limits<float>::quiet_NaN();
491 if (x > y) {
492 out = x + std::log1p(std::exp(y - x));
493 } else if (x < y) {
494 out = y + std::log1p(std::exp(x - y));
495 }
496 return T(out);
497 }
498 };
499 template <typename T>
500 struct LogAddExp2 {
operator ()mindspore::np_dtypes::ufuncs::LogAddExp2501 T operator()(T a, T b) {
502 float x = static_cast<float>(a);
503 float y = static_cast<float>(b);
504 if (x == y) {
505 return T(x + 1.0f);
506 }
507 float out = std::numeric_limits<float>::quiet_NaN();
508 if (x > y) {
509 out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
510 } else if (x < y) {
511 out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
512 }
513 return T(out);
514 }
515 };
516 template <typename T>
517 struct Modf {
operator ()mindspore::np_dtypes::ufuncs::Modf518 std::pair<T, T> operator()(T a) {
519 float integral;
520 float f = std::modf(static_cast<float>(a), &integral);
521 return {T(f), T(integral)};
522 }
523 };
524 template <typename T>
525 struct Reciprocal {
operator ()mindspore::np_dtypes::ufuncs::Reciprocal526 T operator()(T a) { return T(1.f / static_cast<float>(a)); }
527 };
528 template <typename T>
529 struct Rint {
operator ()mindspore::np_dtypes::ufuncs::Rint530 T operator()(T a) { return T(std::rint(static_cast<float>(a))); }
531 };
532 template <typename T>
533 struct Sign {
operator ()mindspore::np_dtypes::ufuncs::Sign534 T operator()(T a) {
535 if (isnan(a)) {
536 return a;
537 }
538 if (a < T(0)) {
539 return T(-1);
540 }
541 if (a > T(0)) {
542 return T(1);
543 }
544 return a;
545 }
546 };
547 template <typename T>
548 struct SignBit {
operator ()mindspore::np_dtypes::ufuncs::SignBit549 bool operator()(T a) { return std::signbit(static_cast<float>(a)); }
550 };
551 template <typename T>
552 struct Sqrt {
operator ()mindspore::np_dtypes::ufuncs::Sqrt553 T operator()(T a) { return T(std::sqrt(static_cast<float>(a))); }
554 };
555 template <typename T>
556 struct Square {
operator ()mindspore::np_dtypes::ufuncs::Square557 T operator()(T a) {
558 float f(a);
559 return T(f * f);
560 }
561 };
562 template <typename T>
563 struct Trunc {
operator ()mindspore::np_dtypes::ufuncs::Trunc564 T operator()(T a) { return T(std::trunc(static_cast<float>(a))); }
565 };
566 // Trigonometric functions
567 template <typename T>
568 struct Sin {
operator ()mindspore::np_dtypes::ufuncs::Sin569 T operator()(T a) { return sin(a); }
570 };
571 template <typename T>
572 struct Cos {
operator ()mindspore::np_dtypes::ufuncs::Cos573 T operator()(T a) { return cos(a); }
574 };
575 template <typename T>
576 struct Tan {
operator ()mindspore::np_dtypes::ufuncs::Tan577 T operator()(T a) { return tan(a); }
578 };
579 template <typename T>
580 struct Arcsin {
operator ()mindspore::np_dtypes::ufuncs::Arcsin581 T operator()(T a) { return T(std::asin(static_cast<float>(a))); }
582 };
583 template <typename T>
584 struct Arccos {
operator ()mindspore::np_dtypes::ufuncs::Arccos585 T operator()(T a) { return T(std::acos(static_cast<float>(a))); }
586 };
587 template <typename T>
588 struct Arctan {
operator ()mindspore::np_dtypes::ufuncs::Arctan589 T operator()(T a) { return T(std::atan(static_cast<float>(a))); }
590 };
591 template <typename T>
592 struct Arctan2 {
operator ()mindspore::np_dtypes::ufuncs::Arctan2593 T operator()(T a, T b) { return T(std::atan2(static_cast<float>(a), static_cast<float>(b))); }
594 };
595 template <typename T>
596 struct Hypot {
operator ()mindspore::np_dtypes::ufuncs::Hypot597 T operator()(T a, T b) { return T(std::hypot(static_cast<float>(a), static_cast<float>(b))); }
598 };
599 template <typename T>
600 struct Sinh {
operator ()mindspore::np_dtypes::ufuncs::Sinh601 T operator()(T a) { return T(std::sinh(static_cast<float>(a))); }
602 };
603 template <typename T>
604 struct Cosh {
operator ()mindspore::np_dtypes::ufuncs::Cosh605 T operator()(T a) { return T(std::cosh(static_cast<float>(a))); }
606 };
607 template <typename T>
608 struct Tanh {
operator ()mindspore::np_dtypes::ufuncs::Tanh609 T operator()(T a) { return tanh(a); }
610 };
611 template <typename T>
612 struct Arcsinh {
operator ()mindspore::np_dtypes::ufuncs::Arcsinh613 T operator()(T a) { return T(std::asinh(static_cast<float>(a))); }
614 };
615 template <typename T>
616 struct Arccosh {
operator ()mindspore::np_dtypes::ufuncs::Arccosh617 T operator()(T a) { return T(std::acosh(static_cast<float>(a))); }
618 };
619 template <typename T>
620 struct Arctanh {
operator ()mindspore::np_dtypes::ufuncs::Arctanh621 T operator()(T a) { return T(std::atanh(static_cast<float>(a))); }
622 };
623 template <typename T>
624 struct Deg2rad {
operator ()mindspore::np_dtypes::ufuncs::Deg2rad625 T operator()(T a) {
626 static constexpr float PI = 3.14159265358979323846f;
627 static constexpr float RADIANS_PER_DEGREE = PI / 180.0f;
628 return T(static_cast<float>(a) * RADIANS_PER_DEGREE);
629 }
630 };
631 template <typename T>
632 struct Rad2deg {
operator ()mindspore::np_dtypes::ufuncs::Rad2deg633 T operator()(T a) {
634 static constexpr float PI = 3.14159265358979323846f;
635 static constexpr float DEGREES_PER_RADIAN = 180.0f / PI;
636 return T(static_cast<float>(a) * DEGREES_PER_RADIAN);
637 }
638 };
639 template <typename T>
640 struct Eq {
operator ()mindspore::np_dtypes::ufuncs::Eq641 npy_bool operator()(T a, T b) { return a == b; }
642 };
643 template <typename T>
644 struct Ne {
operator ()mindspore::np_dtypes::ufuncs::Ne645 npy_bool operator()(T a, T b) { return a != b; }
646 };
647 template <typename T>
648 struct Lt {
operator ()mindspore::np_dtypes::ufuncs::Lt649 npy_bool operator()(T a, T b) { return a < b; }
650 };
651 template <typename T>
652 struct Le {
operator ()mindspore::np_dtypes::ufuncs::Le653 npy_bool operator()(T a, T b) { return a <= b; }
654 };
655 template <typename T>
656 struct Gt {
operator ()mindspore::np_dtypes::ufuncs::Gt657 npy_bool operator()(T a, T b) { return a > b; }
658 };
659 template <typename T>
660 struct Ge {
operator ()mindspore::np_dtypes::ufuncs::Ge661 npy_bool operator()(T a, T b) { return a >= b; }
662 };
663 template <typename T>
664 struct Maximum {
operator ()mindspore::np_dtypes::ufuncs::Maximum665 T operator()(T a, T b) { return isnan(a) || a > b ? a : b; }
666 };
667 template <typename T>
668 struct Minimum {
operator ()mindspore::np_dtypes::ufuncs::Minimum669 T operator()(T a, T b) { return isnan(a) || a < b ? a : b; }
670 };
671 template <typename T>
672 struct Fmax {
operator ()mindspore::np_dtypes::ufuncs::Fmax673 T operator()(T a, T b) { return isnan(b) || a > b ? a : b; }
674 };
675 template <typename T>
676 struct Fmin {
operator ()mindspore::np_dtypes::ufuncs::Fmin677 T operator()(T a, T b) { return isnan(b) || a < b ? a : b; }
678 };
679 template <typename T>
680 struct LogicalNot {
operator ()mindspore::np_dtypes::ufuncs::LogicalNot681 npy_bool operator()(T a) { return !static_cast<bool>(a); }
682 };
683 template <typename T>
684 struct LogicalAnd {
operator ()mindspore::np_dtypes::ufuncs::LogicalAnd685 npy_bool operator()(T a, T b) { return static_cast<bool>(a) && static_cast<bool>(b); }
686 };
687 template <typename T>
688 struct LogicalOr {
operator ()mindspore::np_dtypes::ufuncs::LogicalOr689 npy_bool operator()(T a, T b) { return static_cast<bool>(a) || static_cast<bool>(b); }
690 };
691 template <typename T>
692 struct LogicalXor {
operator ()mindspore::np_dtypes::ufuncs::LogicalXor693 npy_bool operator()(T a, T b) { return static_cast<bool>(a) ^ static_cast<bool>(b); }
694 };
695 // Get unsigned integer type with same size of T.
696 template <int kNumBytes>
697 struct GetUnsignedInteger;
698 template <>
699 struct GetUnsignedInteger<1> {
700 using uint_type = uint8_t;
701 };
702 template <>
703 struct GetUnsignedInteger<2> {
704 using uint_type = uint16_t;
705 };
706 template <>
707 struct GetUnsignedInteger<4> {
708 using uint_type = uint32_t;
709 };
710 template <typename T>
711 using UIntType = typename GetUnsignedInteger<sizeof(T)>::uint_type;
712 template <typename TypeIn, typename TypeOut>
bit_cast(TypeIn value)713 TypeOut bit_cast(TypeIn value) {
714 static_assert(sizeof(TypeIn) == sizeof(TypeOut), "For bit_cast, types must match size.");
715 TypeOut out = TypeOut(0);
716 errno_t ret = memcpy_s(&out, sizeof(TypeOut), &value, sizeof(TypeIn));
717 if (ret != EOK) {
718 PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d", ret);
719 return out;
720 }
721 return out;
722 }
723 template <typename T>
724 struct NextAfter {
operator ()mindspore::np_dtypes::ufuncs::NextAfter725 T operator()(T from, T to) {
726 if (isnan(from) || isnan(to)) {
727 return std::numeric_limits<T>::quiet_NaN();
728 }
729 UIntType<T> from_uint = bit_cast<T, UIntType<T>>(from);
730 UIntType<T> to_uint = bit_cast<T, UIntType<T>>(to);
731 if (from_uint == to_uint) {
732 return to;
733 }
734 UIntType<T> sign_mask = UIntType<T>(1) << (sizeof(T) * CHAR_BIT - 1);
735 UIntType<T> from_uint_abs = bit_cast<T, UIntType<T>>(abs(from));
736 UIntType<T> from_uint_sign = from_uint & sign_mask;
737 UIntType<T> to_uint_abs = bit_cast<T, UIntType<T>>(abs(to));
738 UIntType<T> to_uint_sign = to_uint & sign_mask;
739 if (from_uint_abs == 0) {
740 if (to_uint_abs == 0) {
741 return to;
742 } else {
743 // Minimum non-zero value with sign bit of `to`.
744 return bit_cast<UIntType<T>, T>(static_cast<UIntType<T>>(0x01 | to_uint_sign));
745 }
746 }
747 UIntType<T> next_step = (from_uint_abs > to_uint_abs || from_uint_sign != to_uint_sign)
748 ? static_cast<UIntType<T>>(-1)
749 : static_cast<UIntType<T>>(1);
750 UIntType<T> out_uint = from_uint + next_step;
751 return bit_cast<UIntType<T>, T>(out_uint);
752 }
753 };
754 } // namespace ufuncs
755
756 // Cast input object to Python type T.
757 template <typename T>
CastToPyType(PyObject * obj,T * output)758 bool CastToPyType(PyObject *obj, T *output) {
759 // object is an instance of NpTypeDescr
760 if (PyType_CheckType<T>(obj)) {
761 *output = PyType_GetValue<T>(obj);
762 return true;
763 }
764 // object is an instance of int
765 if (PyLong_Check(obj)) {
766 long value = PyLong_AsLong(obj);
767 if (PyErr_Occurred()) {
768 return false;
769 }
770 *output = T(value);
771 return true;
772 }
773 // object is an instance of float
774 if (PyFloat_Check(obj)) {
775 double value = PyFloat_AsDouble(obj);
776 if (PyErr_Occurred()) {
777 return false;
778 }
779 *output = T(value);
780 return true;
781 }
782 // object is an instance of scalar float16
783 if (PyArray_IsScalar(obj, Half)) {
784 float16 value;
785 PyArray_ScalarAsCtype(obj, &value);
786 *output = T(value);
787 return true;
788 }
789 // object is an instance of scalar float
790 if (PyArray_IsScalar(obj, Float)) {
791 float value;
792 PyArray_ScalarAsCtype(obj, &value);
793 *output = T(value);
794 return true;
795 }
796 // object is an instance of scalar double
797 if (PyArray_IsScalar(obj, Double)) {
798 double value;
799 PyArray_ScalarAsCtype(obj, &value);
800 *output = T(value);
801 return true;
802 }
803 // object is an instance of scalar long double
804 if (PyArray_IsScalar(obj, LongDouble)) {
805 long double value;
806 PyArray_ScalarAsCtype(obj, &value);
807 *output = T(value);
808 return true;
809 }
810 // object is an instance of 0-dim array
811 if (PyArray_IsZeroDim(obj)) {
812 PyArrayObject *arr = reinterpret_cast<PyArrayObject *>(obj);
813 // cast value in array to type T
814 if (PyArray_TYPE(arr) != NpTypeDescr<T>::Dtype()) {
815 PyObjectPtr new_arr = SafePtr(PyArray_Cast(arr, NpTypeDescr<T>::Dtype()));
816 if (PyErr_Occurred()) {
817 return false;
818 }
819 arr = reinterpret_cast<PyArrayObject *>(new_arr.get());
820 }
821 *output = *reinterpret_cast<T *>(PyArray_DATA(arr));
822 return true;
823 }
824 return false;
825 }
826
827 // Constructs a new Python type.
828 template <typename T>
PyType_New(PyTypeObject * type,PyObject * args,PyObject * kwds)829 PyObject *PyType_New(PyTypeObject *type, PyObject *args, PyObject *kwds) {
830 if (kwds && PyDict_Size(kwds)) {
831 PyErr_Format(PyExc_TypeError, "No keyword arguments should be provided when constructing %s",
832 NpTypeDescr<T>::type_name);
833 return nullptr;
834 }
835 Py_ssize_t arg_num = PyTuple_Size(args);
836 if (arg_num != 1) {
837 PyErr_Format(PyExc_TypeError, "One argument is expected when constructing %s, but got %d.",
838 NpTypeDescr<T>::type_name, arg_num);
839 return nullptr;
840 }
841 PyObject *arg = PyTuple_GetItem(args, 0);
842 T value;
843 // If arg is already NpTypeDescr<T>, just return it.
844 if (PyType_CheckType<T>(arg)) {
845 Py_INCREF(arg);
846 return arg;
847 }
848 // If arg can be casted to T value, create NpTypeDescr<T> from the value.
849 if (CastToPyType<T>(arg, &value)) {
850 return PyTypeFromValue<T>(value).release();
851 }
852 // If arg is an array, cast it to NpTypeDescr<T>
853 if (PyArray_Check(arg)) {
854 PyArrayObject *arr = reinterpret_cast<PyArrayObject *>(arg);
855 if (PyArray_TYPE(arr) != NpTypeDescr<T>::Dtype()) {
856 return PyArray_Cast(arr, NpTypeDescr<T>::Dtype());
857 } else {
858 Py_INCREF(arg);
859 return arg;
860 }
861 }
862 // If arg is unicodes or bytes, convert it from string to float, then cast the float to T value,
863 // and then create NpTypeDescr<T> from the value.
864 if (PyUnicode_Check(arg) || PyBytes_Check(arg)) {
865 PyObject *value_f = PyFloat_FromString(arg);
866 if (CastToPyType<T>(value_f, &value)) {
867 return PyTypeFromValue<T>(value).release();
868 }
869 }
870 PyErr_Format(PyExc_TypeError, "Only number argument is expected when constructing %s, but got %s.",
871 NpTypeDescr<T>::type_name, Py_TYPE(arg)->tp_name);
872 return nullptr;
873 }
874
875 // Implementation of repr() for PyType.
876 template <typename T>
PyType_Repr(PyObject * self)877 PyObject *PyType_Repr(PyObject *self) {
878 T value = reinterpret_cast<PyType<T> *>(self)->value;
879 std::string value_str = std::to_string(static_cast<float>(value));
880 return PyUnicode_FromString(value_str.c_str());
881 }
882
883 // Overload function _Py_HashDouble to support Python version over 3.10.
HashDouble_(Py_hash_t (* hash_double)(PyObject *,double),PyObject * self,double value)884 inline Py_hash_t HashDouble_(Py_hash_t (*hash_double)(PyObject *, double), PyObject *self, double value) {
885 return hash_double(self, value);
886 }
887
HashDouble_(Py_hash_t (* hash_double)(double),PyObject * self,double value)888 inline Py_hash_t HashDouble_(Py_hash_t (*hash_double)(double), PyObject *self, double value) {
889 return hash_double(value);
890 }
891
892 // Implementation of hash() for PyType.
893 template <typename T>
PyType_Hash(PyObject * self)894 Py_hash_t PyType_Hash(PyObject *self) {
895 T value = reinterpret_cast<PyType<T> *>(self)->value;
896 return HashDouble_(&_Py_HashDouble, self, static_cast<double>(value));
897 }
898
899 // Implementation of str() for PyType.
900 template <typename T>
PyType_Str(PyObject * self)901 PyObject *PyType_Str(PyObject *self) {
902 T value = reinterpret_cast<PyType<T> *>(self)->value;
903 std::string value_str = std::to_string(static_cast<float>(value));
904 return PyUnicode_FromString(value_str.c_str());
905 }
906
907 // Implementation of Comparisons for PyType.
908 template <typename T>
PyType_RichCompare(PyObject * a,PyObject * b,int op)909 PyObject *PyType_RichCompare(PyObject *a, PyObject *b, int op) {
910 if (!PyType_CheckType<T>(a) || !PyType_CheckType<T>(b)) {
911 return PyGenericArrType_Type.tp_richcompare(a, b, op);
912 }
913 T x = PyType_GetValue<T>(a);
914 T y = PyType_GetValue<T>(b);
915 bool result;
916 switch (op) {
917 case Py_EQ:
918 result = (x == y);
919 break;
920 case Py_NE:
921 result = (x != y);
922 break;
923 case Py_LT:
924 result = (x < y);
925 break;
926 case Py_LE:
927 result = (x <= y);
928 break;
929 case Py_GT:
930 result = (x > y);
931 break;
932 case Py_GE:
933 result = (x >= y);
934 break;
935 default:
936 PyErr_Format(PyExc_ValueError, "Got invalid op type %d when comparing %s", op, NpTypeDescr<T>::type_name);
937 return nullptr;
938 }
939 PyObject *ret = PyBool_FromLong(result);
940 Py_INCREF(ret);
941 return ret;
942 }
943
944 // Implementations of NumPy array methods for PyType.
945 template <typename T>
NpType_GetItem(void * data,void * arr)946 PyObject *NpType_GetItem(void *data, void *arr) {
947 T value;
948 errno_t ret = memcpy_s(&value, sizeof(T), data, sizeof(T));
949 if (ret != EOK) {
950 PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
951 return nullptr;
952 }
953 return PyTypeFromValue(value).release();
954 }
955
956 template <typename T>
NpType_SetItem(PyObject * item,void * data,void * arr)957 int NpType_SetItem(PyObject *item, void *data, void *arr) {
958 T value;
959 if (!CastToPyType<T>(item, &value)) {
960 PyErr_Format(PyExc_TypeError, "Only number argument is expected for SetItem %s, but got %s.",
961 NpTypeDescr<T>::type_name, Py_TYPE(item)->tp_name);
962 return -1;
963 }
964 errno_t ret = memcpy_s(data, sizeof(T), &value, sizeof(T));
965 if (ret != EOK) {
966 PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
967 return -1;
968 }
969 return 0;
970 }
971
972 template <typename T>
NpType_Compare(const void * d1,const void * d2,void * arr)973 int NpType_Compare(const void *d1, const void *d2, void *arr) {
974 T x = *reinterpret_cast<const T *>(d1);
975 T y = *reinterpret_cast<const T *>(d2);
976 if (x < y) {
977 return -1;
978 }
979 if (y < x) {
980 return 1;
981 }
982 if (!isnan(x) && isnan(y)) {
983 return -1;
984 }
985 if (isnan(x) && !isnan(y)) {
986 return 1;
987 }
988 return 0;
989 }
990
991 template <typename T>
NpType_CopySwapN(void * dest,npy_intp dstride,void * src,npy_intp sstride,npy_intp n,int swap,void * arr)992 void NpType_CopySwapN(void *dest, npy_intp dstride, void *src, npy_intp sstride, npy_intp n, int swap, void *arr) {
993 static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t), "Swap is not supported");
994 char *dst_p = reinterpret_cast<char *>(dest);
995 char *src_p = reinterpret_cast<char *>(src);
996 if (!src_p) {
997 return;
998 }
999 if (swap && sizeof(T) == sizeof(int16_t)) {
1000 for (npy_intp i = 0; i < n; i++) {
1001 char *r = dst_p + dstride * i;
1002 errno_t ret = memcpy_s(r, sizeof(T), src_p + sstride * i, sizeof(T));
1003 if (ret != EOK) {
1004 PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1005 return;
1006 }
1007 std::swap(r[0], r[1]);
1008 }
1009 } else if (dstride == sizeof(T) && sstride == sizeof(T)) {
1010 errno_t ret = memcpy_s(dst_p, n * sizeof(T), src_p, n * sizeof(T));
1011 if (ret != EOK) {
1012 PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1013 return;
1014 }
1015 } else {
1016 for (npy_intp i = 0; i < n; i++) {
1017 errno_t ret = memcpy_s(dst_p + dstride * i, sizeof(T), src_p + sstride * i, sizeof(T));
1018 if (ret != EOK) {
1019 PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1020 return;
1021 }
1022 }
1023 }
1024 }
1025
1026 template <typename T>
NpType_CopySwap(void * dest,void * src,int swap,void * arr)1027 void NpType_CopySwap(void *dest, void *src, int swap, void *arr) {
1028 static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t), "Swap is not supported");
1029 if (!src) {
1030 return;
1031 }
1032 errno_t ret = memcpy_s(dest, sizeof(T), src, sizeof(T));
1033 if (ret != EOK) {
1034 PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1035 return;
1036 }
1037 if (swap && (sizeof(T) == sizeof(int16_t))) {
1038 char *p = reinterpret_cast<char *>(dest);
1039 std::swap(p[0], p[1]);
1040 }
1041 }
1042
1043 template <typename T>
NpType_NonZero(void * data,void * arr)1044 npy_bool NpType_NonZero(void *data, void *arr) {
1045 T value;
1046 errno_t ret = memcpy_s(&value, sizeof(T), data, sizeof(T));
1047 if (ret != EOK) {
1048 PyErr_Format(PyExc_MemoryError, "memcpy_s failed: %d.", ret);
1049 return false;
1050 }
1051 return value != static_cast<T>(0);
1052 }
1053
1054 template <typename T>
NpType_Fill(void * data,npy_intp length,void * arr)1055 int NpType_Fill(void *data, npy_intp length, void *arr) {
1056 T *const buffer = reinterpret_cast<T *>(data);
1057 const T start(buffer[0]);
1058 const T delta = static_cast<T>(buffer[1]) - start;
1059 for (npy_intp i = 2; i < length; i++) {
1060 buffer[i] = static_cast<T>(start + T(i) * delta);
1061 }
1062 return 0;
1063 }
1064
1065 template <typename T>
NpType_Dot(void * ip1,npy_intp is1,void * ip2,npy_intp is2,void * op,npy_intp n,void * arr)1066 void NpType_Dot(void *ip1, npy_intp is1, void *ip2, npy_intp is2, void *op, npy_intp n, void *arr) {
1067 char *p1 = reinterpret_cast<char *>(ip1);
1068 char *p2 = reinterpret_cast<char *>(ip2);
1069 T acc = T(0);
1070 for (npy_intp i = 0; i < n; i++) {
1071 T *const a = reinterpret_cast<T *>(p1);
1072 T *const b = reinterpret_cast<T *>(p2);
1073 acc += static_cast<T>(*a) * static_cast<T>(*b);
1074 p1 += is1;
1075 p2 += is2;
1076 }
1077 T *out = reinterpret_cast<T *>(op);
1078 *out = static_cast<T>(acc);
1079 }
1080
1081 template <typename T>
NpType_ArgMax(void * data,npy_intp n,npy_intp * max_ind,void * arr)1082 int NpType_ArgMax(void *data, npy_intp n, npy_intp *max_ind, void *arr) {
1083 const T *data_p = reinterpret_cast<const T *>(data);
1084 T max_val = static_cast<T>(data_p[0]);
1085 *max_ind = 0;
1086 for (npy_intp i = 0; i < n; i++) {
1087 T val = static_cast<T>(data_p[i]);
1088 if (isnan(val) || val > max_val) {
1089 max_val = val;
1090 *max_ind = i;
1091 // NumPy stops at the first NaN.
1092 if (isnan(val)) {
1093 break;
1094 }
1095 }
1096 }
1097 return 0;
1098 }
1099
1100 template <typename T>
NpType_ArgMin(void * data,npy_intp n,npy_intp * min_ind,void * arr)1101 int NpType_ArgMin(void *data, npy_intp n, npy_intp *min_ind, void *arr) {
1102 const T *data_p = reinterpret_cast<const T *>(data);
1103 T min_val = static_cast<T>(data_p[0]);
1104 *min_ind = 0;
1105 for (npy_intp i = 1; i < n; i++) {
1106 T val = static_cast<T>(data_p[i]);
1107 if (isnan(val) || val < min_val) {
1108 min_val = val;
1109 *min_ind = i;
1110 // NumPy stops at the first NaN.
1111 if (isnan(val)) {
1112 break;
1113 }
1114 }
1115 }
1116 return 0;
1117 }
1118
1119 template <typename T>
GetNpDescrProto()1120 PyArray_DescrProto GetNpDescrProto() {
1121 return {
1122 PyObject_HEAD_INIT(nullptr)
1123 /*typeobj=*/nullptr,
1124 /*kind=*/NpTypeDescr<T>::kind,
1125 /*type=*/NpTypeDescr<T>::type,
1126 /*byteorder=*/NpTypeDescr<T>::byte_order,
1127 /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_SETITEM,
1128 /*type_num=*/0,
1129 /*elsize=*/sizeof(T),
1130 /*alignment=*/alignof(T),
1131 /*subarray=*/nullptr,
1132 /*fields=*/nullptr,
1133 /*names=*/nullptr,
1134 /*f=*/&NpTypeDescr<T>::arr_funcs,
1135 /*metadata=*/nullptr,
1136 /*c_metadata=*/nullptr,
1137 /*hash=*/-1,
1138 };
1139 }
1140
1141 // Cast a numpy array from type 'From' to 'To'.
1142 template <typename From, typename To>
NpyCast(void * from,void * to,npy_intp n,void * from_arr,void * to_arr)1143 void NpyCast(void *from, void *to, npy_intp n, void *from_arr, void *to_arr) {
1144 const From *from_ptr = static_cast<From *>(from);
1145 To *to_ptr = static_cast<To *>(to);
1146 for (npy_intp i = 0; i < n; i++) {
1147 to_ptr[i] = static_cast<To>(from_ptr[i]);
1148 }
1149 }
1150
1151 // Register a cast between T and other numpy type Y.
1152 template <typename T, typename Y>
RegisterNpTypeCast(int np_type,bool scalar_castable)1153 bool RegisterNpTypeCast(int np_type, bool scalar_castable) {
1154 PyArray_Descr *descr = PyArray_DescrFromType(np_type);
1155 if (PyArray_RegisterCastFunc(descr, NpTypeDescr<T>::Dtype(), NpyCast<Y, T>) < 0) {
1156 return false;
1157 }
1158 if (PyArray_RegisterCastFunc(&NpTypeDescr<T>::np_descr, np_type, NpyCast<T, Y>) < 0) {
1159 return false;
1160 }
1161 if (scalar_castable && PyArray_RegisterCanCast(&NpTypeDescr<T>::np_descr, np_type, NPY_NOSCALAR) < 0) {
1162 return false;
1163 }
1164 return true;
1165 }
1166
1167 // Register casts between T and other numpy types.
1168 template <typename T>
RegisterNpTypeCasts()1169 bool RegisterNpTypeCasts() {
1170 if (!RegisterNpTypeCast<T, bool>(NPY_BOOL, false)) {
1171 return false;
1172 }
1173 if (!RegisterNpTypeCast<T, float16>(NPY_HALF, false)) {
1174 return false;
1175 }
1176 if (!RegisterNpTypeCast<T, float>(NPY_FLOAT, true)) {
1177 return false;
1178 }
1179 if (!RegisterNpTypeCast<T, double>(NPY_DOUBLE, false)) {
1180 return false;
1181 }
1182 if (!RegisterNpTypeCast<T, long double>(NPY_LONGDOUBLE, false)) {
1183 return false;
1184 }
1185 if (!RegisterNpTypeCast<T, unsigned char>(NPY_UBYTE, false)) {
1186 return false;
1187 }
1188 if (!RegisterNpTypeCast<T, unsigned short>(NPY_USHORT, false)) {
1189 return false;
1190 }
1191 if (!RegisterNpTypeCast<T, unsigned int>(NPY_UINT, false)) {
1192 return false;
1193 }
1194 if (!RegisterNpTypeCast<T, unsigned long>(NPY_ULONG, false)) {
1195 return false;
1196 }
1197 if (!RegisterNpTypeCast<T, unsigned long long>(NPY_ULONGLONG, false)) {
1198 return false;
1199 }
1200 if (!RegisterNpTypeCast<T, char>(NPY_BYTE, false)) {
1201 return false;
1202 }
1203 if (!RegisterNpTypeCast<T, short>(NPY_SHORT, false)) {
1204 return false;
1205 }
1206 if (!RegisterNpTypeCast<T, int>(NPY_INT, false)) {
1207 return false;
1208 }
1209 if (!RegisterNpTypeCast<T, long>(NPY_LONG, false)) {
1210 return false;
1211 }
1212 if (!RegisterNpTypeCast<T, long long>(NPY_LONGLONG, false)) {
1213 return false;
1214 }
1215 // Complexs are not support yet.
1216 return true;
1217 }
1218
1219 // Register a Numpy universal function.
1220 template <typename UFunc, typename T>
RegisterNpTypeUFunc(PyObject * numpy,const char * fn_name)1221 bool RegisterNpTypeUFunc(PyObject *numpy, const char *fn_name) {
1222 std::vector<int> types = UFunc::Types();
1223 PyUFuncGenericFunction fn = reinterpret_cast<PyUFuncGenericFunction>(UFunc::Fn);
1224 PyObjectPtr ufunc_p = SafePtr(PyObject_GetAttrString(numpy, fn_name));
1225 if (!ufunc_p) {
1226 return false;
1227 }
1228 PyUFuncObject *ufunc = reinterpret_cast<PyUFuncObject *>(ufunc_p.get());
1229 if (static_cast<int>(types.size()) != ufunc->nargs) {
1230 PyErr_Format(PyExc_AssertionError, "The ufunc %s need %d arguments, but got %lu.", fn_name, ufunc->nargs,
1231 types.size());
1232 return false;
1233 }
1234 if (PyUFunc_RegisterLoopForType(ufunc, NpTypeDescr<T>::Dtype(), fn, const_cast<int *>(types.data()), nullptr) < 0) {
1235 return false;
1236 }
1237 return true;
1238 }
1239
1240 // Register Numpy universal functions of type T.
1241 template <typename T>
RegisterNpTypeUFuncs(PyObject * numpy)1242 bool RegisterNpTypeUFuncs(PyObject *numpy) {
1243 // Math operations
1244 bool ok = RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Add<T>>, T>(numpy, "add") &&
1245 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Subtract<T>>, T>(numpy, "subtract") &&
1246 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Multiply<T>>, T>(numpy, "multiply") &&
1247 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Divide<T>>, T>(numpy, "divide") &&
1248 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp<T>>, T>(numpy, "logaddexp") &&
1249 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp2<T>>, T>(numpy, "logaddexp2") &&
1250 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Negative<T>>, T>(numpy, "negative") &&
1251 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Positive<T>>, T>(numpy, "positive") &&
1252 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Divide<T>>, T>(numpy, "true_divide") &&
1253 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::FloorDivide<T>>, T>(numpy, "floor_divide") &&
1254 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Power<T>>, T>(numpy, "power") &&
1255 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy, "remainder") &&
1256 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy, "mod") &&
1257 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Fmod<T>>, T>(numpy, "fmod") &&
1258 RegisterNpTypeUFunc<ufuncs::DivmodUFunc<T>, T>(numpy, "divmod") &&
1259 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "absolute") &&
1260 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "fabs") &&
1261 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Rint<T>>, T>(numpy, "rint") &&
1262 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Sign<T>>, T>(numpy, "sign") &&
1263 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Heaviside<T>>, T>(numpy, "heaviside") &&
1264 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Conjugate<T>>, T>(numpy, "conjugate") &&
1265 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Exp<T>>, T>(numpy, "exp") &&
1266 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Exp2<T>>, T>(numpy, "exp2") &&
1267 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Expm1<T>>, T>(numpy, "expm1") &&
1268 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Log<T>>, T>(numpy, "log") &&
1269 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Log1p<T>>, T>(numpy, "log1p") &&
1270 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Log2<T>>, T>(numpy, "log2") &&
1271 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Log10<T>>, T>(numpy, "log10") &&
1272 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Sqrt<T>>, T>(numpy, "sqrt") &&
1273 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Square<T>>, T>(numpy, "square") &&
1274 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Cbrt<T>>, T>(numpy, "cbrt") &&
1275 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Reciprocal<T>>, T>(numpy, "reciprocal") &&
1276 // Trigonometric functions
1277 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Sin<T>>, T>(numpy, "sin") &&
1278 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Cos<T>>, T>(numpy, "cos") &&
1279 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Tan<T>>, T>(numpy, "tan") &&
1280 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arcsin<T>>, T>(numpy, "arcsin") &&
1281 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arccos<T>>, T>(numpy, "arccos") &&
1282 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arctan<T>>, T>(numpy, "arctan") &&
1283 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Arctan2<T>>, T>(numpy, "arctan2") &&
1284 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Hypot<T>>, T>(numpy, "hypot") &&
1285 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Sinh<T>>, T>(numpy, "sinh") &&
1286 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Cosh<T>>, T>(numpy, "cosh") &&
1287 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Tanh<T>>, T>(numpy, "tanh") &&
1288 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arcsinh<T>>, T>(numpy, "arcsinh") &&
1289 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arccosh<T>>, T>(numpy, "arccosh") &&
1290 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Arctanh<T>>, T>(numpy, "arctanh") &&
1291 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Deg2rad<T>>, T>(numpy, "deg2rad") &&
1292 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Rad2deg<T>>, T>(numpy, "rad2deg") &&
1293 // Comparison functions
1294 RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Eq<T>>, T>(numpy, "equal") &&
1295 RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Ne<T>>, T>(numpy, "not_equal") &&
1296 RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Lt<T>>, T>(numpy, "less") &&
1297 RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Le<T>>, T>(numpy, "less_equal") &&
1298 RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Gt<T>>, T>(numpy, "greater") &&
1299 RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::Ge<T>>, T>(numpy, "greater_equal") &&
1300 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Maximum<T>>, T>(numpy, "maximum") &&
1301 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Minimum<T>>, T>(numpy, "minimum") &&
1302 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Fmax<T>>, T>(numpy, "fmax") &&
1303 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::Fmin<T>>, T>(numpy, "fmin") &&
1304 RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::LogicalAnd<T>>, T>(numpy, "logical_and") &&
1305 RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::LogicalOr<T>>, T>(numpy, "logical_or") &&
1306 RegisterNpTypeUFunc<BinaryUFunc<T, bool, ufuncs::LogicalXor<T>>, T>(numpy, "logical_xor") &&
1307 RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::LogicalNot<T>>, T>(numpy, "logical_not") &&
1308 // Floating point functions
1309 RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::IsFinite<T>>, T>(numpy, "isfinite") &&
1310 RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::IsInf<T>>, T>(numpy, "isinf") &&
1311 RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::IsNan<T>>, T>(numpy, "isnan") &&
1312 RegisterNpTypeUFunc<UnaryUFunc<T, bool, ufuncs::SignBit<T>>, T>(numpy, "signbit") &&
1313 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::CopySign<T>>, T>(numpy, "copysign") &&
1314 RegisterNpTypeUFunc<UnaryUFunc2<T, T, T, ufuncs::Modf<T>>, T>(numpy, "modf") &&
1315 RegisterNpTypeUFunc<BinaryUFunc2<T, int, T, ufuncs::Ldexp<T>>, T>(numpy, "ldexp") &&
1316 RegisterNpTypeUFunc<UnaryUFunc2<T, T, int, ufuncs::Frexp<T>>, T>(numpy, "frexp") &&
1317 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Floor<T>>, T>(numpy, "floor") &&
1318 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Ceil<T>>, T>(numpy, "ceil") &&
1319 RegisterNpTypeUFunc<UnaryUFunc<T, T, ufuncs::Trunc<T>>, T>(numpy, "trunc") &&
1320 RegisterNpTypeUFunc<BinaryUFunc<T, T, ufuncs::NextAfter<T>>, T>(numpy, "nextafter");
1321 return ok;
1322 }
1323
1324 template <typename T>
RegisterNumpyType()1325 bool RegisterNumpyType() {
1326 // Check if current type is already initialized.
1327 if (NpTypeDescr<T>::Dtype() != NPY_NOTYPE) {
1328 return true;
1329 }
1330
1331 // Import Python modules
1332 import_array1(false);
1333 import_umath1(false);
1334 PyObjectPtr numpy_str = SafePtr(PyUnicode_FromString("numpy"));
1335 if (!numpy_str) {
1336 return false;
1337 }
1338 PyObjectPtr numpy_obj = SafePtr(PyImport_Import(numpy_str.get()));
1339 if (!numpy_obj) {
1340 return false;
1341 }
1342 // Initializes the NumPy type.
1343 PyHeapTypeObject *heap_type = reinterpret_cast<PyHeapTypeObject *>(PyType_Type.tp_alloc(&PyType_Type, 0));
1344 if (!heap_type) {
1345 return false;
1346 }
1347 PyObjectPtr name = SafePtr(PyUnicode_FromString(NpTypeDescr<T>::type_name));
1348 PyObjectPtr qualname = SafePtr(PyUnicode_FromString(NpTypeDescr<T>::type_name));
1349 heap_type->ht_name = name.release();
1350 heap_type->ht_qualname = qualname.release();
1351 PyTypeObject *py_type = &heap_type->ht_type;
1352 py_type->tp_name = NpTypeDescr<T>::type_name;
1353 py_type->tp_basicsize = sizeof(PyType<T>);
1354 py_type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
1355 py_type->tp_base = &PyGenericArrType_Type;
1356 py_type->tp_new = PyType_New<T>;
1357 py_type->tp_repr = PyType_Repr<T>;
1358 py_type->tp_hash = PyType_Hash<T>;
1359 py_type->tp_str = PyType_Str<T>;
1360 py_type->tp_doc = const_cast<char *>(NpTypeDescr<T>::type_doc);
1361 py_type->tp_richcompare = PyType_RichCompare<T>;
1362 py_type->tp_as_number = &NpTypeDescr<T>::number_methods;
1363 if (PyType_Ready(py_type) < 0) {
1364 return false;
1365 }
1366 NpTypeDescr<T>::np_type_ptr = py_type;
1367
1368 // Initializes the NumPy descriptor.
1369 PyArray_ArrFuncs &arr_funcs = NpTypeDescr<T>::arr_funcs;
1370 PyArray_InitArrFuncs(&arr_funcs);
1371 arr_funcs.getitem = NpType_GetItem<T>;
1372 arr_funcs.setitem = NpType_SetItem<T>;
1373 arr_funcs.compare = NpType_Compare<T>;
1374 arr_funcs.copyswapn = NpType_CopySwapN<T>;
1375 arr_funcs.copyswap = NpType_CopySwap<T>;
1376 arr_funcs.nonzero = NpType_NonZero<T>;
1377 arr_funcs.fill = NpType_Fill<T>;
1378 arr_funcs.dotfunc = NpType_Dot<T>;
1379 arr_funcs.argmax = NpType_ArgMax<T>;
1380 arr_funcs.argmin = NpType_ArgMin<T>;
1381
1382 // Before NumPy 2.0, we allocate and manage the lifetime of descriptor, and Numpy only stores the pointer.
1383 // After NumPy 2.0, NumPy allocates and manages the lifetime of the descriptor.
1384 #if NPY_ABI_VERSION < 0x02000000
1385 PyArray_DescrProto *descr_proto = &NpTypeDescr<T>::np_descr;
1386 #else
1387 PyArray_DescrProto descr_proto_storage;
1388 PyArray_DescrProto *descr_proto = &descr_proto_storage;
1389 #endif
1390 *descr_proto = GetNpDescrProto<T>();
1391 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE)
1392 Py_TYPE(descr_proto) = &PyArrayDescr_Type;
1393 #else
1394 Py_SET_TYPE(descr_proto, &PyArrayDescr_Type);
1395 #endif
1396 descr_proto->typeobj = py_type;
1397
1398 NpTypeDescr<T>::np_type_num = PyArray_RegisterDataType(descr_proto);
1399 if (NpTypeDescr<T>::Dtype() < 0) {
1400 return false;
1401 }
1402 #if NPY_ABI_VERSION >= 0x02000000
1403 NpTypeDescr<T>::np_descr = *PyArray_DescrFromType(NpTypeDescr<T>::Dtype());
1404 #endif
1405 if (NpTypeDescr<T>::Dtype() < 0) {
1406 return false;
1407 }
1408
1409 // Support numpy.dtype(type_name)
1410 PyObjectPtr np_type_dict = SafePtr(PyObject_GetAttrString(numpy_obj.get(), "sctypeDict"));
1411 if (!np_type_dict) {
1412 return false;
1413 }
1414 if (PyDict_SetItemString(np_type_dict.get(), NpTypeDescr<T>::type_name,
1415 reinterpret_cast<PyObject *>(NpTypeDescr<T>::TypePtr())) < 0) {
1416 return false;
1417 }
1418
1419 // Support dtype(type_name)
1420 if (PyObject_SetAttrString(reinterpret_cast<PyObject *>(NpTypeDescr<T>::TypePtr()), "dtype",
1421 reinterpret_cast<PyObject *>(&NpTypeDescr<T>::np_descr)) < 0) {
1422 return false;
1423 }
1424
1425 // Register casts
1426 if (!RegisterNpTypeCasts<T>()) {
1427 return false;
1428 }
1429
1430 // Register UFuncs
1431 if (!RegisterNpTypeUFuncs<T>(numpy_obj.get())) {
1432 return false;
1433 }
1434
1435 return true;
1436 }
1437
GetNumpyVersion()1438 std::string GetNumpyVersion() {
1439 static std::string version_str = "";
1440 if (!version_str.empty()) {
1441 return version_str;
1442 }
1443 PyObjectPtr numpy_str = SafePtr(PyUnicode_FromString("numpy"));
1444 if (!numpy_str) {
1445 return version_str;
1446 }
1447 PyObjectPtr numpy_obj = SafePtr(PyImport_Import(numpy_str.get()));
1448 if (!numpy_obj) {
1449 return version_str;
1450 }
1451 PyObject *numpy_dict = PyModule_GetDict(numpy_obj.get());
1452 if (!numpy_dict) {
1453 return version_str;
1454 }
1455 PyObject *numpy_version = PyDict_GetItemString(numpy_dict, "__version__");
1456 if (!numpy_version || !PyUnicode_Check(numpy_version)) {
1457 return version_str;
1458 }
1459 const char *version_c = PyUnicode_AsUTF8(numpy_version);
1460 if (!version_c) {
1461 return version_str;
1462 }
1463 version_str = version_c;
1464 MS_LOG(DEBUG) << "Current numpy version:" << version_str;
1465 return version_str;
1466 }
1467
GetMinimumSupportedNumpyVersion()1468 std::string GetMinimumSupportedNumpyVersion() {
1469 switch (NPY_API_VERSION) {
1470 case 0x0000000d: // 1.19.3+
1471 return "1.19.3";
1472 case 0x0000000e: // 1.20 & 1.21
1473 return "1.20.0";
1474 case 0x0000000f: // 1.22
1475 return "1.22.0";
1476 case 0x00000010: // 1.23 & 1.24
1477 return "1.23.0";
1478 case 0x00000011: // 1.25 & 1.26
1479 return "1.20.0";
1480 case 0x00000012: // 2.0
1481 return "2.0.0";
1482 default: // Values that exceed the macro definition limit.
1483 return (NPY_API_VERSION < 0x0000000d) ? "1.19.3" : "2.0.0";
1484 }
1485 }
1486
NumpyVersionValid(std::string version)1487 bool NumpyVersionValid(std::string version) {
1488 // Get current numpy versions
1489 if (version.empty()) {
1490 return false;
1491 }
1492 std::replace(version.begin(), version.end(), '.', ' ');
1493 std::istringstream iss(version);
1494 std::vector<int> version_parts(3);
1495 // version_parts[i] will be 0 if string is invalid.
1496 iss >> version_parts[0] >> version_parts[1] >> version_parts[2];
1497 // Get minimum supported numpy version
1498 std::string minimum_version = GetMinimumSupportedNumpyVersion();
1499 if (minimum_version.empty()) {
1500 return false;
1501 }
1502 std::replace(minimum_version.begin(), minimum_version.end(), '.', ' ');
1503 std::istringstream minimum_iss(minimum_version);
1504 std::vector<int> minimum_version_parts(3);
1505 minimum_iss >> minimum_version_parts[0] >> minimum_version_parts[1] >> minimum_version_parts[2];
1506 return (version_parts[0] == minimum_version_parts[0]) && (version_parts[1] >= minimum_version_parts[1]);
1507 }
1508
RegisterNumpyTypes()1509 void RegisterNumpyTypes() {
1510 std::string numpy_version = GetNumpyVersion();
1511 std::string minimum_numpy_version = GetMinimumSupportedNumpyVersion();
1512 if (!NumpyVersionValid(numpy_version)) {
1513 MS_LOG(INFO) << "For asnumpy, the numpy bfloat16 data type is supported in Numpy versions " << minimum_numpy_version
1514 << " to " << minimum_numpy_version[0] << ".x.x, but got " << numpy_version
1515 << ", please upgrade numpy version.";
1516 return;
1517 }
1518 if (!RegisterNumpyType<bfloat16>()) {
1519 if (PyErr_Occurred()) {
1520 PyErr_Print();
1521 }
1522 MS_LOG(EXCEPTION) << "Failed to register BFloat16 type!";
1523 }
1524 }
1525 } // namespace np_dtypes
1526
GetBFloat16NpDType()1527 int GetBFloat16NpDType() { return np_dtypes::NpTypeDescr<bfloat16>::Dtype(); }
1528
IsNumpyVersionValid(bool show_warning=false)1529 bool IsNumpyVersionValid(bool show_warning = false) {
1530 std::string numpy_version = np_dtypes::GetNumpyVersion();
1531 std::string minimum_numpy_version = np_dtypes::GetMinimumSupportedNumpyVersion();
1532 if (!np_dtypes::NumpyVersionValid(numpy_version)) {
1533 if (show_warning) {
1534 MS_LOG(WARNING) << "For asnumpy, the numpy bfloat16 data type is supported in Numpy versions "
1535 << minimum_numpy_version << " to " << minimum_numpy_version[0] << ".x.x, but got "
1536 << numpy_version << ", please upgrade numpy version.";
1537 }
1538 return false;
1539 }
1540 return true;
1541 }
1542
RegNumpyTypes(py::module * m)1543 void RegNumpyTypes(py::module *m) {
1544 np_dtypes::RegisterNumpyTypes();
1545 auto m_sub = m->def_submodule("np_dtypes", "types of numpy");
1546 m_sub.add_object("bfloat16", reinterpret_cast<PyObject *>(np_dtypes::NpTypeDescr<bfloat16>::TypePtr()));
1547 (void)m_sub.def("np_version_valid", &IsNumpyVersionValid, "Check whether numpy version is valid");
1548 }
1549 } // namespace mindspore
1550
1551 #if NPY_ABI_VERSION < 0x02000000
1552 #undef PyArray_DescrProto
1553 #endif
1554