• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_buffer.h"
17 
18 #include "absl/base/casts.h"
19 #include "pybind11/pybind11.h"
20 #include "pybind11/pytypes.h"
21 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
22 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
23 #include "tensorflow/compiler/xla/python/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 
26 namespace xla {
27 
28 namespace py = pybind11;
29 
PyBuffer(std::shared_ptr<PyClient> client,std::unique_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)30 PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
31                    std::unique_ptr<PjRtBuffer> buffer,
32                    std::shared_ptr<Traceback> traceback)
33     : client_(std::move(client)),
34       buffer_(std::move(buffer)),
35       traceback_(std::move(traceback)) {
36   CHECK(PyGILState_Check());
37   next_ = client_->buffers_;
38   client_->buffers_ = this;
39   prev_ = nullptr;
40   if (next_) {
41     next_->prev_ = this;
42   }
43 }
44 
~PyBuffer()45 PyBuffer::~PyBuffer() {
46   CHECK(PyGILState_Check());
47   if (client_->buffers_ == this) {
48     client_->buffers_ = next_;
49   }
50   if (prev_) {
51     prev_->next_ = next_;
52   }
53   if (next_) {
54     next_->prev_ = prev_;
55   }
56 }
57 
python_shape() const58 pybind11::tuple PyBuffer::python_shape() const {
59   return IntSpanToTuple(buffer()->on_device_shape().dimensions());
60 }
61 
python_dtype() const62 pybind11::dtype PyBuffer::python_dtype() const {
63   PrimitiveType primitive = buffer()->on_device_shape().element_type();
64   return PrimitiveTypeToDtype(primitive).ValueOrDie();
65 }
66 
device() const67 ClientAndPtr<PjRtDevice> PyBuffer::device() const {
68   return WrapWithClient(client_, buffer_->device());
69 }
70 
CopyToDevice(const ClientAndPtr<PjRtDevice> & dst_device) const71 StatusOr<std::unique_ptr<PyBuffer>> PyBuffer::CopyToDevice(
72     const ClientAndPtr<PjRtDevice>& dst_device) const {
73   CHECK(dst_device.get() != nullptr);
74   GlobalPyRefManager()->CollectGarbage();
75   std::unique_ptr<PjRtBuffer> out;
76   {
77     py::gil_scoped_release gil_release;
78     TF_ASSIGN_OR_RETURN(out, buffer_->CopyToDevice(dst_device.get()));
79   }
80   auto traceback = Traceback::Get();
81   return std::make_unique<PyBuffer>(dst_device.client, std::move(out),
82                                     std::move(traceback));
83 }
84 
BlockHostUntilReady()85 Status PyBuffer::BlockHostUntilReady() {
86   GlobalPyRefManager()->CollectGarbage();
87   py::gil_scoped_release gil_release;
88   return buffer_->BlockHostUntilReady();
89 }
90 
CopyToHostAsync()91 Status PyBuffer::CopyToHostAsync() {
92   if (!buffer_->IsOnCpu() && !host_value_) {
93     std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
94     host_value_ = host_value;
95     py::gil_scoped_release gil;
96     host_value->value = std::make_shared<Literal>(
97         ShapeUtil::DeviceShapeToHostShape(buffer_->on_device_shape()));
98     Literal* literal = host_value->value.get();
99     buffer_->ToLiteral(literal,
100                        [host_value{std::move(host_value)}](Status status) {
101                          host_value->status = std::move(status);
102                          host_value->ready.Notify();
103                        });
104   }
105   return Status::OK();
106 }
107 
AsNumPyArray(py::handle this_obj)108 StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
109   if (buffer_->IsDeleted()) {
110     return InvalidArgument("DeviceArray has been deleted.");
111   }
112   TF_RET_CHECK(buffer_->on_device_shape().IsArray());
113   // On CPU, we can return the value in a zero-copy way.
114   if (buffer_->IsOnCpu()) {
115     TF_ASSIGN_OR_RETURN(
116         py::dtype dtype,
117         PrimitiveTypeToDtype(buffer_->on_device_shape().element_type()));
118     // Objects that must be kept alive while the array is alive.
119     struct Hold {
120       py::object buffer;
121       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
122     };
123     auto hold = std::make_unique<Hold>();
124     TF_ASSIGN_OR_RETURN(hold->external_reference_hold,
125                         buffer_->AcquireExternalReference());
126     hold->buffer = py::reinterpret_borrow<py::object>(this_obj);
127     void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
128     py::capsule hold_capsule(hold.release(),
129                              [](void* h) { delete static_cast<Hold*>(h); });
130     py::array array(dtype, buffer_->on_device_shape().dimensions(),
131                     ByteStridesForShape(buffer_->on_device_shape()), data,
132                     hold_capsule);
133     array.attr("flags").attr("writeable") = Py_False;
134     {
135       py::gil_scoped_release gil;
136       TF_RETURN_IF_ERROR(buffer_->BlockHostUntilReady());
137     }
138     return array;
139   }
140 
141   TF_RETURN_IF_ERROR(CopyToHostAsync());
142   if (!host_value_->ready.HasBeenNotified()) {
143     py::gil_scoped_release gil;
144     host_value_->ready.WaitForNotification();
145   }
146   TF_RETURN_IF_ERROR(host_value_->status);
147   TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value_->value));
148   array.attr("flags").attr("writeable") = Py_False;
149   return array;
150 }
151 
152 // TODO(zhangqiaorjc): Delete UnsafeBufferPointer.
UnsafeBufferPointer() const153 StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
154   if (buffer_->on_device_shape().IsTuple()) {
155     return Unimplemented(
156         "unsafe_buffer_pointer is not implemented for tuple "
157         "buffers.");
158   }
159 
160   TF_ASSIGN_OR_RETURN(
161       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
162       buffer_->AcquireExternalReference());
163   const void* ptr = external_reference_hold->OpaqueDeviceMemoryDataPointer();
164   return absl::bit_cast<std::uintptr_t>(ptr);
165 }
166 
CudaArrayInterface() const167 StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
168   // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
169   if (buffer_->client()->platform_id() != kGpuId) {
170     return InvalidArgument(
171         "__cuda_array_interface__ is only defined for NVidia GPU buffers.");
172   }
173   if (!buffer_->on_device_shape().IsArray()) {
174     return InvalidArgument(
175         "__cuda_array_interface__ is only defined for array buffers.");
176   }
177   if (buffer_->on_device_shape().element_type() == BF16) {
178     return InvalidArgument(
179         "__cuda_array_interface__ is not supported for bfloat16 buffers.");
180   }
181   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
182       buffer_->on_device_shape().layout()));
183 
184   py::dict result;
185   result["shape"] = IntSpanToTuple(buffer_->on_device_shape().dimensions());
186   TF_ASSIGN_OR_RETURN(py::str typestr,
187                       TypeDescriptorForPrimitiveType(
188                           buffer_->on_device_shape().element_type()));
189   result["typestr"] = std::move(typestr);
190   TF_ASSIGN_OR_RETURN(
191       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
192       buffer_->AcquireExternalReference());
193   const void* root_ptr =
194       external_reference_hold->OpaqueDeviceMemoryDataPointer();
195   py::tuple data(2);
196   data[0] = py::int_(absl::bit_cast<std::uintptr_t>(root_ptr));
197   data[1] = py::bool_(true);  // read-only
198   result["data"] = std::move(data);
199   result["version"] = py::int_(2);
200   return result;
201 }
202 
203 // PEP 3118 buffer protocol implementation.
204 
205 namespace {
206 
207 // Extra data to be kept alive by the consumer of the buffer protocol.
208 struct ExtraBufferInfo {
ExtraBufferInfoxla::__anonb86489e10311::ExtraBufferInfo209   explicit ExtraBufferInfo(
210       std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold)
211       : external_reference_hold(std::move(external_reference_hold)) {}
212 
213   std::string format;
214   std::vector<Py_ssize_t> strides;
215   // We keep an external reference hold to the PjRtBuffer. This prevents a
216   // use-after-free in the event that Delete() is called on a buffer with an
217   // live buffer protocol view. It does however mean that Delete() sometimes
218   // won't actually delete immediately.
219   std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
220 };
221 
PjRtBufferGetBuffer(PyObject * exporter,Py_buffer * view,int flags)222 int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
223   auto& buffer =
224       *py::reinterpret_borrow<py::object>(exporter).cast<PyBuffer&>().buffer();
225   Status status = [&]() {
226     // Py_buffer objects are POD C structures, so we don't need to hold the GIL.
227     // Additionally we call BlockHostUntilReady() below, which may block.
228     py::gil_scoped_release gil_release;
229 
230     if (!buffer.IsOnCpu()) {
231       return InvalidArgument(
232           "Python buffer protocol is only defined for CPU buffers.");
233     }
234     if (!buffer.on_device_shape().IsArray()) {
235       return InvalidArgument(
236           "Python buffer protocol is only defined for array buffers.");
237     }
238     // If we allowed exports of formatted BF16 buffers, consumers would get
239     // confused about the type because there is no way to describe BF16 to
240     // Python.
241     if (buffer.on_device_shape().element_type() == BF16 &&
242         ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) {
243       return InvalidArgument(
244           "bfloat16 buffer format not supported by Python buffer protocol.");
245     }
246     if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) {
247       return InvalidArgument("XLA buffers are read-only.");
248     }
249     TF_ASSIGN_OR_RETURN(
250         std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
251         buffer.AcquireExternalReference());
252     if (buffer.IsDeleted()) {
253       return InvalidArgument("Deleted buffer used in buffer protocol.");
254     }
255     const Shape& shape = buffer.on_device_shape();
256     if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS ||
257          (flags & PyBUF_STRIDES) == PyBUF_ND) &&
258         !LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
259       return InvalidArgument("Buffer is not in C-contiguous layout.");
260     } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS &&
261                !LayoutUtil::IsMonotonicWithDim0Minor(shape.layout())) {
262       return InvalidArgument("Buffer is not in F-contiguous layout.");
263     } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS &&
264                !LayoutUtil::IsMonotonicWithDim0Major(shape.layout()) &&
265                !LayoutUtil::IsMonotonicWithDim0Minor(shape.layout())) {
266       return InvalidArgument("Buffer is not in contiguous layout.");
267     }
268     std::memset(view, 0, sizeof(Py_buffer));
269     const void* root_ptr =
270         external_reference_hold->OpaqueDeviceMemoryDataPointer();
271     view->buf = const_cast<void*>(root_ptr);
272     auto extra =
273         absl::make_unique<ExtraBufferInfo>(std::move(external_reference_hold));
274     view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
275     view->len = ShapeUtil::ByteSizeOf(shape);
276     view->readonly = 1;
277     if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
278       TF_ASSIGN_OR_RETURN(extra->format, FormatDescriptorForPrimitiveType(
279                                              shape.element_type()));
280       view->format = const_cast<char*>(extra->format.c_str());
281     }
282     if ((flags & PyBUF_ND) == PyBUF_ND) {
283       view->ndim = shape.dimensions_size();
284       static_assert(sizeof(int64) == sizeof(Py_ssize_t),
285                     "Py_ssize_t must be 64 bits");
286       if (view->ndim != 0) {
287         view->shape = reinterpret_cast<Py_ssize_t*>(
288             const_cast<int64*>(shape.dimensions().data()));
289         if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
290           extra->strides = ByteStridesForShape(shape);
291           view->strides = extra->strides.data();
292         }
293       }
294     }
295     TF_RETURN_IF_ERROR(buffer.BlockHostUntilReady());
296     view->internal = extra.release();
297     return Status::OK();
298   }();
299   if (!status.ok()) {
300     PyErr_SetString(PyExc_BufferError, status.ToString().c_str());
301     return -1;
302   }
303   view->obj = exporter;
304   Py_INCREF(view->obj);
305   return 0;
306 }
307 
PjRtBufferReleaseBuffer(PyObject *,Py_buffer * buffer)308 void PjRtBufferReleaseBuffer(PyObject*, Py_buffer* buffer) {
309   auto extra = static_cast<ExtraBufferInfo*>(buffer->internal);
310   delete extra;
311 }
312 
__anonb86489e10502() 313 PyBufferProcs PjRtBufferProcs = []() {
314   PyBufferProcs procs;
315   procs.bf_getbuffer = &PjRtBufferGetBuffer;
316   procs.bf_releasebuffer = &PjRtBufferReleaseBuffer;
317   return procs;
318 }();
319 
320 }  // namespace
321 
BufferProtocol()322 /*static*/ PyBufferProcs* PyBuffer::BufferProtocol() {
323   return &PjRtBufferProcs;
324 }
325 
SetStickyDevice(pybind11::object sticky_device)326 void PyBuffer::SetStickyDevice(pybind11::object sticky_device) {
327   if (sticky_device_ && !sticky_device_->equal(sticky_device)) {
328     throw std::invalid_argument(
329         "One cannot set again the stickyness of a buffer and needs to create "
330         "a new one or a `_DeviceArray`");
331   }
332   sticky_device_ = sticky_device;
333 }
334 
SetAval(pybind11::object aval)335 void PyBuffer::SetAval(pybind11::object aval) {
336   if (aval_ && !aval_->equal(aval)) {
337     throw std::invalid_argument(
338         "One cannot set again the aval_ of a buffer and needs to create a "
339         "new one or a `_DeviceArray`");
340   }
341   aval_ = aval;
342 }
343 
344 }  // namespace xla
345