• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/dlpack.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/types/span.h"
28 #include "include/dlpack/dlpack.h"  // from @dlpack
29 #include "pybind11/pytypes.h"
30 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
31 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
32 #include "tensorflow/compiler/xla/python/traceback.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 
36 namespace py = pybind11;
37 
38 namespace xla {
39 namespace {
40 
41 const char* const kDlTensorCapsuleName = "dltensor";
42 
43 struct DLPackTensor {
44   ~DLPackTensor();
45 
46   // `buffer_reference` is populated if we have shared (read-only) access.
47   py::object buffer_reference;
48 
49   // `external_reference` is always populated.
50   std::unique_ptr<PjRtBuffer::ExternalReference> external_reference;
51 
52   std::vector<int64> shape;
53   std::vector<int64> strides;
54   DLManagedTensor tensor;
55 };
56 
~DLPackTensor()57 DLPackTensor::~DLPackTensor() {
58   if (buffer_reference) {
59     GlobalPyRefManager()->AddGarbage(
60         absl::MakeSpan(&buffer_reference, /*size=*/1));
61   }
62 }
63 
DLPackTensorDeleter(DLManagedTensor * t)64 void DLPackTensorDeleter(DLManagedTensor* t) {
65   if (t) {
66     delete static_cast<DLPackTensor*>(t->manager_ctx);
67   }
68 }
69 
PrimitiveTypeToDLDataType(PrimitiveType type)70 StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
71   switch (type) {
72     case S8:
73       return DLDataType{kDLInt, 8, 1};
74     case S16:
75       return DLDataType{kDLInt, 16, 1};
76     case S32:
77       return DLDataType{kDLInt, 32, 1};
78     case S64:
79       return DLDataType{kDLInt, 64, 1};
80     case U8:
81       return DLDataType{kDLUInt, 8, 1};
82     case U16:
83       return DLDataType{kDLUInt, 16, 1};
84     case U32:
85       return DLDataType{kDLUInt, 32, 1};
86     case U64:
87       return DLDataType{kDLUInt, 64, 1};
88     case F16:
89       return DLDataType{kDLFloat, 16, 1};
90     case F32:
91       return DLDataType{kDLFloat, 32, 1};
92     case F64:
93       return DLDataType{kDLFloat, 64, 1};
94     case BF16:
95       return DLDataType{kDLBfloat, 16, 1};
96     case PRED:
97       return DLDataType{kDLUInt, 8, 1};
98     case C64:
99     case C128:
100     default:
101       return Unimplemented("XLA type %s has no DLPack equivalent",
102                            PrimitiveType_Name(type));
103   }
104 }
105 
DLDataTypeToPrimitiveType(DLDataType type)106 StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
107   if (type.lanes != 1) {
108     return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
109                          type.lanes);
110   }
111   switch (type.code) {
112     case kDLInt:
113       switch (type.bits) {
114         case 8:
115           return S8;
116         case 16:
117           return S16;
118         case 32:
119           return S32;
120         case 64:
121           return S64;
122         default:
123           return Unimplemented(
124               "Invalid or unsupported DLPack integer width: %d bits",
125               type.bits);
126       }
127     case kDLUInt:
128       switch (type.bits) {
129         case 8:
130           return U8;
131         case 16:
132           return U16;
133         case 32:
134           return U32;
135         case 64:
136           return U64;
137         default:
138           return Unimplemented(
139               "Invalid or unsupported DLPack unsigned integer width: %d bits",
140               type.bits);
141       }
142     case kDLFloat:
143       switch (type.bits) {
144         case 16:
145           return F16;
146         case 32:
147           return F32;
148         case 64:
149           return F64;
150         default:
151           return Unimplemented(
152               "Invalid or unsupported DLPack float width: %d bits", type.bits);
153       }
154     case kDLBfloat:
155       switch (type.bits) {
156         case 16:
157           return BF16;
158         default:
159           return Unimplemented(
160               "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
161       }
162     default:
163       return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
164   }
165 }
166 
167 // Returns the strides for `shape`.
StridesForShape(const Shape & shape)168 std::vector<int64> StridesForShape(const Shape& shape) {
169   std::vector<int64> strides;
170   CHECK(shape.IsArray());
171   CHECK(shape.has_layout());
172 
173   strides.resize(shape.dimensions_size());
174   int64_t stride = 1;
175   for (int i : shape.layout().minor_to_major()) {
176     strides.at(i) = stride;
177     stride *= shape.dimensions(i);
178   }
179   return strides;
180 }
181 
StridesToLayout(absl::Span<int64 const> dims,absl::Span<int64 const> strides)182 StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
183                                              absl::Span<int64 const> strides) {
184   CHECK_EQ(dims.size(), strides.size());
185   std::vector<int64> minor_to_major(dims.size());
186   std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
187   absl::c_sort(minor_to_major, [&](int a, int b) {
188     if (strides[a] < strides[b]) {
189       return true;
190     }
191     if (strides[a] > strides[b]) {
192       return false;
193     }
194     return dims[a] == 1 && dims[b] != 1;
195   });
196   int64_t stride = 1;
197   for (int64_t d : minor_to_major) {
198     if (strides[d] != stride) {
199       return Unimplemented(
200           "Only DLPack tensors with trivial (compact) striding are supported; "
201           "i.e., tensors whose striding represents a transposition of the "
202           "underlying buffer but not broadcasting. Dimensions were: [%s], "
203           "strides were [%s].",
204           absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
205     }
206     stride *= dims[d];
207   }
208   return minor_to_major;
209 }
210 
DLDeviceTypeForDevice(const PjRtDevice & device)211 StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
212   if (device.client()->platform_id() == kCpuId) {
213     return kDLCPU;
214   } else if (device.client()->platform_id() == kGpuId) {
215     return kDLGPU;
216   }
217   return InvalidArgument("Device %s cannot be used as a DLPack device.",
218                          device.DebugString());
219 }
220 
DLContextForDevice(const PjRtDevice & device)221 StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
222   DLContext context;
223   TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
224   context.device_id = device.local_hardware_id();
225   return context;
226 }
227 
DeviceForDLContext(const PjRtClient * cpu_client,const PjRtClient * gpu_client,const DLContext & context)228 StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient* cpu_client,
229                                          const PjRtClient* gpu_client,
230                                          const DLContext& context) {
231   switch (context.device_type) {
232     case kDLCPU:
233       if (cpu_client == nullptr) {
234         return InvalidArgument(
235             "DLPack tensor is on CPU, but no CPU backend was provided.");
236       }
237       TF_RET_CHECK(cpu_client->platform_id() == kCpuId);
238       return cpu_client->LookupAddressableDevice(context.device_id);
239     case kDLGPU:
240       if (gpu_client == nullptr) {
241         return InvalidArgument(
242             "DLPack tensor is on GPU, but no GPU backend was provided.");
243       }
244       TF_RET_CHECK(gpu_client->platform_id() == kGpuId);
245       return gpu_client->LookupAddressableDevice(context.device_id);
246     default:
247       return InvalidArgument("Unknown/unsupported DLPack device type %d",
248                              context.device_type);
249   }
250 }
251 
252 }  // namespace
253 
BufferToDLPackManagedTensor(py::handle py_buffer,bool take_ownership)254 StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
255                                                   bool take_ownership) {
256   TF_ASSIGN_OR_RETURN(PyBuffer * buffer, PyBuffer::AsPyBuffer(py_buffer));
257   auto pack = std::make_unique<DLPackTensor>();
258   if (buffer->buffer()->on_device_shape().IsTuple()) {
259     return Unimplemented(
260         "unsafe_buffer_pointer is not implemented for tuple "
261         "buffers.");
262   }
263   if (buffer->buffer()->on_device_shape().is_dynamic()) {
264     return Unimplemented("DynamicShape is not implemented in DLPack.");
265   }
266 
267   DLTensor& dt = pack->tensor.dl_tensor;
268   if (take_ownership) {
269     // Block on outstanding operations, so that it is safe to read or mutate the
270     // returned buffer.
271     StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> buffer_or =
272         buffer->buffer()->ReleaseDeviceMemoryOwnership(
273             /*wait_for_operations_to_complete=*/true);
274     if (!buffer_or.ok()) {
275       return InvalidArgument(
276           "Buffer synchronization failed converting to DLPack tensor: %s",
277           buffer_or.status().ToString());
278     }
279     pack->external_reference = buffer_or.ConsumeValueOrDie();
280     if (!pack->external_reference) {
281       return InvalidArgument(
282           "Cannot convert deleted/invalid buffer to DLPack tensor.");
283     }
284   } else {
285     // Block on outstanding operations, so that it is safe to read or mutate the
286     // returned buffer.
287     TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
288     pack->buffer_reference = py::reinterpret_borrow<py::object>(py_buffer);
289     TF_ASSIGN_OR_RETURN(pack->external_reference,
290                         buffer->buffer()->AcquireExternalReference());
291   }
292   dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer();
293   pack->tensor.manager_ctx = pack.get();
294   pack->tensor.deleter = DLPackTensorDeleter;
295   TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
296   dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
297   dt.ndim = buffer->buffer()->on_device_shape().dimensions_size();
298   TF_ASSIGN_OR_RETURN(dt.dtype,
299                       PrimitiveTypeToDLDataType(
300                           buffer->buffer()->on_device_shape().element_type()));
301 
302   pack->shape = std::vector<int64>(
303       buffer->buffer()->on_device_shape().dimensions().begin(),
304       buffer->buffer()->on_device_shape().dimensions().end());
305   pack->strides = StridesForShape(buffer->buffer()->on_device_shape());
306   dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
307   dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
308   dt.byte_offset = 0;
309 
310   py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
311                       [](PyObject* obj) {
312                         DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
313                             PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
314                         if (dlmt) {
315                           DLPackTensorDeleter(dlmt);
316                         } else {
317                           // The tensor has been deleted. Clear any error from
318                           // PyCapsule_GetPointer.
319                           PyErr_Clear();
320                         }
321                       });
322   return capsule;
323 }
324 
DLPackManagedTensorToBuffer(const pybind11::capsule & tensor,std::shared_ptr<PyClient> cpu_client,std::shared_ptr<PyClient> gpu_client)325 StatusOr<PyBuffer::object> DLPackManagedTensorToBuffer(
326     const pybind11::capsule& tensor, std::shared_ptr<PyClient> cpu_client,
327     std::shared_ptr<PyClient> gpu_client) {
328   // Backward compatibility: if only one client is passed, it may be from any
329   // platform. Drop this support after dropping support for jax <= 0.2.14.
330   if (cpu_client && cpu_client->pjrt_client()->platform_id() == kGpuId) {
331     gpu_client = std::move(cpu_client);
332     cpu_client = nullptr;
333   }
334   if (cpu_client && cpu_client->pjrt_client()->platform_id() != kCpuId) {
335     return InvalidArgument("DLPack does not support platform %s",
336                            cpu_client->pjrt_client()->platform_name());
337   }
338 
339   if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
340     return InvalidArgument(
341         "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
342         "Note that a DLPack tensor may be consumed at most once.",
343         absl::string_view(tensor.name()));
344   }
345   DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
346   if (dlmt->dl_tensor.ndim < 0) {
347     return InvalidArgument(
348         "Number of dimensions in DLManagedTensor must be nonnegative, got %d",
349         dlmt->dl_tensor.ndim);
350   }
351   TF_ASSIGN_OR_RETURN(
352       PjRtDevice * device,
353       DeviceForDLContext(cpu_client ? cpu_client->pjrt_client() : nullptr,
354                          gpu_client ? gpu_client->pjrt_client() : nullptr,
355                          dlmt->dl_tensor.ctx));
356   absl::Span<int64 const> dimensions(
357       reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
358   TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
359                       DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
360 
361   std::vector<int64> minor_to_major;
362   if (dlmt->dl_tensor.strides &&
363       absl::c_find(dimensions, 0) == dimensions.end()) {
364     absl::Span<int64 const> strides(
365         reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
366         dlmt->dl_tensor.ndim);
367     TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
368   } else {
369     minor_to_major.resize(dlmt->dl_tensor.ndim);
370     std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
371   }
372   Shape shape =
373       ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
374 
375   std::function<void()> on_delete_callback;
376   if (dlmt->deleter) {
377     on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
378   }
379   TF_ASSIGN_OR_RETURN(auto pjrt_buffer,
380                       device->client()->CreateViewOfDeviceBuffer(
381                           static_cast<char*>(dlmt->dl_tensor.data) +
382                               dlmt->dl_tensor.byte_offset,
383                           shape, device, on_delete_callback));
384   // We have taken ownership of the array inside the capsule; make sure the
385   // capsule it cannot be used again.
386   PyCapsule_SetName(tensor.ptr(), "used_dltensor");
387   PyCapsule_SetDestructor(tensor.ptr(), nullptr);
388   // TODO(phawkins): simplify the expression below once we know cpu_client is
389   // always non-null.
390   return PyBuffer::Make(
391       (cpu_client && device->client() == cpu_client->pjrt_client())
392           ? std::move(cpu_client)
393           : std::move(gpu_client),
394       std::move(pjrt_buffer), Traceback::Get());
395 }
396 
397 }  // namespace xla
398