1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/lib/core/bfloat16.h"
17
18 #include <array>
19 #include <locale>
20 // Place `<locale>` before <Python.h> to avoid a build failure in macOS.
21 #include <Python.h>
22
23 #include "absl/strings/str_cat.h"
24 #include "third_party/eigen3/Eigen/Core"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/python/lib/core/numpy.h"
27
28 namespace tensorflow {
29 namespace {
30
31 using bfloat16 = Eigen::bfloat16;
32
33 struct PyDecrefDeleter {
operator ()tensorflow::__anonea8c089e0111::PyDecrefDeleter34 void operator()(PyObject* p) const { Py_DECREF(p); }
35 };
36
37 // Safe container for an owned PyObject. On destruction, the reference count of
38 // the contained object will be decremented.
39 using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
make_safe(PyObject * object)40 Safe_PyObjectPtr make_safe(PyObject* object) {
41 return Safe_PyObjectPtr(object);
42 }
43
PyLong_CheckNoOverflow(PyObject * object)44 bool PyLong_CheckNoOverflow(PyObject* object) {
45 if (!PyLong_Check(object)) {
46 return false;
47 }
48 int overflow = 0;
49 PyLong_AsLongAndOverflow(object, &overflow);
50 return (overflow == 0);
51 }
52
53 // Registered numpy type ID. Global variable populated by the registration code.
54 // Protected by the GIL.
55 int npy_bfloat16 = NPY_NOTYPE;
56
57 // Forward declaration.
58 extern PyTypeObject bfloat16_type;
59
60 // Pointer to the bfloat16 type object we are using. This is either a pointer
61 // to bfloat16_type, if we choose to register it, or to the bfloat16 type
62 // registered by another system into NumPy.
63 PyTypeObject* bfloat16_type_ptr = nullptr;
64
65 // Representation of a Python bfloat16 object.
66 struct PyBfloat16 {
67 PyObject_HEAD; // Python object header
68 bfloat16 value;
69 };
70
71 // Returns true if 'object' is a PyBfloat16.
PyBfloat16_Check(PyObject * object)72 bool PyBfloat16_Check(PyObject* object) {
73 return PyObject_IsInstance(object,
74 reinterpret_cast<PyObject*>(&bfloat16_type));
75 }
76
77 // Extracts the value of a PyBfloat16 object.
PyBfloat16_Bfloat16(PyObject * object)78 bfloat16 PyBfloat16_Bfloat16(PyObject* object) {
79 return reinterpret_cast<PyBfloat16*>(object)->value;
80 }
81
82 // Constructs a PyBfloat16 object from a bfloat16.
PyBfloat16_FromBfloat16(bfloat16 x)83 Safe_PyObjectPtr PyBfloat16_FromBfloat16(bfloat16 x) {
84 Safe_PyObjectPtr ref = make_safe(bfloat16_type.tp_alloc(&bfloat16_type, 0));
85 PyBfloat16* p = reinterpret_cast<PyBfloat16*>(ref.get());
86 if (p) {
87 p->value = x;
88 }
89 return ref;
90 }
91
92 // Converts a Python object to a bfloat16 value. Returns true on success,
93 // returns false and reports a Python error on failure.
CastToBfloat16(PyObject * arg,bfloat16 * output)94 bool CastToBfloat16(PyObject* arg, bfloat16* output) {
95 if (PyBfloat16_Check(arg)) {
96 *output = PyBfloat16_Bfloat16(arg);
97 return true;
98 }
99 if (PyFloat_Check(arg)) {
100 double d = PyFloat_AsDouble(arg);
101 if (PyErr_Occurred()) {
102 return false;
103 }
104 // TODO(phawkins): check for overflow
105 *output = bfloat16(d);
106 return true;
107 }
108 if (PyLong_CheckNoOverflow(arg)) {
109 long l = PyLong_AsLong(arg); // NOLINT
110 if (PyErr_Occurred()) {
111 return false;
112 }
113 // TODO(phawkins): check for overflow
114 *output = bfloat16(static_cast<float>(l));
115 return true;
116 }
117 if (PyArray_IsScalar(arg, Half)) {
118 Eigen::half f;
119 PyArray_ScalarAsCtype(arg, &f);
120 *output = bfloat16(f);
121 return true;
122 }
123 if (PyArray_IsScalar(arg, Float)) {
124 float f;
125 PyArray_ScalarAsCtype(arg, &f);
126 *output = bfloat16(f);
127 return true;
128 }
129 if (PyArray_IsScalar(arg, Double)) {
130 double f;
131 PyArray_ScalarAsCtype(arg, &f);
132 *output = bfloat16(f);
133 return true;
134 }
135 if (PyArray_IsZeroDim(arg)) {
136 Safe_PyObjectPtr ref;
137 PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
138 if (PyArray_TYPE(arr) != npy_bfloat16) {
139 ref = make_safe(PyArray_Cast(arr, npy_bfloat16));
140 if (PyErr_Occurred()) {
141 return false;
142 }
143 arg = ref.get();
144 arr = reinterpret_cast<PyArrayObject*>(arg);
145 }
146 *output = *reinterpret_cast<bfloat16*>(PyArray_DATA(arr));
147 return true;
148 }
149 return false;
150 }
151
SafeCastToBfloat16(PyObject * arg,bfloat16 * output)152 bool SafeCastToBfloat16(PyObject* arg, bfloat16* output) {
153 if (PyBfloat16_Check(arg)) {
154 *output = PyBfloat16_Bfloat16(arg);
155 return true;
156 }
157 return false;
158 }
159
160 // Converts a PyBfloat16 into a PyFloat.
PyBfloat16_Float(PyObject * self)161 PyObject* PyBfloat16_Float(PyObject* self) {
162 bfloat16 x = PyBfloat16_Bfloat16(self);
163 return PyFloat_FromDouble(static_cast<double>(x));
164 }
165
166 // Converts a PyBfloat16 into a PyInt.
PyBfloat16_Int(PyObject * self)167 PyObject* PyBfloat16_Int(PyObject* self) {
168 bfloat16 x = PyBfloat16_Bfloat16(self);
169 long y = static_cast<long>(x); // NOLINT
170 return PyLong_FromLong(y);
171 }
172
173 // Negates a PyBfloat16.
PyBfloat16_Negative(PyObject * self)174 PyObject* PyBfloat16_Negative(PyObject* self) {
175 bfloat16 x = PyBfloat16_Bfloat16(self);
176 return PyBfloat16_FromBfloat16(-x).release();
177 }
178
PyBfloat16_Add(PyObject * a,PyObject * b)179 PyObject* PyBfloat16_Add(PyObject* a, PyObject* b) {
180 bfloat16 x, y;
181 if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
182 return PyBfloat16_FromBfloat16(x + y).release();
183 }
184 return PyArray_Type.tp_as_number->nb_add(a, b);
185 }
186
PyBfloat16_Subtract(PyObject * a,PyObject * b)187 PyObject* PyBfloat16_Subtract(PyObject* a, PyObject* b) {
188 bfloat16 x, y;
189 if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
190 return PyBfloat16_FromBfloat16(x - y).release();
191 }
192 return PyArray_Type.tp_as_number->nb_subtract(a, b);
193 }
194
PyBfloat16_Multiply(PyObject * a,PyObject * b)195 PyObject* PyBfloat16_Multiply(PyObject* a, PyObject* b) {
196 bfloat16 x, y;
197 if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
198 return PyBfloat16_FromBfloat16(x * y).release();
199 }
200 return PyArray_Type.tp_as_number->nb_multiply(a, b);
201 }
202
PyBfloat16_TrueDivide(PyObject * a,PyObject * b)203 PyObject* PyBfloat16_TrueDivide(PyObject* a, PyObject* b) {
204 bfloat16 x, y;
205 if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
206 return PyBfloat16_FromBfloat16(x / y).release();
207 }
208 return PyArray_Type.tp_as_number->nb_true_divide(a, b);
209 }
210
211 // Python number methods for PyBfloat16 objects.
212 PyNumberMethods PyBfloat16_AsNumber = {
213 PyBfloat16_Add, // nb_add
214 PyBfloat16_Subtract, // nb_subtract
215 PyBfloat16_Multiply, // nb_multiply
216 nullptr, // nb_remainder
217 nullptr, // nb_divmod
218 nullptr, // nb_power
219 PyBfloat16_Negative, // nb_negative
220 nullptr, // nb_positive
221 nullptr, // nb_absolute
222 nullptr, // nb_nonzero
223 nullptr, // nb_invert
224 nullptr, // nb_lshift
225 nullptr, // nb_rshift
226 nullptr, // nb_and
227 nullptr, // nb_xor
228 nullptr, // nb_or
229 PyBfloat16_Int, // nb_int
230 nullptr, // reserved
231 PyBfloat16_Float, // nb_float
232
233 nullptr, // nb_inplace_add
234 nullptr, // nb_inplace_subtract
235 nullptr, // nb_inplace_multiply
236 nullptr, // nb_inplace_remainder
237 nullptr, // nb_inplace_power
238 nullptr, // nb_inplace_lshift
239 nullptr, // nb_inplace_rshift
240 nullptr, // nb_inplace_and
241 nullptr, // nb_inplace_xor
242 nullptr, // nb_inplace_or
243
244 nullptr, // nb_floor_divide
245 PyBfloat16_TrueDivide, // nb_true_divide
246 nullptr, // nb_inplace_floor_divide
247 nullptr, // nb_inplace_true_divide
248 nullptr, // nb_index
249 };
250
251 // Constructs a new PyBfloat16.
PyBfloat16_New(PyTypeObject * type,PyObject * args,PyObject * kwds)252 PyObject* PyBfloat16_New(PyTypeObject* type, PyObject* args, PyObject* kwds) {
253 if (kwds && PyDict_Size(kwds)) {
254 PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
255 return nullptr;
256 }
257 Py_ssize_t size = PyTuple_Size(args);
258 if (size != 1) {
259 PyErr_SetString(PyExc_TypeError,
260 "expected number as argument to bfloat16 constructor");
261 return nullptr;
262 }
263 PyObject* arg = PyTuple_GetItem(args, 0);
264
265 bfloat16 value;
266 if (PyBfloat16_Check(arg)) {
267 Py_INCREF(arg);
268 return arg;
269 } else if (CastToBfloat16(arg, &value)) {
270 return PyBfloat16_FromBfloat16(value).release();
271 } else if (PyArray_Check(arg)) {
272 PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
273 if (PyArray_TYPE(arr) != npy_bfloat16) {
274 return PyArray_Cast(arr, npy_bfloat16);
275 } else {
276 Py_INCREF(arg);
277 return arg;
278 }
279 }
280 PyErr_Format(PyExc_TypeError, "expected number, got %s",
281 arg->ob_type->tp_name);
282 return nullptr;
283 }
284
285 // Comparisons on PyBfloat16s.
PyBfloat16_RichCompare(PyObject * a,PyObject * b,int op)286 PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
287 bfloat16 x, y;
288 if (!SafeCastToBfloat16(a, &x) || !SafeCastToBfloat16(b, &y)) {
289 return PyGenericArrType_Type.tp_richcompare(a, b, op);
290 }
291 bool result;
292 switch (op) {
293 case Py_LT:
294 result = x < y;
295 break;
296 case Py_LE:
297 result = x <= y;
298 break;
299 case Py_EQ:
300 result = x == y;
301 break;
302 case Py_NE:
303 result = x != y;
304 break;
305 case Py_GT:
306 result = x > y;
307 break;
308 case Py_GE:
309 result = x >= y;
310 break;
311 default:
312 LOG(FATAL) << "Invalid op type " << op;
313 }
314 return PyBool_FromLong(result);
315 }
316
317 // Implementation of repr() for PyBfloat16.
PyBfloat16_Repr(PyObject * self)318 PyObject* PyBfloat16_Repr(PyObject* self) {
319 bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
320 std::string v = absl::StrCat(static_cast<float>(x));
321 return PyUnicode_FromString(v.c_str());
322 }
323
324 // Implementation of str() for PyBfloat16.
PyBfloat16_Str(PyObject * self)325 PyObject* PyBfloat16_Str(PyObject* self) {
326 bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
327 std::string v = absl::StrCat(static_cast<float>(x));
328 return PyUnicode_FromString(v.c_str());
329 }
330
331 // Hash function for PyBfloat16. We use the identity function, which is a weak
332 // hash function.
PyBfloat16_Hash(PyObject * self)333 Py_hash_t PyBfloat16_Hash(PyObject* self) {
334 bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
335 return x.value;
336 }
337
338 // Python type for PyBfloat16 objects.
339 PyTypeObject bfloat16_type = {
340 PyVarObject_HEAD_INIT(nullptr, 0) "bfloat16", // tp_name
341 sizeof(PyBfloat16), // tp_basicsize
342 0, // tp_itemsize
343 nullptr, // tp_dealloc
344 #if PY_VERSION_HEX < 0x03080000
345 nullptr, // tp_print
346 #else
347 0, // tp_vectorcall_offset
348 #endif
349 nullptr, // tp_getattr
350 nullptr, // tp_setattr
351 nullptr, // tp_compare / tp_reserved
352 PyBfloat16_Repr, // tp_repr
353 &PyBfloat16_AsNumber, // tp_as_number
354 nullptr, // tp_as_sequence
355 nullptr, // tp_as_mapping
356 PyBfloat16_Hash, // tp_hash
357 nullptr, // tp_call
358 PyBfloat16_Str, // tp_str
359 nullptr, // tp_getattro
360 nullptr, // tp_setattro
361 nullptr, // tp_as_buffer
362 // tp_flags
363 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
364 "bfloat16 floating-point values", // tp_doc
365 nullptr, // tp_traverse
366 nullptr, // tp_clear
367 PyBfloat16_RichCompare, // tp_richcompare
368 0, // tp_weaklistoffset
369 nullptr, // tp_iter
370 nullptr, // tp_iternext
371 nullptr, // tp_methods
372 nullptr, // tp_members
373 nullptr, // tp_getset
374 nullptr, // tp_base
375 nullptr, // tp_dict
376 nullptr, // tp_descr_get
377 nullptr, // tp_descr_set
378 0, // tp_dictoffset
379 nullptr, // tp_init
380 nullptr, // tp_alloc
381 PyBfloat16_New, // tp_new
382 nullptr, // tp_free
383 nullptr, // tp_is_gc
384 nullptr, // tp_bases
385 nullptr, // tp_mro
386 nullptr, // tp_cache
387 nullptr, // tp_subclasses
388 nullptr, // tp_weaklist
389 nullptr, // tp_del
390 0, // tp_version_tag
391 };
392
393 // Numpy support
394
395 PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
396
397 PyArray_Descr NPyBfloat16_Descr = {
398 PyObject_HEAD_INIT(nullptr) //
399 /*typeobj=*/
400 (&bfloat16_type),
401 // We must register bfloat16 with a kind other than "f", because numpy
402 // considers two types with the same kind and size to be equal, but
403 // float16 != bfloat16.
404 // The downside of this is that NumPy scalar promotion does not work with
405 // bfloat16 values.
406 /*kind=*/'V',
407 // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
408 // character is unique.
409 /*type=*/'E',
410 /*byteorder=*/'=',
411 /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
412 /*type_num=*/0,
413 /*elsize=*/sizeof(bfloat16),
414 /*alignment=*/alignof(bfloat16),
415 /*subarray=*/nullptr,
416 /*fields=*/nullptr,
417 /*names=*/nullptr,
418 /*f=*/&NPyBfloat16_ArrFuncs,
419 /*metadata=*/nullptr,
420 /*c_metadata=*/nullptr,
421 /*hash=*/-1, // -1 means "not computed yet".
422 };
423
424 // Implementations of NumPy array methods.
425
NPyBfloat16_GetItem(void * data,void * arr)426 PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
427 bfloat16 x;
428 memcpy(&x, data, sizeof(bfloat16));
429 return PyBfloat16_FromBfloat16(x).release();
430 }
431
NPyBfloat16_SetItem(PyObject * item,void * data,void * arr)432 int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) {
433 bfloat16 x;
434 if (!CastToBfloat16(item, &x)) {
435 PyErr_Format(PyExc_TypeError, "expected number, got %s",
436 item->ob_type->tp_name);
437 return -1;
438 }
439 memcpy(data, &x, sizeof(bfloat16));
440 return 0;
441 }
442
ByteSwap16(void * value)443 void ByteSwap16(void* value) {
444 char* p = reinterpret_cast<char*>(value);
445 std::swap(p[0], p[1]);
446 }
447
NPyBfloat16_Compare(const void * a,const void * b,void * arr)448 int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
449 bfloat16 x;
450 memcpy(&x, a, sizeof(bfloat16));
451
452 bfloat16 y;
453 memcpy(&y, b, sizeof(bfloat16));
454
455 if (x < y) {
456 return -1;
457 }
458 if (y < x) {
459 return 1;
460 }
461 // NaNs sort to the end.
462 if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) {
463 return -1;
464 }
465 if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) {
466 return 1;
467 }
468 return 0;
469 }
470
NPyBfloat16_CopySwapN(void * dstv,npy_intp dstride,void * srcv,npy_intp sstride,npy_intp n,int swap,void * arr)471 void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
472 npy_intp sstride, npy_intp n, int swap, void* arr) {
473 char* dst = reinterpret_cast<char*>(dstv);
474 char* src = reinterpret_cast<char*>(srcv);
475 if (!src) {
476 return;
477 }
478 if (swap) {
479 for (npy_intp i = 0; i < n; i++) {
480 char* r = dst + dstride * i;
481 memcpy(r, src + sstride * i, sizeof(uint16_t));
482 ByteSwap16(r);
483 }
484 } else if (dstride == sizeof(uint16_t) && sstride == sizeof(uint16_t)) {
485 memcpy(dst, src, n * sizeof(uint16_t));
486 } else {
487 for (npy_intp i = 0; i < n; i++) {
488 memcpy(dst + dstride * i, src + sstride * i, sizeof(uint16_t));
489 }
490 }
491 }
492
NPyBfloat16_CopySwap(void * dst,void * src,int swap,void * arr)493 void NPyBfloat16_CopySwap(void* dst, void* src, int swap, void* arr) {
494 if (!src) {
495 return;
496 }
497 memcpy(dst, src, sizeof(uint16_t));
498 if (swap) {
499 ByteSwap16(dst);
500 }
501 }
502
NPyBfloat16_NonZero(void * data,void * arr)503 npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
504 bfloat16 x;
505 memcpy(&x, data, sizeof(x));
506 return x != static_cast<bfloat16>(0);
507 }
508
NPyBfloat16_Fill(void * buffer_raw,npy_intp length,void * ignored)509 int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
510 bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
511 const float start(buffer[0]);
512 const float delta = static_cast<float>(buffer[1]) - start;
513 for (npy_intp i = 2; i < length; ++i) {
514 buffer[i] = static_cast<bfloat16>(start + i * delta);
515 }
516 return 0;
517 }
518
NPyBfloat16_DotFunc(void * ip1,npy_intp is1,void * ip2,npy_intp is2,void * op,npy_intp n,void * arr)519 void NPyBfloat16_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
520 void* op, npy_intp n, void* arr) {
521 char* c1 = reinterpret_cast<char*>(ip1);
522 char* c2 = reinterpret_cast<char*>(ip2);
523 float acc = 0.0f;
524 for (npy_intp i = 0; i < n; ++i) {
525 bfloat16* const b1 = reinterpret_cast<bfloat16*>(c1);
526 bfloat16* const b2 = reinterpret_cast<bfloat16*>(c2);
527 acc += static_cast<float>(*b1) * static_cast<float>(*b2);
528 c1 += is1;
529 c2 += is2;
530 }
531 bfloat16* out = reinterpret_cast<bfloat16*>(op);
532 *out = static_cast<bfloat16>(acc);
533 }
534
NPyBfloat16_CompareFunc(const void * v1,const void * v2,void * arr)535 int NPyBfloat16_CompareFunc(const void* v1, const void* v2, void* arr) {
536 bfloat16 b1 = *reinterpret_cast<const bfloat16*>(v1);
537 bfloat16 b2 = *reinterpret_cast<const bfloat16*>(v2);
538 if (b1 < b2) {
539 return -1;
540 }
541 if (b1 > b2) {
542 return 1;
543 }
544 return 0;
545 }
546
NPyBfloat16_ArgMaxFunc(void * data,npy_intp n,npy_intp * max_ind,void * arr)547 int NPyBfloat16_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
548 void* arr) {
549 const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
550 float max_val = -std::numeric_limits<float>::infinity();
551 for (npy_intp i = 0; i < n; ++i) {
552 if (static_cast<float>(bdata[i]) > max_val) {
553 max_val = static_cast<float>(bdata[i]);
554 *max_ind = i;
555 }
556 }
557 return 0;
558 }
559
NPyBfloat16_ArgMinFunc(void * data,npy_intp n,npy_intp * min_ind,void * arr)560 int NPyBfloat16_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
561 void* arr) {
562 const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
563 float min_val = std::numeric_limits<float>::infinity();
564 for (npy_intp i = 0; i < n; ++i) {
565 if (static_cast<float>(bdata[i]) < min_val) {
566 min_val = static_cast<float>(bdata[i]);
567 *min_ind = i;
568 }
569 }
570 return 0;
571 }
572
573 // NumPy casts
574
575 template <typename T, typename Enable = void>
576 struct TypeDescriptor {
577 // typedef ... T; // Representation type in memory for NumPy values of type
578 // static int Dtype() { return NPY_...; } // Numpy type number for T.
579 };
580
581 template <>
582 struct TypeDescriptor<bfloat16> {
583 typedef bfloat16 T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor584 static int Dtype() { return npy_bfloat16; }
585 };
586
587 template <>
588 struct TypeDescriptor<uint8> {
589 typedef uint8 T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor590 static int Dtype() { return NPY_UINT8; }
591 };
592
593 template <>
594 struct TypeDescriptor<uint16> {
595 typedef uint16 T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor596 static int Dtype() { return NPY_UINT16; }
597 };
598
599 // We register "int", "long", and "long long" types for portability across
600 // Linux, where "int" and "long" are the same type, and Windows, where "long"
601 // and "longlong" are the same type.
602 template <>
603 struct TypeDescriptor<unsigned int> {
604 typedef unsigned int T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor605 static int Dtype() { return NPY_UINT; }
606 };
607
608 template <>
609 struct TypeDescriptor<unsigned long> { // NOLINT
610 typedef unsigned long T; // NOLINT
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor611 static int Dtype() { return NPY_ULONG; }
612 };
613
614 template <>
615 struct TypeDescriptor<unsigned long long> { // NOLINT
616 typedef unsigned long long T; // NOLINT
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor617 static int Dtype() { return NPY_ULONGLONG; }
618 };
619
620 template <>
621 struct TypeDescriptor<int8> {
622 typedef int8 T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor623 static int Dtype() { return NPY_INT8; }
624 };
625
626 template <>
627 struct TypeDescriptor<int16> {
628 typedef int16 T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor629 static int Dtype() { return NPY_INT16; }
630 };
631
632 template <>
633 struct TypeDescriptor<int> {
634 typedef int T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor635 static int Dtype() { return NPY_INT; }
636 };
637
638 template <>
639 struct TypeDescriptor<long> { // NOLINT
640 typedef long T; // NOLINT
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor641 static int Dtype() { return NPY_LONG; }
642 };
643
644 template <>
645 struct TypeDescriptor<long long> { // NOLINT
646 typedef long long T; // NOLINT
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor647 static int Dtype() { return NPY_LONGLONG; }
648 };
649
650 template <>
651 struct TypeDescriptor<bool> {
652 typedef int8 T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor653 static int Dtype() { return NPY_BOOL; }
654 };
655
656 template <>
657 struct TypeDescriptor<Eigen::half> {
658 typedef Eigen::half T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor659 static int Dtype() { return NPY_HALF; }
660 };
661
662 template <>
663 struct TypeDescriptor<float> {
664 typedef float T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor665 static int Dtype() { return NPY_FLOAT; }
666 };
667
668 template <>
669 struct TypeDescriptor<double> {
670 typedef double T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor671 static int Dtype() { return NPY_DOUBLE; }
672 };
673
674 template <>
675 struct TypeDescriptor<std::complex<float>> {
676 typedef std::complex<float> T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor677 static int Dtype() { return NPY_COMPLEX64; }
678 };
679
680 template <>
681 struct TypeDescriptor<std::complex<double>> {
682 typedef std::complex<double> T;
Dtypetensorflow::__anonea8c089e0111::TypeDescriptor683 static int Dtype() { return NPY_COMPLEX128; }
684 };
685
686 // Performs a NumPy array cast from type 'From' to 'To'.
687 template <typename From, typename To>
NPyCast(void * from_void,void * to_void,npy_intp n,void * fromarr,void * toarr)688 void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
689 void* toarr) {
690 const auto* from =
691 reinterpret_cast<typename TypeDescriptor<From>::T*>(from_void);
692 auto* to = reinterpret_cast<typename TypeDescriptor<To>::T*>(to_void);
693 for (npy_intp i = 0; i < n; ++i) {
694 to[i] =
695 static_cast<typename TypeDescriptor<To>::T>(static_cast<To>(from[i]));
696 }
697 }
698
699 // Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy
700 // type corresponding to 'T'.
701 template <typename T>
RegisterBfloat16Cast(int numpy_type)702 bool RegisterBfloat16Cast(int numpy_type) {
703 PyArray_Descr* descr = PyArray_DescrFromType(numpy_type);
704 if (PyArray_RegisterCastFunc(descr, npy_bfloat16, NPyCast<T, bfloat16>) < 0) {
705 return false;
706 }
707 if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type,
708 NPyCast<bfloat16, T>) < 0) {
709 return false;
710 }
711 return true;
712 }
713
714 template <typename InType, typename OutType, typename Functor>
715 struct UnaryUFunc {
Typestensorflow::__anonea8c089e0111::UnaryUFunc716 static std::vector<int> Types() {
717 return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype()};
718 }
Calltensorflow::__anonea8c089e0111::UnaryUFunc719 static void Call(char** args, const npy_intp* dimensions,
720 const npy_intp* steps, void* data) {
721 const char* i0 = args[0];
722 char* o = args[1];
723 for (npy_intp k = 0; k < *dimensions; k++) {
724 auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
725 *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) = Functor()(x);
726 i0 += steps[0];
727 o += steps[1];
728 }
729 }
730 };
731
732 template <typename InType, typename OutType, typename OutType2,
733 typename Functor>
734 struct UnaryUFunc2 {
Typestensorflow::__anonea8c089e0111::UnaryUFunc2735 static std::vector<int> Types() {
736 return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<OutType>::Dtype(),
737 TypeDescriptor<OutType2>::Dtype()};
738 }
Calltensorflow::__anonea8c089e0111::UnaryUFunc2739 static void Call(char** args, const npy_intp* dimensions,
740 const npy_intp* steps, void* data) {
741 const char* i0 = args[0];
742 char* o0 = args[1];
743 char* o1 = args[2];
744 for (npy_intp k = 0; k < *dimensions; k++) {
745 auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
746 std::tie(*reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o0),
747 *reinterpret_cast<typename TypeDescriptor<OutType2>::T*>(o1)) =
748 Functor()(x);
749 i0 += steps[0];
750 o0 += steps[1];
751 o1 += steps[2];
752 }
753 }
754 };
755
756 template <typename InType, typename OutType, typename Functor>
757 struct BinaryUFunc {
Typestensorflow::__anonea8c089e0111::BinaryUFunc758 static std::vector<int> Types() {
759 return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType>::Dtype(),
760 TypeDescriptor<OutType>::Dtype()};
761 }
Calltensorflow::__anonea8c089e0111::BinaryUFunc762 static void Call(char** args, const npy_intp* dimensions,
763 const npy_intp* steps, void* data) {
764 const char* i0 = args[0];
765 const char* i1 = args[1];
766 char* o = args[2];
767 for (npy_intp k = 0; k < *dimensions; k++) {
768 auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
769 auto y = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i1);
770 *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
771 Functor()(x, y);
772 i0 += steps[0];
773 i1 += steps[1];
774 o += steps[2];
775 }
776 }
777 };
778
779 template <typename InType, typename InType2, typename OutType, typename Functor>
780 struct BinaryUFunc2 {
Typestensorflow::__anonea8c089e0111::BinaryUFunc2781 static std::vector<int> Types() {
782 return {TypeDescriptor<InType>::Dtype(), TypeDescriptor<InType2>::Dtype(),
783 TypeDescriptor<OutType>::Dtype()};
784 }
Calltensorflow::__anonea8c089e0111::BinaryUFunc2785 static void Call(char** args, const npy_intp* dimensions,
786 const npy_intp* steps, void* data) {
787 const char* i0 = args[0];
788 const char* i1 = args[1];
789 char* o = args[2];
790 for (npy_intp k = 0; k < *dimensions; k++) {
791 auto x = *reinterpret_cast<const typename TypeDescriptor<InType>::T*>(i0);
792 auto y =
793 *reinterpret_cast<const typename TypeDescriptor<InType2>::T*>(i1);
794 *reinterpret_cast<typename TypeDescriptor<OutType>::T*>(o) =
795 Functor()(x, y);
796 i0 += steps[0];
797 i1 += steps[1];
798 o += steps[2];
799 }
800 }
801 };
802
803 template <typename UFunc>
RegisterUFunc(PyObject * numpy,const char * name)804 bool RegisterUFunc(PyObject* numpy, const char* name) {
805 std::vector<int> types = UFunc::Types();
806 PyUFuncGenericFunction fn =
807 reinterpret_cast<PyUFuncGenericFunction>(UFunc::Call);
808 Safe_PyObjectPtr ufunc_obj = make_safe(PyObject_GetAttrString(numpy, name));
809 if (!ufunc_obj) {
810 return false;
811 }
812 PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
813 if (static_cast<int>(types.size()) != ufunc->nargs) {
814 PyErr_Format(PyExc_AssertionError,
815 "ufunc %s takes %d arguments, loop takes %lu", name,
816 ufunc->nargs, types.size());
817 return false;
818 }
819 if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16, fn,
820 const_cast<int*>(types.data()),
821 nullptr) < 0) {
822 return false;
823 }
824 return true;
825 }
826
827 namespace ufuncs {
828
829 struct Add {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Add830 bfloat16 operator()(bfloat16 a, bfloat16 b) { return a + b; }
831 };
832 struct Subtract {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Subtract833 bfloat16 operator()(bfloat16 a, bfloat16 b) { return a - b; }
834 };
835 struct Multiply {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Multiply836 bfloat16 operator()(bfloat16 a, bfloat16 b) { return a * b; }
837 };
838 struct TrueDivide {
operator ()tensorflow::__anonea8c089e0111::ufuncs::TrueDivide839 bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; }
840 };
841
divmod(float a,float b)842 std::pair<float, float> divmod(float a, float b) {
843 if (b == 0.0f) {
844 float nan = std::numeric_limits<float>::quiet_NaN();
845 return {nan, nan};
846 }
847 float mod = std::fmod(a, b);
848 float div = (a - mod) / b;
849 if (mod != 0.0f) {
850 if ((b < 0.0f) != (mod < 0.0f)) {
851 mod += b;
852 div -= 1.0f;
853 }
854 } else {
855 mod = std::copysign(0.0f, b);
856 }
857
858 float floordiv;
859 if (div != 0.0f) {
860 floordiv = std::floor(div);
861 if (div - floordiv > 0.5f) {
862 floordiv += 1.0f;
863 }
864 } else {
865 floordiv = std::copysign(0.0f, a / b);
866 }
867 return {floordiv, mod};
868 }
869
870 struct FloorDivide {
operator ()tensorflow::__anonea8c089e0111::ufuncs::FloorDivide871 bfloat16 operator()(bfloat16 a, bfloat16 b) {
872 return bfloat16(divmod(static_cast<float>(a), static_cast<float>(b)).first);
873 }
874 };
875 struct Remainder {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Remainder876 bfloat16 operator()(bfloat16 a, bfloat16 b) {
877 return bfloat16(
878 divmod(static_cast<float>(a), static_cast<float>(b)).second);
879 }
880 };
881 struct DivmodUFunc {
Typestensorflow::__anonea8c089e0111::ufuncs::DivmodUFunc882 static std::vector<int> Types() {
883 return {npy_bfloat16, npy_bfloat16, npy_bfloat16, npy_bfloat16};
884 }
Calltensorflow::__anonea8c089e0111::ufuncs::DivmodUFunc885 static void Call(char** args, npy_intp* dimensions, npy_intp* steps,
886 void* data) {
887 const char* i0 = args[0];
888 const char* i1 = args[1];
889 char* o0 = args[2];
890 char* o1 = args[3];
891 for (npy_intp k = 0; k < *dimensions; k++) {
892 bfloat16 x = *reinterpret_cast<const bfloat16*>(i0);
893 bfloat16 y = *reinterpret_cast<const bfloat16*>(i1);
894 float floordiv, mod;
895 std::tie(floordiv, mod) =
896 divmod(static_cast<float>(x), static_cast<float>(y));
897 *reinterpret_cast<bfloat16*>(o0) = bfloat16(floordiv);
898 *reinterpret_cast<bfloat16*>(o1) = bfloat16(mod);
899 i0 += steps[0];
900 i1 += steps[1];
901 o0 += steps[2];
902 o1 += steps[3];
903 }
904 }
905 };
906 struct Fmod {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Fmod907 bfloat16 operator()(bfloat16 a, bfloat16 b) {
908 return bfloat16(std::fmod(static_cast<float>(a), static_cast<float>(b)));
909 }
910 };
911 struct Negative {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Negative912 bfloat16 operator()(bfloat16 a) { return -a; }
913 };
914 struct Positive {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Positive915 bfloat16 operator()(bfloat16 a) { return a; }
916 };
917 struct Power {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Power918 bfloat16 operator()(bfloat16 a, bfloat16 b) {
919 return bfloat16(std::pow(static_cast<float>(a), static_cast<float>(b)));
920 }
921 };
922 struct Abs {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Abs923 bfloat16 operator()(bfloat16 a) {
924 return bfloat16(std::abs(static_cast<float>(a)));
925 }
926 };
927 struct Cbrt {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Cbrt928 bfloat16 operator()(bfloat16 a) {
929 return bfloat16(std::cbrt(static_cast<float>(a)));
930 }
931 };
932 struct Ceil {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Ceil933 bfloat16 operator()(bfloat16 a) {
934 return bfloat16(std::ceil(static_cast<float>(a)));
935 }
936 };
937 struct CopySign {
operator ()tensorflow::__anonea8c089e0111::ufuncs::CopySign938 bfloat16 operator()(bfloat16 a, bfloat16 b) {
939 return bfloat16(
940 std::copysign(static_cast<float>(a), static_cast<float>(b)));
941 }
942 };
943 struct Exp {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Exp944 bfloat16 operator()(bfloat16 a) {
945 return bfloat16(std::exp(static_cast<float>(a)));
946 }
947 };
948 struct Exp2 {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Exp2949 bfloat16 operator()(bfloat16 a) {
950 return bfloat16(std::exp2(static_cast<float>(a)));
951 }
952 };
953 struct Expm1 {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Expm1954 bfloat16 operator()(bfloat16 a) {
955 return bfloat16(std::expm1(static_cast<float>(a)));
956 }
957 };
958 struct Floor {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Floor959 bfloat16 operator()(bfloat16 a) {
960 return bfloat16(std::floor(static_cast<float>(a)));
961 }
962 };
963 struct Frexp {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Frexp964 std::pair<bfloat16, int> operator()(bfloat16 a) {
965 int exp;
966 float f = std::frexp(static_cast<float>(a), &exp);
967 return {bfloat16(f), exp};
968 }
969 };
970 struct Heaviside {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Heaviside971 bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
972 float x = static_cast<float>(bx);
973 if (Eigen::numext::isnan(x)) {
974 return bx;
975 }
976 if (x < 0) {
977 return bfloat16(0.0f);
978 }
979 if (x > 0) {
980 return bfloat16(1.0f);
981 }
982 return h0; // x == 0
983 }
984 };
985 struct Conjugate {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Conjugate986 bfloat16 operator()(bfloat16 a) { return a; }
987 };
988 struct IsFinite {
operator ()tensorflow::__anonea8c089e0111::ufuncs::IsFinite989 bool operator()(bfloat16 a) { return std::isfinite(static_cast<float>(a)); }
990 };
991 struct IsInf {
operator ()tensorflow::__anonea8c089e0111::ufuncs::IsInf992 bool operator()(bfloat16 a) { return std::isinf(static_cast<float>(a)); }
993 };
994 struct IsNan {
operator ()tensorflow::__anonea8c089e0111::ufuncs::IsNan995 bool operator()(bfloat16 a) {
996 return Eigen::numext::isnan(static_cast<float>(a));
997 }
998 };
999 struct Ldexp {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Ldexp1000 bfloat16 operator()(bfloat16 a, int exp) {
1001 return bfloat16(std::ldexp(static_cast<float>(a), exp));
1002 }
1003 };
1004 struct Log {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Log1005 bfloat16 operator()(bfloat16 a) {
1006 return bfloat16(std::log(static_cast<float>(a)));
1007 }
1008 };
1009 struct Log2 {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Log21010 bfloat16 operator()(bfloat16 a) {
1011 return bfloat16(std::log2(static_cast<float>(a)));
1012 }
1013 };
1014 struct Log10 {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Log101015 bfloat16 operator()(bfloat16 a) {
1016 return bfloat16(std::log10(static_cast<float>(a)));
1017 }
1018 };
1019 struct Log1p {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Log1p1020 bfloat16 operator()(bfloat16 a) {
1021 return bfloat16(std::log1p(static_cast<float>(a)));
1022 }
1023 };
1024 struct LogAddExp {
operator ()tensorflow::__anonea8c089e0111::ufuncs::LogAddExp1025 bfloat16 operator()(bfloat16 bx, bfloat16 by) {
1026 float x = static_cast<float>(bx);
1027 float y = static_cast<float>(by);
1028 if (x == y) {
1029 // Handles infinities of the same sign.
1030 return bfloat16(x + std::log(2.0f));
1031 }
1032 float out = std::numeric_limits<float>::quiet_NaN();
1033 if (x > y) {
1034 out = x + std::log1p(std::exp(y - x));
1035 } else if (x < y) {
1036 out = y + std::log1p(std::exp(x - y));
1037 }
1038 return bfloat16(out);
1039 }
1040 };
1041 struct LogAddExp2 {
operator ()tensorflow::__anonea8c089e0111::ufuncs::LogAddExp21042 bfloat16 operator()(bfloat16 bx, bfloat16 by) {
1043 float x = static_cast<float>(bx);
1044 float y = static_cast<float>(by);
1045 if (x == y) {
1046 // Handles infinities of the same sign.
1047 return bfloat16(x + 1.0f);
1048 }
1049 float out = std::numeric_limits<float>::quiet_NaN();
1050 if (x > y) {
1051 out = x + std::log1p(std::exp2(y - x)) / std::log(2.0f);
1052 } else if (x < y) {
1053 out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
1054 }
1055 return bfloat16(out);
1056 }
1057 };
1058 struct Modf {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Modf1059 std::pair<bfloat16, bfloat16> operator()(bfloat16 a) {
1060 float integral;
1061 float f = std::modf(static_cast<float>(a), &integral);
1062 return {bfloat16(f), bfloat16(integral)};
1063 }
1064 };
1065
1066 struct Reciprocal {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Reciprocal1067 bfloat16 operator()(bfloat16 a) {
1068 return bfloat16(1.f / static_cast<float>(a));
1069 }
1070 };
1071 struct Rint {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Rint1072 bfloat16 operator()(bfloat16 a) {
1073 return bfloat16(std::rint(static_cast<float>(a)));
1074 }
1075 };
1076 struct Sign {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Sign1077 bfloat16 operator()(bfloat16 a) {
1078 float f(a);
1079 if (f < 0) {
1080 return bfloat16(-1);
1081 }
1082 if (f > 0) {
1083 return bfloat16(1);
1084 }
1085 return a;
1086 }
1087 };
1088 struct SignBit {
operator ()tensorflow::__anonea8c089e0111::ufuncs::SignBit1089 bool operator()(bfloat16 a) { return std::signbit(static_cast<float>(a)); }
1090 };
1091 struct Sqrt {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Sqrt1092 bfloat16 operator()(bfloat16 a) {
1093 return bfloat16(std::sqrt(static_cast<float>(a)));
1094 }
1095 };
1096 struct Square {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Square1097 bfloat16 operator()(bfloat16 a) {
1098 float f(a);
1099 return bfloat16(f * f);
1100 }
1101 };
1102 struct Trunc {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Trunc1103 bfloat16 operator()(bfloat16 a) {
1104 return bfloat16(std::trunc(static_cast<float>(a)));
1105 }
1106 };
1107
1108 // Trigonometric functions
1109 struct Sin {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Sin1110 bfloat16 operator()(bfloat16 a) {
1111 return bfloat16(std::sin(static_cast<float>(a)));
1112 }
1113 };
1114 struct Cos {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Cos1115 bfloat16 operator()(bfloat16 a) {
1116 return bfloat16(std::cos(static_cast<float>(a)));
1117 }
1118 };
1119 struct Tan {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Tan1120 bfloat16 operator()(bfloat16 a) {
1121 return bfloat16(std::tan(static_cast<float>(a)));
1122 }
1123 };
1124 struct Arcsin {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Arcsin1125 bfloat16 operator()(bfloat16 a) {
1126 return bfloat16(std::asin(static_cast<float>(a)));
1127 }
1128 };
1129 struct Arccos {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Arccos1130 bfloat16 operator()(bfloat16 a) {
1131 return bfloat16(std::acos(static_cast<float>(a)));
1132 }
1133 };
1134 struct Arctan {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Arctan1135 bfloat16 operator()(bfloat16 a) {
1136 return bfloat16(std::atan(static_cast<float>(a)));
1137 }
1138 };
1139 struct Arctan2 {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Arctan21140 bfloat16 operator()(bfloat16 a, bfloat16 b) {
1141 return bfloat16(std::atan2(static_cast<float>(a), static_cast<float>(b)));
1142 }
1143 };
1144 struct Hypot {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Hypot1145 bfloat16 operator()(bfloat16 a, bfloat16 b) {
1146 return bfloat16(std::hypot(static_cast<float>(a), static_cast<float>(b)));
1147 }
1148 };
1149 struct Sinh {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Sinh1150 bfloat16 operator()(bfloat16 a) {
1151 return bfloat16(std::sinh(static_cast<float>(a)));
1152 }
1153 };
1154 struct Cosh {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Cosh1155 bfloat16 operator()(bfloat16 a) {
1156 return bfloat16(std::cosh(static_cast<float>(a)));
1157 }
1158 };
1159 struct Tanh {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Tanh1160 bfloat16 operator()(bfloat16 a) {
1161 return bfloat16(std::tanh(static_cast<float>(a)));
1162 }
1163 };
1164 struct Arcsinh {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Arcsinh1165 bfloat16 operator()(bfloat16 a) {
1166 return bfloat16(std::asinh(static_cast<float>(a)));
1167 }
1168 };
1169 struct Arccosh {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Arccosh1170 bfloat16 operator()(bfloat16 a) {
1171 return bfloat16(std::acosh(static_cast<float>(a)));
1172 }
1173 };
1174 struct Arctanh {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Arctanh1175 bfloat16 operator()(bfloat16 a) {
1176 return bfloat16(std::atanh(static_cast<float>(a)));
1177 }
1178 };
1179 struct Deg2rad {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Deg2rad1180 bfloat16 operator()(bfloat16 a) {
1181 static constexpr float radians_per_degree = M_PI / 180.0f;
1182 return bfloat16(static_cast<float>(a) * radians_per_degree);
1183 }
1184 };
1185 struct Rad2deg {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Rad2deg1186 bfloat16 operator()(bfloat16 a) {
1187 static constexpr float degrees_per_radian = 180.0f / M_PI;
1188 return bfloat16(static_cast<float>(a) * degrees_per_radian);
1189 }
1190 };
1191
1192 struct Eq {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Eq1193 npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
1194 };
1195 struct Ne {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Ne1196 npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
1197 };
1198 struct Lt {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Lt1199 npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
1200 };
1201 struct Gt {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Gt1202 npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
1203 };
1204 struct Le {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Le1205 npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
1206 };
1207 struct Ge {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Ge1208 npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
1209 };
1210 struct Maximum {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Maximum1211 bfloat16 operator()(bfloat16 a, bfloat16 b) {
1212 float fa(a), fb(b);
1213 return Eigen::numext::isnan(fa) || fa > fb ? a : b;
1214 }
1215 };
1216 struct Minimum {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Minimum1217 bfloat16 operator()(bfloat16 a, bfloat16 b) {
1218 float fa(a), fb(b);
1219 return Eigen::numext::isnan(fa) || fa < fb ? a : b;
1220 }
1221 };
1222 struct Fmax {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Fmax1223 bfloat16 operator()(bfloat16 a, bfloat16 b) {
1224 float fa(a), fb(b);
1225 return Eigen::numext::isnan(fb) || fa > fb ? a : b;
1226 }
1227 };
1228 struct Fmin {
operator ()tensorflow::__anonea8c089e0111::ufuncs::Fmin1229 bfloat16 operator()(bfloat16 a, bfloat16 b) {
1230 float fa(a), fb(b);
1231 return Eigen::numext::isnan(fb) || fa < fb ? a : b;
1232 }
1233 };
1234
1235 struct LogicalNot {
operator ()tensorflow::__anonea8c089e0111::ufuncs::LogicalNot1236 npy_bool operator()(bfloat16 a) { return !a; }
1237 };
1238 struct LogicalAnd {
operator ()tensorflow::__anonea8c089e0111::ufuncs::LogicalAnd1239 npy_bool operator()(bfloat16 a, bfloat16 b) { return a && b; }
1240 };
1241 struct LogicalOr {
operator ()tensorflow::__anonea8c089e0111::ufuncs::LogicalOr1242 npy_bool operator()(bfloat16 a, bfloat16 b) { return a || b; }
1243 };
1244 struct LogicalXor {
operator ()tensorflow::__anonea8c089e0111::ufuncs::LogicalXor1245 npy_bool operator()(bfloat16 a, bfloat16 b) {
1246 return static_cast<bool>(a) ^ static_cast<bool>(b);
1247 }
1248 };
1249
1250 struct NextAfter {
operator ()tensorflow::__anonea8c089e0111::ufuncs::NextAfter1251 bfloat16 operator()(bfloat16 from, bfloat16 to) {
1252 uint16_t from_as_int, to_as_int;
1253 const uint16_t sign_mask = 1 << 15;
1254 float from_as_float(from), to_as_float(to);
1255 memcpy(&from_as_int, &from, sizeof(bfloat16));
1256 memcpy(&to_as_int, &to, sizeof(bfloat16));
1257 if (Eigen::numext::isnan(from_as_float) ||
1258 Eigen::numext::isnan(to_as_float)) {
1259 return bfloat16(std::numeric_limits<float>::quiet_NaN());
1260 }
1261 if (from_as_int == to_as_int) {
1262 return to;
1263 }
1264 if (from_as_float == 0) {
1265 if (to_as_float == 0) {
1266 return to;
1267 } else {
1268 // Smallest subnormal signed like `to`.
1269 uint16_t out_int = (to_as_int & sign_mask) | 1;
1270 bfloat16 out;
1271 memcpy(&out, &out_int, sizeof(bfloat16));
1272 return out;
1273 }
1274 }
1275 uint16_t from_sign = from_as_int & sign_mask;
1276 uint16_t to_sign = to_as_int & sign_mask;
1277 uint16_t from_abs = from_as_int & ~sign_mask;
1278 uint16_t to_abs = to_as_int & ~sign_mask;
1279 uint16_t magnitude_adjustment =
1280 (from_abs > to_abs || from_sign != to_sign) ? 0xFFFF : 0x0001;
1281 uint16_t out_int = from_as_int + magnitude_adjustment;
1282 bfloat16 out;
1283 memcpy(&out, &out_int, sizeof(bfloat16));
1284 return out;
1285 }
1286 };
1287
1288 // TODO(phawkins): implement spacing
1289
1290 } // namespace ufuncs
1291
1292 } // namespace
1293
1294 // Initializes the module.
Initialize()1295 bool Initialize() {
1296 ImportNumpy();
1297 import_umath1(false);
1298
1299 Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy"));
1300 if (!numpy_str) {
1301 return false;
1302 }
1303 Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
1304 if (!numpy) {
1305 return false;
1306 }
1307
1308 // If another module (presumably either TF or JAX) has registered a bfloat16
1309 // type, use it. We don't want two bfloat16 types if we can avoid it since it
1310 // leads to confusion if we have two different types with the same name. This
1311 // assumes that the other module has a sufficiently complete bfloat16
1312 // implementation. The only known NumPy bfloat16 extension at the time of
1313 // writing is this one (distributed in TF and JAX).
1314 // TODO(phawkins): distribute the bfloat16 extension as its own pip package,
1315 // so we can unambiguously refer to a single canonical definition of bfloat16.
1316 int typenum = PyArray_TypeNumFromName(const_cast<char*>("bfloat16"));
1317 if (typenum != NPY_NOTYPE) {
1318 PyArray_Descr* descr = PyArray_DescrFromType(typenum);
1319 // The test for an argmax function here is to verify that the
1320 // bfloat16 implementation is sufficiently new, and, say, not from
1321 // an older version of TF or JAX.
1322 if (descr && descr->f && descr->f->argmax) {
1323 npy_bfloat16 = typenum;
1324 bfloat16_type_ptr = descr->typeobj;
1325 return true;
1326 }
1327 }
1328
1329 bfloat16_type.tp_base = &PyGenericArrType_Type;
1330
1331 if (PyType_Ready(&bfloat16_type) < 0) {
1332 return false;
1333 }
1334
1335 // Initializes the NumPy descriptor.
1336 PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
1337 NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
1338 NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
1339 NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
1340 NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
1341 NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
1342 NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
1343 NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
1344 NPyBfloat16_ArrFuncs.dotfunc = NPyBfloat16_DotFunc;
1345 NPyBfloat16_ArrFuncs.compare = NPyBfloat16_CompareFunc;
1346 NPyBfloat16_ArrFuncs.argmax = NPyBfloat16_ArgMaxFunc;
1347 NPyBfloat16_ArrFuncs.argmin = NPyBfloat16_ArgMinFunc;
1348
1349 Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
1350 npy_bfloat16 = PyArray_RegisterDataType(&NPyBfloat16_Descr);
1351 bfloat16_type_ptr = &bfloat16_type;
1352 if (npy_bfloat16 < 0) {
1353 return false;
1354 }
1355
1356 // Support dtype(bfloat16)
1357 if (PyDict_SetItemString(bfloat16_type.tp_dict, "dtype",
1358 reinterpret_cast<PyObject*>(&NPyBfloat16_Descr)) <
1359 0) {
1360 return false;
1361 }
1362
1363 // Register casts
1364 if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF)) {
1365 return false;
1366 }
1367
1368 if (!RegisterBfloat16Cast<float>(NPY_FLOAT)) {
1369 return false;
1370 }
1371 if (!RegisterBfloat16Cast<double>(NPY_DOUBLE)) {
1372 return false;
1373 }
1374 if (!RegisterBfloat16Cast<bool>(NPY_BOOL)) {
1375 return false;
1376 }
1377 if (!RegisterBfloat16Cast<uint8>(NPY_UINT8)) {
1378 return false;
1379 }
1380 if (!RegisterBfloat16Cast<uint16>(NPY_UINT16)) {
1381 return false;
1382 }
1383 if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT)) {
1384 return false;
1385 }
1386 if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG)) { // NOLINT
1387 return false;
1388 }
1389 if (!RegisterBfloat16Cast<unsigned long long>(NPY_ULONGLONG)) { // NOLINT
1390 return false;
1391 }
1392 if (!RegisterBfloat16Cast<uint64>(NPY_UINT64)) {
1393 return false;
1394 }
1395 if (!RegisterBfloat16Cast<int8>(NPY_INT8)) {
1396 return false;
1397 }
1398 if (!RegisterBfloat16Cast<int16>(NPY_INT16)) {
1399 return false;
1400 }
1401 if (!RegisterBfloat16Cast<int>(NPY_INT)) {
1402 return false;
1403 }
1404 if (!RegisterBfloat16Cast<long>(NPY_LONG)) { // NOLINT
1405 return false;
1406 }
1407 if (!RegisterBfloat16Cast<long long>(NPY_LONGLONG)) { // NOLINT
1408 return false;
1409 }
1410 // Following the numpy convention. imag part is dropped when converting to
1411 // float.
1412 if (!RegisterBfloat16Cast<std::complex<float>>(NPY_COMPLEX64)) {
1413 return false;
1414 }
1415 if (!RegisterBfloat16Cast<std::complex<double>>(NPY_COMPLEX128)) {
1416 return false;
1417 }
1418
1419 // Safe casts from bfloat16 to other types
1420 if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_FLOAT, NPY_NOSCALAR) <
1421 0) {
1422 return false;
1423 }
1424 if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_DOUBLE, NPY_NOSCALAR) <
1425 0) {
1426 return false;
1427 }
1428 if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX64, NPY_NOSCALAR) <
1429 0) {
1430 return false;
1431 }
1432 if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_COMPLEX128,
1433 NPY_NOSCALAR) < 0) {
1434 return false;
1435 }
1436
1437 // Safe casts to bfloat16 from other types
1438 if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL), npy_bfloat16,
1439 NPY_NOSCALAR) < 0) {
1440 return false;
1441 }
1442 if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UINT8), npy_bfloat16,
1443 NPY_NOSCALAR) < 0) {
1444 return false;
1445 }
1446 if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_INT8), npy_bfloat16,
1447 NPY_NOSCALAR) < 0) {
1448 return false;
1449 }
1450
1451 bool ok =
1452 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Add>>(numpy.get(),
1453 "add") &&
1454 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Subtract>>(
1455 numpy.get(), "subtract") &&
1456 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Multiply>>(
1457 numpy.get(), "multiply") &&
1458 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
1459 numpy.get(), "divide") &&
1460 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp>>(
1461 numpy.get(), "logaddexp") &&
1462 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp2>>(
1463 numpy.get(), "logaddexp2") &&
1464 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Negative>>(
1465 numpy.get(), "negative") &&
1466 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Positive>>(
1467 numpy.get(), "positive") &&
1468 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
1469 numpy.get(), "true_divide") &&
1470 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::FloorDivide>>(
1471 numpy.get(), "floor_divide") &&
1472 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Power>>(numpy.get(),
1473 "power") &&
1474 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
1475 numpy.get(), "remainder") &&
1476 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
1477 numpy.get(), "mod") &&
1478 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmod>>(numpy.get(),
1479 "fmod") &&
1480 RegisterUFunc<ufuncs::DivmodUFunc>(numpy.get(), "divmod") &&
1481 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
1482 "absolute") &&
1483 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
1484 "fabs") &&
1485 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rint>>(numpy.get(),
1486 "rint") &&
1487 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sign>>(numpy.get(),
1488 "sign") &&
1489 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Heaviside>>(
1490 numpy.get(), "heaviside") &&
1491 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Conjugate>>(
1492 numpy.get(), "conjugate") &&
1493 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp>>(numpy.get(),
1494 "exp") &&
1495 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp2>>(numpy.get(),
1496 "exp2") &&
1497 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Expm1>>(numpy.get(),
1498 "expm1") &&
1499 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log>>(numpy.get(),
1500 "log") &&
1501 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log2>>(numpy.get(),
1502 "log2") &&
1503 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log10>>(numpy.get(),
1504 "log10") &&
1505 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log1p>>(numpy.get(),
1506 "log1p") &&
1507 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sqrt>>(numpy.get(),
1508 "sqrt") &&
1509 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Square>>(numpy.get(),
1510 "square") &&
1511 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cbrt>>(numpy.get(),
1512 "cbrt") &&
1513 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Reciprocal>>(
1514 numpy.get(), "reciprocal") &&
1515
1516 // Trigonometric functions
1517 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sin>>(numpy.get(),
1518 "sin") &&
1519 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cos>>(numpy.get(),
1520 "cos") &&
1521 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tan>>(numpy.get(),
1522 "tan") &&
1523 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsin>>(numpy.get(),
1524 "arcsin") &&
1525 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccos>>(numpy.get(),
1526 "arccos") &&
1527 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctan>>(numpy.get(),
1528 "arctan") &&
1529 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Arctan2>>(
1530 numpy.get(), "arctan2") &&
1531 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Hypot>>(numpy.get(),
1532 "hypot") &&
1533 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sinh>>(numpy.get(),
1534 "sinh") &&
1535 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cosh>>(numpy.get(),
1536 "cosh") &&
1537 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tanh>>(numpy.get(),
1538 "tanh") &&
1539 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsinh>>(
1540 numpy.get(), "arcsinh") &&
1541 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccosh>>(
1542 numpy.get(), "arccosh") &&
1543 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctanh>>(
1544 numpy.get(), "arctanh") &&
1545 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Deg2rad>>(
1546 numpy.get(), "deg2rad") &&
1547 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rad2deg>>(
1548 numpy.get(), "rad2deg") &&
1549
1550 // Comparison functions
1551 RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Eq>>(numpy.get(),
1552 "equal") &&
1553 RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ne>>(numpy.get(),
1554 "not_equal") &&
1555 RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Lt>>(numpy.get(),
1556 "less") &&
1557 RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Gt>>(numpy.get(),
1558 "greater") &&
1559 RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Le>>(numpy.get(),
1560 "less_equal") &&
1561 RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ge>>(numpy.get(),
1562 "greater_equal") &&
1563 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Maximum>>(
1564 numpy.get(), "maximum") &&
1565 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Minimum>>(
1566 numpy.get(), "minimum") &&
1567 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmax>>(numpy.get(),
1568 "fmax") &&
1569 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmin>>(numpy.get(),
1570 "fmin") &&
1571 RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalAnd>>(
1572 numpy.get(), "logical_and") &&
1573 RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalOr>>(
1574 numpy.get(), "logical_or") &&
1575 RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalXor>>(
1576 numpy.get(), "logical_xor") &&
1577 RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::LogicalNot>>(
1578 numpy.get(), "logical_not") &&
1579
1580 // Floating point functions
1581 RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsFinite>>(numpy.get(),
1582 "isfinite") &&
1583 RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsInf>>(numpy.get(),
1584 "isinf") &&
1585 RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsNan>>(numpy.get(),
1586 "isnan") &&
1587 RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::SignBit>>(numpy.get(),
1588 "signbit") &&
1589 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::CopySign>>(
1590 numpy.get(), "copysign") &&
1591 RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, bfloat16, ufuncs::Modf>>(
1592 numpy.get(), "modf") &&
1593 RegisterUFunc<BinaryUFunc2<bfloat16, int, bfloat16, ufuncs::Ldexp>>(
1594 numpy.get(), "ldexp") &&
1595 RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, int, ufuncs::Frexp>>(
1596 numpy.get(), "frexp") &&
1597 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Floor>>(numpy.get(),
1598 "floor") &&
1599 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
1600 "ceil") &&
1601 RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
1602 "trunc") &&
1603 RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
1604 numpy.get(), "nextafter");
1605
1606 return ok;
1607 }
1608
RegisterNumpyBfloat16()1609 bool RegisterNumpyBfloat16() {
1610 if (npy_bfloat16 != NPY_NOTYPE) {
1611 // Already initialized.
1612 return true;
1613 }
1614 if (!Initialize()) {
1615 if (!PyErr_Occurred()) {
1616 PyErr_SetString(PyExc_RuntimeError, "cannot load bfloat16 module.");
1617 }
1618 PyErr_Print();
1619 return false;
1620 }
1621 return true;
1622 }
1623
Bfloat16Dtype()1624 PyObject* Bfloat16Dtype() {
1625 return reinterpret_cast<PyObject*>(bfloat16_type_ptr);
1626 }
1627
Bfloat16NumpyType()1628 int Bfloat16NumpyType() { return npy_bfloat16; }
1629
1630 } // namespace tensorflow
1631