1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/eager/abstract_tensor_handle.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/immediate_execution_operation.h"
32 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_op_internal.h"
35 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
36 #include "tensorflow/c/tf_buffer_internal.h"
37 #include "tensorflow/c/tf_tensor_internal.h"
38 #include "tensorflow/core/common_runtime/copy_tensor.h"
39 #include "tensorflow/core/common_runtime/device.h"
40 #include "tensorflow/core/common_runtime/device_factory.h"
41 #include "tensorflow/core/common_runtime/device_mgr.h"
42 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
43 #include "tensorflow/core/common_runtime/eager/context.h"
44 #include "tensorflow/core/common_runtime/eager/custom_device.h"
45 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
46 #include "tensorflow/core/common_runtime/eager/execute.h"
47 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
48 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
49 #include "tensorflow/core/common_runtime/function.h"
50 #include "tensorflow/core/framework/attr_value.pb.h"
51 #include "tensorflow/core/framework/device_attributes.pb.h"
52 #include "tensorflow/core/framework/function.h"
53 #include "tensorflow/core/framework/node_def_util.h"
54 #include "tensorflow/core/framework/rendezvous.h"
55 #include "tensorflow/core/framework/tensor_shape.pb.h"
56 #include "tensorflow/core/framework/types.h"
57 #include "tensorflow/core/platform/casts.h"
58 #include "tensorflow/core/platform/errors.h"
59 #include "tensorflow/core/platform/platform.h"
60 #include "tensorflow/core/platform/status.h"
61 #include "tensorflow/core/profiler/lib/traceme.h"
62 #include "tensorflow/core/protobuf/error_codes.pb.h"
63 #include "tensorflow/core/public/version.h"
64
65 // "tensorflow/core/platform/platform.h" must be included first before using
66 // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
67 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
68 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
69 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
70 #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
71
72 #if !defined(IS_MOBILE_PLATFORM)
73 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
74 #endif // !IS_MOBILE_PLATFORM
75
76 using tensorflow::string;
77
78 namespace {
79
DeviceName(const tensorflow::Device * d)80 string DeviceName(const tensorflow::Device* d) {
81 return (d == nullptr) ? "cpu:0" : d->name();
82 }
83
84 // Annotate eager runtime construction context to the given `function_def` as
85 // an attribute.
AnnotateEagerRuntimeConstructionContext(tensorflow::FunctionDef & function_def)86 void AnnotateEagerRuntimeConstructionContext(
87 tensorflow::FunctionDef& function_def) {
88 tensorflow::AttrValue value;
89 SetAttrValue("kEagerRuntime", &value);
90 (*function_def.mutable_attr())["_construction_context"] = value;
91 }
92
93 } // namespace
94
95 extern "C" {
96
TFE_NewContextOptions()97 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
98
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)99 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
100 size_t proto_len, TF_Status* status) {
101 TF_SetConfig(&options->session_options, proto, proto_len, status);
102 }
103
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)104 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
105 unsigned char enable) {
106 options->async = enable;
107 }
108
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)109 void TFE_ContextOptionsSetDevicePlacementPolicy(
110 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
111 options->device_placement_policy = policy;
112 }
113
TFE_DeleteContextOptions(TFE_ContextOptions * options)114 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
115
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)116 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
117 if (opts->use_tfrt) {
118 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
119 tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
120 opts->session_options.options,
121 static_cast<tensorflow::ContextDevicePlacementPolicy>(
122 opts->device_placement_policy),
123 opts->async, opts->use_tfrt_distributed_runtime);
124 #if !defined(IS_MOBILE_PLATFORM)
125 tfrt_context->SetDistributedManager(
126 tfrt::tf::CreateDistributedManagerContext(
127 tfrt_context->GetCoreRuntime()->GetHostContext()));
128 #endif // !IS_MOBILE_PLATFORM
129 return tensorflow::wrap(tfrt_context);
130 #else
131 status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
132 return nullptr;
133 #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
134 }
135 std::vector<std::unique_ptr<tensorflow::Device>> devices;
136 status->status = tensorflow::DeviceFactory::AddDevices(
137 opts->session_options.options, "/job:localhost/replica:0/task:0",
138 &devices);
139 if (!status->status.ok()) return nullptr;
140 std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
141 new tensorflow::DynamicDeviceMgr(std::move(devices)));
142
143 tensorflow::Rendezvous* r =
144 new tensorflow::IntraProcessRendezvous(device_mgr.get());
145 tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
146 opts->session_options.options,
147 static_cast<tensorflow::ContextDevicePlacementPolicy>(
148 opts->device_placement_policy),
149 opts->async, device_mgr.release(),
150 /*device_mgr_owned*/ true, r,
151 /*cluster_flr=*/nullptr,
152 /*collective_executor_mgr=*/nullptr,
153 /*run_eager_op_as_function=*/opts->run_eager_op_as_function,
154 /*jit_compile_rewrite=*/opts->jit_compile_rewrite);
155 #if !defined(IS_MOBILE_PLATFORM)
156 eager_context->SetDistributedManager(
157 std::make_unique<tensorflow::EagerContextDistributedManager>(
158 eager_context));
159 #endif // !IS_MOBILE_PLATFORM
160 return tensorflow::wrap(eager_context);
161 }
162
TFE_DeleteContext(TFE_Context * ctx)163 void TFE_DeleteContext(TFE_Context* ctx) {
164 if (ctx == nullptr) {
165 return;
166 }
167
168 // ctx->RefCountIsOne() should be true here.
169 tensorflow::unwrap(ctx)->Release();
170 }
171
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)172 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
173 TF_DeviceList* l = new TF_DeviceList;
174 tensorflow::unwrap(ctx)->ListDevices(&l->response);
175 return l;
176 }
177
TFE_ContextClearCaches(TFE_Context * ctx)178 void TFE_ContextClearCaches(TFE_Context* ctx) {
179 tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
180 }
181
182 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)183 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
184 int keep_alive_secs,
185 const void* proto,
186 size_t proto_len,
187 TF_Status* status) {
188 #if defined(IS_MOBILE_PLATFORM)
189 status->status = tensorflow::errors::Unimplemented(
190 "TFE_ContextSetServerDef not supported on mobile");
191 #else // !defined(IS_MOBILE_PLATFORM)
192 tensorflow::ServerDef server_def;
193 if (!server_def.ParseFromArray(proto, proto_len)) {
194 status->status = tensorflow::errors::InvalidArgument(
195 "Invalid tensorflow.ServerDef protocol buffer");
196 return;
197 }
198 status->status =
199 tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
200 server_def, /*reset_context=*/true, keep_alive_secs);
201 #endif // !IS_MOBILE_PLATFORM
202 }
203
TFE_ContextUpdateServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)204 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
205 int keep_alive_secs,
206 const void* proto,
207 size_t proto_len,
208 TF_Status* status) {
209 #if defined(IS_MOBILE_PLATFORM)
210 status->status = tensorflow::errors::Unimplemented(
211 "TFE_ContextSetServerDef not supported on mobile");
212 #else // !defined(IS_MOBILE_PLATFORM)
213 tensorflow::ServerDef server_def;
214 tensorflow::EagerContext* context =
215 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
216 if (!server_def.ParseFromArray(proto, proto_len)) {
217 status->status = tensorflow::errors::InvalidArgument(
218 "Invalid tensorflow.ServerDef protocol buffer");
219 return;
220 } else if (context->GetContextId() ==
221 tensorflow::EagerContext::kInvalidContextId) {
222 status->status = tensorflow::errors::InvalidArgument(
223 "Trying to update a context with invalid context id.");
224 }
225 status->status =
226 tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
227 server_def, /*reset_context=*/false, keep_alive_secs);
228 #endif // !IS_MOBILE_PLATFORM
229 }
230
TFE_ContextCheckAlive(TFE_Context * ctx,const char * worker_name,TF_Status * status)231 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
232 const char* worker_name,
233 TF_Status* status) {
234 #if defined(IS_MOBILE_PLATFORM)
235 status->status = tensorflow::errors::Unimplemented(
236 "TFE_ContextSetServerDef not supported on mobile");
237 return false;
238 #else // !defined(IS_MOBILE_PLATFORM)
239 bool is_alive;
240 status->status =
241 tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
242 worker_name, &is_alive);
243 return is_alive;
244 #endif // !IS_MOBILE_PLATFORM
245 }
246
TFE_ContextAsyncWait(TFE_Context * ctx,TF_Status * status)247 TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
248 TF_Status* status) {
249 #if defined(IS_MOBILE_PLATFORM)
250 status->status = tensorflow::Status::OK();
251 #else // !defined(IS_MOBILE_PLATFORM)
252 status->status = tensorflow::unwrap(ctx)->AsyncWait();
253 #endif // !IS_MOBILE_PLATFORM
254 }
255
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)256 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
257 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
258 tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
259 static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
260 }
261
262 // Note: this function looks up a thread local policy. So it should be called in
263 // the appropriate client thread. In particular, in async mode, it may not be
264 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)265 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
266 TFE_Context* ctx) {
267 return static_cast<TFE_ContextDevicePlacementPolicy>(
268 tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
269 }
270
TFE_NewTensorHandle(const TF_Tensor * t,TF_Status * status)271 TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
272 tensorflow::Tensor tensor;
273 status->status = tensorflow::TF_TensorToTensor(t, &tensor);
274 if (!status->status.ok()) return nullptr;
275
276 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
277 }
278
TFE_DeleteTensorHandle(TFE_TensorHandle * h)279 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
280 if (h == nullptr) return;
281
282 tensorflow::profiler::TraceMe activity(
283 "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
284 if (h) {
285 tensorflow::unwrap(h)->Release();
286 }
287 }
288
TFE_TensorHandleDataType(TFE_TensorHandle * h)289 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
290 return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
291 }
292
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)293 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
294 if (h == nullptr) {
295 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
296 return -1;
297 }
298
299 int num_dims = -1;
300 status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
301 return num_dims;
302 }
303
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)304 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
305 if (h == nullptr) {
306 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
307 return -1;
308 }
309
310 int64_t num_elements = -1;
311 status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
312 return num_elements;
313 }
314
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)315 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
316 TF_Status* status) {
317 if (h == nullptr) {
318 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
319 return -1;
320 }
321
322 int64_t dim = -1;
323 status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
324 return dim;
325 }
326
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)327 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
328 if (h == nullptr) {
329 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
330 return nullptr;
331 }
332 return tensorflow::unwrap(h)->DeviceName(&status->status);
333 }
334
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)335 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
336 TF_Status* status) {
337 if (h == nullptr) {
338 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
339 return nullptr;
340 }
341 return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
342 }
343
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)344 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
345 TFE_TensorHandle* h, TF_Status* status) {
346 if (h == nullptr) {
347 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
348 return nullptr;
349 }
350
351 return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
352 }
353
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)354 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
355 if (h == nullptr) {
356 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
357 return nullptr;
358 }
359
360 tensorflow::AbstractTensorInterface* t =
361 tensorflow::unwrap(h)->Resolve(&status->status);
362 if (t == nullptr) {
363 return nullptr;
364 }
365
366 return new TF_Tensor{t};
367 }
368
TFE_TensorHandleDevicePointer(TFE_TensorHandle * h,TF_Status * status)369 void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
370 if (h == nullptr) {
371 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
372 return nullptr;
373 }
374 tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
375 tensorflow::unwrap(h);
376 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
377 if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
378 return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
379 unwrapped_handle)
380 ->DevicePointer();
381 }
382 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
383 if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
384 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
385 return nullptr;
386 }
387 tensorflow::TensorHandle* handle =
388 tensorflow::TensorHandleFromInterface(unwrapped_handle);
389
390 if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
391 status->status = tensorflow::errors::InvalidArgument(
392 "TFE_TensorHandleDevicePointer may not be called on a ",
393 handle->TypeString(), " tensor handle.");
394 return nullptr;
395 }
396 tensorflow::Device* device(handle->device());
397 if (device != nullptr) {
398 status->status = device->Sync();
399 if (!status->status.ok()) {
400 return nullptr;
401 }
402 }
403 const tensorflow::Tensor* tensor;
404 status->status = handle->Tensor(&tensor);
405 if (!status->status.ok()) {
406 return nullptr;
407 }
408 return const_cast<void*>(
409 static_cast<const void*>(tensor->tensor_data().data()));
410 }
411
412 namespace tensorflow {
413 namespace {
414 class CustomDeviceAPI : public tensorflow::CustomDevice {
415 public:
CustomDeviceAPI(TFE_Context * context,TFE_CustomDevice device,void * info,string name)416 CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
417 string name)
418 : context_(context), device_(device), info_(info), name_(name) {}
419
~CustomDeviceAPI()420 ~CustomDeviceAPI() override { device_.delete_device(info_); }
421
name()422 const string& name() override { return name_; }
423
CopyTensorToDevice(ImmediateExecutionTensorHandle * handle,ImmediateExecutionTensorHandle ** result)424 tensorflow::Status CopyTensorToDevice(
425 ImmediateExecutionTensorHandle* handle,
426 ImmediateExecutionTensorHandle** result) override {
427 handle->Ref();
428 TF_Status status;
429 TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
430 context_, tensorflow::wrap(handle), &status, info_);
431 handle->Release();
432 if (!status.status.ok()) return status.status;
433 *result = tensorflow::unwrap(result_handle);
434 (*result)->Ref();
435 TFE_DeleteTensorHandle(result_handle);
436 return status.status;
437 }
438
CopyTensorFromDevice(ImmediateExecutionTensorHandle * handle,const tensorflow::string & target_device_name,ImmediateExecutionTensorHandle ** result)439 tensorflow::Status CopyTensorFromDevice(
440 ImmediateExecutionTensorHandle* handle,
441 const tensorflow::string& target_device_name,
442 ImmediateExecutionTensorHandle** result) override {
443 TF_Status status;
444 handle->Ref();
445 TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
446 context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
447 info_);
448 handle->Release();
449 if (!status.status.ok()) return status.status;
450 *result = tensorflow::unwrap(result_handle);
451 (*result)->Ref();
452 TFE_DeleteTensorHandle(result_handle);
453 return status.status;
454 }
455
Execute(const ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)456 tensorflow::Status Execute(const ImmediateExecutionOperation* op,
457 ImmediateExecutionTensorHandle** retvals,
458 int* num_retvals) override {
459 std::vector<TFE_TensorHandle*> outputs(*num_retvals);
460 TF_Status status;
461 device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
462 info_);
463 if (status.status.ok()) {
464 for (int i = 0; i < *num_retvals; ++i) {
465 retvals[i] = tensorflow::unwrap(outputs[i]);
466 retvals[i]->Ref();
467 TFE_DeleteTensorHandle(outputs[i]);
468 }
469 }
470 return status.status;
471 }
472
Pack(absl::Span<ImmediateExecutionTensorHandle * > handles,ImmediateExecutionTensorHandle ** result)473 tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
474 ImmediateExecutionTensorHandle** result) override {
475 TF_Status status;
476 *result = tensorflow::unwrap(device_.pack(context_,
477 tensorflow::wrap(handles.data()),
478 handles.size(), &status, info_));
479 return status.status;
480 }
481
482 private:
483 TFE_Context* context_;
484 TFE_CustomDevice device_;
485 void* info_;
486 string name_;
487 };
488
489 // An adapter which wraps the shape/data produced by C custom devices and uses
490 // it to implement custom device methods.
491 class CAPICustomDeviceTensorHandle
492 : public tensorflow::CustomDeviceTensorHandle {
493 public:
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext * context,tensorflow::CustomDevice * device,tensorflow::DataType dtype,void * data,TFE_CustomDeviceTensorHandleMethods methods)494 CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
495 tensorflow::CustomDevice* device,
496 tensorflow::DataType dtype, void* data,
497 TFE_CustomDeviceTensorHandleMethods methods)
498 : tensorflow::CustomDeviceTensorHandle(context, device, dtype),
499 data_(data),
500 methods_(methods) {}
501
~CAPICustomDeviceTensorHandle()502 ~CAPICustomDeviceTensorHandle() override { methods_.deallocator(data_); }
DevicePointer() const503 void* DevicePointer() const override { return data_; }
NumDims(int * num_dims) const504 Status NumDims(int* num_dims) const override {
505 TF_Status s;
506 *num_dims = methods_.num_dims(data_, &s);
507 return s.status;
508 }
Dim(int dim_index,int64_t * dim) const509 Status Dim(int dim_index, int64_t* dim) const override {
510 TF_Status s;
511 *dim = methods_.dim(data_, dim_index, &s);
512 return s.status;
513 }
514
PreferCustomSummarizer() const515 bool PreferCustomSummarizer() const override {
516 return methods_.summarize != nullptr;
517 }
518
SummarizeValue(std::string & summary) const519 Status SummarizeValue(std::string& summary) const override {
520 if (methods_.summarize == nullptr) {
521 return tensorflow::CustomDeviceTensorHandle::SummarizeValue(summary);
522 }
523 TF_Status c_status;
524 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> summary_buffer(
525 methods_.summarize(data_, &c_status), TF_DeleteBuffer);
526 if (!c_status.status.ok()) {
527 return c_status.status;
528 }
529 summary = std::string(reinterpret_cast<const char*>(summary_buffer->data),
530 summary_buffer->length);
531 return OkStatus();
532 }
533
534 private:
535 void* const data_;
536 const TFE_CustomDeviceTensorHandleMethods methods_;
537 };
538
539 } // namespace
540 } // namespace tensorflow
541
TFE_NewCustomDeviceTensorHandle(TFE_Context * ctx,const char * device_name,TF_DataType dtype,void * data,TFE_CustomDeviceTensorHandleMethods methods,TF_Status * status)542 TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
543 TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data,
544 TFE_CustomDeviceTensorHandleMethods methods, TF_Status* status) {
545 tensorflow::ImmediateExecutionContext* context = tensorflow::unwrap(ctx);
546 tensorflow::CustomDevice* device = nullptr;
547 if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
548 &device)) {
549 methods.deallocator(data);
550 status->status =
551 tensorflow::errors::InvalidArgument(device_name, " unknown device.");
552 return nullptr;
553 }
554 return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
555 context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
556 methods));
557 }
558
TFE_NewTensorHandleFromDeviceMemory(TFE_Context * ctx,const char * device_name,TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg,TF_Status * status)559 TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
560 TFE_Context* ctx, const char* device_name, TF_DataType dtype,
561 const int64_t* dims, int num_dims, void* data, size_t len,
562 void (*deallocator)(void* data, size_t len, void* arg),
563 void* deallocator_arg, TF_Status* status) {
564 tensorflow::Device* device = nullptr;
565 tensorflow::EagerContext* context =
566 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
567 status->status = context->FindDeviceFromName(device_name, &device);
568 if (!status->status.ok()) {
569 deallocator(data, len, deallocator_arg);
570 status->status =
571 tensorflow::errors::InvalidArgument(device_name, " unknown device.");
572 return nullptr;
573 }
574 std::vector<int64_t> dimvec(num_dims);
575 for (int i = 0; i < num_dims; ++i) {
576 dimvec[i] = static_cast<int64_t>(dims[i]);
577 }
578
579 // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
580 // the device?
581 TF_ManagedBuffer* buf =
582 new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
583 /*owns_memory=*/false);
584
585 tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
586 tensorflow::TensorShape(dimvec), buf);
587 buf->Unref();
588 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
589 std::move(t), device, device, context));
590 }
591
592 // This function will block till the operation that produces `h` has
593 // completed. This is only valid on local TFE_TensorHandles. Returns the size in
594 // bytes of the memory pointed to by the device pointer returned above.
TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle * h,TF_Status * status)595 size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
596 TF_Status* status) {
597 if (h == nullptr) {
598 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
599 return 0;
600 }
601 tensorflow::TensorHandle* handle =
602 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
603 if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
604 status->status = tensorflow::errors::InvalidArgument(
605 "TFE_TensorHandleDeviceMemorySize may not be called on a ",
606 handle->TypeString(), " tensor handle.");
607 return 0;
608 }
609 const tensorflow::Tensor* tensor;
610 status->status = handle->Tensor(&tensor);
611 if (!status->status.ok()) {
612 return 0;
613 }
614 return tensor->TotalBytes();
615 }
616
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)617 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
618 TF_Status* status) {
619 tensorflow::ImmediateExecutionOperation* new_op =
620 tensorflow::unwrap(ctx)->CreateOperation();
621 status->status = new_op->Reset(op_or_function_name, nullptr);
622 if (!status->status.ok()) {
623 new_op->Release();
624 new_op = nullptr;
625 }
626 return tensorflow::wrap(new_op);
627 }
628
TFE_DeleteOp(TFE_Op * op)629 void TFE_DeleteOp(TFE_Op* op) {
630 if (op == nullptr) {
631 return;
632 }
633
634 tensorflow::unwrap(op)->Release();
635 }
636
TFE_OpGetName(const TFE_Op * op,TF_Status * status)637 const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
638 return tensorflow::unwrap(op)->Name().c_str();
639 }
640
TFE_OpGetContext(const TFE_Op * op,TF_Status * status)641 TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
642 return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
643 }
644
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)645 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
646 status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
647 }
648
TFE_OpGetDevice(const TFE_Op * op,TF_Status * status)649 const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
650 return tensorflow::unwrap(op)->DeviceName().c_str();
651 }
652
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)653 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
654 status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
655 }
656
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)657 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
658 TF_Status* status) {
659 status->status = tensorflow::unwrap(op)->AddInputList(
660 {reinterpret_cast<tensorflow::AbstractTensorHandle**>(
661 tensorflow::unwrap(inputs)),
662 static_cast<size_t>(num_inputs)});
663 }
664
TFE_OpGetFlatInputCount(const TFE_Op * op,TF_Status * status)665 extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
666 return tensorflow::unwrap(op)->GetInputs().size();
667 }
668
TFE_OpGetFlatInput(const TFE_Op * op,int index,TF_Status * status)669 extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
670 TF_Status* status) {
671 return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
672 }
673
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)674 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
675 unsigned char* is_list, TF_Status* status) {
676 TF_AttrType ret = TF_ATTR_INT;
677 const tensorflow::AttrTypeMap* attr_types_;
678 bool is_function;
679 status->status = tensorflow::AttrTypeMapForOp(
680 tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
681 if (!status->status.ok()) {
682 return ret;
683 }
684 status->status =
685 tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
686 return ret;
687 }
688
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)689 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
690 const char* op_or_function_name,
691 const char* attr_name, unsigned char* is_list,
692 TF_Status* status) {
693 TF_AttrType ret;
694 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
695 if (status->status.ok()) {
696 ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
697 } else {
698 ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
699 }
700 TFE_DeleteOp(op);
701 return ret;
702 }
703
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)704 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
705 size_t length) {
706 auto s = tensorflow::unwrap(op)->SetAttrString(
707 attr_name, static_cast<const char*>(value), length);
708 if (!s.ok()) {
709 LOG(WARNING) << "Unable to set attribute: " << attr_name;
710 }
711 }
712
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)713 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
714 auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
715 if (!s.ok()) {
716 LOG(WARNING) << "Unable to set attribute: " << attr_name;
717 }
718 }
719
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)720 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
721 auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
722 if (!s.ok()) {
723 LOG(WARNING) << "Unable to set attribute: " << attr_name;
724 }
725 }
726
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)727 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
728 auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
729 (value == 0) ? false : true);
730 if (!s.ok()) {
731 LOG(WARNING) << "Unable to set attribute: " << attr_name;
732 }
733 }
734
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)735 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
736 auto s = tensorflow::unwrap(op)->SetAttrType(
737 attr_name, static_cast<tensorflow::DataType>(value));
738 if (!s.ok()) {
739 LOG(WARNING) << "Unable to set attribute: " << attr_name;
740 }
741 }
742
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)743 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
744 const int num_dims, TF_Status* out_status) {
745 out_status->status =
746 tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
747 }
748
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)749 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
750 const TFE_Op* value) {
751 auto s = tensorflow::unwrap(op)->SetAttrFunction(
752 attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
753 if (!s.ok()) {
754 LOG(WARNING) << "Unable to set attribute: " << attr_name;
755 }
756 }
757
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)758 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
759 const char* data, size_t length) {
760 auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
761 if (!s.ok()) {
762 LOG(WARNING) << "Unable to set attribute: " << attr_name;
763 }
764 }
765
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)766 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
767 TF_Status* status) {
768 tensorflow::Tensor t;
769 status->status = TF_TensorToTensor(tensor, &t);
770 tensorflow::TensorInterface interface(t);
771 status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
772 }
773
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)774 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
775 const void* const* values, const size_t* lengths,
776 int num_values) {
777 auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
778 num_values);
779 if (!s.ok()) {
780 LOG(WARNING) << "Unable to set attribute: " << attr_name;
781 }
782 }
783
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)784 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
785 const float* values, int num_values) {
786 auto s =
787 tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
788 if (!s.ok()) {
789 LOG(WARNING) << "Unable to set attribute: " << attr_name;
790 }
791 }
792
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)793 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
794 const int64_t* values, int num_values) {
795 auto s =
796 tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
797 if (!s.ok()) {
798 LOG(WARNING) << "Unable to set attribute: " << attr_name;
799 }
800 }
801
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)802 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
803 const TF_DataType* values, int num_values) {
804 auto s = tensorflow::unwrap(op)->SetAttrTypeList(
805 attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
806 num_values);
807 if (!s.ok()) {
808 LOG(WARNING) << "Unable to set attribute: " << attr_name;
809 }
810 }
811
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)812 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
813 const unsigned char* values, int num_values) {
814 auto s =
815 tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
816 if (!s.ok()) {
817 LOG(WARNING) << "Unable to set attribute: " << attr_name;
818 }
819 }
820
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)821 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
822 const int64_t** dims, const int* num_dims,
823 int num_values, TF_Status* out_status) {
824 out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
825 attr_name, dims, num_dims, num_values);
826 }
827
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)828 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
829 const TFE_Op** value, int num_values) {
830 auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
831 attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
832 tensorflow::unwrap(value)),
833 static_cast<size_t>(num_values)});
834 if (!s.ok()) {
835 LOG(WARNING) << "Unable to set attribute: " << attr_name;
836 }
837 }
838
TFE_OpSetAttrValueProto(const TFE_Op * op,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)839 void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
840 const void* proto, size_t proto_len,
841 TF_Status* status) {
842 tensorflow::AttrValue attr_value;
843 if (!attr_value.ParseFromArray(proto, proto_len)) {
844 status->status =
845 tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
846 return;
847 }
848 if (op == nullptr) {
849 status->status = tensorflow::errors::InvalidArgument(
850 "Got a null or uninitialized `op` argument");
851 return;
852 }
853 tensorflow::EagerOperation* operation =
854 OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
855 operation->MutableAttrs()->Set(attr_name, attr_value);
856 }
857
TFE_OpGetInputLength(TFE_Op * op,const char * input_name,TF_Status * status)858 TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
859 const char* input_name,
860 TF_Status* status) {
861 int ret = -1;
862 status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
863 return ret;
864 }
865
TFE_OpGetOutputLength(TFE_Op * op,const char * output_name,TF_Status * status)866 TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
867 const char* output_name,
868 TF_Status* status) {
869 int ret = -1;
870 status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
871 return ret;
872 }
873
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)874 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
875 TF_Status* status) {
876 tensorflow::ImmediateExecutionOperation* unwrapped_op =
877 tensorflow::unwrap(op);
878
879 status->status =
880 unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
881 unwrapped_op,
882 reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
883 retvals),
884 num_retvals);
885 }
886
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)887 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
888 TFE_Context* ctx,
889 const char* device_name,
890 TF_Status* status) {
891 if (h == nullptr) {
892 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
893 return nullptr;
894 }
895
896 tensorflow::ImmediateExecutionContext* unwrapped_ctx =
897 tensorflow::unwrap(ctx);
898
899 auto* result =
900 unwrapped_ctx->GetCustomDeviceOpHandler().CopyTensorHandleToDevice(
901 unwrapped_ctx, tensorflow::unwrap(h), device_name, &status->status);
902
903 if (status->status.ok()) {
904 return tensorflow::wrap(result);
905 }
906 return nullptr;
907 }
908
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)909 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
910 const char* serialized_function_def, size_t size,
911 TF_Status* status) {
912 tensorflow::FunctionDef function_def;
913 if (!function_def.ParseFromArray(serialized_function_def, size)) {
914 status->status =
915 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
916 return;
917 }
918
919 AnnotateEagerRuntimeConstructionContext(function_def);
920 status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
921 }
922
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)923 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
924 TF_Status* status) {
925 AnnotateEagerRuntimeConstructionContext(function->fdef);
926 status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
927 function->fdef, function->stack_traces);
928 }
929
TFE_ContextRemoveFunction(TFE_Context * ctx,const char * name,TF_Status * status)930 void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
931 TF_Status* status) {
932 status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
933 }
934
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)935 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
936 return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
937 }
938
TFE_ContextEnableRunMetadata(TFE_Context * ctx)939 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
940 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
941 }
942
TFE_ContextDisableRunMetadata(TFE_Context * ctx)943 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
944 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
945 }
946
947 } // extern "C"
948
TFE_NewTensorHandle(const tensorflow::Tensor & t,TF_Status * status)949 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
950 TF_Status* status) {
951 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
952 }
953
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)954 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
955 TF_Status* status) {
956 auto* context = tensorflow::unwrap(ctx);
957 status->status = context->AsyncWait();
958 if (!status->status.ok()) return;
959 auto run_metadata = context->ExportRunMetadata();
960 status->status = MessageToBuffer(*run_metadata, buf);
961 }
962
963 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)964 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
965 TF_Status* status) {
966 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
967 for (const auto& attr : func.attr()) {
968 if (!status->status.ok()) return nullptr;
969 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
970 if (!status->status.ok()) return nullptr;
971 }
972 return func_op;
973 }
974 } // namespace
975
TFE_ContextStartStep(TFE_Context * ctx)976 void TFE_ContextStartStep(TFE_Context* ctx) {
977 tensorflow::unwrap(ctx)->StartStep();
978 }
979
TFE_ContextEndStep(TFE_Context * ctx)980 void TFE_ContextEndStep(TFE_Context* ctx) {
981 tensorflow::unwrap(ctx)->EndStep();
982 }
983
TFE_OpGetAttrs(const TFE_Op * op)984 const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
985 return tensorflow::wrap(tensorflow::unwrap(op)->GetOpAttrs());
986 }
987
TFE_OpAddAttrs(TFE_Op * op,const TFE_OpAttrs * attrs)988 void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
989 tensorflow::unwrap(op)->AddAttrs(tensorflow::unwrap(attrs));
990 }
991
TFE_OpAttrsSerialize(const TFE_OpAttrs * attrs,TF_Buffer * buf,TF_Status * status)992 void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
993 TF_Status* status) {
994 tensorflow::NameAttrList name_and_attrs;
995 tensorflow::unwrap(attrs)->GetNameAttrList(&name_and_attrs);
996 status->status = MessageToBuffer(name_and_attrs, buf);
997 }
998
999 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)1000 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
1001 const tensorflow::AttrValue& default_value,
1002 const char* attr_name, TF_Status* status) {
1003 switch (default_value.value_case()) {
1004 case tensorflow::AttrValue::kS: {
1005 const string& v = default_value.s();
1006 TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
1007 break;
1008 }
1009 case tensorflow::AttrValue::kI:
1010 TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1011 break;
1012 case tensorflow::AttrValue::kF:
1013 TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1014 break;
1015 case tensorflow::AttrValue::kB:
1016 TFE_OpSetAttrBool(op, attr_name, default_value.b());
1017 break;
1018 case tensorflow::AttrValue::kType:
1019 TFE_OpSetAttrType(op, attr_name,
1020 static_cast<TF_DataType>(default_value.type()));
1021 break;
1022 case tensorflow::AttrValue::kShape: {
1023 const auto& tensor_shape = default_value.shape();
1024 if (tensor_shape.unknown_rank()) {
1025 TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1026 } else {
1027 const auto num_dims = tensor_shape.dim_size();
1028 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1029 for (int i = 0; i < num_dims; ++i) {
1030 dims[i] = tensor_shape.dim(i).size();
1031 }
1032 TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1033 }
1034 } break;
1035 case tensorflow::AttrValue::kFunc: {
1036 const auto func_op = GetFunc(ctx, default_value.func(), status);
1037 if (!status->status.ok()) return;
1038 // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1039 // require TFE_Op* and just convert it internally a NameAttrValue, so
1040 // consider adding an overload to the C API to make this case easier.
1041 TFE_OpSetAttrFunction(op, attr_name, func_op);
1042 TFE_DeleteOp(func_op);
1043 } break;
1044 case tensorflow::AttrValue::kList: {
1045 // String
1046 if (const int s_size = default_value.list().s_size()) {
1047 absl::InlinedVector<const void*, 4> values_vector;
1048 values_vector.reserve(s_size);
1049 absl::InlinedVector<size_t, 4> lengths_vector;
1050 lengths_vector.reserve(s_size);
1051 for (int i = 0; i < s_size; ++i) {
1052 const string& v = default_value.list().s(i);
1053 values_vector.push_back(v.data());
1054 lengths_vector.push_back(v.size());
1055 }
1056 TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
1057 lengths_vector.data(), s_size);
1058 }
1059
1060 // Int
1061 if (const int i_size = default_value.list().i_size()) {
1062 absl::InlinedVector<int64_t, 4> i_vector;
1063 i_vector.reserve(i_size);
1064 for (int i = 0; i < i_size; ++i) {
1065 i_vector.push_back(default_value.list().i(i));
1066 }
1067 TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
1068 }
1069 // Float
1070 if (const int f_size = default_value.list().f_size()) {
1071 absl::InlinedVector<float, 4> f_vector;
1072 f_vector.reserve(f_size);
1073 for (int i = 0; i < f_size; ++i) {
1074 f_vector.push_back(default_value.list().f(i));
1075 }
1076 TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
1077 }
1078 // Bool
1079 if (const int b_size = default_value.list().b_size()) {
1080 absl::InlinedVector<unsigned char, 4> b_vector;
1081 b_vector.reserve(b_size);
1082 for (int i = 0; i < b_size; i++) {
1083 b_vector.push_back(default_value.list().b(i));
1084 }
1085 TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
1086 }
1087 // Type
1088 if (const int type_size = default_value.list().type_size()) {
1089 absl::InlinedVector<unsigned int, 4> type_vector;
1090 type_vector.reserve(type_size);
1091 for (int i = 0; i < type_size; ++i) {
1092 type_vector.push_back(default_value.list().type(i));
1093 }
1094 TFE_OpSetAttrTypeList(
1095 op, attr_name,
1096 reinterpret_cast<const TF_DataType*>(type_vector.data()),
1097 type_size);
1098 }
1099
1100 // Rest are not supported.
1101 if (default_value.list().shape_size() > 0 ||
1102 default_value.list().func_size() > 0 ||
1103 default_value.list().tensor_size() > 0) {
1104 TF_SetStatus(
1105 status, TF_UNIMPLEMENTED,
1106 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1107 default_value.DebugString())
1108 .data());
1109 }
1110 } break;
1111 case tensorflow::AttrValue::kTensor:
1112 TF_FALLTHROUGH_INTENDED;
1113 case tensorflow::AttrValue::kPlaceholder:
1114 TF_FALLTHROUGH_INTENDED;
1115 case tensorflow::AttrValue::VALUE_NOT_SET:
1116 TF_SetStatus(
1117 status, TF_UNIMPLEMENTED,
1118 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1119 default_value.DebugString())
1120 .data());
1121 }
1122 }
1123 } // namespace tensorflow
1124
1125 namespace {
DefaultCustomDevicePack(TFE_Context * context,TFE_TensorHandle ** handles,int num_handles,TF_Status * status,void * device_info)1126 TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
1127 TFE_TensorHandle** handles,
1128 int num_handles, TF_Status* status,
1129 void* device_info) {
1130 TF_SetStatus(status, TF_UNIMPLEMENTED,
1131 "This custom device does not support packing tensors.");
1132 return nullptr;
1133 }
1134 } // namespace
1135
1136 extern "C" {
1137
TFE_RegisterCustomDevice(TFE_Context * ctx,TFE_CustomDevice device,const char * device_name,void * device_info,TF_Status * status)1138 void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
1139 const char* device_name, void* device_info,
1140 TF_Status* status) {
1141 // Fill in default values for optional functionality.
1142 if (device.pack == nullptr) {
1143 device.pack = &DefaultCustomDevicePack;
1144 }
1145 auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
1146 ctx, device, device_info, device_name);
1147 status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
1148 device_name, std::move(custom_device));
1149 }
1150
1151 } // extern "C"
1152