• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/eager/abstract_tensor_handle.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/immediate_execution_operation.h"
32 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_op_internal.h"
35 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
36 #include "tensorflow/c/tf_buffer_internal.h"
37 #include "tensorflow/c/tf_tensor_internal.h"
38 #include "tensorflow/core/common_runtime/copy_tensor.h"
39 #include "tensorflow/core/common_runtime/device.h"
40 #include "tensorflow/core/common_runtime/device_factory.h"
41 #include "tensorflow/core/common_runtime/device_mgr.h"
42 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
43 #include "tensorflow/core/common_runtime/eager/context.h"
44 #include "tensorflow/core/common_runtime/eager/custom_device.h"
45 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
46 #include "tensorflow/core/common_runtime/eager/execute.h"
47 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
48 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
49 #include "tensorflow/core/common_runtime/function.h"
50 #include "tensorflow/core/framework/attr_value.pb.h"
51 #include "tensorflow/core/framework/device_attributes.pb.h"
52 #include "tensorflow/core/framework/function.h"
53 #include "tensorflow/core/framework/node_def_util.h"
54 #include "tensorflow/core/framework/rendezvous.h"
55 #include "tensorflow/core/framework/tensor_shape.pb.h"
56 #include "tensorflow/core/framework/types.h"
57 #include "tensorflow/core/platform/casts.h"
58 #include "tensorflow/core/platform/errors.h"
59 #include "tensorflow/core/platform/platform.h"
60 #include "tensorflow/core/platform/status.h"
61 #include "tensorflow/core/profiler/lib/traceme.h"
62 #include "tensorflow/core/protobuf/error_codes.pb.h"
63 #include "tensorflow/core/public/version.h"
64 
65 // "tensorflow/core/platform/platform.h" must be included first before using
66 // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
67 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
68 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
69 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
70 #endif  // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
71 
72 #if !defined(IS_MOBILE_PLATFORM)
73 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
74 #endif  // !IS_MOBILE_PLATFORM
75 
76 using tensorflow::string;
77 
78 namespace {
79 
DeviceName(const tensorflow::Device * d)80 string DeviceName(const tensorflow::Device* d) {
81   return (d == nullptr) ? "cpu:0" : d->name();
82 }
83 
84 // Annotate eager runtime construction context to the given `function_def` as
85 // an attribute.
AnnotateEagerRuntimeConstructionContext(tensorflow::FunctionDef & function_def)86 void AnnotateEagerRuntimeConstructionContext(
87     tensorflow::FunctionDef& function_def) {
88   tensorflow::AttrValue value;
89   SetAttrValue("kEagerRuntime", &value);
90   (*function_def.mutable_attr())["_construction_context"] = value;
91 }
92 
93 }  // namespace
94 
95 extern "C" {
96 
TFE_NewContextOptions()97 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
98 
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)99 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
100                                  size_t proto_len, TF_Status* status) {
101   TF_SetConfig(&options->session_options, proto, proto_len, status);
102 }
103 
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)104 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
105                                 unsigned char enable) {
106   options->async = enable;
107 }
108 
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)109 void TFE_ContextOptionsSetDevicePlacementPolicy(
110     TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
111   options->device_placement_policy = policy;
112 }
113 
TFE_DeleteContextOptions(TFE_ContextOptions * options)114 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
115 
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)116 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
117   if (opts->use_tfrt) {
118 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
119     tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
120         opts->session_options.options,
121         static_cast<tensorflow::ContextDevicePlacementPolicy>(
122             opts->device_placement_policy),
123         opts->async, opts->use_tfrt_distributed_runtime);
124 #if !defined(IS_MOBILE_PLATFORM)
125     tfrt_context->SetDistributedManager(
126         tfrt::tf::CreateDistributedManagerContext(
127             tfrt_context->GetCoreRuntime()->GetHostContext()));
128 #endif  // !IS_MOBILE_PLATFORM
129     return tensorflow::wrap(tfrt_context);
130 #else
131     status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
132     return nullptr;
133 #endif  // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
134   }
135   std::vector<std::unique_ptr<tensorflow::Device>> devices;
136   status->status = tensorflow::DeviceFactory::AddDevices(
137       opts->session_options.options, "/job:localhost/replica:0/task:0",
138       &devices);
139   if (!status->status.ok()) return nullptr;
140   std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
141       new tensorflow::DynamicDeviceMgr(std::move(devices)));
142 
143   tensorflow::Rendezvous* r =
144       new tensorflow::IntraProcessRendezvous(device_mgr.get());
145   tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
146       opts->session_options.options,
147       static_cast<tensorflow::ContextDevicePlacementPolicy>(
148           opts->device_placement_policy),
149       opts->async, device_mgr.release(),
150       /*device_mgr_owned*/ true, r,
151       /*cluster_flr=*/nullptr,
152       /*collective_executor_mgr=*/nullptr,
153       /*run_eager_op_as_function=*/opts->run_eager_op_as_function,
154       /*jit_compile_rewrite=*/opts->jit_compile_rewrite);
155 #if !defined(IS_MOBILE_PLATFORM)
156   eager_context->SetDistributedManager(
157       std::make_unique<tensorflow::EagerContextDistributedManager>(
158           eager_context));
159 #endif  // !IS_MOBILE_PLATFORM
160   return tensorflow::wrap(eager_context);
161 }
162 
TFE_DeleteContext(TFE_Context * ctx)163 void TFE_DeleteContext(TFE_Context* ctx) {
164   if (ctx == nullptr) {
165     return;
166   }
167 
168   // ctx->RefCountIsOne() should be true here.
169   tensorflow::unwrap(ctx)->Release();
170 }
171 
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)172 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
173   TF_DeviceList* l = new TF_DeviceList;
174   tensorflow::unwrap(ctx)->ListDevices(&l->response);
175   return l;
176 }
177 
TFE_ContextClearCaches(TFE_Context * ctx)178 void TFE_ContextClearCaches(TFE_Context* ctx) {
179   tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
180 }
181 
182 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)183 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
184                                                    int keep_alive_secs,
185                                                    const void* proto,
186                                                    size_t proto_len,
187                                                    TF_Status* status) {
188 #if defined(IS_MOBILE_PLATFORM)
189   status->status = tensorflow::errors::Unimplemented(
190       "TFE_ContextSetServerDef not supported on mobile");
191 #else   // !defined(IS_MOBILE_PLATFORM)
192   tensorflow::ServerDef server_def;
193   if (!server_def.ParseFromArray(proto, proto_len)) {
194     status->status = tensorflow::errors::InvalidArgument(
195         "Invalid tensorflow.ServerDef protocol buffer");
196     return;
197   }
198   status->status =
199       tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
200           server_def, /*reset_context=*/true, keep_alive_secs);
201 #endif  // !IS_MOBILE_PLATFORM
202 }
203 
TFE_ContextUpdateServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)204 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
205                                                       int keep_alive_secs,
206                                                       const void* proto,
207                                                       size_t proto_len,
208                                                       TF_Status* status) {
209 #if defined(IS_MOBILE_PLATFORM)
210   status->status = tensorflow::errors::Unimplemented(
211       "TFE_ContextSetServerDef not supported on mobile");
212 #else   // !defined(IS_MOBILE_PLATFORM)
213   tensorflow::ServerDef server_def;
214   tensorflow::EagerContext* context =
215       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
216   if (!server_def.ParseFromArray(proto, proto_len)) {
217     status->status = tensorflow::errors::InvalidArgument(
218         "Invalid tensorflow.ServerDef protocol buffer");
219     return;
220   } else if (context->GetContextId() ==
221              tensorflow::EagerContext::kInvalidContextId) {
222     status->status = tensorflow::errors::InvalidArgument(
223         "Trying to update a context with invalid context id.");
224   }
225   status->status =
226       tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
227           server_def, /*reset_context=*/false, keep_alive_secs);
228 #endif  // !IS_MOBILE_PLATFORM
229 }
230 
TFE_ContextCheckAlive(TFE_Context * ctx,const char * worker_name,TF_Status * status)231 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
232                                                  const char* worker_name,
233                                                  TF_Status* status) {
234 #if defined(IS_MOBILE_PLATFORM)
235   status->status = tensorflow::errors::Unimplemented(
236       "TFE_ContextSetServerDef not supported on mobile");
237   return false;
238 #else   // !defined(IS_MOBILE_PLATFORM)
239   bool is_alive;
240   status->status =
241       tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
242           worker_name, &is_alive);
243   return is_alive;
244 #endif  // !IS_MOBILE_PLATFORM
245 }
246 
TFE_ContextAsyncWait(TFE_Context * ctx,TF_Status * status)247 TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
248                                                 TF_Status* status) {
249 #if defined(IS_MOBILE_PLATFORM)
250   status->status = tensorflow::Status::OK();
251 #else   // !defined(IS_MOBILE_PLATFORM)
252   status->status = tensorflow::unwrap(ctx)->AsyncWait();
253 #endif  // !IS_MOBILE_PLATFORM
254 }
255 
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)256 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
257     TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
258   tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
259       static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
260 }
261 
262 // Note: this function looks up a thread local policy. So it should be called in
263 // the appropriate client thread. In particular, in async mode, it may not be
264 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)265 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
266     TFE_Context* ctx) {
267   return static_cast<TFE_ContextDevicePlacementPolicy>(
268       tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
269 }
270 
TFE_NewTensorHandle(const TF_Tensor * t,TF_Status * status)271 TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
272   tensorflow::Tensor tensor;
273   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
274   if (!status->status.ok()) return nullptr;
275 
276   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
277 }
278 
TFE_DeleteTensorHandle(TFE_TensorHandle * h)279 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
280   if (h == nullptr) return;
281 
282   tensorflow::profiler::TraceMe activity(
283       "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
284   if (h) {
285     tensorflow::unwrap(h)->Release();
286   }
287 }
288 
TFE_TensorHandleDataType(TFE_TensorHandle * h)289 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
290   return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
291 }
292 
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)293 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
294   if (h == nullptr) {
295     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
296     return -1;
297   }
298 
299   int num_dims = -1;
300   status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
301   return num_dims;
302 }
303 
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)304 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
305   if (h == nullptr) {
306     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
307     return -1;
308   }
309 
310   int64_t num_elements = -1;
311   status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
312   return num_elements;
313 }
314 
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)315 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
316                             TF_Status* status) {
317   if (h == nullptr) {
318     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
319     return -1;
320   }
321 
322   int64_t dim = -1;
323   status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
324   return dim;
325 }
326 
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)327 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
328   if (h == nullptr) {
329     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
330     return nullptr;
331   }
332   return tensorflow::unwrap(h)->DeviceName(&status->status);
333 }
334 
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)335 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
336                                               TF_Status* status) {
337   if (h == nullptr) {
338     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
339     return nullptr;
340   }
341   return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
342 }
343 
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)344 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
345     TFE_TensorHandle* h, TF_Status* status) {
346   if (h == nullptr) {
347     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
348     return nullptr;
349   }
350 
351   return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
352 }
353 
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)354 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
355   if (h == nullptr) {
356     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
357     return nullptr;
358   }
359 
360   tensorflow::AbstractTensorInterface* t =
361       tensorflow::unwrap(h)->Resolve(&status->status);
362   if (t == nullptr) {
363     return nullptr;
364   }
365 
366   return new TF_Tensor{t};
367 }
368 
TFE_TensorHandleDevicePointer(TFE_TensorHandle * h,TF_Status * status)369 void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
370   if (h == nullptr) {
371     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
372     return nullptr;
373   }
374   tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
375       tensorflow::unwrap(h);
376   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
377   if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
378     return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
379                unwrapped_handle)
380         ->DevicePointer();
381   }
382   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
383   if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
384     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
385     return nullptr;
386   }
387   tensorflow::TensorHandle* handle =
388       tensorflow::TensorHandleFromInterface(unwrapped_handle);
389 
390   if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
391     status->status = tensorflow::errors::InvalidArgument(
392         "TFE_TensorHandleDevicePointer may not be called on a ",
393         handle->TypeString(), " tensor handle.");
394     return nullptr;
395   }
396   tensorflow::Device* device(handle->device());
397   if (device != nullptr) {
398     status->status = device->Sync();
399     if (!status->status.ok()) {
400       return nullptr;
401     }
402   }
403   const tensorflow::Tensor* tensor;
404   status->status = handle->Tensor(&tensor);
405   if (!status->status.ok()) {
406     return nullptr;
407   }
408   return const_cast<void*>(
409       static_cast<const void*>(tensor->tensor_data().data()));
410 }
411 
412 namespace tensorflow {
413 namespace {
414 class CustomDeviceAPI : public tensorflow::CustomDevice {
415  public:
CustomDeviceAPI(TFE_Context * context,TFE_CustomDevice device,void * info,string name)416   CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
417                   string name)
418       : context_(context), device_(device), info_(info), name_(name) {}
419 
~CustomDeviceAPI()420   ~CustomDeviceAPI() override { device_.delete_device(info_); }
421 
name()422   const string& name() override { return name_; }
423 
CopyTensorToDevice(ImmediateExecutionTensorHandle * handle,ImmediateExecutionTensorHandle ** result)424   tensorflow::Status CopyTensorToDevice(
425       ImmediateExecutionTensorHandle* handle,
426       ImmediateExecutionTensorHandle** result) override {
427     handle->Ref();
428     TF_Status status;
429     TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
430         context_, tensorflow::wrap(handle), &status, info_);
431     handle->Release();
432     if (!status.status.ok()) return status.status;
433     *result = tensorflow::unwrap(result_handle);
434     (*result)->Ref();
435     TFE_DeleteTensorHandle(result_handle);
436     return status.status;
437   }
438 
CopyTensorFromDevice(ImmediateExecutionTensorHandle * handle,const tensorflow::string & target_device_name,ImmediateExecutionTensorHandle ** result)439   tensorflow::Status CopyTensorFromDevice(
440       ImmediateExecutionTensorHandle* handle,
441       const tensorflow::string& target_device_name,
442       ImmediateExecutionTensorHandle** result) override {
443     TF_Status status;
444     handle->Ref();
445     TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
446         context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
447         info_);
448     handle->Release();
449     if (!status.status.ok()) return status.status;
450     *result = tensorflow::unwrap(result_handle);
451     (*result)->Ref();
452     TFE_DeleteTensorHandle(result_handle);
453     return status.status;
454   }
455 
Execute(const ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)456   tensorflow::Status Execute(const ImmediateExecutionOperation* op,
457                              ImmediateExecutionTensorHandle** retvals,
458                              int* num_retvals) override {
459     std::vector<TFE_TensorHandle*> outputs(*num_retvals);
460     TF_Status status;
461     device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
462                     info_);
463     if (status.status.ok()) {
464       for (int i = 0; i < *num_retvals; ++i) {
465         retvals[i] = tensorflow::unwrap(outputs[i]);
466         retvals[i]->Ref();
467         TFE_DeleteTensorHandle(outputs[i]);
468       }
469     }
470     return status.status;
471   }
472 
Pack(absl::Span<ImmediateExecutionTensorHandle * > handles,ImmediateExecutionTensorHandle ** result)473   tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
474                           ImmediateExecutionTensorHandle** result) override {
475     TF_Status status;
476     *result = tensorflow::unwrap(device_.pack(context_,
477                                               tensorflow::wrap(handles.data()),
478                                               handles.size(), &status, info_));
479     return status.status;
480   }
481 
482  private:
483   TFE_Context* context_;
484   TFE_CustomDevice device_;
485   void* info_;
486   string name_;
487 };
488 
489 // An adapter which wraps the shape/data produced by C custom devices and uses
490 // it to implement custom device methods.
491 class CAPICustomDeviceTensorHandle
492     : public tensorflow::CustomDeviceTensorHandle {
493  public:
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext * context,tensorflow::CustomDevice * device,tensorflow::DataType dtype,void * data,TFE_CustomDeviceTensorHandleMethods methods)494   CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
495                                tensorflow::CustomDevice* device,
496                                tensorflow::DataType dtype, void* data,
497                                TFE_CustomDeviceTensorHandleMethods methods)
498       : tensorflow::CustomDeviceTensorHandle(context, device, dtype),
499         data_(data),
500         methods_(methods) {}
501 
~CAPICustomDeviceTensorHandle()502   ~CAPICustomDeviceTensorHandle() override { methods_.deallocator(data_); }
DevicePointer() const503   void* DevicePointer() const override { return data_; }
NumDims(int * num_dims) const504   Status NumDims(int* num_dims) const override {
505     TF_Status s;
506     *num_dims = methods_.num_dims(data_, &s);
507     return s.status;
508   }
Dim(int dim_index,int64_t * dim) const509   Status Dim(int dim_index, int64_t* dim) const override {
510     TF_Status s;
511     *dim = methods_.dim(data_, dim_index, &s);
512     return s.status;
513   }
514 
PreferCustomSummarizer() const515   bool PreferCustomSummarizer() const override {
516     return methods_.summarize != nullptr;
517   }
518 
SummarizeValue(std::string & summary) const519   Status SummarizeValue(std::string& summary) const override {
520     if (methods_.summarize == nullptr) {
521       return tensorflow::CustomDeviceTensorHandle::SummarizeValue(summary);
522     }
523     TF_Status c_status;
524     std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> summary_buffer(
525         methods_.summarize(data_, &c_status), TF_DeleteBuffer);
526     if (!c_status.status.ok()) {
527       return c_status.status;
528     }
529     summary = std::string(reinterpret_cast<const char*>(summary_buffer->data),
530                           summary_buffer->length);
531     return OkStatus();
532   }
533 
534  private:
535   void* const data_;
536   const TFE_CustomDeviceTensorHandleMethods methods_;
537 };
538 
539 }  // namespace
540 }  // namespace tensorflow
541 
TFE_NewCustomDeviceTensorHandle(TFE_Context * ctx,const char * device_name,TF_DataType dtype,void * data,TFE_CustomDeviceTensorHandleMethods methods,TF_Status * status)542 TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
543     TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data,
544     TFE_CustomDeviceTensorHandleMethods methods, TF_Status* status) {
545   tensorflow::ImmediateExecutionContext* context = tensorflow::unwrap(ctx);
546   tensorflow::CustomDevice* device = nullptr;
547   if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
548                                                                     &device)) {
549     methods.deallocator(data);
550     status->status =
551         tensorflow::errors::InvalidArgument(device_name, " unknown device.");
552     return nullptr;
553   }
554   return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
555       context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
556       methods));
557 }
558 
TFE_NewTensorHandleFromDeviceMemory(TFE_Context * ctx,const char * device_name,TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg,TF_Status * status)559 TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
560     TFE_Context* ctx, const char* device_name, TF_DataType dtype,
561     const int64_t* dims, int num_dims, void* data, size_t len,
562     void (*deallocator)(void* data, size_t len, void* arg),
563     void* deallocator_arg, TF_Status* status) {
564   tensorflow::Device* device = nullptr;
565   tensorflow::EagerContext* context =
566       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
567   status->status = context->FindDeviceFromName(device_name, &device);
568   if (!status->status.ok()) {
569     deallocator(data, len, deallocator_arg);
570     status->status =
571         tensorflow::errors::InvalidArgument(device_name, " unknown device.");
572     return nullptr;
573   }
574   std::vector<int64_t> dimvec(num_dims);
575   for (int i = 0; i < num_dims; ++i) {
576     dimvec[i] = static_cast<int64_t>(dims[i]);
577   }
578 
579   // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
580   // the device?
581   TF_ManagedBuffer* buf =
582       new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
583                            /*owns_memory=*/false);
584 
585   tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
586                        tensorflow::TensorShape(dimvec), buf);
587   buf->Unref();
588   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
589       std::move(t), device, device, context));
590 }
591 
592 // This function will block till the operation that produces `h` has
593 // completed. This is only valid on local TFE_TensorHandles. Returns the size in
594 // bytes of the memory pointed to by the device pointer returned above.
TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle * h,TF_Status * status)595 size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
596                                         TF_Status* status) {
597   if (h == nullptr) {
598     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
599     return 0;
600   }
601   tensorflow::TensorHandle* handle =
602       tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
603   if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
604     status->status = tensorflow::errors::InvalidArgument(
605         "TFE_TensorHandleDeviceMemorySize may not be called on a ",
606         handle->TypeString(), " tensor handle.");
607     return 0;
608   }
609   const tensorflow::Tensor* tensor;
610   status->status = handle->Tensor(&tensor);
611   if (!status->status.ok()) {
612     return 0;
613   }
614   return tensor->TotalBytes();
615 }
616 
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)617 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
618                   TF_Status* status) {
619   tensorflow::ImmediateExecutionOperation* new_op =
620       tensorflow::unwrap(ctx)->CreateOperation();
621   status->status = new_op->Reset(op_or_function_name, nullptr);
622   if (!status->status.ok()) {
623     new_op->Release();
624     new_op = nullptr;
625   }
626   return tensorflow::wrap(new_op);
627 }
628 
TFE_DeleteOp(TFE_Op * op)629 void TFE_DeleteOp(TFE_Op* op) {
630   if (op == nullptr) {
631     return;
632   }
633 
634   tensorflow::unwrap(op)->Release();
635 }
636 
TFE_OpGetName(const TFE_Op * op,TF_Status * status)637 const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
638   return tensorflow::unwrap(op)->Name().c_str();
639 }
640 
TFE_OpGetContext(const TFE_Op * op,TF_Status * status)641 TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
642   return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
643 }
644 
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)645 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
646   status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
647 }
648 
TFE_OpGetDevice(const TFE_Op * op,TF_Status * status)649 const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
650   return tensorflow::unwrap(op)->DeviceName().c_str();
651 }
652 
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)653 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
654   status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
655 }
656 
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)657 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
658                         TF_Status* status) {
659   status->status = tensorflow::unwrap(op)->AddInputList(
660       {reinterpret_cast<tensorflow::AbstractTensorHandle**>(
661            tensorflow::unwrap(inputs)),
662        static_cast<size_t>(num_inputs)});
663 }
664 
TFE_OpGetFlatInputCount(const TFE_Op * op,TF_Status * status)665 extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
666   return tensorflow::unwrap(op)->GetInputs().size();
667 }
668 
TFE_OpGetFlatInput(const TFE_Op * op,int index,TF_Status * status)669 extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
670                                             TF_Status* status) {
671   return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
672 }
673 
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)674 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
675                               unsigned char* is_list, TF_Status* status) {
676   TF_AttrType ret = TF_ATTR_INT;
677   const tensorflow::AttrTypeMap* attr_types_;
678   bool is_function;
679   status->status = tensorflow::AttrTypeMapForOp(
680       tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
681   if (!status->status.ok()) {
682     return ret;
683   }
684   status->status =
685       tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
686   return ret;
687 }
688 
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)689 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
690                                   const char* op_or_function_name,
691                                   const char* attr_name, unsigned char* is_list,
692                                   TF_Status* status) {
693   TF_AttrType ret;
694   TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
695   if (status->status.ok()) {
696     ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
697   } else {
698     ret = TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
699   }
700   TFE_DeleteOp(op);
701   return ret;
702 }
703 
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)704 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
705                          size_t length) {
706   auto s = tensorflow::unwrap(op)->SetAttrString(
707       attr_name, static_cast<const char*>(value), length);
708   if (!s.ok()) {
709     LOG(WARNING) << "Unable to set attribute: " << attr_name;
710   }
711 }
712 
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)713 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
714   auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
715   if (!s.ok()) {
716     LOG(WARNING) << "Unable to set attribute: " << attr_name;
717   }
718 }
719 
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)720 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
721   auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
722   if (!s.ok()) {
723     LOG(WARNING) << "Unable to set attribute: " << attr_name;
724   }
725 }
726 
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)727 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
728   auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
729                                                (value == 0) ? false : true);
730   if (!s.ok()) {
731     LOG(WARNING) << "Unable to set attribute: " << attr_name;
732   }
733 }
734 
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)735 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
736   auto s = tensorflow::unwrap(op)->SetAttrType(
737       attr_name, static_cast<tensorflow::DataType>(value));
738   if (!s.ok()) {
739     LOG(WARNING) << "Unable to set attribute: " << attr_name;
740   }
741 }
742 
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)743 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
744                         const int num_dims, TF_Status* out_status) {
745   out_status->status =
746       tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
747 }
748 
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)749 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
750                            const TFE_Op* value) {
751   auto s = tensorflow::unwrap(op)->SetAttrFunction(
752       attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
753   if (!s.ok()) {
754     LOG(WARNING) << "Unable to set attribute: " << attr_name;
755   }
756 }
757 
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)758 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
759                                const char* data, size_t length) {
760   auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
761   if (!s.ok()) {
762     LOG(WARNING) << "Unable to set attribute: " << attr_name;
763   }
764 }
765 
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)766 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
767                          TF_Status* status) {
768   tensorflow::Tensor t;
769   status->status = TF_TensorToTensor(tensor, &t);
770   tensorflow::TensorInterface interface(t);
771   status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
772 }
773 
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)774 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
775                              const void* const* values, const size_t* lengths,
776                              int num_values) {
777   auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
778                                                      num_values);
779   if (!s.ok()) {
780     LOG(WARNING) << "Unable to set attribute: " << attr_name;
781   }
782 }
783 
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)784 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
785                             const float* values, int num_values) {
786   auto s =
787       tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
788   if (!s.ok()) {
789     LOG(WARNING) << "Unable to set attribute: " << attr_name;
790   }
791 }
792 
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)793 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
794                           const int64_t* values, int num_values) {
795   auto s =
796       tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
797   if (!s.ok()) {
798     LOG(WARNING) << "Unable to set attribute: " << attr_name;
799   }
800 }
801 
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)802 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
803                            const TF_DataType* values, int num_values) {
804   auto s = tensorflow::unwrap(op)->SetAttrTypeList(
805       attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
806       num_values);
807   if (!s.ok()) {
808     LOG(WARNING) << "Unable to set attribute: " << attr_name;
809   }
810 }
811 
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)812 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
813                            const unsigned char* values, int num_values) {
814   auto s =
815       tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
816   if (!s.ok()) {
817     LOG(WARNING) << "Unable to set attribute: " << attr_name;
818   }
819 }
820 
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)821 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
822                             const int64_t** dims, const int* num_dims,
823                             int num_values, TF_Status* out_status) {
824   out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
825       attr_name, dims, num_dims, num_values);
826 }
827 
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)828 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
829                                const TFE_Op** value, int num_values) {
830   auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
831       attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
832                       tensorflow::unwrap(value)),
833                   static_cast<size_t>(num_values)});
834   if (!s.ok()) {
835     LOG(WARNING) << "Unable to set attribute: " << attr_name;
836   }
837 }
838 
TFE_OpSetAttrValueProto(const TFE_Op * op,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)839 void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
840                              const void* proto, size_t proto_len,
841                              TF_Status* status) {
842   tensorflow::AttrValue attr_value;
843   if (!attr_value.ParseFromArray(proto, proto_len)) {
844     status->status =
845         tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
846     return;
847   }
848   if (op == nullptr) {
849     status->status = tensorflow::errors::InvalidArgument(
850         "Got a null or uninitialized `op` argument");
851     return;
852   }
853   tensorflow::EagerOperation* operation =
854       OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
855   operation->MutableAttrs()->Set(attr_name, attr_value);
856 }
857 
TFE_OpGetInputLength(TFE_Op * op,const char * input_name,TF_Status * status)858 TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
859                                                const char* input_name,
860                                                TF_Status* status) {
861   int ret = -1;
862   status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
863   return ret;
864 }
865 
TFE_OpGetOutputLength(TFE_Op * op,const char * output_name,TF_Status * status)866 TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
867                                                 const char* output_name,
868                                                 TF_Status* status) {
869   int ret = -1;
870   status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
871   return ret;
872 }
873 
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)874 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
875                  TF_Status* status) {
876   tensorflow::ImmediateExecutionOperation* unwrapped_op =
877       tensorflow::unwrap(op);
878 
879   status->status =
880       unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
881           unwrapped_op,
882           reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
883               retvals),
884           num_retvals);
885 }
886 
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)887 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
888                                                TFE_Context* ctx,
889                                                const char* device_name,
890                                                TF_Status* status) {
891   if (h == nullptr) {
892     status->status = tensorflow::errors::InvalidArgument("Invalid handle");
893     return nullptr;
894   }
895 
896   tensorflow::ImmediateExecutionContext* unwrapped_ctx =
897       tensorflow::unwrap(ctx);
898 
899   auto* result =
900       unwrapped_ctx->GetCustomDeviceOpHandler().CopyTensorHandleToDevice(
901           unwrapped_ctx, tensorflow::unwrap(h), device_name, &status->status);
902 
903   if (status->status.ok()) {
904     return tensorflow::wrap(result);
905   }
906   return nullptr;
907 }
908 
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)909 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
910                                const char* serialized_function_def, size_t size,
911                                TF_Status* status) {
912   tensorflow::FunctionDef function_def;
913   if (!function_def.ParseFromArray(serialized_function_def, size)) {
914     status->status =
915         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
916     return;
917   }
918 
919   AnnotateEagerRuntimeConstructionContext(function_def);
920   status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
921 }
922 
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)923 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
924                             TF_Status* status) {
925   AnnotateEagerRuntimeConstructionContext(function->fdef);
926   status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
927       function->fdef, function->stack_traces);
928 }
929 
TFE_ContextRemoveFunction(TFE_Context * ctx,const char * name,TF_Status * status)930 void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
931                                TF_Status* status) {
932   status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
933 }
934 
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)935 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
936   return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
937 }
938 
TFE_ContextEnableRunMetadata(TFE_Context * ctx)939 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
940   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
941 }
942 
TFE_ContextDisableRunMetadata(TFE_Context * ctx)943 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
944   tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
945 }
946 
947 }  // extern "C"
948 
TFE_NewTensorHandle(const tensorflow::Tensor & t,TF_Status * status)949 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
950                                       TF_Status* status) {
951   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
952 }
953 
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)954 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
955                                   TF_Status* status) {
956   auto* context = tensorflow::unwrap(ctx);
957   status->status = context->AsyncWait();
958   if (!status->status.ok()) return;
959   auto run_metadata = context->ExportRunMetadata();
960   status->status = MessageToBuffer(*run_metadata, buf);
961 }
962 
963 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)964 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
965                 TF_Status* status) {
966   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
967   for (const auto& attr : func.attr()) {
968     if (!status->status.ok()) return nullptr;
969     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
970     if (!status->status.ok()) return nullptr;
971   }
972   return func_op;
973 }
974 }  // namespace
975 
TFE_ContextStartStep(TFE_Context * ctx)976 void TFE_ContextStartStep(TFE_Context* ctx) {
977   tensorflow::unwrap(ctx)->StartStep();
978 }
979 
TFE_ContextEndStep(TFE_Context * ctx)980 void TFE_ContextEndStep(TFE_Context* ctx) {
981   tensorflow::unwrap(ctx)->EndStep();
982 }
983 
TFE_OpGetAttrs(const TFE_Op * op)984 const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
985   return tensorflow::wrap(tensorflow::unwrap(op)->GetOpAttrs());
986 }
987 
TFE_OpAddAttrs(TFE_Op * op,const TFE_OpAttrs * attrs)988 void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
989   tensorflow::unwrap(op)->AddAttrs(tensorflow::unwrap(attrs));
990 }
991 
TFE_OpAttrsSerialize(const TFE_OpAttrs * attrs,TF_Buffer * buf,TF_Status * status)992 void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
993                           TF_Status* status) {
994   tensorflow::NameAttrList name_and_attrs;
995   tensorflow::unwrap(attrs)->GetNameAttrList(&name_and_attrs);
996   status->status = MessageToBuffer(name_and_attrs, buf);
997 }
998 
999 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)1000 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
1001                           const tensorflow::AttrValue& default_value,
1002                           const char* attr_name, TF_Status* status) {
1003   switch (default_value.value_case()) {
1004     case tensorflow::AttrValue::kS: {
1005       const string& v = default_value.s();
1006       TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
1007       break;
1008     }
1009     case tensorflow::AttrValue::kI:
1010       TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1011       break;
1012     case tensorflow::AttrValue::kF:
1013       TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1014       break;
1015     case tensorflow::AttrValue::kB:
1016       TFE_OpSetAttrBool(op, attr_name, default_value.b());
1017       break;
1018     case tensorflow::AttrValue::kType:
1019       TFE_OpSetAttrType(op, attr_name,
1020                         static_cast<TF_DataType>(default_value.type()));
1021       break;
1022     case tensorflow::AttrValue::kShape: {
1023       const auto& tensor_shape = default_value.shape();
1024       if (tensor_shape.unknown_rank()) {
1025         TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1026       } else {
1027         const auto num_dims = tensor_shape.dim_size();
1028         std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1029         for (int i = 0; i < num_dims; ++i) {
1030           dims[i] = tensor_shape.dim(i).size();
1031         }
1032         TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1033       }
1034     } break;
1035     case tensorflow::AttrValue::kFunc: {
1036       const auto func_op = GetFunc(ctx, default_value.func(), status);
1037       if (!status->status.ok()) return;
1038       // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1039       // require TFE_Op* and just convert it internally a NameAttrValue, so
1040       // consider adding an overload to the C API to make this case easier.
1041       TFE_OpSetAttrFunction(op, attr_name, func_op);
1042       TFE_DeleteOp(func_op);
1043     } break;
1044     case tensorflow::AttrValue::kList: {
1045       // String
1046       if (const int s_size = default_value.list().s_size()) {
1047         absl::InlinedVector<const void*, 4> values_vector;
1048         values_vector.reserve(s_size);
1049         absl::InlinedVector<size_t, 4> lengths_vector;
1050         lengths_vector.reserve(s_size);
1051         for (int i = 0; i < s_size; ++i) {
1052           const string& v = default_value.list().s(i);
1053           values_vector.push_back(v.data());
1054           lengths_vector.push_back(v.size());
1055         }
1056         TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
1057                                 lengths_vector.data(), s_size);
1058       }
1059 
1060       // Int
1061       if (const int i_size = default_value.list().i_size()) {
1062         absl::InlinedVector<int64_t, 4> i_vector;
1063         i_vector.reserve(i_size);
1064         for (int i = 0; i < i_size; ++i) {
1065           i_vector.push_back(default_value.list().i(i));
1066         }
1067         TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
1068       }
1069       // Float
1070       if (const int f_size = default_value.list().f_size()) {
1071         absl::InlinedVector<float, 4> f_vector;
1072         f_vector.reserve(f_size);
1073         for (int i = 0; i < f_size; ++i) {
1074           f_vector.push_back(default_value.list().f(i));
1075         }
1076         TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
1077       }
1078       // Bool
1079       if (const int b_size = default_value.list().b_size()) {
1080         absl::InlinedVector<unsigned char, 4> b_vector;
1081         b_vector.reserve(b_size);
1082         for (int i = 0; i < b_size; i++) {
1083           b_vector.push_back(default_value.list().b(i));
1084         }
1085         TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
1086       }
1087       // Type
1088       if (const int type_size = default_value.list().type_size()) {
1089         absl::InlinedVector<unsigned int, 4> type_vector;
1090         type_vector.reserve(type_size);
1091         for (int i = 0; i < type_size; ++i) {
1092           type_vector.push_back(default_value.list().type(i));
1093         }
1094         TFE_OpSetAttrTypeList(
1095             op, attr_name,
1096             reinterpret_cast<const TF_DataType*>(type_vector.data()),
1097             type_size);
1098       }
1099 
1100       // Rest are not supported.
1101       if (default_value.list().shape_size() > 0 ||
1102           default_value.list().func_size() > 0 ||
1103           default_value.list().tensor_size() > 0) {
1104         TF_SetStatus(
1105             status, TF_UNIMPLEMENTED,
1106             tensorflow::strings::StrCat("Unable to get setfor default value: ",
1107                                         default_value.DebugString())
1108                 .data());
1109       }
1110     } break;
1111     case tensorflow::AttrValue::kTensor:
1112       TF_FALLTHROUGH_INTENDED;
1113     case tensorflow::AttrValue::kPlaceholder:
1114       TF_FALLTHROUGH_INTENDED;
1115     case tensorflow::AttrValue::VALUE_NOT_SET:
1116       TF_SetStatus(
1117           status, TF_UNIMPLEMENTED,
1118           tensorflow::strings::StrCat("Unable to get setfor default value: ",
1119                                       default_value.DebugString())
1120               .data());
1121   }
1122 }
1123 }  // namespace tensorflow
1124 
1125 namespace {
DefaultCustomDevicePack(TFE_Context * context,TFE_TensorHandle ** handles,int num_handles,TF_Status * status,void * device_info)1126 TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
1127                                           TFE_TensorHandle** handles,
1128                                           int num_handles, TF_Status* status,
1129                                           void* device_info) {
1130   TF_SetStatus(status, TF_UNIMPLEMENTED,
1131                "This custom device does not support packing tensors.");
1132   return nullptr;
1133 }
1134 }  // namespace
1135 
1136 extern "C" {
1137 
TFE_RegisterCustomDevice(TFE_Context * ctx,TFE_CustomDevice device,const char * device_name,void * device_info,TF_Status * status)1138 void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
1139                               const char* device_name, void* device_info,
1140                               TF_Status* status) {
1141   // Fill in default values for optional functionality.
1142   if (device.pack == nullptr) {
1143     device.pack = &DefaultCustomDevicePack;
1144   }
1145   auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
1146       ctx, device, device_info, device_name);
1147   status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
1148       device_name, std::move(custom_device));
1149 }
1150 
1151 }  // extern "C"
1152