• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/dlpack.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/types/span.h"
25 #include "include/dlpack/dlpack.h"  // TF:dlpack
26 #include "tensorflow/compiler/xla/python/shared_device_buffer.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
30 #include "tensorflow/stream_executor/device_memory.h"
31 #include "tensorflow/stream_executor/host/host_platform_id.h"
32 #include "tensorflow/stream_executor/platform.h"
33 
34 namespace py = pybind11;
35 
36 namespace xla {
37 namespace {
38 
39 const char* const kDlTensorCapsuleName = "dltensor";
40 
41 struct DLPackTensor {
42   std::shared_ptr<SharedDeviceBuffer> buffer;
43   std::vector<int64> shape;
44   std::vector<int64> strides;
45   DLManagedTensor tensor;
46 };
47 
DLPackTensorDeleter(DLManagedTensor * t)48 void DLPackTensorDeleter(DLManagedTensor* t) {
49   if (t) {
50     delete static_cast<DLPackTensor*>(t->manager_ctx);
51   }
52 }
53 
PrimitiveTypeToDLDataType(PrimitiveType type)54 StatusOr<DLDataType> PrimitiveTypeToDLDataType(PrimitiveType type) {
55   switch (type) {
56     case S8:
57       return DLDataType{kDLInt, 8, 1};
58     case S16:
59       return DLDataType{kDLInt, 16, 1};
60     case S32:
61       return DLDataType{kDLInt, 32, 1};
62     case S64:
63       return DLDataType{kDLInt, 64, 1};
64     case U8:
65       return DLDataType{kDLUInt, 8, 1};
66     case U16:
67       return DLDataType{kDLUInt, 16, 1};
68     case U32:
69       return DLDataType{kDLUInt, 32, 1};
70     case U64:
71       return DLDataType{kDLUInt, 64, 1};
72     case F16:
73       return DLDataType{kDLFloat, 16, 1};
74     case F32:
75       return DLDataType{kDLFloat, 32, 1};
76     case F64:
77       return DLDataType{kDLFloat, 64, 1};
78     case BF16:
79       return DLDataType{kDLBfloat, 16, 1};
80     case PRED:
81     case C64:
82     case C128:
83     default:
84       return Unimplemented("XLA type %s has no DLPack equivalent",
85                            PrimitiveType_Name(type));
86   }
87 }
88 
DLDataTypeToPrimitiveType(DLDataType type)89 StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
90   if (type.lanes != 1) {
91     return Unimplemented("DLPack types with lanes != 1 not implemented, got %d",
92                          type.lanes);
93   }
94   switch (type.code) {
95     case kDLInt:
96       switch (type.bits) {
97         case 8:
98           return S8;
99         case 16:
100           return S16;
101         case 32:
102           return S32;
103         case 64:
104           return S64;
105         default:
106           return Unimplemented(
107               "Invalid or unsupported DLPack integer width: %d bits",
108               type.bits);
109       }
110     case kDLUInt:
111       switch (type.bits) {
112         case 8:
113           return U8;
114         case 16:
115           return U16;
116         case 32:
117           return U32;
118         case 64:
119           return U64;
120         default:
121           return Unimplemented(
122               "Invalid or unsupported DLPack unsigned integer width: %d bits",
123               type.bits);
124       }
125     case kDLFloat:
126       switch (type.bits) {
127         case 16:
128           return F16;
129         case 32:
130           return F32;
131         case 64:
132           return F64;
133         default:
134           return Unimplemented(
135               "Invalid or unsupported DLPack float width: %d bits", type.bits);
136       }
137     case kDLBfloat:
138       switch (type.bits) {
139         case 16:
140           return BF16;
141         default:
142           return Unimplemented(
143               "Invalid or unsupported DLPack Bfloat width: %d bits", type.bits);
144       }
145     default:
146       return Unimplemented("Unknown or invalid DLPack type code %d", type.code);
147   }
148 }
149 
150 // Returns the strides for `shape`.
StridesForShape(const Shape & shape)151 std::vector<int64> StridesForShape(const Shape& shape) {
152   std::vector<int64> strides;
153   CHECK(shape.IsArray());
154   CHECK(shape.has_layout());
155 
156   strides.resize(shape.dimensions_size());
157   int64 stride = 1;
158   for (int i : shape.layout().minor_to_major()) {
159     strides.at(i) = stride;
160     stride *= shape.dimensions(i);
161   }
162   return strides;
163 }
164 
StridesToLayout(absl::Span<int64 const> dims,absl::Span<int64 const> strides)165 StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
166                                              absl::Span<int64 const> strides) {
167   CHECK_EQ(dims.size(), strides.size());
168   std::vector<int64> minor_to_major(dims.size());
169   std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
170   absl::c_sort(minor_to_major, [&](int a, int b) {
171     if (strides[a] < strides[b]) {
172       return true;
173     }
174     if (strides[a] > strides[b]) {
175       return false;
176     }
177     return dims[a] == 1 && dims[b] != 1;
178   });
179   int64 stride = 1;
180   for (int64 d : minor_to_major) {
181     if (strides[d] != stride) {
182       return Unimplemented(
183           "Only DLPack tensors with trivial (compact) striding are supported; "
184           "i.e., tensors whose striding represents a transposition of the "
185           "underlying buffer but not broadcasting. Dimensions were: [%s], "
186           "strides were [%s].",
187           absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
188     }
189     stride *= dims[d];
190   }
191   return minor_to_major;
192 }
193 
DLDeviceTypeForDevice(const Device & device)194 StatusOr<DLDeviceType> DLDeviceTypeForDevice(const Device& device) {
195   const se::Platform* platform =
196       device.local_device_state()->executor()->platform();
197   if (platform->id() == se::host::kHostPlatformId) {
198     return kDLCPU;
199   } else if (platform->id() == se::cuda::kCudaPlatformId) {
200     return kDLGPU;
201   }
202   return InvalidArgument("Device %s cannot be used as a DLPack device.",
203                          device.DebugString());
204 }
205 
DLContextForDevice(const Device & device)206 StatusOr<DLContext> DLContextForDevice(const Device& device) {
207   DLContext context;
208   TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
209   context.device_id = device.local_device_state()->device_ordinal();
210   return context;
211 }
212 
DeviceForDLContext(const PyLocalClient & client,const DLContext & context)213 StatusOr<std::shared_ptr<Device>> DeviceForDLContext(
214     const PyLocalClient& client, const DLContext& context) {
215   se::Platform::Id platform_id;
216   switch (context.device_type) {
217     case kDLCPU:
218       platform_id = se::host::kHostPlatformId;
219       break;
220     case kDLGPU:
221       platform_id = se::cuda::kCudaPlatformId;
222       break;
223     default:
224       return InvalidArgument("Unknown/unsupported DLPack device type %d",
225                              context.device_type);
226   }
227   auto it = absl::c_find_if(
228       client.local_devices(), [&](const std::shared_ptr<Device>& device) {
229         return device->local_device_state()->executor()->platform()->id() ==
230                    platform_id &&
231                device->local_device_state()->device_ordinal() ==
232                    context.device_id;
233       });
234   if (it == client.local_devices().end()) {
235     return InvalidArgument(
236         "No matching device found for DLPack device_type %d device_id %d",
237         context.device_type, context.device_id);
238   }
239   return *it;
240 }
241 
242 }  // namespace
243 
BufferToDLPackManagedTensor(PyLocalBuffer * buffer)244 StatusOr<py::capsule> BufferToDLPackManagedTensor(PyLocalBuffer* buffer) {
245   auto pack = absl::make_unique<DLPackTensor>();
246   pack->buffer = buffer->DeviceBuffer();
247   if (!pack->buffer) {
248     return InvalidArgument(
249         "Cannot convert deleted/invalid buffer to DLPack tensor.");
250   }
251   pack->tensor.manager_ctx = pack.get();
252   pack->tensor.deleter = DLPackTensorDeleter;
253   DLTensor& dt = pack->tensor.dl_tensor;
254   if (buffer->on_device_shape().IsTuple()) {
255     return Unimplemented(
256         "unsafe_buffer_pointer is not implemented for tuple "
257         "buffers.");
258   }
259   TF_RET_CHECK(pack->buffer->device_memory().size() == 1);
260   dt.data = pack->buffer->device_memory().front().opaque();
261   TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->device()));
262   dt.ctx.device_id = buffer->device()->local_device_state()->device_ordinal();
263   dt.ndim = buffer->on_host_shape().dimensions_size();
264   TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType(
265                                     buffer->on_host_shape().element_type()));
266 
267   pack->shape = std::vector<int64>(buffer->on_host_shape().dimensions().begin(),
268                                    buffer->on_host_shape().dimensions().end());
269   pack->strides = StridesForShape(buffer->on_host_shape());
270   dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
271   dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
272   dt.byte_offset = 0;
273 
274   py::capsule capsule(&pack.release()->tensor, kDlTensorCapsuleName,
275                       [](PyObject* obj) {
276                         DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(
277                             PyCapsule_GetPointer(obj, kDlTensorCapsuleName));
278                         if (dlmt) {
279                           DLPackTensorDeleter(dlmt);
280                         } else {
281                           // The tensor has been deleted. Clear any error from
282                           // PyCapsule_GetPointer.
283                           PyErr_Clear();
284                         }
285                       });
286 
287   TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
288   return capsule;
289 }
290 
DLPackManagedTensorToBuffer(const pybind11::capsule & tensor,std::shared_ptr<PyLocalClient> client)291 StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
292     const pybind11::capsule& tensor, std::shared_ptr<PyLocalClient> client) {
293   if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
294     return InvalidArgument(
295         "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
296         "Note that a DLPack tensor may be consumed at most once.",
297         absl::string_view(tensor.name()));
298   }
299   DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor);
300   if (dlmt->dl_tensor.ndim < 0) {
301     return InvalidArgument(
302         "Number of dimensions in DLManagedTensor must be nonnegative, got %d",
303         dlmt->dl_tensor.ndim);
304   }
305   TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
306                       DeviceForDLContext(*client, dlmt->dl_tensor.ctx));
307   absl::Span<int64 const> dimensions(
308       reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
309   TF_ASSIGN_OR_RETURN(PrimitiveType element_type,
310                       DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype));
311 
312   std::vector<int64> minor_to_major;
313   if (dlmt->dl_tensor.strides && !absl::c_find(dimensions, 0)) {
314     absl::Span<int64 const> strides(
315         reinterpret_cast<int64*>(dlmt->dl_tensor.strides),
316         dlmt->dl_tensor.ndim);
317     TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides));
318   } else {
319     minor_to_major.resize(dlmt->dl_tensor.ndim);
320     std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
321   }
322   Shape shape =
323       ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major);
324   se::DeviceMemoryBase buffer(
325       static_cast<char*>(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset,
326       ShapeUtil::ByteSizeOf(shape));
327 
328   std::function<void()> on_delete_callback;
329   if (dlmt->deleter) {
330     on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
331   }
332   auto device_buffer = std::make_shared<SharedDeviceBuffer>(
333       /*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id,
334       std::initializer_list<se::DeviceMemoryBase>{buffer},
335       /*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
336       /*definition_event=*/nullptr, std::move(on_delete_callback));
337 
338   // We have taken ownership of the array inside the capsule; make sure the
339   // capsule it cannot be used again.
340   PyCapsule_SetName(tensor.ptr(), "used_dltensor");
341   PyCapsule_SetDestructor(tensor.ptr(), nullptr);
342   return absl::make_unique<PyLocalBuffer>(shape, shape,
343                                           std::move(device_buffer),
344                                           std::move(client), std::move(device));
345 }
346 
347 }  // namespace xla
348