1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/eager/abstract_tensor_handle.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/immediate_execution_operation.h"
32 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_op_internal.h"
35 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
36 #include "tensorflow/c/tf_tensor_internal.h"
37 #include "tensorflow/core/common_runtime/copy_tensor.h"
38 #include "tensorflow/core/common_runtime/device.h"
39 #include "tensorflow/core/common_runtime/device_factory.h"
40 #include "tensorflow/core/common_runtime/device_mgr.h"
41 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
42 #include "tensorflow/core/common_runtime/eager/context.h"
43 #include "tensorflow/core/common_runtime/eager/custom_device.h"
44 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
45 #include "tensorflow/core/common_runtime/eager/execute.h"
46 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
47 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
48 #include "tensorflow/core/common_runtime/function.h"
49 #include "tensorflow/core/framework/attr_value.pb.h"
50 #include "tensorflow/core/framework/device_attributes.pb.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/node_def_util.h"
53 #include "tensorflow/core/framework/rendezvous.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/platform/casts.h"
57 #include "tensorflow/core/platform/errors.h"
58 #include "tensorflow/core/platform/platform.h"
59 #include "tensorflow/core/platform/status.h"
60 #include "tensorflow/core/profiler/lib/traceme.h"
61 #include "tensorflow/core/protobuf/error_codes.pb.h"
62 #include "tensorflow/core/public/version.h"
63
64 // "tensorflow/core/platform/platform.h" must be included first before using
65 // PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
66 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
67 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
68 #include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
69 #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
70
71 #if !defined(IS_MOBILE_PLATFORM)
72 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
73 #endif // !IS_MOBILE_PLATFORM
74
75 using tensorflow::string;
76
77 namespace {
78
DeviceName(const tensorflow::Device * d)79 string DeviceName(const tensorflow::Device* d) {
80 return (d == nullptr) ? "cpu:0" : d->name();
81 }
82
83 // Annotate eager runtime construction context to the given `function_def` as
84 // an attribute.
AnnotateEagerRuntimeConstructionContext(tensorflow::FunctionDef & function_def)85 void AnnotateEagerRuntimeConstructionContext(
86 tensorflow::FunctionDef& function_def) {
87 tensorflow::AttrValue value;
88 SetAttrValue("kEagerRuntime", &value);
89 (*function_def.mutable_attr())["_construction_context"] = value;
90 }
91
92 } // namespace
93
94 extern "C" {
95
TFE_NewContextOptions()96 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
97
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)98 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
99 size_t proto_len, TF_Status* status) {
100 TF_SetConfig(&options->session_options, proto, proto_len, status);
101 }
102
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)103 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
104 unsigned char enable) {
105 options->async = enable;
106 }
107
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)108 void TFE_ContextOptionsSetDevicePlacementPolicy(
109 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
110 options->device_placement_policy = policy;
111 }
112
TFE_DeleteContextOptions(TFE_ContextOptions * options)113 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
114
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)115 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
116 if (opts->use_tfrt) {
117 #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
118 tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
119 opts->session_options.options,
120 static_cast<tensorflow::ContextDevicePlacementPolicy>(
121 opts->device_placement_policy),
122 opts->async, opts->use_tfrt_distributed_runtime);
123 #if !defined(IS_MOBILE_PLATFORM)
124 tfrt_context->SetDistributedManager(
125 tfrt::tf::CreateDistributedManagerContext(
126 tfrt_context->GetCoreRuntime()->GetHostContext()));
127 #endif // !IS_MOBILE_PLATFORM
128 return tensorflow::wrap(tfrt_context);
129 #else
130 status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
131 return nullptr;
132 #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
133 }
134 std::vector<std::unique_ptr<tensorflow::Device>> devices;
135 status->status = tensorflow::DeviceFactory::AddDevices(
136 opts->session_options.options, "/job:localhost/replica:0/task:0",
137 &devices);
138 if (!status->status.ok()) return nullptr;
139 std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
140 new tensorflow::DynamicDeviceMgr(std::move(devices)));
141
142 tensorflow::Rendezvous* r =
143 new tensorflow::IntraProcessRendezvous(device_mgr.get());
144 tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
145 opts->session_options.options,
146 static_cast<tensorflow::ContextDevicePlacementPolicy>(
147 opts->device_placement_policy),
148 opts->async, device_mgr.release(),
149 /*device_mgr_owned*/ true, r,
150 /*cluster_flr=*/nullptr,
151 /*collective_executor_mgr=*/nullptr,
152 /*run_eager_op_as_function=*/opts->run_eager_op_as_function);
153 #if !defined(IS_MOBILE_PLATFORM)
154 eager_context->SetDistributedManager(
155 std::make_unique<tensorflow::EagerContextDistributedManager>(
156 eager_context));
157 #endif // !IS_MOBILE_PLATFORM
158 return tensorflow::wrap(eager_context);
159 }
160
TFE_DeleteContext(TFE_Context * ctx)161 void TFE_DeleteContext(TFE_Context* ctx) {
162 if (ctx == nullptr) {
163 return;
164 }
165
166 // ctx->RefCountIsOne() should be true here.
167 tensorflow::unwrap(ctx)->Release();
168 }
169
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)170 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
171 TF_DeviceList* l = new TF_DeviceList;
172 tensorflow::unwrap(ctx)->ListDevices(&l->response);
173 return l;
174 }
175
TFE_ContextClearCaches(TFE_Context * ctx)176 void TFE_ContextClearCaches(TFE_Context* ctx) {
177 tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
178 }
179
180 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)181 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
182 int keep_alive_secs,
183 const void* proto,
184 size_t proto_len,
185 TF_Status* status) {
186 #if defined(IS_MOBILE_PLATFORM)
187 status->status = tensorflow::errors::Unimplemented(
188 "TFE_ContextSetServerDef not supported on mobile");
189 #else // !defined(IS_MOBILE_PLATFORM)
190 tensorflow::ServerDef server_def;
191 if (!server_def.ParseFromArray(proto, proto_len)) {
192 status->status = tensorflow::errors::InvalidArgument(
193 "Invalid tensorflow.ServerDef protocol buffer");
194 return;
195 }
196 status->status =
197 tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
198 server_def, /*reset_context=*/true, keep_alive_secs);
199 #endif // !IS_MOBILE_PLATFORM
200 }
201
TFE_ContextUpdateServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)202 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
203 int keep_alive_secs,
204 const void* proto,
205 size_t proto_len,
206 TF_Status* status) {
207 #if defined(IS_MOBILE_PLATFORM)
208 status->status = tensorflow::errors::Unimplemented(
209 "TFE_ContextSetServerDef not supported on mobile");
210 #else // !defined(IS_MOBILE_PLATFORM)
211 tensorflow::ServerDef server_def;
212 tensorflow::EagerContext* context =
213 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
214 if (!server_def.ParseFromArray(proto, proto_len)) {
215 status->status = tensorflow::errors::InvalidArgument(
216 "Invalid tensorflow.ServerDef protocol buffer");
217 return;
218 } else if (context->GetContextId() ==
219 tensorflow::EagerContext::kInvalidContextId) {
220 status->status = tensorflow::errors::InvalidArgument(
221 "Trying to update a context with invalid context id.");
222 }
223 status->status =
224 tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
225 server_def, /*reset_context=*/false, keep_alive_secs);
226 #endif // !IS_MOBILE_PLATFORM
227 }
228
TFE_ContextCheckAlive(TFE_Context * ctx,const char * worker_name,TF_Status * status)229 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
230 const char* worker_name,
231 TF_Status* status) {
232 #if defined(IS_MOBILE_PLATFORM)
233 status->status = tensorflow::errors::Unimplemented(
234 "TFE_ContextSetServerDef not supported on mobile");
235 return false;
236 #else // !defined(IS_MOBILE_PLATFORM)
237 bool is_alive;
238 status->status =
239 tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
240 worker_name, &is_alive);
241 return is_alive;
242 #endif // !IS_MOBILE_PLATFORM
243 }
244
TFE_ContextAsyncWait(TFE_Context * ctx,TF_Status * status)245 TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
246 TF_Status* status) {
247 #if defined(IS_MOBILE_PLATFORM)
248 status->status = tensorflow::Status::OK();
249 #else // !defined(IS_MOBILE_PLATFORM)
250 status->status = tensorflow::unwrap(ctx)->AsyncWait();
251 #endif // !IS_MOBILE_PLATFORM
252 }
253
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)254 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
255 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
256 tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
257 static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
258 }
259
260 // Note: this function looks up a thread local policy. So it should be called in
261 // the appropriate client thread. In particular, in async mode, it may not be
262 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)263 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
264 TFE_Context* ctx) {
265 return static_cast<TFE_ContextDevicePlacementPolicy>(
266 tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
267 }
268
TFE_NewTensorHandle(const TF_Tensor * t,TF_Status * status)269 TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
270 tensorflow::Tensor tensor;
271 status->status = tensorflow::TF_TensorToTensor(t, &tensor);
272 if (!status->status.ok()) return nullptr;
273
274 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
275 }
276
TFE_DeleteTensorHandle(TFE_TensorHandle * h)277 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
278 if (h == nullptr) return;
279
280 tensorflow::profiler::TraceMe activity(
281 "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
282 if (h) {
283 tensorflow::unwrap(h)->Release();
284 }
285 }
286
TFE_TensorHandleDataType(TFE_TensorHandle * h)287 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
288 return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
289 }
290
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)291 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
292 if (h == nullptr) {
293 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
294 return -1;
295 }
296
297 int num_dims = -1;
298 status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
299 return num_dims;
300 }
301
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)302 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
303 if (h == nullptr) {
304 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
305 return -1;
306 }
307
308 int64_t num_elements = -1;
309 status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
310 return num_elements;
311 }
312
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)313 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
314 TF_Status* status) {
315 if (h == nullptr) {
316 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
317 return -1;
318 }
319
320 int64_t dim = -1;
321 status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
322 return dim;
323 }
324
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)325 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
326 if (h == nullptr) {
327 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
328 return nullptr;
329 }
330 return tensorflow::unwrap(h)->DeviceName(&status->status);
331 }
332
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)333 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
334 TF_Status* status) {
335 if (h == nullptr) {
336 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
337 return nullptr;
338 }
339 return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
340 }
341
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)342 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
343 TFE_TensorHandle* h, TF_Status* status) {
344 if (h == nullptr) {
345 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
346 return nullptr;
347 }
348
349 return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
350 }
351
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)352 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
353 if (h == nullptr) {
354 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
355 return nullptr;
356 }
357
358 tensorflow::AbstractTensorInterface* t =
359 tensorflow::unwrap(h)->Resolve(&status->status);
360 if (t == nullptr) {
361 return nullptr;
362 }
363
364 return new TF_Tensor{t};
365 }
366
TFE_TensorHandleDevicePointer(TFE_TensorHandle * h,TF_Status * status)367 void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
368 if (h == nullptr) {
369 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
370 return nullptr;
371 }
372 tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
373 tensorflow::unwrap(h);
374 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
375 if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
376 return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
377 unwrapped_handle)
378 ->DevicePointer();
379 }
380 // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
381 if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
382 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
383 return nullptr;
384 }
385 tensorflow::TensorHandle* handle =
386 tensorflow::TensorHandleFromInterface(unwrapped_handle);
387
388 if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
389 status->status = tensorflow::errors::InvalidArgument(
390 "TFE_TensorHandleDevicePointer may not be called on a ",
391 handle->TypeString(), " tensor handle.");
392 return nullptr;
393 }
394 tensorflow::Device* device(handle->device());
395 if (device != nullptr) {
396 status->status = device->Sync();
397 if (!status->status.ok()) {
398 return nullptr;
399 }
400 }
401 const tensorflow::Tensor* tensor;
402 status->status = handle->Tensor(&tensor);
403 if (!status->status.ok()) {
404 return nullptr;
405 }
406 return const_cast<void*>(
407 static_cast<const void*>(tensor->tensor_data().data()));
408 }
409
410 namespace tensorflow {
411 namespace {
412 class CustomDeviceAPI : public tensorflow::CustomDevice {
413 public:
CustomDeviceAPI(TFE_Context * context,TFE_CustomDevice device,void * info,string name)414 CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
415 string name)
416 : context_(context), device_(device), info_(info), name_(name) {}
417
~CustomDeviceAPI()418 ~CustomDeviceAPI() override { device_.delete_device(info_); }
419
name()420 const string& name() override { return name_; }
421
CopyTensorToDevice(ImmediateExecutionTensorHandle * handle,ImmediateExecutionTensorHandle ** result)422 tensorflow::Status CopyTensorToDevice(
423 ImmediateExecutionTensorHandle* handle,
424 ImmediateExecutionTensorHandle** result) override {
425 handle->Ref();
426 TF_Status status;
427 TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
428 context_, tensorflow::wrap(handle), &status, info_);
429 handle->Release();
430 if (!status.status.ok()) return status.status;
431 *result = tensorflow::unwrap(result_handle);
432 (*result)->Ref();
433 TFE_DeleteTensorHandle(result_handle);
434 return status.status;
435 }
436
CopyTensorFromDevice(ImmediateExecutionTensorHandle * handle,const tensorflow::string & target_device_name,ImmediateExecutionTensorHandle ** result)437 tensorflow::Status CopyTensorFromDevice(
438 ImmediateExecutionTensorHandle* handle,
439 const tensorflow::string& target_device_name,
440 ImmediateExecutionTensorHandle** result) override {
441 TF_Status status;
442 handle->Ref();
443 TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
444 context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
445 info_);
446 handle->Release();
447 if (!status.status.ok()) return status.status;
448 *result = tensorflow::unwrap(result_handle);
449 (*result)->Ref();
450 TFE_DeleteTensorHandle(result_handle);
451 return status.status;
452 }
453
Execute(const ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)454 tensorflow::Status Execute(const ImmediateExecutionOperation* op,
455 ImmediateExecutionTensorHandle** retvals,
456 int* num_retvals) override {
457 std::vector<TFE_TensorHandle*> outputs(*num_retvals);
458 TF_Status status;
459 device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
460 info_);
461 if (status.status.ok()) {
462 for (int i = 0; i < *num_retvals; ++i) {
463 retvals[i] = tensorflow::unwrap(outputs[i]);
464 retvals[i]->Ref();
465 TFE_DeleteTensorHandle(outputs[i]);
466 }
467 }
468 return status.status;
469 }
470
Pack(absl::Span<ImmediateExecutionTensorHandle * > handles,ImmediateExecutionTensorHandle ** result)471 tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
472 ImmediateExecutionTensorHandle** result) override {
473 TF_Status status;
474 *result = tensorflow::unwrap(device_.pack(context_,
475 tensorflow::wrap(handles.data()),
476 handles.size(), &status, info_));
477 return status.status;
478 }
479
480 private:
481 TFE_Context* context_;
482 TFE_CustomDevice device_;
483 void* info_;
484 string name_;
485 };
486
487 // An adapter which wraps the shape/data produced by C custom devices and uses
488 // it to implement custom device methods.
489 class CAPICustomDeviceTensorHandle
490 : public tensorflow::CustomDeviceTensorHandle {
491 public:
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext * context,tensorflow::CustomDevice * device,tensorflow::DataType dtype,void * data,TFE_CustomDeviceTensorHandleMethods methods)492 CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
493 tensorflow::CustomDevice* device,
494 tensorflow::DataType dtype, void* data,
495 TFE_CustomDeviceTensorHandleMethods methods)
496 : tensorflow::CustomDeviceTensorHandle(context, device, dtype),
497 data_(data),
498 methods_(methods) {}
499
~CAPICustomDeviceTensorHandle()500 ~CAPICustomDeviceTensorHandle() override { methods_.deallocator(data_); }
DevicePointer() const501 void* DevicePointer() const override { return data_; }
NumDims(int * num_dims) const502 Status NumDims(int* num_dims) const override {
503 TF_Status s;
504 *num_dims = methods_.num_dims(data_, &s);
505 return s.status;
506 }
Dim(int dim_index,int64 * dim) const507 Status Dim(int dim_index, int64* dim) const override {
508 TF_Status s;
509 *dim = methods_.dim(data_, dim_index, &s);
510 return s.status;
511 }
512
PreferCustomSummarizer() const513 bool PreferCustomSummarizer() const override {
514 return methods_.summarize != nullptr;
515 }
516
SummarizeValue(std::string & summary) const517 Status SummarizeValue(std::string& summary) const override {
518 if (methods_.summarize == nullptr) {
519 return tensorflow::CustomDeviceTensorHandle::SummarizeValue(summary);
520 }
521 TF_Status c_status;
522 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> summary_buffer(
523 methods_.summarize(data_, &c_status), TF_DeleteBuffer);
524 if (!c_status.status.ok()) {
525 return c_status.status;
526 }
527 summary = std::string(reinterpret_cast<const char*>(summary_buffer->data),
528 summary_buffer->length);
529 return Status::OK();
530 }
531
532 private:
533 void* const data_;
534 const TFE_CustomDeviceTensorHandleMethods methods_;
535 };
536
537 } // namespace
538 } // namespace tensorflow
539
TFE_NewCustomDeviceTensorHandle(TFE_Context * ctx,const char * device_name,TF_DataType dtype,void * data,TFE_CustomDeviceTensorHandleMethods methods,TF_Status * status)540 TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
541 TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data,
542 TFE_CustomDeviceTensorHandleMethods methods, TF_Status* status) {
543 tensorflow::ImmediateExecutionContext* context = tensorflow::unwrap(ctx);
544 tensorflow::CustomDevice* device = nullptr;
545 if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
546 &device)) {
547 methods.deallocator(data);
548 status->status =
549 tensorflow::errors::InvalidArgument(device_name, " unknown device.");
550 return nullptr;
551 }
552 return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
553 context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
554 methods));
555 }
556
TFE_NewTensorHandleFromDeviceMemory(TFE_Context * ctx,const char * device_name,TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg,TF_Status * status)557 TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
558 TFE_Context* ctx, const char* device_name, TF_DataType dtype,
559 const int64_t* dims, int num_dims, void* data, size_t len,
560 void (*deallocator)(void* data, size_t len, void* arg),
561 void* deallocator_arg, TF_Status* status) {
562 tensorflow::Device* device = nullptr;
563 tensorflow::EagerContext* context =
564 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
565 status->status = context->FindDeviceFromName(device_name, &device);
566 if (!status->status.ok()) {
567 deallocator(data, len, deallocator_arg);
568 status->status =
569 tensorflow::errors::InvalidArgument(device_name, " unknown device.");
570 return nullptr;
571 }
572 std::vector<tensorflow::int64> dimvec(num_dims);
573 for (int i = 0; i < num_dims; ++i) {
574 dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
575 }
576
577 // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
578 // the device?
579 TF_ManagedBuffer* buf =
580 new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
581 /*owns_memory=*/false);
582
583 tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
584 tensorflow::TensorShape(dimvec), buf);
585 buf->Unref();
586 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
587 std::move(t), device, device, context));
588 }
589
590 // This function will block till the operation that produces `h` has
591 // completed. This is only valid on local TFE_TensorHandles. Returns the size in
592 // bytes of the memory pointed to by the device pointer returned above.
TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle * h,TF_Status * status)593 size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
594 TF_Status* status) {
595 if (h == nullptr) {
596 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
597 return 0;
598 }
599 tensorflow::TensorHandle* handle =
600 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
601 if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
602 status->status = tensorflow::errors::InvalidArgument(
603 "TFE_TensorHandleDeviceMemorySize may not be called on a ",
604 handle->TypeString(), " tensor handle.");
605 return 0;
606 }
607 const tensorflow::Tensor* tensor;
608 status->status = handle->Tensor(&tensor);
609 if (!status->status.ok()) {
610 return 0;
611 }
612 return tensor->TotalBytes();
613 }
614
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)615 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
616 TF_Status* status) {
617 tensorflow::ImmediateExecutionOperation* new_op =
618 tensorflow::unwrap(ctx)->CreateOperation();
619 status->status = new_op->Reset(op_or_function_name, nullptr);
620 if (!status->status.ok()) {
621 new_op->Release();
622 new_op = nullptr;
623 }
624 return tensorflow::wrap(new_op);
625 }
626
TFE_DeleteOp(TFE_Op * op)627 void TFE_DeleteOp(TFE_Op* op) {
628 if (op == nullptr) {
629 return;
630 }
631
632 tensorflow::unwrap(op)->Release();
633 }
634
TFE_OpGetName(const TFE_Op * op,TF_Status * status)635 const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
636 return tensorflow::unwrap(op)->Name().c_str();
637 }
638
TFE_OpGetContext(const TFE_Op * op,TF_Status * status)639 TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
640 return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
641 }
642
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)643 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
644 status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
645 }
646
TFE_OpGetDevice(const TFE_Op * op,TF_Status * status)647 const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
648 return tensorflow::unwrap(op)->DeviceName().c_str();
649 }
650
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)651 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
652 status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
653 }
654
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)655 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
656 TF_Status* status) {
657 status->status = tensorflow::unwrap(op)->AddInputList(
658 {reinterpret_cast<tensorflow::AbstractTensorHandle**>(
659 tensorflow::unwrap(inputs)),
660 static_cast<size_t>(num_inputs)});
661 }
662
TFE_OpGetFlatInputCount(const TFE_Op * op,TF_Status * status)663 extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
664 return tensorflow::unwrap(op)->GetInputs().size();
665 }
666
TFE_OpGetFlatInput(const TFE_Op * op,int index,TF_Status * status)667 extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
668 TF_Status* status) {
669 return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
670 }
671
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)672 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
673 unsigned char* is_list, TF_Status* status) {
674 TF_AttrType ret = TF_ATTR_INT;
675 const tensorflow::AttrTypeMap* attr_types_;
676 bool is_function;
677 status->status = tensorflow::AttrTypeMapForOp(
678 tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
679 if (!status->status.ok()) {
680 return ret;
681 }
682 status->status =
683 tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
684 return ret;
685 }
686
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)687 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
688 const char* op_or_function_name,
689 const char* attr_name, unsigned char* is_list,
690 TF_Status* status) {
691 TF_AttrType ret;
692 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
693 if (status->status.ok()) {
694 ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
695 } else {
696 ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
697 }
698 TFE_DeleteOp(op);
699 return ret;
700 }
701
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)702 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
703 size_t length) {
704 auto s = tensorflow::unwrap(op)->SetAttrString(
705 attr_name, static_cast<const char*>(value), length);
706 if (!s.ok()) {
707 LOG(WARNING) << "Unable to set attribute: " << attr_name;
708 }
709 }
710
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)711 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
712 auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
713 if (!s.ok()) {
714 LOG(WARNING) << "Unable to set attribute: " << attr_name;
715 }
716 }
717
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)718 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
719 auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
720 if (!s.ok()) {
721 LOG(WARNING) << "Unable to set attribute: " << attr_name;
722 }
723 }
724
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)725 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
726 auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
727 (value == 0) ? false : true);
728 if (!s.ok()) {
729 LOG(WARNING) << "Unable to set attribute: " << attr_name;
730 }
731 }
732
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)733 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
734 auto s = tensorflow::unwrap(op)->SetAttrType(
735 attr_name, static_cast<tensorflow::DataType>(value));
736 if (!s.ok()) {
737 LOG(WARNING) << "Unable to set attribute: " << attr_name;
738 }
739 }
740
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)741 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
742 const int num_dims, TF_Status* out_status) {
743 out_status->status =
744 tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
745 }
746
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)747 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
748 const TFE_Op* value) {
749 auto s = tensorflow::unwrap(op)->SetAttrFunction(
750 attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
751 if (!s.ok()) {
752 LOG(WARNING) << "Unable to set attribute: " << attr_name;
753 }
754 }
755
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)756 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
757 const char* data, size_t length) {
758 auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
759 if (!s.ok()) {
760 LOG(WARNING) << "Unable to set attribute: " << attr_name;
761 }
762 }
763
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)764 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
765 TF_Status* status) {
766 tensorflow::Tensor t;
767 status->status = TF_TensorToTensor(tensor, &t);
768 tensorflow::TensorInterface interface(t);
769 status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
770 }
771
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)772 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
773 const void* const* values, const size_t* lengths,
774 int num_values) {
775 auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
776 num_values);
777 if (!s.ok()) {
778 LOG(WARNING) << "Unable to set attribute: " << attr_name;
779 }
780 }
781
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)782 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
783 const float* values, int num_values) {
784 auto s =
785 tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
786 if (!s.ok()) {
787 LOG(WARNING) << "Unable to set attribute: " << attr_name;
788 }
789 }
790
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)791 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
792 const int64_t* values, int num_values) {
793 auto s =
794 tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
795 if (!s.ok()) {
796 LOG(WARNING) << "Unable to set attribute: " << attr_name;
797 }
798 }
799
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)800 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
801 const TF_DataType* values, int num_values) {
802 auto s = tensorflow::unwrap(op)->SetAttrTypeList(
803 attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
804 num_values);
805 if (!s.ok()) {
806 LOG(WARNING) << "Unable to set attribute: " << attr_name;
807 }
808 }
809
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)810 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
811 const unsigned char* values, int num_values) {
812 auto s =
813 tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
814 if (!s.ok()) {
815 LOG(WARNING) << "Unable to set attribute: " << attr_name;
816 }
817 }
818
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)819 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
820 const int64_t** dims, const int* num_dims,
821 int num_values, TF_Status* out_status) {
822 out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
823 attr_name, dims, num_dims, num_values);
824 }
825
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)826 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
827 const TFE_Op** value, int num_values) {
828 auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
829 attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
830 tensorflow::unwrap(value)),
831 static_cast<size_t>(num_values)});
832 if (!s.ok()) {
833 LOG(WARNING) << "Unable to set attribute: " << attr_name;
834 }
835 }
836
TFE_OpSetAttrValueProto(const TFE_Op * op,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)837 void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
838 const void* proto, size_t proto_len,
839 TF_Status* status) {
840 tensorflow::AttrValue attr_value;
841 if (!attr_value.ParseFromArray(proto, proto_len)) {
842 status->status =
843 tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
844 return;
845 }
846 if (op == nullptr) {
847 status->status = tensorflow::errors::InvalidArgument(
848 "Got a null or uninitialized `op` argument");
849 return;
850 }
851 tensorflow::EagerOperation* operation =
852 OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
853 operation->MutableAttrs()->Set(attr_name, attr_value);
854 }
855
TFE_OpGetInputLength(TFE_Op * op,const char * input_name,TF_Status * status)856 TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
857 const char* input_name,
858 TF_Status* status) {
859 int ret = -1;
860 status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
861 return ret;
862 }
863
TFE_OpGetOutputLength(TFE_Op * op,const char * output_name,TF_Status * status)864 TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
865 const char* output_name,
866 TF_Status* status) {
867 int ret = -1;
868 status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
869 return ret;
870 }
871
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)872 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
873 TF_Status* status) {
874 tensorflow::ImmediateExecutionOperation* unwrapped_op =
875 tensorflow::unwrap(op);
876
877 status->status =
878 unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
879 unwrapped_op,
880 reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
881 retvals),
882 num_retvals);
883 }
884
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)885 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
886 TFE_Context* ctx,
887 const char* device_name,
888 TF_Status* status) {
889 if (h == nullptr) {
890 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
891 return nullptr;
892 }
893
894 tensorflow::ImmediateExecutionContext* unwrapped_ctx =
895 tensorflow::unwrap(ctx);
896
897 auto* result =
898 unwrapped_ctx->GetCustomDeviceOpHandler().CopyTensorHandleToDevice(
899 unwrapped_ctx, tensorflow::unwrap(h), device_name, &status->status);
900
901 if (status->status.ok()) {
902 return tensorflow::wrap(result);
903 }
904 return nullptr;
905 }
906
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)907 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
908 const char* serialized_function_def, size_t size,
909 TF_Status* status) {
910 tensorflow::FunctionDef function_def;
911 if (!function_def.ParseFromArray(serialized_function_def, size)) {
912 status->status =
913 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
914 return;
915 }
916
917 AnnotateEagerRuntimeConstructionContext(function_def);
918 status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
919 }
920
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)921 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
922 TF_Status* status) {
923 AnnotateEagerRuntimeConstructionContext(function->fdef);
924 status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
925 function->fdef, function->stack_traces);
926 }
927
TFE_ContextRemoveFunction(TFE_Context * ctx,const char * name,TF_Status * status)928 void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
929 TF_Status* status) {
930 status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
931 }
932
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)933 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
934 return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
935 }
936
TFE_ContextEnableRunMetadata(TFE_Context * ctx)937 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
938 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
939 }
940
TFE_ContextDisableRunMetadata(TFE_Context * ctx)941 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
942 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
943 }
944
945 } // extern "C"
946
TFE_NewTensorHandle(const tensorflow::Tensor & t,TF_Status * status)947 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
948 TF_Status* status) {
949 return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
950 }
951
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)952 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
953 TF_Status* status) {
954 auto* context = tensorflow::unwrap(ctx);
955 status->status = context->AsyncWait();
956 if (!status->status.ok()) return;
957 auto run_metadata = context->ExportRunMetadata();
958 status->status = MessageToBuffer(*run_metadata, buf);
959 }
960
961 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)962 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
963 TF_Status* status) {
964 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
965 for (const auto& attr : func.attr()) {
966 if (!status->status.ok()) return nullptr;
967 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
968 if (!status->status.ok()) return nullptr;
969 }
970 return func_op;
971 }
972 } // namespace
973
TFE_ContextStartStep(TFE_Context * ctx)974 void TFE_ContextStartStep(TFE_Context* ctx) {
975 tensorflow::unwrap(ctx)->StartStep();
976 }
977
TFE_ContextEndStep(TFE_Context * ctx)978 void TFE_ContextEndStep(TFE_Context* ctx) {
979 tensorflow::unwrap(ctx)->EndStep();
980 }
981
TFE_OpGetAttrs(const TFE_Op * op)982 const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
983 return tensorflow::wrap(tensorflow::unwrap(op)->GetOpAttrs());
984 }
985
TFE_OpAddAttrs(TFE_Op * op,const TFE_OpAttrs * attrs)986 void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
987 tensorflow::unwrap(op)->AddAttrs(tensorflow::unwrap(attrs));
988 }
989
TFE_OpAttrsSerialize(const TFE_OpAttrs * attrs,TF_Buffer * buf,TF_Status * status)990 void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
991 TF_Status* status) {
992 tensorflow::NameAttrList name_and_attrs;
993 tensorflow::unwrap(attrs)->GetNameAttrList(&name_and_attrs);
994 status->status = MessageToBuffer(name_and_attrs, buf);
995 }
996
997 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)998 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
999 const tensorflow::AttrValue& default_value,
1000 const char* attr_name, TF_Status* status) {
1001 switch (default_value.value_case()) {
1002 case tensorflow::AttrValue::kS: {
1003 const string& v = default_value.s();
1004 TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
1005 break;
1006 }
1007 case tensorflow::AttrValue::kI:
1008 TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1009 break;
1010 case tensorflow::AttrValue::kF:
1011 TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1012 break;
1013 case tensorflow::AttrValue::kB:
1014 TFE_OpSetAttrBool(op, attr_name, default_value.b());
1015 break;
1016 case tensorflow::AttrValue::kType:
1017 TFE_OpSetAttrType(op, attr_name,
1018 static_cast<TF_DataType>(default_value.type()));
1019 break;
1020 case tensorflow::AttrValue::kShape: {
1021 const auto& tensor_shape = default_value.shape();
1022 if (tensor_shape.unknown_rank()) {
1023 TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1024 } else {
1025 const auto num_dims = tensor_shape.dim_size();
1026 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1027 for (int i = 0; i < num_dims; ++i) {
1028 dims[i] = tensor_shape.dim(i).size();
1029 }
1030 TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1031 }
1032 } break;
1033 case tensorflow::AttrValue::kFunc: {
1034 const auto func_op = GetFunc(ctx, default_value.func(), status);
1035 if (!status->status.ok()) return;
1036 // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1037 // require TFE_Op* and just convert it internally a NameAttrValue, so
1038 // consider adding an overload to the C API to make this case easier.
1039 TFE_OpSetAttrFunction(op, attr_name, func_op);
1040 TFE_DeleteOp(func_op);
1041 } break;
1042 case tensorflow::AttrValue::kList: {
1043 // String
1044 if (const int s_size = default_value.list().s_size()) {
1045 absl::InlinedVector<const void*, 4> values_vector;
1046 absl::InlinedVector<size_t, 4> lengths_vector;
1047 for (int i = 0; i < s_size; ++i) {
1048 const string& v = default_value.list().s(i);
1049 values_vector.push_back(v.data());
1050 lengths_vector.push_back(v.size());
1051 }
1052 TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
1053 lengths_vector.data(), s_size);
1054 }
1055
1056 // Int
1057 if (const int i_size = default_value.list().i_size()) {
1058 absl::InlinedVector<int64_t, 4> i_vector;
1059 for (int i = 0; i < i_size; ++i) {
1060 i_vector.push_back(default_value.list().i(i));
1061 }
1062 TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
1063 }
1064 // Float
1065 if (const int f_size = default_value.list().f_size()) {
1066 absl::InlinedVector<float, 4> f_vector;
1067 for (int i = 0; i < f_size; ++i) {
1068 f_vector.push_back(default_value.list().f(i));
1069 }
1070 TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
1071 }
1072 // Bool
1073 if (const int b_size = default_value.list().b_size()) {
1074 absl::InlinedVector<unsigned char, 4> b_vector;
1075 for (int i = 0; i < b_size; i++) {
1076 b_vector.push_back(default_value.list().b(i));
1077 }
1078 TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
1079 }
1080 // Type
1081 if (const int type_size = default_value.list().type_size()) {
1082 absl::InlinedVector<unsigned int, 4> type_vector;
1083 for (int i = 0; i < type_size; ++i) {
1084 type_vector.push_back(default_value.list().type(i));
1085 }
1086 TFE_OpSetAttrTypeList(
1087 op, attr_name,
1088 reinterpret_cast<const TF_DataType*>(type_vector.data()),
1089 type_size);
1090 }
1091
1092 // Rest are not supported.
1093 if (default_value.list().shape_size() > 0 ||
1094 default_value.list().func_size() > 0 ||
1095 default_value.list().tensor_size() > 0) {
1096 TF_SetStatus(
1097 status, TF_UNIMPLEMENTED,
1098 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1099 default_value.DebugString())
1100 .data());
1101 }
1102 } break;
1103 case tensorflow::AttrValue::kTensor:
1104 TF_FALLTHROUGH_INTENDED;
1105 case tensorflow::AttrValue::kPlaceholder:
1106 TF_FALLTHROUGH_INTENDED;
1107 case tensorflow::AttrValue::VALUE_NOT_SET:
1108 TF_SetStatus(
1109 status, TF_UNIMPLEMENTED,
1110 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1111 default_value.DebugString())
1112 .data());
1113 }
1114 }
1115 } // namespace tensorflow
1116
1117 namespace {
DefaultCustomDevicePack(TFE_Context * context,TFE_TensorHandle ** handles,int num_handles,TF_Status * status,void * device_info)1118 TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
1119 TFE_TensorHandle** handles,
1120 int num_handles, TF_Status* status,
1121 void* device_info) {
1122 TF_SetStatus(status, TF_UNIMPLEMENTED,
1123 "This custom device does not support packing tensors.");
1124 return nullptr;
1125 }
1126 } // namespace
1127
1128 extern "C" {
1129
TFE_RegisterCustomDevice(TFE_Context * ctx,TFE_CustomDevice device,const char * device_name,void * device_info,TF_Status * status)1130 void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
1131 const char* device_name, void* device_info,
1132 TF_Status* status) {
1133 // Fill in default values for optional functionality.
1134 if (device.pack == nullptr) {
1135 device.pack = &DefaultCustomDevicePack;
1136 }
1137 auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
1138 ctx, device, device_info, device_name);
1139 status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
1140 device_name, std::move(custom_device));
1141 }
1142
1143 } // extern "C"
1144