• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 // clang-format off
25 // Required for IS_MOBILE_PLATFORM
26 #include "tensorflow/core/platform/platform.h"
27 // clang-format on
28 
29 #include "absl/algorithm/container.h"
30 #include "absl/container/fixed_array.h"
31 #include "absl/memory/memory.h"
32 #include "tensorflow/c/c_api.h"
33 #include "tensorflow/c/c_api_internal.h"
34 #include "tensorflow/c/eager/tensor_handle_interface.h"
35 #include "tensorflow/c/tf_tensor_internal.h"
36 #include "tensorflow/c/eager/c_api_experimental.h"
37 #include "tensorflow/c/eager/c_api_internal.h"
38 #include "tensorflow/core/common_runtime/device.h"
39 #include "tensorflow/core/common_runtime/eager/context.h"
40 #include "tensorflow/core/framework/device_attributes.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/platform/platform.h"  // NOLINT
46 #include "tensorflow/core/protobuf/error_codes.pb.h"
47 #include "tensorflow/core/protobuf/device_filters.pb.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #ifdef TENSORFLOW_EAGER_USE_XLA
50 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
51 #endif  // TENSORFLOW_EAGER_USE_XLA
52 #include "tensorflow/core/common_runtime/copy_tensor.h"
53 #include "tensorflow/core/common_runtime/device_factory.h"
54 #include "tensorflow/core/common_runtime/device_mgr.h"
55 #include "tensorflow/core/common_runtime/device_set.h"
56 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
57 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
58 #include "tensorflow/core/common_runtime/eager/execute.h"
59 #include "tensorflow/core/common_runtime/function.h"
60 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
61 #if !defined(IS_MOBILE_PLATFORM)
62 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
63 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
64 #include "tensorflow/core/distributed_runtime/remote_device.h"
65 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
66 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
67 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
68 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
69 #include "tensorflow/core/distributed_runtime/server_lib.h"
70 #include "tensorflow/core/distributed_runtime/worker_env.h"
71 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
72 #endif  // !IS_MOBILE_PLATFORM
73 #include "tensorflow/core/framework/node_def_util.h"
74 #include "tensorflow/core/framework/rendezvous.h"
75 #include "tensorflow/core/framework/tensor_shape.pb.h"
76 #include "tensorflow/core/framework/types.h"
77 #include "tensorflow/core/lib/core/blocking_counter.h"
78 #include "tensorflow/core/lib/core/notification.h"
79 #include "tensorflow/core/lib/core/refcount.h"
80 #include "tensorflow/core/lib/core/stringpiece.h"
81 #include "tensorflow/core/lib/gtl/cleanup.h"
82 #include "tensorflow/core/lib/gtl/flatmap.h"
83 #include "tensorflow/core/lib/gtl/map_util.h"
84 
85 #include "tensorflow/core/lib/random/random.h"
86 #include "tensorflow/core/platform/casts.h"
87 #include "tensorflow/core/platform/env.h"
88 #include "tensorflow/core/platform/mutex.h"
89 #include "tensorflow/core/platform/thread_annotations.h"
90 #include "tensorflow/core/profiler/lib/traceme.h"
91 #include "tensorflow/core/public/version.h"
92 
93 using tensorflow::int64;
94 using tensorflow::string;
95 
96 namespace {
97 
GetOpDef(TFE_Op * op,TF_Status * status)98 const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
99   const tensorflow::OpDef* op_def = op->operation.OpDef();
100   if (op_def) return op_def;
101   status->status =
102       tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
103   return op_def;
104 }
105 
IsCPU(const tensorflow::Device * d)106 bool IsCPU(const tensorflow::Device* d) {
107   return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
108 }
109 
DeviceName(const tensorflow::Device * d)110 string DeviceName(const tensorflow::Device* d) {
111   return (d == nullptr) ? "cpu:0" : d->name();
112 }
113 
114 #if !defined(IS_MOBILE_PLATFORM)
AddRemoteDevicesToMgr(const std::vector<string> & added_remote_workers,tensorflow::WorkerCacheInterface * worker_cache,tensorflow::DynamicDeviceMgr * remote_device_mgr)115 tensorflow::Status AddRemoteDevicesToMgr(
116     const std::vector<string>& added_remote_workers,
117     tensorflow::WorkerCacheInterface* worker_cache,
118     tensorflow::DynamicDeviceMgr* remote_device_mgr) {
119   std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
120   tensorflow::mutex remote_devices_mu;
121   int num_added_workers = added_remote_workers.size();
122   tensorflow::BlockingCounter counter(num_added_workers);
123   std::vector<tensorflow::Status> statuses(num_added_workers);
124   for (int i = 0; i < num_added_workers; i++) {
125     tensorflow::NewRemoteDevices(
126         tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
127         [i, &statuses, &counter, &remote_devices, &remote_devices_mu](
128             const tensorflow::Status& s,
129             std::vector<tensorflow::Device*>* devices) {
130           statuses[i] = s;
131           if (s.ok()) {
132             tensorflow::mutex_lock l(remote_devices_mu);
133             for (tensorflow::Device* d : *devices) {
134               remote_devices.emplace_back(d);
135             }
136           }
137           counter.DecrementCount();
138         });
139   }
140   counter.Wait();
141   for (int i = 0; i < num_added_workers; i++) {
142     TF_RETURN_IF_ERROR(statuses[i]);
143   }
144 
145   TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
146   return tensorflow::Status::OK();
147 }
148 
GetAllRemoteDevices(const std::vector<string> & remote_workers,tensorflow::WorkerCacheInterface * worker_cache,std::unique_ptr<tensorflow::DynamicDeviceMgr> * device_mgr)149 tensorflow::Status GetAllRemoteDevices(
150     const std::vector<string>& remote_workers,
151     tensorflow::WorkerCacheInterface* worker_cache,
152     std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
153   auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
154   TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
155                                            remote_device_mgr.get()));
156   *device_mgr = std::move(remote_device_mgr);
157   return tensorflow::Status::OK();
158 }
159 
RemoveRemoteDevicesFromMgr(const std::vector<string> & removed_remote_workers,tensorflow::DynamicDeviceMgr * remote_device_mgr)160 tensorflow::Status RemoveRemoteDevicesFromMgr(
161     const std::vector<string>& removed_remote_workers,
162     tensorflow::DynamicDeviceMgr* remote_device_mgr) {
163   const std::vector<tensorflow::Device*> remote_devices =
164       (remote_device_mgr->ListDevices());
165   std::vector<tensorflow::Device*> devices_to_remove;
166   for (tensorflow::Device* d : remote_devices) {
167     for (const string& remote_worker : removed_remote_workers) {
168       if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
169                                                           d->name())) {
170         devices_to_remove.emplace_back(d);
171         break;
172       }
173     }
174   }
175   TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
176   return tensorflow::Status::OK();
177 }
178 
ListRemoteWorkers(tensorflow::ServerInterface * server,const string & local_worker,std::vector<string> * remote_workers)179 tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
180                                      const string& local_worker,
181                                      std::vector<string>* remote_workers) {
182   tensorflow::GrpcServer* grpc_server =
183       dynamic_cast<tensorflow::GrpcServer*>(server);
184   if (grpc_server == nullptr) {
185     return tensorflow::errors::Internal(
186         "Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
187   }
188   grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
189   remote_workers->erase(
190       std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
191       remote_workers->end());
192   return tensorflow::Status::OK();
193 }
194 
DifferentiateWorkerLists(const std::vector<string> * current_list,const std::vector<string> * new_list,std::vector<string> * added,std::vector<string> * removed,std::vector<string> * existing)195 void DifferentiateWorkerLists(const std::vector<string>* current_list,
196                               const std::vector<string>* new_list,
197                               std::vector<string>* added,
198                               std::vector<string>* removed,
199                               std::vector<string>* existing) {
200   // Get STL set_difference and set_intersection with one list traversal.
201   // Similar to the set_difference library function, the input lists
202   // (`current_list` and `new_list`) must be sorted before calling the function.
203   added->resize(new_list->size());
204   removed->resize(current_list->size());
205   existing->resize(current_list->size());
206   std::vector<string>::const_iterator curr_it = current_list->begin();
207   std::vector<string>::const_iterator new_it = new_list->begin();
208   std::vector<string>::iterator added_it = added->begin();
209   std::vector<string>::iterator removed_it = removed->begin();
210   std::vector<string>::iterator existing_it = existing->begin();
211   while (curr_it != current_list->end() && new_it != new_list->end()) {
212     if (*curr_it < *new_it) {
213       *removed_it++ = *curr_it++;
214     } else if (*curr_it > *new_it) {
215       *added_it++ = *new_it++;
216     } else {
217       *existing_it++ = *curr_it++;
218       new_it++;
219     }
220   }
221   removed_it = std::copy(curr_it, current_list->end(), removed_it);
222   added_it = std::copy(new_it, new_list->end(), added_it);
223   added->resize(added_it - added->begin());
224   removed->resize(removed_it - removed->begin());
225   existing->resize(existing_it - existing->begin());
226 }
227 
GetReplacedFromExistingWorkers(const std::vector<string> * existing_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * client_cache,std::vector<string> * replaced_workers)228 tensorflow::Status GetReplacedFromExistingWorkers(
229     const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
230     tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
231     tensorflow::eager::EagerClientCache* client_cache,
232     std::vector<string>* replaced_workers) {
233   tensorflow::BlockingCounter counter(existing_workers->size());
234   std::vector<tensorflow::Status> statuses(existing_workers->size());
235   tensorflow::eager::KeepAliveRequest request;
236   request.set_context_id(context_id);
237   std::vector<tensorflow::eager::KeepAliveResponse> responses(
238       existing_workers->size());
239   for (int i = 0; i < existing_workers->size(); i++) {
240     tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
241     statuses[i] =
242         client_cache->GetClient(existing_workers->at(i), &eager_client);
243     if (!statuses[i].ok()) {
244       counter.DecrementCount();
245       continue;
246     }
247     eager_client->KeepAliveAsync(
248         &request, &responses[i],
249         [i, &statuses, &counter](const tensorflow::Status& s) {
250           statuses[i] = s;
251           counter.DecrementCount();
252         });
253   }
254   counter.Wait();
255   for (int i = 0; i < existing_workers->size(); i++) {
256     // If the RPC fails (indicating that the requested ID doesn't exist on
257     // remote), or the returned view ID is not equal to the local one
258     // (indicating that the remote worker has a stale view of cluster), treat
259     // the worker as replaced.
260     if (!statuses[i].ok() ||
261         responses[i].context_view_id() != context_view_id) {
262       replaced_workers->emplace_back(existing_workers->at(i));
263     }
264   }
265   return tensorflow::Status::OK();
266 }
267 
CreateRemoteContexts(TFE_Context * ctx,const std::vector<string> & remote_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,int keep_alive_secs,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,bool async,const bool lazy_copy_remote_function_inputs,const tensorflow::eager::CreateContextRequest & base_request)268 tensorflow::Status CreateRemoteContexts(
269     TFE_Context* ctx, const std::vector<string>& remote_workers,
270     tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
271     int keep_alive_secs, const tensorflow::ServerDef& server_def,
272     tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
273     const bool lazy_copy_remote_function_inputs,
274     const tensorflow::eager::CreateContextRequest& base_request) {
275   int num_remote_workers = remote_workers.size();
276   tensorflow::BlockingCounter counter(num_remote_workers);
277   std::vector<tensorflow::Status> statuses(num_remote_workers);
278   for (int i = 0; i < num_remote_workers; i++) {
279     const string& remote_worker = remote_workers[i];
280     tensorflow::DeviceNameUtils::ParsedName parsed_name;
281     if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
282                                                     &parsed_name)) {
283       statuses[i] = tensorflow::errors::InvalidArgument(
284           "Unable to parse ", remote_worker, " as a device name");
285       counter.DecrementCount();
286       continue;
287     }
288 
289     tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
290     statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
291     if (eager_client == nullptr) {
292       statuses[i] = tensorflow::errors::Internal(
293           "Cannot find a client for the given target:", remote_worker);
294     }
295     if (!statuses[i].ok()) {
296       counter.DecrementCount();
297       continue;
298     }
299 
300     tensorflow::eager::CreateContextRequest request;
301     tensorflow::eager::CreateContextResponse* response =
302         new tensorflow::eager::CreateContextResponse();
303     request.set_context_id(context_id);
304     request.set_context_view_id(context_view_id);
305     *request.mutable_server_def() = server_def;
306     request.mutable_server_def()->set_job_name(parsed_name.job);
307     request.mutable_server_def()->set_task_index(parsed_name.task);
308     request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
309         server_def.default_session_config());
310 
311     std::vector<bool> filtered_device_mask;
312     ctx->context->FilterDevicesForRemoteWorkers(
313         remote_worker, base_request.cluster_device_attributes(),
314         &filtered_device_mask);
315     DCHECK_EQ(filtered_device_mask.size(),
316               base_request.cluster_device_attributes_size());
317     for (int i = 0; i < filtered_device_mask.size(); i++) {
318       if (filtered_device_mask[i]) {
319         const auto& da = base_request.cluster_device_attributes(i);
320         *request.add_cluster_device_attributes() = da;
321       }
322     }
323     request.set_async(async);
324     request.set_keep_alive_secs(keep_alive_secs);
325     request.set_lazy_copy_remote_function_inputs(
326         lazy_copy_remote_function_inputs);
327 
328     eager_client->CreateContextAsync(
329         &request, response,
330         [i, &statuses, &counter, response](const tensorflow::Status& s) {
331           statuses[i] = s;
332           delete response;
333           counter.DecrementCount();
334         });
335   }
336   counter.Wait();
337   for (int i = 0; i < num_remote_workers; i++) {
338     TF_RETURN_IF_ERROR(statuses[i]);
339   }
340   return tensorflow::Status::OK();
341 }
342 
UpdateRemoteContexts(TFE_Context * ctx,const std::vector<string> & remote_workers,const std::vector<string> & added_workers,const std::vector<string> & removed_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,const tensorflow::eager::CreateContextRequest & base_request)343 tensorflow::Status UpdateRemoteContexts(
344     TFE_Context* ctx, const std::vector<string>& remote_workers,
345     const std::vector<string>& added_workers,
346     const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
347     tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
348     tensorflow::eager::EagerClientCache* remote_eager_workers,
349     const tensorflow::eager::CreateContextRequest& base_request) {
350   int num_remote_workers = remote_workers.size();
351   tensorflow::BlockingCounter counter(num_remote_workers);
352   std::vector<tensorflow::Status> statuses(num_remote_workers);
353 
354   int cluster_device_count = base_request.cluster_device_attributes_size();
355   std::unordered_set<string> added_or_removed(added_workers.begin(),
356                                               added_workers.end());
357   std::copy(removed_workers.begin(), removed_workers.end(),
358             std::inserter(added_or_removed, added_or_removed.end()));
359   // Whether each device is in the updated (added or removed) workers
360   std::vector<bool> device_added_or_removed(cluster_device_count);
361   for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
362     const auto& da = base_request.cluster_device_attributes().at(i);
363     tensorflow::DeviceNameUtils::ParsedName pn;
364     tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
365     string task_name;
366     tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
367     if (added_or_removed.find(task_name) != added_or_removed.end()) {
368       device_added_or_removed[i] = true;
369     }
370   }
371 
372   for (int i = 0; i < num_remote_workers; i++) {
373     const string& remote_worker = remote_workers[i];
374     tensorflow::DeviceNameUtils::ParsedName parsed_name;
375     if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
376                                                     &parsed_name)) {
377       statuses[i] = tensorflow::errors::InvalidArgument(
378           "Unable to parse ", remote_worker, " as a device name");
379       counter.DecrementCount();
380       continue;
381     }
382 
383     tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
384     statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
385     if (eager_client == nullptr) {
386       statuses[i] = tensorflow::errors::Internal(
387           "Cannot find a client for the given target:", remote_worker);
388     }
389     if (!statuses[i].ok()) {
390       counter.DecrementCount();
391       continue;
392     }
393 
394     std::vector<bool> filtered_device_mask;
395     ctx->context->FilterDevicesForRemoteWorkers(
396         remote_worker, base_request.cluster_device_attributes(),
397         &filtered_device_mask);
398     DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
399 
400     // If any of the devices that match the device filters are in the set of
401     // added or removed workers, we must send a complete UpdateContextRequest.
402     // Otherwise, only send a simple request to increment context view ID.
403     std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
404     std::transform(device_added_or_removed.begin(),
405                    device_added_or_removed.end(), filtered_device_mask.begin(),
406                    added_or_removed_filtered_devices.begin(),
407                    std::logical_and<bool>());
408     const bool full_update_request =
409         std::accumulate(added_or_removed_filtered_devices.begin(),
410                         added_or_removed_filtered_devices.end(), false,
411                         std::logical_or<bool>());
412 
413     tensorflow::eager::UpdateContextRequest request;
414     auto* response = new tensorflow::eager::UpdateContextResponse();
415     request.set_context_id(context_id);
416     request.set_context_view_id(context_view_id);
417     if (full_update_request) {
418       *request.mutable_server_def() = server_def;
419       request.mutable_server_def()->set_job_name(parsed_name.job);
420       request.mutable_server_def()->set_task_index(parsed_name.task);
421       request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
422           server_def.default_session_config());
423       for (int i = 0; i < cluster_device_count; i++) {
424         if (filtered_device_mask[i]) {
425           const auto& da = base_request.cluster_device_attributes(i);
426           *request.add_cluster_device_attributes() = da;
427         }
428       }
429     }
430 
431     eager_client->UpdateContextAsync(
432         &request, response,
433         [i, &statuses, &counter, response](const tensorflow::Status& s) {
434           statuses[i] = s;
435           delete response;
436           counter.DecrementCount();
437         });
438   }
439   counter.Wait();
440   for (int i = 0; i < num_remote_workers; i++) {
441     TF_RETURN_IF_ERROR(statuses[i]);
442   }
443   return tensorflow::Status::OK();
444 }
445 
UpdateTFE_ContextWithServerDef(int keep_alive_secs,const tensorflow::ServerDef & server_def,TFE_Context * ctx,bool reset_context)446 tensorflow::Status UpdateTFE_ContextWithServerDef(
447     int keep_alive_secs, const tensorflow::ServerDef& server_def,
448     TFE_Context* ctx, bool reset_context) {
449   // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
450   // server object (which currently CHECK-fails) and we miss the error, instead,
451   // we log the error, and then return to allow the user to see the error
452   // message.
453 #define LOG_AND_RETURN_IF_ERROR(...)                    \
454   do {                                                  \
455     const ::tensorflow::Status _status = (__VA_ARGS__); \
456     if (TF_PREDICT_FALSE(!_status.ok())) {              \
457       LOG(ERROR) << _status.error_message();            \
458       return _status;                                   \
459     }                                                   \
460   } while (0);
461 
462   string worker_name =
463       tensorflow::strings::StrCat("/job:", server_def.job_name(),
464                                   "/replica:0/task:", server_def.task_index());
465 
466   // List of current remote workers before updating server_def. Unused if
467   // resetting the server_def.
468   std::vector<string> curr_remote_workers;
469   // List of updated remote workers.
470   std::vector<string> remote_workers;
471 
472   // New server created for new server_def. Unused if updating server_def.
473   std::unique_ptr<tensorflow::ServerInterface> new_server;
474   tensorflow::EagerContext* context = ctx->context;
475   tensorflow::GrpcServer* grpc_server;
476   if (reset_context) {
477     LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
478     grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
479     LOG_AND_RETURN_IF_ERROR(
480         ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
481   } else {
482     LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
483                                               &curr_remote_workers));
484     // No need to check the cast here, since `ListRemoteWorkers` already checks
485     // if the server is a GRPC server or not.
486     grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
487     LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
488     LOG_AND_RETURN_IF_ERROR(
489         ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
490   }
491 
492   tensorflow::uint64 context_id = context->GetContextId();
493   tensorflow::uint64 context_view_id = context->GetContextViewId();
494   if (reset_context) {
495     context_id = tensorflow::EagerContext::NewContextId();
496     context_view_id = 0;
497     // Make master eager context accessible by local eager service, which might
498     // receive send tensor requests from remote workers.
499     LOG_AND_RETURN_IF_ERROR(
500         grpc_server->AddMasterEagerContextToEagerService(context_id, context));
501   }
502 
503   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
504   LOG_AND_RETURN_IF_ERROR(
505       grpc_server->master_env()->worker_cache->GetEagerClientCache(
506           &remote_eager_workers));
507 
508   // When updating an existing context, populate the following lists with:
509   // * added_workers: set(remote_workers) - set(curr_remote_workers)
510   // * removed_workers: set(curr_remote_workers) - set(remote_workers)
511   // * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
512   // * replaced_workers: workers with the same task names and potentially the
513   //     same `hostname:port`s, but replaced by different processes
514   std::vector<string> added_workers;
515   std::vector<string> removed_workers;
516   std::vector<string> existing_workers;
517   std::vector<string> replaced_workers;
518 
519   // New remote device manager created for new server_def. Unused if updating
520   // server_def.
521   std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
522   tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
523   if (reset_context) {
524     LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
525         remote_workers, grpc_server->master_env()->worker_cache,
526         &new_remote_device_mgr));
527     remote_device_mgr = new_remote_device_mgr.get();
528   } else {
529     context->ClearCachesAndDefaultExecutor();
530     // TODO(b/143914772): Potential memory leak if rendezvous has pending
531     // tensors for removed / replaced workers.
532 
533     remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
534     if (remote_device_mgr == nullptr) {
535       LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
536           "Updating context with an invalid set of remote devices."));
537     }
538     std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
539     std::sort(remote_workers.begin(), remote_workers.end());
540     DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
541                              &added_workers, &removed_workers,
542                              &existing_workers);
543     LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
544         &existing_workers, context_id, context->GetContextViewId(), server_def,
545         remote_eager_workers.get(), &replaced_workers));
546     if (VLOG_IS_ON(1)) {
547       VLOG(1) << "Updating cluster with following changes";
548       for (const string& w : added_workers) VLOG(1) << "  Added worker " << w;
549       for (const string& w : removed_workers)
550         VLOG(1) << "  Removed worker " << w;
551       for (const string& w : replaced_workers)
552         VLOG(1) << "  Replaced worker " << w;
553     }
554     if (!replaced_workers.empty()) {
555       // Treat replaced workers as removed then added back, so that we recreate
556       // remote devices and contexts, and re-register functions on those workers
557       removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
558                              replaced_workers.end());
559       added_workers.insert(added_workers.end(), replaced_workers.begin(),
560                            replaced_workers.end());
561       for (const string& w : replaced_workers) {
562         existing_workers.erase(
563             std::remove(existing_workers.begin(), existing_workers.end(), w),
564             existing_workers.end());
565       }
566     }
567     LOG_AND_RETURN_IF_ERROR(
568         RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
569     LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr(
570         added_workers, grpc_server->master_env()->worker_cache,
571         remote_device_mgr));
572   }
573 
574   std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
575   remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
576 
577   std::vector<tensorflow::DeviceAttributes> local_device_attributes;
578   grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
579       &local_device_attributes);
580 
581   // This request make sure that we can create Rendezvous properly between
582   // Local and Remote context.
583   tensorflow::eager::CreateContextRequest base_request;
584   for (const auto& da : cluster_device_attributes) {
585     *base_request.add_cluster_device_attributes() = da;
586   }
587   for (const auto& da : local_device_attributes) {
588     *base_request.add_cluster_device_attributes() = da;
589   }
590 
591   // Initialize remote eager workers.
592   // TODO(b/138847548) Create remote eager contexts in async mode by default.
593   if (reset_context) {
594     LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
595         ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
596         server_def, remote_eager_workers.get(), context->Executor().Async(),
597         context->LazyCopyFunctionRemoteInputs(), base_request));
598   } else {
599     // The master's context_view_id will be incremented by one
600     // the UpdateRemoteMaster call later. We want all new workers and
601     // existing workers to also have the updated context_view_id, so
602     // we must set their context_view_id to the existing master's
603     // context_view_id + 1.
604     LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
605         ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
606         server_def, remote_eager_workers.get(), context->Executor().Async(),
607         context->LazyCopyFunctionRemoteInputs(), base_request));
608     if (!existing_workers.empty()) {
609       if (VLOG_IS_ON(1)) {
610         for (const string& w : existing_workers) {
611           VLOG(1) << "Updating cluster with existing worker " << w;
612         }
613       }
614       LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
615           ctx, existing_workers, added_workers, removed_workers, context_id,
616           context_view_id + 1, server_def, remote_eager_workers.get(),
617           base_request));
618     }
619   }
620 
621   tensorflow::RemoteRendezvous* r =
622       grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
623   auto session_name = tensorflow::strings::StrCat("eager_", context_id);
624   auto* device_mgr = grpc_server->worker_env()->device_mgr;
625   std::shared_ptr<tensorflow::WorkerSession> worker_session;
626 
627   if (reset_context) {
628     TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
629         session_name, server_def, base_request.cluster_device_attributes(),
630         true));
631     TF_RETURN_IF_ERROR(
632         grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
633             session_name, &worker_session));
634 
635     // Initialize remote tensor communication based on worker session.
636     TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
637 
638     tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
639         tensorflow::eager::CreateClusterFLR(context_id, context,
640                                             worker_session.get());
641     auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
642         /*is_master=*/true, context);
643 
644     LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
645         std::move(new_server), grpc_server->worker_env(), worker_session,
646         std::move(remote_eager_workers), std::move(new_remote_device_mgr),
647         remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
648         std::move(remote_mgr)));
649 
650     // NOTE: We start the server after all other initialization, because the
651     // GrpcServer cannot be destroyed after it is started.
652     LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
653   } else {
654     LOG_AND_RETURN_IF_ERROR(
655         grpc_server->worker_env()->session_mgr->UpdateSession(
656             session_name, server_def, base_request.cluster_device_attributes(),
657             true));
658     TF_RETURN_IF_ERROR(
659         grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
660             session_name, &worker_session));
661     tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
662         tensorflow::eager::CreateClusterFLR(context_id, context,
663                                             worker_session.get());
664     LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
665         grpc_server->worker_env(), std::move(remote_eager_workers),
666         added_workers, removed_workers, context_id, r, device_mgr,
667         keep_alive_secs, cluster_flr));
668   }
669 #undef LOG_AND_RETURN_IF_ERROR
670 
671   return tensorflow::Status::OK();
672 }
673 #endif  // !IS_MOBILE_PLATFORM
674 
675 }  // namespace
676 
677 extern "C" {
678 
TFE_NewContextOptions()679 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
680 
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)681 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
682                                  size_t proto_len, TF_Status* status) {
683   TF_SetConfig(&options->session_options, proto, proto_len, status);
684 }
685 
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)686 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
687                                 unsigned char enable) {
688   options->async = enable;
689 }
690 
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)691 void TFE_ContextOptionsSetDevicePlacementPolicy(
692     TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
693   options->device_placement_policy = policy;
694 }
695 
TFE_DeleteContextOptions(TFE_ContextOptions * options)696 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
697 
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)698 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
699   std::vector<std::unique_ptr<tensorflow::Device>> devices;
700   status->status = tensorflow::DeviceFactory::AddDevices(
701       opts->session_options.options, "/job:localhost/replica:0/task:0",
702       &devices);
703   if (!status->status.ok()) return nullptr;
704   std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
705       new tensorflow::StaticDeviceMgr(std::move(devices)));
706 
707   tensorflow::Rendezvous* r =
708       new tensorflow::IntraProcessRendezvous(device_mgr.get());
709 
710   return new TFE_Context{new tensorflow::EagerContext(
711       opts->session_options.options,
712       static_cast<tensorflow::ContextDevicePlacementPolicy>(
713           opts->device_placement_policy),
714       static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
715       opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
716       /*device_mgr_owned*/ true, r,
717       tensorflow::GetDefaultCustomKernelCreator())};
718 }
719 
TFE_NewContextFromSession(const TFE_ContextOptions * opts,TF_Session * sess,TF_Status * status)720 TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
721                                        TF_Session* sess, TF_Status* status) {
722   const tensorflow::DeviceMgr* device_mgr = nullptr;
723   status->status = sess->session->LocalDeviceManager(&device_mgr);
724   if (!status->status.ok()) return nullptr;
725   tensorflow::Rendezvous* r =
726       new tensorflow::IntraProcessRendezvous(device_mgr);
727 
728   return new TFE_Context{new tensorflow::EagerContext(
729       opts->session_options.options,
730       static_cast<tensorflow::ContextDevicePlacementPolicy>(
731           opts->device_placement_policy),
732       static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
733       opts->async, opts->lazy_remote_inputs_copy, device_mgr,
734       /*device_mgr_owned*/ false, r,
735       tensorflow::GetDefaultCustomKernelCreator())};
736 }
737 
TFE_DeleteContext(TFE_Context * ctx)738 void TFE_DeleteContext(TFE_Context* ctx) {
739   // context->RefCountIsOne() should be true here.
740   // TODO(iga): Remove EagerContext refcounting.
741   ctx->context->Unref();
742 
743   delete ctx;
744 }
745 
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)746 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
747   TF_DeviceList* l = new TF_DeviceList;
748   ctx->context->ListDevices(&l->response);
749   return l;
750 }
751 
TFE_ContextClearCaches(TFE_Context * ctx)752 void TFE_ContextClearCaches(TFE_Context* ctx) {
753   ctx->context->ClearCachesAndThreadExecutors();
754 }
755 
756 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)757 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
758                                                    int keep_alive_secs,
759                                                    const void* proto,
760                                                    size_t proto_len,
761                                                    TF_Status* status) {
762 #if defined(IS_MOBILE_PLATFORM)
763   status->status = tensorflow::errors::Unimplemented(
764       "TFE_ContextSetServerDef not supported on mobile");
765 #else   // !defined(IS_MOBILE_PLATFORM)
766   tensorflow::ServerDef server_def;
767   if (!server_def.ParseFromArray(proto, proto_len)) {
768     status->status = tensorflow::errors::InvalidArgument(
769         "Invalid tensorflow.ServerDef protocol buffer");
770     return;
771   }
772   if (server_def.has_cluster_device_filters()) {
773     const auto& cdf = server_def.cluster_device_filters();
774     for (const auto& jdf : cdf.jobs()) {
775       const string& remote_prefix = "/job:" + jdf.name() + "/task:";
776       for (const auto& tdf : jdf.tasks()) {
777         const int32_t task_index = tdf.first;
778         std::vector<string> device_filters(tdf.second.device_filters_size());
779         for (int i = 0; i < tdf.second.device_filters_size(); i++) {
780           device_filters[i] = tdf.second.device_filters(i);
781         }
782         const string remote_worker = remote_prefix + std::to_string(task_index);
783         status->status =
784             ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
785       }
786     }
787   }
788   status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
789                                                   ctx, /*reset_context=*/true);
790 #endif  // !IS_MOBILE_PLATFORM
791 }
792 
TFE_ContextUpdateServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)793 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
794                                                       int keep_alive_secs,
795                                                       const void* proto,
796                                                       size_t proto_len,
797                                                       TF_Status* status) {
798 #if defined(IS_MOBILE_PLATFORM)
799   status->status = tensorflow::errors::Unimplemented(
800       "TFE_ContextSetServerDef not supported on mobile");
801 #else   // !defined(IS_MOBILE_PLATFORM)
802   tensorflow::ServerDef server_def;
803   if (!server_def.ParseFromArray(proto, proto_len)) {
804     status->status = tensorflow::errors::InvalidArgument(
805         "Invalid tensorflow.ServerDef protocol buffer");
806     return;
807   } else if (ctx->context->GetContextId() ==
808              tensorflow::EagerContext::kInvalidContextId) {
809     status->status = tensorflow::errors::InvalidArgument(
810         "Trying to update a context with invalid context id.");
811   }
812   if (server_def.has_cluster_device_filters()) {
813     LOG(WARNING) << "Device filters can only be specified when initializing "
814                     "the cluster. Any changes in device filters are ignored "
815                     "when updating the server def.";
816   }
817   // TODO(haoyuzhang): Check server_def compatibility before the update
818   status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
819                                                   ctx, /*reset_context=*/false);
820 #endif  // !IS_MOBILE_PLATFORM
821 }
822 
TFE_ContextCheckAlive(TFE_Context * ctx,const char * worker_name,TF_Status * status)823 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
824                                                  const char* worker_name,
825                                                  TF_Status* status) {
826 #if defined(IS_MOBILE_PLATFORM)
827   status->status = tensorflow::errors::Unimplemented(
828       "TFE_ContextSetServerDef not supported on mobile");
829   return false;
830 #else   // !defined(IS_MOBILE_PLATFORM)
831   tensorflow::EagerContext* context = ctx->context;
832   tensorflow::GrpcServer* grpc_server =
833       static_cast<tensorflow::GrpcServer*>(context->GetServer());
834 
835   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
836   status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
837       &remote_eager_workers);
838   if (!status->status.ok()) {
839     LOG(ERROR) << "Failed to get client cache for remote workers.";
840     return false;
841   }
842 
843   // TODO(yuefengz): support partially specified `worker_name`.
844   tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
845   status->status = remote_eager_workers->GetClient(worker_name, &eager_client);
846   if (!status->status.ok()) {
847     return false;
848   }
849 
850   // Send a rpc request to the worker to check aliveness.
851   tensorflow::eager::KeepAliveRequest request;
852   request.set_context_id(context->GetContextId());
853   tensorflow::eager::KeepAliveResponse response;
854 
855   tensorflow::Status keep_alive_status;
856   tensorflow::Notification done;
857   eager_client->KeepAliveAsync(
858       &request, &response,
859       [&keep_alive_status, &done](const tensorflow::Status& s) {
860         keep_alive_status = s;
861         done.Notify();
862       });
863   done.WaitForNotification();
864 
865   status->status = tensorflow::Status::OK();
866 
867   // If `context_id` doesn't exist on the remote worker, an InvalidArgument
868   // error will return. But this still indicates that the remote worker is
869   // alive.
870   if (keep_alive_status.ok() ||
871       keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
872     return true;
873   } else {
874     LOG(INFO) << "Remote worker " << worker_name
875               << " is not alive: " << keep_alive_status.error_message();
876     return false;
877   }
878 #endif  // !IS_MOBILE_PLATFORM
879 }
880 
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)881 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
882     TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
883   ctx->context->SetThreadLocalDevicePlacementPolicy(
884       static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
885 }
886 
887 // Note: this function looks up a thread local policy. So it should be called in
888 // the appropriate client thread. In particular, in async mode, it may not be
889 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)890 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
891     TFE_Context* ctx) {
892   return static_cast<TFE_ContextDevicePlacementPolicy>(
893       ctx->context->GetDevicePlacementPolicy());
894 }
895 
TFE_NewTensorHandle(TF_Tensor * t,TF_Status * status)896 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
897   tensorflow::Tensor tensor;
898   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
899   if (!status->status.ok()) return nullptr;
900   return TFE_TensorHandle::CreateLocalHandle(tensor, status);
901 }
902 
TFE_DeleteTensorHandle(TFE_TensorHandle * h)903 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
904   if (h == nullptr) return;
905   tensorflow::profiler::TraceMe activity(
906       "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
907   delete h;
908 }
909 
~TensorHandleInterface()910 tensorflow::TensorHandleInterface::~TensorHandleInterface() {
911   VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
912           << handle_;
913   if (handle_) {
914     handle_->Unref();
915   }
916 }
917 
IsValid(Status * status) const918 bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
919   if (handle_ == nullptr) {
920     *status = tensorflow::errors::InvalidArgument(
921         "The passed in handle is a nullptr");
922     return false;
923   }
924 
925   return true;
926 }
927 
TFE_TensorHandleDataType(TFE_TensorHandle * h)928 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
929   return h->handle->DataType();
930 }
931 
DataType() const932 TF_DataType tensorflow::TensorHandleInterface::DataType() const {
933   return static_cast<TF_DataType>(handle_->dtype);
934 }
935 
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)936 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
937   if (h == nullptr) {
938     status->status = tensorflow::errors::InvalidArgument(
939         "The passed in handle is a nullptr");
940     return -1;
941   }
942 
943   return h->handle->NumDims(&status->status);
944 }
945 
NumDims(Status * status) const946 int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
947   if (!IsValid(status)) {
948     return -1;
949   }
950 
951   int result;
952   *status = handle_->NumDims(&result);
953   return result;
954 }
955 
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)956 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
957   if (h == nullptr) {
958     status->status = tensorflow::errors::InvalidArgument(
959         "The passed in handle is a nullptr");
960     return -1;
961   }
962 
963   return h->handle->NumElements(&status->status);
964 }
965 
NumElements(Status * status) const966 int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
967   if (!IsValid(status)) {
968     return -1;
969   }
970 
971   tensorflow::int64 result;
972   *status = handle_->NumElements(&result);
973   return result;
974 }
975 
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)976 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
977                             TF_Status* status) {
978   if (h == nullptr) {
979     status->status = tensorflow::errors::InvalidArgument(
980         "The passed in handle is a nullptr");
981     return -1;
982   }
983 
984   return h->handle->Dim(dim_index, &status->status);
985 }
986 
Dim(int dim_index,Status * status) const987 int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
988                                                Status* status) const {
989   if (!IsValid(status)) {
990     return -1;
991   }
992 
993   tensorflow::int64 result;
994   *status = handle_->Dim(dim_index, &result);
995   return result;
996 }
997 
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)998 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
999   if (h == nullptr) {
1000     status->status = tensorflow::errors::InvalidArgument(
1001         "The passed in handle is a nullptr");
1002     return nullptr;
1003   }
1004   return h->handle->DeviceName(&status->status);
1005 }
1006 
DeviceName(Status * status) const1007 const char* tensorflow::TensorHandleInterface::DeviceName(
1008     Status* status) const {
1009   if (!IsValid(status)) {
1010     return nullptr;
1011   }
1012   tensorflow::Device* d = handle_->op_device();
1013   return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
1014                         : d->name().c_str();
1015 }
1016 
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)1017 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
1018                                               TF_Status* status) {
1019   if (h == nullptr) {
1020     status->status = tensorflow::errors::InvalidArgument(
1021         "The passed in handle is a nullptr");
1022     return nullptr;
1023   }
1024   return h->handle->BackingDeviceName(&status->status);
1025 }
1026 
BackingDeviceName(Status * status) const1027 const char* tensorflow::TensorHandleInterface::BackingDeviceName(
1028     Status* status) const {
1029   if (!IsValid(status)) {
1030     return nullptr;
1031   }
1032   tensorflow::Device* d = handle_->device();
1033   return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
1034                         : d->name().c_str();
1035 }
1036 
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)1037 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
1038     TFE_TensorHandle* h, TF_Status* status) {
1039   if (h == nullptr || !h->handle->IsValid(&status->status)) {
1040     status->status = tensorflow::errors::InvalidArgument(
1041         "The passed in handle is a nullptr");
1042     return nullptr;
1043   }
1044 
1045   return new TFE_TensorHandle{
1046       std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
1047 }
1048 
Copy()1049 AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
1050   handle_->Ref();
1051   return new TensorHandleInterface(handle_);
1052 }
1053 
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)1054 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
1055   if (h == nullptr) {
1056     status->status = tensorflow::errors::InvalidArgument(
1057         "The passed in handle is a nullptr");
1058     return nullptr;
1059   }
1060 
1061   return h->handle->Resolve(&status->status);
1062 }
1063 
Resolve(Status * status)1064 TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
1065   if (!IsValid(status)) {
1066     return nullptr;
1067   }
1068 
1069   // TODO(agarwal): move this implementation inside TFE_TensorHandle.
1070   if (handle_->IsRemote()) {
1071     const tensorflow::Tensor* t = nullptr;
1072     tensorflow::TensorHandle* h_cpu = nullptr;
1073     *status = EagerCopyToDevice(handle_, handle_->Context(),
1074                                 &handle_->Context()->Executor(),
1075                                 handle_->Context()->HostCPU(), false, &h_cpu);
1076     if (!status->ok()) {
1077       return nullptr;
1078     }
1079     *status = h_cpu->Tensor(&t);
1080     if (!status->ok()) {
1081       h_cpu->Unref();
1082       return nullptr;
1083     }
1084     TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
1085     h_cpu->Unref();
1086     return retval;
1087   } else {
1088     tensorflow::Tensor tensor;
1089     if (IsCPU(handle_->device())) {
1090       const tensorflow::Tensor* src = nullptr;
1091       *status = handle_->Tensor(&src);
1092       if (!status->ok()) return nullptr;
1093       tensor = *src;
1094     } else {
1095       tensorflow::EagerContext* ctx = handle_->Context();
1096       CHECK_NE(ctx, nullptr);
1097       *status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
1098       if (!status->ok()) return nullptr;
1099     }
1100     return tensorflow::TF_TensorFromTensor(tensor, status);
1101   }
1102 }
1103 
TFE_TensorHandleDevicePointer(TFE_TensorHandle * h,TF_Status * status)1104 void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
1105   if (h == nullptr || !h->handle->IsValid(&status->status)) {
1106     status->status = tensorflow::errors::InvalidArgument(
1107         "The passed in handle is a nullptr");
1108     return nullptr;
1109   }
1110   tensorflow::TensorHandle* handle =
1111       tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
1112           ->Handle();
1113 
1114   if (handle->IsRemote()) {
1115     status->status = tensorflow::errors::InvalidArgument(
1116         "TFE_TensorHandleDevicePointer may not be called on a remote tensor "
1117         "handle.");
1118     return nullptr;
1119   }
1120   if (handle->device() != nullptr) {
1121     status->status = handle->device()->Sync();
1122     if (!status->status.ok()) {
1123       return nullptr;
1124     }
1125   }
1126   const tensorflow::Tensor* tensor;
1127   status->status = handle->Tensor(&tensor);
1128   if (!status->status.ok()) {
1129     return nullptr;
1130   }
1131   return const_cast<void*>(
1132       static_cast<const void*>(tensor->tensor_data().data()));
1133 }
1134 
TFE_NewTensorHandleFromDeviceMemory(TFE_Context * ctx,const char * device_name,TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg,TF_Status * status)1135 TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
1136     TFE_Context* ctx, const char* device_name, TF_DataType dtype,
1137     const int64_t* dims, int num_dims, void* data, size_t len,
1138     void (*deallocator)(void* data, size_t len, void* arg),
1139     void* deallocator_arg, TF_Status* status) {
1140   tensorflow::Device* device;
1141   tensorflow::EagerContext* context = ctx->context;
1142   status->status = context->FindDeviceFromName(device_name, &device);
1143   if (!status->status.ok()) {
1144     deallocator(data, len, deallocator_arg);
1145     return nullptr;
1146   }
1147   std::vector<tensorflow::int64> dimvec(num_dims);
1148   for (int i = 0; i < num_dims; ++i) {
1149     dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
1150   }
1151 
1152   if (dtype == TF_STRING || dtype == TF_RESOURCE ||
1153       !tensorflow::DataTypeCanUseMemcpy(
1154           static_cast<tensorflow::DataType>(dtype))) {
1155     status->status = tensorflow::errors::InvalidArgument(
1156         "Trying to create a tensor with a pointer to non-pod memory.");
1157     deallocator(data, len, deallocator_arg);
1158     return nullptr;
1159   }
1160   // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
1161   // the device?
1162   TF_ManagedBuffer* buf =
1163       new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
1164 
1165   tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
1166                        tensorflow::TensorShape(dimvec), buf);
1167   buf->Unref();
1168   tensorflow::TensorHandle* ret_handle;
1169   status->status = tensorflow::TensorHandle::CreateLocalHandle(
1170       t, device, context, &ret_handle);
1171   if (!status->status.ok()) {
1172     return nullptr;
1173   }
1174   return new TFE_TensorHandle{
1175       std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
1176 }
1177 
1178 // This function will block till the operation that produces `h` has
1179 // completed. This is only valid on local TFE_TensorHandles. Returns the size in
1180 // bytes of the memory pointed to by the device pointer returned above.
TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle * h,TF_Status * status)1181 size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
1182                                         TF_Status* status) {
1183   if (h == nullptr || !h->handle->IsValid(&status->status)) {
1184     status->status = tensorflow::errors::InvalidArgument(
1185         "The passed in handle is a nullptr");
1186     return 0;
1187   }
1188   tensorflow::TensorHandle* handle =
1189       tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
1190           ->Handle();
1191 
1192   if (handle->IsRemote()) {
1193     status->status = tensorflow::errors::InvalidArgument(
1194         "TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
1195         "handle.");
1196     return 0;
1197   }
1198   const tensorflow::Tensor* tensor;
1199   status->status = handle->Tensor(&tensor);
1200   if (!status->status.ok()) {
1201     return 0;
1202   }
1203   return tensor->TotalBytes();
1204 }
1205 
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)1206 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
1207                   TF_Status* status) {
1208   std::unique_ptr<TFE_Op> new_op(
1209       new TFE_Op{tensorflow::EagerOperation(ctx->context)});
1210   status->status =
1211       new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
1212   if (!status->status.ok()) {
1213     new_op.reset();
1214   }
1215   return new_op.release();
1216 }
1217 
TFE_DeleteOp(TFE_Op * op)1218 void TFE_DeleteOp(TFE_Op* op) { delete op; }
1219 
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)1220 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
1221   status->status = op->operation.SetDeviceName(device_name);
1222 }
1223 
TFE_OpGetDevice(TFE_Op * op,TF_Status * status)1224 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
1225   tensorflow::Device* device = (op->operation.Device() == nullptr)
1226                                    ? op->operation.EagerContext().HostCPU()
1227                                    : op->operation.Device();
1228   return device->name().c_str();
1229 }
1230 
TFE_OpSetXLACompilation(TFE_Op * op,unsigned char enable)1231 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
1232   op->operation.SetUseXla(enable);
1233 #ifndef TENSORFLOW_EAGER_USE_XLA
1234   LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
1235                   "built with XLA support.";
1236 #endif  // TENSORFLOW_EAGER_USE_XLA
1237 }
1238 
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)1239 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
1240   tensorflow::TensorHandle* h =
1241       tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
1242           input->handle.get())
1243           ->Handle();
1244   op->operation.AddInput(h);
1245   status->status = op->operation.MaybeInferSingleInputAttrs(h);
1246 }
1247 
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)1248 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
1249                         TF_Status* status) {
1250   for (int i = 0; i < num_inputs; ++i) {
1251     op->operation.AddInput(
1252         tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
1253             inputs[i]->handle.get())
1254             ->Handle());
1255   }
1256   status->status = op->operation.InferInputListAttrs(num_inputs);
1257 }
1258 
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)1259 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
1260                               unsigned char* is_list, TF_Status* status) {
1261   TF_AttrType ret = TF_ATTR_INT;
1262   status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
1263                                               attr_name, &ret, is_list);
1264   return ret;
1265 }
1266 
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)1267 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
1268                                   const char* op_or_function_name,
1269                                   const char* attr_name, unsigned char* is_list,
1270                                   TF_Status* status) {
1271   TF_AttrType ret;
1272   TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
1273   if (status->status.ok()) {
1274     ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
1275   } else {
1276     ret = TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
1277   }
1278   TFE_DeleteOp(op);
1279   return ret;
1280 }
1281 
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)1282 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
1283                          size_t length) {
1284   op->operation.MutableAttrs()->Set(
1285       attr_name,
1286       tensorflow::StringPiece(static_cast<const char*>(value), length));
1287 }
1288 
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)1289 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
1290   op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
1291 }
1292 
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)1293 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
1294   op->operation.MutableAttrs()->Set(attr_name, value);
1295 }
1296 
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)1297 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
1298   op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
1299 }
1300 
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)1301 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
1302   op->operation.MutableAttrs()->Set(attr_name,
1303                                     static_cast<tensorflow::DataType>(value));
1304 }
1305 
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)1306 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
1307                         const int num_dims, TF_Status* out_status) {
1308   if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
1309     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
1310                  tensorflow::strings::StrCat(
1311                      "Value specified for `", attr_name, "` has ", num_dims,
1312                      " dimensions which is over the limit of ",
1313                      tensorflow::TensorShape::MaxDimensions(), ".")
1314                      .c_str());
1315     return;
1316   }
1317   tensorflow::TensorShapeProto proto;
1318   if (num_dims < 0) {
1319     proto.set_unknown_rank(true);
1320   } else {
1321     for (int d = 0; d < num_dims; ++d) {
1322       proto.add_dim()->set_size(dims[d]);
1323     }
1324   }
1325   op->operation.MutableAttrs()->Set(attr_name, proto);
1326 }
1327 
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)1328 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
1329                            const TFE_Op* value) {
1330   tensorflow::AttrValue attr_value;
1331   tensorflow::NameAttrList* func = attr_value.mutable_func();
1332   func->set_name(value->operation.Name());
1333   value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
1334   op->operation.MutableAttrs()->Set(attr_name, attr_value);
1335 }
1336 
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)1337 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
1338                                const char* data, size_t length) {
1339   tensorflow::AttrValue attr_value;
1340   tensorflow::NameAttrList* func = attr_value.mutable_func();
1341   func->set_name(data, length);
1342   op->operation.MutableAttrs()->Set(attr_name, attr_value);
1343 }
1344 
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)1345 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
1346                          TF_Status* status) {
1347   tensorflow::Tensor t;
1348   status->status = TF_TensorToTensor(tensor, &t);
1349   if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
1350 }
1351 
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1352 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
1353                              const void* const* values, const size_t* lengths,
1354                              int num_values) {
1355   std::vector<tensorflow::StringPiece> v(num_values);
1356   for (int i = 0; i < num_values; ++i) {
1357     v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
1358                                    lengths[i]);
1359   }
1360   op->operation.MutableAttrs()->Set(attr_name, v);
1361 }
1362 
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)1363 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
1364                             const float* values, int num_values) {
1365   op->operation.MutableAttrs()->Set(
1366       attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
1367 }
1368 
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)1369 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
1370                           const int64_t* values, int num_values) {
1371   op->operation.MutableAttrs()->Set(
1372       attr_name, tensorflow::gtl::ArraySlice<const int64>(
1373                      reinterpret_cast<const int64*>(values), num_values));
1374 }
1375 
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)1376 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
1377                            const TF_DataType* values, int num_values) {
1378   op->operation.MutableAttrs()->Set(
1379       attr_name,
1380       tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
1381           reinterpret_cast<const tensorflow::DataType*>(values), num_values));
1382 }
1383 
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)1384 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
1385                            const unsigned char* values, int num_values) {
1386   std::unique_ptr<bool[]> b(new bool[num_values]);
1387   for (int i = 0; i < num_values; ++i) {
1388     b[i] = values[i];
1389   }
1390   op->operation.MutableAttrs()->Set(
1391       attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
1392 }
1393 
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)1394 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
1395                             const int64_t** dims, const int* num_dims,
1396                             int num_values, TF_Status* out_status) {
1397   std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
1398       new tensorflow::TensorShapeProto[num_values]);
1399   for (int i = 0; i < num_values; ++i) {
1400     const auto num_dims_i = num_dims[i];
1401 
1402     if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
1403       TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
1404                    tensorflow::strings::StrCat(
1405                        "Value specified for `", attr_name, "` has ", num_dims_i,
1406                        " dimensions which is over the limit of ",
1407                        tensorflow::TensorShape::MaxDimensions(), ".")
1408                        .c_str());
1409       return;
1410     }
1411     if (num_dims_i < 0) {
1412       proto[i].set_unknown_rank(true);
1413     } else {
1414       const int64_t* dims_i = dims[i];
1415       auto proto_i = &proto[i];
1416       for (int d = 0; d < num_dims_i; ++d) {
1417         proto_i->add_dim()->set_size(dims_i[d]);
1418       }
1419     }
1420   }
1421   op->operation.MutableAttrs()->Set(
1422       attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
1423                      proto.get(), num_values));
1424 }
1425 
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)1426 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
1427                                const TFE_Op** value, int num_values) {
1428   std::unique_ptr<tensorflow::NameAttrList[]> funcs(
1429       new tensorflow::NameAttrList[num_values]);
1430   for (int i = 0; i < num_values; i++) {
1431     funcs[i].set_name(value[i]->operation.Name());
1432     value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
1433   }
1434   op->operation.MutableAttrs()->Set(
1435       attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
1436                      funcs.get(), num_values));
1437 }
1438 
TFE_OpGetInputLength(TFE_Op * op,const char * input_name,TF_Status * status)1439 TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
1440                                                const char* input_name,
1441                                                TF_Status* status) {
1442   const tensorflow::OpDef* op_def = GetOpDef(op, status);
1443   if (!status->status.ok()) {
1444     return -1;
1445   }
1446   tensorflow::AttrValueMap attrs;
1447   op->operation.Attrs().FillAttrValueMap(&attrs);
1448   tensorflow::NameRangeMap name_ranges;
1449   status->status = tensorflow::NameRangesForNode(
1450       tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
1451   if (!status->status.ok()) {
1452     return -1;
1453   }
1454   auto iter = name_ranges.find(input_name);
1455   if (iter == name_ranges.end()) {
1456     status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
1457                                                          "' not found");
1458     return -1;
1459   }
1460   return iter->second.second - iter->second.first;
1461 }
1462 
TFE_OpGetOutputLength(TFE_Op * op,const char * output_name,TF_Status * status)1463 TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
1464                                                 const char* output_name,
1465                                                 TF_Status* status) {
1466   const tensorflow::OpDef* op_def = GetOpDef(op, status);
1467   if (!status->status.ok()) {
1468     return -1;
1469   }
1470   tensorflow::AttrValueMap attrs;
1471   op->operation.Attrs().FillAttrValueMap(&attrs);
1472   tensorflow::NameRangeMap name_ranges;
1473   status->status = tensorflow::NameRangesForNode(
1474       tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
1475   if (!status->status.ok()) {
1476     return -1;
1477   }
1478   auto iter = name_ranges.find(output_name);
1479   if (iter == name_ranges.end()) {
1480     status->status = tensorflow::errors::InvalidArgument(
1481         "Output '", output_name, "' not found");
1482     return -1;
1483   }
1484   return iter->second.second - iter->second.first;
1485 }
1486 
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)1487 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
1488                  TF_Status* status) {
1489   absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
1490   VLOG(1) << "Calling TFE_Execute() on op " << op;
1491   status->status = tensorflow::EagerExecute(&op->operation,
1492                                             handle_retvals.data(), num_retvals);
1493   if (!status->status.ok()) {
1494     return;
1495   }
1496   for (int i = 0; i < *num_retvals; ++i) {
1497     retvals[i] = new TFE_TensorHandle{
1498         std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
1499   }
1500 }
1501 
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)1502 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
1503                                                TFE_Context* ctx,
1504                                                const char* device_name,
1505                                                TF_Status* status) {
1506   tensorflow::TensorHandle* handle = nullptr;
1507   tensorflow::Device* device;
1508   tensorflow::EagerContext* context = ctx->context;
1509   status->status = context->FindDeviceFromName(device_name, &device);
1510   if (!status->status.ok()) {
1511     return nullptr;
1512   }
1513   status->status = tensorflow::EagerCopyToDevice(
1514       tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
1515           ->Handle(),
1516       context, &context->Executor(), device, false, &handle);
1517   if (status->status.ok()) {
1518     return new TFE_TensorHandle{
1519         std::make_unique<tensorflow::TensorHandleInterface>(handle)};
1520   }
1521   return nullptr;
1522 }
1523 
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)1524 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
1525                                const char* serialized_function_def, size_t size,
1526                                TF_Status* status) {
1527   tensorflow::FunctionDef function_def;
1528   if (!function_def.ParseFromArray(serialized_function_def, size)) {
1529     status->status =
1530         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
1531     return;
1532   }
1533   status->status = ctx->context->AddFunctionDef(function_def);
1534 }
1535 
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)1536 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
1537                             TF_Status* status) {
1538   status->status = ctx->context->AddFunctionDef(function->fdef);
1539 }
1540 
TFE_ContextRemoveFunction(TFE_Context * ctx,const char * name,TF_Status * status)1541 void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
1542                                TF_Status* status) {
1543   status->status = ctx->context->RemoveFunction(name);
1544 }
1545 
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)1546 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
1547   return ctx->context->FindFunctionDef(name) != nullptr;
1548 }
1549 
TFE_ContextEnableRunMetadata(TFE_Context * ctx)1550 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
1551   ctx->context->SetShouldStoreGraphs(true);
1552 }
1553 
TFE_ContextDisableRunMetadata(TFE_Context * ctx)1554 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
1555   ctx->context->SetShouldStoreGraphs(false);
1556 }
1557 
1558 }  // extern "C"
1559 
TFE_NewTensorHandle(const tensorflow::Tensor & t,TF_Status * status)1560 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
1561                                       TF_Status* status) {
1562   return TFE_TensorHandle::CreateLocalHandle(t, status);
1563 }
1564 
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)1565 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
1566                                   TF_Status* status) {
1567   tensorflow::EagerContext* context = ctx->context;
1568   status->status = context->Executor().WaitForAllPendingNodes();
1569   if (!status->status.ok()) return;
1570   tensorflow::mutex_lock ml(*context->MetadataMu());
1571   status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
1572   context->ClearRunMetadata();
1573 }
1574 
1575 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)1576 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
1577                 TF_Status* status) {
1578   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
1579   for (const auto& attr : func.attr()) {
1580     if (!status->status.ok()) return nullptr;
1581     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
1582     if (!status->status.ok()) return nullptr;
1583   }
1584   return func_op;
1585 }
1586 }  // namespace
1587 
TFE_ContextStartStep(TFE_Context * ctx)1588 void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
1589 
TFE_ContextEndStep(TFE_Context * ctx)1590 void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
1591 
1592 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)1593 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
1594                           const tensorflow::AttrValue& default_value,
1595                           const char* attr_name, TF_Status* status) {
1596   switch (default_value.value_case()) {
1597     case tensorflow::AttrValue::kS: {
1598       const string& v = default_value.s();
1599       TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
1600       break;
1601     }
1602     case tensorflow::AttrValue::kI:
1603       TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1604       break;
1605     case tensorflow::AttrValue::kF:
1606       TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1607       break;
1608     case tensorflow::AttrValue::kB:
1609       TFE_OpSetAttrBool(op, attr_name, default_value.b());
1610       break;
1611     case tensorflow::AttrValue::kType:
1612       TFE_OpSetAttrType(op, attr_name,
1613                         static_cast<TF_DataType>(default_value.type()));
1614       break;
1615     case tensorflow::AttrValue::kShape: {
1616       const auto& tensor_shape = default_value.shape();
1617       if (tensor_shape.unknown_rank()) {
1618         TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1619       } else {
1620         const auto num_dims = tensor_shape.dim_size();
1621         std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1622         for (int i = 0; i < num_dims; ++i) {
1623           dims[i] = tensor_shape.dim(i).size();
1624         }
1625         TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1626       }
1627     } break;
1628     case tensorflow::AttrValue::kFunc: {
1629       const auto func_op = GetFunc(ctx, default_value.func(), status);
1630       if (!status->status.ok()) return;
1631       // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1632       // require TFE_Op* and just convert it internally a NameAttrValue, so
1633       // consider adding an overload to the C API to make this case easier.
1634       TFE_OpSetAttrFunction(op, attr_name, func_op);
1635     } break;
1636     case tensorflow::AttrValue::kList:
1637       TF_FALLTHROUGH_INTENDED;
1638     case tensorflow::AttrValue::kTensor:
1639       TF_FALLTHROUGH_INTENDED;
1640     case tensorflow::AttrValue::kPlaceholder:
1641       TF_FALLTHROUGH_INTENDED;
1642     case tensorflow::AttrValue::VALUE_NOT_SET:
1643       TF_SetStatus(
1644           status, TF_UNIMPLEMENTED,
1645           tensorflow::strings::StrCat("Unable to get setfor default value: ",
1646                                       default_value.DebugString())
1647               .data());
1648   }
1649 }
1650 }  // namespace tensorflow
1651