• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_buffer.h"
17 
18 #include <functional>
19 #include <type_traits>
20 
21 #include "absl/base/casts.h"
22 #include "pybind11/pybind11.h"
23 #include "pybind11/pytypes.h"
24 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
25 #include "tensorflow/compiler/xla/python/py_client.h"
26 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
27 #include "tensorflow/compiler/xla/python/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 
30 namespace xla {
31 
32 namespace py = pybind11;
33 
34 namespace {
35 
36 // Representation of a DeviceArrayBase as a Python object. Since
37 // a DeviceArrayBase has no fields, this is just a PyObject.
38 struct PyBufferBasePyObject {
39   PyObject_HEAD;
40 };
41 static_assert(std::is_standard_layout<PyBufferBasePyObject>::value,
42               "PyBufferBasePyObject must be standard layout");
43 
44 // Representation of a DeviceArray as a Python object.
45 struct PyBufferPyObject {
46   PyBufferBasePyObject base;
47   PyBuffer buffer;
48   // Used by the Python interpreter to maintain a list of weak references to
49   // this object.
50   PyObject* weakrefs;
51 };
52 static_assert(std::is_standard_layout<PyBufferPyObject>::value,
53               "PyBufferPyObject must be standard layout");
54 
PyBuffer_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)55 PyObject* PyBuffer_tp_new(PyTypeObject* subtype, PyObject* args,
56                           PyObject* kwds) {
57   PyBufferPyObject* self =
58       reinterpret_cast<PyBufferPyObject*>(subtype->tp_alloc(subtype, 0));
59   if (!self) return nullptr;
60   self->weakrefs = nullptr;
61   return reinterpret_cast<PyObject*>(self);
62 }
63 
PyBuffer_tp_dealloc(PyObject * self)64 void PyBuffer_tp_dealloc(PyObject* self) {
65   PyTypeObject* tp = Py_TYPE(self);
66   PyBufferPyObject* o = reinterpret_cast<PyBufferPyObject*>(self);
67   if (o->weakrefs) {
68     PyObject_ClearWeakRefs(self);
69   }
70   o->buffer.~PyBuffer();
71   tp->tp_free(self);
72   Py_DECREF(tp);
73 }
74 
75 }  // namespace
76 
Make(std::shared_ptr<PyClient> client,std::shared_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)77 /*static*/ PyBuffer::object PyBuffer::Make(
78     std::shared_ptr<PyClient> client, std::shared_ptr<PjRtBuffer> buffer,
79     std::shared_ptr<Traceback> traceback) {
80   py::object obj = py::reinterpret_steal<py::object>(PyBuffer_tp_new(
81       reinterpret_cast<PyTypeObject*>(type_), nullptr, nullptr));
82   PyBufferPyObject* buf = reinterpret_cast<PyBufferPyObject*>(obj.ptr());
83   new (&buf->buffer)
84       PyBuffer(std::move(client), std::move(buffer), std::move(traceback));
85   return py::reinterpret_borrow<PyBuffer::object>(obj);
86 }
87 
IsPyBuffer(py::handle handle)88 bool PyBuffer::IsPyBuffer(py::handle handle) {
89   return handle.get_type() == PyBuffer::type();
90 }
91 
AsPyBufferUnchecked(pybind11::handle handle)92 /*static*/ PyBuffer* PyBuffer::AsPyBufferUnchecked(pybind11::handle handle) {
93   return &(reinterpret_cast<PyBufferPyObject*>(handle.ptr())->buffer);
94 }
95 
AsPyBuffer(pybind11::handle handle)96 /*static*/ StatusOr<PyBuffer*> PyBuffer::AsPyBuffer(pybind11::handle handle) {
97   if (!IsPyBuffer(handle)) {
98     return InvalidArgument("Expected a DeviceArray");
99   }
100   return AsPyBufferUnchecked(handle);
101 }
102 
AsHandle()103 py::handle PyBuffer::AsHandle() {
104   return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
105                                      offsetof(PyBufferPyObject, buffer));
106 }
107 
PyBuffer(std::shared_ptr<PyClient> client,std::shared_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)108 PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
109                    std::shared_ptr<PjRtBuffer> buffer,
110                    std::shared_ptr<Traceback> traceback)
111     : client_(std::move(client)),
112       buffer_(std::move(buffer)),
113       traceback_(std::move(traceback)) {
114   CHECK(PyGILState_Check());
115   next_ = client_->buffers_[buffer_->device()->id()];
116   client_->buffers_[buffer_->device()->id()] = this;
117   prev_ = nullptr;
118   if (next_) {
119     next_->prev_ = this;
120   }
121 }
122 
~PyBuffer()123 PyBuffer::~PyBuffer() {
124   CHECK(PyGILState_Check());
125   if (client_->buffers_[device()->id()] == this) {
126     client_->buffers_[device()->id()] = next_;
127   }
128   if (prev_) {
129     prev_->next_ = next_;
130   }
131   if (next_) {
132     next_->prev_ = prev_;
133   }
134 }
135 
size()136 StatusOr<int64> PyBuffer::size() {
137   Shape max_buffer_shape = buffer()->on_device_shape();
138   if (max_buffer_shape.is_dynamic()) {
139     TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
140     return ShapeUtil::ElementsIn(*dynamic_shape);
141   }
142   return ShapeUtil::ElementsIn(max_buffer_shape);
143 }
144 
xla_dynamic_shape()145 StatusOr<const Shape*> PyBuffer::xla_dynamic_shape() {
146   CHECK(PyGILState_Check());
147   if (buffer_->on_device_shape().is_static()) {
148     return &buffer_->on_device_shape();
149   }
150   // Python buffer protocol references shape data by pointer, therefore we must
151   // store a valid copy of the shape.
152   if (!dynamic_shape_) {
153     Shape dynamic_shape;
154     {
155       py::gil_scoped_release gil_release;
156       TF_ASSIGN_OR_RETURN(dynamic_shape, buffer_->logical_on_device_shape());
157     }
158     dynamic_shape_ = dynamic_shape;
159   }
160   return &dynamic_shape_.value();
161 }
162 
python_shape() const163 pybind11::tuple PyBuffer::python_shape() const {
164   return SpanToTuple(buffer()->on_device_shape().dimensions());
165 }
166 
python_dtype() const167 pybind11::dtype PyBuffer::python_dtype() const {
168   PrimitiveType primitive = buffer()->on_device_shape().element_type();
169   return PrimitiveTypeToDtype(primitive).ValueOrDie();
170 }
171 
device() const172 ClientAndPtr<PjRtDevice> PyBuffer::device() const {
173   return WrapWithClient(client_, buffer_->device());
174 }
175 
Clone() const176 PyBuffer::object PyBuffer::Clone() const {
177   auto buffer = Make(client_, buffer_, traceback_);
178   buffer.buf()->sticky_device_ = sticky_device_;
179   buffer.buf()->aval_ = aval_;
180   return buffer;
181 }
182 
CopyToDevice(const ClientAndPtr<PjRtDevice> & dst_device) const183 StatusOr<py::object> PyBuffer::CopyToDevice(
184     const ClientAndPtr<PjRtDevice>& dst_device) const {
185   CHECK(dst_device.get() != nullptr);
186   GlobalPyRefManager()->CollectGarbage();
187   std::unique_ptr<PjRtBuffer> out;
188   {
189     py::gil_scoped_release gil_release;
190     TF_ASSIGN_OR_RETURN(out, buffer_->CopyToDevice(dst_device.get()));
191   }
192   auto traceback = Traceback::Get();
193   return Make(dst_device.client, std::move(out), std::move(traceback));
194 }
195 
BlockHostUntilReady()196 Status PyBuffer::BlockHostUntilReady() {
197   GlobalPyRefManager()->CollectGarbage();
198   py::gil_scoped_release gil_release;
199   return buffer_->BlockHostUntilReady();
200 }
201 
CopyToHostAsync()202 Status PyBuffer::CopyToHostAsync() {
203   if (!buffer_->IsOnCpu() && !host_value_) {
204     std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
205     host_value_ = host_value;
206     // TODO(b/182461453): This is a blocking call. If we further implemented
207     // populating dynamic shape metadata while fetching the literal, we wouldn't
208     // need this static approach.
209     TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
210 
211     py::gil_scoped_release gil;
212     host_value->value = std::make_shared<Literal>(
213         ShapeUtil::DeviceShapeToHostShape(*dynamic_shape));
214     Literal* literal = host_value->value.get();
215     buffer_->ToLiteral(literal,
216                        [host_value{std::move(host_value)}](Status status) {
217                          host_value->status = std::move(status);
218                          host_value->ready.Notify();
219                        });
220   }
221   return Status::OK();
222 }
223 
AsNumPyArray(py::handle this_obj)224 StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
225   if (buffer_->IsDeleted()) {
226     return InvalidArgument("DeviceArray has been deleted.");
227   }
228   TF_RET_CHECK(buffer_->on_device_shape().IsArray());
229   // On CPU, we can return the value in a zero-copy way.
230   if (buffer_->IsOnCpu()) {
231     TF_ASSIGN_OR_RETURN(const auto* shape, xla_dynamic_shape());
232     TF_ASSIGN_OR_RETURN(py::dtype dtype,
233                         PrimitiveTypeToDtype(shape->element_type()));
234     // Objects that must be kept alive while the array is alive.
235     struct Hold {
236       py::object buffer;
237       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
238     };
239     auto hold = std::make_unique<Hold>();
240     TF_ASSIGN_OR_RETURN(hold->external_reference_hold,
241                         buffer_->AcquireExternalReference());
242     hold->buffer = py::reinterpret_borrow<py::object>(this_obj);
243     void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
244     py::capsule hold_capsule(hold.release(),
245                              [](void* h) { delete static_cast<Hold*>(h); });
246     py::array array(dtype, shape->dimensions(), ByteStridesForShape(*shape),
247                     data, hold_capsule);
248     array.attr("flags").attr("writeable") = Py_False;
249     {
250       py::gil_scoped_release gil;
251       TF_RETURN_IF_ERROR(buffer_->BlockHostUntilReady());
252     }
253     return array;
254   }
255 
256   TF_RETURN_IF_ERROR(CopyToHostAsync());
257   if (!host_value_->ready.HasBeenNotified()) {
258     py::gil_scoped_release gil;
259     host_value_->ready.WaitForNotification();
260   }
261   TF_RETURN_IF_ERROR(host_value_->status);
262   TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value_->value));
263   array.attr("flags").attr("writeable") = Py_False;
264   return array;
265 }
266 
UnsafeBufferPointer() const267 StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
268   return client_->pjrt_client()->UnsafeBufferPointer(buffer_.get());
269 }
270 
CudaArrayInterface()271 StatusOr<py::dict> PyBuffer::CudaArrayInterface() {
272   // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
273   if (buffer_->client()->platform_id() != kGpuId) {
274     return InvalidArgument(
275         "__cuda_array_interface__ is only defined for NVidia GPU buffers.");
276   }
277   if (!buffer_->on_device_shape().IsArray()) {
278     return InvalidArgument(
279         "__cuda_array_interface__ is only defined for array buffers.");
280   }
281   if (buffer_->on_device_shape().element_type() == BF16) {
282     return InvalidArgument(
283         "__cuda_array_interface__ is not supported for bfloat16 buffers.");
284   }
285   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
286       buffer_->on_device_shape().layout()));
287 
288   py::dict result;
289   TF_ASSIGN_OR_RETURN(const auto* dynamic_shape, xla_dynamic_shape());
290   result["shape"] = SpanToTuple(dynamic_shape->dimensions());
291   TF_ASSIGN_OR_RETURN(py::str typestr,
292                       TypeDescriptorForPrimitiveType(
293                           buffer_->on_device_shape().element_type()));
294   result["typestr"] = std::move(typestr);
295   TF_ASSIGN_OR_RETURN(
296       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
297       buffer_->AcquireExternalReference());
298   const void* root_ptr =
299       external_reference_hold->OpaqueDeviceMemoryDataPointer();
300   py::tuple data(2);
301   data[0] = py::int_(absl::bit_cast<std::uintptr_t>(root_ptr));
302   data[1] = py::bool_(true);  // read-only
303   result["data"] = std::move(data);
304   result["version"] = py::int_(2);
305   return result;
306 }
307 
308 // PEP 3118 buffer protocol implementation.
309 
310 namespace {
311 
312 // Extra data to be kept alive by the consumer of the buffer protocol.
313 struct ExtraBufferInfo {
ExtraBufferInfoxla::__anon5a1ba3860411::ExtraBufferInfo314   explicit ExtraBufferInfo(
315       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold)
316       : external_reference_hold(std::move(external_reference_hold)) {}
317 
318   std::string format;
319   std::vector<Py_ssize_t> strides;
320   // We keep an external reference hold to the PjRtBuffer. This prevents a
321   // use-after-free in the event that Delete() is called on a buffer with an
322   // live buffer protocol view. It does however mean that Delete() sometimes
323   // won't actually delete immediately.
324   std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
325 };
326 
PyBuffer_bf_getbuffer(PyObject * exporter,Py_buffer * view,int flags)327 int PyBuffer_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) {
328   Status status = [&]() {
329     TF_ASSIGN_OR_RETURN(PyBuffer * py_buffer, PyBuffer::AsPyBuffer(exporter));
330     PjRtBuffer& buffer = *py_buffer->buffer();
331     TF_ASSIGN_OR_RETURN(const auto* shape, py_buffer->xla_dynamic_shape());
332     // Py_buffer objects are POD C structures, so we don't need to hold the GIL.
333     // Additionally we call BlockHostUntilReady() below, which may block.
334     py::gil_scoped_release gil_release;
335 
336     if (!buffer.IsOnCpu()) {
337       return InvalidArgument(
338           "Python buffer protocol is only defined for CPU buffers.");
339     }
340     if (!buffer.on_device_shape().IsArray()) {
341       return InvalidArgument(
342           "Python buffer protocol is only defined for array buffers.");
343     }
344     // If we allowed exports of formatted BF16 buffers, consumers would get
345     // confused about the type because there is no way to describe BF16 to
346     // Python.
347     if (buffer.on_device_shape().element_type() == BF16 &&
348         ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) {
349       return InvalidArgument(
350           "bfloat16 buffer format not supported by Python buffer protocol.");
351     }
352     if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) {
353       return InvalidArgument("XLA buffers are read-only.");
354     }
355     TF_ASSIGN_OR_RETURN(
356         std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
357         buffer.AcquireExternalReference());
358     if (buffer.IsDeleted()) {
359       return InvalidArgument("Deleted buffer used in buffer protocol.");
360     }
361 
362     if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS ||
363          (flags & PyBUF_STRIDES) == PyBUF_ND) &&
364         !LayoutUtil::IsMonotonicWithDim0Major(shape->layout())) {
365       return InvalidArgument("Buffer is not in C-contiguous layout.");
366     } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS &&
367                !LayoutUtil::IsMonotonicWithDim0Minor(shape->layout())) {
368       return InvalidArgument("Buffer is not in F-contiguous layout.");
369     } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS &&
370                !LayoutUtil::IsMonotonicWithDim0Major(shape->layout()) &&
371                !LayoutUtil::IsMonotonicWithDim0Minor(shape->layout())) {
372       return InvalidArgument("Buffer is not in contiguous layout.");
373     }
374     std::memset(view, 0, sizeof(Py_buffer));
375     const void* root_ptr =
376         external_reference_hold->OpaqueDeviceMemoryDataPointer();
377     view->buf = const_cast<void*>(root_ptr);
378     auto extra =
379         absl::make_unique<ExtraBufferInfo>(std::move(external_reference_hold));
380     view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(shape->element_type());
381     view->len = ShapeUtil::ByteSizeOf(*shape);
382     view->readonly = 1;
383     if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
384       TF_ASSIGN_OR_RETURN(extra->format, FormatDescriptorForPrimitiveType(
385                                              shape->element_type()));
386       view->format = const_cast<char*>(extra->format.c_str());
387     }
388     if ((flags & PyBUF_ND) == PyBUF_ND) {
389       view->ndim = shape->dimensions_size();
390       static_assert(sizeof(int64) == sizeof(Py_ssize_t),
391                     "Py_ssize_t must be 64 bits");
392       if (view->ndim != 0) {
393         view->shape = reinterpret_cast<Py_ssize_t*>(
394             const_cast<int64*>(shape->dimensions().data()));
395         if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
396           extra->strides = ByteStridesForShape(*shape);
397           view->strides = extra->strides.data();
398         }
399       }
400     }
401     TF_RETURN_IF_ERROR(buffer.BlockHostUntilReady());
402     view->internal = extra.release();
403     return Status::OK();
404   }();
405   if (!status.ok()) {
406     // numpy.asarray(...) silents the PyExc_BufferError. Adding a log here helps
407     // debugging when the error really occurs.
408     VLOG(1) << "Buffer Protocol Error: " << status;
409     PyErr_SetString(PyExc_BufferError, status.ToString().c_str());
410     return -1;
411   }
412   view->obj = exporter;
413   Py_INCREF(view->obj);
414   return 0;
415 }
416 
PyBuffer_bf_releasebuffer(PyObject *,Py_buffer * buffer)417 void PyBuffer_bf_releasebuffer(PyObject*, Py_buffer* buffer) {
418   auto extra = static_cast<ExtraBufferInfo*>(buffer->internal);
419   delete extra;
420 }
421 
__anon5a1ba3860602() 422 PyBufferProcs PyBuffer_tp_as_buffer = []() {
423   PyBufferProcs procs;
424   procs.bf_getbuffer = &PyBuffer_bf_getbuffer;
425   procs.bf_releasebuffer = &PyBuffer_bf_releasebuffer;
426   return procs;
427 }();
428 
429 // Helpers for building Python properties
430 template <typename Func>
property_readonly(Func && get)431 py::object property_readonly(Func&& get) {
432   py::handle property(reinterpret_cast<PyObject*>(&PyProperty_Type));
433   return property(py::cpp_function(std::forward<Func>(get)), py::none(),
434                   py::none(), "");
435 }
436 
437 template <typename GetFunc, typename SetFunc>
property(GetFunc && get,SetFunc && set)438 py::object property(GetFunc&& get, SetFunc&& set) {
439   py::handle property(reinterpret_cast<PyObject*>(&PyProperty_Type));
440   return property(py::cpp_function(std::forward<GetFunc>(get)),
441                   py::cpp_function(std::forward<SetFunc>(set)), py::none(), "");
442 }
443 
444 }  // namespace
445 
446 PyObject* PyBuffer::base_type_ = nullptr;
447 PyObject* PyBuffer::type_ = nullptr;
448 
RegisterTypes(py::module & m)449 Status PyBuffer::RegisterTypes(py::module& m) {
450   // We do not use pybind11::class_ to build Python wrapper objects because
451   // creation, destruction, and casting of buffer objects is performance
452   // critical. By using hand-written Python classes, we can avoid extra C heap
453   // allocations, and we can avoid pybind11's slow cast<>() implementation
454   // during jit dispatch.
455 
456   // We need to use heap-allocated type objects because we want to add
457   // additional methods dynamically.
458   {
459     py::str name = py::str("DeviceArrayBase");
460     py::str qualname = py::str("DeviceArrayBase");
461     PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
462         PyType_Type.tp_alloc(&PyType_Type, 0));
463     // Caution: we must not call any functions that might invoke the GC until
464     // PyType_Ready() is called. Otherwise the GC might see a half-constructed
465     // type object.
466     if (!heap_type) {
467       return Internal("Unable to create heap type object");
468     }
469     heap_type->ht_name = name.release().ptr();
470     heap_type->ht_qualname = qualname.release().ptr();
471     PyTypeObject* type = &heap_type->ht_type;
472     type->tp_name = "DeviceArrayBase";
473     type->tp_basicsize = sizeof(PyBufferBasePyObject);
474     type->tp_flags =
475         Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_BASETYPE;
476     TF_RET_CHECK(PyType_Ready(type) == 0);
477     base_type_ = reinterpret_cast<PyObject*>(type);
478   }
479   py::object base_type = py::reinterpret_borrow<py::object>(base_type_);
480   base_type.attr("__module__") = m.attr("__name__");
481 
482   m.attr("DeviceArrayBase") = base_type;
483   {
484     py::tuple bases = py::make_tuple(base_type);
485     py::str name = py::str("DeviceArray");
486     py::str qualname = py::str("DeviceArray");
487     PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
488         PyType_Type.tp_alloc(&PyType_Type, 0));
489     // Caution: we must not call any functions that might invoke the GC until
490     // PyType_Ready() is called below. Otherwise the GC might see a
491     // half-constructed type object.
492     if (!heap_type) {
493       return Internal("Unable to create heap type object");
494     }
495     heap_type->ht_name = name.release().ptr();
496     heap_type->ht_qualname = qualname.release().ptr();
497     PyTypeObject* type = &heap_type->ht_type;
498     type->tp_name = "DeviceArray";
499     type->tp_basicsize = sizeof(PyBufferPyObject);
500     type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE;
501     type->tp_bases = bases.release().ptr();
502     type->tp_dealloc = PyBuffer_tp_dealloc;
503     type->tp_new = PyBuffer_tp_new;
504     // Supported protocols
505     type->tp_as_number = &heap_type->as_number;
506     type->tp_as_sequence = &heap_type->as_sequence;
507     type->tp_as_mapping = &heap_type->as_mapping;
508     type->tp_as_buffer = &PyBuffer_tp_as_buffer;
509 
510     // Allow weak references to DeviceArray objects.
511     type->tp_weaklistoffset = offsetof(PyBufferPyObject, weakrefs);
512 
513     TF_RET_CHECK(PyType_Ready(type) == 0);
514     type_ = reinterpret_cast<PyObject*>(type);
515   }
516   py::object type = py::reinterpret_borrow<py::object>(type_);
517   m.attr("DeviceArray") = type;
518   m.attr("PyLocalBuffer") = type;
519   m.attr("Buffer") = type;
520 
521   // Add methods and properties to the class. We use pybind11 and add methods
522   // dynamically mostly because this is easy to write and allows us to use
523   // pybind11's casting logic. This is most likely slightly slower than
524   // hand-writing bindings, but most of these methods are not performance
525   // critical.
526   type.attr("__array_priority__") =
527       property_readonly([](py::object self) -> int { return 100; });
528   type.attr("_device") = property(
529       [](PyBuffer::object self) -> ClientAndPtr<PjRtDevice> {
530         return WrapWithClient(self.buf()->client(),
531                               self.buf()->sticky_device());
532       },
533       [](PyBuffer::object self, PjRtDevice* sticky_device) {
534         return self.buf()->set_sticky_device(sticky_device);
535       });
536   type.attr("aval") = property(
537       [](PyBuffer::object self) -> py::object { return self.buf()->GetAval(); },
538       [](PyBuffer::object self, py::object aval) {
539         return self.buf()->SetAval(std::move(aval));
540       });
541   type.attr("weak_type") = property(
542       [](PyBuffer::object self) -> absl::optional<bool> {
543         return self.buf()->weak_type();
544       },
545       [](PyBuffer::object self, absl::optional<bool> weak_type) {
546         return self.buf()->set_weak_type(weak_type);
547       });
548   type.attr("_lazy_expr") =
549       property_readonly([](py::handle self) { return py::none(); });
550   type.attr("device_buffer") =
551       property_readonly([](py::object self) { return self; });
552   type.attr(
553       "shape") = property_readonly([](PyBuffer::object self) -> py::tuple {
554     return SpanToTuple(self.buf()->buffer()->on_device_shape().dimensions());
555   });
556   type.attr("dtype") = property_readonly([](PyBuffer::object self) {
557     PrimitiveType primitive =
558         self.buf()->buffer()->on_device_shape().element_type();
559     return PrimitiveTypeToDtype(primitive).ValueOrDie();
560   });
561   type.attr("size") =
562       property_readonly([](PyBuffer::object self) -> StatusOr<int64_t> {
563         return self.buf()->size();
564       });
565   type.attr("ndim") = property_readonly(
566       [](PyBuffer::object self) -> int { return self.buf()->ndim(); });
567   type.attr("_value") = property_readonly(
568       [](PyBuffer::object self) -> StatusOr<pybind11::object> {
569         GlobalPyRefManager()->CollectGarbage();
570         return self.buf()->AsNumPyArray(self);
571       });
572   type.attr("copy_to_device") = py::cpp_function(
573       [](PyBuffer::object self, const ClientAndPtr<PjRtDevice>& dst_device) {
574         return self.buf()->CopyToDevice(dst_device);
575       },
576       py::is_method(type));
577   type.attr("on_device_size_in_bytes") = py::cpp_function(
578       [](PyBuffer::object self) -> StatusOr<size_t> {
579         return self.buf()->OnDeviceSizeInBytes();
580       },
581       py::is_method(type));
582   type.attr("delete") = py::cpp_function(
583       [](PyBuffer::object self) { self.buf()->Delete(); }, py::is_method(type));
584   type.attr("block_host_until_ready") = py::cpp_function(
585       [](PyBuffer::object self) { return self.buf()->BlockHostUntilReady(); },
586       py::is_method(type));
587   type.attr("block_until_ready") = py::cpp_function(
588       [](PyBuffer::object self) -> StatusOr<PyBuffer::object> {
589         TF_RETURN_IF_ERROR(self.buf()->BlockHostUntilReady());
590         return std::move(self);
591       },
592       py::is_method(type));
593   type.attr("copy_to_host_async") = py::cpp_function(
594       [](PyBuffer::object self) { return self.buf()->CopyToHostAsync(); },
595       py::is_method(type));
596   type.attr("to_py") = py::cpp_function(
597       [](PyBuffer::object self) { return self.buf()->AsNumPyArray(self); },
598       py::is_method(type));
599   type.attr("xla_shape") = py::cpp_function(
600       [](PyBuffer::object self) { return self.buf()->shape(); },
601       py::is_method(type));
602   type.attr("xla_dynamic_shape") = py::cpp_function(
603       [](PyBuffer::object self) { return self.buf()->xla_dynamic_shape(); },
604       py::is_method(type));
605   type.attr("client") = property_readonly(
606       [](PyBuffer::object self) { return self.buf()->client(); });
607   type.attr("device") = py::cpp_function(
608       [](PyBuffer::object self) { return self.buf()->device(); },
609       py::is_method(type));
610   type.attr("platform") = py::cpp_function(
611       [](PyBuffer::object self) { return self.buf()->platform_name(); },
612       py::is_method(type));
613   type.attr("is_deleted") = py::cpp_function(
614       [](PyBuffer::object self) { return self.buf()->is_deleted(); },
615       py::is_method(type));
616   type.attr("unsafe_buffer_pointer") = py::cpp_function(
617       [](PyBuffer::object self) { return self.buf()->UnsafeBufferPointer(); },
618       py::is_method(type));
619   type.attr("__cuda_array_interface__") = property_readonly(
620       [](PyBuffer::object self) { return self.buf()->CudaArrayInterface(); });
621   type.attr("traceback") = property_readonly(
622       [](PyBuffer::object self) { return self.buf()->traceback(); });
623   type.attr("clone") = py::cpp_function(
624       [](PyBuffer::object self) { return self.buf()->Clone(); },
625       py::is_method(type));
626   type.attr("__module__") = m.attr("__name__");
627   return Status::OK();
628 }
629 
630 }  // namespace xla
631