1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_buffer.h"
17
18 #include <functional>
19 #include <type_traits>
20
21 #include "absl/base/casts.h"
22 #include "pybind11/pybind11.h"
23 #include "pybind11/pytypes.h"
24 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
25 #include "tensorflow/compiler/xla/python/py_client.h"
26 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
27 #include "tensorflow/compiler/xla/python/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29
30 namespace xla {
31
32 namespace py = pybind11;
33
34 namespace {
35
36 // Representation of a DeviceArrayBase as a Python object. Since
37 // a DeviceArrayBase has no fields, this is just a PyObject.
38 struct PyBufferBasePyObject {
39 PyObject_HEAD;
40 };
41 static_assert(std::is_standard_layout<PyBufferBasePyObject>::value,
42 "PyBufferBasePyObject must be standard layout");
43
44 // Representation of a DeviceArray as a Python object.
45 struct PyBufferPyObject {
46 PyBufferBasePyObject base;
47 PyBuffer buffer;
48 // Used by the Python interpreter to maintain a list of weak references to
49 // this object.
50 PyObject* weakrefs;
51 };
52 static_assert(std::is_standard_layout<PyBufferPyObject>::value,
53 "PyBufferPyObject must be standard layout");
54
PyBuffer_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)55 PyObject* PyBuffer_tp_new(PyTypeObject* subtype, PyObject* args,
56 PyObject* kwds) {
57 PyBufferPyObject* self =
58 reinterpret_cast<PyBufferPyObject*>(subtype->tp_alloc(subtype, 0));
59 if (!self) return nullptr;
60 self->weakrefs = nullptr;
61 return reinterpret_cast<PyObject*>(self);
62 }
63
PyBuffer_tp_dealloc(PyObject * self)64 void PyBuffer_tp_dealloc(PyObject* self) {
65 PyTypeObject* tp = Py_TYPE(self);
66 PyBufferPyObject* o = reinterpret_cast<PyBufferPyObject*>(self);
67 if (o->weakrefs) {
68 PyObject_ClearWeakRefs(self);
69 }
70 o->buffer.~PyBuffer();
71 tp->tp_free(self);
72 Py_DECREF(tp);
73 }
74
75 } // namespace
76
Make(std::shared_ptr<PyClient> client,std::shared_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)77 /*static*/ PyBuffer::object PyBuffer::Make(
78 std::shared_ptr<PyClient> client, std::shared_ptr<PjRtBuffer> buffer,
79 std::shared_ptr<Traceback> traceback) {
80 py::object obj = py::reinterpret_steal<py::object>(PyBuffer_tp_new(
81 reinterpret_cast<PyTypeObject*>(type_), nullptr, nullptr));
82 PyBufferPyObject* buf = reinterpret_cast<PyBufferPyObject*>(obj.ptr());
83 new (&buf->buffer)
84 PyBuffer(std::move(client), std::move(buffer), std::move(traceback));
85 return py::reinterpret_borrow<PyBuffer::object>(obj);
86 }
87
IsPyBuffer(py::handle handle)88 bool PyBuffer::IsPyBuffer(py::handle handle) {
89 return handle.get_type() == PyBuffer::type();
90 }
91
AsPyBufferUnchecked(pybind11::handle handle)92 /*static*/ PyBuffer* PyBuffer::AsPyBufferUnchecked(pybind11::handle handle) {
93 return &(reinterpret_cast<PyBufferPyObject*>(handle.ptr())->buffer);
94 }
95
AsPyBuffer(pybind11::handle handle)96 /*static*/ StatusOr<PyBuffer*> PyBuffer::AsPyBuffer(pybind11::handle handle) {
97 if (!IsPyBuffer(handle)) {
98 return InvalidArgument("Expected a DeviceArray");
99 }
100 return AsPyBufferUnchecked(handle);
101 }
102
AsHandle()103 py::handle PyBuffer::AsHandle() {
104 return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
105 offsetof(PyBufferPyObject, buffer));
106 }
107
PyBuffer(std::shared_ptr<PyClient> client,std::shared_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)108 PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
109 std::shared_ptr<PjRtBuffer> buffer,
110 std::shared_ptr<Traceback> traceback)
111 : client_(std::move(client)),
112 buffer_(std::move(buffer)),
113 traceback_(std::move(traceback)) {
114 CHECK(PyGILState_Check());
115 next_ = client_->buffers_[buffer_->device()->id()];
116 client_->buffers_[buffer_->device()->id()] = this;
117 prev_ = nullptr;
118 if (next_) {
119 next_->prev_ = this;
120 }
121 }
122
~PyBuffer()123 PyBuffer::~PyBuffer() {
124 CHECK(PyGILState_Check());
125 if (client_->buffers_[device()->id()] == this) {
126 client_->buffers_[device()->id()] = next_;
127 }
128 if (prev_) {
129 prev_->next_ = next_;
130 }
131 if (next_) {
132 next_->prev_ = prev_;
133 }
134 }
135
size()136 StatusOr<int64> PyBuffer::size() {
137 Shape max_buffer_shape = buffer()->on_device_shape();
138 if (max_buffer_shape.is_dynamic()) {
139 TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
140 return ShapeUtil::ElementsIn(*dynamic_shape);
141 }
142 return ShapeUtil::ElementsIn(max_buffer_shape);
143 }
144
xla_dynamic_shape()145 StatusOr<const Shape*> PyBuffer::xla_dynamic_shape() {
146 CHECK(PyGILState_Check());
147 if (buffer_->on_device_shape().is_static()) {
148 return &buffer_->on_device_shape();
149 }
150 // Python buffer protocol references shape data by pointer, therefore we must
151 // store a valid copy of the shape.
152 if (!dynamic_shape_) {
153 Shape dynamic_shape;
154 {
155 py::gil_scoped_release gil_release;
156 TF_ASSIGN_OR_RETURN(dynamic_shape, buffer_->logical_on_device_shape());
157 }
158 dynamic_shape_ = dynamic_shape;
159 }
160 return &dynamic_shape_.value();
161 }
162
python_shape() const163 pybind11::tuple PyBuffer::python_shape() const {
164 return SpanToTuple(buffer()->on_device_shape().dimensions());
165 }
166
python_dtype() const167 pybind11::dtype PyBuffer::python_dtype() const {
168 PrimitiveType primitive = buffer()->on_device_shape().element_type();
169 return PrimitiveTypeToDtype(primitive).ValueOrDie();
170 }
171
device() const172 ClientAndPtr<PjRtDevice> PyBuffer::device() const {
173 return WrapWithClient(client_, buffer_->device());
174 }
175
Clone() const176 PyBuffer::object PyBuffer::Clone() const {
177 auto buffer = Make(client_, buffer_, traceback_);
178 buffer.buf()->sticky_device_ = sticky_device_;
179 buffer.buf()->aval_ = aval_;
180 return buffer;
181 }
182
CopyToDevice(const ClientAndPtr<PjRtDevice> & dst_device) const183 StatusOr<py::object> PyBuffer::CopyToDevice(
184 const ClientAndPtr<PjRtDevice>& dst_device) const {
185 CHECK(dst_device.get() != nullptr);
186 GlobalPyRefManager()->CollectGarbage();
187 std::unique_ptr<PjRtBuffer> out;
188 {
189 py::gil_scoped_release gil_release;
190 TF_ASSIGN_OR_RETURN(out, buffer_->CopyToDevice(dst_device.get()));
191 }
192 auto traceback = Traceback::Get();
193 return Make(dst_device.client, std::move(out), std::move(traceback));
194 }
195
BlockHostUntilReady()196 Status PyBuffer::BlockHostUntilReady() {
197 GlobalPyRefManager()->CollectGarbage();
198 py::gil_scoped_release gil_release;
199 return buffer_->BlockHostUntilReady();
200 }
201
CopyToHostAsync()202 Status PyBuffer::CopyToHostAsync() {
203 if (!buffer_->IsOnCpu() && !host_value_) {
204 std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
205 host_value_ = host_value;
206 // TODO(b/182461453): This is a blocking call. If we further implemented
207 // populating dynamic shape metadata while fetching the literal, we wouldn't
208 // need this static approach.
209 TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
210
211 py::gil_scoped_release gil;
212 host_value->value = std::make_shared<Literal>(
213 ShapeUtil::DeviceShapeToHostShape(*dynamic_shape));
214 Literal* literal = host_value->value.get();
215 buffer_->ToLiteral(literal,
216 [host_value{std::move(host_value)}](Status status) {
217 host_value->status = std::move(status);
218 host_value->ready.Notify();
219 });
220 }
221 return Status::OK();
222 }
223
AsNumPyArray(py::handle this_obj)224 StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
225 if (buffer_->IsDeleted()) {
226 return InvalidArgument("DeviceArray has been deleted.");
227 }
228 TF_RET_CHECK(buffer_->on_device_shape().IsArray());
229 // On CPU, we can return the value in a zero-copy way.
230 if (buffer_->IsOnCpu()) {
231 TF_ASSIGN_OR_RETURN(const auto* shape, xla_dynamic_shape());
232 TF_ASSIGN_OR_RETURN(py::dtype dtype,
233 PrimitiveTypeToDtype(shape->element_type()));
234 // Objects that must be kept alive while the array is alive.
235 struct Hold {
236 py::object buffer;
237 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
238 };
239 auto hold = std::make_unique<Hold>();
240 TF_ASSIGN_OR_RETURN(hold->external_reference_hold,
241 buffer_->AcquireExternalReference());
242 hold->buffer = py::reinterpret_borrow<py::object>(this_obj);
243 void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
244 py::capsule hold_capsule(hold.release(),
245 [](void* h) { delete static_cast<Hold*>(h); });
246 py::array array(dtype, shape->dimensions(), ByteStridesForShape(*shape),
247 data, hold_capsule);
248 array.attr("flags").attr("writeable") = Py_False;
249 {
250 py::gil_scoped_release gil;
251 TF_RETURN_IF_ERROR(buffer_->BlockHostUntilReady());
252 }
253 return array;
254 }
255
256 TF_RETURN_IF_ERROR(CopyToHostAsync());
257 if (!host_value_->ready.HasBeenNotified()) {
258 py::gil_scoped_release gil;
259 host_value_->ready.WaitForNotification();
260 }
261 TF_RETURN_IF_ERROR(host_value_->status);
262 TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value_->value));
263 array.attr("flags").attr("writeable") = Py_False;
264 return array;
265 }
266
UnsafeBufferPointer() const267 StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
268 return client_->pjrt_client()->UnsafeBufferPointer(buffer_.get());
269 }
270
CudaArrayInterface()271 StatusOr<py::dict> PyBuffer::CudaArrayInterface() {
272 // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
273 if (buffer_->client()->platform_id() != kGpuId) {
274 return InvalidArgument(
275 "__cuda_array_interface__ is only defined for NVidia GPU buffers.");
276 }
277 if (!buffer_->on_device_shape().IsArray()) {
278 return InvalidArgument(
279 "__cuda_array_interface__ is only defined for array buffers.");
280 }
281 if (buffer_->on_device_shape().element_type() == BF16) {
282 return InvalidArgument(
283 "__cuda_array_interface__ is not supported for bfloat16 buffers.");
284 }
285 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
286 buffer_->on_device_shape().layout()));
287
288 py::dict result;
289 TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
290 result["shape"] = SpanToTuple(dynamic_shape->dimensions());
291 TF_ASSIGN_OR_RETURN(py::str typestr,
292 TypeDescriptorForPrimitiveType(
293 buffer_->on_device_shape().element_type()));
294 result["typestr"] = std::move(typestr);
295 TF_ASSIGN_OR_RETURN(
296 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
297 buffer_->AcquireExternalReference());
298 const void* root_ptr =
299 external_reference_hold->OpaqueDeviceMemoryDataPointer();
300 py::tuple data(2);
301 data[0] = py::int_(absl::bit_cast<std::uintptr_t>(root_ptr));
302 data[1] = py::bool_(true); // read-only
303 result["data"] = std::move(data);
304 result["version"] = py::int_(2);
305 return result;
306 }
307
308 // PEP 3118 buffer protocol implementation.
309
310 namespace {
311
312 // Extra data to be kept alive by the consumer of the buffer protocol.
313 struct ExtraBufferInfo {
ExtraBufferInfoxla::__anon5a1ba3860411::ExtraBufferInfo314 explicit ExtraBufferInfo(
315 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold)
316 : external_reference_hold(std::move(external_reference_hold)) {}
317
318 std::string format;
319 std::vector<Py_ssize_t> strides;
320 // We keep an external reference hold to the PjRtBuffer. This prevents a
321 // use-after-free in the event that Delete() is called on a buffer with an
322 // live buffer protocol view. It does however mean that Delete() sometimes
323 // won't actually delete immediately.
324 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
325 };
326
PyBuffer_bf_getbuffer(PyObject * exporter,Py_buffer * view,int flags)327 int PyBuffer_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) {
328 Status status = [&]() {
329 TF_ASSIGN_OR_RETURN(PyBuffer * py_buffer, PyBuffer::AsPyBuffer(exporter));
330 PjRtBuffer& buffer = *py_buffer->buffer();
331 TF_ASSIGN_OR_RETURN(const auto* shape, py_buffer->xla_dynamic_shape());
332 // Py_buffer objects are POD C structures, so we don't need to hold the GIL.
333 // Additionally we call BlockHostUntilReady() below, which may block.
334 py::gil_scoped_release gil_release;
335
336 if (!buffer.IsOnCpu()) {
337 return InvalidArgument(
338 "Python buffer protocol is only defined for CPU buffers.");
339 }
340 if (!buffer.on_device_shape().IsArray()) {
341 return InvalidArgument(
342 "Python buffer protocol is only defined for array buffers.");
343 }
344 // If we allowed exports of formatted BF16 buffers, consumers would get
345 // confused about the type because there is no way to describe BF16 to
346 // Python.
347 if (buffer.on_device_shape().element_type() == BF16 &&
348 ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) {
349 return InvalidArgument(
350 "bfloat16 buffer format not supported by Python buffer protocol.");
351 }
352 if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) {
353 return InvalidArgument("XLA buffers are read-only.");
354 }
355 TF_ASSIGN_OR_RETURN(
356 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
357 buffer.AcquireExternalReference());
358 if (buffer.IsDeleted()) {
359 return InvalidArgument("Deleted buffer used in buffer protocol.");
360 }
361
362 if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS ||
363 (flags & PyBUF_STRIDES) == PyBUF_ND) &&
364 !LayoutUtil::IsMonotonicWithDim0Major(shape->layout())) {
365 return InvalidArgument("Buffer is not in C-contiguous layout.");
366 } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS &&
367 !LayoutUtil::IsMonotonicWithDim0Minor(shape->layout())) {
368 return InvalidArgument("Buffer is not in F-contiguous layout.");
369 } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS &&
370 !LayoutUtil::IsMonotonicWithDim0Major(shape->layout()) &&
371 !LayoutUtil::IsMonotonicWithDim0Minor(shape->layout())) {
372 return InvalidArgument("Buffer is not in contiguous layout.");
373 }
374 std::memset(view, 0, sizeof(Py_buffer));
375 const void* root_ptr =
376 external_reference_hold->OpaqueDeviceMemoryDataPointer();
377 view->buf = const_cast<void*>(root_ptr);
378 auto extra =
379 absl::make_unique<ExtraBufferInfo>(std::move(external_reference_hold));
380 view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(shape->element_type());
381 view->len = ShapeUtil::ByteSizeOf(*shape);
382 view->readonly = 1;
383 if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
384 TF_ASSIGN_OR_RETURN(extra->format, FormatDescriptorForPrimitiveType(
385 shape->element_type()));
386 view->format = const_cast<char*>(extra->format.c_str());
387 }
388 if ((flags & PyBUF_ND) == PyBUF_ND) {
389 view->ndim = shape->dimensions_size();
390 static_assert(sizeof(int64) == sizeof(Py_ssize_t),
391 "Py_ssize_t must be 64 bits");
392 if (view->ndim != 0) {
393 view->shape = reinterpret_cast<Py_ssize_t*>(
394 const_cast<int64*>(shape->dimensions().data()));
395 if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
396 extra->strides = ByteStridesForShape(*shape);
397 view->strides = extra->strides.data();
398 }
399 }
400 }
401 TF_RETURN_IF_ERROR(buffer.BlockHostUntilReady());
402 view->internal = extra.release();
403 return Status::OK();
404 }();
405 if (!status.ok()) {
406 // numpy.asarray(...) silents the PyExc_BufferError. Adding a log here helps
407 // debugging when the error really occurs.
408 VLOG(1) << "Buffer Protocol Error: " << status;
409 PyErr_SetString(PyExc_BufferError, status.ToString().c_str());
410 return -1;
411 }
412 view->obj = exporter;
413 Py_INCREF(view->obj);
414 return 0;
415 }
416
PyBuffer_bf_releasebuffer(PyObject *,Py_buffer * buffer)417 void PyBuffer_bf_releasebuffer(PyObject*, Py_buffer* buffer) {
418 auto extra = static_cast<ExtraBufferInfo*>(buffer->internal);
419 delete extra;
420 }
421
__anon5a1ba3860602() 422 PyBufferProcs PyBuffer_tp_as_buffer = []() {
423 PyBufferProcs procs;
424 procs.bf_getbuffer = &PyBuffer_bf_getbuffer;
425 procs.bf_releasebuffer = &PyBuffer_bf_releasebuffer;
426 return procs;
427 }();
428
429 // Helpers for building Python properties
430 template <typename Func>
property_readonly(Func && get)431 py::object property_readonly(Func&& get) {
432 py::handle property(reinterpret_cast<PyObject*>(&PyProperty_Type));
433 return property(py::cpp_function(std::forward<Func>(get)), py::none(),
434 py::none(), "");
435 }
436
437 template <typename GetFunc, typename SetFunc>
property(GetFunc && get,SetFunc && set)438 py::object property(GetFunc&& get, SetFunc&& set) {
439 py::handle property(reinterpret_cast<PyObject*>(&PyProperty_Type));
440 return property(py::cpp_function(std::forward<GetFunc>(get)),
441 py::cpp_function(std::forward<SetFunc>(set)), py::none(), "");
442 }
443
444 } // namespace
445
446 PyObject* PyBuffer::base_type_ = nullptr;
447 PyObject* PyBuffer::type_ = nullptr;
448
RegisterTypes(py::module & m)449 Status PyBuffer::RegisterTypes(py::module& m) {
450 // We do not use pybind11::class_ to build Python wrapper objects because
451 // creation, destruction, and casting of buffer objects is performance
452 // critical. By using hand-written Python classes, we can avoid extra C heap
453 // allocations, and we can avoid pybind11's slow cast<>() implementation
454 // during jit dispatch.
455
456 // We need to use heap-allocated type objects because we want to add
457 // additional methods dynamically.
458 {
459 py::str name = py::str("DeviceArrayBase");
460 py::str qualname = py::str("DeviceArrayBase");
461 PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
462 PyType_Type.tp_alloc(&PyType_Type, 0));
463 // Caution: we must not call any functions that might invoke the GC until
464 // PyType_Ready() is called. Otherwise the GC might see a half-constructed
465 // type object.
466 if (!heap_type) {
467 return Internal("Unable to create heap type object");
468 }
469 heap_type->ht_name = name.release().ptr();
470 heap_type->ht_qualname = qualname.release().ptr();
471 PyTypeObject* type = &heap_type->ht_type;
472 type->tp_name = "DeviceArrayBase";
473 type->tp_basicsize = sizeof(PyBufferBasePyObject);
474 type->tp_flags =
475 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_BASETYPE;
476 TF_RET_CHECK(PyType_Ready(type) == 0);
477 base_type_ = reinterpret_cast<PyObject*>(type);
478 }
479 py::object base_type = py::reinterpret_borrow<py::object>(base_type_);
480 base_type.attr("__module__") = m.attr("__name__");
481
482 m.attr("DeviceArrayBase") = base_type;
483 {
484 py::tuple bases = py::make_tuple(base_type);
485 py::str name = py::str("DeviceArray");
486 py::str qualname = py::str("DeviceArray");
487 PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
488 PyType_Type.tp_alloc(&PyType_Type, 0));
489 // Caution: we must not call any functions that might invoke the GC until
490 // PyType_Ready() is called below. Otherwise the GC might see a
491 // half-constructed type object.
492 if (!heap_type) {
493 return Internal("Unable to create heap type object");
494 }
495 heap_type->ht_name = name.release().ptr();
496 heap_type->ht_qualname = qualname.release().ptr();
497 PyTypeObject* type = &heap_type->ht_type;
498 type->tp_name = "DeviceArray";
499 type->tp_basicsize = sizeof(PyBufferPyObject);
500 type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE;
501 type->tp_bases = bases.release().ptr();
502 type->tp_dealloc = PyBuffer_tp_dealloc;
503 type->tp_new = PyBuffer_tp_new;
504 // Supported protocols
505 type->tp_as_number = &heap_type->as_number;
506 type->tp_as_sequence = &heap_type->as_sequence;
507 type->tp_as_mapping = &heap_type->as_mapping;
508 type->tp_as_buffer = &PyBuffer_tp_as_buffer;
509
510 // Allow weak references to DeviceArray objects.
511 type->tp_weaklistoffset = offsetof(PyBufferPyObject, weakrefs);
512
513 TF_RET_CHECK(PyType_Ready(type) == 0);
514 type_ = reinterpret_cast<PyObject*>(type);
515 }
516 py::object type = py::reinterpret_borrow<py::object>(type_);
517 m.attr("DeviceArray") = type;
518 m.attr("PyLocalBuffer") = type;
519 m.attr("Buffer") = type;
520
521 // Add methods and properties to the class. We use pybind11 and add methods
522 // dynamically mostly because this is easy to write and allows us to use
523 // pybind11's casting logic. This is most likely slightly slower than
524 // hand-writing bindings, but most of these methods are not performance
525 // critical.
526 type.attr("__array_priority__") =
527 property_readonly([](py::object self) -> int { return 100; });
528 type.attr("_device") = property(
529 [](PyBuffer::object self) -> ClientAndPtr<PjRtDevice> {
530 return WrapWithClient(self.buf()->client(),
531 self.buf()->sticky_device());
532 },
533 [](PyBuffer::object self, PjRtDevice* sticky_device) {
534 return self.buf()->set_sticky_device(sticky_device);
535 });
536 type.attr("aval") = property(
537 [](PyBuffer::object self) -> py::object { return self.buf()->GetAval(); },
538 [](PyBuffer::object self, py::object aval) {
539 return self.buf()->SetAval(std::move(aval));
540 });
541 type.attr("weak_type") = property(
542 [](PyBuffer::object self) -> absl::optional<bool> {
543 return self.buf()->weak_type();
544 },
545 [](PyBuffer::object self, absl::optional<bool> weak_type) {
546 return self.buf()->set_weak_type(weak_type);
547 });
548 type.attr("_lazy_expr") =
549 property_readonly([](py::handle self) { return py::none(); });
550 type.attr("device_buffer") =
551 property_readonly([](py::object self) { return self; });
552 type.attr(
553 "shape") = property_readonly([](PyBuffer::object self) -> py::tuple {
554 return SpanToTuple(self.buf()->buffer()->on_device_shape().dimensions());
555 });
556 type.attr("dtype") = property_readonly([](PyBuffer::object self) {
557 PrimitiveType primitive =
558 self.buf()->buffer()->on_device_shape().element_type();
559 return PrimitiveTypeToDtype(primitive).ValueOrDie();
560 });
561 type.attr("size") =
562 property_readonly([](PyBuffer::object self) -> StatusOr<int64_t> {
563 return self.buf()->size();
564 });
565 type.attr("ndim") = property_readonly(
566 [](PyBuffer::object self) -> int { return self.buf()->ndim(); });
567 type.attr("_value") = property_readonly(
568 [](PyBuffer::object self) -> StatusOr<pybind11::object> {
569 GlobalPyRefManager()->CollectGarbage();
570 return self.buf()->AsNumPyArray(self);
571 });
572 type.attr("copy_to_device") = py::cpp_function(
573 [](PyBuffer::object self, const ClientAndPtr<PjRtDevice>& dst_device) {
574 return self.buf()->CopyToDevice(dst_device);
575 },
576 py::is_method(type));
577 type.attr("on_device_size_in_bytes") = py::cpp_function(
578 [](PyBuffer::object self) -> StatusOr<size_t> {
579 return self.buf()->OnDeviceSizeInBytes();
580 },
581 py::is_method(type));
582 type.attr("delete") = py::cpp_function(
583 [](PyBuffer::object self) { self.buf()->Delete(); }, py::is_method(type));
584 type.attr("block_host_until_ready") = py::cpp_function(
585 [](PyBuffer::object self) { return self.buf()->BlockHostUntilReady(); },
586 py::is_method(type));
587 type.attr("block_until_ready") = py::cpp_function(
588 [](PyBuffer::object self) -> StatusOr<PyBuffer::object> {
589 TF_RETURN_IF_ERROR(self.buf()->BlockHostUntilReady());
590 return std::move(self);
591 },
592 py::is_method(type));
593 type.attr("copy_to_host_async") = py::cpp_function(
594 [](PyBuffer::object self) { return self.buf()->CopyToHostAsync(); },
595 py::is_method(type));
596 type.attr("to_py") = py::cpp_function(
597 [](PyBuffer::object self) { return self.buf()->AsNumPyArray(self); },
598 py::is_method(type));
599 type.attr("xla_shape") = py::cpp_function(
600 [](PyBuffer::object self) { return self.buf()->shape(); },
601 py::is_method(type));
602 type.attr("xla_dynamic_shape") = py::cpp_function(
603 [](PyBuffer::object self) { return self.buf()->xla_dynamic_shape(); },
604 py::is_method(type));
605 type.attr("client") = property_readonly(
606 [](PyBuffer::object self) { return self.buf()->client(); });
607 type.attr("device") = py::cpp_function(
608 [](PyBuffer::object self) { return self.buf()->device(); },
609 py::is_method(type));
610 type.attr("platform") = py::cpp_function(
611 [](PyBuffer::object self) { return self.buf()->platform_name(); },
612 py::is_method(type));
613 type.attr("is_deleted") = py::cpp_function(
614 [](PyBuffer::object self) { return self.buf()->is_deleted(); },
615 py::is_method(type));
616 type.attr("unsafe_buffer_pointer") = py::cpp_function(
617 [](PyBuffer::object self) { return self.buf()->UnsafeBufferPointer(); },
618 py::is_method(type));
619 type.attr("__cuda_array_interface__") = property_readonly(
620 [](PyBuffer::object self) { return self.buf()->CudaArrayInterface(); });
621 type.attr("traceback") = property_readonly(
622 [](PyBuffer::object self) { return self.buf()->traceback(); });
623 type.attr("clone") = py::cpp_function(
624 [](PyBuffer::object self) { return self.buf()->Clone(); },
625 py::is_method(type));
626 type.attr("__module__") = m.attr("__name__");
627 return Status::OK();
628 }
629
630 } // namespace xla
631