1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_values.h"
17
18 #include "pybind11/pybind11.h"
19 #include "pybind11/pytypes.h"
20 #include "tensorflow/compiler/xla/primitive_util.h"
21 #include "tensorflow/compiler/xla/python/py_buffer.h"
22 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
23 #include "tensorflow/compiler/xla/python/types.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/profiler/lib/traceme.h"
26 #include "tensorflow/python/lib/core/numpy.h"
27
28 namespace py = pybind11;
29
30 namespace xla {
31
32 namespace {
33
34 using DevicePutFunc = std::function<StatusOr<DevicePutResult>(
35 py::handle, PjRtDevice*, const DevicePutOptions& options)>;
36
37 template <typename T, typename SquashedT>
HandlePythonScalar(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)38 StatusOr<DevicePutResult> HandlePythonScalar(py::handle obj,
39 PjRtDevice* to_device,
40 const DevicePutOptions& options) {
41 T data;
42
43 try {
44 data = py::cast<T>(obj);
45 } catch (const std::exception& e) {
46 return InvalidArgument(
47 "Unable to convert Python scalar to %s. This most likely means the "
48 "value (%s) overflows the range of the type.",
49 PrimitiveType_Name(primitive_util::NativeToPrimitiveType<T>()),
50 py::repr(obj));
51 }
52
53 void* ptr;
54 SquashedT squashed_data;
55 Shape shape;
56 PrimitiveType type;
57 if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
58 ptr = &data;
59 type = primitive_util::NativeToPrimitiveType<T>();
60 } else {
61 // TODO(phawkins): we should check for overflow here, e.g., because of bugs
62 // like https://github.com/google/jax/issues/2006
63 squashed_data = static_cast<SquashedT>(data);
64 ptr = &squashed_data;
65 type = primitive_util::NativeToPrimitiveType<SquashedT>();
66 }
67 TF_ASSIGN_OR_RETURN(
68 auto buffer,
69 to_device->client()->BufferFromHostBuffer(
70 ptr, type, /*dims=*/{}, /*byte_strides=*/{},
71 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
72 /*on_done_with_host_buffer=*/nullptr, to_device));
73 return DevicePutResult(std::move(buffer), /*weak_type=*/true);
74 }
75
HandlePythonInt(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)76 StatusOr<DevicePutResult> HandlePythonInt(py::handle obj, PjRtDevice* to_device,
77 const DevicePutOptions& options) {
78 void* ptr;
79 PrimitiveType type;
80 int64_t data_int64;
81 int32_t data_int32;
82
83 if (options.squash_64bit_types) {
84 try {
85 data_int32 = py::cast<int32>(obj);
86 } catch (const std::exception& e) {
87 return InvalidArgument(
88 "Unable to convert Python scalar to %s. This most likely means the "
89 "value (%s) overflows the range of the type.",
90 PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int32>()),
91 py::repr(obj));
92 }
93 ptr = &data_int32;
94 type = S32;
95 } else {
96 try {
97 data_int64 = py::cast<int64>(obj);
98 } catch (const std::exception& e) {
99 return InvalidArgument(
100 "Unable to convert Python scalar to %s. This most likely means the "
101 "value (%s) overflows the range of the type.",
102 PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int64>()),
103 py::repr(obj));
104 }
105 ptr = &data_int64;
106 type = S64;
107 }
108 TF_ASSIGN_OR_RETURN(
109 auto buffer,
110 to_device->client()->BufferFromHostBuffer(
111 ptr, type, /*dims=*/{}, /*byte_strides=*/{},
112 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
113 /*on_done_with_host_buffer=*/nullptr, to_device));
114 return DevicePutResult(std::move(buffer), /*weak_type=*/true);
115 }
116
117 template <typename T, typename SquashedT = T>
HandleNumpyScalar(py::handle h,PjRtDevice * to_device,const DevicePutOptions & options)118 StatusOr<DevicePutResult> HandleNumpyScalar(py::handle h, PjRtDevice* to_device,
119 const DevicePutOptions& options) {
120 T data;
121 SquashedT data_squashed;
122 void* ptr;
123 PrimitiveType type;
124 if (std::is_same<T, bfloat16>()) {
125 // For extension types, ScalarAsCtype returns a pointer to the data.
126 PyArray_ScalarAsCtype(h.ptr(), &ptr);
127 type = BF16;
128 } else if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
129 PyArray_ScalarAsCtype(h.ptr(), &data);
130 ptr = &data;
131 type = primitive_util::NativeToPrimitiveType<T>();
132 } else {
133 PyArray_ScalarAsCtype(h.ptr(), &data);
134 data_squashed = static_cast<SquashedT>(data);
135 ptr = &data_squashed;
136 type = primitive_util::NativeToPrimitiveType<SquashedT>();
137 }
138 TF_ASSIGN_OR_RETURN(
139 std::unique_ptr<PjRtBuffer> buffer,
140 to_device->client()->BufferFromHostBuffer(
141 ptr, type, /*dims=*/{}, /*byte_strides=*/{},
142 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
143 /*on_done_with_host_buffer=*/nullptr, to_device));
144 return DevicePutResult(std::move(buffer), /*weak_type=*/false);
145 }
146
HandleNumpyArray(py::handle h,PjRtDevice * to_device,const DevicePutOptions & options)147 StatusOr<DevicePutResult> HandleNumpyArray(py::handle h, PjRtDevice* to_device,
148 const DevicePutOptions& options) {
149 py::array array = py::cast<py::array>(h);
150 TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype()));
151
152 PrimitiveType squashed_type;
153 if (options.squash_64bit_types) {
154 squashed_type = Squash64BitTypes(type);
155 if (squashed_type != type) {
156 TF_ASSIGN_OR_RETURN(py::dtype squashed_dtype,
157 PrimitiveTypeToDtype(squashed_type));
158 array = py::reinterpret_steal<py::array>(PyArray_CastToType(
159 reinterpret_cast<PyArrayObject*>(array.ptr()),
160 reinterpret_cast<PyArray_Descr*>(squashed_dtype.release().ptr()),
161 /*fortran=*/0));
162 }
163 } else {
164 squashed_type = type;
165 }
166
167 absl::InlinedVector<int64, 4> dims(array.ndim());
168 absl::InlinedVector<int64, 4> byte_strides(array.ndim());
169 for (int i = 0; i < array.ndim(); ++i) {
170 dims[i] = array.shape(i);
171 byte_strides[i] = array.strides(i);
172 }
173 const void* data = array.data();
174 PjRtClient::HostBufferSemantics host_buffer_semantics =
175 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall;
176 std::function<void()> on_done_with_host_buffer;
177 if (options.allow_zero_copy) {
178 std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
179 GlobalPyRefManager()->ManageReference(std::move(array));
180 on_done_with_host_buffer =
181 [py_buffer_ref{
182 std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
183 host_buffer_semantics = PjRtClient::HostBufferSemantics::kZeroCopy;
184 }
185 TF_ASSIGN_OR_RETURN(
186 auto buffer,
187 to_device->client()->BufferFromHostBuffer(
188 data, squashed_type, dims, byte_strides, host_buffer_semantics,
189 std::move(on_done_with_host_buffer), to_device));
190 return DevicePutResult(std::move(buffer), /*weak_type=*/false);
191 }
192
PyBufferHelper(py::handle obj,py::handle py_buffer,PyBuffer * buffer,PjRtDevice * to_device)193 StatusOr<DevicePutResult> PyBufferHelper(py::handle obj, py::handle py_buffer,
194 PyBuffer* buffer,
195 PjRtDevice* to_device) {
196 bool weak_type = buffer->weak_type()
197 ? *buffer->weak_type()
198 : py::cast<bool>(obj.attr("aval").attr("weak_type"));
199 if (buffer->buffer()->device() == to_device) {
200 return DevicePutResult(
201 buffer->buffer(), weak_type,
202 /*owning_pybuffer=*/py::reinterpret_borrow<py::object>(py_buffer));
203 } else {
204 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> copied_buffer,
205 buffer->buffer()->CopyToDevice(to_device));
206 return DevicePutResult(std::move(copied_buffer), weak_type);
207 }
208 }
209
HandlePyBuffer(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)210 StatusOr<DevicePutResult> HandlePyBuffer(py::handle obj, PjRtDevice* to_device,
211 const DevicePutOptions& options) {
212 return PyBufferHelper(obj, obj, PyBuffer::AsPyBufferUnchecked(obj),
213 to_device);
214 }
215
HandleDeviceArray(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)216 StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
217 PjRtDevice* to_device,
218 const DevicePutOptions& options) {
219 // Handle Python DeviceArray objects provided they have a .device_buffer field
220 // Otherwise, fallback to handling as a NumPy array, since we do not
221 // understand how to get a buffer object out. For example, ShardedDeviceArray
222 // in JAX is handled by this path.
223 py::object buffer = py::getattr(obj, "device_buffer", py::none());
224 if (buffer.is_none()) {
225 return HandleNumpyArray(obj, to_device, options);
226 }
227
228 // Force buffers with a non-trivial lazy expression.
229 py::object forced;
230 if (!py::getattr(obj, "_lazy_expr").is_none()) {
231 if (!options.force_lazy_arrays) {
232 return InvalidArgument("Lazy arrays are not supported by device_put");
233 }
234 static py::function& force = *[]() {
235 const auto xla_module = py::module::import("jax.interpreters.xla");
236 return new py::function(
237 py::cast<py::function>(xla_module.attr("_force")));
238 }();
239 forced = force(obj);
240 buffer = forced.attr("device_buffer");
241 obj = forced;
242 }
243
244 return PyBufferHelper(obj, buffer, py::cast<PyBuffer*>(buffer), to_device);
245 }
246
247 } // namespace
248
DevicePut(pybind11::handle arg,PjRtDevice * to_device,const DevicePutOptions & options)249 StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
250 const DevicePutOptions& options) {
251 tensorflow::profiler::TraceMe traceme("DevicePut");
252 static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
253 [] {
254 auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
255 const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
256 // Python scalar types.
257 static_assert(sizeof(bool) == 1,
258 "Conversion code assumes bool is 1 byte");
259 (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] =
260 HandlePythonScalar<bool, bool>;
261 (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = HandlePythonInt;
262 (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] =
263 HandlePythonScalar<double, float>;
264 (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
265 HandlePythonScalar<complex128, complex64>;
266
267 // Generic subclasses of DeviceArray, e.g., ShardedDeviceArray.
268 (*p)[PyBuffer::base_type()] = HandleDeviceArray;
269
270 try {
271 py::object xla_module = py::module::import("jax.interpreters.xla");
272 py::object device_array =
273 py::getattr(xla_module, "_DeviceArray", py::none());
274 if (!device_array.is_none()) {
275 (*p)[device_array.ptr()] = HandleDeviceArray;
276 }
277 } catch (const py::error_already_set& e) {
278 // Ignore; jax may not be present.
279 }
280
281 try {
282 py::object pxla_module = py::module::import("jax.interpreters.pxla");
283 py::object sda =
284 py::getattr(pxla_module, "ShardedDeviceArray", py::none());
285 if (!sda.is_none()) {
286 (*p)[sda.ptr()] = HandleDeviceArray;
287 }
288 } catch (const py::error_already_set& e) {
289 // Ignore; jax may not be present.
290 }
291
292 const auto numpy = py::module::import("numpy");
293 (*p)[numpy.attr("ndarray").ptr()] = HandleNumpyArray;
294
295 // Numpy scalar types. For some of them, we share the handler with
296 // Python types (np_int64, np_float64, np_complex128).
297 (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar<bool>;
298 (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar<int8_t>;
299 (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar<int16_t>;
300 (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar<int32_t>;
301 (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
302 (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar<uint8_t>;
303 (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar<uint16_t>;
304 (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar<uint32_t>;
305 (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar<uint64_t, uint32_t>;
306 (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar<bfloat16>;
307 (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar<half>;
308 (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar<float>;
309 (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar<double, float>;
310 (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar<complex64>;
311 (*p)[dtypes.np_complex128.ptr()] =
312 HandleNumpyScalar<complex128, complex64>;
313 static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT
314 "long long must be the same size as int64_t");
315 (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
316 static_assert(sizeof(int) == sizeof(int32_t),
317 "int must be the same size as int32_t");
318 (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar<int32_t>;
319
320 return p;
321 }();
322
323 // Fast-path for the most common case of PyBuffer.
324 if (arg.get_type().ptr() == PyBuffer::type()) {
325 return HandlePyBuffer(arg, to_device, options);
326 }
327
328 auto res = handlers->find(arg.get_type().ptr());
329 if (res == handlers->end()) {
330 for (auto base_class : arg.get_type().attr("mro")()) {
331 res = handlers->find(base_class.ptr());
332 if (res != handlers->end()) {
333 return res->second(arg, to_device, options);
334 }
335 }
336 return InvalidArgument(
337 "%s", absl::StrCat(
338 "Not supported: The C++ jax jit execution path, only accepts "
339 "DeviceArray, Numpy arrays scalars of supported types "
340 "(see implementation), or Python scalars. Got type ",
341 py::cast<std::string>(py::str(arg.get_type()))));
342 }
343 return res->second(arg, to_device, options);
344 }
345
IsFloat0(py::array arg)346 bool IsFloat0(py::array arg) {
347 static const auto* dtypes_module =
348 new py::module(py::module::import("jax.dtypes"));
349 static const auto* float0_dtype =
350 new py::handle(dtypes_module->attr("float0"));
351 return float0_dtype->is(arg.attr("dtype"));
352 }
353
DebugString() const354 std::string PyArgSignature::DebugString() const {
355 std::string result = "";
356 if (weak_type) {
357 absl::StrAppend(&result, "weak_");
358 }
359 absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
360 absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
361 return result;
362 }
363
364 using ToPyArgSignatureHandler =
365 std::function<StatusOr<PyArgSignature>(py::handle, bool)>;
366
PyArgSignatureOfValue(pybind11::handle arg,bool jax_enable_x64)367 StatusOr<PyArgSignature> PyArgSignatureOfValue(pybind11::handle arg,
368 bool jax_enable_x64) {
369 static const absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>* const
370 handlers = [] {
371 auto p = new absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>();
372
373 const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
374
375 // The 4 Python native types.
376 ToPyArgSignatureHandler bool_handler =
377 [](py::handle, bool) -> StatusOr<PyArgSignature> {
378 return PyArgSignature(PrimitiveType::PRED, {}, true);
379 };
380 ToPyArgSignatureHandler int_handler =
381 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
382 // TODO(phawkins): we should consider checking for integer overflow.
383 if (jax_enable_x64) {
384 return PyArgSignature(PrimitiveType::S64, {}, true);
385 } else {
386 return PyArgSignature(PrimitiveType::S32, {}, true);
387 }
388 };
389 ToPyArgSignatureHandler float_handler =
390 [&dtypes](py::handle h,
391 bool jax_enable_x64) -> StatusOr<PyArgSignature> {
392 // Only Python native types has a True weak_type.
393 bool weak_type = !py::isinstance(h, dtypes.np_float64);
394 if (jax_enable_x64) {
395 return PyArgSignature(PrimitiveType::F64, {}, weak_type);
396 } else {
397 return PyArgSignature(PrimitiveType::F32, {}, weak_type);
398 }
399 };
400 ToPyArgSignatureHandler complex_handler =
401 [&dtypes](py::handle h,
402 bool jax_enable_x64) -> StatusOr<PyArgSignature> {
403 // Note that this branch is also taken for np.complex128:
404 // isinstance(np.complex128(3), complex) returns True
405 // isinstance(np.complex64(3), complex) returns False
406 bool weak_type = !py::isinstance(h, dtypes.np_complex128);
407 if (jax_enable_x64) {
408 return PyArgSignature(PrimitiveType::C128, {}, weak_type);
409 } else {
410 return PyArgSignature(PrimitiveType::C64, {}, weak_type);
411 }
412 };
413
414 (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = bool_handler;
415 (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = int_handler;
416 (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = float_handler;
417 (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] = complex_handler;
418
419 // The Buffer types except for fast-path PyBuffer.
420 ToPyArgSignatureHandler device_array_handler =
421 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
422 py::handle aval = h.attr("aval");
423 TF_ASSIGN_OR_RETURN(auto dtype,
424 DtypeToPrimitiveType(aval.attr("dtype")));
425 return PyArgSignature(
426 dtype, py::cast<std::vector<int64>>(aval.attr("shape")),
427 py::cast<py::bool_>(aval.attr("weak_type")));
428 };
429 (*p)[PyBuffer::base_type()] = device_array_handler;
430
431 try {
432 py::object xla_module = py::module::import("jax.interpreters.xla");
433 py::object device_array =
434 py::getattr(xla_module, "_DeviceArray", py::none());
435 if (!device_array.is_none()) {
436 (*p)[device_array.ptr()] = device_array_handler;
437 }
438 } catch (const py::error_already_set& e) {
439 // Ignore; jax may not be present.
440 }
441
442 try {
443 py::object pxla_module = py::module::import("jax.interpreters.pxla");
444 py::object sda =
445 py::getattr(pxla_module, "ShardedDeviceArray", py::none());
446 if (!sda.is_none()) {
447 (*p)[sda.ptr()] = device_array_handler;
448 }
449 } catch (const py::error_already_set& e) {
450 // Ignore; jax may not be present.
451 }
452
453 ToPyArgSignatureHandler numpy_handler =
454 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
455 py::array numpy_array = py::cast<py::array>(h);
456 TF_ASSIGN_OR_RETURN(PrimitiveType dtype,
457 DtypeToPrimitiveType(numpy_array.dtype()));
458 if (!jax_enable_x64) {
459 dtype = Squash64BitTypes(dtype);
460 }
461 // We use reinterpret_cast<> to defend against environments where
462 // ssize_t may not be precisely the same type as int64_t, even if it
463 // is the same size (long vs long long).
464 static_assert(sizeof(int64_t) == sizeof(ssize_t),
465 "Code assumes ssize_t is the same as int64_t");
466 return PyArgSignature(
467 dtype,
468 absl::MakeConstSpan(
469 reinterpret_cast<const int64_t*>(numpy_array.shape()),
470 numpy_array.ndim()),
471 /*weak_type=*/false);
472 };
473 const auto numpy = py::module::import("numpy");
474 (*p)[numpy.attr("ndarray").ptr()] = numpy_handler;
475
476 ToPyArgSignatureHandler np_uint64_handler =
477 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
478 if (jax_enable_x64) {
479 return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false);
480 } else {
481 return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false);
482 }
483 };
484 ToPyArgSignatureHandler np_int_handler =
485 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
486 if (jax_enable_x64) {
487 return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false);
488 } else {
489 return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false);
490 }
491 };
492 ToPyArgSignatureHandler numpy_array_handler =
493 [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
494 // This block deals with all numpy scalar types, except for int64_dt,
495 // float64_dt and complex128_dt which are taken care of in previous if
496 // blocks.
497 TF_ASSIGN_OR_RETURN(auto dtype,
498 DtypeToPrimitiveType(h.attr("dtype")));
499 return PyArgSignature(dtype, {}, /*weak_type=*/false);
500 };
501
502 // This block deals with all numpy scalar types, except for int64_dt,
503 // float64_dt and complex128_dt which are taken care of in previous if
504 // blocks.
505 (*p)[dtypes.np_bool.ptr()] = numpy_array_handler;
506 (*p)[dtypes.np_int8.ptr()] = numpy_array_handler;
507 (*p)[dtypes.np_int16.ptr()] = numpy_array_handler;
508 (*p)[dtypes.np_int32.ptr()] = numpy_array_handler;
509 (*p)[dtypes.np_int64.ptr()] = np_int_handler;
510 (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler;
511 (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
512 (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
513 (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
514 (*p)[dtypes.np_float16.ptr()] = numpy_array_handler;
515 (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler;
516 (*p)[dtypes.np_float32.ptr()] = numpy_array_handler;
517 (*p)[dtypes.np_float64.ptr()] = float_handler;
518 (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler;
519 (*p)[dtypes.np_complex128.ptr()] = complex_handler;
520 (*p)[dtypes.np_longlong.ptr()] = np_int_handler;
521 (*p)[dtypes.np_intc.ptr()] = numpy_array_handler;
522
523 return p;
524 }();
525
526 // Fast-path for the most common case of PyBuffer.
527 if (arg.get_type().ptr() == PyBuffer::type()) {
528 // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
529 TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(arg));
530 bool weak_type = buffer->weak_type().has_value()
531 ? *buffer->weak_type()
532 : py::cast<bool>(arg.attr("aval").attr("weak_type"));
533 return PyArgSignature(buffer->buffer()->on_device_shape().element_type(),
534 buffer->buffer()->on_device_shape().dimensions(),
535 weak_type);
536 }
537
538 auto res = handlers->find(arg.get_type().ptr());
539 if (res == handlers->end()) {
540 // We attempt to look at the MRO classes
541 for (auto base_class : arg.get_type().attr("mro")()) {
542 res = handlers->find(base_class.ptr());
543 if (res != handlers->end()) {
544 return res->second(arg, jax_enable_x64);
545 }
546 }
547 return InvalidArgument(
548 "%s",
549 absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts "
550 "Buffer/DeviceArray/ShardedDeviceArray, Numpy "
551 "arrays scalars of supported types "
552 "(see implementation), or Python scalars. Got type ",
553 py::cast<std::string>(py::str(arg.get_type()))));
554 }
555 return res->second(arg, jax_enable_x64);
556 }
557
558 } // namespace xla
559