• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/dlpack.h"
17 
18 #include <string>
19 
20 #include "include/dlpack/dlpack.h"  // from @dlpack
21 #include "tensorflow/c/eager/c_api.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
24 #include "tensorflow/c/tf_status_internal.h"
25 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_reference.h"
28 #include "tensorflow/core/platform/logging.h"
29 
30 namespace tensorflow {
31 
32 namespace {
33 
34 // Managing context for the DLManagedTensor, will manage the lifetime of
35 // DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
36 // original framework of destruction, and this context will be deleted also.
37 struct TfDlManagedTensorCtx {
38   TensorReference reference;
39   std::vector<int64_t> shape;
40   std::vector<int64_t> strides;
41   DLManagedTensor tensor;
42 
TfDlManagedTensorCtxtensorflow::__anon9d85d4420111::TfDlManagedTensorCtx43   explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
44 };
45 
46 // Gets tensor from eager tensor handle.
GetTensorFromHandle(TFE_TensorHandle * h,TF_Status * status)47 const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
48   if (h == nullptr) {
49     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
50     return nullptr;
51   }
52   tensorflow::TensorHandle* handle =
53       tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
54   if (handle->Type() != TensorHandle::LOCAL) {
55     status->status = tensorflow::errors::InvalidArgument(
56         "DLPack doesn't support ", handle->TypeString(), " tensor");
57     return nullptr;
58   }
59   const tensorflow::Tensor* tensor;
60   status->status = handle->Tensor(&tensor);
61   if (!status->status.ok()) {
62     return nullptr;
63   }
64   return tensor;
65 }
66 
67 // Deleter for DLManagedTensor
DLManagedTensorDeleter(DLManagedTensor * arg)68 void DLManagedTensorDeleter(DLManagedTensor* arg) {
69   TfDlManagedTensorCtx* owner =
70       static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
71   owner->reference.Unref();
72   delete owner;
73 }
74 
75 // Converts TF_DATAType to DLPack data type.
GetDlDataType(TF_DataType data_type,TF_Status * status)76 DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
77   DLDataType dtype;
78   dtype.lanes = 1;
79   dtype.bits = TF_DataTypeSize(data_type) * 8;
80   switch (data_type) {
81     case TF_DataType::TF_HALF:
82     case TF_DataType::TF_FLOAT:
83     case TF_DataType::TF_DOUBLE:
84       dtype.code = DLDataTypeCode::kDLFloat;
85       break;
86     case TF_DataType::TF_INT8:
87     case TF_DataType::TF_INT16:
88     case TF_DataType::TF_INT32:
89     case TF_DataType::TF_INT64:
90       dtype.code = DLDataTypeCode::kDLInt;
91       break;
92     case TF_DataType::TF_BOOL:
93     case TF_DataType::TF_UINT8:
94     case TF_DataType::TF_UINT16:
95     case TF_DataType::TF_UINT32:
96     case TF_DataType::TF_UINT64:
97       dtype.code = DLDataTypeCode::kDLUInt;
98       break;
99     case TF_DataType::TF_BFLOAT16:
100       dtype.code = DLDataTypeCode::kDLBfloat;
101       break;
102     case TF_DataType::TF_COMPLEX64:
103     case TF_DataType::TF_COMPLEX128:
104       dtype.code = DLDataTypeCode::kDLComplex;
105       break;
106     default:
107       status->status = tensorflow::errors::InvalidArgument(
108           DataType_Name(static_cast<DataType>(data_type)),
109           " is not supported by dlpack");
110       break;
111   }
112   return dtype;
113 }
114 
115 // Gets DLPack's DLDevice from eager tensor handle.
GetDlContext(TFE_TensorHandle * h,TF_Status * status)116 DLDevice GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
117   DLDevice ctx;
118   const char* device_name =
119       tensorflow::unwrap(h)->BackingDeviceName(&status->status);
120   DeviceNameUtils::ParsedName parsed_name;
121   tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
122   std::string device_type = parsed_name.type;
123   int device_id = 0;
124   if (parsed_name.has_id) {
125     device_id = parsed_name.id;
126   }
127 
128   ctx.device_id = device_id;
129   if (device_type == "CPU") {
130     ctx.device_type = DLDeviceType::kDLCPU;
131   } else if (device_type == "GPU") {
132     ctx.device_type = DLDeviceType::kDLCUDA;
133   } else {
134     status->status = tensorflow::errors::InvalidArgument(
135         "Unsupported Device Type for dlpack");
136   }
137 
138   return ctx;
139 }
140 
141 // Converts DLDevice to TF device name.
DeviceNameFromDlContext(const DLDevice & ctx,TF_Status * status)142 absl::optional<std::string> DeviceNameFromDlContext(const DLDevice& ctx,
143                                                     TF_Status* status) {
144   switch (ctx.device_type) {
145     case DLDeviceType::kDLCPU:
146       return "CPU:0";
147     case DLDeviceType::kDLCUDA:
148       return absl::StrCat("GPU:", ctx.device_id);
149     default:
150       return absl::nullopt;
151   }
152 }
153 
154 // Converts DLPack data type to TF_DATATYPE.
TfDataTypeFormDlDataType(const DLDataType & dtype,TF_DataType * tf_dtype)155 Status TfDataTypeFormDlDataType(const DLDataType& dtype,
156                                 TF_DataType* tf_dtype) {
157   switch (dtype.code) {
158     case DLDataTypeCode::kDLUInt:
159       switch (dtype.bits) {
160         case 8:
161           *tf_dtype = TF_DataType::TF_UINT8;
162           return OkStatus();
163         case 16:
164           *tf_dtype = TF_DataType::TF_UINT16;
165           return OkStatus();
166         case 32:
167           *tf_dtype = TF_DataType::TF_UINT32;
168           return OkStatus();
169         case 64:
170           *tf_dtype = TF_DataType::TF_UINT64;
171           return OkStatus();
172         default:
173           return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
174                                                      dtype.bits);
175       }
176       return OkStatus();
177     case DLDataTypeCode::kDLInt:
178       switch (dtype.bits) {
179         case 8:
180           *tf_dtype = TF_DataType::TF_INT8;
181           return OkStatus();
182         case 16:
183           *tf_dtype = TF_DataType::TF_INT16;
184           return OkStatus();
185         case 32:
186           *tf_dtype = TF_DataType::TF_INT32;
187           return OkStatus();
188         case 64:
189           *tf_dtype = TF_DataType::TF_INT64;
190           return OkStatus();
191         default:
192           return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
193                                                      dtype.bits);
194       }
195       return OkStatus();
196     case DLDataTypeCode::kDLFloat:
197       switch (dtype.bits) {
198         case 16:
199           *tf_dtype = TF_DataType::TF_HALF;
200           return OkStatus();
201         case 32:
202           *tf_dtype = TF_DataType::TF_FLOAT;
203           return OkStatus();
204         case 64:
205           *tf_dtype = TF_DataType::TF_DOUBLE;
206           return OkStatus();
207         default:
208           return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
209                                                      dtype.bits);
210       }
211       break;
212     case DLDataTypeCode::kDLBfloat:
213       switch (dtype.bits) {
214         case 16:
215           *tf_dtype = TF_DataType::TF_BFLOAT16;
216           return OkStatus();
217         default:
218           return tensorflow::errors::InvalidArgument(
219               "Unsupported BFloat bits: ", dtype.bits);
220       }
221       break;
222     case DLDataTypeCode::kDLComplex:
223       switch (dtype.bits) {
224         case 64:
225           *tf_dtype = TF_DataType::TF_COMPLEX64;
226           return OkStatus();
227         case 128:
228           *tf_dtype = TF_DataType::TF_COMPLEX128;
229           return OkStatus();
230         default:
231           return tensorflow::errors::InvalidArgument(
232               "Unsupported Complex bits: ", dtype.bits);
233       }
234       break;
235     default:
236       return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
237                                                  dtype.code);
238   }
239 }
240 
241 // Wraps the deleter function of DLManagedTensor to match the function signature
242 // TFE_NewTensorHandleFromDeviceMemory.
DeallocatorWrapperFunc(void * data,size_t len,void * dlmt_vptr)243 void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
244   TFE_CallDLManagedTensorDeleter(dlmt_vptr);
245 }
246 
247 // Checks whether the stride array matches the layout of compact, row-majored
248 // data.
IsValidStrideCompactRowMajorData(int64_t * shape_arr,int64_t * stride_arr,int ndim)249 bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
250                                       int ndim) {
251   if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
252     return false;
253   }
254   for (int i = ndim - 2; i >= 0; --i) {
255     if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
256       return false;
257     }
258   }
259   return true;
260 }
261 }  // namespace
262 
TFE_CallDLManagedTensorDeleter(void * dlm_ptr)263 void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
264   DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
265   if (dlMTensor->deleter != nullptr) {
266     dlMTensor->deleter(dlMTensor);
267   }
268 }
269 
TFE_HandleToDLPack(TFE_TensorHandle * h,TF_Status * status)270 void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
271   auto tf_dlm_context = GetDlContext(h, status);
272   if (!status->status.ok()) {
273     return nullptr;
274   }
275 
276   auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
277   if (!status->status.ok()) {
278     return nullptr;
279   }
280 
281   const Tensor* tensor = GetTensorFromHandle(h, status);
282   TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
283 
284   auto tf_dlm_type = GetDlDataType(data_type, status);
285   if (!status->status.ok()) {
286     return nullptr;
287   }
288 
289   TensorReference tensor_ref(*tensor);  // This will call buf_->Ref()
290   auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
291   tf_dlm_tensor_ctx->reference = tensor_ref;
292 
293   DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
294   dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
295   dlm_tensor->deleter = &DLManagedTensorDeleter;
296   dlm_tensor->dl_tensor.device = tf_dlm_context;
297   int ndim = tensor->dims();
298   dlm_tensor->dl_tensor.ndim = ndim;
299   dlm_tensor->dl_tensor.data = tf_dlm_data;
300   dlm_tensor->dl_tensor.dtype = tf_dlm_type;
301 
302   std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
303   std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
304   shape_arr->resize(ndim);
305   stride_arr->resize(ndim, 1);
306   for (int i = 0; i < ndim; i++) {
307     (*shape_arr)[i] = tensor->dim_size(i);
308   }
309   for (int i = ndim - 2; i >= 0; --i) {
310     (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
311   }
312 
313   dlm_tensor->dl_tensor.shape = shape_arr->data();
314   // There are two ways to represent compact row-major data
315   // 1) nullptr indicates tensor is compact and row-majored.
316   // 2) fill in the strides array as the real case for compact row-major data.
317   // Here we choose option 2, since some frameworks didn't handle the strides
318   // argument properly.
319   dlm_tensor->dl_tensor.strides = stride_arr->data();
320 
321   dlm_tensor->dl_tensor.byte_offset =
322       0;  // TF doesn't handle the strides and byte_offsets here
323   return static_cast<void*>(dlm_tensor);
324 }
325 
TFE_HandleFromDLPack(void * dlm,TF_Status * status,TFE_Context * ctx)326 TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
327                                        TFE_Context* ctx) {
328   DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
329   DLTensor* dl_tensor = &dlmt->dl_tensor;
330   absl::optional<std::string> device_name =
331       DeviceNameFromDlContext(dl_tensor->device, status);
332   if (!device_name.has_value()) {
333     status->status =
334         tensorflow::errors::InvalidArgument("Unsupported Device Type");
335     return nullptr;
336   }
337   TF_DataType dtype;
338   Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
339   if (!s.ok()) {
340     status->status = std::move(s);
341     return nullptr;
342   }
343   int num_dims = dl_tensor->ndim;
344   const int64_t* dims = dl_tensor->shape;
345   void* data = dl_tensor->data;
346 
347   size_t total_bytes = dl_tensor->dtype.bits / 8;
348   for (int i = 0; i < num_dims; i++) {
349     total_bytes *= dims[i];
350   }
351 
352   if (dl_tensor->strides != nullptr &&
353       !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
354                                         num_dims)) {
355     status->status = tensorflow::errors::InvalidArgument(
356         "Invalid strides array from DLPack");
357     return nullptr;
358   }
359 
360   TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
361       ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
362       total_bytes, &DeallocatorWrapperFunc, dlmt, status);
363 
364   return handle;
365 }
366 
367 }  // namespace tensorflow
368