1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/eager/abstract_tensor_handle.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/immediate_execution_operation.h"
32 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_op_internal.h"
35 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
36 #include "tensorflow/c/tf_tensor_internal.h"
37 #include "tensorflow/core/common_runtime/copy_tensor.h"
38 #include "tensorflow/core/common_runtime/device.h"
39 #include "tensorflow/core/common_runtime/device_factory.h"
40 #include "tensorflow/core/common_runtime/device_mgr.h"
41 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
42 #include "tensorflow/core/common_runtime/eager/context.h"
43 #include "tensorflow/core/common_runtime/eager/custom_device.h"
44 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
45 #include "tensorflow/core/common_runtime/eager/execute.h"
46 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
47 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/framework/attr_value.pb.h"
50 #include "tensorflow/core/framework/device_attributes.pb.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/node_def_util.h"
53 #include "tensorflow/core/framework/rendezvous.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/platform/casts.h"
57 #include "tensorflow/core/platform/errors.h"
58 #include "tensorflow/core/platform/platform.h"
59 #include "tensorflow/core/platform/status.h"
60 #include "tensorflow/core/profiler/lib/traceme.h"
61 #include "tensorflow/core/protobuf/error_codes.pb.h"
62 #include "tensorflow/core/public/version.h"
63
64 // "tensorflow/core/platform/platform.h" must be included first before using
65 // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
66 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
67 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
68 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
69 #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
70
71 #if !defined(IS_MOBILE_PLATFORM)
72 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
73 #endif // !IS_MOBILE_PLATFORM
74
75 using tensorflow::string;
76
77 namespace {
78
DeviceName(const tensorflow::Device * d)79 string DeviceName(const tensorflow::Device* d) {
80 return (d == nullptr) ? "cpu:0" : d->name();
81 }
82
83 // Annotate eager runtime construction context to the given `function_def` as
84 // an attribute.
AnnotateEagerRuntimeConstructionContext(tensorflow::FunctionDef & function_def)85 void AnnotateEagerRuntimeConstructionContext(
86 tensorflow::FunctionDef& function_def) {
87 tensorflow::AttrValue value;
88 SetAttrValue("kEagerRuntime", &value);
89 (*function_def.mutable_attr())["_construction_context"] = value;
90 }
91
92 } // namespace
93
94 extern "C" {
95
TFE_NewContextOptions()96 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
97
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)98 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
99 size_t proto_len, TF_Status* status) {
100 TF_SetConfig(&options->session_options, proto, proto_len, status);
101 }
102
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)103 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
104 unsigned char enable) {
105 options->async = enable;
106 }
107
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)108 void TFE_ContextOptionsSetDevicePlacementPolicy(
109 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
110 options->device_placement_policy = policy;
111 }
112
TFE_DeleteContextOptions(TFE_ContextOptions * options)113 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
114
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)115 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
116 if (opts->use_tfrt) {
117 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
118 tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
119 opts->session_options.options,
120 static_cast<tensorflow::ContextDevicePlacementPolicy>(
121 opts->device_placement_policy),
122 opts->async);
123 #if !defined(IS_MOBILE_PLATFORM)
124 tfrt_context->SetDistributedManager(
125 tfrt::tf::CreateDistributedManagerContext(
126 tfrt_context->GetCoreRuntime()->GetHostContext()));
127 #endif // !IS_MOBILE_PLATFORM
128 return tensorflow::wrap(tfrt_context);
129 #else
130 status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
131 return nullptr;
132 #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
133 }
134 std::vector<std::unique_ptr<tensorflow::Device>> devices;
135 status->status = tensorflow::DeviceFactory::AddDevices(
136 opts->session_options.options, "/job:localhost/replica:0/task:0",
137 &devices);
138 if (!status->status.ok()) return nullptr;
139 std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
140 new tensorflow::StaticDeviceMgr(std::move(devices)));
141
142 tensorflow::Rendezvous* r =
143 new tensorflow::IntraProcessRendezvous(device_mgr.get());
144 tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
145 opts->session_options.options,
146 static_cast<tensorflow::ContextDevicePlacementPolicy>(
147 opts->device_placement_policy),
148 opts->async, device_mgr.release(),
149 /*device_mgr_owned*/ true, r);
150 #if !defined(IS_MOBILE_PLATFORM)
151 eager_context->SetDistributedManager(
152 std::make_unique<tensorflow::EagerContextDistributedManager>(
153 eager_context));
154 #endif // !IS_MOBILE_PLATFORM
155 return tensorflow::wrap(eager_context);
156 }
157
TFE_DeleteContext(TFE_Context * ctx)158 void TFE_DeleteContext(TFE_Context* ctx) {
159 if (ctx == nullptr) {
160 return;
161 }
162
163 // ctx->RefCountIsOne() should be true here.
164 tensorflow::unwrap(ctx)->Release();
165 }
166
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)167 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
168 TF_DeviceList* l = new TF_DeviceList;
169 tensorflow::unwrap(ctx)->ListDevices(&l->response);
170 return l;
171 }
172
TFE_ContextClearCaches(TFE_Context * ctx)173 void TFE_ContextClearCaches(TFE_Context* ctx) {
174 tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
175 }
176
177 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)178 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
179 int keep_alive_secs,
180 const void* proto,
181 size_t proto_len,
182 TF_Status* status) {
183 #if defined(IS_MOBILE_PLATFORM)
184 status->status = tensorflow::errors::Unimplemented(
185 "TFE_ContextSetServerDef not supported on mobile");
186 #else // !defined(IS_MOBILE_PLATFORM)
187 tensorflow::ServerDef server_def;
188 if (!server_def.ParseFromArray(proto, proto_len)) {
189 status->status = tensorflow::errors::InvalidArgument(
190 "Invalid tensorflow.ServerDef protocol buffer");
191 return;
192 }
193 status->status =
194 tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
195 server_def, /*reset_context=*/true, keep_alive_secs);
196 #endif // !IS_MOBILE_PLATFORM
197 }
198
TFE_ContextUpdateServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)199 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
200 int keep_alive_secs,
201 const void* proto,
202 size_t proto_len,
203 TF_Status* status) {
204 #if defined(IS_MOBILE_PLATFORM)
205 status->status = tensorflow::errors::Unimplemented(
206 "TFE_ContextSetServerDef not supported on mobile");
207 #else // !defined(IS_MOBILE_PLATFORM)
208 tensorflow::ServerDef server_def;
209 tensorflow::EagerContext* context =
210 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
211 if (!server_def.ParseFromArray(proto, proto_len)) {
212 status->status = tensorflow::errors::InvalidArgument(
213 "Invalid tensorflow.ServerDef protocol buffer");
214 return;
215 } else if (context->GetContextId() ==
216 tensorflow::EagerContext::kInvalidContextId) {
217 status->status = tensorflow::errors::InvalidArgument(
218 "Trying to update a context with invalid context id.");
219 }
220 status->status =
221 tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
222 server_def, /*reset_context=*/false, keep_alive_secs);
223 #endif // !IS_MOBILE_PLATFORM
224 }
225
TFE_ContextCheckAlive(TFE_Context * ctx,const char * worker_name,TF_Status * status)226 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
227 const char* worker_name,
228 TF_Status* status) {
229 #if defined(IS_MOBILE_PLATFORM)
230 status->status = tensorflow::errors::Unimplemented(
231 "TFE_ContextSetServerDef not supported on mobile");
232 return false;
233 #else // !defined(IS_MOBILE_PLATFORM)
234 bool is_alive;
235 status->status =
236 tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
237 worker_name, &is_alive);
238 return is_alive;
239 #endif // !IS_MOBILE_PLATFORM
240 }
241
TFE_ContextAsyncWait(TFE_Context * ctx,TF_Status * status)242 TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
243 TF_Status* status) {
244 #if defined(IS_MOBILE_PLATFORM)
245 status->status = tensorflow::Status::OK();
246 #else // !defined(IS_MOBILE_PLATFORM)
247 status->status = tensorflow::unwrap(ctx)->AsyncWait();
248 #endif // !IS_MOBILE_PLATFORM
249 }
250
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)251 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
252 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
253 tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
254 static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
255 }
256
257 // Note: this function looks up a thread local policy. So it should be called in
258 // the appropriate client thread. In particular, in async mode, it may not be
259 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)260 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
261 TFE_Context* ctx) {
262 return static_cast<TFE_ContextDevicePlacementPolicy>(
263 tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
264 }
265
TFE_NewTensorHandle(const TF_Tensor * t,TF_Status * status)266 TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
267 tensorflow::Tensor tensor;
268 status->status = tensorflow::TF_TensorToTensor(t, &tensor);
269 if (!status->status.ok()) return nullptr;
270
271 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
272 }
273
TFE_DeleteTensorHandle(TFE_TensorHandle * h)274 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
275 if (h == nullptr) return;
276
277 tensorflow::profiler::TraceMe activity(
278 "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
279 if (h) {
280 tensorflow::unwrap(h)->Release();
281 }
282 }
283
TFE_TensorHandleDataType(TFE_TensorHandle * h)284 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
285 return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
286 }
287
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)288 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
289 if (h == nullptr) {
290 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
291 return -1;
292 }
293
294 int num_dims = -1;
295 status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
296 return num_dims;
297 }
298
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)299 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
300 if (h == nullptr) {
301 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
302 return -1;
303 }
304
305 tensorflow::int64 num_elements = -1;
306 status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
307 return num_elements;
308 }
309
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)310 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
311 TF_Status* status) {
312 if (h == nullptr) {
313 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
314 return -1;
315 }
316
317 tensorflow::int64 dim = -1;
318 status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
319 return dim;
320 }
321
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)322 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
323 if (h == nullptr) {
324 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
325 return nullptr;
326 }
327 return tensorflow::unwrap(h)->DeviceName(&status->status);
328 }
329
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)330 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
331 TF_Status* status) {
332 if (h == nullptr) {
333 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
334 return nullptr;
335 }
336 return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
337 }
338
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)339 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
340 TFE_TensorHandle* h, TF_Status* status) {
341 if (h == nullptr) {
342 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
343 return nullptr;
344 }
345
346 return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
347 }
348
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)349 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
350 if (h == nullptr) {
351 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
352 return nullptr;
353 }
354
355 tensorflow::AbstractTensorInterface* t =
356 tensorflow::unwrap(h)->Resolve(&status->status);
357 if (t == nullptr) {
358 return nullptr;
359 }
360
361 return new TF_Tensor{t};
362 }
363
TFE_TensorHandleDevicePointer(TFE_TensorHandle * h,TF_Status * status)364 void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
365 if (h == nullptr) {
366 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
367 return nullptr;
368 }
369 tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
370 tensorflow::unwrap(h);
371 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
372 if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
373 return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
374 unwrapped_handle)
375 ->DevicePointer();
376 }
377 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
378 if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
379 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
380 return nullptr;
381 }
382 tensorflow::TensorHandle* handle =
383 tensorflow::TensorHandleFromInterface(unwrapped_handle);
384
385 if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
386 status->status = tensorflow::errors::InvalidArgument(
387 "TFE_TensorHandleDevicePointer may not be called on a ",
388 handle->TypeString(), " tensor handle.");
389 return nullptr;
390 }
391 tensorflow::Device* device(handle->device());
392 if (device != nullptr) {
393 status->status = device->Sync();
394 if (!status->status.ok()) {
395 return nullptr;
396 }
397 }
398 const tensorflow::Tensor* tensor;
399 status->status = handle->Tensor(&tensor);
400 if (!status->status.ok()) {
401 return nullptr;
402 }
403 return const_cast<void*>(
404 static_cast<const void*>(tensor->tensor_data().data()));
405 }
406
407 namespace tensorflow {
408 namespace {
409 class CustomDeviceAPI : public tensorflow::CustomDevice {
410 public:
CustomDeviceAPI(TFE_Context * context,TFE_CustomDevice device,void * info,string name)411 CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
412 string name)
413 : context_(context), device_(device), info_(info), name_(name) {}
414
~CustomDeviceAPI()415 ~CustomDeviceAPI() override { device_.delete_device(info_); }
416
name()417 const string& name() override { return name_; }
418
CopyTensorToDevice(ImmediateExecutionTensorHandle * handle,ImmediateExecutionTensorHandle ** result)419 tensorflow::Status CopyTensorToDevice(
420 ImmediateExecutionTensorHandle* handle,
421 ImmediateExecutionTensorHandle** result) override {
422 handle->Ref();
423 TF_Status status;
424 TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
425 context_, tensorflow::wrap(handle), &status, info_);
426 handle->Release();
427 if (!status.status.ok()) return status.status;
428 *result = tensorflow::unwrap(result_handle);
429 (*result)->Ref();
430 TFE_DeleteTensorHandle(result_handle);
431 return status.status;
432 }
433
CopyTensorFromDevice(ImmediateExecutionTensorHandle * handle,const tensorflow::string & target_device_name,ImmediateExecutionTensorHandle ** result)434 tensorflow::Status CopyTensorFromDevice(
435 ImmediateExecutionTensorHandle* handle,
436 const tensorflow::string& target_device_name,
437 ImmediateExecutionTensorHandle** result) override {
438 TF_Status status;
439 handle->Ref();
440 TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
441 context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
442 info_);
443 handle->Release();
444 if (!status.status.ok()) return status.status;
445 *result = tensorflow::unwrap(result_handle);
446 (*result)->Ref();
447 TFE_DeleteTensorHandle(result_handle);
448 return status.status;
449 }
450
Execute(const ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)451 tensorflow::Status Execute(const ImmediateExecutionOperation* op,
452 ImmediateExecutionTensorHandle** retvals,
453 int* num_retvals) override {
454 std::vector<TFE_TensorHandle*> outputs(*num_retvals);
455 TF_Status status;
456 device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
457 info_);
458 if (status.status.ok()) {
459 for (int i = 0; i < *num_retvals; ++i) {
460 retvals[i] = tensorflow::unwrap(outputs[i]);
461 retvals[i]->Ref();
462 TFE_DeleteTensorHandle(outputs[i]);
463 }
464 }
465 return status.status;
466 }
467
Pack(absl::Span<ImmediateExecutionTensorHandle * > handles,ImmediateExecutionTensorHandle ** result)468 tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
469 ImmediateExecutionTensorHandle** result) override {
470 TF_Status status;
471 *result = tensorflow::unwrap(device_.pack(context_,
472 tensorflow::wrap(handles.data()),
473 handles.size(), &status, info_));
474 return status.status;
475 }
476
477 private:
478 TFE_Context* context_;
479 TFE_CustomDevice device_;
480 void* info_;
481 string name_;
482 };
483
484 // An adapter which wraps the shape/data produced by C custom devices and uses
485 // it to implement custom device methods.
486 class CAPICustomDeviceTensorHandle
487 : public tensorflow::CustomDeviceTensorHandle {
488 public:
489 using NumDimsCallback = std::function<int(TF_Status* status)>;
490 using DimCallback = std::function<int64_t(int dim_index, TF_Status* status)>;
491 using DeallocatorCallback = std::function<void()>;
492
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext * context,tensorflow::CustomDevice * device,tensorflow::DataType dtype,void * data,NumDimsCallback num_dims_callback,DimCallback dim_callback,DeallocatorCallback deallocator)493 CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
494 tensorflow::CustomDevice* device,
495 tensorflow::DataType dtype, void* data,
496 NumDimsCallback num_dims_callback,
497 DimCallback dim_callback,
498 DeallocatorCallback deallocator)
499 : tensorflow::CustomDeviceTensorHandle(context, device, dtype),
500 data_(data),
501 num_dims_callback_(num_dims_callback),
502 dim_callback_(dim_callback),
503 deallocator_(deallocator) {}
504
~CAPICustomDeviceTensorHandle()505 ~CAPICustomDeviceTensorHandle() override { deallocator_(); }
DevicePointer() const506 void* DevicePointer() const override { return data_; }
NumDims(int * num_dims) const507 Status NumDims(int* num_dims) const override {
508 TF_Status s;
509 *num_dims = num_dims_callback_(&s);
510 return s.status;
511 }
Dim(int dim_index,int64 * dim) const512 Status Dim(int dim_index, int64* dim) const override {
513 TF_Status s;
514 *dim = dim_callback_(dim_index, &s);
515 return s.status;
516 }
517
518 private:
519 void* const data_;
520 NumDimsCallback num_dims_callback_;
521 DimCallback dim_callback_;
522 DeallocatorCallback deallocator_;
523 };
524
525 } // namespace
526 } // namespace tensorflow
527
TFE_NewCustomDeviceTensorHandle(TFE_Context * ctx,const char * device_name,TF_DataType dtype,void * data,int (* num_dims_callback)(void * data,void * arg,TF_Status * status),int64_t (* dim_callback)(void * data,int dim_index,void * arg,TF_Status * status),void (* deallocator)(void * data,void * arg),void * arg,TF_Status * status)528 TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
529 TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data,
530 int (*num_dims_callback)(void* data, void* arg, TF_Status* status),
531 int64_t (*dim_callback)(void* data, int dim_index, void* arg,
532 TF_Status* status),
533 void (*deallocator)(void* data, void* arg), void* arg, TF_Status* status) {
534 tensorflow::EagerContext* context =
535 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
536 tensorflow::CustomDevice* device = nullptr;
537 if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
538 &device)) {
539 deallocator(data, arg);
540 status->status =
541 tensorflow::errors::InvalidArgument(device_name, " unknown device.");
542 return nullptr;
543 }
544 return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
545 context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
546 /*num_dims_callback=*/
547 [num_dims_callback, data, arg](TF_Status* status) {
548 return num_dims_callback(data, arg, status);
549 },
550 /*dim_callback=*/
551 [dim_callback, data, arg](int dim_index, TF_Status* status) {
552 return dim_callback(data, dim_index, arg, status);
553 },
554 /*deallocator=*/[deallocator, data, arg]() { deallocator(data, arg); }));
555 }
556
TFE_NewTensorHandleFromDeviceMemory(TFE_Context * ctx,const char * device_name,TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg,TF_Status * status)557 TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
558 TFE_Context* ctx, const char* device_name, TF_DataType dtype,
559 const int64_t* dims, int num_dims, void* data, size_t len,
560 void (*deallocator)(void* data, size_t len, void* arg),
561 void* deallocator_arg, TF_Status* status) {
562 tensorflow::Device* device = nullptr;
563 tensorflow::EagerContext* context =
564 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
565 status->status = context->FindDeviceFromName(device_name, &device);
566 tensorflow::CustomDevice* custom_device = nullptr;
567 if (!status->status.ok()) {
568 if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(
569 device_name, &custom_device)) {
570 deallocator(data, len, deallocator_arg);
571 status->status =
572 tensorflow::errors::InvalidArgument(device_name, " unknown device.");
573 return nullptr;
574 } else {
575 status->status = tensorflow::Status::OK();
576 }
577 }
578 std::vector<tensorflow::int64> dimvec(num_dims);
579 for (int i = 0; i < num_dims; ++i) {
580 dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
581 }
582 if (custom_device != nullptr) {
583 return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
584 context, custom_device,
585 *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
586 /*num_dims_callback=*/
587 [num_dims](TF_Status* status) { return num_dims; },
588 /*dim_callback=*/
589 [dimvec](int dim_index, TF_Status* status) {
590 return dimvec[dim_index];
591 },
592 /*deallocator=*/
593 [data, len, deallocator, deallocator_arg]() {
594 deallocator(data, len, deallocator_arg);
595 }));
596 }
597
598 // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
599 // the device?
600 TF_ManagedBuffer* buf =
601 new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
602 /*owns_memory=*/false);
603
604 tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
605 tensorflow::TensorShape(dimvec), buf);
606 buf->Unref();
607 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
608 std::move(t), device, device, context));
609 }
610
611 // This function will block till the operation that produces `h` has
612 // completed. This is only valid on local TFE_TensorHandles. Returns the size in
613 // bytes of the memory pointed to by the device pointer returned above.
TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle * h,TF_Status * status)614 size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
615 TF_Status* status) {
616 if (h == nullptr) {
617 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
618 return 0;
619 }
620 tensorflow::TensorHandle* handle =
621 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
622 if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
623 status->status = tensorflow::errors::InvalidArgument(
624 "TFE_TensorHandleDeviceMemorySize may not be called on a ",
625 handle->TypeString(), " tensor handle.");
626 return 0;
627 }
628 const tensorflow::Tensor* tensor;
629 status->status = handle->Tensor(&tensor);
630 if (!status->status.ok()) {
631 return 0;
632 }
633 return tensor->TotalBytes();
634 }
635
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)636 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
637 TF_Status* status) {
638 tensorflow::ImmediateExecutionOperation* new_op =
639 tensorflow::unwrap(ctx)->CreateOperation();
640 status->status = new_op->Reset(op_or_function_name, nullptr);
641 if (!status->status.ok()) {
642 new_op->Release();
643 new_op = nullptr;
644 }
645 return tensorflow::wrap(new_op);
646 }
647
TFE_DeleteOp(TFE_Op * op)648 void TFE_DeleteOp(TFE_Op* op) {
649 if (op == nullptr) {
650 return;
651 }
652
653 tensorflow::unwrap(op)->Release();
654 }
655
TFE_OpGetName(const TFE_Op * op,TF_Status * status)656 const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
657 return tensorflow::unwrap(op)->Name().c_str();
658 }
659
TFE_OpGetContext(const TFE_Op * op,TF_Status * status)660 TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
661 return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
662 }
663
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)664 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
665 status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
666 }
667
TFE_OpGetDevice(const TFE_Op * op,TF_Status * status)668 const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
669 return tensorflow::unwrap(op)->DeviceName().c_str();
670 }
671
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)672 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
673 status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
674 }
675
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)676 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
677 TF_Status* status) {
678 status->status = tensorflow::unwrap(op)->AddInputList(
679 {reinterpret_cast<tensorflow::AbstractTensorHandle**>(
680 tensorflow::unwrap(inputs)),
681 static_cast<size_t>(num_inputs)});
682 }
683
TFE_OpGetFlatInputCount(const TFE_Op * op,TF_Status * status)684 extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
685 return tensorflow::unwrap(op)->GetInputs().size();
686 }
687
TFE_OpGetFlatInput(const TFE_Op * op,int index,TF_Status * status)688 extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
689 TF_Status* status) {
690 return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
691 }
692
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)693 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
694 unsigned char* is_list, TF_Status* status) {
695 TF_AttrType ret = TF_ATTR_INT;
696 const tensorflow::AttrTypeMap* attr_types_;
697 bool is_function;
698 status->status = tensorflow::AttrTypeMapForOp(
699 tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
700 if (!status->status.ok()) {
701 return ret;
702 }
703 status->status =
704 tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
705 return ret;
706 }
707
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)708 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
709 const char* op_or_function_name,
710 const char* attr_name, unsigned char* is_list,
711 TF_Status* status) {
712 TF_AttrType ret;
713 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
714 if (status->status.ok()) {
715 ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
716 } else {
717 ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
718 }
719 TFE_DeleteOp(op);
720 return ret;
721 }
722
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)723 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
724 size_t length) {
725 auto s = tensorflow::unwrap(op)->SetAttrString(
726 attr_name, static_cast<const char*>(value), length);
727 if (!s.ok()) {
728 LOG(WARNING) << "Unable to set attribute: " << attr_name;
729 }
730 }
731
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)732 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
733 auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
734 if (!s.ok()) {
735 LOG(WARNING) << "Unable to set attribute: " << attr_name;
736 }
737 }
738
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)739 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
740 auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
741 if (!s.ok()) {
742 LOG(WARNING) << "Unable to set attribute: " << attr_name;
743 }
744 }
745
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)746 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
747 auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
748 (value == 0) ? false : true);
749 if (!s.ok()) {
750 LOG(WARNING) << "Unable to set attribute: " << attr_name;
751 }
752 }
753
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)754 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
755 auto s = tensorflow::unwrap(op)->SetAttrType(
756 attr_name, static_cast<tensorflow::DataType>(value));
757 if (!s.ok()) {
758 LOG(WARNING) << "Unable to set attribute: " << attr_name;
759 }
760 }
761
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)762 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
763 const int num_dims, TF_Status* out_status) {
764 out_status->status =
765 tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
766 }
767
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)768 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
769 const TFE_Op* value) {
770 auto s = tensorflow::unwrap(op)->SetAttrFunction(
771 attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
772 if (!s.ok()) {
773 LOG(WARNING) << "Unable to set attribute: " << attr_name;
774 }
775 }
776
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)777 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
778 const char* data, size_t length) {
779 auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
780 if (!s.ok()) {
781 LOG(WARNING) << "Unable to set attribute: " << attr_name;
782 }
783 }
784
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)785 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
786 TF_Status* status) {
787 tensorflow::Tensor t;
788 status->status = TF_TensorToTensor(tensor, &t);
789 tensorflow::TensorInterface interface(t);
790 status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
791 }
792
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)793 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
794 const void* const* values, const size_t* lengths,
795 int num_values) {
796 auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
797 num_values);
798 if (!s.ok()) {
799 LOG(WARNING) << "Unable to set attribute: " << attr_name;
800 }
801 }
802
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)803 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
804 const float* values, int num_values) {
805 auto s =
806 tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
807 if (!s.ok()) {
808 LOG(WARNING) << "Unable to set attribute: " << attr_name;
809 }
810 }
811
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)812 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
813 const int64_t* values, int num_values) {
814 auto s =
815 tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
816 if (!s.ok()) {
817 LOG(WARNING) << "Unable to set attribute: " << attr_name;
818 }
819 }
820
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)821 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
822 const TF_DataType* values, int num_values) {
823 auto s = tensorflow::unwrap(op)->SetAttrTypeList(
824 attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
825 num_values);
826 if (!s.ok()) {
827 LOG(WARNING) << "Unable to set attribute: " << attr_name;
828 }
829 }
830
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)831 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
832 const unsigned char* values, int num_values) {
833 auto s =
834 tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
835 if (!s.ok()) {
836 LOG(WARNING) << "Unable to set attribute: " << attr_name;
837 }
838 }
839
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)840 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
841 const int64_t** dims, const int* num_dims,
842 int num_values, TF_Status* out_status) {
843 out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
844 attr_name, dims, num_dims, num_values);
845 }
846
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)847 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
848 const TFE_Op** value, int num_values) {
849 auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
850 attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
851 tensorflow::unwrap(value)),
852 static_cast<size_t>(num_values)});
853 if (!s.ok()) {
854 LOG(WARNING) << "Unable to set attribute: " << attr_name;
855 }
856 }
857
TFE_OpSetAttrValueProto(const TFE_Op * op,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)858 void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
859 const void* proto, size_t proto_len,
860 TF_Status* status) {
861 tensorflow::AttrValue attr_value;
862 if (!attr_value.ParseFromArray(proto, proto_len)) {
863 status->status =
864 tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
865 return;
866 }
867 if (op == nullptr) {
868 status->status = tensorflow::errors::InvalidArgument(
869 "Got a null or uninitialized `op` argument");
870 return;
871 }
872 tensorflow::EagerOperation* operation =
873 OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
874 operation->MutableAttrs()->Set(attr_name, attr_value);
875 }
876
TFE_OpGetInputLength(TFE_Op * op,const char * input_name,TF_Status * status)877 TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
878 const char* input_name,
879 TF_Status* status) {
880 int ret = -1;
881 status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
882 return ret;
883 }
884
TFE_OpGetOutputLength(TFE_Op * op,const char * output_name,TF_Status * status)885 TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
886 const char* output_name,
887 TF_Status* status) {
888 int ret = -1;
889 status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
890 return ret;
891 }
892
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)893 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
894 TF_Status* status) {
895 tensorflow::ImmediateExecutionOperation* unwrapped_op =
896 tensorflow::unwrap(op);
897
898 status->status =
899 unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
900 unwrapped_op,
901 reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
902 retvals),
903 num_retvals);
904 }
905
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)906 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
907 TFE_Context* ctx,
908 const char* device_name,
909 TF_Status* status) {
910 if (h == nullptr) {
911 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
912 return nullptr;
913 }
914
915 auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
916 tensorflow::unwrap(h), device_name, &status->status);
917 if (status->status.ok()) {
918 return tensorflow::wrap(result);
919 }
920 return nullptr;
921 }
922
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)923 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
924 const char* serialized_function_def, size_t size,
925 TF_Status* status) {
926 tensorflow::FunctionDef function_def;
927 if (!function_def.ParseFromArray(serialized_function_def, size)) {
928 status->status =
929 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
930 return;
931 }
932
933 AnnotateEagerRuntimeConstructionContext(function_def);
934 status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
935 }
936
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)937 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
938 TF_Status* status) {
939 AnnotateEagerRuntimeConstructionContext(function->fdef);
940 status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
941 function->fdef, function->stack_traces);
942 }
943
TFE_ContextRemoveFunction(TFE_Context * ctx,const char * name,TF_Status * status)944 void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
945 TF_Status* status) {
946 status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
947 }
948
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)949 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
950 return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
951 }
952
TFE_ContextEnableRunMetadata(TFE_Context * ctx)953 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
954 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
955 }
956
TFE_ContextDisableRunMetadata(TFE_Context * ctx)957 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
958 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
959 }
960
961 } // extern "C"
962
TFE_NewTensorHandle(const tensorflow::Tensor & t,TF_Status * status)963 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
964 TF_Status* status) {
965 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
966 }
967
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)968 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
969 TF_Status* status) {
970 auto* context = tensorflow::unwrap(ctx);
971 status->status = context->AsyncWait();
972 if (!status->status.ok()) return;
973 auto run_metadata = context->ExportRunMetadata();
974 status->status = MessageToBuffer(*run_metadata, buf);
975 }
976
977 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)978 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
979 TF_Status* status) {
980 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
981 for (const auto& attr : func.attr()) {
982 if (!status->status.ok()) return nullptr;
983 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
984 if (!status->status.ok()) return nullptr;
985 }
986 return func_op;
987 }
988 } // namespace
989
TFE_ContextStartStep(TFE_Context * ctx)990 void TFE_ContextStartStep(TFE_Context* ctx) {
991 tensorflow::unwrap(ctx)->StartStep();
992 }
993
TFE_ContextEndStep(TFE_Context * ctx)994 void TFE_ContextEndStep(TFE_Context* ctx) {
995 tensorflow::unwrap(ctx)->EndStep();
996 }
997
TFE_OpGetAttrs(const TFE_Op * op)998 const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
999 return tensorflow::wrap(tensorflow::unwrap(op)->GetOpAttrs());
1000 }
1001
TFE_OpAddAttrs(TFE_Op * op,const TFE_OpAttrs * attrs)1002 void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
1003 tensorflow::unwrap(op)->AddAttrs(tensorflow::unwrap(attrs));
1004 }
1005
TFE_OpAttrsSerialize(const TFE_OpAttrs * attrs,TF_Buffer * buf,TF_Status * status)1006 void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
1007 TF_Status* status) {
1008 tensorflow::NameAttrList name_and_attrs;
1009 tensorflow::unwrap(attrs)->GetNameAttrList(&name_and_attrs);
1010 status->status = MessageToBuffer(name_and_attrs, buf);
1011 }
1012
1013 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)1014 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
1015 const tensorflow::AttrValue& default_value,
1016 const char* attr_name, TF_Status* status) {
1017 switch (default_value.value_case()) {
1018 case tensorflow::AttrValue::kS: {
1019 const string& v = default_value.s();
1020 TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
1021 break;
1022 }
1023 case tensorflow::AttrValue::kI:
1024 TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1025 break;
1026 case tensorflow::AttrValue::kF:
1027 TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1028 break;
1029 case tensorflow::AttrValue::kB:
1030 TFE_OpSetAttrBool(op, attr_name, default_value.b());
1031 break;
1032 case tensorflow::AttrValue::kType:
1033 TFE_OpSetAttrType(op, attr_name,
1034 static_cast<TF_DataType>(default_value.type()));
1035 break;
1036 case tensorflow::AttrValue::kShape: {
1037 const auto& tensor_shape = default_value.shape();
1038 if (tensor_shape.unknown_rank()) {
1039 TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1040 } else {
1041 const auto num_dims = tensor_shape.dim_size();
1042 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1043 for (int i = 0; i < num_dims; ++i) {
1044 dims[i] = tensor_shape.dim(i).size();
1045 }
1046 TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1047 }
1048 } break;
1049 case tensorflow::AttrValue::kFunc: {
1050 const auto func_op = GetFunc(ctx, default_value.func(), status);
1051 if (!status->status.ok()) return;
1052 // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1053 // require TFE_Op* and just convert it internally a NameAttrValue, so
1054 // consider adding an overload to the C API to make this case easier.
1055 TFE_OpSetAttrFunction(op, attr_name, func_op);
1056 TFE_DeleteOp(func_op);
1057 } break;
1058 case tensorflow::AttrValue::kList: {
1059 // String
1060 if (const int s_size = default_value.list().s_size()) {
1061 absl::InlinedVector<const void*, 4> values_vector;
1062 absl::InlinedVector<size_t, 4> lengths_vector;
1063 for (int i = 0; i < s_size; ++i) {
1064 const string& v = default_value.list().s(i);
1065 values_vector.push_back(v.data());
1066 lengths_vector.push_back(v.size());
1067 }
1068 TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
1069 lengths_vector.data(), s_size);
1070 }
1071
1072 // Int
1073 if (const int i_size = default_value.list().i_size()) {
1074 absl::InlinedVector<int64_t, 4> i_vector;
1075 for (int i = 0; i < i_size; ++i) {
1076 i_vector.push_back(default_value.list().i(i));
1077 }
1078 TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
1079 }
1080 // Float
1081 if (const int f_size = default_value.list().f_size()) {
1082 absl::InlinedVector<float, 4> f_vector;
1083 for (int i = 0; i < f_size; ++i) {
1084 f_vector.push_back(default_value.list().f(i));
1085 }
1086 TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
1087 }
1088 // Bool
1089 if (const int b_size = default_value.list().b_size()) {
1090 absl::InlinedVector<unsigned char, 4> b_vector;
1091 for (int i = 0; i < b_size; i++) {
1092 b_vector.push_back(default_value.list().b(i));
1093 }
1094 TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
1095 }
1096 // Type
1097 if (const int type_size = default_value.list().type_size()) {
1098 absl::InlinedVector<unsigned int, 4> type_vector;
1099 for (int i = 0; i < type_size; ++i) {
1100 type_vector.push_back(default_value.list().type(i));
1101 }
1102 TFE_OpSetAttrTypeList(
1103 op, attr_name,
1104 reinterpret_cast<const TF_DataType*>(type_vector.data()),
1105 type_size);
1106 }
1107
1108 // Rest are not supported.
1109 if (default_value.list().shape_size() > 0 ||
1110 default_value.list().func_size() > 0 ||
1111 default_value.list().tensor_size() > 0) {
1112 TF_SetStatus(
1113 status, TF_UNIMPLEMENTED,
1114 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1115 default_value.DebugString())
1116 .data());
1117 }
1118 } break;
1119 case tensorflow::AttrValue::kTensor:
1120 TF_FALLTHROUGH_INTENDED;
1121 case tensorflow::AttrValue::kPlaceholder:
1122 TF_FALLTHROUGH_INTENDED;
1123 case tensorflow::AttrValue::VALUE_NOT_SET:
1124 TF_SetStatus(
1125 status, TF_UNIMPLEMENTED,
1126 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1127 default_value.DebugString())
1128 .data());
1129 }
1130 }
1131 } // namespace tensorflow
1132
1133 namespace {
DefaultCustomDevicePack(TFE_Context * context,TFE_TensorHandle ** handles,int num_handles,TF_Status * status,void * device_info)1134 TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
1135 TFE_TensorHandle** handles,
1136 int num_handles, TF_Status* status,
1137 void* device_info) {
1138 TF_SetStatus(status, TF_UNIMPLEMENTED,
1139 "This custom device does not support packing tensors.");
1140 return nullptr;
1141 }
1142 } // namespace
1143
1144 extern "C" {
1145
TFE_RegisterCustomDevice(TFE_Context * ctx,TFE_CustomDevice device,const char * device_name,void * device_info,TF_Status * status)1146 void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
1147 const char* device_name, void* device_info,
1148 TF_Status* status) {
1149 // Fill in default values for optional functionality.
1150 if (device.pack == nullptr) {
1151 device.pack = &DefaultCustomDevicePack;
1152 }
1153 auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
1154 ctx, device, device_info, device_name);
1155 status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
1156 device_name, std::move(custom_device));
1157 }
1158
1159 } // extern "C"
1160