• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_values.h"
17 
18 #include "pybind11/pybind11.h"
19 #include "pybind11/pytypes.h"
20 #include "tensorflow/compiler/xla/primitive_util.h"
21 #include "tensorflow/compiler/xla/python/py_buffer.h"
22 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
23 #include "tensorflow/compiler/xla/python/types.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/profiler/lib/traceme.h"
26 #include "tensorflow/python/lib/core/numpy.h"
27 
28 namespace py = pybind11;
29 
30 namespace xla {
31 
32 namespace {
33 
34 using DevicePutFunc = std::function<StatusOr<DevicePutResult>(
35     py::handle, PjRtDevice*, const DevicePutOptions& options)>;
36 
37 template <typename T, typename SquashedT>
HandlePythonScalar(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)38 StatusOr<DevicePutResult> HandlePythonScalar(py::handle obj,
39                                              PjRtDevice* to_device,
40                                              const DevicePutOptions& options) {
41   T data;
42 
43   try {
44     data = py::cast<T>(obj);
45   } catch (const std::exception& e) {
46     return InvalidArgument(
47         "Unable to convert Python scalar to %s. This most likely means the "
48         "value (%s) overflows the range of the type.",
49         PrimitiveType_Name(primitive_util::NativeToPrimitiveType<T>()),
50         py::repr(obj));
51   }
52 
53   void* ptr;
54   SquashedT squashed_data;
55   Shape shape;
56   PrimitiveType type;
57   if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
58     ptr = &data;
59     type = primitive_util::NativeToPrimitiveType<T>();
60   } else {
61     // TODO(phawkins): we should check for overflow here, e.g., because of bugs
62     // like https://github.com/google/jax/issues/2006
63     squashed_data = static_cast<SquashedT>(data);
64     ptr = &squashed_data;
65     type = primitive_util::NativeToPrimitiveType<SquashedT>();
66   }
67   TF_ASSIGN_OR_RETURN(
68       auto buffer,
69       to_device->client()->BufferFromHostBuffer(
70           ptr, type, /*dims=*/{}, /*byte_strides=*/{},
71           PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
72           /*on_done_with_host_buffer=*/nullptr, to_device));
73   return DevicePutResult(std::move(buffer), /*weak_type=*/true);
74 }
75 
HandlePythonInt(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)76 StatusOr<DevicePutResult> HandlePythonInt(py::handle obj, PjRtDevice* to_device,
77                                           const DevicePutOptions& options) {
78   void* ptr;
79   PrimitiveType type;
80   int64_t data_int64;
81   int32_t data_int32;
82 
83   if (options.squash_64bit_types) {
84     try {
85       data_int32 = py::cast<int32>(obj);
86     } catch (const std::exception& e) {
87       return InvalidArgument(
88           "Unable to convert Python scalar to %s. This most likely means the "
89           "value (%s) overflows the range of the type.",
90           PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int32>()),
91           py::repr(obj));
92     }
93     ptr = &data_int32;
94     type = S32;
95   } else {
96     try {
97       data_int64 = py::cast<int64>(obj);
98     } catch (const std::exception& e) {
99       return InvalidArgument(
100           "Unable to convert Python scalar to %s. This most likely means the "
101           "value (%s) overflows the range of the type.",
102           PrimitiveType_Name(primitive_util::NativeToPrimitiveType<int64>()),
103           py::repr(obj));
104     }
105     ptr = &data_int64;
106     type = S64;
107   }
108   TF_ASSIGN_OR_RETURN(
109       auto buffer,
110       to_device->client()->BufferFromHostBuffer(
111           ptr, type, /*dims=*/{}, /*byte_strides=*/{},
112           PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
113           /*on_done_with_host_buffer=*/nullptr, to_device));
114   return DevicePutResult(std::move(buffer), /*weak_type=*/true);
115 }
116 
117 template <typename T, typename SquashedT = T>
HandleNumpyScalar(py::handle h,PjRtDevice * to_device,const DevicePutOptions & options)118 StatusOr<DevicePutResult> HandleNumpyScalar(py::handle h, PjRtDevice* to_device,
119                                             const DevicePutOptions& options) {
120   T data;
121   SquashedT data_squashed;
122   void* ptr;
123   PrimitiveType type;
124   if (std::is_same<T, bfloat16>()) {
125     // For extension types, ScalarAsCtype returns a pointer to the data.
126     PyArray_ScalarAsCtype(h.ptr(), &ptr);
127     type = BF16;
128   } else if (std::is_same<T, SquashedT>() || !options.squash_64bit_types) {
129     PyArray_ScalarAsCtype(h.ptr(), &data);
130     ptr = &data;
131     type = primitive_util::NativeToPrimitiveType<T>();
132   } else {
133     PyArray_ScalarAsCtype(h.ptr(), &data);
134     data_squashed = static_cast<SquashedT>(data);
135     ptr = &data_squashed;
136     type = primitive_util::NativeToPrimitiveType<SquashedT>();
137   }
138   TF_ASSIGN_OR_RETURN(
139       std::unique_ptr<PjRtBuffer> buffer,
140       to_device->client()->BufferFromHostBuffer(
141           ptr, type, /*dims=*/{}, /*byte_strides=*/{},
142           PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
143           /*on_done_with_host_buffer=*/nullptr, to_device));
144   return DevicePutResult(std::move(buffer), /*weak_type=*/false);
145 }
146 
HandleNumpyArray(py::handle h,PjRtDevice * to_device,const DevicePutOptions & options)147 StatusOr<DevicePutResult> HandleNumpyArray(py::handle h, PjRtDevice* to_device,
148                                            const DevicePutOptions& options) {
149   py::array array = py::cast<py::array>(h);
150   TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype()));
151 
152   PrimitiveType squashed_type;
153   if (options.squash_64bit_types) {
154     squashed_type = Squash64BitTypes(type);
155     if (squashed_type != type) {
156       TF_ASSIGN_OR_RETURN(py::dtype squashed_dtype,
157                           PrimitiveTypeToDtype(squashed_type));
158       array = py::reinterpret_steal<py::array>(PyArray_CastToType(
159           reinterpret_cast<PyArrayObject*>(array.ptr()),
160           reinterpret_cast<PyArray_Descr*>(squashed_dtype.release().ptr()),
161           /*fortran=*/0));
162     }
163   } else {
164     squashed_type = type;
165   }
166 
167   absl::InlinedVector<int64, 4> dims(array.ndim());
168   absl::InlinedVector<int64, 4> byte_strides(array.ndim());
169   for (int i = 0; i < array.ndim(); ++i) {
170     dims[i] = array.shape(i);
171     byte_strides[i] = array.strides(i);
172   }
173   const void* data = array.data();
174   PjRtClient::HostBufferSemantics host_buffer_semantics =
175       PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall;
176   std::function<void()> on_done_with_host_buffer;
177   if (options.allow_zero_copy) {
178     std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
179         GlobalPyRefManager()->ManageReference(std::move(array));
180     on_done_with_host_buffer =
181         [py_buffer_ref{
182             std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
183     host_buffer_semantics = PjRtClient::HostBufferSemantics::kZeroCopy;
184   }
185   TF_ASSIGN_OR_RETURN(
186       auto buffer,
187       to_device->client()->BufferFromHostBuffer(
188           data, squashed_type, dims, byte_strides, host_buffer_semantics,
189           std::move(on_done_with_host_buffer), to_device));
190   return DevicePutResult(std::move(buffer), /*weak_type=*/false);
191 }
192 
PyBufferHelper(py::handle obj,py::handle py_buffer,PyBuffer * buffer,PjRtDevice * to_device)193 StatusOr<DevicePutResult> PyBufferHelper(py::handle obj, py::handle py_buffer,
194                                          PyBuffer* buffer,
195                                          PjRtDevice* to_device) {
196   bool weak_type = buffer->weak_type()
197                        ? *buffer->weak_type()
198                        : py::cast<bool>(obj.attr("aval").attr("weak_type"));
199   if (buffer->buffer()->device() == to_device) {
200     return DevicePutResult(
201         buffer->buffer(), weak_type,
202         /*owning_pybuffer=*/py::reinterpret_borrow<py::object>(py_buffer));
203   } else {
204     TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> copied_buffer,
205                         buffer->buffer()->CopyToDevice(to_device));
206     return DevicePutResult(std::move(copied_buffer), weak_type);
207   }
208 }
209 
HandlePyBuffer(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)210 StatusOr<DevicePutResult> HandlePyBuffer(py::handle obj, PjRtDevice* to_device,
211                                          const DevicePutOptions& options) {
212   return PyBufferHelper(obj, obj, PyBuffer::AsPyBufferUnchecked(obj),
213                         to_device);
214 }
215 
HandleDeviceArray(py::handle obj,PjRtDevice * to_device,const DevicePutOptions & options)216 StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
217                                             PjRtDevice* to_device,
218                                             const DevicePutOptions& options) {
219   // Handle Python DeviceArray objects provided they have a .device_buffer field
220   // Otherwise, fallback to handling as a NumPy array, since we do not
221   // understand how to get a buffer object out. For example, ShardedDeviceArray
222   // in JAX is handled by this path.
223   py::object buffer = py::getattr(obj, "device_buffer", py::none());
224   if (buffer.is_none()) {
225     return HandleNumpyArray(obj, to_device, options);
226   }
227 
228   // Force buffers with a non-trivial lazy expression.
229   py::object forced;
230   if (!py::getattr(obj, "_lazy_expr").is_none()) {
231     if (!options.force_lazy_arrays) {
232       return InvalidArgument("Lazy arrays are not supported by device_put");
233     }
234     static py::function& force = *[]() {
235       const auto xla_module = py::module::import("jax.interpreters.xla");
236       return new py::function(
237           py::cast<py::function>(xla_module.attr("_force")));
238     }();
239     forced = force(obj);
240     buffer = forced.attr("device_buffer");
241     obj = forced;
242   }
243 
244   return PyBufferHelper(obj, buffer, py::cast<PyBuffer*>(buffer), to_device);
245 }
246 
247 }  // namespace
248 
DevicePut(pybind11::handle arg,PjRtDevice * to_device,const DevicePutOptions & options)249 StatusOr<DevicePutResult> DevicePut(pybind11::handle arg, PjRtDevice* to_device,
250                                     const DevicePutOptions& options) {
251   tensorflow::profiler::TraceMe traceme("DevicePut");
252   static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
253       [] {
254         auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
255         const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
256         // Python scalar types.
257         static_assert(sizeof(bool) == 1,
258                       "Conversion code assumes bool is 1 byte");
259         (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] =
260             HandlePythonScalar<bool, bool>;
261         (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = HandlePythonInt;
262         (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] =
263             HandlePythonScalar<double, float>;
264         (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
265             HandlePythonScalar<complex128, complex64>;
266 
267         // Generic subclasses of DeviceArray, e.g., ShardedDeviceArray.
268         (*p)[PyBuffer::base_type()] = HandleDeviceArray;
269 
270         try {
271           py::object xla_module = py::module::import("jax.interpreters.xla");
272           py::object device_array =
273               py::getattr(xla_module, "_DeviceArray", py::none());
274           if (!device_array.is_none()) {
275             (*p)[device_array.ptr()] = HandleDeviceArray;
276           }
277         } catch (const py::error_already_set& e) {
278           // Ignore; jax may not be present.
279         }
280 
281         try {
282           py::object pxla_module = py::module::import("jax.interpreters.pxla");
283           py::object sda =
284               py::getattr(pxla_module, "ShardedDeviceArray", py::none());
285           if (!sda.is_none()) {
286             (*p)[sda.ptr()] = HandleDeviceArray;
287           }
288         } catch (const py::error_already_set& e) {
289           // Ignore; jax may not be present.
290         }
291 
292         const auto numpy = py::module::import("numpy");
293         (*p)[numpy.attr("ndarray").ptr()] = HandleNumpyArray;
294 
295         // Numpy scalar types. For some of them, we share the handler with
296         // Python types (np_int64, np_float64, np_complex128).
297         (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar<bool>;
298         (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar<int8_t>;
299         (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar<int16_t>;
300         (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar<int32_t>;
301         (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
302         (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar<uint8_t>;
303         (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar<uint16_t>;
304         (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar<uint32_t>;
305         (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar<uint64_t, uint32_t>;
306         (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar<bfloat16>;
307         (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar<half>;
308         (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar<float>;
309         (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar<double, float>;
310         (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar<complex64>;
311         (*p)[dtypes.np_complex128.ptr()] =
312             HandleNumpyScalar<complex128, complex64>;
313         static_assert(sizeof(long long) == sizeof(int64_t),  // NOLINT
314                       "long long must be the same size as int64_t");
315         (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar<int64_t, int32_t>;
316         static_assert(sizeof(int) == sizeof(int32_t),
317                       "int must be the same size as int32_t");
318         (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar<int32_t>;
319 
320         return p;
321       }();
322 
323   // Fast-path for the most common case of PyBuffer.
324   if (arg.get_type().ptr() == PyBuffer::type()) {
325     return HandlePyBuffer(arg, to_device, options);
326   }
327 
328   auto res = handlers->find(arg.get_type().ptr());
329   if (res == handlers->end()) {
330     for (auto base_class : arg.get_type().attr("mro")()) {
331       res = handlers->find(base_class.ptr());
332       if (res != handlers->end()) {
333         return res->second(arg, to_device, options);
334       }
335     }
336     return InvalidArgument(
337         "%s", absl::StrCat(
338                   "Not supported: The C++ jax jit execution path, only accepts "
339                   "DeviceArray, Numpy arrays scalars of supported types "
340                   "(see implementation), or Python scalars. Got type ",
341                   py::cast<std::string>(py::str(arg.get_type()))));
342   }
343   return res->second(arg, to_device, options);
344 }
345 
IsFloat0(py::array arg)346 bool IsFloat0(py::array arg) {
347   static const auto* dtypes_module =
348       new py::module(py::module::import("jax.dtypes"));
349   static const auto* float0_dtype =
350       new py::handle(dtypes_module->attr("float0"));
351   return float0_dtype->is(arg.attr("dtype"));
352 }
353 
DebugString() const354 std::string PyArgSignature::DebugString() const {
355   std::string result = "";
356   if (weak_type) {
357     absl::StrAppend(&result, "weak_");
358   }
359   absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
360   absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
361   return result;
362 }
363 
364 using ToPyArgSignatureHandler =
365     std::function<StatusOr<PyArgSignature>(py::handle, bool)>;
366 
PyArgSignatureOfValue(pybind11::handle arg,bool jax_enable_x64)367 StatusOr<PyArgSignature> PyArgSignatureOfValue(pybind11::handle arg,
368                                                bool jax_enable_x64) {
369   static const absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>* const
370       handlers = [] {
371         auto p = new absl::flat_hash_map<PyObject*, ToPyArgSignatureHandler>();
372 
373         const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
374 
375         // The 4 Python native types.
376         ToPyArgSignatureHandler bool_handler =
377             [](py::handle, bool) -> StatusOr<PyArgSignature> {
378           return PyArgSignature(PrimitiveType::PRED, {}, true);
379         };
380         ToPyArgSignatureHandler int_handler =
381             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
382           // TODO(phawkins): we should consider checking for integer overflow.
383           if (jax_enable_x64) {
384             return PyArgSignature(PrimitiveType::S64, {}, true);
385           } else {
386             return PyArgSignature(PrimitiveType::S32, {}, true);
387           }
388         };
389         ToPyArgSignatureHandler float_handler =
390             [&dtypes](py::handle h,
391                       bool jax_enable_x64) -> StatusOr<PyArgSignature> {
392           // Only Python native types has a True weak_type.
393           bool weak_type = !py::isinstance(h, dtypes.np_float64);
394           if (jax_enable_x64) {
395             return PyArgSignature(PrimitiveType::F64, {}, weak_type);
396           } else {
397             return PyArgSignature(PrimitiveType::F32, {}, weak_type);
398           }
399         };
400         ToPyArgSignatureHandler complex_handler =
401             [&dtypes](py::handle h,
402                       bool jax_enable_x64) -> StatusOr<PyArgSignature> {
403           // Note that this branch is also taken  for np.complex128:
404           // isinstance(np.complex128(3), complex) returns True
405           // isinstance(np.complex64(3), complex) returns False
406           bool weak_type = !py::isinstance(h, dtypes.np_complex128);
407           if (jax_enable_x64) {
408             return PyArgSignature(PrimitiveType::C128, {}, weak_type);
409           } else {
410             return PyArgSignature(PrimitiveType::C64, {}, weak_type);
411           }
412         };
413 
414         (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = bool_handler;
415         (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = int_handler;
416         (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = float_handler;
417         (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] = complex_handler;
418 
419         // The Buffer types except for fast-path PyBuffer.
420         ToPyArgSignatureHandler device_array_handler =
421             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
422           py::handle aval = h.attr("aval");
423           TF_ASSIGN_OR_RETURN(auto dtype,
424                               DtypeToPrimitiveType(aval.attr("dtype")));
425           return PyArgSignature(
426               dtype, py::cast<std::vector<int64>>(aval.attr("shape")),
427               py::cast<py::bool_>(aval.attr("weak_type")));
428         };
429         (*p)[PyBuffer::base_type()] = device_array_handler;
430 
431         try {
432           py::object xla_module = py::module::import("jax.interpreters.xla");
433           py::object device_array =
434               py::getattr(xla_module, "_DeviceArray", py::none());
435           if (!device_array.is_none()) {
436             (*p)[device_array.ptr()] = device_array_handler;
437           }
438         } catch (const py::error_already_set& e) {
439           // Ignore; jax may not be present.
440         }
441 
442         try {
443           py::object pxla_module = py::module::import("jax.interpreters.pxla");
444           py::object sda =
445               py::getattr(pxla_module, "ShardedDeviceArray", py::none());
446           if (!sda.is_none()) {
447             (*p)[sda.ptr()] = device_array_handler;
448           }
449         } catch (const py::error_already_set& e) {
450           // Ignore; jax may not be present.
451         }
452 
453         ToPyArgSignatureHandler numpy_handler =
454             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
455           py::array numpy_array = py::cast<py::array>(h);
456           TF_ASSIGN_OR_RETURN(PrimitiveType dtype,
457                               DtypeToPrimitiveType(numpy_array.dtype()));
458           if (!jax_enable_x64) {
459             dtype = Squash64BitTypes(dtype);
460           }
461           // We use reinterpret_cast<> to defend against environments where
462           // ssize_t may not be precisely the same type as int64_t, even if it
463           // is the same size (long vs long long).
464           static_assert(sizeof(int64_t) == sizeof(ssize_t),
465                         "Code assumes ssize_t is the same as int64_t");
466           return PyArgSignature(
467               dtype,
468               absl::MakeConstSpan(
469                   reinterpret_cast<const int64_t*>(numpy_array.shape()),
470                   numpy_array.ndim()),
471               /*weak_type=*/false);
472         };
473         const auto numpy = py::module::import("numpy");
474         (*p)[numpy.attr("ndarray").ptr()] = numpy_handler;
475 
476         ToPyArgSignatureHandler np_uint64_handler =
477             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
478           if (jax_enable_x64) {
479             return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false);
480           } else {
481             return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false);
482           }
483         };
484         ToPyArgSignatureHandler np_int_handler =
485             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
486           if (jax_enable_x64) {
487             return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false);
488           } else {
489             return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false);
490           }
491         };
492         ToPyArgSignatureHandler numpy_array_handler =
493             [](py::handle h, bool jax_enable_x64) -> StatusOr<PyArgSignature> {
494           // This block deals with all numpy scalar types, except for int64_dt,
495           // float64_dt and complex128_dt which are taken care of in previous if
496           // blocks.
497           TF_ASSIGN_OR_RETURN(auto dtype,
498                               DtypeToPrimitiveType(h.attr("dtype")));
499           return PyArgSignature(dtype, {}, /*weak_type=*/false);
500         };
501 
502         // This block deals with all numpy scalar types, except for int64_dt,
503         // float64_dt and complex128_dt which are taken care of in previous if
504         // blocks.
505         (*p)[dtypes.np_bool.ptr()] = numpy_array_handler;
506         (*p)[dtypes.np_int8.ptr()] = numpy_array_handler;
507         (*p)[dtypes.np_int16.ptr()] = numpy_array_handler;
508         (*p)[dtypes.np_int32.ptr()] = numpy_array_handler;
509         (*p)[dtypes.np_int64.ptr()] = np_int_handler;
510         (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler;
511         (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
512         (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
513         (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
514         (*p)[dtypes.np_float16.ptr()] = numpy_array_handler;
515         (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler;
516         (*p)[dtypes.np_float32.ptr()] = numpy_array_handler;
517         (*p)[dtypes.np_float64.ptr()] = float_handler;
518         (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler;
519         (*p)[dtypes.np_complex128.ptr()] = complex_handler;
520         (*p)[dtypes.np_longlong.ptr()] = np_int_handler;
521         (*p)[dtypes.np_intc.ptr()] = numpy_array_handler;
522 
523         return p;
524       }();
525 
526   // Fast-path for the most common case of PyBuffer.
527   if (arg.get_type().ptr() == PyBuffer::type()) {
528     // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
529     TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(arg));
530     bool weak_type = buffer->weak_type().has_value()
531                          ? *buffer->weak_type()
532                          : py::cast<bool>(arg.attr("aval").attr("weak_type"));
533     return PyArgSignature(buffer->buffer()->on_device_shape().element_type(),
534                           buffer->buffer()->on_device_shape().dimensions(),
535                           weak_type);
536   }
537 
538   auto res = handlers->find(arg.get_type().ptr());
539   if (res == handlers->end()) {
540     // We attempt to look at the MRO classes
541     for (auto base_class : arg.get_type().attr("mro")()) {
542       res = handlers->find(base_class.ptr());
543       if (res != handlers->end()) {
544         return res->second(arg, jax_enable_x64);
545       }
546     }
547     return InvalidArgument(
548         "%s",
549         absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts "
550                      "Buffer/DeviceArray/ShardedDeviceArray, Numpy "
551                      "arrays scalars of supported types "
552                      "(see implementation), or Python scalars. Got type ",
553                      py::cast<std::string>(py::str(arg.get_type()))));
554   }
555   return res->second(arg, jax_enable_x64);
556 }
557 
558 }  // namespace xla
559