1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/dlpack.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/span.h"
25 #include "include/dlpack/dlpack.h" // TF:dlpack
26 #include "tensorflow/compiler/xla/python/shared_device_buffer.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
30 #include "tensorflow/stream_executor/device_memory.h"
31 #include "tensorflow/stream_executor/host/host_platform_id.h"
32 #include "tensorflow/stream_executor/platform.h"
33
34 namespace py = pybind11;
35
36 namespace xla {
37 namespace {
38
39 const char* const kDlTensorCapsuleName = "dltensor";
40
41 struct DLPackTensor {
42 std::shared_ptr<SharedDeviceBuffer> buffer;
43 std::vector<int64> shape;
44 std::vector<int64> strides;
45 DLManagedTensor tensor;
46 };
47
DLPackTensorDeleter(DLManagedTensor * t)48 void DLPackTensorDeleter(DLManagedTensor* t) {
49 if (t) {
50 delete static_cast<DLPackTensor*>(t->manager_ctx);
51 }
52 }
53
PrimitiveTypeToDLDataType(PrimitiveType type)54 StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
55 switch (type) {
56 case S8:
57 return DLDataType{kDLInt, 8, 1};
58 case S16:
59 return DLDataType{kDLInt, 16, 1};
60 case S32:
61 return DLDataType{kDLInt, 32, 1};
62 case S64:
63 return DLDataType{kDLInt, 64, 1};
64 case U8:
65 return DLDataType{kDLUInt, 8, 1};
66 case U16:
67 return DLDataType{kDLUInt, 16, 1};
68 case U32:
69 return DLDataType{kDLUInt, 32, 1};
70 case U64:
71 return DLDataType{kDLUInt, 64, 1};
72 case F16:
73 return DLDataType{kDLFloat, 16, 1};
74 case F32:
75 return DLDataType{kDLFloat, 32, 1};
76 case F64:
77 return DLDataType{kDLFloat, 64, 1};
78 case BF16:
79 return DLDataType{kDLBfloat, 16, 1};
80 case PRED:
81 case C64:
82 case C128:
83 default:
84 return Unimplemented("XLA type %s has no DLPack equivalent",
85 PrimitiveType_Name(type));
86 }
87 }
88
DLDataTypeToPrimitiveType(DLDataType type)89 StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
90 if (type.lanes != 1) {
91 return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
92 type.lanes);
93 }
94 switch (type.code) {
95 case kDLInt:
96 switch (type.bits) {
97 case 8:
98 return S8;
99 case 16:
100 return S16;
101 case 32:
102 return S32;
103 case 64:
104 return S64;
105 default:
106 return Unimplemented(
107 "Invalid or unsupported DLPack integer width: %d bits",
108 type.bits);
109 }
110 case kDLUInt:
111 switch (type.bits) {
112 case 8:
113 return U8;
114 case 16:
115 return U16;
116 case 32:
117 return U32;
118 case 64:
119 return U64;
120 default:
121 return Unimplemented(
122 "Invalid or unsupported DLPack unsigned integer width: %d bits",
123 type.bits);
124 }
125 case kDLFloat:
126 switch (type.bits) {
127 case 16:
128 return F16;
129 case 32:
130 return F32;
131 case 64:
132 return F64;
133 default:
134 return Unimplemented(
135 "Invalid or unsupported DLPack float width: %d bits", type.bits);
136 }
137 case kDLBfloat:
138 switch (type.bits) {
139 case 16:
140 return BF16;
141 default:
142 return Unimplemented(
143 "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
144 }
145 default:
146 return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
147 }
148 }
149
150 // Returns the strides for `shape`.
StridesForShape(const Shape & shape)151 std::vector<int64> StridesForShape(const Shape& shape) {
152 std::vector<int64> strides;
153 CHECK(shape.IsArray());
154 CHECK(shape.has_layout());
155
156 strides.resize(shape.dimensions_size());
157 int64 stride = 1;
158 for (int i : shape.layout().minor_to_major()) {
159 strides.at(i) = stride;
160 stride *= shape.dimensions(i);
161 }
162 return strides;
163 }
164
StridesToLayout(absl::Span<int64 const> dims,absl::Span<int64 const> strides)165 StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
166 absl::Span<int64 const> strides) {
167 CHECK_EQ(dims.size(), strides.size());
168 std::vector<int64> minor_to_major(dims.size());
169 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
170 absl::c_sort(minor_to_major, [&](int a, int b) {
171 if (strides[a] < strides[b]) {
172 return true;
173 }
174 if (strides[a] > strides[b]) {
175 return false;
176 }
177 return dims[a] == 1 && dims[b] != 1;
178 });
179 int64 stride = 1;
180 for (int64 d : minor_to_major) {
181 if (strides[d] != stride) {
182 return Unimplemented(
183 "Only DLPack tensors with trivial (compact) striding are supported; "
184 "i.e., tensors whose striding represents a transposition of the "
185 "underlying buffer but not broadcasting. Dimensions were: [%s], "
186 "strides were [%s].",
187 absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
188 }
189 stride *= dims[d];
190 }
191 return minor_to_major;
192 }
193
DLDeviceTypeForDevice(const Device & device)194 StatusOr<DLDeviceType> DLDeviceTypeForDevice(const Device& device) {
195 const se::Platform* platform =
196 device.local_device_state()->executor()->platform();
197 if (platform->id() == se::host::kHostPlatformId) {
198 return kDLCPU;
199 } else if (platform->id() == se::cuda::kCudaPlatformId) {
200 return kDLGPU;
201 }
202 return InvalidArgument("Device %s cannot be used as a DLPack device.",
203 device.DebugString());
204 }
205
DLContextForDevice(const Device & device)206 StatusOr<DLContext> DLContextForDevice(const Device& device) {
207 DLContext context;
208 TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
209 context.device_id = device.local_device_state()->device_ordinal();
210 return context;
211 }
212
DeviceForDLContext(const PyLocalClient & client,const DLContext & context)213 StatusOr<std::shared_ptr<Device>> DeviceForDLContext(
214 const PyLocalClient& client, const DLContext& context) {
215 se::Platform::Id platform_id;
216 switch (context.device_type) {
217 case kDLCPU:
218 platform_id = se::host::kHostPlatformId;
219 break;
220 case kDLGPU:
221 platform_id = se::cuda::kCudaPlatformId;
222 break;
223 default:
224 return InvalidArgument("Unknown/unsupported DLPack device type %d",
225 context.device_type);
226 }
227 auto it = absl::c_find_if(
228 client.local_devices(), [&](const std::shared_ptr<Device>& device) {
229 return device->local_device_state()->executor()->platform()->id() ==
230 platform_id &&
231 device->local_device_state()->device_ordinal() ==
232 context.device_id;
233 });
234 if (it == client.local_devices().end()) {
235 return InvalidArgument(
236 "No matching device found for DLPack device_type %d device_id %d",
237 context.device_type, context.device_id);
238 }
239 return *it;
240 }
241
242 } // namespace
243
BufferToDLPackManagedTensor(PyLocalBuffer * buffer)244 StatusOr<py::capsule> BufferToDLPackManagedTensor(PyLocalBuffer* buffer) {
245 auto pack = absl::make_unique<DLPackTensor>();
246 pack->buffer = buffer->DeviceBuffer();
247 if (!pack->buffer) {
248 return InvalidArgument(
249 "Cannot convert deleted/invalid buffer to DLPack tensor.");
250 }
251 pack->tensor.manager_ctx = pack.get();
252 pack->tensor.deleter = DLPackTensorDeleter;
253 DLTensor& dt = pack->tensor.dl_tensor;
254 if (buffer->on_device_shape().IsTuple()) {
255 return Unimplemented(
256 "unsafe_buffer_pointer is not implemented for tuple "
257 "buffers.");
258 }
259 TF_RET_CHECK(pack->buffer->device_memory().size() == 1);
260 dt.data = pack->buffer->device_memory().front().opaque();
261 TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->device()));
262 dt.ctx.device_id = buffer->device()->local_device_state()->device_ordinal();
263 dt.ndim = buffer->on_host_shape().dimensions_size();
264 TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType(
265 buffer->on_host_shape().element_type()));
266
267 pack->shape = std::vector<int64>(buffer->on_host_shape().dimensions().begin(),
268 buffer->on_host_shape().dimensions().end());
269 pack->strides = StridesForShape(buffer->on_host_shape());
270 dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
271 dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
272 dt.byte_offset = 0;
273
274 py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
275 [](PyObject* obj) {
276 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
277 PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
278 if (dlmt) {
279 DLPackTensorDeleter(dlmt);
280 } else {
281 // The tensor has been deleted. Clear any error from
282 // PyCapsule_GetPointer.
283 PyErr_Clear();
284 }
285 });
286
287 TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
288 return capsule;
289 }
290
DLPackManagedTensorToBuffer(const pybind11::capsule & tensor,std::shared_ptr<PyLocalClient> client)291 StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
292 const pybind11::capsule& tensor, std::shared_ptr<PyLocalClient> client) {
293 if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
294 return InvalidArgument(
295 "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
296 "Note that a DLPack tensor may be consumed at most once.",
297 absl::string_view(tensor.name()));
298 }
299 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
300 if (dlmt->dl_tensor.ndim < 0) {
301 return InvalidArgument(
302 "Number of dimensions in DLManagedTensor must be nonnegative, got %d",
303 dlmt->dl_tensor.ndim);
304 }
305 TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
306 DeviceForDLContext(*client, dlmt->dl_tensor.ctx));
307 absl::Span<int64 const> dimensions(
308 reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
309 TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
310 DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
311
312 std::vector<int64> minor_to_major;
313 if (dlmt->dl_tensor.strides && !absl::c_find(dimensions, 0)) {
314 absl::Span<int64 const> strides(
315 reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
316 dlmt->dl_tensor.ndim);
317 TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
318 } else {
319 minor_to_major.resize(dlmt->dl_tensor.ndim);
320 std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
321 }
322 Shape shape =
323 ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
324 se::DeviceMemoryBase buffer(
325 static_cast<char*>(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset,
326 ShapeUtil::ByteSizeOf(shape));
327
328 std::function<void()> on_delete_callback;
329 if (dlmt->deleter) {
330 on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
331 }
332 auto device_buffer = std::make_shared<SharedDeviceBuffer>(
333 /*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id,
334 std::initializer_list<se::DeviceMemoryBase>{buffer},
335 /*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
336 /*definition_event=*/nullptr, std::move(on_delete_callback));
337
338 // We have taken ownership of the array inside the capsule; make sure the
339 // capsule it cannot be used again.
340 PyCapsule_SetName(tensor.ptr(), "used_dltensor");
341 PyCapsule_SetDestructor(tensor.ptr(), nullptr);
342 return absl::make_unique<PyLocalBuffer>(shape, shape,
343 std::move(device_buffer),
344 std::move(client), std::move(device));
345 }
346
347 } // namespace xla
348