1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "tensorflow/c/c_api.h"
26 #include "tensorflow/c/c_api_internal.h"
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/core/platform/host_info.h"
29 #ifdef TENSORFLOW_EAGER_USE_XLA
30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
31 #endif // TENSORFLOW_EAGER_USE_XLA
32 #include "tensorflow/core/common_runtime/copy_tensor.h"
33 #include "tensorflow/core/common_runtime/device_factory.h"
34 #include "tensorflow/core/common_runtime/device_mgr.h"
35 #include "tensorflow/core/common_runtime/device_set.h"
36 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
37 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
38 #include "tensorflow/core/common_runtime/eager/execute.h"
39 #include "tensorflow/core/common_runtime/function.h"
40 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
41 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
42 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
43 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
44 #include "tensorflow/core/distributed_runtime/server_lib.h"
45 #include "tensorflow/core/distributed_runtime/worker_env.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/rendezvous.h"
48 #include "tensorflow/core/framework/tensor_shape.pb.h"
49 #include "tensorflow/core/framework/types.h"
50 #include "tensorflow/core/lib/core/refcount.h"
51 #include "tensorflow/core/lib/core/stringpiece.h"
52 #include "tensorflow/core/lib/gtl/cleanup.h"
53 #include "tensorflow/core/lib/gtl/flatmap.h"
54 #include "tensorflow/core/lib/gtl/map_util.h"
55 #include "tensorflow/core/lib/gtl/stl_util.h"
56 #include "tensorflow/core/lib/random/random.h"
57 #include "tensorflow/core/platform/env.h"
58 #include "tensorflow/core/platform/mutex.h"
59 #include "tensorflow/core/platform/thread_annotations.h"
60 #include "tensorflow/core/public/version.h"
61
62 using tensorflow::int64;
63 using tensorflow::string;
64
65 namespace {
IsCPU(const tensorflow::Device * d)66 bool IsCPU(const tensorflow::Device* d) {
67 return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
68 }
69
IsXLA(const tensorflow::Device * d)70 bool IsXLA(const tensorflow::Device* d) {
71 if (d == nullptr) return false;
72 const auto& device_type = d->attributes().device_type();
73 return device_type.find("XLA") != std::string::npos;
74 }
75
DeviceName(const tensorflow::Device * d)76 string DeviceName(const tensorflow::Device* d) {
77 return (d == nullptr) ? "cpu:0" : d->name();
78 }
79
GetAllRemoteDevices(const std::vector<string> & remote_workers,tensorflow::WorkerCacheInterface * worker_cache,std::unique_ptr<tensorflow::DeviceMgr> * device_mgr)80 tensorflow::Status GetAllRemoteDevices(
81 const std::vector<string>& remote_workers,
82 tensorflow::WorkerCacheInterface* worker_cache,
83 std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
84 std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
85 tensorflow::Status status;
86 // TODO(nareshmodi) do this in parallel instead of serially.
87 for (const string& remote_worker : remote_workers) {
88 tensorflow::Notification n;
89 tensorflow::NewRemoteDevices(
90 tensorflow::Env::Default(), worker_cache, remote_worker,
91 [&status, &n, &remote_devices](
92 const tensorflow::Status& s,
93 std::vector<tensorflow::Device*>* devices) {
94 status = s;
95 if (s.ok()) {
96 for (tensorflow::Device* d : *devices) {
97 remote_devices.emplace_back(d);
98 }
99 }
100 n.Notify();
101 });
102 n.WaitForNotification();
103 }
104 std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
105 new tensorflow::DeviceMgr(std::move(remote_devices)));
106
107 TF_RETURN_IF_ERROR(status);
108
109 *device_mgr = std::move(remote_device_mgr);
110 return tensorflow::Status::OK();
111 }
112
CreateRemoteContexts(const std::vector<string> & remote_workers,int64 rendezvous_id,int keep_alive_secs,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,bool async,tensorflow::gtl::FlatMap<string,tensorflow::uint64> * remote_contexts)113 tensorflow::Status CreateRemoteContexts(
114 const std::vector<string>& remote_workers, int64 rendezvous_id,
115 int keep_alive_secs, const tensorflow::ServerDef& server_def,
116 tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
117 tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
118 for (int i = 0; i < remote_workers.size(); i++) {
119 const string& remote_worker = remote_workers[i];
120
121 tensorflow::eager::CreateContextRequest request;
122 tensorflow::eager::CreateContextResponse response;
123 request.set_rendezvous_id(rendezvous_id);
124 tensorflow::DeviceNameUtils::ParsedName parsed_name;
125 if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
126 &parsed_name)) {
127 return tensorflow::errors::InvalidArgument(
128 "Unable to parse ", remote_worker, " as a device name");
129 }
130 *request.mutable_server_def() = server_def;
131 request.mutable_server_def()->set_job_name(parsed_name.job);
132 request.mutable_server_def()->set_task_index(parsed_name.task);
133 request.set_async(async);
134 request.set_keep_alive_secs(keep_alive_secs);
135 auto* eager_client = remote_eager_workers->GetClient(remote_worker);
136 if (eager_client == nullptr) {
137 return tensorflow::errors::Internal(
138 "Cannot find a client for the given target:", remote_worker);
139 }
140 tensorflow::Notification n;
141 tensorflow::Status status;
142 // TODO(nareshmodi) do this in parallel instead of serially.
143 eager_client->CreateContextAsync(
144 &request, &response, [&status, &n](const tensorflow::Status& s) {
145 status = s;
146 n.Notify();
147 });
148 n.WaitForNotification();
149 TF_RETURN_IF_ERROR(status);
150
151 remote_contexts->emplace(remote_worker, response.context_id());
152 }
153 return tensorflow::Status::OK();
154 }
155
UpdateTFE_ContextWithServerDef(int keep_alive_secs,const tensorflow::ServerDef & server_def,TFE_Context * ctx)156 tensorflow::Status UpdateTFE_ContextWithServerDef(
157 int keep_alive_secs, const tensorflow::ServerDef& server_def,
158 TFE_Context* ctx) {
159 // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
160 // server object (which currently CHECK-fails) and we miss the error, instead,
161 // we log the error, and then return to allow the user to see the error
162 // message.
163 #define LOG_AND_RETURN_IF_ERROR(...) \
164 do { \
165 const ::tensorflow::Status _status = (__VA_ARGS__); \
166 if (TF_PREDICT_FALSE(!_status.ok())) { \
167 LOG(ERROR) << _status.error_message(); \
168 return _status; \
169 } \
170 } while (0);
171
172 string worker_name =
173 tensorflow::strings::StrCat("/job:", server_def.job_name(),
174 "/replica:0/task:", server_def.task_index());
175
176 std::unique_ptr<tensorflow::ServerInterface> server;
177 LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server));
178
179 tensorflow::GrpcServer* grpc_server =
180 dynamic_cast<tensorflow::GrpcServer*>(server.get());
181 if (grpc_server == nullptr) {
182 LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
183 "Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
184 }
185
186 LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
187
188 int64 rendezvous_id = tensorflow::random::New64();
189
190 std::vector<string> remote_workers;
191 grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
192 remote_workers.erase(
193 std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
194 remote_workers.end());
195
196 std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
197 LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
198 remote_workers, grpc_server->master_env()->worker_cache,
199 &remote_device_mgr));
200
201 std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
202 grpc_server->channel_cache();
203 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
204 tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
205
206 // Initialize remote eager workers.
207 tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
208 LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
209 remote_workers, rendezvous_id, keep_alive_secs, server_def,
210 remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
211
212 tensorflow::RemoteRendezvous* r =
213 grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
214
215 auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
216 TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
217 session_name, server_def, true));
218
219 std::shared_ptr<tensorflow::WorkerSession> worker_session;
220 TF_RETURN_IF_ERROR(
221 grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
222 session_name, &worker_session));
223
224 // Initialize remote tensor communication based on worker session.
225 TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
226
227 auto* device_mgr = grpc_server->worker_env()->device_mgr;
228
229 return ctx->context.InitializeRemote(
230 std::move(server), std::move(remote_eager_workers),
231 std::move(remote_device_mgr), remote_contexts, r, device_mgr,
232 keep_alive_secs);
233 #undef LOG_AND_RETURN_IF_ERROR
234 }
235
OpInferSingleInputAttrs(TFE_Op * op,TFE_TensorHandle * input)236 tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
237 TFE_TensorHandle* input) {
238 TFE_OpInferenceContext* ictx = op->inference_ctx.get();
239 const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
240 if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
241 // Some clients that are still setting their input attributes manually are
242 // adding input list to their op by calling `TFE_OpAddInput` for each of
243 // its elements instead of calling `TFE_OpAddInputList`. When this happens,
244 // we cannot detect the end of such list, thus lose track of the input
245 // arguments in the op definition. To guarantee backward compatibility with
246 // those clients, disable automatic inference in this case.
247 op->inference_ctx.reset(nullptr);
248 return tensorflow::Status::OK();
249 }
250 const std::string& type_attr = input_def.type_attr();
251 if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
252 op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
253 ictx->attrs.insert(type_attr);
254 }
255 return tensorflow::Status::OK();
256 }
257
OpInferSingleTypeInputListAttrs(TFE_Op * op,const tensorflow::OpDef::ArgDef & input_def,TFE_TensorHandle ** inputs,int num_inputs)258 void OpInferSingleTypeInputListAttrs(TFE_Op* op,
259 const tensorflow::OpDef::ArgDef& input_def,
260 TFE_TensorHandle** inputs,
261 int num_inputs) {
262 TFE_OpInferenceContext* ictx = op->inference_ctx.get();
263 if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
264 op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
265 ictx->attrs.insert(input_def.number_attr());
266 }
267 if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
268 op->operation.MutableAttrs()->Set(input_def.type_attr(),
269 inputs[0]->handle->dtype);
270 ictx->attrs.insert(input_def.type_attr());
271 }
272 }
273
OpInferMixedTypeInputListAttrs(TFE_Op * op,const tensorflow::OpDef::ArgDef & input_def,TFE_TensorHandle ** inputs,int num_inputs)274 void OpInferMixedTypeInputListAttrs(TFE_Op* op,
275 const tensorflow::OpDef::ArgDef& input_def,
276 TFE_TensorHandle** inputs, int num_inputs) {
277 TFE_OpInferenceContext* ictx = op->inference_ctx.get();
278 if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
279 std::unique_ptr<tensorflow::DataType[]> dtypes(
280 new tensorflow::DataType[num_inputs]);
281 for (int i = 0; i < num_inputs; ++i) {
282 dtypes[i] = inputs[i]->handle->dtype;
283 }
284 op->operation.MutableAttrs()->Set(
285 input_def.type_list_attr(),
286 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
287 num_inputs));
288 ictx->attrs.insert(input_def.type_list_attr());
289 }
290 }
291
OpInferInputListAttrs(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs)292 tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
293 int num_inputs) {
294 TFE_OpInferenceContext* ictx = op->inference_ctx.get();
295 const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
296 if (!input_def.type_list_attr().empty()) {
297 OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
298 } else if (!input_def.type_attr().empty() &&
299 !input_def.number_attr().empty()) {
300 OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
301 } else {
302 return tensorflow::errors::InvalidArgument("Invalid input list definition");
303 }
304 return tensorflow::Status::OK();
305 }
306
307 } // namespace
308
309 extern "C" {
310
TFE_NewContextOptions()311 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
312
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)313 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
314 size_t proto_len, TF_Status* status) {
315 TF_SetConfig(&options->session_options, proto, proto_len, status);
316 }
317
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)318 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
319 unsigned char enable) {
320 options->async = enable;
321 }
322
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)323 void TFE_ContextOptionsSetDevicePlacementPolicy(
324 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
325 options->policy = policy;
326 }
327
TFE_ContextSetAsyncForThread(TFE_Context * ctx,unsigned char enable,TF_Status * status)328 TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
329 unsigned char enable,
330 TF_Status* status) {
331 status->status = ctx->context.SetAsyncForThread(enable);
332 }
333
TFE_DeleteContextOptions(TFE_ContextOptions * options)334 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
335
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)336 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
337 std::vector<std::unique_ptr<tensorflow::Device>> devices;
338 status->status = tensorflow::DeviceFactory::AddDevices(
339 opts->session_options.options, "/job:localhost/replica:0/task:0",
340 &devices);
341 if (!status->status.ok()) return nullptr;
342 std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
343 new tensorflow::DeviceMgr(std::move(devices)));
344
345 tensorflow::Rendezvous* r =
346 new tensorflow::IntraProcessRendezvous(device_mgr.get());
347
348 return new TFE_Context(opts->session_options.options, opts->policy,
349 opts->async, device_mgr.release(),
350 /*device_mgr_owned*/ true, r);
351 }
352
TFE_NewContextFromSession(const TFE_ContextOptions * opts,TF_Session * sess,TF_Status * status)353 TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
354 TF_Session* sess, TF_Status* status) {
355 const tensorflow::DeviceMgr* device_mgr = nullptr;
356 status->status = sess->session->LocalDeviceManager(&device_mgr);
357 if (!status->status.ok()) return nullptr;
358 tensorflow::Rendezvous* r =
359 new tensorflow::IntraProcessRendezvous(device_mgr);
360 return new TFE_Context(opts->session_options.options, opts->policy,
361 opts->async, device_mgr, /*device_mgr_owned*/ false,
362 r);
363 }
364
TFE_DeleteContext(TFE_Context * ctx)365 void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
366
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)367 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
368 TF_DeviceList* list = new TF_DeviceList;
369 ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
370 if (ctx->context.remote_device_mgr()) {
371 ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
372 }
373 return list;
374 }
375
TFE_ContextClearCaches(TFE_Context * ctx,TF_Status * status)376 void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) {
377 status->status = ctx->context.ClearCaches();
378 }
379
380 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)381 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
382 int keep_alive_secs,
383 const void* proto,
384 size_t proto_len,
385 TF_Status* status) {
386 tensorflow::ServerDef server_def;
387 if (!server_def.ParseFromArray(proto, proto_len)) {
388 status->status = tensorflow::errors::InvalidArgument(
389 "Invalid tensorflow.ServerDef protocol buffer");
390 return;
391 }
392 status->status =
393 UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
394 }
395
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)396 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
397 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
398 ctx->context.SetThreadLocalDevicePlacementPolicy(
399 static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
400 }
401
402 // Note: this function looks up a thread local policy. So it should be called in
403 // the appropriate client thread. In particular, in async mode, it may not be
404 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)405 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
406 TFE_Context* ctx) {
407 return static_cast<TFE_ContextDevicePlacementPolicy>(
408 ctx->context.GetDevicePlacementPolicy());
409 }
410
TFE_ContextAsyncWait(TFE_Context * ctx,TF_Status * status)411 void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
412 status->status = ctx->context.AsyncWait();
413 }
414
TFE_ContextGetStatus(TFE_Context * ctx,TF_Status * status)415 void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
416 status->status = ctx->context.GetStatus();
417 }
418
TFE_ContextAsyncClearError(TFE_Context * ctx)419 void TFE_ContextAsyncClearError(TFE_Context* ctx) {
420 ctx->context.ClearAsyncError();
421 }
422
TFE_NewTensorHandle(TF_Tensor * t,TF_Status * status)423 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
424 tensorflow::Tensor tensor;
425 status->status = tensorflow::TF_TensorToTensor(t, &tensor);
426 if (!status->status.ok()) return nullptr;
427 return new TFE_TensorHandle(tensor, nullptr, nullptr);
428 }
429
TFE_DeleteTensorHandle(TFE_TensorHandle * h)430 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
431 if (h == nullptr) return;
432 VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
433 << h->handle;
434 if (h->handle) {
435 h->handle->Unref();
436 }
437 delete h;
438 }
439
TFE_TensorHandleDataType(TFE_TensorHandle * h)440 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
441 return static_cast<TF_DataType>(h->handle->dtype);
442 }
443
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)444 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
445 if (h == nullptr || h->handle == nullptr) {
446 status->status = tensorflow::errors::InvalidArgument(
447 "The passed in handle is a nullptr");
448 return -1;
449 }
450 int result;
451 status->status = h->handle->NumDims(&result);
452 return result;
453 }
454
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)455 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
456 if (h == nullptr || h->handle == nullptr) {
457 status->status = tensorflow::errors::InvalidArgument(
458 "The passed in handle is a nullptr");
459 return -1;
460 }
461 tensorflow::int64 result;
462 status->status = h->handle->NumElements(&result);
463 return result;
464 }
465
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)466 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
467 TF_Status* status) {
468 if (h == nullptr || h->handle == nullptr) {
469 status->status = tensorflow::errors::InvalidArgument(
470 "The passed in handle is a nullptr");
471 return -1;
472 }
473 tensorflow::int64 result;
474 status->status = h->handle->Dim(dim_index, &result);
475 return result;
476 }
477
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)478 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
479 if (h == nullptr || h->handle == nullptr) {
480 status->status = tensorflow::errors::InvalidArgument(
481 "The passed in handle is a nullptr");
482 return nullptr;
483 }
484 tensorflow::Device* d = h->handle->op_device();
485 return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
486 : d->name().c_str();
487 }
488
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)489 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
490 TF_Status* status) {
491 if (h == nullptr || h->handle == nullptr) {
492 status->status = tensorflow::errors::InvalidArgument(
493 "The passed in handle is a nullptr");
494 return nullptr;
495 }
496 tensorflow::Device* d = h->handle->device();
497 return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
498 : d->name().c_str();
499 }
500
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)501 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
502 TFE_TensorHandle* h, TF_Status* status) {
503 if (h == nullptr || h->handle == nullptr) {
504 status->status = tensorflow::errors::InvalidArgument(
505 "The passed in handle is a nullptr");
506 return nullptr;
507 }
508
509 h->handle->Ref();
510
511 return new TFE_TensorHandle(h->handle);
512 }
513
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)514 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
515 if (h == nullptr || h->handle == nullptr) {
516 status->status = tensorflow::errors::InvalidArgument(
517 "The passed in handle is a nullptr");
518 return nullptr;
519 }
520 // TODO(agarwal): move this implementation inside TFE_TensorHandle.
521 const tensorflow::Tensor* t = nullptr;
522 tensorflow::TensorHandle* h_cpu = nullptr;
523 tensorflow::Device* d = nullptr;
524 tensorflow::Device* op_device = nullptr;
525
526 if (h->handle->IsRemote()) {
527 status->status = EagerCopyToDevice(
528 h->handle, h->handle->Context(),
529 h->handle->Context()->HostCPU()->name().c_str(), &h_cpu);
530 if (!status->status.ok()) {
531 return nullptr;
532 }
533 status->status = h_cpu->TensorAndDevice(&t, &d, &op_device);
534 if (!status->status.ok()) {
535 h_cpu->Unref();
536 return nullptr;
537 }
538 } else {
539 status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
540 if (!status->status.ok()) return nullptr;
541
542 if (!IsCPU(d)) {
543 status->status = h->handle->CopyToDevice(
544 h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu);
545 if (!status->status.ok()) {
546 return nullptr;
547 }
548 status->status = h_cpu->TensorAndDevice(&t, &d, &op_device);
549 if (!status->status.ok()) {
550 h_cpu->Unref();
551 return nullptr;
552 }
553 }
554 }
555 TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
556 if (h_cpu != nullptr) {
557 h_cpu->Unref();
558 }
559 return retval;
560 }
561
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)562 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
563 TF_Status* status) {
564 const char* name = op_or_function_name; // Shorthand
565 const tensorflow::AttrTypeMap* types;
566 bool is_function = false;
567 status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
568 if (!status->status.ok()) {
569 return nullptr;
570 }
571 if (!is_function) {
572 const tensorflow::OpDef* op_def;
573 status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
574 if (!status->status.ok()) {
575 return nullptr;
576 }
577 return new TFE_Op(ctx, name, false, types,
578 new TFE_OpInferenceContext(op_def));
579 }
580 if (!ctx->context.FindFunctionByName(name)) {
581 status->status = tensorflow::errors::NotFound(
582 "'", name,
583 "' is neither a type of a primitive operation nor a name "
584 "of a function registered in binary running on ",
585 tensorflow::port::Hostname(),
586 ". Make sure the operation or function is "
587 "registered in the binary running in this process.");
588 return nullptr;
589 }
590 return new TFE_Op(ctx, name, true, types, nullptr);
591 }
592
TFE_DeleteOp(TFE_Op * op)593 void TFE_DeleteOp(TFE_Op* op) { delete op; }
594
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)595 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
596 status->status = op->operation.SetDevice(device_name);
597 }
598
TFE_OpGetDevice(TFE_Op * op,TF_Status * status)599 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
600 tensorflow::Device* device = (op->operation.Device() == nullptr)
601 ? op->operation.EagerContext()->HostCPU()
602 : op->operation.Device();
603 return device->name().c_str();
604 }
605
TFE_OpSetXLACompilation(TFE_Op * op,unsigned char enable)606 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
607 op->operation.SetUseXla(enable);
608 #ifndef TENSORFLOW_EAGER_USE_XLA
609 LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
610 "built with XLA support.";
611 #endif // TENSORFLOW_EAGER_USE_XLA
612 }
613
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)614 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
615 op->operation.AddInput(input->handle);
616 if (op->inference_ctx) {
617 status->status = OpInferSingleInputAttrs(op, input);
618 }
619 }
620
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)621 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
622 TF_Status* status) {
623 for (int i = 0; i < num_inputs; ++i) {
624 op->operation.AddInput(inputs[i]->handle);
625 }
626 if (op->inference_ctx) {
627 status->status = OpInferInputListAttrs(op, inputs, num_inputs);
628 }
629 }
630
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)631 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
632 unsigned char* is_list, TF_Status* status) {
633 TF_AttrType ret;
634 status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
635 attr_name, &ret, is_list);
636 return ret;
637 }
638
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)639 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
640 const char* op_or_function_name,
641 const char* attr_name, unsigned char* is_list,
642 TF_Status* status) {
643 TF_AttrType ret;
644 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
645 if (!status->status.ok()) {
646 return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
647 }
648 ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
649 TFE_DeleteOp(op);
650 return ret;
651 }
652
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)653 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
654 size_t length) {
655 op->operation.MutableAttrs()->Set(
656 attr_name,
657 tensorflow::StringPiece(static_cast<const char*>(value), length));
658 }
659
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)660 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
661 op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
662 }
663
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)664 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
665 op->operation.MutableAttrs()->Set(attr_name, value);
666 }
667
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)668 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
669 op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
670 }
671
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)672 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
673 op->operation.MutableAttrs()->Set(attr_name,
674 static_cast<tensorflow::DataType>(value));
675 }
676
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)677 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
678 const int num_dims, TF_Status* out_status) {
679 if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
680 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
681 tensorflow::strings::StrCat(
682 "Value specified for `", attr_name, "` has ", num_dims,
683 " dimensions which is over the limit of ",
684 tensorflow::TensorShape::MaxDimensions(), ".")
685 .c_str());
686 return;
687 }
688 tensorflow::TensorShapeProto proto;
689 if (num_dims < 0) {
690 proto.set_unknown_rank(true);
691 } else {
692 for (int d = 0; d < num_dims; ++d) {
693 proto.add_dim()->set_size(dims[d]);
694 }
695 }
696 op->operation.MutableAttrs()->Set(attr_name, proto);
697 }
698
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)699 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
700 const TFE_Op* value) {
701 tensorflow::AttrValue attr_value;
702 tensorflow::NameAttrList* func = attr_value.mutable_func();
703 func->set_name(value->operation.Name());
704 value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
705 op->operation.MutableAttrs()->Set(attr_name, attr_value);
706 }
707
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)708 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
709 const char* data, size_t length) {
710 tensorflow::AttrValue attr_value;
711 tensorflow::NameAttrList* func = attr_value.mutable_func();
712 func->set_name(data, length);
713 op->operation.MutableAttrs()->Set(attr_name, attr_value);
714 }
715
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)716 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
717 TF_Status* status) {
718 tensorflow::Tensor t;
719 status->status = TF_TensorToTensor(tensor, &t);
720 if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
721 }
722
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)723 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
724 const void* const* values, const size_t* lengths,
725 int num_values) {
726 std::vector<tensorflow::StringPiece> v(num_values);
727 for (int i = 0; i < num_values; ++i) {
728 v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
729 lengths[i]);
730 }
731 op->operation.MutableAttrs()->Set(attr_name, v);
732 }
733
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)734 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
735 const float* values, int num_values) {
736 op->operation.MutableAttrs()->Set(
737 attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
738 }
739
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)740 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
741 const int64_t* values, int num_values) {
742 op->operation.MutableAttrs()->Set(
743 attr_name, tensorflow::gtl::ArraySlice<const int64>(
744 reinterpret_cast<const int64*>(values), num_values));
745 }
746
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)747 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
748 const TF_DataType* values, int num_values) {
749 op->operation.MutableAttrs()->Set(
750 attr_name,
751 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
752 reinterpret_cast<const tensorflow::DataType*>(values), num_values));
753 }
754
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)755 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
756 const unsigned char* values, int num_values) {
757 std::unique_ptr<bool[]> b(new bool[num_values]);
758 for (int i = 0; i < num_values; ++i) {
759 b[i] = values[i];
760 }
761 op->operation.MutableAttrs()->Set(
762 attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
763 }
764
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)765 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
766 const int64_t** dims, const int* num_dims,
767 int num_values, TF_Status* out_status) {
768 std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
769 new tensorflow::TensorShapeProto[num_values]);
770 for (int i = 0; i < num_values; ++i) {
771 const auto num_dims_i = num_dims[i];
772
773 if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
774 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
775 tensorflow::strings::StrCat(
776 "Value specified for `", attr_name, "` has ", num_dims_i,
777 " dimensions which is over the limit of ",
778 tensorflow::TensorShape::MaxDimensions(), ".")
779 .c_str());
780 return;
781 }
782 if (num_dims_i < 0) {
783 proto[i].set_unknown_rank(true);
784 } else {
785 const int64_t* dims_i = dims[i];
786 auto proto_i = &proto[i];
787 for (int d = 0; d < num_dims_i; ++d) {
788 proto_i->add_dim()->set_size(dims_i[d]);
789 }
790 }
791 }
792 op->operation.MutableAttrs()->Set(
793 attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
794 proto.get(), num_values));
795 }
796
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)797 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
798 const TFE_Op** value, int num_values) {
799 std::unique_ptr<tensorflow::NameAttrList[]> funcs(
800 new tensorflow::NameAttrList[num_values]);
801 for (int i = 0; i < num_values; i++) {
802 funcs[i].set_name(value[i]->operation.Name());
803 value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
804 }
805 op->operation.MutableAttrs()->Set(
806 attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
807 funcs.get(), num_values));
808 }
809
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)810 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
811 TF_Status* status) {
812 VLOG(1) << "Calling TFE_Execute() on op " << op;
813 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
814 *num_retvals);
815 status->status =
816 tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals);
817 if (!status->status.ok()) {
818 return;
819 }
820 for (int i = 0; i < *num_retvals; ++i) {
821 retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
822 }
823 }
824
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)825 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
826 TFE_Context* ctx,
827 const char* device_name,
828 TF_Status* status) {
829 tensorflow::TensorHandle* handle;
830 status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
831 device_name, &handle);
832 if (status->status.ok()) {
833 return new TFE_TensorHandle(handle);
834 }
835 return nullptr;
836 }
837
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)838 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
839 const char* serialized_function_def, size_t size,
840 TF_Status* status) {
841 tensorflow::FunctionDef function_def;
842 if (!function_def.ParseFromArray(serialized_function_def, size)) {
843 status->status =
844 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
845 return;
846 }
847 status->status = ctx->context.AddFunctionDef(function_def);
848 }
849
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)850 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
851 TF_Status* status) {
852 status->status = ctx->context.AddFunctionDef(function->fdef);
853 }
854
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)855 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
856 return ctx->context.FindFunctionDef(name) != nullptr;
857 }
858
TFE_ContextEnableRunMetadata(TFE_Context * ctx)859 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
860 ctx->context.SetShouldStoreGraphs(true);
861 ctx->context.SetShouldStoreStepStats(true);
862 }
863
TFE_ContextDisableRunMetadata(TFE_Context * ctx)864 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
865 ctx->context.SetShouldStoreGraphs(false);
866 ctx->context.SetShouldStoreStepStats(false);
867 }
868
869 } // extern "C"
870
TFE_NewTensorHandle(const tensorflow::Tensor & t)871 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
872 return new TFE_TensorHandle(t, nullptr, nullptr);
873 }
874
TFE_TensorHandleUnderlyingTensorInHostMemory(TFE_TensorHandle * h,TF_Status * status)875 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
876 TFE_TensorHandle* h, TF_Status* status) {
877 if (!h->handle->OnHostCPU()) {
878 status->status = tensorflow::errors::FailedPrecondition(
879 "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
880 "a tensorflow::Tensor");
881 return nullptr;
882 }
883 tensorflow::Device* d = nullptr;
884 tensorflow::Device* op_device = nullptr;
885 const tensorflow::Tensor* t = nullptr;
886 status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
887 if (!status->status.ok()) return nullptr;
888 return t;
889 }
890
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)891 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
892 TF_Status* status) {
893 TFE_ContextAsyncWait(ctx, status);
894 if (!status->status.ok()) return;
895 tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
896 status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
897 ctx->context.ClearRunMetadata();
898 }
899
900 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)901 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
902 TF_Status* status) {
903 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
904 for (const auto& attr : func.attr()) {
905 if (TF_GetCode(status) != TF_OK) return nullptr;
906 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
907 if (TF_GetCode(status) != TF_OK) return nullptr;
908 }
909 return func_op;
910 }
911 } // namespace
912
TFE_ContextStartStep(TFE_Context * ctx)913 void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
914
TFE_ContextEndStep(TFE_Context * ctx)915 void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
916
917 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)918 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
919 const tensorflow::AttrValue& default_value,
920 const char* attr_name, TF_Status* status) {
921 switch (default_value.value_case()) {
922 case tensorflow::AttrValue::kS: {
923 const string& v = default_value.s();
924 TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
925 break;
926 }
927 case tensorflow::AttrValue::kI:
928 TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
929 break;
930 case tensorflow::AttrValue::kF:
931 TFE_OpSetAttrFloat(op, attr_name, default_value.f());
932 break;
933 case tensorflow::AttrValue::kB:
934 TFE_OpSetAttrBool(op, attr_name, default_value.b());
935 break;
936 case tensorflow::AttrValue::kType:
937 TFE_OpSetAttrType(op, attr_name,
938 static_cast<TF_DataType>(default_value.type()));
939 break;
940 case tensorflow::AttrValue::kShape: {
941 const auto& tensor_shape = default_value.shape();
942 if (tensor_shape.unknown_rank()) {
943 TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
944 } else {
945 const auto num_dims = tensor_shape.dim_size();
946 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
947 for (int i = 0; i < num_dims; ++i) {
948 dims[i] = tensor_shape.dim(i).size();
949 }
950 TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
951 }
952 } break;
953 case tensorflow::AttrValue::kFunc: {
954 const auto func_op = GetFunc(ctx, default_value.func(), status);
955 if (TF_GetCode(status) != TF_OK) return;
956 // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
957 // require TFE_Op* and just convert it internally a NameAttrValue, so
958 // consider adding an overload to the C API to make this case easier.
959 TFE_OpSetAttrFunction(op, attr_name, func_op);
960 } break;
961 case tensorflow::AttrValue::kList:
962 TF_FALLTHROUGH_INTENDED;
963 case tensorflow::AttrValue::kTensor:
964 TF_FALLTHROUGH_INTENDED;
965 case tensorflow::AttrValue::kPlaceholder:
966 TF_FALLTHROUGH_INTENDED;
967 case tensorflow::AttrValue::VALUE_NOT_SET:
968 TF_SetStatus(
969 status, TF_UNIMPLEMENTED,
970 tensorflow::strings::StrCat("Unable to get setfor default value: ",
971 default_value.DebugString())
972 .data());
973 }
974 }
975 } // namespace tensorflow
976