• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/dlpack.h"
17 
18 #include "include/dlpack/dlpack.h"  // from @dlpack
19 #include "tensorflow/c/eager/c_api.h"
20 #include "tensorflow/c/eager/c_api_experimental.h"
21 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
22 #include "tensorflow/c/tf_status_internal.h"
23 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_reference.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace tensorflow {
29 
30 namespace {
31 
32 // Managing context for the DLManagedTensor, will manage the lifetime of
33 // DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
34 // original framework of destruction, and this context will be deleted also.
35 struct TfDlManagedTensorCtx {
36   TensorReference reference;
37   std::vector<int64_t> shape;
38   std::vector<int64_t> strides;
39   DLManagedTensor tensor;
40 
TfDlManagedTensorCtxtensorflow::__anon25ce42a40111::TfDlManagedTensorCtx41   explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
42 };
43 
44 // Gets tensor from eager tensor handle.
GetTensorFromHandle(TFE_TensorHandle * h,TF_Status * status)45 const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
46   if (h == nullptr) {
47     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
48     return nullptr;
49   }
50   tensorflow::TensorHandle* handle =
51       tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
52   if (handle->Type() != TensorHandle::LOCAL) {
53     status->status = tensorflow::errors::InvalidArgument(
54         "DLPack doesn't support ", handle->TypeString(), " tensor");
55     return nullptr;
56   }
57   const tensorflow::Tensor* tensor;
58   status->status = handle->Tensor(&tensor);
59   if (!status->status.ok()) {
60     return nullptr;
61   }
62   return tensor;
63 }
64 
65 // Deleter for DLManagedTensor
DLManagedTensorDeleter(DLManagedTensor * arg)66 void DLManagedTensorDeleter(DLManagedTensor* arg) {
67   TfDlManagedTensorCtx* owner =
68       static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
69   owner->reference.Unref();
70   delete owner;
71 }
72 
73 // Converts TF_DATAType to DLPack data type.
GetDlDataType(TF_DataType data_type,TF_Status * status)74 DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
75   DLDataType dtype;
76   dtype.lanes = 1;
77   dtype.bits = TF_DataTypeSize(data_type) * 8;
78   switch (data_type) {
79     case TF_DataType::TF_HALF:
80     case TF_DataType::TF_FLOAT:
81     case TF_DataType::TF_DOUBLE:
82       dtype.code = DLDataTypeCode::kDLFloat;
83       break;
84     case TF_DataType::TF_INT8:
85     case TF_DataType::TF_INT16:
86     case TF_DataType::TF_INT32:
87     case TF_DataType::TF_INT64:
88       dtype.code = DLDataTypeCode::kDLInt;
89       break;
90     case TF_DataType::TF_BOOL:
91     case TF_DataType::TF_UINT8:
92     case TF_DataType::TF_UINT16:
93     case TF_DataType::TF_UINT32:
94     case TF_DataType::TF_UINT64:
95       dtype.code = DLDataTypeCode::kDLUInt;
96       break;
97     case TF_DataType::TF_BFLOAT16:
98       dtype.code = DLDataTypeCode::kDLBfloat;
99       break;
100     default:
101       status->status = tensorflow::errors::InvalidArgument(
102           DataType_Name(static_cast<DataType>(data_type)),
103           " is not supported by dlpack");
104       break;
105   }
106   return dtype;
107 }
108 
109 // Gets DLPack's DLContext from eager tensor handle.
GetDlContext(TFE_TensorHandle * h,TF_Status * status)110 DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
111   DLContext ctx;
112   const char* device_name =
113       tensorflow::unwrap(h)->BackingDeviceName(&status->status);
114   DeviceNameUtils::ParsedName parsed_name;
115   tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
116   std::string device_type = parsed_name.type;
117   int device_id = 0;
118   if (parsed_name.has_id) {
119     device_id = parsed_name.id;
120   }
121 
122   ctx.device_id = device_id;
123   if (device_type == "CPU") {
124     ctx.device_type = DLDeviceType::kDLCPU;
125   } else if (device_type == "GPU") {
126     ctx.device_type = DLDeviceType::kDLGPU;
127   } else {
128     status->status = tensorflow::errors::InvalidArgument(
129         "Unsupported Device Type for dlpack");
130   }
131 
132   return ctx;
133 }
134 
135 // Converts DLContext to TF device name.
DeviceNameFromDlContext(const DLContext & ctx,TF_Status * status)136 absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
137                                                     TF_Status* status) {
138   switch (ctx.device_type) {
139     case DLDeviceType::kDLCPU:
140       return "CPU:0";
141     case DLDeviceType::kDLGPU:
142       return absl::StrCat("GPU:", ctx.device_id);
143     default:
144       return absl::nullopt;
145   }
146 }
147 
148 // Converts DLPack data type to TF_DATATYPE.
TfDataTypeFormDlDataType(const DLDataType & dtype,TF_DataType * tf_dtype)149 Status TfDataTypeFormDlDataType(const DLDataType& dtype,
150                                 TF_DataType* tf_dtype) {
151   switch (dtype.code) {
152     case DLDataTypeCode::kDLUInt:
153       switch (dtype.bits) {
154         case 8:
155           *tf_dtype = TF_DataType::TF_UINT8;
156           return Status::OK();
157         case 16:
158           *tf_dtype = TF_DataType::TF_UINT16;
159           return Status::OK();
160         case 32:
161           *tf_dtype = TF_DataType::TF_UINT32;
162           return Status::OK();
163         case 64:
164           *tf_dtype = TF_DataType::TF_UINT64;
165           return Status::OK();
166         default:
167           return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
168                                                      dtype.bits);
169       }
170       return Status::OK();
171     case DLDataTypeCode::kDLInt:
172       switch (dtype.bits) {
173         case 8:
174           *tf_dtype = TF_DataType::TF_INT8;
175           return Status::OK();
176         case 16:
177           *tf_dtype = TF_DataType::TF_INT16;
178           return Status::OK();
179         case 32:
180           *tf_dtype = TF_DataType::TF_INT32;
181           return Status::OK();
182         case 64:
183           *tf_dtype = TF_DataType::TF_INT64;
184           return Status::OK();
185         default:
186           return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
187                                                      dtype.bits);
188       }
189       return Status::OK();
190     case DLDataTypeCode::kDLFloat:
191       switch (dtype.bits) {
192         case 16:
193           *tf_dtype = TF_DataType::TF_HALF;
194           return Status::OK();
195         case 32:
196           *tf_dtype = TF_DataType::TF_FLOAT;
197           return Status::OK();
198         case 64:
199           *tf_dtype = TF_DataType::TF_DOUBLE;
200           return Status::OK();
201         default:
202           return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
203                                                      dtype.bits);
204       }
205       break;
206     case DLDataTypeCode::kDLBfloat:
207       switch (dtype.bits) {
208         case 16:
209           *tf_dtype = TF_DataType::TF_BFLOAT16;
210           return Status::OK();
211         default:
212           return tensorflow::errors::InvalidArgument(
213               "Unsupported BFloat bits: ", dtype.bits);
214       }
215       break;
216     default:
217       return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
218                                                  dtype.code);
219   }
220 }
221 
222 // Wraps the deleter function of DLManagedTensor to match the function signature
223 // TFE_NewTensorHandleFromDeviceMemory.
DeallocatorWrapperFunc(void * data,size_t len,void * dlmt_vptr)224 void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
225   TFE_CallDLManagedTensorDeleter(dlmt_vptr);
226 }
227 
228 // Checks whether the stride array matches the layout of compact, row-majored
229 // data.
IsValidStrideCompactRowMajorData(int64_t * shape_arr,int64_t * stride_arr,int ndim)230 bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
231                                       int ndim) {
232   if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
233     return false;
234   }
235   for (int i = ndim - 2; i >= 0; --i) {
236     if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
237       return false;
238     }
239   }
240   return true;
241 }
242 }  // namespace
243 
TFE_CallDLManagedTensorDeleter(void * dlm_ptr)244 void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
245   DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
246   if (dlMTensor->deleter != nullptr) {
247     dlMTensor->deleter(dlMTensor);
248   }
249 }
250 
TFE_HandleToDLPack(TFE_TensorHandle * h,TF_Status * status)251 void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
252   auto tf_dlm_context = GetDlContext(h, status);
253   if (!status->status.ok()) {
254     return nullptr;
255   }
256 
257   auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
258   if (!status->status.ok()) {
259     return nullptr;
260   }
261 
262   const Tensor* tensor = GetTensorFromHandle(h, status);
263   TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
264 
265   auto tf_dlm_type = GetDlDataType(data_type, status);
266   if (!status->status.ok()) {
267     return nullptr;
268   }
269 
270   TensorReference tensor_ref(*tensor);  // This will call buf_->Ref()
271   auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
272   tf_dlm_tensor_ctx->reference = tensor_ref;
273 
274   DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
275   dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
276   dlm_tensor->deleter = &DLManagedTensorDeleter;
277   dlm_tensor->dl_tensor.ctx = tf_dlm_context;
278   int ndim = tensor->dims();
279   dlm_tensor->dl_tensor.ndim = ndim;
280   dlm_tensor->dl_tensor.data = tf_dlm_data;
281   dlm_tensor->dl_tensor.dtype = tf_dlm_type;
282 
283   std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
284   std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
285   shape_arr->resize(ndim);
286   stride_arr->resize(ndim, 1);
287   for (int i = 0; i < ndim; i++) {
288     (*shape_arr)[i] = tensor->dim_size(i);
289   }
290   for (int i = ndim - 2; i >= 0; --i) {
291     (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
292   }
293 
294   dlm_tensor->dl_tensor.shape = shape_arr->data();
295   // There are two ways to represent compact row-major data
296   // 1) nullptr indicates tensor is compact and row-majored.
297   // 2) fill in the strides array as the real case for compact row-major data.
298   // Here we choose option 2, since some frameworks didn't handle the strides
299   // argument properly.
300   dlm_tensor->dl_tensor.strides = stride_arr->data();
301 
302   dlm_tensor->dl_tensor.byte_offset =
303       0;  // TF doesn't handle the strides and byte_offsets here
304   return static_cast<void*>(dlm_tensor);
305 }
306 
TFE_HandleFromDLPack(void * dlm,TF_Status * status,TFE_Context * ctx)307 TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
308                                        TFE_Context* ctx) {
309   DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
310   DLTensor* dl_tensor = &dlmt->dl_tensor;
311   absl::optional<std::string> device_name =
312       DeviceNameFromDlContext(dl_tensor->ctx, status);
313   if (!device_name.has_value()) {
314     status->status =
315         tensorflow::errors::InvalidArgument("Unsupported Device Type");
316     return nullptr;
317   }
318   TF_DataType dtype;
319   Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
320   if (!s.ok()) {
321     status->status = std::move(s);
322     return nullptr;
323   }
324   int num_dims = dl_tensor->ndim;
325   const int64_t* dims = dl_tensor->shape;
326   void* data = dl_tensor->data;
327 
328   size_t total_bytes = dl_tensor->dtype.bits / 8;
329   for (int i = 0; i < num_dims; i++) {
330     total_bytes *= dims[i];
331   }
332 
333   if (dl_tensor->strides != nullptr &&
334       !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
335                                         num_dims)) {
336     status->status = tensorflow::errors::InvalidArgument(
337         "Invalid strides array from DLPack");
338     return nullptr;
339   }
340 
341   TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
342       ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
343       total_bytes, &DeallocatorWrapperFunc, dlmt, status);
344 
345   return handle;
346 }
347 
348 }  // namespace tensorflow
349