• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/dlpack.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/span.h"
25 #include "include/dlpack/dlpack.h"  // from @dlpack
26 #include "pybind11/pytypes.h"
27 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
28 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
29 #include "tensorflow/compiler/xla/python/traceback.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/util.h"
32 
33 namespace py = pybind11;
34 
35 namespace xla {
36 namespace {
37 
38 const char* const kDlTensorCapsuleName = "dltensor";
39 
40 struct DLPackTensor {
41   ~DLPackTensor();
42 
43   // `buffer_reference` is populated if we have shared (read-only) access.
44   py::object buffer_reference;
45 
46   // `external_reference` is always populated.
47   std::unique_ptr<PjRtBuffer::ExternalReference> external_reference;
48 
49   std::vector<int64> shape;
50   std::vector<int64> strides;
51   DLManagedTensor tensor;
52 };
53 
~DLPackTensor()54 DLPackTensor::~DLPackTensor() {
55   if (buffer_reference) {
56     GlobalPyRefManager()->AddGarbage(
57         absl::MakeSpan(&buffer_reference, /*size=*/1));
58   }
59 }
60 
DLPackTensorDeleter(DLManagedTensor * t)61 void DLPackTensorDeleter(DLManagedTensor* t) {
62   if (t) {
63     delete static_cast<DLPackTensor*>(t->manager_ctx);
64   }
65 }
66 
PrimitiveTypeToDLDataType(PrimitiveType type)67 StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
68   switch (type) {
69     case S8:
70       return DLDataType{kDLInt, 8, 1};
71     case S16:
72       return DLDataType{kDLInt, 16, 1};
73     case S32:
74       return DLDataType{kDLInt, 32, 1};
75     case S64:
76       return DLDataType{kDLInt, 64, 1};
77     case U8:
78       return DLDataType{kDLUInt, 8, 1};
79     case U16:
80       return DLDataType{kDLUInt, 16, 1};
81     case U32:
82       return DLDataType{kDLUInt, 32, 1};
83     case U64:
84       return DLDataType{kDLUInt, 64, 1};
85     case F16:
86       return DLDataType{kDLFloat, 16, 1};
87     case F32:
88       return DLDataType{kDLFloat, 32, 1};
89     case F64:
90       return DLDataType{kDLFloat, 64, 1};
91     case BF16:
92       return DLDataType{kDLBfloat, 16, 1};
93     case PRED:
94       return DLDataType{kDLUInt, 8, 1};
95     case C64:
96     case C128:
97     default:
98       return Unimplemented("XLA type %s has no DLPack equivalent",
99                            PrimitiveType_Name(type));
100   }
101 }
102 
DLDataTypeToPrimitiveType(DLDataType type)103 StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
104   if (type.lanes != 1) {
105     return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
106                          type.lanes);
107   }
108   switch (type.code) {
109     case kDLInt:
110       switch (type.bits) {
111         case 8:
112           return S8;
113         case 16:
114           return S16;
115         case 32:
116           return S32;
117         case 64:
118           return S64;
119         default:
120           return Unimplemented(
121               "Invalid or unsupported DLPack integer width: %d bits",
122               type.bits);
123       }
124     case kDLUInt:
125       switch (type.bits) {
126         case 8:
127           return U8;
128         case 16:
129           return U16;
130         case 32:
131           return U32;
132         case 64:
133           return U64;
134         default:
135           return Unimplemented(
136               "Invalid or unsupported DLPack unsigned integer width: %d bits",
137               type.bits);
138       }
139     case kDLFloat:
140       switch (type.bits) {
141         case 16:
142           return F16;
143         case 32:
144           return F32;
145         case 64:
146           return F64;
147         default:
148           return Unimplemented(
149               "Invalid or unsupported DLPack float width: %d bits", type.bits);
150       }
151     case kDLBfloat:
152       switch (type.bits) {
153         case 16:
154           return BF16;
155         default:
156           return Unimplemented(
157               "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
158       }
159     default:
160       return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
161   }
162 }
163 
164 // Returns the strides for `shape`.
StridesForShape(const Shape & shape)165 std::vector<int64> StridesForShape(const Shape& shape) {
166   std::vector<int64> strides;
167   CHECK(shape.IsArray());
168   CHECK(shape.has_layout());
169 
170   strides.resize(shape.dimensions_size());
171   int64 stride = 1;
172   for (int i : shape.layout().minor_to_major()) {
173     strides.at(i) = stride;
174     stride *= shape.dimensions(i);
175   }
176   return strides;
177 }
178 
StridesToLayout(absl::Span<int64 const> dims,absl::Span<int64 const> strides)179 StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
180                                              absl::Span<int64 const> strides) {
181   CHECK_EQ(dims.size(), strides.size());
182   std::vector<int64> minor_to_major(dims.size());
183   std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
184   absl::c_sort(minor_to_major, [&](int a, int b) {
185     if (strides[a] < strides[b]) {
186       return true;
187     }
188     if (strides[a] > strides[b]) {
189       return false;
190     }
191     return dims[a] == 1 && dims[b] != 1;
192   });
193   int64 stride = 1;
194   for (int64 d : minor_to_major) {
195     if (strides[d] != stride) {
196       return Unimplemented(
197           "Only DLPack tensors with trivial (compact) striding are supported; "
198           "i.e., tensors whose striding represents a transposition of the "
199           "underlying buffer but not broadcasting. Dimensions were: [%s], "
200           "strides were [%s].",
201           absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
202     }
203     stride *= dims[d];
204   }
205   return minor_to_major;
206 }
207 
DLDeviceTypeForDevice(const PjRtDevice & device)208 StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
209   if (device.client()->platform_id() == kCpuId) {
210     return kDLCPU;
211   } else if (device.client()->platform_id() == kGpuId) {
212     return kDLGPU;
213   }
214   return InvalidArgument("Device %s cannot be used as a DLPack device.",
215                          device.DebugString());
216 }
217 
DLContextForDevice(const PjRtDevice & device)218 StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
219   DLContext context;
220   TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
221   context.device_id = device.local_hardware_id();
222   return context;
223 }
224 
DeviceForDLContext(const PjRtClient & client,const DLContext & context)225 StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
226                                          const DLContext& context) {
227   switch (context.device_type) {
228     case kDLCPU:
229       if (client.platform_id() != kCpuId) {
230         return InvalidArgument(
231             "DLPack CPU device type mismatch with PjRtClient platform %s",
232             client.platform_name());
233       }
234       return client.LookupAddressableDevice(context.device_id);
235     case kDLGPU:
236       if (client.platform_id() != kGpuId) {
237         return InvalidArgument(
238             "DLPack GPU device type mismatch with PjRtClient platform %s",
239             client.platform_name());
240       }
241       return client.LookupAddressableDevice(context.device_id);
242     default:
243       return InvalidArgument("Unknown/unsupported DLPack device type %d",
244                              context.device_type);
245   }
246 }
247 
248 }  // namespace
249 
BufferToDLPackManagedTensor(py::handle py_buffer,bool take_ownership)250 StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
251                                                   bool take_ownership) {
252   PyBuffer* buffer = py::cast<PyBuffer*>(py_buffer);
253   auto pack = std::make_unique<DLPackTensor>();
254   if (buffer->buffer()->on_device_shape().IsTuple()) {
255     return Unimplemented(
256         "unsafe_buffer_pointer is not implemented for tuple "
257         "buffers.");
258   }
259 
260   DLTensor& dt = pack->tensor.dl_tensor;
261   if (take_ownership) {
262     // Block on outstanding operations, so that it is safe to read or mutate the
263     // returned buffer.
264     StatusOr<std::unique_ptr<PjRtBuffer::ExternalReference>> buffer_or =
265         buffer->buffer()->ReleaseDeviceMemoryOwnership(
266             /*wait_for_operations_to_complete=*/true);
267     if (!buffer_or.ok()) {
268       return InvalidArgument(
269           "Buffer synchronization failed converting to DLPack tensor: %s",
270           buffer_or.status().ToString());
271     }
272     pack->external_reference = buffer_or.ConsumeValueOrDie();
273     if (!pack->external_reference) {
274       return InvalidArgument(
275           "Cannot convert deleted/invalid buffer to DLPack tensor.");
276     }
277   } else {
278     // Block on outstanding operations, so that it is safe to read or mutate the
279     // returned buffer.
280     TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
281     pack->buffer_reference = py::reinterpret_borrow<py::object>(py_buffer);
282     TF_ASSIGN_OR_RETURN(pack->external_reference,
283                         buffer->buffer()->AcquireExternalReference());
284   }
285   dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer();
286   pack->tensor.manager_ctx = pack.get();
287   pack->tensor.deleter = DLPackTensorDeleter;
288   TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
289   dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
290   dt.ndim = buffer->buffer()->on_device_shape().dimensions_size();
291   TF_ASSIGN_OR_RETURN(dt.dtype,
292                       PrimitiveTypeToDLDataType(
293                           buffer->buffer()->on_device_shape().element_type()));
294 
295   pack->shape = std::vector<int64>(
296       buffer->buffer()->on_device_shape().dimensions().begin(),
297       buffer->buffer()->on_device_shape().dimensions().end());
298   pack->strides = StridesForShape(buffer->buffer()->on_device_shape());
299   dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
300   dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
301   dt.byte_offset = 0;
302 
303   py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
304                       [](PyObject* obj) {
305                         DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
306                             PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
307                         if (dlmt) {
308                           DLPackTensorDeleter(dlmt);
309                         } else {
310                           // The tensor has been deleted. Clear any error from
311                           // PyCapsule_GetPointer.
312                           PyErr_Clear();
313                         }
314                       });
315   return capsule;
316 }
317 
DLPackManagedTensorToBuffer(const pybind11::capsule & tensor,std::shared_ptr<PyClient> client)318 StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
319     const pybind11::capsule& tensor, std::shared_ptr<PyClient> client) {
320   if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
321     return InvalidArgument(
322         "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
323         "Note that a DLPack tensor may be consumed at most once.",
324         absl::string_view(tensor.name()));
325   }
326   DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
327   if (dlmt->dl_tensor.ndim < 0) {
328     return InvalidArgument(
329         "Number of dimensions in DLManagedTensor must be nonnegative, got %d",
330         dlmt->dl_tensor.ndim);
331   }
332   TF_ASSIGN_OR_RETURN(
333       PjRtDevice * device,
334       DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx));
335   absl::Span<int64 const> dimensions(
336       reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
337   TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
338                       DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
339 
340   std::vector<int64> minor_to_major;
341   if (dlmt->dl_tensor.strides &&
342       absl::c_find(dimensions, 0) == dimensions.end()) {
343     absl::Span<int64 const> strides(
344         reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
345         dlmt->dl_tensor.ndim);
346     TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
347   } else {
348     minor_to_major.resize(dlmt->dl_tensor.ndim);
349     std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
350   }
351   Shape shape =
352       ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
353 
354   std::function<void()> on_delete_callback;
355   if (dlmt->deleter) {
356     on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
357   }
358   TF_ASSIGN_OR_RETURN(auto pjrt_buffer,
359                       client->pjrt_client()->CreateViewOfDeviceBuffer(
360                           static_cast<char*>(dlmt->dl_tensor.data) +
361                               dlmt->dl_tensor.byte_offset,
362                           shape, device, on_delete_callback));
363   // We have taken ownership of the array inside the capsule; make sure the
364   // capsule it cannot be used again.
365   PyCapsule_SetName(tensor.ptr(), "used_dltensor");
366   PyCapsule_SetDestructor(tensor.ptr(), nullptr);
367   return std::make_unique<PyBuffer>(std::move(client), std::move(pjrt_buffer),
368                                     Traceback::Get());
369 }
370 
371 }  // namespace xla
372