1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/dlpack.h"
17
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/types/span.h"
28 #include "include/dlpack/dlpack.h" // from @dlpack
29 #include "pybind11/pytypes.h"
30 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
31 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
32 #include "tensorflow/compiler/xla/python/traceback.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35
36 namespace py = pybind11;
37
38 namespace xla {
39 namespace {
40
41 const char* const kDlTensorCapsuleName = "dltensor";
42
43 struct DLPackTensor {
44 ~DLPackTensor();
45
46 // `buffer_reference` is populated if we have shared (read-only) access.
47 py::object buffer_reference;
48
49 // `external_reference` is always populated.
50 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference;
51
52 std::vector<int64> shape;
53 std::vector<int64> strides;
54 DLManagedTensor tensor;
55 };
56
~DLPackTensor()57 DLPackTensor::~DLPackTensor() {
58 if (buffer_reference) {
59 GlobalPyRefManager()->AddGarbage(
60 absl::MakeSpan(&buffer_reference, /*size=*/1));
61 }
62 }
63
DLPackTensorDeleter(DLManagedTensor * t)64 void DLPackTensorDeleter(DLManagedTensor* t) {
65 if (t) {
66 delete static_cast<DLPackTensor*>(t->manager_ctx);
67 }
68 }
69
PrimitiveTypeToDLDataType(PrimitiveType type)70 StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
71 switch (type) {
72 case S8:
73 return DLDataType{kDLInt, 8, 1};
74 case S16:
75 return DLDataType{kDLInt, 16, 1};
76 case S32:
77 return DLDataType{kDLInt, 32, 1};
78 case S64:
79 return DLDataType{kDLInt, 64, 1};
80 case U8:
81 return DLDataType{kDLUInt, 8, 1};
82 case U16:
83 return DLDataType{kDLUInt, 16, 1};
84 case U32:
85 return DLDataType{kDLUInt, 32, 1};
86 case U64:
87 return DLDataType{kDLUInt, 64, 1};
88 case F16:
89 return DLDataType{kDLFloat, 16, 1};
90 case F32:
91 return DLDataType{kDLFloat, 32, 1};
92 case F64:
93 return DLDataType{kDLFloat, 64, 1};
94 case BF16:
95 return DLDataType{kDLBfloat, 16, 1};
96 case PRED:
97 return DLDataType{kDLUInt, 8, 1};
98 case C64:
99 case C128:
100 default:
101 return Unimplemented("XLA type %s has no DLPack equivalent",
102 PrimitiveType_Name(type));
103 }
104 }
105
DLDataTypeToPrimitiveType(DLDataType type)106 StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
107 if (type.lanes != 1) {
108 return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
109 type.lanes);
110 }
111 switch (type.code) {
112 case kDLInt:
113 switch (type.bits) {
114 case 8:
115 return S8;
116 case 16:
117 return S16;
118 case 32:
119 return S32;
120 case 64:
121 return S64;
122 default:
123 return Unimplemented(
124 "Invalid or unsupported DLPack integer width: %d bits",
125 type.bits);
126 }
127 case kDLUInt:
128 switch (type.bits) {
129 case 8:
130 return U8;
131 case 16:
132 return U16;
133 case 32:
134 return U32;
135 case 64:
136 return U64;
137 default:
138 return Unimplemented(
139 "Invalid or unsupported DLPack unsigned integer width: %d bits",
140 type.bits);
141 }
142 case kDLFloat:
143 switch (type.bits) {
144 case 16:
145 return F16;
146 case 32:
147 return F32;
148 case 64:
149 return F64;
150 default:
151 return Unimplemented(
152 "Invalid or unsupported DLPack float width: %d bits", type.bits);
153 }
154 case kDLBfloat:
155 switch (type.bits) {
156 case 16:
157 return BF16;
158 default:
159 return Unimplemented(
160 "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
161 }
162 default:
163 return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
164 }
165 }
166
167 // Returns the strides for `shape`.
StridesForShape(const Shape & shape)168 std::vector<int64> StridesForShape(const Shape& shape) {
169 std::vector<int64> strides;
170 CHECK(shape.IsArray());
171 CHECK(shape.has_layout());
172
173 strides.resize(shape.dimensions_size());
174 int64_t stride = 1;
175 for (int i : shape.layout().minor_to_major()) {
176 strides.at(i) = stride;
177 stride *= shape.dimensions(i);
178 }
179 return strides;
180 }
181
StridesToLayout(absl::Span<int64 const> dims,absl::Span<int64 const> strides)182 StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
183 absl::Span<int64 const> strides) {
184 CHECK_EQ(dims.size(), strides.size());
185 std::vector<int64> minor_to_major(dims.size());
186 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
187 absl::c_sort(minor_to_major, [&](int a, int b) {
188 if (strides[a] < strides[b]) {
189 return true;
190 }
191 if (strides[a] > strides[b]) {
192 return false;
193 }
194 return dims[a] == 1 && dims[b] != 1;
195 });
196 int64_t stride = 1;
197 for (int64_t d : minor_to_major) {
198 if (strides[d] != stride) {
199 return Unimplemented(
200 "Only DLPack tensors with trivial (compact) striding are supported; "
201 "i.e., tensors whose striding represents a transposition of the "
202 "underlying buffer but not broadcasting. Dimensions were: [%s], "
203 "strides were [%s].",
204 absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
205 }
206 stride *= dims[d];
207 }
208 return minor_to_major;
209 }
210
DLDeviceTypeForDevice(const PjRtDevice & device)211 StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
212 if (device.client()->platform_id() == kCpuId) {
213 return kDLCPU;
214 } else if (device.client()->platform_id() == kGpuId) {
215 return kDLGPU;
216 }
217 return InvalidArgument("Device %s cannot be used as a DLPack device.",
218 device.DebugString());
219 }
220
DLContextForDevice(const PjRtDevice & device)221 StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
222 DLContext context;
223 TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
224 context.device_id = device.local_hardware_id();
225 return context;
226 }
227
DeviceForDLContext(const PjRtClient * cpu_client,const PjRtClient * gpu_client,const DLContext & context)228 StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient* cpu_client,
229 const PjRtClient* gpu_client,
230 const DLContext& context) {
231 switch (context.device_type) {
232 case kDLCPU:
233 if (cpu_client == nullptr) {
234 return InvalidArgument(
235 "DLPack tensor is on CPU, but no CPU backend was provided.");
236 }
237 TF_RET_CHECK(cpu_client->platform_id() == kCpuId);
238 return cpu_client->LookupAddressableDevice(context.device_id);
239 case kDLGPU:
240 if (gpu_client == nullptr) {
241 return InvalidArgument(
242 "DLPack tensor is on GPU, but no GPU backend was provided.");
243 }
244 TF_RET_CHECK(gpu_client->platform_id() == kGpuId);
245 return gpu_client->LookupAddressableDevice(context.device_id);
246 default:
247 return InvalidArgument("Unknown/unsupported DLPack device type %d",
248 context.device_type);
249 }
250 }
251
252 } // namespace
253
BufferToDLPackManagedTensor(py::handle py_buffer,bool take_ownership)254 StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
255 bool take_ownership) {
256 TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(py_buffer));
257 auto pack = std::make_unique<DLPackTensor>();
258 if (buffer->buffer()->on_device_shape().IsTuple()) {
259 return Unimplemented(
260 "unsafe_buffer_pointer is not implemented for tuple "
261 "buffers.");
262 }
263 if (buffer->buffer()->on_device_shape().is_dynamic()) {
264 return Unimplemented("DynamicShape is not implemented in DLPack.");
265 }
266
267 DLTensor& dt = pack->tensor.dl_tensor;
268 if (take_ownership) {
269 // Block on outstanding operations, so that it is safe to read or mutate the
270 // returned buffer.
271 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> buffer_or =
272 buffer->buffer()->ReleaseDeviceMemoryOwnership(
273 /*wait_for_operations_to_complete=*/true);
274 if (!buffer_or.ok()) {
275 return InvalidArgument(
276 "Buffer synchronization failed converting to DLPack tensor: %s",
277 buffer_or.status().ToString());
278 }
279 pack->external_reference = buffer_or.ConsumeValueOrDie();
280 if (!pack->external_reference) {
281 return InvalidArgument(
282 "Cannot convert deleted/invalid buffer to DLPack tensor.");
283 }
284 } else {
285 // Block on outstanding operations, so that it is safe to read or mutate the
286 // returned buffer.
287 TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
288 pack->buffer_reference = py::reinterpret_borrow<py::object>(py_buffer);
289 TF_ASSIGN_OR_RETURN(pack->external_reference,
290 buffer->buffer()->AcquireExternalReference());
291 }
292 dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer();
293 pack->tensor.manager_ctx = pack.get();
294 pack->tensor.deleter = DLPackTensorDeleter;
295 TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
296 dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
297 dt.ndim = buffer->buffer()->on_device_shape().dimensions_size();
298 TF_ASSIGN_OR_RETURN(dt.dtype,
299 PrimitiveTypeToDLDataType(
300 buffer->buffer()->on_device_shape().element_type()));
301
302 pack->shape = std::vector<int64>(
303 buffer->buffer()->on_device_shape().dimensions().begin(),
304 buffer->buffer()->on_device_shape().dimensions().end());
305 pack->strides = StridesForShape(buffer->buffer()->on_device_shape());
306 dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
307 dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
308 dt.byte_offset = 0;
309
310 py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
311 [](PyObject* obj) {
312 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
313 PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
314 if (dlmt) {
315 DLPackTensorDeleter(dlmt);
316 } else {
317 // The tensor has been deleted. Clear any error from
318 // PyCapsule_GetPointer.
319 PyErr_Clear();
320 }
321 });
322 return capsule;
323 }
324
DLPackManagedTensorToBuffer(const pybind11::capsule & tensor,std::shared_ptr<PyClient> cpu_client,std::shared_ptr<PyClient> gpu_client)325 StatusOr<PyBuffer::object> DLPackManagedTensorToBuffer(
326 const pybind11::capsule& tensor, std::shared_ptr<PyClient> cpu_client,
327 std::shared_ptr<PyClient> gpu_client) {
328 // Backward compatibility: if only one client is passed, it may be from any
329 // platform. Drop this support after dropping support for jax <= 0.2.14.
330 if (cpu_client && cpu_client->pjrt_client()->platform_id() == kGpuId) {
331 gpu_client = std::move(cpu_client);
332 cpu_client = nullptr;
333 }
334 if (cpu_client && cpu_client->pjrt_client()->platform_id() != kCpuId) {
335 return InvalidArgument("DLPack does not support platform %s",
336 cpu_client->pjrt_client()->platform_name());
337 }
338
339 if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
340 return InvalidArgument(
341 "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
342 "Note that a DLPack tensor may be consumed at most once.",
343 absl::string_view(tensor.name()));
344 }
345 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
346 if (dlmt->dl_tensor.ndim < 0) {
347 return InvalidArgument(
348 "Number of dimensions in DLManagedTensor must be nonnegative, got %d",
349 dlmt->dl_tensor.ndim);
350 }
351 TF_ASSIGN_OR_RETURN(
352 PjRtDevice * device,
353 DeviceForDLContext(cpu_client ? cpu_client->pjrt_client() : nullptr,
354 gpu_client ? gpu_client->pjrt_client() : nullptr,
355 dlmt->dl_tensor.ctx));
356 absl::Span<int64 const> dimensions(
357 reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
358 TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
359 DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
360
361 std::vector<int64> minor_to_major;
362 if (dlmt->dl_tensor.strides &&
363 absl::c_find(dimensions, 0) == dimensions.end()) {
364 absl::Span<int64 const> strides(
365 reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
366 dlmt->dl_tensor.ndim);
367 TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
368 } else {
369 minor_to_major.resize(dlmt->dl_tensor.ndim);
370 std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
371 }
372 Shape shape =
373 ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
374
375 std::function<void()> on_delete_callback;
376 if (dlmt->deleter) {
377 on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
378 }
379 TF_ASSIGN_OR_RETURN(auto pjrt_buffer,
380 device->client()->CreateViewOfDeviceBuffer(
381 static_cast<char*>(dlmt->dl_tensor.data) +
382 dlmt->dl_tensor.byte_offset,
383 shape, device, on_delete_callback));
384 // We have taken ownership of the array inside the capsule; make sure the
385 // capsule it cannot be used again.
386 PyCapsule_SetName(tensor.ptr(), "used_dltensor");
387 PyCapsule_SetDestructor(tensor.ptr(), nullptr);
388 // TODO(phawkins): simplify the expression below once we know cpu_client is
389 // always non-null.
390 return PyBuffer::Make(
391 (cpu_client && device->client() == cpu_client->pjrt_client())
392 ? std::move(cpu_client)
393 : std::move(gpu_client),
394 std::move(pjrt_buffer), Traceback::Get());
395 }
396
397 } // namespace xla
398