1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_buffer.h"
17
18 #include "absl/base/casts.h"
19 #include "pybind11/pybind11.h"
20 #include "pybind11/pytypes.h"
21 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
22 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
23 #include "tensorflow/compiler/xla/python/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25
26 namespace xla {
27
28 namespace py = pybind11;
29
PyBuffer(std::shared_ptr<PyClient> client,std::unique_ptr<PjRtBuffer> buffer,std::shared_ptr<Traceback> traceback)30 PyBuffer::PyBuffer(std::shared_ptr<PyClient> client,
31 std::unique_ptr<PjRtBuffer> buffer,
32 std::shared_ptr<Traceback> traceback)
33 : client_(std::move(client)),
34 buffer_(std::move(buffer)),
35 traceback_(std::move(traceback)) {
36 CHECK(PyGILState_Check());
37 next_ = client_->buffers_;
38 client_->buffers_ = this;
39 prev_ = nullptr;
40 if (next_) {
41 next_->prev_ = this;
42 }
43 }
44
~PyBuffer()45 PyBuffer::~PyBuffer() {
46 CHECK(PyGILState_Check());
47 if (client_->buffers_ == this) {
48 client_->buffers_ = next_;
49 }
50 if (prev_) {
51 prev_->next_ = next_;
52 }
53 if (next_) {
54 next_->prev_ = prev_;
55 }
56 }
57
python_shape() const58 pybind11::tuple PyBuffer::python_shape() const {
59 return IntSpanToTuple(buffer()->on_device_shape().dimensions());
60 }
61
python_dtype() const62 pybind11::dtype PyBuffer::python_dtype() const {
63 PrimitiveType primitive = buffer()->on_device_shape().element_type();
64 return PrimitiveTypeToDtype(primitive).ValueOrDie();
65 }
66
device() const67 ClientAndPtr<PjRtDevice> PyBuffer::device() const {
68 return WrapWithClient(client_, buffer_->device());
69 }
70
CopyToDevice(const ClientAndPtr<PjRtDevice> & dst_device) const71 StatusOr<std::unique_ptr<PyBuffer>> PyBuffer::CopyToDevice(
72 const ClientAndPtr<PjRtDevice>& dst_device) const {
73 CHECK(dst_device.get() != nullptr);
74 GlobalPyRefManager()->CollectGarbage();
75 std::unique_ptr<PjRtBuffer> out;
76 {
77 py::gil_scoped_release gil_release;
78 TF_ASSIGN_OR_RETURN(out, buffer_->CopyToDevice(dst_device.get()));
79 }
80 auto traceback = Traceback::Get();
81 return std::make_unique<PyBuffer>(dst_device.client, std::move(out),
82 std::move(traceback));
83 }
84
BlockHostUntilReady()85 Status PyBuffer::BlockHostUntilReady() {
86 GlobalPyRefManager()->CollectGarbage();
87 py::gil_scoped_release gil_release;
88 return buffer_->BlockHostUntilReady();
89 }
90
CopyToHostAsync()91 Status PyBuffer::CopyToHostAsync() {
92 if (!buffer_->IsOnCpu() && !host_value_) {
93 std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
94 host_value_ = host_value;
95 py::gil_scoped_release gil;
96 host_value->value = std::make_shared<Literal>(
97 ShapeUtil::DeviceShapeToHostShape(buffer_->on_device_shape()));
98 Literal* literal = host_value->value.get();
99 buffer_->ToLiteral(literal,
100 [host_value{std::move(host_value)}](Status status) {
101 host_value->status = std::move(status);
102 host_value->ready.Notify();
103 });
104 }
105 return Status::OK();
106 }
107
AsNumPyArray(py::handle this_obj)108 StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
109 if (buffer_->IsDeleted()) {
110 return InvalidArgument("DeviceArray has been deleted.");
111 }
112 TF_RET_CHECK(buffer_->on_device_shape().IsArray());
113 // On CPU, we can return the value in a zero-copy way.
114 if (buffer_->IsOnCpu()) {
115 TF_ASSIGN_OR_RETURN(
116 py::dtype dtype,
117 PrimitiveTypeToDtype(buffer_->on_device_shape().element_type()));
118 // Objects that must be kept alive while the array is alive.
119 struct Hold {
120 py::object buffer;
121 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
122 };
123 auto hold = std::make_unique<Hold>();
124 TF_ASSIGN_OR_RETURN(hold->external_reference_hold,
125 buffer_->AcquireExternalReference());
126 hold->buffer = py::reinterpret_borrow<py::object>(this_obj);
127 void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
128 py::capsule hold_capsule(hold.release(),
129 [](void* h) { delete static_cast<Hold*>(h); });
130 py::array array(dtype, buffer_->on_device_shape().dimensions(),
131 ByteStridesForShape(buffer_->on_device_shape()), data,
132 hold_capsule);
133 array.attr("flags").attr("writeable") = Py_False;
134 {
135 py::gil_scoped_release gil;
136 TF_RETURN_IF_ERROR(buffer_->BlockHostUntilReady());
137 }
138 return array;
139 }
140
141 TF_RETURN_IF_ERROR(CopyToHostAsync());
142 if (!host_value_->ready.HasBeenNotified()) {
143 py::gil_scoped_release gil;
144 host_value_->ready.WaitForNotification();
145 }
146 TF_RETURN_IF_ERROR(host_value_->status);
147 TF_ASSIGN_OR_RETURN(py::object array, LiteralToPython(host_value_->value));
148 array.attr("flags").attr("writeable") = Py_False;
149 return array;
150 }
151
152 // TODO(zhangqiaorjc): Delete UnsafeBufferPointer.
UnsafeBufferPointer() const153 StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
154 if (buffer_->on_device_shape().IsTuple()) {
155 return Unimplemented(
156 "unsafe_buffer_pointer is not implemented for tuple "
157 "buffers.");
158 }
159
160 TF_ASSIGN_OR_RETURN(
161 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
162 buffer_->AcquireExternalReference());
163 const void* ptr = external_reference_hold->OpaqueDeviceMemoryDataPointer();
164 return absl::bit_cast<std::uintptr_t>(ptr);
165 }
166
CudaArrayInterface() const167 StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
168 // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
169 if (buffer_->client()->platform_id() != kGpuId) {
170 return InvalidArgument(
171 "__cuda_array_interface__ is only defined for NVidia GPU buffers.");
172 }
173 if (!buffer_->on_device_shape().IsArray()) {
174 return InvalidArgument(
175 "__cuda_array_interface__ is only defined for array buffers.");
176 }
177 if (buffer_->on_device_shape().element_type() == BF16) {
178 return InvalidArgument(
179 "__cuda_array_interface__ is not supported for bfloat16 buffers.");
180 }
181 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
182 buffer_->on_device_shape().layout()));
183
184 py::dict result;
185 result["shape"] = IntSpanToTuple(buffer_->on_device_shape().dimensions());
186 TF_ASSIGN_OR_RETURN(py::str typestr,
187 TypeDescriptorForPrimitiveType(
188 buffer_->on_device_shape().element_type()));
189 result["typestr"] = std::move(typestr);
190 TF_ASSIGN_OR_RETURN(
191 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
192 buffer_->AcquireExternalReference());
193 const void* root_ptr =
194 external_reference_hold->OpaqueDeviceMemoryDataPointer();
195 py::tuple data(2);
196 data[0] = py::int_(absl::bit_cast<std::uintptr_t>(root_ptr));
197 data[1] = py::bool_(true); // read-only
198 result["data"] = std::move(data);
199 result["version"] = py::int_(2);
200 return result;
201 }
202
203 // PEP 3118 buffer protocol implementation.
204
205 namespace {
206
207 // Extra data to be kept alive by the consumer of the buffer protocol.
208 struct ExtraBufferInfo {
ExtraBufferInfoxla::__anonb86489e10311::ExtraBufferInfo209 explicit ExtraBufferInfo(
210 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold)
211 : external_reference_hold(std::move(external_reference_hold)) {}
212
213 std::string format;
214 std::vector<Py_ssize_t> strides;
215 // We keep an external reference hold to the PjRtBuffer. This prevents a
216 // use-after-free in the event that Delete() is called on a buffer with an
217 // live buffer protocol view. It does however mean that Delete() sometimes
218 // won't actually delete immediately.
219 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold;
220 };
221
PjRtBufferGetBuffer(PyObject * exporter,Py_buffer * view,int flags)222 int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
223 auto& buffer =
224 *py::reinterpret_borrow<py::object>(exporter).cast<PyBuffer&>().buffer();
225 Status status = [&]() {
226 // Py_buffer objects are POD C structures, so we don't need to hold the GIL.
227 // Additionally we call BlockHostUntilReady() below, which may block.
228 py::gil_scoped_release gil_release;
229
230 if (!buffer.IsOnCpu()) {
231 return InvalidArgument(
232 "Python buffer protocol is only defined for CPU buffers.");
233 }
234 if (!buffer.on_device_shape().IsArray()) {
235 return InvalidArgument(
236 "Python buffer protocol is only defined for array buffers.");
237 }
238 // If we allowed exports of formatted BF16 buffers, consumers would get
239 // confused about the type because there is no way to describe BF16 to
240 // Python.
241 if (buffer.on_device_shape().element_type() == BF16 &&
242 ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) {
243 return InvalidArgument(
244 "bfloat16 buffer format not supported by Python buffer protocol.");
245 }
246 if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) {
247 return InvalidArgument("XLA buffers are read-only.");
248 }
249 TF_ASSIGN_OR_RETURN(
250 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference_hold,
251 buffer.AcquireExternalReference());
252 if (buffer.IsDeleted()) {
253 return InvalidArgument("Deleted buffer used in buffer protocol.");
254 }
255 const Shape& shape = buffer.on_device_shape();
256 if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS ||
257 (flags & PyBUF_STRIDES) == PyBUF_ND) &&
258 !LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
259 return InvalidArgument("Buffer is not in C-contiguous layout.");
260 } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS &&
261 !LayoutUtil::IsMonotonicWithDim0Minor(shape.layout())) {
262 return InvalidArgument("Buffer is not in F-contiguous layout.");
263 } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS &&
264 !LayoutUtil::IsMonotonicWithDim0Major(shape.layout()) &&
265 !LayoutUtil::IsMonotonicWithDim0Minor(shape.layout())) {
266 return InvalidArgument("Buffer is not in contiguous layout.");
267 }
268 std::memset(view, 0, sizeof(Py_buffer));
269 const void* root_ptr =
270 external_reference_hold->OpaqueDeviceMemoryDataPointer();
271 view->buf = const_cast<void*>(root_ptr);
272 auto extra =
273 absl::make_unique<ExtraBufferInfo>(std::move(external_reference_hold));
274 view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
275 view->len = ShapeUtil::ByteSizeOf(shape);
276 view->readonly = 1;
277 if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
278 TF_ASSIGN_OR_RETURN(extra->format, FormatDescriptorForPrimitiveType(
279 shape.element_type()));
280 view->format = const_cast<char*>(extra->format.c_str());
281 }
282 if ((flags & PyBUF_ND) == PyBUF_ND) {
283 view->ndim = shape.dimensions_size();
284 static_assert(sizeof(int64) == sizeof(Py_ssize_t),
285 "Py_ssize_t must be 64 bits");
286 if (view->ndim != 0) {
287 view->shape = reinterpret_cast<Py_ssize_t*>(
288 const_cast<int64*>(shape.dimensions().data()));
289 if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
290 extra->strides = ByteStridesForShape(shape);
291 view->strides = extra->strides.data();
292 }
293 }
294 }
295 TF_RETURN_IF_ERROR(buffer.BlockHostUntilReady());
296 view->internal = extra.release();
297 return Status::OK();
298 }();
299 if (!status.ok()) {
300 PyErr_SetString(PyExc_BufferError, status.ToString().c_str());
301 return -1;
302 }
303 view->obj = exporter;
304 Py_INCREF(view->obj);
305 return 0;
306 }
307
PjRtBufferReleaseBuffer(PyObject *,Py_buffer * buffer)308 void PjRtBufferReleaseBuffer(PyObject*, Py_buffer* buffer) {
309 auto extra = static_cast<ExtraBufferInfo*>(buffer->internal);
310 delete extra;
311 }
312
__anonb86489e10502() 313 PyBufferProcs PjRtBufferProcs = []() {
314 PyBufferProcs procs;
315 procs.bf_getbuffer = &PjRtBufferGetBuffer;
316 procs.bf_releasebuffer = &PjRtBufferReleaseBuffer;
317 return procs;
318 }();
319
320 } // namespace
321
BufferProtocol()322 /*static*/ PyBufferProcs* PyBuffer::BufferProtocol() {
323 return &PjRtBufferProcs;
324 }
325
SetStickyDevice(pybind11::object sticky_device)326 void PyBuffer::SetStickyDevice(pybind11::object sticky_device) {
327 if (sticky_device_ && !sticky_device_->equal(sticky_device)) {
328 throw std::invalid_argument(
329 "One cannot set again the stickyness of a buffer and needs to create "
330 "a new one or a `_DeviceArray`");
331 }
332 sticky_device_ = sticky_device;
333 }
334
SetAval(pybind11::object aval)335 void PyBuffer::SetAval(pybind11::object aval) {
336 if (aval_ && !aval_->equal(aval)) {
337 throw std::invalid_argument(
338 "One cannot set again the aval_ of a buffer and needs to create a "
339 "new one or a `_DeviceArray`");
340 }
341 aval_ = aval;
342 }
343
344 } // namespace xla
345