1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/dlpack.h"
17
18 #include "include/dlpack/dlpack.h" // from @dlpack
19 #include "tensorflow/c/eager/c_api.h"
20 #include "tensorflow/c/eager/c_api_experimental.h"
21 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
22 #include "tensorflow/c/tf_status_internal.h"
23 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_reference.h"
26 #include "tensorflow/core/platform/logging.h"
27
28 namespace tensorflow {
29
30 namespace {
31
32 // Managing context for the DLManagedTensor, will manage the lifetime of
33 // DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
34 // original framework of destruction, and this context will be deleted also.
35 struct TfDlManagedTensorCtx {
36 TensorReference reference;
37 std::vector<int64_t> shape;
38 std::vector<int64_t> strides;
39 DLManagedTensor tensor;
40
TfDlManagedTensorCtxtensorflow::__anon25ce42a40111::TfDlManagedTensorCtx41 explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
42 };
43
44 // Gets tensor from eager tensor handle.
GetTensorFromHandle(TFE_TensorHandle * h,TF_Status * status)45 const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
46 if (h == nullptr) {
47 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
48 return nullptr;
49 }
50 tensorflow::TensorHandle* handle =
51 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
52 if (handle->Type() != TensorHandle::LOCAL) {
53 status->status = tensorflow::errors::InvalidArgument(
54 "DLPack doesn't support ", handle->TypeString(), " tensor");
55 return nullptr;
56 }
57 const tensorflow::Tensor* tensor;
58 status->status = handle->Tensor(&tensor);
59 if (!status->status.ok()) {
60 return nullptr;
61 }
62 return tensor;
63 }
64
65 // Deleter for DLManagedTensor
DLManagedTensorDeleter(DLManagedTensor * arg)66 void DLManagedTensorDeleter(DLManagedTensor* arg) {
67 TfDlManagedTensorCtx* owner =
68 static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
69 owner->reference.Unref();
70 delete owner;
71 }
72
73 // Converts TF_DATAType to DLPack data type.
GetDlDataType(TF_DataType data_type,TF_Status * status)74 DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
75 DLDataType dtype;
76 dtype.lanes = 1;
77 dtype.bits = TF_DataTypeSize(data_type) * 8;
78 switch (data_type) {
79 case TF_DataType::TF_HALF:
80 case TF_DataType::TF_FLOAT:
81 case TF_DataType::TF_DOUBLE:
82 dtype.code = DLDataTypeCode::kDLFloat;
83 break;
84 case TF_DataType::TF_INT8:
85 case TF_DataType::TF_INT16:
86 case TF_DataType::TF_INT32:
87 case TF_DataType::TF_INT64:
88 dtype.code = DLDataTypeCode::kDLInt;
89 break;
90 case TF_DataType::TF_BOOL:
91 case TF_DataType::TF_UINT8:
92 case TF_DataType::TF_UINT16:
93 case TF_DataType::TF_UINT32:
94 case TF_DataType::TF_UINT64:
95 dtype.code = DLDataTypeCode::kDLUInt;
96 break;
97 case TF_DataType::TF_BFLOAT16:
98 dtype.code = DLDataTypeCode::kDLBfloat;
99 break;
100 default:
101 status->status = tensorflow::errors::InvalidArgument(
102 DataType_Name(static_cast<DataType>(data_type)),
103 " is not supported by dlpack");
104 break;
105 }
106 return dtype;
107 }
108
109 // Gets DLPack's DLContext from eager tensor handle.
GetDlContext(TFE_TensorHandle * h,TF_Status * status)110 DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
111 DLContext ctx;
112 const char* device_name =
113 tensorflow::unwrap(h)->BackingDeviceName(&status->status);
114 DeviceNameUtils::ParsedName parsed_name;
115 tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
116 std::string device_type = parsed_name.type;
117 int device_id = 0;
118 if (parsed_name.has_id) {
119 device_id = parsed_name.id;
120 }
121
122 ctx.device_id = device_id;
123 if (device_type == "CPU") {
124 ctx.device_type = DLDeviceType::kDLCPU;
125 } else if (device_type == "GPU") {
126 ctx.device_type = DLDeviceType::kDLGPU;
127 } else {
128 status->status = tensorflow::errors::InvalidArgument(
129 "Unsupported Device Type for dlpack");
130 }
131
132 return ctx;
133 }
134
135 // Converts DLContext to TF device name.
DeviceNameFromDlContext(const DLContext & ctx,TF_Status * status)136 absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
137 TF_Status* status) {
138 switch (ctx.device_type) {
139 case DLDeviceType::kDLCPU:
140 return "CPU:0";
141 case DLDeviceType::kDLGPU:
142 return absl::StrCat("GPU:", ctx.device_id);
143 default:
144 return absl::nullopt;
145 }
146 }
147
148 // Converts DLPack data type to TF_DATATYPE.
TfDataTypeFormDlDataType(const DLDataType & dtype,TF_DataType * tf_dtype)149 Status TfDataTypeFormDlDataType(const DLDataType& dtype,
150 TF_DataType* tf_dtype) {
151 switch (dtype.code) {
152 case DLDataTypeCode::kDLUInt:
153 switch (dtype.bits) {
154 case 8:
155 *tf_dtype = TF_DataType::TF_UINT8;
156 return Status::OK();
157 case 16:
158 *tf_dtype = TF_DataType::TF_UINT16;
159 return Status::OK();
160 case 32:
161 *tf_dtype = TF_DataType::TF_UINT32;
162 return Status::OK();
163 case 64:
164 *tf_dtype = TF_DataType::TF_UINT64;
165 return Status::OK();
166 default:
167 return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
168 dtype.bits);
169 }
170 return Status::OK();
171 case DLDataTypeCode::kDLInt:
172 switch (dtype.bits) {
173 case 8:
174 *tf_dtype = TF_DataType::TF_INT8;
175 return Status::OK();
176 case 16:
177 *tf_dtype = TF_DataType::TF_INT16;
178 return Status::OK();
179 case 32:
180 *tf_dtype = TF_DataType::TF_INT32;
181 return Status::OK();
182 case 64:
183 *tf_dtype = TF_DataType::TF_INT64;
184 return Status::OK();
185 default:
186 return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
187 dtype.bits);
188 }
189 return Status::OK();
190 case DLDataTypeCode::kDLFloat:
191 switch (dtype.bits) {
192 case 16:
193 *tf_dtype = TF_DataType::TF_HALF;
194 return Status::OK();
195 case 32:
196 *tf_dtype = TF_DataType::TF_FLOAT;
197 return Status::OK();
198 case 64:
199 *tf_dtype = TF_DataType::TF_DOUBLE;
200 return Status::OK();
201 default:
202 return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
203 dtype.bits);
204 }
205 break;
206 case DLDataTypeCode::kDLBfloat:
207 switch (dtype.bits) {
208 case 16:
209 *tf_dtype = TF_DataType::TF_BFLOAT16;
210 return Status::OK();
211 default:
212 return tensorflow::errors::InvalidArgument(
213 "Unsupported BFloat bits: ", dtype.bits);
214 }
215 break;
216 default:
217 return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
218 dtype.code);
219 }
220 }
221
222 // Wraps the deleter function of DLManagedTensor to match the function signature
223 // TFE_NewTensorHandleFromDeviceMemory.
DeallocatorWrapperFunc(void * data,size_t len,void * dlmt_vptr)224 void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
225 TFE_CallDLManagedTensorDeleter(dlmt_vptr);
226 }
227
228 // Checks whether the stride array matches the layout of compact, row-majored
229 // data.
IsValidStrideCompactRowMajorData(int64_t * shape_arr,int64_t * stride_arr,int ndim)230 bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
231 int ndim) {
232 if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
233 return false;
234 }
235 for (int i = ndim - 2; i >= 0; --i) {
236 if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
237 return false;
238 }
239 }
240 return true;
241 }
242 } // namespace
243
TFE_CallDLManagedTensorDeleter(void * dlm_ptr)244 void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
245 DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
246 if (dlMTensor->deleter != nullptr) {
247 dlMTensor->deleter(dlMTensor);
248 }
249 }
250
TFE_HandleToDLPack(TFE_TensorHandle * h,TF_Status * status)251 void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
252 auto tf_dlm_context = GetDlContext(h, status);
253 if (!status->status.ok()) {
254 return nullptr;
255 }
256
257 auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
258 if (!status->status.ok()) {
259 return nullptr;
260 }
261
262 const Tensor* tensor = GetTensorFromHandle(h, status);
263 TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
264
265 auto tf_dlm_type = GetDlDataType(data_type, status);
266 if (!status->status.ok()) {
267 return nullptr;
268 }
269
270 TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
271 auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
272 tf_dlm_tensor_ctx->reference = tensor_ref;
273
274 DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
275 dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
276 dlm_tensor->deleter = &DLManagedTensorDeleter;
277 dlm_tensor->dl_tensor.ctx = tf_dlm_context;
278 int ndim = tensor->dims();
279 dlm_tensor->dl_tensor.ndim = ndim;
280 dlm_tensor->dl_tensor.data = tf_dlm_data;
281 dlm_tensor->dl_tensor.dtype = tf_dlm_type;
282
283 std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
284 std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
285 shape_arr->resize(ndim);
286 stride_arr->resize(ndim, 1);
287 for (int i = 0; i < ndim; i++) {
288 (*shape_arr)[i] = tensor->dim_size(i);
289 }
290 for (int i = ndim - 2; i >= 0; --i) {
291 (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
292 }
293
294 dlm_tensor->dl_tensor.shape = shape_arr->data();
295 // There are two ways to represent compact row-major data
296 // 1) nullptr indicates tensor is compact and row-majored.
297 // 2) fill in the strides array as the real case for compact row-major data.
298 // Here we choose option 2, since some frameworks didn't handle the strides
299 // argument properly.
300 dlm_tensor->dl_tensor.strides = stride_arr->data();
301
302 dlm_tensor->dl_tensor.byte_offset =
303 0; // TF doesn't handle the strides and byte_offsets here
304 return static_cast<void*>(dlm_tensor);
305 }
306
TFE_HandleFromDLPack(void * dlm,TF_Status * status,TFE_Context * ctx)307 TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
308 TFE_Context* ctx) {
309 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
310 DLTensor* dl_tensor = &dlmt->dl_tensor;
311 absl::optional<std::string> device_name =
312 DeviceNameFromDlContext(dl_tensor->ctx, status);
313 if (!device_name.has_value()) {
314 status->status =
315 tensorflow::errors::InvalidArgument("Unsupported Device Type");
316 return nullptr;
317 }
318 TF_DataType dtype;
319 Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
320 if (!s.ok()) {
321 status->status = std::move(s);
322 return nullptr;
323 }
324 int num_dims = dl_tensor->ndim;
325 const int64_t* dims = dl_tensor->shape;
326 void* data = dl_tensor->data;
327
328 size_t total_bytes = dl_tensor->dtype.bits / 8;
329 for (int i = 0; i < num_dims; i++) {
330 total_bytes *= dims[i];
331 }
332
333 if (dl_tensor->strides != nullptr &&
334 !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
335 num_dims)) {
336 status->status = tensorflow::errors::InvalidArgument(
337 "Invalid strides array from DLPack");
338 return nullptr;
339 }
340
341 TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
342 ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
343 total_bytes, &DeallocatorWrapperFunc, dlmt, status);
344
345 return handle;
346 }
347
348 } // namespace tensorflow
349