• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/eager/abstract_tensor_handle.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/immediate_execution_operation.h"
32 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_op_internal.h"
35 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
36 #include "tensorflow/c/tf_tensor_internal.h"
37 #include "tensorflow/core/common_runtime/copy_tensor.h"
38 #include "tensorflow/core/common_runtime/device.h"
39 #include "tensorflow/core/common_runtime/device_factory.h"
40 #include "tensorflow/core/common_runtime/device_mgr.h"
41 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
42 #include "tensorflow/core/common_runtime/eager/context.h"
43 #include "tensorflow/core/common_runtime/eager/custom_device.h"
44 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
45 #include "tensorflow/core/common_runtime/eager/execute.h"
46 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
47 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/framework/attr_value.pb.h"
50 #include "tensorflow/core/framework/device_attributes.pb.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/node_def_util.h"
53 #include "tensorflow/core/framework/rendezvous.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/platform/casts.h"
57 #include "tensorflow/core/platform/errors.h"
58 #include "tensorflow/core/platform/platform.h"
59 #include "tensorflow/core/platform/status.h"
60 #include "tensorflow/core/profiler/lib/traceme.h"
61 #include "tensorflow/core/protobuf/error_codes.pb.h"
62 #include "tensorflow/core/public/version.h"
63 
64 // "tensorflow/core/platform/platform.h" must be included first before using
65 // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
66 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
67 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
68 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
69 #endif  // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
70 
71 #if !defined(IS_MOBILE_PLATFORM)
72 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
73 #endif  // !IS_MOBILE_PLATFORM
74 
75 using tensorflow::string;
76 
77 namespace {
78 
DeviceName(const tensorflow::Device * d)79 string DeviceName(const tensorflow::Device* d) {
80   return (d == nullptr) ? "cpu:0" : d->name();
81 }
82 
83 // Annotate eager runtime construction context to the given `function_def` as
84 // an attribute.
AnnotateEagerRuntimeConstructionContext(tensorflow::FunctionDef & function_def)85 void AnnotateEagerRuntimeConstructionContext(
86     tensorflow::FunctionDef& function_def) {
87   tensorflow::AttrValue value;
88   SetAttrValue("kEagerRuntime", &value);
89   (*function_def.mutable_attr())["_construction_context"] = value;
90 }
91 
92 }  // namespace
93 
94 extern "C" {
95 
TFE_NewContextOptions()96 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
97 
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)98 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
99                                  size_t proto_len, TF_Status* status) {
100   TF_SetConfig(&options->session_options, proto, proto_len, status);
101 }
102 
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)103 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
104                                 unsigned char enable) {
105   options->async = enable;
106 }
107 
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)108 void TFE_ContextOptionsSetDevicePlacementPolicy(
109     TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
110   options->device_placement_policy = policy;
111 }
112 
TFE_DeleteContextOptions(TFE_ContextOptions * options)113 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
114 
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)115 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
116   if (opts->use_tfrt) {
117 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
118     tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
119         opts->session_options.options,
120         static_cast<tensorflow::ContextDevicePlacementPolicy>(
121             opts->device_placement_policy),
122         opts->async);
123 #if !defined(IS_MOBILE_PLATFORM)
124     tfrt_context->SetDistributedManager(
125         tfrt::tf::CreateDistributedManagerContext(
126             tfrt_context->GetCoreRuntime()->GetHostContext()));
127 #endif  // !IS_MOBILE_PLATFORM
128     return tensorflow::wrap(tfrt_context);
129 #else
130     status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
131     return nullptr;
132 #endif  // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
133   }
134   std::vector<std::unique_ptr<tensorflow::Device>> devices;
135   status->status = tensorflow::DeviceFactory::AddDevices(
136       opts->session_options.options, "/job:localhost/replica:0/task:0",
137       &devices);
138   if (!status->status.ok()) return nullptr;
139   std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
140       new tensorflow::StaticDeviceMgr(std::move(devices)));
141 
142   tensorflow::Rendezvous* r =
143       new tensorflow::IntraProcessRendezvous(device_mgr.get());
144   tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
145       opts->session_options.options,
146       static_cast<tensorflow::ContextDevicePlacementPolicy>(
147           opts->device_placement_policy),
148       opts->async, device_mgr.release(),
149       /*device_mgr_owned*/ true, r);
150 #if !defined(IS_MOBILE_PLATFORM)
151   eager_context->SetDistributedManager(
152       std::make_unique<tensorflow::EagerContextDistributedManager>(
153           eager_context));
154 #endif  // !IS_MOBILE_PLATFORM
155   return tensorflow::wrap(eager_context);
156 }
157 
TFE_DeleteContext(TFE_Context * ctx)158 void TFE_DeleteContext(TFE_Context* ctx) {
159   if (ctx == nullptr) {
160     return;
161   }
162 
163   // ctx->RefCountIsOne() should be true here.
164   tensorflow::unwrap(ctx)->Release();
165 }
166 
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)167 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
168   TF_DeviceList* l = new TF_DeviceList;
169   tensorflow::unwrap(ctx)->ListDevices(&l->response);
170   return l;
171 }
172 
TFE_ContextClearCaches(TFE_Context * ctx)173 void TFE_ContextClearCaches(TFE_Context* ctx) {
174   tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
175 }
176 
177 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)178 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
179                                                    int keep_alive_secs,
180                                                    const void* proto,
181                                                    size_t proto_len,
182                                                    TF_Status* status) {
183 #if defined(IS_MOBILE_PLATFORM)
184   status->status = tensorflow::errors::Unimplemented(
185       "TFE_ContextSetServerDef not supported on mobile");
186 #else   // !defined(IS_MOBILE_PLATFORM)
187   tensorflow::ServerDef server_def;
188   if (!server_def.ParseFromArray(proto, proto_len)) {
189     status->status = tensorflow::errors::InvalidArgument(
190         "Invalid tensorflow.ServerDef protocol buffer");
191     return;
192   }
193   status->status =
194       tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
195           server_def, /*reset_context=*/true, keep_alive_secs);
196 #endif  // !IS_MOBILE_PLATFORM
197 }
198 
TFE_ContextUpdateServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)199 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
200                                                       int keep_alive_secs,
201                                                       const void* proto,
202                                                       size_t proto_len,
203                                                       TF_Status* status) {
204 #if defined(IS_MOBILE_PLATFORM)
205   status->status = tensorflow::errors::Unimplemented(
206       "TFE_ContextSetServerDef not supported on mobile");
207 #else   // !defined(IS_MOBILE_PLATFORM)
208   tensorflow::ServerDef server_def;
209   tensorflow::EagerContext* context =
210       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
211   if (!server_def.ParseFromArray(proto, proto_len)) {
212     status->status = tensorflow::errors::InvalidArgument(
213         "Invalid tensorflow.ServerDef protocol buffer");
214     return;
215   } else if (context->GetContextId() ==
216              tensorflow::EagerContext::kInvalidContextId) {
217     status->status = tensorflow::errors::InvalidArgument(
218         "Trying to update a context with invalid context id.");
219   }
220   status->status =
221       tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
222           server_def, /*reset_context=*/false, keep_alive_secs);
223 #endif  // !IS_MOBILE_PLATFORM
224 }
225 
TFE_ContextCheckAlive(TFE_Context * ctx,const char * worker_name,TF_Status * status)226 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
227                                                  const char* worker_name,
228                                                  TF_Status* status) {
229 #if defined(IS_MOBILE_PLATFORM)
230   status->status = tensorflow::errors::Unimplemented(
231       "TFE_ContextSetServerDef not supported on mobile");
232   return false;
233 #else   // !defined(IS_MOBILE_PLATFORM)
234   bool is_alive;
235   status->status =
236       tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
237           worker_name, &is_alive);
238   return is_alive;
239 #endif  // !IS_MOBILE_PLATFORM
240 }
241 
TFE_ContextAsyncWait(TFE_Context * ctx,TF_Status * status)242 TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
243                                                 TF_Status* status) {
244 #if defined(IS_MOBILE_PLATFORM)
245   status->status = tensorflow::Status::OK();
246 #else   // !defined(IS_MOBILE_PLATFORM)
247   status->status = tensorflow::unwrap(ctx)->AsyncWait();
248 #endif  // !IS_MOBILE_PLATFORM
249 }
250 
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)251 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
252     TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
253   tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
254       static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
255 }
256 
257 // Note: this function looks up a thread local policy. So it should be called in
258 // the appropriate client thread. In particular, in async mode, it may not be
259 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)260 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
261     TFE_Context* ctx) {
262   return static_cast<TFE_ContextDevicePlacementPolicy>(
263       tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
264 }
265 
TFE_NewTensorHandle(const TF_Tensor * t,TF_Status * status)266 TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
267   tensorflow::Tensor tensor;
268   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
269   if (!status->status.ok()) return nullptr;
270 
271   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
272 }
273 
TFE_DeleteTensorHandle(TFE_TensorHandle * h)274 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
275   if (h == nullptr) return;
276 
277   tensorflow::profiler::TraceMe activity(
278       "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
279   if (h) {
280     tensorflow::unwrap(h)->Release();
281   }
282 }
283 
TFE_TensorHandleDataType(TFE_TensorHandle * h)284 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
285   return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
286 }
287 
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)288 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
289   if (h == nullptr) {
290     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
291     return -1;
292   }
293 
294   int num_dims = -1;
295   status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
296   return num_dims;
297 }
298 
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)299 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
300   if (h == nullptr) {
301     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
302     return -1;
303   }
304 
305   tensorflow::int64 num_elements = -1;
306   status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
307   return num_elements;
308 }
309 
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)310 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
311                             TF_Status* status) {
312   if (h == nullptr) {
313     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
314     return -1;
315   }
316 
317   tensorflow::int64 dim = -1;
318   status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
319   return dim;
320 }
321 
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)322 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
323   if (h == nullptr) {
324     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
325     return nullptr;
326   }
327   return tensorflow::unwrap(h)->DeviceName(&status->status);
328 }
329 
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)330 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
331                                               TF_Status* status) {
332   if (h == nullptr) {
333     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
334     return nullptr;
335   }
336   return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
337 }
338 
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)339 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
340     TFE_TensorHandle* h, TF_Status* status) {
341   if (h == nullptr) {
342     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
343     return nullptr;
344   }
345 
346   return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
347 }
348 
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)349 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
350   if (h == nullptr) {
351     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
352     return nullptr;
353   }
354 
355   tensorflow::AbstractTensorInterface* t =
356       tensorflow::unwrap(h)->Resolve(&status->status);
357   if (t == nullptr) {
358     return nullptr;
359   }
360 
361   return new TF_Tensor{t};
362 }
363 
TFE_TensorHandleDevicePointer(TFE_TensorHandle * h,TF_Status * status)364 void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
365   if (h == nullptr) {
366     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
367     return nullptr;
368   }
369   tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
370       tensorflow::unwrap(h);
371   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
372   if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
373     return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
374                unwrapped_handle)
375         ->DevicePointer();
376   }
377   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
378   if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
379     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
380     return nullptr;
381   }
382   tensorflow::TensorHandle* handle =
383       tensorflow::TensorHandleFromInterface(unwrapped_handle);
384 
385   if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
386     status->status = tensorflow::errors::InvalidArgument(
387         "TFE_TensorHandleDevicePointer may not be called on a ",
388         handle->TypeString(), " tensor handle.");
389     return nullptr;
390   }
391   tensorflow::Device* device(handle->device());
392   if (device != nullptr) {
393     status->status = device->Sync();
394     if (!status->status.ok()) {
395       return nullptr;
396     }
397   }
398   const tensorflow::Tensor* tensor;
399   status->status = handle->Tensor(&tensor);
400   if (!status->status.ok()) {
401     return nullptr;
402   }
403   return const_cast<void*>(
404       static_cast<const void*>(tensor->tensor_data().data()));
405 }
406 
407 namespace tensorflow {
408 namespace {
409 class CustomDeviceAPI : public tensorflow::CustomDevice {
410  public:
CustomDeviceAPI(TFE_Context * context,TFE_CustomDevice device,void * info,string name)411   CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
412                   string name)
413       : context_(context), device_(device), info_(info), name_(name) {}
414 
~CustomDeviceAPI()415   ~CustomDeviceAPI() override { device_.delete_device(info_); }
416 
name()417   const string& name() override { return name_; }
418 
CopyTensorToDevice(ImmediateExecutionTensorHandle * handle,ImmediateExecutionTensorHandle ** result)419   tensorflow::Status CopyTensorToDevice(
420       ImmediateExecutionTensorHandle* handle,
421       ImmediateExecutionTensorHandle** result) override {
422     handle->Ref();
423     TF_Status status;
424     TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
425         context_, tensorflow::wrap(handle), &status, info_);
426     handle->Release();
427     if (!status.status.ok()) return status.status;
428     *result = tensorflow::unwrap(result_handle);
429     (*result)->Ref();
430     TFE_DeleteTensorHandle(result_handle);
431     return status.status;
432   }
433 
CopyTensorFromDevice(ImmediateExecutionTensorHandle * handle,const tensorflow::string & target_device_name,ImmediateExecutionTensorHandle ** result)434   tensorflow::Status CopyTensorFromDevice(
435       ImmediateExecutionTensorHandle* handle,
436       const tensorflow::string& target_device_name,
437       ImmediateExecutionTensorHandle** result) override {
438     TF_Status status;
439     handle->Ref();
440     TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
441         context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
442         info_);
443     handle->Release();
444     if (!status.status.ok()) return status.status;
445     *result = tensorflow::unwrap(result_handle);
446     (*result)->Ref();
447     TFE_DeleteTensorHandle(result_handle);
448     return status.status;
449   }
450 
Execute(const ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)451   tensorflow::Status Execute(const ImmediateExecutionOperation* op,
452                              ImmediateExecutionTensorHandle** retvals,
453                              int* num_retvals) override {
454     std::vector<TFE_TensorHandle*> outputs(*num_retvals);
455     TF_Status status;
456     device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
457                     info_);
458     if (status.status.ok()) {
459       for (int i = 0; i < *num_retvals; ++i) {
460         retvals[i] = tensorflow::unwrap(outputs[i]);
461         retvals[i]->Ref();
462         TFE_DeleteTensorHandle(outputs[i]);
463       }
464     }
465     return status.status;
466   }
467 
Pack(absl::Span<ImmediateExecutionTensorHandle * > handles,ImmediateExecutionTensorHandle ** result)468   tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
469                           ImmediateExecutionTensorHandle** result) override {
470     TF_Status status;
471     *result = tensorflow::unwrap(device_.pack(context_,
472                                               tensorflow::wrap(handles.data()),
473                                               handles.size(), &status, info_));
474     return status.status;
475   }
476 
477  private:
478   TFE_Context* context_;
479   TFE_CustomDevice device_;
480   void* info_;
481   string name_;
482 };
483 
484 // An adapter which wraps the shape/data produced by C custom devices and uses
485 // it to implement custom device methods.
486 class CAPICustomDeviceTensorHandle
487     : public tensorflow::CustomDeviceTensorHandle {
488  public:
489   using NumDimsCallback = std::function<int(TF_Status* status)>;
490   using DimCallback = std::function<int64_t(int dim_index, TF_Status* status)>;
491   using DeallocatorCallback = std::function<void()>;
492 
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext * context,tensorflow::CustomDevice * device,tensorflow::DataType dtype,void * data,NumDimsCallback num_dims_callback,DimCallback dim_callback,DeallocatorCallback deallocator)493   CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
494                                tensorflow::CustomDevice* device,
495                                tensorflow::DataType dtype, void* data,
496                                NumDimsCallback num_dims_callback,
497                                DimCallback dim_callback,
498                                DeallocatorCallback deallocator)
499       : tensorflow::CustomDeviceTensorHandle(context, device, dtype),
500         data_(data),
501         num_dims_callback_(num_dims_callback),
502         dim_callback_(dim_callback),
503         deallocator_(deallocator) {}
504 
~CAPICustomDeviceTensorHandle()505   ~CAPICustomDeviceTensorHandle() override { deallocator_(); }
DevicePointer() const506   void* DevicePointer() const override { return data_; }
NumDims(int * num_dims) const507   Status NumDims(int* num_dims) const override {
508     TF_Status s;
509     *num_dims = num_dims_callback_(&s);
510     return s.status;
511   }
Dim(int dim_index,int64 * dim) const512   Status Dim(int dim_index, int64* dim) const override {
513     TF_Status s;
514     *dim = dim_callback_(dim_index, &s);
515     return s.status;
516   }
517 
518  private:
519   void* const data_;
520   NumDimsCallback num_dims_callback_;
521   DimCallback dim_callback_;
522   DeallocatorCallback deallocator_;
523 };
524 
525 }  // namespace
526 }  // namespace tensorflow
527 
TFE_NewCustomDeviceTensorHandle(TFE_Context * ctx,const char * device_name,TF_DataType dtype,void * data,int (* num_dims_callback)(void * data,void * arg,TF_Status * status),int64_t (* dim_callback)(void * data,int dim_index,void * arg,TF_Status * status),void (* deallocator)(void * data,void * arg),void * arg,TF_Status * status)528 TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
529     TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data,
530     int (*num_dims_callback)(void* data, void* arg, TF_Status* status),
531     int64_t (*dim_callback)(void* data, int dim_index, void* arg,
532                             TF_Status* status),
533     void (*deallocator)(void* data, void* arg), void* arg, TF_Status* status) {
534   tensorflow::EagerContext* context =
535       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
536   tensorflow::CustomDevice* device = nullptr;
537   if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
538                                                                     &device)) {
539     deallocator(data, arg);
540     status->status =
541         tensorflow::errors::InvalidArgument(device_name, " unknown device.");
542     return nullptr;
543   }
544   return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
545       context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
546       /*num_dims_callback=*/
547       [num_dims_callback, data, arg](TF_Status* status) {
548         return num_dims_callback(data, arg, status);
549       },
550       /*dim_callback=*/
551       [dim_callback, data, arg](int dim_index, TF_Status* status) {
552         return dim_callback(data, dim_index, arg, status);
553       },
554       /*deallocator=*/[deallocator, data, arg]() { deallocator(data, arg); }));
555 }
556 
TFE_NewTensorHandleFromDeviceMemory(TFE_Context * ctx,const char * device_name,TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg,TF_Status * status)557 TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
558     TFE_Context* ctx, const char* device_name, TF_DataType dtype,
559     const int64_t* dims, int num_dims, void* data, size_t len,
560     void (*deallocator)(void* data, size_t len, void* arg),
561     void* deallocator_arg, TF_Status* status) {
562   tensorflow::Device* device = nullptr;
563   tensorflow::EagerContext* context =
564       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
565   status->status = context->FindDeviceFromName(device_name, &device);
566   tensorflow::CustomDevice* custom_device = nullptr;
567   if (!status->status.ok()) {
568     if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(
569             device_name, &custom_device)) {
570       deallocator(data, len, deallocator_arg);
571       status->status =
572           tensorflow::errors::InvalidArgument(device_name, " unknown device.");
573       return nullptr;
574     } else {
575       status->status = tensorflow::Status::OK();
576     }
577   }
578   std::vector<tensorflow::int64> dimvec(num_dims);
579   for (int i = 0; i < num_dims; ++i) {
580     dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
581   }
582   if (custom_device != nullptr) {
583     return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
584         context, custom_device,
585         *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
586         /*num_dims_callback=*/
587         [num_dims](TF_Status* status) { return num_dims; },
588         /*dim_callback=*/
589         [dimvec](int dim_index, TF_Status* status) {
590           return dimvec[dim_index];
591         },
592         /*deallocator=*/
593         [data, len, deallocator, deallocator_arg]() {
594           deallocator(data, len, deallocator_arg);
595         }));
596   }
597 
598   // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
599   // the device?
600   TF_ManagedBuffer* buf =
601       new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
602                            /*owns_memory=*/false);
603 
604   tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
605                        tensorflow::TensorShape(dimvec), buf);
606   buf->Unref();
607   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
608       std::move(t), device, device, context));
609 }
610 
611 // This function will block till the operation that produces `h` has
612 // completed. This is only valid on local TFE_TensorHandles. Returns the size in
613 // bytes of the memory pointed to by the device pointer returned above.
TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle * h,TF_Status * status)614 size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
615                                         TF_Status* status) {
616   if (h == nullptr) {
617     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
618     return 0;
619   }
620   tensorflow::TensorHandle* handle =
621       tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
622   if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
623     status->status = tensorflow::errors::InvalidArgument(
624         "TFE_TensorHandleDeviceMemorySize may not be called on a ",
625         handle->TypeString(), " tensor handle.");
626     return 0;
627   }
628   const tensorflow::Tensor* tensor;
629   status->status = handle->Tensor(&tensor);
630   if (!status->status.ok()) {
631     return 0;
632   }
633   return tensor->TotalBytes();
634 }
635 
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)636 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
637                   TF_Status* status) {
638   tensorflow::ImmediateExecutionOperation* new_op =
639       tensorflow::unwrap(ctx)->CreateOperation();
640   status->status = new_op->Reset(op_or_function_name, nullptr);
641   if (!status->status.ok()) {
642     new_op->Release();
643     new_op = nullptr;
644   }
645   return tensorflow::wrap(new_op);
646 }
647 
TFE_DeleteOp(TFE_Op * op)648 void TFE_DeleteOp(TFE_Op* op) {
649   if (op == nullptr) {
650     return;
651   }
652 
653   tensorflow::unwrap(op)->Release();
654 }
655 
TFE_OpGetName(const TFE_Op * op,TF_Status * status)656 const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
657   return tensorflow::unwrap(op)->Name().c_str();
658 }
659 
TFE_OpGetContext(const TFE_Op * op,TF_Status * status)660 TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
661   return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
662 }
663 
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)664 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
665   status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
666 }
667 
TFE_OpGetDevice(const TFE_Op * op,TF_Status * status)668 const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
669   return tensorflow::unwrap(op)->DeviceName().c_str();
670 }
671 
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)672 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
673   status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
674 }
675 
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)676 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
677                         TF_Status* status) {
678   status->status = tensorflow::unwrap(op)->AddInputList(
679       {reinterpret_cast<tensorflow::AbstractTensorHandle**>(
680            tensorflow::unwrap(inputs)),
681        static_cast<size_t>(num_inputs)});
682 }
683 
TFE_OpGetFlatInputCount(const TFE_Op * op,TF_Status * status)684 extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
685   return tensorflow::unwrap(op)->GetInputs().size();
686 }
687 
TFE_OpGetFlatInput(const TFE_Op * op,int index,TF_Status * status)688 extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
689                                             TF_Status* status) {
690   return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
691 }
692 
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)693 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
694                               unsigned char* is_list, TF_Status* status) {
695   TF_AttrType ret = TF_ATTR_INT;
696   const tensorflow::AttrTypeMap* attr_types_;
697   bool is_function;
698   status->status = tensorflow::AttrTypeMapForOp(
699       tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
700   if (!status->status.ok()) {
701     return ret;
702   }
703   status->status =
704       tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
705   return ret;
706 }
707 
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)708 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
709                                   const char* op_or_function_name,
710                                   const char* attr_name, unsigned char* is_list,
711                                   TF_Status* status) {
712   TF_AttrType ret;
713   TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
714   if (status->status.ok()) {
715     ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
716   } else {
717     ret = TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
718   }
719   TFE_DeleteOp(op);
720   return ret;
721 }
722 
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)723 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
724                          size_t length) {
725   auto s = tensorflow::unwrap(op)->SetAttrString(
726       attr_name, static_cast<const char*>(value), length);
727   if (!s.ok()) {
728     LOG(WARNING) << "Unable to set attribute: " << attr_name;
729   }
730 }
731 
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)732 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
733   auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
734   if (!s.ok()) {
735     LOG(WARNING) << "Unable to set attribute: " << attr_name;
736   }
737 }
738 
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)739 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
740   auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
741   if (!s.ok()) {
742     LOG(WARNING) << "Unable to set attribute: " << attr_name;
743   }
744 }
745 
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)746 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
747   auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
748                                                (value == 0) ? false : true);
749   if (!s.ok()) {
750     LOG(WARNING) << "Unable to set attribute: " << attr_name;
751   }
752 }
753 
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)754 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
755   auto s = tensorflow::unwrap(op)->SetAttrType(
756       attr_name, static_cast<tensorflow::DataType>(value));
757   if (!s.ok()) {
758     LOG(WARNING) << "Unable to set attribute: " << attr_name;
759   }
760 }
761 
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)762 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
763                         const int num_dims, TF_Status* out_status) {
764   out_status->status =
765       tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
766 }
767 
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)768 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
769                            const TFE_Op* value) {
770   auto s = tensorflow::unwrap(op)->SetAttrFunction(
771       attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
772   if (!s.ok()) {
773     LOG(WARNING) << "Unable to set attribute: " << attr_name;
774   }
775 }
776 
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)777 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
778                                const char* data, size_t length) {
779   auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
780   if (!s.ok()) {
781     LOG(WARNING) << "Unable to set attribute: " << attr_name;
782   }
783 }
784 
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)785 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
786                          TF_Status* status) {
787   tensorflow::Tensor t;
788   status->status = TF_TensorToTensor(tensor, &t);
789   tensorflow::TensorInterface interface(t);
790   status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
791 }
792 
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)793 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
794                              const void* const* values, const size_t* lengths,
795                              int num_values) {
796   auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
797                                                      num_values);
798   if (!s.ok()) {
799     LOG(WARNING) << "Unable to set attribute: " << attr_name;
800   }
801 }
802 
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)803 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
804                             const float* values, int num_values) {
805   auto s =
806       tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
807   if (!s.ok()) {
808     LOG(WARNING) << "Unable to set attribute: " << attr_name;
809   }
810 }
811 
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)812 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
813                           const int64_t* values, int num_values) {
814   auto s =
815       tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
816   if (!s.ok()) {
817     LOG(WARNING) << "Unable to set attribute: " << attr_name;
818   }
819 }
820 
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)821 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
822                            const TF_DataType* values, int num_values) {
823   auto s = tensorflow::unwrap(op)->SetAttrTypeList(
824       attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
825       num_values);
826   if (!s.ok()) {
827     LOG(WARNING) << "Unable to set attribute: " << attr_name;
828   }
829 }
830 
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)831 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
832                            const unsigned char* values, int num_values) {
833   auto s =
834       tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
835   if (!s.ok()) {
836     LOG(WARNING) << "Unable to set attribute: " << attr_name;
837   }
838 }
839 
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)840 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
841                             const int64_t** dims, const int* num_dims,
842                             int num_values, TF_Status* out_status) {
843   out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
844       attr_name, dims, num_dims, num_values);
845 }
846 
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)847 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
848                                const TFE_Op** value, int num_values) {
849   auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
850       attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
851                       tensorflow::unwrap(value)),
852                   static_cast<size_t>(num_values)});
853   if (!s.ok()) {
854     LOG(WARNING) << "Unable to set attribute: " << attr_name;
855   }
856 }
857 
TFE_OpSetAttrValueProto(const TFE_Op * op,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)858 void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
859                              const void* proto, size_t proto_len,
860                              TF_Status* status) {
861   tensorflow::AttrValue attr_value;
862   if (!attr_value.ParseFromArray(proto, proto_len)) {
863     status->status =
864         tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
865     return;
866   }
867   if (op == nullptr) {
868     status->status = tensorflow::errors::InvalidArgument(
869         "Got a null or uninitialized `op` argument");
870     return;
871   }
872   tensorflow::EagerOperation* operation =
873       OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
874   operation->MutableAttrs()->Set(attr_name, attr_value);
875 }
876 
TFE_OpGetInputLength(TFE_Op * op,const char * input_name,TF_Status * status)877 TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
878                                                const char* input_name,
879                                                TF_Status* status) {
880   int ret = -1;
881   status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
882   return ret;
883 }
884 
TFE_OpGetOutputLength(TFE_Op * op,const char * output_name,TF_Status * status)885 TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
886                                                 const char* output_name,
887                                                 TF_Status* status) {
888   int ret = -1;
889   status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
890   return ret;
891 }
892 
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)893 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
894                  TF_Status* status) {
895   tensorflow::ImmediateExecutionOperation* unwrapped_op =
896       tensorflow::unwrap(op);
897 
898   status->status =
899       unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
900           unwrapped_op,
901           reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
902               retvals),
903           num_retvals);
904 }
905 
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)906 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
907                                                TFE_Context* ctx,
908                                                const char* device_name,
909                                                TF_Status* status) {
910   if (h == nullptr) {
911     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
912     return nullptr;
913   }
914 
915   auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
916       tensorflow::unwrap(h), device_name, &status->status);
917   if (status->status.ok()) {
918     return tensorflow::wrap(result);
919   }
920   return nullptr;
921 }
922 
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)923 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
924                                const char* serialized_function_def, size_t size,
925                                TF_Status* status) {
926   tensorflow::FunctionDef function_def;
927   if (!function_def.ParseFromArray(serialized_function_def, size)) {
928     status->status =
929         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
930     return;
931   }
932 
933   AnnotateEagerRuntimeConstructionContext(function_def);
934   status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
935 }
936 
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)937 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
938                             TF_Status* status) {
939   AnnotateEagerRuntimeConstructionContext(function->fdef);
940   status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
941       function->fdef, function->stack_traces);
942 }
943 
TFE_ContextRemoveFunction(TFE_Context * ctx,const char * name,TF_Status * status)944 void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
945                                TF_Status* status) {
946   status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
947 }
948 
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)949 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
950   return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
951 }
952 
TFE_ContextEnableRunMetadata(TFE_Context * ctx)953 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
954   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
955 }
956 
TFE_ContextDisableRunMetadata(TFE_Context * ctx)957 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
958   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
959 }
960 
961 }  // extern "C"
962 
TFE_NewTensorHandle(const tensorflow::Tensor & t,TF_Status * status)963 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
964                                       TF_Status* status) {
965   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
966 }
967 
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)968 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
969                                   TF_Status* status) {
970   auto* context = tensorflow::unwrap(ctx);
971   status->status = context->AsyncWait();
972   if (!status->status.ok()) return;
973   auto run_metadata = context->ExportRunMetadata();
974   status->status = MessageToBuffer(*run_metadata, buf);
975 }
976 
977 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)978 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
979                 TF_Status* status) {
980   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
981   for (const auto& attr : func.attr()) {
982     if (!status->status.ok()) return nullptr;
983     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
984     if (!status->status.ok()) return nullptr;
985   }
986   return func_op;
987 }
988 }  // namespace
989 
TFE_ContextStartStep(TFE_Context * ctx)990 void TFE_ContextStartStep(TFE_Context* ctx) {
991   tensorflow::unwrap(ctx)->StartStep();
992 }
993 
TFE_ContextEndStep(TFE_Context * ctx)994 void TFE_ContextEndStep(TFE_Context* ctx) {
995   tensorflow::unwrap(ctx)->EndStep();
996 }
997 
TFE_OpGetAttrs(const TFE_Op * op)998 const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
999   return tensorflow::wrap(tensorflow::unwrap(op)->GetOpAttrs());
1000 }
1001 
TFE_OpAddAttrs(TFE_Op * op,const TFE_OpAttrs * attrs)1002 void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
1003   tensorflow::unwrap(op)->AddAttrs(tensorflow::unwrap(attrs));
1004 }
1005 
TFE_OpAttrsSerialize(const TFE_OpAttrs * attrs,TF_Buffer * buf,TF_Status * status)1006 void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
1007                           TF_Status* status) {
1008   tensorflow::NameAttrList name_and_attrs;
1009   tensorflow::unwrap(attrs)->GetNameAttrList(&name_and_attrs);
1010   status->status = MessageToBuffer(name_and_attrs, buf);
1011 }
1012 
1013 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)1014 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
1015                           const tensorflow::AttrValue& default_value,
1016                           const char* attr_name, TF_Status* status) {
1017   switch (default_value.value_case()) {
1018     case tensorflow::AttrValue::kS: {
1019       const string& v = default_value.s();
1020       TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
1021       break;
1022     }
1023     case tensorflow::AttrValue::kI:
1024       TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1025       break;
1026     case tensorflow::AttrValue::kF:
1027       TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1028       break;
1029     case tensorflow::AttrValue::kB:
1030       TFE_OpSetAttrBool(op, attr_name, default_value.b());
1031       break;
1032     case tensorflow::AttrValue::kType:
1033       TFE_OpSetAttrType(op, attr_name,
1034                         static_cast<TF_DataType>(default_value.type()));
1035       break;
1036     case tensorflow::AttrValue::kShape: {
1037       const auto& tensor_shape = default_value.shape();
1038       if (tensor_shape.unknown_rank()) {
1039         TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1040       } else {
1041         const auto num_dims = tensor_shape.dim_size();
1042         std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1043         for (int i = 0; i < num_dims; ++i) {
1044           dims[i] = tensor_shape.dim(i).size();
1045         }
1046         TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1047       }
1048     } break;
1049     case tensorflow::AttrValue::kFunc: {
1050       const auto func_op = GetFunc(ctx, default_value.func(), status);
1051       if (!status->status.ok()) return;
1052       // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1053       // require TFE_Op* and just convert it internally a NameAttrValue, so
1054       // consider adding an overload to the C API to make this case easier.
1055       TFE_OpSetAttrFunction(op, attr_name, func_op);
1056       TFE_DeleteOp(func_op);
1057     } break;
1058     case tensorflow::AttrValue::kList: {
1059       // String
1060       if (const int s_size = default_value.list().s_size()) {
1061         absl::InlinedVector<const void*, 4> values_vector;
1062         absl::InlinedVector<size_t, 4> lengths_vector;
1063         for (int i = 0; i < s_size; ++i) {
1064           const string& v = default_value.list().s(i);
1065           values_vector.push_back(v.data());
1066           lengths_vector.push_back(v.size());
1067         }
1068         TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
1069                                 lengths_vector.data(), s_size);
1070       }
1071 
1072       // Int
1073       if (const int i_size = default_value.list().i_size()) {
1074         absl::InlinedVector<int64_t, 4> i_vector;
1075         for (int i = 0; i < i_size; ++i) {
1076           i_vector.push_back(default_value.list().i(i));
1077         }
1078         TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
1079       }
1080       // Float
1081       if (const int f_size = default_value.list().f_size()) {
1082         absl::InlinedVector<float, 4> f_vector;
1083         for (int i = 0; i < f_size; ++i) {
1084           f_vector.push_back(default_value.list().f(i));
1085         }
1086         TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
1087       }
1088       // Bool
1089       if (const int b_size = default_value.list().b_size()) {
1090         absl::InlinedVector<unsigned char, 4> b_vector;
1091         for (int i = 0; i < b_size; i++) {
1092           b_vector.push_back(default_value.list().b(i));
1093         }
1094         TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
1095       }
1096       // Type
1097       if (const int type_size = default_value.list().type_size()) {
1098         absl::InlinedVector<unsigned int, 4> type_vector;
1099         for (int i = 0; i < type_size; ++i) {
1100           type_vector.push_back(default_value.list().type(i));
1101         }
1102         TFE_OpSetAttrTypeList(
1103             op, attr_name,
1104             reinterpret_cast<const TF_DataType*>(type_vector.data()),
1105             type_size);
1106       }
1107 
1108       // Rest are not supported.
1109       if (default_value.list().shape_size() > 0 ||
1110           default_value.list().func_size() > 0 ||
1111           default_value.list().tensor_size() > 0) {
1112         TF_SetStatus(
1113             status, TF_UNIMPLEMENTED,
1114             tensorflow::strings::StrCat("Unable to get setfor default value: ",
1115                                         default_value.DebugString())
1116                 .data());
1117       }
1118     } break;
1119     case tensorflow::AttrValue::kTensor:
1120       TF_FALLTHROUGH_INTENDED;
1121     case tensorflow::AttrValue::kPlaceholder:
1122       TF_FALLTHROUGH_INTENDED;
1123     case tensorflow::AttrValue::VALUE_NOT_SET:
1124       TF_SetStatus(
1125           status, TF_UNIMPLEMENTED,
1126           tensorflow::strings::StrCat("Unable to get setfor default value: ",
1127                                       default_value.DebugString())
1128               .data());
1129   }
1130 }
1131 }  // namespace tensorflow
1132 
1133 namespace {
DefaultCustomDevicePack(TFE_Context * context,TFE_TensorHandle ** handles,int num_handles,TF_Status * status,void * device_info)1134 TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
1135                                           TFE_TensorHandle** handles,
1136                                           int num_handles, TF_Status* status,
1137                                           void* device_info) {
1138   TF_SetStatus(status, TF_UNIMPLEMENTED,
1139                "This custom device does not support packing tensors.");
1140   return nullptr;
1141 }
1142 }  // namespace
1143 
1144 extern "C" {
1145 
TFE_RegisterCustomDevice(TFE_Context * ctx,TFE_CustomDevice device,const char * device_name,void * device_info,TF_Status * status)1146 void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
1147                               const char* device_name, void* device_info,
1148                               TF_Status* status) {
1149   // Fill in default values for optional functionality.
1150   if (device.pack == nullptr) {
1151     device.pack = &DefaultCustomDevicePack;
1152   }
1153   auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
1154       ctx, device, device_info, device_name);
1155   status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
1156       device_name, std::move(custom_device));
1157 }
1158 
1159 }  // extern "C"
1160