1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/dlpack.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/span.h"
25 #include "include/dlpack/dlpack.h" // from @dlpack
26 #include "pybind11/pytypes.h"
27 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
28 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
29 #include "tensorflow/compiler/xla/python/traceback.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/util.h"
32
33 namespace py = pybind11;
34
35 namespace xla {
36 namespace {
37
38 const char* const kDlTensorCapsuleName = "dltensor";
39
40 struct DLPackTensor {
41 ~DLPackTensor();
42
43 // `buffer_reference` is populated if we have shared (read-only) access.
44 py::object buffer_reference;
45
46 // `external_reference` is always populated.
47 std::unique_ptr<PjRtBuffer::ExternalReference> external_reference;
48
49 std::vector<int64> shape;
50 std::vector<int64> strides;
51 DLManagedTensor tensor;
52 };
53
~DLPackTensor()54 DLPackTensor::~DLPackTensor() {
55 if (buffer_reference) {
56 GlobalPyRefManager()->AddGarbage(
57 absl::MakeSpan(&buffer_reference, /*size=*/1));
58 }
59 }
60
DLPackTensorDeleter(DLManagedTensor * t)61 void DLPackTensorDeleter(DLManagedTensor* t) {
62 if (t) {
63 delete static_cast<DLPackTensor*>(t->manager_ctx);
64 }
65 }
66
PrimitiveTypeToDLDataType(PrimitiveType type)67 StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
68 switch (type) {
69 case S8:
70 return DLDataType{kDLInt, 8, 1};
71 case S16:
72 return DLDataType{kDLInt, 16, 1};
73 case S32:
74 return DLDataType{kDLInt, 32, 1};
75 case S64:
76 return DLDataType{kDLInt, 64, 1};
77 case U8:
78 return DLDataType{kDLUInt, 8, 1};
79 case U16:
80 return DLDataType{kDLUInt, 16, 1};
81 case U32:
82 return DLDataType{kDLUInt, 32, 1};
83 case U64:
84 return DLDataType{kDLUInt, 64, 1};
85 case F16:
86 return DLDataType{kDLFloat, 16, 1};
87 case F32:
88 return DLDataType{kDLFloat, 32, 1};
89 case F64:
90 return DLDataType{kDLFloat, 64, 1};
91 case BF16:
92 return DLDataType{kDLBfloat, 16, 1};
93 case PRED:
94 return DLDataType{kDLUInt, 8, 1};
95 case C64:
96 case C128:
97 default:
98 return Unimplemented("XLA type %s has no DLPack equivalent",
99 PrimitiveType_Name(type));
100 }
101 }
102
DLDataTypeToPrimitiveType(DLDataType type)103 StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
104 if (type.lanes != 1) {
105 return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
106 type.lanes);
107 }
108 switch (type.code) {
109 case kDLInt:
110 switch (type.bits) {
111 case 8:
112 return S8;
113 case 16:
114 return S16;
115 case 32:
116 return S32;
117 case 64:
118 return S64;
119 default:
120 return Unimplemented(
121 "Invalid or unsupported DLPack integer width: %d bits",
122 type.bits);
123 }
124 case kDLUInt:
125 switch (type.bits) {
126 case 8:
127 return U8;
128 case 16:
129 return U16;
130 case 32:
131 return U32;
132 case 64:
133 return U64;
134 default:
135 return Unimplemented(
136 "Invalid or unsupported DLPack unsigned integer width: %d bits",
137 type.bits);
138 }
139 case kDLFloat:
140 switch (type.bits) {
141 case 16:
142 return F16;
143 case 32:
144 return F32;
145 case 64:
146 return F64;
147 default:
148 return Unimplemented(
149 "Invalid or unsupported DLPack float width: %d bits", type.bits);
150 }
151 case kDLBfloat:
152 switch (type.bits) {
153 case 16:
154 return BF16;
155 default:
156 return Unimplemented(
157 "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
158 }
159 default:
160 return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
161 }
162 }
163
164 // Returns the strides for `shape`.
StridesForShape(const Shape & shape)165 std::vector<int64> StridesForShape(const Shape& shape) {
166 std::vector<int64> strides;
167 CHECK(shape.IsArray());
168 CHECK(shape.has_layout());
169
170 strides.resize(shape.dimensions_size());
171 int64 stride = 1;
172 for (int i : shape.layout().minor_to_major()) {
173 strides.at(i) = stride;
174 stride *= shape.dimensions(i);
175 }
176 return strides;
177 }
178
StridesToLayout(absl::Span<int64 const> dims,absl::Span<int64 const> strides)179 StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
180 absl::Span<int64 const> strides) {
181 CHECK_EQ(dims.size(), strides.size());
182 std::vector<int64> minor_to_major(dims.size());
183 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
184 absl::c_sort(minor_to_major, [&](int a, int b) {
185 if (strides[a] < strides[b]) {
186 return true;
187 }
188 if (strides[a] > strides[b]) {
189 return false;
190 }
191 return dims[a] == 1 && dims[b] != 1;
192 });
193 int64 stride = 1;
194 for (int64 d : minor_to_major) {
195 if (strides[d] != stride) {
196 return Unimplemented(
197 "Only DLPack tensors with trivial (compact) striding are supported; "
198 "i.e., tensors whose striding represents a transposition of the "
199 "underlying buffer but not broadcasting. Dimensions were: [%s], "
200 "strides were [%s].",
201 absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
202 }
203 stride *= dims[d];
204 }
205 return minor_to_major;
206 }
207
DLDeviceTypeForDevice(const PjRtDevice & device)208 StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
209 if (device.client()->platform_id() == kCpuId) {
210 return kDLCPU;
211 } else if (device.client()->platform_id() == kGpuId) {
212 return kDLGPU;
213 }
214 return InvalidArgument("Device %s cannot be used as a DLPack device.",
215 device.DebugString());
216 }
217
DLContextForDevice(const PjRtDevice & device)218 StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
219 DLContext context;
220 TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
221 context.device_id = device.local_hardware_id();
222 return context;
223 }
224
DeviceForDLContext(const PjRtClient & client,const DLContext & context)225 StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
226 const DLContext& context) {
227 switch (context.device_type) {
228 case kDLCPU:
229 if (client.platform_id() != kCpuId) {
230 return InvalidArgument(
231 "DLPack CPU device type mismatch with PjRtClient platform %s",
232 client.platform_name());
233 }
234 return client.LookupAddressableDevice(context.device_id);
235 case kDLGPU:
236 if (client.platform_id() != kGpuId) {
237 return InvalidArgument(
238 "DLPack GPU device type mismatch with PjRtClient platform %s",
239 client.platform_name());
240 }
241 return client.LookupAddressableDevice(context.device_id);
242 default:
243 return InvalidArgument("Unknown/unsupported DLPack device type %d",
244 context.device_type);
245 }
246 }
247
248 } // namespace
249
BufferToDLPackManagedTensor(py::handle py_buffer,bool take_ownership)250 StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
251 bool take_ownership) {
252 PyBuffer* buffer = py::cast<PyBuffer*>(py_buffer);
253 auto pack = std::make_unique<DLPackTensor>();
254 if (buffer->buffer()->on_device_shape().IsTuple()) {
255 return Unimplemented(
256 "unsafe_buffer_pointer is not implemented for tuple "
257 "buffers.");
258 }
259
260 DLTensor& dt = pack->tensor.dl_tensor;
261 if (take_ownership) {
262 // Block on outstanding operations, so that it is safe to read or mutate the
263 // returned buffer.
264 StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> buffer_or =
265 buffer->buffer()->ReleaseDeviceMemoryOwnership(
266 /*wait_for_operations_to_complete=*/true);
267 if (!buffer_or.ok()) {
268 return InvalidArgument(
269 "Buffer synchronization failed converting to DLPack tensor: %s",
270 buffer_or.status().ToString());
271 }
272 pack->external_reference = buffer_or.ConsumeValueOrDie();
273 if (!pack->external_reference) {
274 return InvalidArgument(
275 "Cannot convert deleted/invalid buffer to DLPack tensor.");
276 }
277 } else {
278 // Block on outstanding operations, so that it is safe to read or mutate the
279 // returned buffer.
280 TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
281 pack->buffer_reference = py::reinterpret_borrow<py::object>(py_buffer);
282 TF_ASSIGN_OR_RETURN(pack->external_reference,
283 buffer->buffer()->AcquireExternalReference());
284 }
285 dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer();
286 pack->tensor.manager_ctx = pack.get();
287 pack->tensor.deleter = DLPackTensorDeleter;
288 TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
289 dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
290 dt.ndim = buffer->buffer()->on_device_shape().dimensions_size();
291 TF_ASSIGN_OR_RETURN(dt.dtype,
292 PrimitiveTypeToDLDataType(
293 buffer->buffer()->on_device_shape().element_type()));
294
295 pack->shape = std::vector<int64>(
296 buffer->buffer()->on_device_shape().dimensions().begin(),
297 buffer->buffer()->on_device_shape().dimensions().end());
298 pack->strides = StridesForShape(buffer->buffer()->on_device_shape());
299 dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
300 dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
301 dt.byte_offset = 0;
302
303 py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
304 [](PyObject* obj) {
305 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
306 PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
307 if (dlmt) {
308 DLPackTensorDeleter(dlmt);
309 } else {
310 // The tensor has been deleted. Clear any error from
311 // PyCapsule_GetPointer.
312 PyErr_Clear();
313 }
314 });
315 return capsule;
316 }
317
DLPackManagedTensorToBuffer(const pybind11::capsule & tensor,std::shared_ptr<PyClient> client)318 StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
319 const pybind11::capsule& tensor, std::shared_ptr<PyClient> client) {
320 if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
321 return InvalidArgument(
322 "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
323 "Note that a DLPack tensor may be consumed at most once.",
324 absl::string_view(tensor.name()));
325 }
326 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
327 if (dlmt->dl_tensor.ndim < 0) {
328 return InvalidArgument(
329 "Number of dimensions in DLManagedTensor must be nonnegative, got %d",
330 dlmt->dl_tensor.ndim);
331 }
332 TF_ASSIGN_OR_RETURN(
333 PjRtDevice * device,
334 DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx));
335 absl::Span<int64 const> dimensions(
336 reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
337 TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
338 DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
339
340 std::vector<int64> minor_to_major;
341 if (dlmt->dl_tensor.strides &&
342 absl::c_find(dimensions, 0) == dimensions.end()) {
343 absl::Span<int64 const> strides(
344 reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
345 dlmt->dl_tensor.ndim);
346 TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
347 } else {
348 minor_to_major.resize(dlmt->dl_tensor.ndim);
349 std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
350 }
351 Shape shape =
352 ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
353
354 std::function<void()> on_delete_callback;
355 if (dlmt->deleter) {
356 on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
357 }
358 TF_ASSIGN_OR_RETURN(auto pjrt_buffer,
359 client->pjrt_client()->CreateViewOfDeviceBuffer(
360 static_cast<char*>(dlmt->dl_tensor.data) +
361 dlmt->dl_tensor.byte_offset,
362 shape, device, on_delete_callback));
363 // We have taken ownership of the array inside the capsule; make sure the
364 // capsule it cannot be used again.
365 PyCapsule_SetName(tensor.ptr(), "used_dltensor");
366 PyCapsule_SetDestructor(tensor.ptr(), nullptr);
367 return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
368 Traceback::Get());
369 }
370
371 } // namespace xla
372