• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/eager/abstract_tensor_handle.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/immediate_execution_operation.h"
32 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_op_internal.h"
35 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
36 #include "tensorflow/c/tf_tensor_internal.h"
37 #include "tensorflow/core/common_runtime/copy_tensor.h"
38 #include "tensorflow/core/common_runtime/device.h"
39 #include "tensorflow/core/common_runtime/device_factory.h"
40 #include "tensorflow/core/common_runtime/device_mgr.h"
41 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
42 #include "tensorflow/core/common_runtime/eager/context.h"
43 #include "tensorflow/core/common_runtime/eager/custom_device.h"
44 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
45 #include "tensorflow/core/common_runtime/eager/execute.h"
46 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
47 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/framework/attr_value.pb.h"
50 #include "tensorflow/core/framework/device_attributes.pb.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/node_def_util.h"
53 #include "tensorflow/core/framework/rendezvous.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/platform/casts.h"
57 #include "tensorflow/core/platform/errors.h"
58 #include "tensorflow/core/platform/platform.h"
59 #include "tensorflow/core/platform/status.h"
60 #include "tensorflow/core/profiler/lib/traceme.h"
61 #include "tensorflow/core/protobuf/error_codes.pb.h"
62 #include "tensorflow/core/public/version.h"
63 
64 // "tensorflow/core/platform/platform.h" must be included first before using
65 // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
66 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
67 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
68 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
69 #endif  // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
70 
71 #if !defined(IS_MOBILE_PLATFORM)
72 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
73 #endif  // !IS_MOBILE_PLATFORM
74 
75 using tensorflow::string;
76 
77 namespace {
78 
DeviceName(const tensorflow::Device * d)79 string DeviceName(const tensorflow::Device* d) {
80   return (d == nullptr) ? "cpu:0" : d->name();
81 }
82 
83 // Annotate eager runtime construction context to the given `function_def` as
84 // an attribute.
AnnotateEagerRuntimeConstructionContext(tensorflow::FunctionDef & function_def)85 void AnnotateEagerRuntimeConstructionContext(
86     tensorflow::FunctionDef& function_def) {
87   tensorflow::AttrValue value;
88   SetAttrValue("kEagerRuntime", &value);
89   (*function_def.mutable_attr())["_construction_context"] = value;
90 }
91 
92 }  // namespace
93 
94 extern "C" {
95 
TFE_NewContextOptions()96 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
97 
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)98 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
99                                  size_t proto_len, TF_Status* status) {
100   TF_SetConfig(&options->session_options, proto, proto_len, status);
101 }
102 
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)103 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
104                                 unsigned char enable) {
105   options->async = enable;
106 }
107 
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)108 void TFE_ContextOptionsSetDevicePlacementPolicy(
109     TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
110   options->device_placement_policy = policy;
111 }
112 
TFE_DeleteContextOptions(TFE_ContextOptions * options)113 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
114 
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)115 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
116   if (opts->use_tfrt) {
117 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
118     tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
119         opts->session_options.options,
120         static_cast<tensorflow::ContextDevicePlacementPolicy>(
121             opts->device_placement_policy),
122         opts->async, opts->use_tfrt_distributed_runtime);
123 #if !defined(IS_MOBILE_PLATFORM)
124     tfrt_context->SetDistributedManager(
125         tfrt::tf::CreateDistributedManagerContext(
126             tfrt_context->GetCoreRuntime()->GetHostContext()));
127 #endif  // !IS_MOBILE_PLATFORM
128     return tensorflow::wrap(tfrt_context);
129 #else
130     status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
131     return nullptr;
132 #endif  // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
133   }
134   std::vector<std::unique_ptr<tensorflow::Device>> devices;
135   status->status = tensorflow::DeviceFactory::AddDevices(
136       opts->session_options.options, "/job:localhost/replica:0/task:0",
137       &devices);
138   if (!status->status.ok()) return nullptr;
139   std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
140       new tensorflow::DynamicDeviceMgr(std::move(devices)));
141 
142   tensorflow::Rendezvous* r =
143       new tensorflow::IntraProcessRendezvous(device_mgr.get());
144   tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
145       opts->session_options.options,
146       static_cast<tensorflow::ContextDevicePlacementPolicy>(
147           opts->device_placement_policy),
148       opts->async, device_mgr.release(),
149       /*device_mgr_owned*/ true, r,
150       /*cluster_flr=*/nullptr,
151       /*collective_executor_mgr=*/nullptr,
152       /*run_eager_op_as_function=*/opts->run_eager_op_as_function);
153 #if !defined(IS_MOBILE_PLATFORM)
154   eager_context->SetDistributedManager(
155       std::make_unique<tensorflow::EagerContextDistributedManager>(
156           eager_context));
157 #endif  // !IS_MOBILE_PLATFORM
158   return tensorflow::wrap(eager_context);
159 }
160 
TFE_DeleteContext(TFE_Context * ctx)161 void TFE_DeleteContext(TFE_Context* ctx) {
162   if (ctx == nullptr) {
163     return;
164   }
165 
166   // ctx->RefCountIsOne() should be true here.
167   tensorflow::unwrap(ctx)->Release();
168 }
169 
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)170 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
171   TF_DeviceList* l = new TF_DeviceList;
172   tensorflow::unwrap(ctx)->ListDevices(&l->response);
173   return l;
174 }
175 
TFE_ContextClearCaches(TFE_Context * ctx)176 void TFE_ContextClearCaches(TFE_Context* ctx) {
177   tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
178 }
179 
180 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)181 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
182                                                    int keep_alive_secs,
183                                                    const void* proto,
184                                                    size_t proto_len,
185                                                    TF_Status* status) {
186 #if defined(IS_MOBILE_PLATFORM)
187   status->status = tensorflow::errors::Unimplemented(
188       "TFE_ContextSetServerDef not supported on mobile");
189 #else   // !defined(IS_MOBILE_PLATFORM)
190   tensorflow::ServerDef server_def;
191   if (!server_def.ParseFromArray(proto, proto_len)) {
192     status->status = tensorflow::errors::InvalidArgument(
193         "Invalid tensorflow.ServerDef protocol buffer");
194     return;
195   }
196   status->status =
197       tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
198           server_def, /*reset_context=*/true, keep_alive_secs);
199 #endif  // !IS_MOBILE_PLATFORM
200 }
201 
TFE_ContextUpdateServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)202 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
203                                                       int keep_alive_secs,
204                                                       const void* proto,
205                                                       size_t proto_len,
206                                                       TF_Status* status) {
207 #if defined(IS_MOBILE_PLATFORM)
208   status->status = tensorflow::errors::Unimplemented(
209       "TFE_ContextSetServerDef not supported on mobile");
210 #else   // !defined(IS_MOBILE_PLATFORM)
211   tensorflow::ServerDef server_def;
212   tensorflow::EagerContext* context =
213       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
214   if (!server_def.ParseFromArray(proto, proto_len)) {
215     status->status = tensorflow::errors::InvalidArgument(
216         "Invalid tensorflow.ServerDef protocol buffer");
217     return;
218   } else if (context->GetContextId() ==
219              tensorflow::EagerContext::kInvalidContextId) {
220     status->status = tensorflow::errors::InvalidArgument(
221         "Trying to update a context with invalid context id.");
222   }
223   status->status =
224       tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
225           server_def, /*reset_context=*/false, keep_alive_secs);
226 #endif  // !IS_MOBILE_PLATFORM
227 }
228 
TFE_ContextCheckAlive(TFE_Context * ctx,const char * worker_name,TF_Status * status)229 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
230                                                  const char* worker_name,
231                                                  TF_Status* status) {
232 #if defined(IS_MOBILE_PLATFORM)
233   status->status = tensorflow::errors::Unimplemented(
234       "TFE_ContextSetServerDef not supported on mobile");
235   return false;
236 #else   // !defined(IS_MOBILE_PLATFORM)
237   bool is_alive;
238   status->status =
239       tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
240           worker_name, &is_alive);
241   return is_alive;
242 #endif  // !IS_MOBILE_PLATFORM
243 }
244 
TFE_ContextAsyncWait(TFE_Context * ctx,TF_Status * status)245 TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
246                                                 TF_Status* status) {
247 #if defined(IS_MOBILE_PLATFORM)
248   status->status = tensorflow::Status::OK();
249 #else   // !defined(IS_MOBILE_PLATFORM)
250   status->status = tensorflow::unwrap(ctx)->AsyncWait();
251 #endif  // !IS_MOBILE_PLATFORM
252 }
253 
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)254 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
255     TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
256   tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
257       static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
258 }
259 
260 // Note: this function looks up a thread local policy. So it should be called in
261 // the appropriate client thread. In particular, in async mode, it may not be
262 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)263 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
264     TFE_Context* ctx) {
265   return static_cast<TFE_ContextDevicePlacementPolicy>(
266       tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
267 }
268 
TFE_NewTensorHandle(const TF_Tensor * t,TF_Status * status)269 TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
270   tensorflow::Tensor tensor;
271   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
272   if (!status->status.ok()) return nullptr;
273 
274   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
275 }
276 
TFE_DeleteTensorHandle(TFE_TensorHandle * h)277 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
278   if (h == nullptr) return;
279 
280   tensorflow::profiler::TraceMe activity(
281       "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
282   if (h) {
283     tensorflow::unwrap(h)->Release();
284   }
285 }
286 
TFE_TensorHandleDataType(TFE_TensorHandle * h)287 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
288   return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
289 }
290 
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)291 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
292   if (h == nullptr) {
293     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
294     return -1;
295   }
296 
297   int num_dims = -1;
298   status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
299   return num_dims;
300 }
301 
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)302 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
303   if (h == nullptr) {
304     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
305     return -1;
306   }
307 
308   int64_t num_elements = -1;
309   status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
310   return num_elements;
311 }
312 
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)313 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
314                             TF_Status* status) {
315   if (h == nullptr) {
316     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
317     return -1;
318   }
319 
320   int64_t dim = -1;
321   status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
322   return dim;
323 }
324 
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)325 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
326   if (h == nullptr) {
327     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
328     return nullptr;
329   }
330   return tensorflow::unwrap(h)->DeviceName(&status->status);
331 }
332 
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)333 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
334                                               TF_Status* status) {
335   if (h == nullptr) {
336     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
337     return nullptr;
338   }
339   return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
340 }
341 
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)342 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
343     TFE_TensorHandle* h, TF_Status* status) {
344   if (h == nullptr) {
345     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
346     return nullptr;
347   }
348 
349   return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
350 }
351 
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)352 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
353   if (h == nullptr) {
354     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
355     return nullptr;
356   }
357 
358   tensorflow::AbstractTensorInterface* t =
359       tensorflow::unwrap(h)->Resolve(&status->status);
360   if (t == nullptr) {
361     return nullptr;
362   }
363 
364   return new TF_Tensor{t};
365 }
366 
TFE_TensorHandleDevicePointer(TFE_TensorHandle * h,TF_Status * status)367 void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
368   if (h == nullptr) {
369     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
370     return nullptr;
371   }
372   tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
373       tensorflow::unwrap(h);
374   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
375   if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
376     return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
377                unwrapped_handle)
378         ->DevicePointer();
379   }
380   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
381   if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
382     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
383     return nullptr;
384   }
385   tensorflow::TensorHandle* handle =
386       tensorflow::TensorHandleFromInterface(unwrapped_handle);
387 
388   if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
389     status->status = tensorflow::errors::InvalidArgument(
390         "TFE_TensorHandleDevicePointer may not be called on a ",
391         handle->TypeString(), " tensor handle.");
392     return nullptr;
393   }
394   tensorflow::Device* device(handle->device());
395   if (device != nullptr) {
396     status->status = device->Sync();
397     if (!status->status.ok()) {
398       return nullptr;
399     }
400   }
401   const tensorflow::Tensor* tensor;
402   status->status = handle->Tensor(&tensor);
403   if (!status->status.ok()) {
404     return nullptr;
405   }
406   return const_cast<void*>(
407       static_cast<const void*>(tensor->tensor_data().data()));
408 }
409 
410 namespace tensorflow {
411 namespace {
412 class CustomDeviceAPI : public tensorflow::CustomDevice {
413  public:
CustomDeviceAPI(TFE_Context * context,TFE_CustomDevice device,void * info,string name)414   CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
415                   string name)
416       : context_(context), device_(device), info_(info), name_(name) {}
417 
~CustomDeviceAPI()418   ~CustomDeviceAPI() override { device_.delete_device(info_); }
419 
name()420   const string& name() override { return name_; }
421 
CopyTensorToDevice(ImmediateExecutionTensorHandle * handle,ImmediateExecutionTensorHandle ** result)422   tensorflow::Status CopyTensorToDevice(
423       ImmediateExecutionTensorHandle* handle,
424       ImmediateExecutionTensorHandle** result) override {
425     handle->Ref();
426     TF_Status status;
427     TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
428         context_, tensorflow::wrap(handle), &status, info_);
429     handle->Release();
430     if (!status.status.ok()) return status.status;
431     *result = tensorflow::unwrap(result_handle);
432     (*result)->Ref();
433     TFE_DeleteTensorHandle(result_handle);
434     return status.status;
435   }
436 
CopyTensorFromDevice(ImmediateExecutionTensorHandle * handle,const tensorflow::string & target_device_name,ImmediateExecutionTensorHandle ** result)437   tensorflow::Status CopyTensorFromDevice(
438       ImmediateExecutionTensorHandle* handle,
439       const tensorflow::string& target_device_name,
440       ImmediateExecutionTensorHandle** result) override {
441     TF_Status status;
442     handle->Ref();
443     TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
444         context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
445         info_);
446     handle->Release();
447     if (!status.status.ok()) return status.status;
448     *result = tensorflow::unwrap(result_handle);
449     (*result)->Ref();
450     TFE_DeleteTensorHandle(result_handle);
451     return status.status;
452   }
453 
Execute(const ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)454   tensorflow::Status Execute(const ImmediateExecutionOperation* op,
455                              ImmediateExecutionTensorHandle** retvals,
456                              int* num_retvals) override {
457     std::vector<TFE_TensorHandle*> outputs(*num_retvals);
458     TF_Status status;
459     device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
460                     info_);
461     if (status.status.ok()) {
462       for (int i = 0; i < *num_retvals; ++i) {
463         retvals[i] = tensorflow::unwrap(outputs[i]);
464         retvals[i]->Ref();
465         TFE_DeleteTensorHandle(outputs[i]);
466       }
467     }
468     return status.status;
469   }
470 
Pack(absl::Span<ImmediateExecutionTensorHandle * > handles,ImmediateExecutionTensorHandle ** result)471   tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
472                           ImmediateExecutionTensorHandle** result) override {
473     TF_Status status;
474     *result = tensorflow::unwrap(device_.pack(context_,
475                                               tensorflow::wrap(handles.data()),
476                                               handles.size(), &status, info_));
477     return status.status;
478   }
479 
480  private:
481   TFE_Context* context_;
482   TFE_CustomDevice device_;
483   void* info_;
484   string name_;
485 };
486 
487 // An adapter which wraps the shape/data produced by C custom devices and uses
488 // it to implement custom device methods.
489 class CAPICustomDeviceTensorHandle
490     : public tensorflow::CustomDeviceTensorHandle {
491  public:
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext * context,tensorflow::CustomDevice * device,tensorflow::DataType dtype,void * data,TFE_CustomDeviceTensorHandleMethods methods)492   CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
493                                tensorflow::CustomDevice* device,
494                                tensorflow::DataType dtype, void* data,
495                                TFE_CustomDeviceTensorHandleMethods methods)
496       : tensorflow::CustomDeviceTensorHandle(context, device, dtype),
497         data_(data),
498         methods_(methods) {}
499 
~CAPICustomDeviceTensorHandle()500   ~CAPICustomDeviceTensorHandle() override { methods_.deallocator(data_); }
DevicePointer() const501   void* DevicePointer() const override { return data_; }
NumDims(int * num_dims) const502   Status NumDims(int* num_dims) const override {
503     TF_Status s;
504     *num_dims = methods_.num_dims(data_, &s);
505     return s.status;
506   }
Dim(int dim_index,int64 * dim) const507   Status Dim(int dim_index, int64* dim) const override {
508     TF_Status s;
509     *dim = methods_.dim(data_, dim_index, &s);
510     return s.status;
511   }
512 
PreferCustomSummarizer() const513   bool PreferCustomSummarizer() const override {
514     return methods_.summarize != nullptr;
515   }
516 
SummarizeValue(std::string & summary) const517   Status SummarizeValue(std::string& summary) const override {
518     if (methods_.summarize == nullptr) {
519       return tensorflow::CustomDeviceTensorHandle::SummarizeValue(summary);
520     }
521     TF_Status c_status;
522     std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> summary_buffer(
523         methods_.summarize(data_, &c_status), TF_DeleteBuffer);
524     if (!c_status.status.ok()) {
525       return c_status.status;
526     }
527     summary = std::string(reinterpret_cast<const char*>(summary_buffer->data),
528                           summary_buffer->length);
529     return Status::OK();
530   }
531 
532  private:
533   void* const data_;
534   const TFE_CustomDeviceTensorHandleMethods methods_;
535 };
536 
537 }  // namespace
538 }  // namespace tensorflow
539 
TFE_NewCustomDeviceTensorHandle(TFE_Context * ctx,const char * device_name,TF_DataType dtype,void * data,TFE_CustomDeviceTensorHandleMethods methods,TF_Status * status)540 TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
541     TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data,
542     TFE_CustomDeviceTensorHandleMethods methods, TF_Status* status) {
543   tensorflow::ImmediateExecutionContext* context = tensorflow::unwrap(ctx);
544   tensorflow::CustomDevice* device = nullptr;
545   if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
546                                                                     &device)) {
547     methods.deallocator(data);
548     status->status =
549         tensorflow::errors::InvalidArgument(device_name, " unknown device.");
550     return nullptr;
551   }
552   return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
553       context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
554       methods));
555 }
556 
TFE_NewTensorHandleFromDeviceMemory(TFE_Context * ctx,const char * device_name,TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg,TF_Status * status)557 TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
558     TFE_Context* ctx, const char* device_name, TF_DataType dtype,
559     const int64_t* dims, int num_dims, void* data, size_t len,
560     void (*deallocator)(void* data, size_t len, void* arg),
561     void* deallocator_arg, TF_Status* status) {
562   tensorflow::Device* device = nullptr;
563   tensorflow::EagerContext* context =
564       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
565   status->status = context->FindDeviceFromName(device_name, &device);
566   if (!status->status.ok()) {
567     deallocator(data, len, deallocator_arg);
568     status->status =
569         tensorflow::errors::InvalidArgument(device_name, " unknown device.");
570     return nullptr;
571   }
572   std::vector<tensorflow::int64> dimvec(num_dims);
573   for (int i = 0; i < num_dims; ++i) {
574     dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
575   }
576 
577   // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
578   // the device?
579   TF_ManagedBuffer* buf =
580       new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
581                            /*owns_memory=*/false);
582 
583   tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
584                        tensorflow::TensorShape(dimvec), buf);
585   buf->Unref();
586   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
587       std::move(t), device, device, context));
588 }
589 
590 // This function will block till the operation that produces `h` has
591 // completed. This is only valid on local TFE_TensorHandles. Returns the size in
592 // bytes of the memory pointed to by the device pointer returned above.
TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle * h,TF_Status * status)593 size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
594                                         TF_Status* status) {
595   if (h == nullptr) {
596     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
597     return 0;
598   }
599   tensorflow::TensorHandle* handle =
600       tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
601   if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
602     status->status = tensorflow::errors::InvalidArgument(
603         "TFE_TensorHandleDeviceMemorySize may not be called on a ",
604         handle->TypeString(), " tensor handle.");
605     return 0;
606   }
607   const tensorflow::Tensor* tensor;
608   status->status = handle->Tensor(&tensor);
609   if (!status->status.ok()) {
610     return 0;
611   }
612   return tensor->TotalBytes();
613 }
614 
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)615 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
616                   TF_Status* status) {
617   tensorflow::ImmediateExecutionOperation* new_op =
618       tensorflow::unwrap(ctx)->CreateOperation();
619   status->status = new_op->Reset(op_or_function_name, nullptr);
620   if (!status->status.ok()) {
621     new_op->Release();
622     new_op = nullptr;
623   }
624   return tensorflow::wrap(new_op);
625 }
626 
TFE_DeleteOp(TFE_Op * op)627 void TFE_DeleteOp(TFE_Op* op) {
628   if (op == nullptr) {
629     return;
630   }
631 
632   tensorflow::unwrap(op)->Release();
633 }
634 
TFE_OpGetName(const TFE_Op * op,TF_Status * status)635 const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
636   return tensorflow::unwrap(op)->Name().c_str();
637 }
638 
TFE_OpGetContext(const TFE_Op * op,TF_Status * status)639 TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
640   return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
641 }
642 
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)643 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
644   status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
645 }
646 
TFE_OpGetDevice(const TFE_Op * op,TF_Status * status)647 const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
648   return tensorflow::unwrap(op)->DeviceName().c_str();
649 }
650 
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)651 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
652   status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
653 }
654 
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)655 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
656                         TF_Status* status) {
657   status->status = tensorflow::unwrap(op)->AddInputList(
658       {reinterpret_cast<tensorflow::AbstractTensorHandle**>(
659            tensorflow::unwrap(inputs)),
660        static_cast<size_t>(num_inputs)});
661 }
662 
TFE_OpGetFlatInputCount(const TFE_Op * op,TF_Status * status)663 extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
664   return tensorflow::unwrap(op)->GetInputs().size();
665 }
666 
TFE_OpGetFlatInput(const TFE_Op * op,int index,TF_Status * status)667 extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
668                                             TF_Status* status) {
669   return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
670 }
671 
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)672 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
673                               unsigned char* is_list, TF_Status* status) {
674   TF_AttrType ret = TF_ATTR_INT;
675   const tensorflow::AttrTypeMap* attr_types_;
676   bool is_function;
677   status->status = tensorflow::AttrTypeMapForOp(
678       tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
679   if (!status->status.ok()) {
680     return ret;
681   }
682   status->status =
683       tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
684   return ret;
685 }
686 
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)687 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
688                                   const char* op_or_function_name,
689                                   const char* attr_name, unsigned char* is_list,
690                                   TF_Status* status) {
691   TF_AttrType ret;
692   TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
693   if (status->status.ok()) {
694     ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
695   } else {
696     ret = TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
697   }
698   TFE_DeleteOp(op);
699   return ret;
700 }
701 
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)702 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
703                          size_t length) {
704   auto s = tensorflow::unwrap(op)->SetAttrString(
705       attr_name, static_cast<const char*>(value), length);
706   if (!s.ok()) {
707     LOG(WARNING) << "Unable to set attribute: " << attr_name;
708   }
709 }
710 
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)711 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
712   auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
713   if (!s.ok()) {
714     LOG(WARNING) << "Unable to set attribute: " << attr_name;
715   }
716 }
717 
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)718 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
719   auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
720   if (!s.ok()) {
721     LOG(WARNING) << "Unable to set attribute: " << attr_name;
722   }
723 }
724 
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)725 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
726   auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
727                                                (value == 0) ? false : true);
728   if (!s.ok()) {
729     LOG(WARNING) << "Unable to set attribute: " << attr_name;
730   }
731 }
732 
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)733 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
734   auto s = tensorflow::unwrap(op)->SetAttrType(
735       attr_name, static_cast<tensorflow::DataType>(value));
736   if (!s.ok()) {
737     LOG(WARNING) << "Unable to set attribute: " << attr_name;
738   }
739 }
740 
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)741 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
742                         const int num_dims, TF_Status* out_status) {
743   out_status->status =
744       tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
745 }
746 
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)747 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
748                            const TFE_Op* value) {
749   auto s = tensorflow::unwrap(op)->SetAttrFunction(
750       attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
751   if (!s.ok()) {
752     LOG(WARNING) << "Unable to set attribute: " << attr_name;
753   }
754 }
755 
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)756 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
757                                const char* data, size_t length) {
758   auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
759   if (!s.ok()) {
760     LOG(WARNING) << "Unable to set attribute: " << attr_name;
761   }
762 }
763 
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)764 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
765                          TF_Status* status) {
766   tensorflow::Tensor t;
767   status->status = TF_TensorToTensor(tensor, &t);
768   tensorflow::TensorInterface interface(t);
769   status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
770 }
771 
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)772 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
773                              const void* const* values, const size_t* lengths,
774                              int num_values) {
775   auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
776                                                      num_values);
777   if (!s.ok()) {
778     LOG(WARNING) << "Unable to set attribute: " << attr_name;
779   }
780 }
781 
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)782 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
783                             const float* values, int num_values) {
784   auto s =
785       tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
786   if (!s.ok()) {
787     LOG(WARNING) << "Unable to set attribute: " << attr_name;
788   }
789 }
790 
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)791 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
792                           const int64_t* values, int num_values) {
793   auto s =
794       tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
795   if (!s.ok()) {
796     LOG(WARNING) << "Unable to set attribute: " << attr_name;
797   }
798 }
799 
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)800 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
801                            const TF_DataType* values, int num_values) {
802   auto s = tensorflow::unwrap(op)->SetAttrTypeList(
803       attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
804       num_values);
805   if (!s.ok()) {
806     LOG(WARNING) << "Unable to set attribute: " << attr_name;
807   }
808 }
809 
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)810 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
811                            const unsigned char* values, int num_values) {
812   auto s =
813       tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
814   if (!s.ok()) {
815     LOG(WARNING) << "Unable to set attribute: " << attr_name;
816   }
817 }
818 
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)819 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
820                             const int64_t** dims, const int* num_dims,
821                             int num_values, TF_Status* out_status) {
822   out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
823       attr_name, dims, num_dims, num_values);
824 }
825 
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)826 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
827                                const TFE_Op** value, int num_values) {
828   auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
829       attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
830                       tensorflow::unwrap(value)),
831                   static_cast<size_t>(num_values)});
832   if (!s.ok()) {
833     LOG(WARNING) << "Unable to set attribute: " << attr_name;
834   }
835 }
836 
TFE_OpSetAttrValueProto(const TFE_Op * op,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)837 void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
838                              const void* proto, size_t proto_len,
839                              TF_Status* status) {
840   tensorflow::AttrValue attr_value;
841   if (!attr_value.ParseFromArray(proto, proto_len)) {
842     status->status =
843         tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
844     return;
845   }
846   if (op == nullptr) {
847     status->status = tensorflow::errors::InvalidArgument(
848         "Got a null or uninitialized `op` argument");
849     return;
850   }
851   tensorflow::EagerOperation* operation =
852       OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
853   operation->MutableAttrs()->Set(attr_name, attr_value);
854 }
855 
TFE_OpGetInputLength(TFE_Op * op,const char * input_name,TF_Status * status)856 TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
857                                                const char* input_name,
858                                                TF_Status* status) {
859   int ret = -1;
860   status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
861   return ret;
862 }
863 
TFE_OpGetOutputLength(TFE_Op * op,const char * output_name,TF_Status * status)864 TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
865                                                 const char* output_name,
866                                                 TF_Status* status) {
867   int ret = -1;
868   status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
869   return ret;
870 }
871 
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)872 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
873                  TF_Status* status) {
874   tensorflow::ImmediateExecutionOperation* unwrapped_op =
875       tensorflow::unwrap(op);
876 
877   status->status =
878       unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
879           unwrapped_op,
880           reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
881               retvals),
882           num_retvals);
883 }
884 
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)885 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
886                                                TFE_Context* ctx,
887                                                const char* device_name,
888                                                TF_Status* status) {
889   if (h == nullptr) {
890     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
891     return nullptr;
892   }
893 
894   tensorflow::ImmediateExecutionContext* unwrapped_ctx =
895       tensorflow::unwrap(ctx);
896 
897   auto* result =
898       unwrapped_ctx->GetCustomDeviceOpHandler().CopyTensorHandleToDevice(
899           unwrapped_ctx, tensorflow::unwrap(h), device_name, &status->status);
900 
901   if (status->status.ok()) {
902     return tensorflow::wrap(result);
903   }
904   return nullptr;
905 }
906 
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)907 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
908                                const char* serialized_function_def, size_t size,
909                                TF_Status* status) {
910   tensorflow::FunctionDef function_def;
911   if (!function_def.ParseFromArray(serialized_function_def, size)) {
912     status->status =
913         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
914     return;
915   }
916 
917   AnnotateEagerRuntimeConstructionContext(function_def);
918   status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
919 }
920 
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)921 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
922                             TF_Status* status) {
923   AnnotateEagerRuntimeConstructionContext(function->fdef);
924   status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
925       function->fdef, function->stack_traces);
926 }
927 
TFE_ContextRemoveFunction(TFE_Context * ctx,const char * name,TF_Status * status)928 void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
929                                TF_Status* status) {
930   status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
931 }
932 
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)933 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
934   return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
935 }
936 
TFE_ContextEnableRunMetadata(TFE_Context * ctx)937 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
938   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
939 }
940 
TFE_ContextDisableRunMetadata(TFE_Context * ctx)941 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
942   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
943 }
944 
945 }  // extern "C"
946 
TFE_NewTensorHandle(const tensorflow::Tensor & t,TF_Status * status)947 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
948                                       TF_Status* status) {
949   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
950 }
951 
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)952 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
953                                   TF_Status* status) {
954   auto* context = tensorflow::unwrap(ctx);
955   status->status = context->AsyncWait();
956   if (!status->status.ok()) return;
957   auto run_metadata = context->ExportRunMetadata();
958   status->status = MessageToBuffer(*run_metadata, buf);
959 }
960 
961 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)962 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
963                 TF_Status* status) {
964   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
965   for (const auto& attr : func.attr()) {
966     if (!status->status.ok()) return nullptr;
967     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
968     if (!status->status.ok()) return nullptr;
969   }
970   return func_op;
971 }
972 }  // namespace
973 
TFE_ContextStartStep(TFE_Context * ctx)974 void TFE_ContextStartStep(TFE_Context* ctx) {
975   tensorflow::unwrap(ctx)->StartStep();
976 }
977 
TFE_ContextEndStep(TFE_Context * ctx)978 void TFE_ContextEndStep(TFE_Context* ctx) {
979   tensorflow::unwrap(ctx)->EndStep();
980 }
981 
TFE_OpGetAttrs(const TFE_Op * op)982 const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
983   return tensorflow::wrap(tensorflow::unwrap(op)->GetOpAttrs());
984 }
985 
TFE_OpAddAttrs(TFE_Op * op,const TFE_OpAttrs * attrs)986 void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
987   tensorflow::unwrap(op)->AddAttrs(tensorflow::unwrap(attrs));
988 }
989 
TFE_OpAttrsSerialize(const TFE_OpAttrs * attrs,TF_Buffer * buf,TF_Status * status)990 void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
991                           TF_Status* status) {
992   tensorflow::NameAttrList name_and_attrs;
993   tensorflow::unwrap(attrs)->GetNameAttrList(&name_and_attrs);
994   status->status = MessageToBuffer(name_and_attrs, buf);
995 }
996 
997 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)998 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
999                           const tensorflow::AttrValue& default_value,
1000                           const char* attr_name, TF_Status* status) {
1001   switch (default_value.value_case()) {
1002     case tensorflow::AttrValue::kS: {
1003       const string& v = default_value.s();
1004       TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
1005       break;
1006     }
1007     case tensorflow::AttrValue::kI:
1008       TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1009       break;
1010     case tensorflow::AttrValue::kF:
1011       TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1012       break;
1013     case tensorflow::AttrValue::kB:
1014       TFE_OpSetAttrBool(op, attr_name, default_value.b());
1015       break;
1016     case tensorflow::AttrValue::kType:
1017       TFE_OpSetAttrType(op, attr_name,
1018                         static_cast<TF_DataType>(default_value.type()));
1019       break;
1020     case tensorflow::AttrValue::kShape: {
1021       const auto& tensor_shape = default_value.shape();
1022       if (tensor_shape.unknown_rank()) {
1023         TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1024       } else {
1025         const auto num_dims = tensor_shape.dim_size();
1026         std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1027         for (int i = 0; i < num_dims; ++i) {
1028           dims[i] = tensor_shape.dim(i).size();
1029         }
1030         TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1031       }
1032     } break;
1033     case tensorflow::AttrValue::kFunc: {
1034       const auto func_op = GetFunc(ctx, default_value.func(), status);
1035       if (!status->status.ok()) return;
1036       // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1037       // require TFE_Op* and just convert it internally a NameAttrValue, so
1038       // consider adding an overload to the C API to make this case easier.
1039       TFE_OpSetAttrFunction(op, attr_name, func_op);
1040       TFE_DeleteOp(func_op);
1041     } break;
1042     case tensorflow::AttrValue::kList: {
1043       // String
1044       if (const int s_size = default_value.list().s_size()) {
1045         absl::InlinedVector<const void*, 4> values_vector;
1046         absl::InlinedVector<size_t, 4> lengths_vector;
1047         for (int i = 0; i < s_size; ++i) {
1048           const string& v = default_value.list().s(i);
1049           values_vector.push_back(v.data());
1050           lengths_vector.push_back(v.size());
1051         }
1052         TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
1053                                 lengths_vector.data(), s_size);
1054       }
1055 
1056       // Int
1057       if (const int i_size = default_value.list().i_size()) {
1058         absl::InlinedVector<int64_t, 4> i_vector;
1059         for (int i = 0; i < i_size; ++i) {
1060           i_vector.push_back(default_value.list().i(i));
1061         }
1062         TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
1063       }
1064       // Float
1065       if (const int f_size = default_value.list().f_size()) {
1066         absl::InlinedVector<float, 4> f_vector;
1067         for (int i = 0; i < f_size; ++i) {
1068           f_vector.push_back(default_value.list().f(i));
1069         }
1070         TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
1071       }
1072       // Bool
1073       if (const int b_size = default_value.list().b_size()) {
1074         absl::InlinedVector<unsigned char, 4> b_vector;
1075         for (int i = 0; i < b_size; i++) {
1076           b_vector.push_back(default_value.list().b(i));
1077         }
1078         TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
1079       }
1080       // Type
1081       if (const int type_size = default_value.list().type_size()) {
1082         absl::InlinedVector<unsigned int, 4> type_vector;
1083         for (int i = 0; i < type_size; ++i) {
1084           type_vector.push_back(default_value.list().type(i));
1085         }
1086         TFE_OpSetAttrTypeList(
1087             op, attr_name,
1088             reinterpret_cast<const TF_DataType*>(type_vector.data()),
1089             type_size);
1090       }
1091 
1092       // Rest are not supported.
1093       if (default_value.list().shape_size() > 0 ||
1094           default_value.list().func_size() > 0 ||
1095           default_value.list().tensor_size() > 0) {
1096         TF_SetStatus(
1097             status, TF_UNIMPLEMENTED,
1098             tensorflow::strings::StrCat("Unable to get setfor default value: ",
1099                                         default_value.DebugString())
1100                 .data());
1101       }
1102     } break;
1103     case tensorflow::AttrValue::kTensor:
1104       TF_FALLTHROUGH_INTENDED;
1105     case tensorflow::AttrValue::kPlaceholder:
1106       TF_FALLTHROUGH_INTENDED;
1107     case tensorflow::AttrValue::VALUE_NOT_SET:
1108       TF_SetStatus(
1109           status, TF_UNIMPLEMENTED,
1110           tensorflow::strings::StrCat("Unable to get setfor default value: ",
1111                                       default_value.DebugString())
1112               .data());
1113   }
1114 }
1115 }  // namespace tensorflow
1116 
1117 namespace {
DefaultCustomDevicePack(TFE_Context * context,TFE_TensorHandle ** handles,int num_handles,TF_Status * status,void * device_info)1118 TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
1119                                           TFE_TensorHandle** handles,
1120                                           int num_handles, TF_Status* status,
1121                                           void* device_info) {
1122   TF_SetStatus(status, TF_UNIMPLEMENTED,
1123                "This custom device does not support packing tensors.");
1124   return nullptr;
1125 }
1126 }  // namespace
1127 
1128 extern "C" {
1129 
TFE_RegisterCustomDevice(TFE_Context * ctx,TFE_CustomDevice device,const char * device_name,void * device_info,TF_Status * status)1130 void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
1131                               const char* device_name, void* device_info,
1132                               TF_Status* status) {
1133   // Fill in default values for optional functionality.
1134   if (device.pack == nullptr) {
1135     device.pack = &DefaultCustomDevicePack;
1136   }
1137   auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
1138       ctx, device, device_info, device_name);
1139   status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
1140       device_name, std::move(custom_device));
1141 }
1142 
1143 }  // extern "C"
1144