1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 // clang-format off
25 // Required for IS_MOBILE_PLATFORM
26 #include "tensorflow/core/platform/platform.h"
27 // clang-format on
28
29 #include "absl/algorithm/container.h"
30 #include "absl/container/fixed_array.h"
31 #include "absl/memory/memory.h"
32 #include "tensorflow/c/c_api.h"
33 #include "tensorflow/c/c_api_internal.h"
34 #include "tensorflow/c/eager/tensor_handle_interface.h"
35 #include "tensorflow/c/tf_tensor_internal.h"
36 #include "tensorflow/c/eager/c_api_experimental.h"
37 #include "tensorflow/c/eager/c_api_internal.h"
38 #include "tensorflow/core/common_runtime/device.h"
39 #include "tensorflow/core/common_runtime/eager/context.h"
40 #include "tensorflow/core/framework/device_attributes.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/platform/platform.h" // NOLINT
46 #include "tensorflow/core/protobuf/error_codes.pb.h"
47 #include "tensorflow/core/protobuf/device_filters.pb.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #ifdef TENSORFLOW_EAGER_USE_XLA
50 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
51 #endif // TENSORFLOW_EAGER_USE_XLA
52 #include "tensorflow/core/common_runtime/copy_tensor.h"
53 #include "tensorflow/core/common_runtime/device_factory.h"
54 #include "tensorflow/core/common_runtime/device_mgr.h"
55 #include "tensorflow/core/common_runtime/device_set.h"
56 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
57 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
58 #include "tensorflow/core/common_runtime/eager/execute.h"
59 #include "tensorflow/core/common_runtime/function.h"
60 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
61 #if !defined(IS_MOBILE_PLATFORM)
62 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
63 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
64 #include "tensorflow/core/distributed_runtime/remote_device.h"
65 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
66 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
67 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
68 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
69 #include "tensorflow/core/distributed_runtime/server_lib.h"
70 #include "tensorflow/core/distributed_runtime/worker_env.h"
71 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
72 #endif // !IS_MOBILE_PLATFORM
73 #include "tensorflow/core/framework/node_def_util.h"
74 #include "tensorflow/core/framework/rendezvous.h"
75 #include "tensorflow/core/framework/tensor_shape.pb.h"
76 #include "tensorflow/core/framework/types.h"
77 #include "tensorflow/core/lib/core/blocking_counter.h"
78 #include "tensorflow/core/lib/core/notification.h"
79 #include "tensorflow/core/lib/core/refcount.h"
80 #include "tensorflow/core/lib/core/stringpiece.h"
81 #include "tensorflow/core/lib/gtl/cleanup.h"
82 #include "tensorflow/core/lib/gtl/flatmap.h"
83 #include "tensorflow/core/lib/gtl/map_util.h"
84
85 #include "tensorflow/core/lib/random/random.h"
86 #include "tensorflow/core/platform/casts.h"
87 #include "tensorflow/core/platform/env.h"
88 #include "tensorflow/core/platform/mutex.h"
89 #include "tensorflow/core/platform/thread_annotations.h"
90 #include "tensorflow/core/profiler/lib/traceme.h"
91 #include "tensorflow/core/public/version.h"
92
93 using tensorflow::int64;
94 using tensorflow::string;
95
96 namespace {
97
GetOpDef(TFE_Op * op,TF_Status * status)98 const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
99 const tensorflow::OpDef* op_def = op->operation.OpDef();
100 if (op_def) return op_def;
101 status->status =
102 tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
103 return op_def;
104 }
105
IsCPU(const tensorflow::Device * d)106 bool IsCPU(const tensorflow::Device* d) {
107 return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
108 }
109
DeviceName(const tensorflow::Device * d)110 string DeviceName(const tensorflow::Device* d) {
111 return (d == nullptr) ? "cpu:0" : d->name();
112 }
113
114 #if !defined(IS_MOBILE_PLATFORM)
AddRemoteDevicesToMgr(const std::vector<string> & added_remote_workers,tensorflow::WorkerCacheInterface * worker_cache,tensorflow::DynamicDeviceMgr * remote_device_mgr)115 tensorflow::Status AddRemoteDevicesToMgr(
116 const std::vector<string>& added_remote_workers,
117 tensorflow::WorkerCacheInterface* worker_cache,
118 tensorflow::DynamicDeviceMgr* remote_device_mgr) {
119 std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
120 tensorflow::mutex remote_devices_mu;
121 int num_added_workers = added_remote_workers.size();
122 tensorflow::BlockingCounter counter(num_added_workers);
123 std::vector<tensorflow::Status> statuses(num_added_workers);
124 for (int i = 0; i < num_added_workers; i++) {
125 tensorflow::NewRemoteDevices(
126 tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
127 [i, &statuses, &counter, &remote_devices, &remote_devices_mu](
128 const tensorflow::Status& s,
129 std::vector<tensorflow::Device*>* devices) {
130 statuses[i] = s;
131 if (s.ok()) {
132 tensorflow::mutex_lock l(remote_devices_mu);
133 for (tensorflow::Device* d : *devices) {
134 remote_devices.emplace_back(d);
135 }
136 }
137 counter.DecrementCount();
138 });
139 }
140 counter.Wait();
141 for (int i = 0; i < num_added_workers; i++) {
142 TF_RETURN_IF_ERROR(statuses[i]);
143 }
144
145 TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
146 return tensorflow::Status::OK();
147 }
148
GetAllRemoteDevices(const std::vector<string> & remote_workers,tensorflow::WorkerCacheInterface * worker_cache,std::unique_ptr<tensorflow::DynamicDeviceMgr> * device_mgr)149 tensorflow::Status GetAllRemoteDevices(
150 const std::vector<string>& remote_workers,
151 tensorflow::WorkerCacheInterface* worker_cache,
152 std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
153 auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
154 TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
155 remote_device_mgr.get()));
156 *device_mgr = std::move(remote_device_mgr);
157 return tensorflow::Status::OK();
158 }
159
RemoveRemoteDevicesFromMgr(const std::vector<string> & removed_remote_workers,tensorflow::DynamicDeviceMgr * remote_device_mgr)160 tensorflow::Status RemoveRemoteDevicesFromMgr(
161 const std::vector<string>& removed_remote_workers,
162 tensorflow::DynamicDeviceMgr* remote_device_mgr) {
163 const std::vector<tensorflow::Device*> remote_devices =
164 (remote_device_mgr->ListDevices());
165 std::vector<tensorflow::Device*> devices_to_remove;
166 for (tensorflow::Device* d : remote_devices) {
167 for (const string& remote_worker : removed_remote_workers) {
168 if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
169 d->name())) {
170 devices_to_remove.emplace_back(d);
171 break;
172 }
173 }
174 }
175 TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
176 return tensorflow::Status::OK();
177 }
178
ListRemoteWorkers(tensorflow::ServerInterface * server,const string & local_worker,std::vector<string> * remote_workers)179 tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
180 const string& local_worker,
181 std::vector<string>* remote_workers) {
182 tensorflow::GrpcServer* grpc_server =
183 dynamic_cast<tensorflow::GrpcServer*>(server);
184 if (grpc_server == nullptr) {
185 return tensorflow::errors::Internal(
186 "Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
187 }
188 grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
189 remote_workers->erase(
190 std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
191 remote_workers->end());
192 return tensorflow::Status::OK();
193 }
194
DifferentiateWorkerLists(const std::vector<string> * current_list,const std::vector<string> * new_list,std::vector<string> * added,std::vector<string> * removed,std::vector<string> * existing)195 void DifferentiateWorkerLists(const std::vector<string>* current_list,
196 const std::vector<string>* new_list,
197 std::vector<string>* added,
198 std::vector<string>* removed,
199 std::vector<string>* existing) {
200 // Get STL set_difference and set_intersection with one list traversal.
201 // Similar to the set_difference library function, the input lists
202 // (`current_list` and `new_list`) must be sorted before calling the function.
203 added->resize(new_list->size());
204 removed->resize(current_list->size());
205 existing->resize(current_list->size());
206 std::vector<string>::const_iterator curr_it = current_list->begin();
207 std::vector<string>::const_iterator new_it = new_list->begin();
208 std::vector<string>::iterator added_it = added->begin();
209 std::vector<string>::iterator removed_it = removed->begin();
210 std::vector<string>::iterator existing_it = existing->begin();
211 while (curr_it != current_list->end() && new_it != new_list->end()) {
212 if (*curr_it < *new_it) {
213 *removed_it++ = *curr_it++;
214 } else if (*curr_it > *new_it) {
215 *added_it++ = *new_it++;
216 } else {
217 *existing_it++ = *curr_it++;
218 new_it++;
219 }
220 }
221 removed_it = std::copy(curr_it, current_list->end(), removed_it);
222 added_it = std::copy(new_it, new_list->end(), added_it);
223 added->resize(added_it - added->begin());
224 removed->resize(removed_it - removed->begin());
225 existing->resize(existing_it - existing->begin());
226 }
227
GetReplacedFromExistingWorkers(const std::vector<string> * existing_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * client_cache,std::vector<string> * replaced_workers)228 tensorflow::Status GetReplacedFromExistingWorkers(
229 const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
230 tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
231 tensorflow::eager::EagerClientCache* client_cache,
232 std::vector<string>* replaced_workers) {
233 tensorflow::BlockingCounter counter(existing_workers->size());
234 std::vector<tensorflow::Status> statuses(existing_workers->size());
235 tensorflow::eager::KeepAliveRequest request;
236 request.set_context_id(context_id);
237 std::vector<tensorflow::eager::KeepAliveResponse> responses(
238 existing_workers->size());
239 for (int i = 0; i < existing_workers->size(); i++) {
240 tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
241 statuses[i] =
242 client_cache->GetClient(existing_workers->at(i), &eager_client);
243 if (!statuses[i].ok()) {
244 counter.DecrementCount();
245 continue;
246 }
247 eager_client->KeepAliveAsync(
248 &request, &responses[i],
249 [i, &statuses, &counter](const tensorflow::Status& s) {
250 statuses[i] = s;
251 counter.DecrementCount();
252 });
253 }
254 counter.Wait();
255 for (int i = 0; i < existing_workers->size(); i++) {
256 // If the RPC fails (indicating that the requested ID doesn't exist on
257 // remote), or the returned view ID is not equal to the local one
258 // (indicating that the remote worker has a stale view of cluster), treat
259 // the worker as replaced.
260 if (!statuses[i].ok() ||
261 responses[i].context_view_id() != context_view_id) {
262 replaced_workers->emplace_back(existing_workers->at(i));
263 }
264 }
265 return tensorflow::Status::OK();
266 }
267
CreateRemoteContexts(TFE_Context * ctx,const std::vector<string> & remote_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,int keep_alive_secs,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,bool async,const bool lazy_copy_remote_function_inputs,const tensorflow::eager::CreateContextRequest & base_request)268 tensorflow::Status CreateRemoteContexts(
269 TFE_Context* ctx, const std::vector<string>& remote_workers,
270 tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
271 int keep_alive_secs, const tensorflow::ServerDef& server_def,
272 tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
273 const bool lazy_copy_remote_function_inputs,
274 const tensorflow::eager::CreateContextRequest& base_request) {
275 int num_remote_workers = remote_workers.size();
276 tensorflow::BlockingCounter counter(num_remote_workers);
277 std::vector<tensorflow::Status> statuses(num_remote_workers);
278 for (int i = 0; i < num_remote_workers; i++) {
279 const string& remote_worker = remote_workers[i];
280 tensorflow::DeviceNameUtils::ParsedName parsed_name;
281 if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
282 &parsed_name)) {
283 statuses[i] = tensorflow::errors::InvalidArgument(
284 "Unable to parse ", remote_worker, " as a device name");
285 counter.DecrementCount();
286 continue;
287 }
288
289 tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
290 statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
291 if (eager_client == nullptr) {
292 statuses[i] = tensorflow::errors::Internal(
293 "Cannot find a client for the given target:", remote_worker);
294 }
295 if (!statuses[i].ok()) {
296 counter.DecrementCount();
297 continue;
298 }
299
300 tensorflow::eager::CreateContextRequest request;
301 tensorflow::eager::CreateContextResponse* response =
302 new tensorflow::eager::CreateContextResponse();
303 request.set_context_id(context_id);
304 request.set_context_view_id(context_view_id);
305 *request.mutable_server_def() = server_def;
306 request.mutable_server_def()->set_job_name(parsed_name.job);
307 request.mutable_server_def()->set_task_index(parsed_name.task);
308 request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
309 server_def.default_session_config());
310
311 std::vector<bool> filtered_device_mask;
312 ctx->context->FilterDevicesForRemoteWorkers(
313 remote_worker, base_request.cluster_device_attributes(),
314 &filtered_device_mask);
315 DCHECK_EQ(filtered_device_mask.size(),
316 base_request.cluster_device_attributes_size());
317 for (int i = 0; i < filtered_device_mask.size(); i++) {
318 if (filtered_device_mask[i]) {
319 const auto& da = base_request.cluster_device_attributes(i);
320 *request.add_cluster_device_attributes() = da;
321 }
322 }
323 request.set_async(async);
324 request.set_keep_alive_secs(keep_alive_secs);
325 request.set_lazy_copy_remote_function_inputs(
326 lazy_copy_remote_function_inputs);
327
328 eager_client->CreateContextAsync(
329 &request, response,
330 [i, &statuses, &counter, response](const tensorflow::Status& s) {
331 statuses[i] = s;
332 delete response;
333 counter.DecrementCount();
334 });
335 }
336 counter.Wait();
337 for (int i = 0; i < num_remote_workers; i++) {
338 TF_RETURN_IF_ERROR(statuses[i]);
339 }
340 return tensorflow::Status::OK();
341 }
342
UpdateRemoteContexts(TFE_Context * ctx,const std::vector<string> & remote_workers,const std::vector<string> & added_workers,const std::vector<string> & removed_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,const tensorflow::eager::CreateContextRequest & base_request)343 tensorflow::Status UpdateRemoteContexts(
344 TFE_Context* ctx, const std::vector<string>& remote_workers,
345 const std::vector<string>& added_workers,
346 const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
347 tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
348 tensorflow::eager::EagerClientCache* remote_eager_workers,
349 const tensorflow::eager::CreateContextRequest& base_request) {
350 int num_remote_workers = remote_workers.size();
351 tensorflow::BlockingCounter counter(num_remote_workers);
352 std::vector<tensorflow::Status> statuses(num_remote_workers);
353
354 int cluster_device_count = base_request.cluster_device_attributes_size();
355 std::unordered_set<string> added_or_removed(added_workers.begin(),
356 added_workers.end());
357 std::copy(removed_workers.begin(), removed_workers.end(),
358 std::inserter(added_or_removed, added_or_removed.end()));
359 // Whether each device is in the updated (added or removed) workers
360 std::vector<bool> device_added_or_removed(cluster_device_count);
361 for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
362 const auto& da = base_request.cluster_device_attributes().at(i);
363 tensorflow::DeviceNameUtils::ParsedName pn;
364 tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
365 string task_name;
366 tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
367 if (added_or_removed.find(task_name) != added_or_removed.end()) {
368 device_added_or_removed[i] = true;
369 }
370 }
371
372 for (int i = 0; i < num_remote_workers; i++) {
373 const string& remote_worker = remote_workers[i];
374 tensorflow::DeviceNameUtils::ParsedName parsed_name;
375 if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
376 &parsed_name)) {
377 statuses[i] = tensorflow::errors::InvalidArgument(
378 "Unable to parse ", remote_worker, " as a device name");
379 counter.DecrementCount();
380 continue;
381 }
382
383 tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
384 statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
385 if (eager_client == nullptr) {
386 statuses[i] = tensorflow::errors::Internal(
387 "Cannot find a client for the given target:", remote_worker);
388 }
389 if (!statuses[i].ok()) {
390 counter.DecrementCount();
391 continue;
392 }
393
394 std::vector<bool> filtered_device_mask;
395 ctx->context->FilterDevicesForRemoteWorkers(
396 remote_worker, base_request.cluster_device_attributes(),
397 &filtered_device_mask);
398 DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
399
400 // If any of the devices that match the device filters are in the set of
401 // added or removed workers, we must send a complete UpdateContextRequest.
402 // Otherwise, only send a simple request to increment context view ID.
403 std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
404 std::transform(device_added_or_removed.begin(),
405 device_added_or_removed.end(), filtered_device_mask.begin(),
406 added_or_removed_filtered_devices.begin(),
407 std::logical_and<bool>());
408 const bool full_update_request =
409 std::accumulate(added_or_removed_filtered_devices.begin(),
410 added_or_removed_filtered_devices.end(), false,
411 std::logical_or<bool>());
412
413 tensorflow::eager::UpdateContextRequest request;
414 auto* response = new tensorflow::eager::UpdateContextResponse();
415 request.set_context_id(context_id);
416 request.set_context_view_id(context_view_id);
417 if (full_update_request) {
418 *request.mutable_server_def() = server_def;
419 request.mutable_server_def()->set_job_name(parsed_name.job);
420 request.mutable_server_def()->set_task_index(parsed_name.task);
421 request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
422 server_def.default_session_config());
423 for (int i = 0; i < cluster_device_count; i++) {
424 if (filtered_device_mask[i]) {
425 const auto& da = base_request.cluster_device_attributes(i);
426 *request.add_cluster_device_attributes() = da;
427 }
428 }
429 }
430
431 eager_client->UpdateContextAsync(
432 &request, response,
433 [i, &statuses, &counter, response](const tensorflow::Status& s) {
434 statuses[i] = s;
435 delete response;
436 counter.DecrementCount();
437 });
438 }
439 counter.Wait();
440 for (int i = 0; i < num_remote_workers; i++) {
441 TF_RETURN_IF_ERROR(statuses[i]);
442 }
443 return tensorflow::Status::OK();
444 }
445
UpdateTFE_ContextWithServerDef(int keep_alive_secs,const tensorflow::ServerDef & server_def,TFE_Context * ctx,bool reset_context)446 tensorflow::Status UpdateTFE_ContextWithServerDef(
447 int keep_alive_secs, const tensorflow::ServerDef& server_def,
448 TFE_Context* ctx, bool reset_context) {
449 // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
450 // server object (which currently CHECK-fails) and we miss the error, instead,
451 // we log the error, and then return to allow the user to see the error
452 // message.
453 #define LOG_AND_RETURN_IF_ERROR(...) \
454 do { \
455 const ::tensorflow::Status _status = (__VA_ARGS__); \
456 if (TF_PREDICT_FALSE(!_status.ok())) { \
457 LOG(ERROR) << _status.error_message(); \
458 return _status; \
459 } \
460 } while (0);
461
462 string worker_name =
463 tensorflow::strings::StrCat("/job:", server_def.job_name(),
464 "/replica:0/task:", server_def.task_index());
465
466 // List of current remote workers before updating server_def. Unused if
467 // resetting the server_def.
468 std::vector<string> curr_remote_workers;
469 // List of updated remote workers.
470 std::vector<string> remote_workers;
471
472 // New server created for new server_def. Unused if updating server_def.
473 std::unique_ptr<tensorflow::ServerInterface> new_server;
474 tensorflow::EagerContext* context = ctx->context;
475 tensorflow::GrpcServer* grpc_server;
476 if (reset_context) {
477 LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
478 grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
479 LOG_AND_RETURN_IF_ERROR(
480 ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
481 } else {
482 LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
483 &curr_remote_workers));
484 // No need to check the cast here, since `ListRemoteWorkers` already checks
485 // if the server is a GRPC server or not.
486 grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
487 LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
488 LOG_AND_RETURN_IF_ERROR(
489 ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
490 }
491
492 tensorflow::uint64 context_id = context->GetContextId();
493 tensorflow::uint64 context_view_id = context->GetContextViewId();
494 if (reset_context) {
495 context_id = tensorflow::EagerContext::NewContextId();
496 context_view_id = 0;
497 // Make master eager context accessible by local eager service, which might
498 // receive send tensor requests from remote workers.
499 LOG_AND_RETURN_IF_ERROR(
500 grpc_server->AddMasterEagerContextToEagerService(context_id, context));
501 }
502
503 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
504 LOG_AND_RETURN_IF_ERROR(
505 grpc_server->master_env()->worker_cache->GetEagerClientCache(
506 &remote_eager_workers));
507
508 // When updating an existing context, populate the following lists with:
509 // * added_workers: set(remote_workers) - set(curr_remote_workers)
510 // * removed_workers: set(curr_remote_workers) - set(remote_workers)
511 // * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
512 // * replaced_workers: workers with the same task names and potentially the
513 // same `hostname:port`s, but replaced by different processes
514 std::vector<string> added_workers;
515 std::vector<string> removed_workers;
516 std::vector<string> existing_workers;
517 std::vector<string> replaced_workers;
518
519 // New remote device manager created for new server_def. Unused if updating
520 // server_def.
521 std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
522 tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
523 if (reset_context) {
524 LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
525 remote_workers, grpc_server->master_env()->worker_cache,
526 &new_remote_device_mgr));
527 remote_device_mgr = new_remote_device_mgr.get();
528 } else {
529 context->ClearCachesAndDefaultExecutor();
530 // TODO(b/143914772): Potential memory leak if rendezvous has pending
531 // tensors for removed / replaced workers.
532
533 remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
534 if (remote_device_mgr == nullptr) {
535 LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
536 "Updating context with an invalid set of remote devices."));
537 }
538 std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
539 std::sort(remote_workers.begin(), remote_workers.end());
540 DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
541 &added_workers, &removed_workers,
542 &existing_workers);
543 LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
544 &existing_workers, context_id, context->GetContextViewId(), server_def,
545 remote_eager_workers.get(), &replaced_workers));
546 if (VLOG_IS_ON(1)) {
547 VLOG(1) << "Updating cluster with following changes";
548 for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
549 for (const string& w : removed_workers)
550 VLOG(1) << " Removed worker " << w;
551 for (const string& w : replaced_workers)
552 VLOG(1) << " Replaced worker " << w;
553 }
554 if (!replaced_workers.empty()) {
555 // Treat replaced workers as removed then added back, so that we recreate
556 // remote devices and contexts, and re-register functions on those workers
557 removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
558 replaced_workers.end());
559 added_workers.insert(added_workers.end(), replaced_workers.begin(),
560 replaced_workers.end());
561 for (const string& w : replaced_workers) {
562 existing_workers.erase(
563 std::remove(existing_workers.begin(), existing_workers.end(), w),
564 existing_workers.end());
565 }
566 }
567 LOG_AND_RETURN_IF_ERROR(
568 RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
569 LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr(
570 added_workers, grpc_server->master_env()->worker_cache,
571 remote_device_mgr));
572 }
573
574 std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
575 remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
576
577 std::vector<tensorflow::DeviceAttributes> local_device_attributes;
578 grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
579 &local_device_attributes);
580
581 // This request make sure that we can create Rendezvous properly between
582 // Local and Remote context.
583 tensorflow::eager::CreateContextRequest base_request;
584 for (const auto& da : cluster_device_attributes) {
585 *base_request.add_cluster_device_attributes() = da;
586 }
587 for (const auto& da : local_device_attributes) {
588 *base_request.add_cluster_device_attributes() = da;
589 }
590
591 // Initialize remote eager workers.
592 // TODO(b/138847548) Create remote eager contexts in async mode by default.
593 if (reset_context) {
594 LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
595 ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
596 server_def, remote_eager_workers.get(), context->Executor().Async(),
597 context->LazyCopyFunctionRemoteInputs(), base_request));
598 } else {
599 // The master's context_view_id will be incremented by one
600 // the UpdateRemoteMaster call later. We want all new workers and
601 // existing workers to also have the updated context_view_id, so
602 // we must set their context_view_id to the existing master's
603 // context_view_id + 1.
604 LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
605 ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
606 server_def, remote_eager_workers.get(), context->Executor().Async(),
607 context->LazyCopyFunctionRemoteInputs(), base_request));
608 if (!existing_workers.empty()) {
609 if (VLOG_IS_ON(1)) {
610 for (const string& w : existing_workers) {
611 VLOG(1) << "Updating cluster with existing worker " << w;
612 }
613 }
614 LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
615 ctx, existing_workers, added_workers, removed_workers, context_id,
616 context_view_id + 1, server_def, remote_eager_workers.get(),
617 base_request));
618 }
619 }
620
621 tensorflow::RemoteRendezvous* r =
622 grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
623 auto session_name = tensorflow::strings::StrCat("eager_", context_id);
624 auto* device_mgr = grpc_server->worker_env()->device_mgr;
625 std::shared_ptr<tensorflow::WorkerSession> worker_session;
626
627 if (reset_context) {
628 TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
629 session_name, server_def, base_request.cluster_device_attributes(),
630 true));
631 TF_RETURN_IF_ERROR(
632 grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
633 session_name, &worker_session));
634
635 // Initialize remote tensor communication based on worker session.
636 TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
637
638 tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
639 tensorflow::eager::CreateClusterFLR(context_id, context,
640 worker_session.get());
641 auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
642 /*is_master=*/true, context);
643
644 LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
645 std::move(new_server), grpc_server->worker_env(), worker_session,
646 std::move(remote_eager_workers), std::move(new_remote_device_mgr),
647 remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
648 std::move(remote_mgr)));
649
650 // NOTE: We start the server after all other initialization, because the
651 // GrpcServer cannot be destroyed after it is started.
652 LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
653 } else {
654 LOG_AND_RETURN_IF_ERROR(
655 grpc_server->worker_env()->session_mgr->UpdateSession(
656 session_name, server_def, base_request.cluster_device_attributes(),
657 true));
658 TF_RETURN_IF_ERROR(
659 grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
660 session_name, &worker_session));
661 tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
662 tensorflow::eager::CreateClusterFLR(context_id, context,
663 worker_session.get());
664 LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster(
665 grpc_server->worker_env(), std::move(remote_eager_workers),
666 added_workers, removed_workers, context_id, r, device_mgr,
667 keep_alive_secs, cluster_flr));
668 }
669 #undef LOG_AND_RETURN_IF_ERROR
670
671 return tensorflow::Status::OK();
672 }
673 #endif // !IS_MOBILE_PLATFORM
674
675 } // namespace
676
677 extern "C" {
678
TFE_NewContextOptions()679 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
680
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)681 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
682 size_t proto_len, TF_Status* status) {
683 TF_SetConfig(&options->session_options, proto, proto_len, status);
684 }
685
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)686 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
687 unsigned char enable) {
688 options->async = enable;
689 }
690
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)691 void TFE_ContextOptionsSetDevicePlacementPolicy(
692 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
693 options->device_placement_policy = policy;
694 }
695
TFE_DeleteContextOptions(TFE_ContextOptions * options)696 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
697
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)698 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
699 std::vector<std::unique_ptr<tensorflow::Device>> devices;
700 status->status = tensorflow::DeviceFactory::AddDevices(
701 opts->session_options.options, "/job:localhost/replica:0/task:0",
702 &devices);
703 if (!status->status.ok()) return nullptr;
704 std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
705 new tensorflow::StaticDeviceMgr(std::move(devices)));
706
707 tensorflow::Rendezvous* r =
708 new tensorflow::IntraProcessRendezvous(device_mgr.get());
709
710 return new TFE_Context{new tensorflow::EagerContext(
711 opts->session_options.options,
712 static_cast<tensorflow::ContextDevicePlacementPolicy>(
713 opts->device_placement_policy),
714 static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
715 opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
716 /*device_mgr_owned*/ true, r,
717 tensorflow::GetDefaultCustomKernelCreator())};
718 }
719
TFE_NewContextFromSession(const TFE_ContextOptions * opts,TF_Session * sess,TF_Status * status)720 TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
721 TF_Session* sess, TF_Status* status) {
722 const tensorflow::DeviceMgr* device_mgr = nullptr;
723 status->status = sess->session->LocalDeviceManager(&device_mgr);
724 if (!status->status.ok()) return nullptr;
725 tensorflow::Rendezvous* r =
726 new tensorflow::IntraProcessRendezvous(device_mgr);
727
728 return new TFE_Context{new tensorflow::EagerContext(
729 opts->session_options.options,
730 static_cast<tensorflow::ContextDevicePlacementPolicy>(
731 opts->device_placement_policy),
732 static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
733 opts->async, opts->lazy_remote_inputs_copy, device_mgr,
734 /*device_mgr_owned*/ false, r,
735 tensorflow::GetDefaultCustomKernelCreator())};
736 }
737
TFE_DeleteContext(TFE_Context * ctx)738 void TFE_DeleteContext(TFE_Context* ctx) {
739 // context->RefCountIsOne() should be true here.
740 // TODO(iga): Remove EagerContext refcounting.
741 ctx->context->Unref();
742
743 delete ctx;
744 }
745
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)746 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
747 TF_DeviceList* l = new TF_DeviceList;
748 ctx->context->ListDevices(&l->response);
749 return l;
750 }
751
TFE_ContextClearCaches(TFE_Context * ctx)752 void TFE_ContextClearCaches(TFE_Context* ctx) {
753 ctx->context->ClearCachesAndThreadExecutors();
754 }
755
756 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)757 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
758 int keep_alive_secs,
759 const void* proto,
760 size_t proto_len,
761 TF_Status* status) {
762 #if defined(IS_MOBILE_PLATFORM)
763 status->status = tensorflow::errors::Unimplemented(
764 "TFE_ContextSetServerDef not supported on mobile");
765 #else // !defined(IS_MOBILE_PLATFORM)
766 tensorflow::ServerDef server_def;
767 if (!server_def.ParseFromArray(proto, proto_len)) {
768 status->status = tensorflow::errors::InvalidArgument(
769 "Invalid tensorflow.ServerDef protocol buffer");
770 return;
771 }
772 if (server_def.has_cluster_device_filters()) {
773 const auto& cdf = server_def.cluster_device_filters();
774 for (const auto& jdf : cdf.jobs()) {
775 const string& remote_prefix = "/job:" + jdf.name() + "/task:";
776 for (const auto& tdf : jdf.tasks()) {
777 const int32_t task_index = tdf.first;
778 std::vector<string> device_filters(tdf.second.device_filters_size());
779 for (int i = 0; i < tdf.second.device_filters_size(); i++) {
780 device_filters[i] = tdf.second.device_filters(i);
781 }
782 const string remote_worker = remote_prefix + std::to_string(task_index);
783 status->status =
784 ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
785 }
786 }
787 }
788 status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
789 ctx, /*reset_context=*/true);
790 #endif // !IS_MOBILE_PLATFORM
791 }
792
TFE_ContextUpdateServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)793 TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
794 int keep_alive_secs,
795 const void* proto,
796 size_t proto_len,
797 TF_Status* status) {
798 #if defined(IS_MOBILE_PLATFORM)
799 status->status = tensorflow::errors::Unimplemented(
800 "TFE_ContextSetServerDef not supported on mobile");
801 #else // !defined(IS_MOBILE_PLATFORM)
802 tensorflow::ServerDef server_def;
803 if (!server_def.ParseFromArray(proto, proto_len)) {
804 status->status = tensorflow::errors::InvalidArgument(
805 "Invalid tensorflow.ServerDef protocol buffer");
806 return;
807 } else if (ctx->context->GetContextId() ==
808 tensorflow::EagerContext::kInvalidContextId) {
809 status->status = tensorflow::errors::InvalidArgument(
810 "Trying to update a context with invalid context id.");
811 }
812 if (server_def.has_cluster_device_filters()) {
813 LOG(WARNING) << "Device filters can only be specified when initializing "
814 "the cluster. Any changes in device filters are ignored "
815 "when updating the server def.";
816 }
817 // TODO(haoyuzhang): Check server_def compatibility before the update
818 status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
819 ctx, /*reset_context=*/false);
820 #endif // !IS_MOBILE_PLATFORM
821 }
822
TFE_ContextCheckAlive(TFE_Context * ctx,const char * worker_name,TF_Status * status)823 TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
824 const char* worker_name,
825 TF_Status* status) {
826 #if defined(IS_MOBILE_PLATFORM)
827 status->status = tensorflow::errors::Unimplemented(
828 "TFE_ContextSetServerDef not supported on mobile");
829 return false;
830 #else // !defined(IS_MOBILE_PLATFORM)
831 tensorflow::EagerContext* context = ctx->context;
832 tensorflow::GrpcServer* grpc_server =
833 static_cast<tensorflow::GrpcServer*>(context->GetServer());
834
835 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
836 status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
837 &remote_eager_workers);
838 if (!status->status.ok()) {
839 LOG(ERROR) << "Failed to get client cache for remote workers.";
840 return false;
841 }
842
843 // TODO(yuefengz): support partially specified `worker_name`.
844 tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
845 status->status = remote_eager_workers->GetClient(worker_name, &eager_client);
846 if (!status->status.ok()) {
847 return false;
848 }
849
850 // Send a rpc request to the worker to check aliveness.
851 tensorflow::eager::KeepAliveRequest request;
852 request.set_context_id(context->GetContextId());
853 tensorflow::eager::KeepAliveResponse response;
854
855 tensorflow::Status keep_alive_status;
856 tensorflow::Notification done;
857 eager_client->KeepAliveAsync(
858 &request, &response,
859 [&keep_alive_status, &done](const tensorflow::Status& s) {
860 keep_alive_status = s;
861 done.Notify();
862 });
863 done.WaitForNotification();
864
865 status->status = tensorflow::Status::OK();
866
867 // If `context_id` doesn't exist on the remote worker, an InvalidArgument
868 // error will return. But this still indicates that the remote worker is
869 // alive.
870 if (keep_alive_status.ok() ||
871 keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
872 return true;
873 } else {
874 LOG(INFO) << "Remote worker " << worker_name
875 << " is not alive: " << keep_alive_status.error_message();
876 return false;
877 }
878 #endif // !IS_MOBILE_PLATFORM
879 }
880
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)881 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
882 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
883 ctx->context->SetThreadLocalDevicePlacementPolicy(
884 static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
885 }
886
887 // Note: this function looks up a thread local policy. So it should be called in
888 // the appropriate client thread. In particular, in async mode, it may not be
889 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)890 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
891 TFE_Context* ctx) {
892 return static_cast<TFE_ContextDevicePlacementPolicy>(
893 ctx->context->GetDevicePlacementPolicy());
894 }
895
TFE_NewTensorHandle(TF_Tensor * t,TF_Status * status)896 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
897 tensorflow::Tensor tensor;
898 status->status = tensorflow::TF_TensorToTensor(t, &tensor);
899 if (!status->status.ok()) return nullptr;
900 return TFE_TensorHandle::CreateLocalHandle(tensor, status);
901 }
902
TFE_DeleteTensorHandle(TFE_TensorHandle * h)903 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
904 if (h == nullptr) return;
905 tensorflow::profiler::TraceMe activity(
906 "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
907 delete h;
908 }
909
~TensorHandleInterface()910 tensorflow::TensorHandleInterface::~TensorHandleInterface() {
911 VLOG(1) << "Deleting tensor handle " << this << " with internal handle "
912 << handle_;
913 if (handle_) {
914 handle_->Unref();
915 }
916 }
917
IsValid(Status * status) const918 bool tensorflow::TensorHandleInterface::IsValid(Status* status) const {
919 if (handle_ == nullptr) {
920 *status = tensorflow::errors::InvalidArgument(
921 "The passed in handle is a nullptr");
922 return false;
923 }
924
925 return true;
926 }
927
TFE_TensorHandleDataType(TFE_TensorHandle * h)928 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
929 return h->handle->DataType();
930 }
931
DataType() const932 TF_DataType tensorflow::TensorHandleInterface::DataType() const {
933 return static_cast<TF_DataType>(handle_->dtype);
934 }
935
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)936 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
937 if (h == nullptr) {
938 status->status = tensorflow::errors::InvalidArgument(
939 "The passed in handle is a nullptr");
940 return -1;
941 }
942
943 return h->handle->NumDims(&status->status);
944 }
945
NumDims(Status * status) const946 int tensorflow::TensorHandleInterface::NumDims(Status* status) const {
947 if (!IsValid(status)) {
948 return -1;
949 }
950
951 int result;
952 *status = handle_->NumDims(&result);
953 return result;
954 }
955
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)956 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
957 if (h == nullptr) {
958 status->status = tensorflow::errors::InvalidArgument(
959 "The passed in handle is a nullptr");
960 return -1;
961 }
962
963 return h->handle->NumElements(&status->status);
964 }
965
NumElements(Status * status) const966 int64_t tensorflow::TensorHandleInterface::NumElements(Status* status) const {
967 if (!IsValid(status)) {
968 return -1;
969 }
970
971 tensorflow::int64 result;
972 *status = handle_->NumElements(&result);
973 return result;
974 }
975
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)976 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
977 TF_Status* status) {
978 if (h == nullptr) {
979 status->status = tensorflow::errors::InvalidArgument(
980 "The passed in handle is a nullptr");
981 return -1;
982 }
983
984 return h->handle->Dim(dim_index, &status->status);
985 }
986
Dim(int dim_index,Status * status) const987 int64_t tensorflow::TensorHandleInterface::Dim(int dim_index,
988 Status* status) const {
989 if (!IsValid(status)) {
990 return -1;
991 }
992
993 tensorflow::int64 result;
994 *status = handle_->Dim(dim_index, &result);
995 return result;
996 }
997
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)998 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
999 if (h == nullptr) {
1000 status->status = tensorflow::errors::InvalidArgument(
1001 "The passed in handle is a nullptr");
1002 return nullptr;
1003 }
1004 return h->handle->DeviceName(&status->status);
1005 }
1006
DeviceName(Status * status) const1007 const char* tensorflow::TensorHandleInterface::DeviceName(
1008 Status* status) const {
1009 if (!IsValid(status)) {
1010 return nullptr;
1011 }
1012 tensorflow::Device* d = handle_->op_device();
1013 return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
1014 : d->name().c_str();
1015 }
1016
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)1017 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
1018 TF_Status* status) {
1019 if (h == nullptr) {
1020 status->status = tensorflow::errors::InvalidArgument(
1021 "The passed in handle is a nullptr");
1022 return nullptr;
1023 }
1024 return h->handle->BackingDeviceName(&status->status);
1025 }
1026
BackingDeviceName(Status * status) const1027 const char* tensorflow::TensorHandleInterface::BackingDeviceName(
1028 Status* status) const {
1029 if (!IsValid(status)) {
1030 return nullptr;
1031 }
1032 tensorflow::Device* d = handle_->device();
1033 return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
1034 : d->name().c_str();
1035 }
1036
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)1037 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
1038 TFE_TensorHandle* h, TF_Status* status) {
1039 if (h == nullptr || !h->handle->IsValid(&status->status)) {
1040 status->status = tensorflow::errors::InvalidArgument(
1041 "The passed in handle is a nullptr");
1042 return nullptr;
1043 }
1044
1045 return new TFE_TensorHandle{
1046 std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
1047 }
1048
Copy()1049 AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
1050 handle_->Ref();
1051 return new TensorHandleInterface(handle_);
1052 }
1053
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)1054 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
1055 if (h == nullptr) {
1056 status->status = tensorflow::errors::InvalidArgument(
1057 "The passed in handle is a nullptr");
1058 return nullptr;
1059 }
1060
1061 return h->handle->Resolve(&status->status);
1062 }
1063
Resolve(Status * status)1064 TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
1065 if (!IsValid(status)) {
1066 return nullptr;
1067 }
1068
1069 // TODO(agarwal): move this implementation inside TFE_TensorHandle.
1070 if (handle_->IsRemote()) {
1071 const tensorflow::Tensor* t = nullptr;
1072 tensorflow::TensorHandle* h_cpu = nullptr;
1073 *status = EagerCopyToDevice(handle_, handle_->Context(),
1074 &handle_->Context()->Executor(),
1075 handle_->Context()->HostCPU(), false, &h_cpu);
1076 if (!status->ok()) {
1077 return nullptr;
1078 }
1079 *status = h_cpu->Tensor(&t);
1080 if (!status->ok()) {
1081 h_cpu->Unref();
1082 return nullptr;
1083 }
1084 TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
1085 h_cpu->Unref();
1086 return retval;
1087 } else {
1088 tensorflow::Tensor tensor;
1089 if (IsCPU(handle_->device())) {
1090 const tensorflow::Tensor* src = nullptr;
1091 *status = handle_->Tensor(&src);
1092 if (!status->ok()) return nullptr;
1093 tensor = *src;
1094 } else {
1095 tensorflow::EagerContext* ctx = handle_->Context();
1096 CHECK_NE(ctx, nullptr);
1097 *status = handle_->CopyToDevice(*ctx, ctx->HostCPU(), &tensor);
1098 if (!status->ok()) return nullptr;
1099 }
1100 return tensorflow::TF_TensorFromTensor(tensor, status);
1101 }
1102 }
1103
TFE_TensorHandleDevicePointer(TFE_TensorHandle * h,TF_Status * status)1104 void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
1105 if (h == nullptr || !h->handle->IsValid(&status->status)) {
1106 status->status = tensorflow::errors::InvalidArgument(
1107 "The passed in handle is a nullptr");
1108 return nullptr;
1109 }
1110 tensorflow::TensorHandle* handle =
1111 tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
1112 ->Handle();
1113
1114 if (handle->IsRemote()) {
1115 status->status = tensorflow::errors::InvalidArgument(
1116 "TFE_TensorHandleDevicePointer may not be called on a remote tensor "
1117 "handle.");
1118 return nullptr;
1119 }
1120 if (handle->device() != nullptr) {
1121 status->status = handle->device()->Sync();
1122 if (!status->status.ok()) {
1123 return nullptr;
1124 }
1125 }
1126 const tensorflow::Tensor* tensor;
1127 status->status = handle->Tensor(&tensor);
1128 if (!status->status.ok()) {
1129 return nullptr;
1130 }
1131 return const_cast<void*>(
1132 static_cast<const void*>(tensor->tensor_data().data()));
1133 }
1134
TFE_NewTensorHandleFromDeviceMemory(TFE_Context * ctx,const char * device_name,TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg,TF_Status * status)1135 TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
1136 TFE_Context* ctx, const char* device_name, TF_DataType dtype,
1137 const int64_t* dims, int num_dims, void* data, size_t len,
1138 void (*deallocator)(void* data, size_t len, void* arg),
1139 void* deallocator_arg, TF_Status* status) {
1140 tensorflow::Device* device;
1141 tensorflow::EagerContext* context = ctx->context;
1142 status->status = context->FindDeviceFromName(device_name, &device);
1143 if (!status->status.ok()) {
1144 deallocator(data, len, deallocator_arg);
1145 return nullptr;
1146 }
1147 std::vector<tensorflow::int64> dimvec(num_dims);
1148 for (int i = 0; i < num_dims; ++i) {
1149 dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
1150 }
1151
1152 if (dtype == TF_STRING || dtype == TF_RESOURCE ||
1153 !tensorflow::DataTypeCanUseMemcpy(
1154 static_cast<tensorflow::DataType>(dtype))) {
1155 status->status = tensorflow::errors::InvalidArgument(
1156 "Trying to create a tensor with a pointer to non-pod memory.");
1157 deallocator(data, len, deallocator_arg);
1158 return nullptr;
1159 }
1160 // TODO(apassos) do we need to wrap the deallocator here to make sure to sync
1161 // the device?
1162 TF_ManagedBuffer* buf =
1163 new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
1164
1165 tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
1166 tensorflow::TensorShape(dimvec), buf);
1167 buf->Unref();
1168 tensorflow::TensorHandle* ret_handle;
1169 status->status = tensorflow::TensorHandle::CreateLocalHandle(
1170 t, device, context, &ret_handle);
1171 if (!status->status.ok()) {
1172 return nullptr;
1173 }
1174 return new TFE_TensorHandle{
1175 std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
1176 }
1177
1178 // This function will block till the operation that produces `h` has
1179 // completed. This is only valid on local TFE_TensorHandles. Returns the size in
1180 // bytes of the memory pointed to by the device pointer returned above.
TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle * h,TF_Status * status)1181 size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
1182 TF_Status* status) {
1183 if (h == nullptr || !h->handle->IsValid(&status->status)) {
1184 status->status = tensorflow::errors::InvalidArgument(
1185 "The passed in handle is a nullptr");
1186 return 0;
1187 }
1188 tensorflow::TensorHandle* handle =
1189 tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
1190 ->Handle();
1191
1192 if (handle->IsRemote()) {
1193 status->status = tensorflow::errors::InvalidArgument(
1194 "TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
1195 "handle.");
1196 return 0;
1197 }
1198 const tensorflow::Tensor* tensor;
1199 status->status = handle->Tensor(&tensor);
1200 if (!status->status.ok()) {
1201 return 0;
1202 }
1203 return tensor->TotalBytes();
1204 }
1205
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)1206 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
1207 TF_Status* status) {
1208 std::unique_ptr<TFE_Op> new_op(
1209 new TFE_Op{tensorflow::EagerOperation(ctx->context)});
1210 status->status =
1211 new_op->operation.Reset(op_or_function_name, nullptr, false, nullptr);
1212 if (!status->status.ok()) {
1213 new_op.reset();
1214 }
1215 return new_op.release();
1216 }
1217
TFE_DeleteOp(TFE_Op * op)1218 void TFE_DeleteOp(TFE_Op* op) { delete op; }
1219
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)1220 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
1221 status->status = op->operation.SetDeviceName(device_name);
1222 }
1223
TFE_OpGetDevice(TFE_Op * op,TF_Status * status)1224 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
1225 tensorflow::Device* device = (op->operation.Device() == nullptr)
1226 ? op->operation.EagerContext().HostCPU()
1227 : op->operation.Device();
1228 return device->name().c_str();
1229 }
1230
TFE_OpSetXLACompilation(TFE_Op * op,unsigned char enable)1231 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
1232 op->operation.SetUseXla(enable);
1233 #ifndef TENSORFLOW_EAGER_USE_XLA
1234 LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
1235 "built with XLA support.";
1236 #endif // TENSORFLOW_EAGER_USE_XLA
1237 }
1238
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)1239 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
1240 tensorflow::TensorHandle* h =
1241 tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
1242 input->handle.get())
1243 ->Handle();
1244 op->operation.AddInput(h);
1245 status->status = op->operation.MaybeInferSingleInputAttrs(h);
1246 }
1247
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)1248 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
1249 TF_Status* status) {
1250 for (int i = 0; i < num_inputs; ++i) {
1251 op->operation.AddInput(
1252 tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
1253 inputs[i]->handle.get())
1254 ->Handle());
1255 }
1256 status->status = op->operation.InferInputListAttrs(num_inputs);
1257 }
1258
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)1259 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
1260 unsigned char* is_list, TF_Status* status) {
1261 TF_AttrType ret = TF_ATTR_INT;
1262 status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
1263 attr_name, &ret, is_list);
1264 return ret;
1265 }
1266
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)1267 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
1268 const char* op_or_function_name,
1269 const char* attr_name, unsigned char* is_list,
1270 TF_Status* status) {
1271 TF_AttrType ret;
1272 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
1273 if (status->status.ok()) {
1274 ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
1275 } else {
1276 ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
1277 }
1278 TFE_DeleteOp(op);
1279 return ret;
1280 }
1281
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)1282 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
1283 size_t length) {
1284 op->operation.MutableAttrs()->Set(
1285 attr_name,
1286 tensorflow::StringPiece(static_cast<const char*>(value), length));
1287 }
1288
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)1289 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
1290 op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
1291 }
1292
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)1293 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
1294 op->operation.MutableAttrs()->Set(attr_name, value);
1295 }
1296
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)1297 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
1298 op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
1299 }
1300
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)1301 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
1302 op->operation.MutableAttrs()->Set(attr_name,
1303 static_cast<tensorflow::DataType>(value));
1304 }
1305
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)1306 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
1307 const int num_dims, TF_Status* out_status) {
1308 if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
1309 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
1310 tensorflow::strings::StrCat(
1311 "Value specified for `", attr_name, "` has ", num_dims,
1312 " dimensions which is over the limit of ",
1313 tensorflow::TensorShape::MaxDimensions(), ".")
1314 .c_str());
1315 return;
1316 }
1317 tensorflow::TensorShapeProto proto;
1318 if (num_dims < 0) {
1319 proto.set_unknown_rank(true);
1320 } else {
1321 for (int d = 0; d < num_dims; ++d) {
1322 proto.add_dim()->set_size(dims[d]);
1323 }
1324 }
1325 op->operation.MutableAttrs()->Set(attr_name, proto);
1326 }
1327
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)1328 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
1329 const TFE_Op* value) {
1330 tensorflow::AttrValue attr_value;
1331 tensorflow::NameAttrList* func = attr_value.mutable_func();
1332 func->set_name(value->operation.Name());
1333 value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
1334 op->operation.MutableAttrs()->Set(attr_name, attr_value);
1335 }
1336
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)1337 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
1338 const char* data, size_t length) {
1339 tensorflow::AttrValue attr_value;
1340 tensorflow::NameAttrList* func = attr_value.mutable_func();
1341 func->set_name(data, length);
1342 op->operation.MutableAttrs()->Set(attr_name, attr_value);
1343 }
1344
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)1345 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
1346 TF_Status* status) {
1347 tensorflow::Tensor t;
1348 status->status = TF_TensorToTensor(tensor, &t);
1349 if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
1350 }
1351
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1352 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
1353 const void* const* values, const size_t* lengths,
1354 int num_values) {
1355 std::vector<tensorflow::StringPiece> v(num_values);
1356 for (int i = 0; i < num_values; ++i) {
1357 v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
1358 lengths[i]);
1359 }
1360 op->operation.MutableAttrs()->Set(attr_name, v);
1361 }
1362
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)1363 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
1364 const float* values, int num_values) {
1365 op->operation.MutableAttrs()->Set(
1366 attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
1367 }
1368
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)1369 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
1370 const int64_t* values, int num_values) {
1371 op->operation.MutableAttrs()->Set(
1372 attr_name, tensorflow::gtl::ArraySlice<const int64>(
1373 reinterpret_cast<const int64*>(values), num_values));
1374 }
1375
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)1376 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
1377 const TF_DataType* values, int num_values) {
1378 op->operation.MutableAttrs()->Set(
1379 attr_name,
1380 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
1381 reinterpret_cast<const tensorflow::DataType*>(values), num_values));
1382 }
1383
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)1384 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
1385 const unsigned char* values, int num_values) {
1386 std::unique_ptr<bool[]> b(new bool[num_values]);
1387 for (int i = 0; i < num_values; ++i) {
1388 b[i] = values[i];
1389 }
1390 op->operation.MutableAttrs()->Set(
1391 attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
1392 }
1393
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)1394 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
1395 const int64_t** dims, const int* num_dims,
1396 int num_values, TF_Status* out_status) {
1397 std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
1398 new tensorflow::TensorShapeProto[num_values]);
1399 for (int i = 0; i < num_values; ++i) {
1400 const auto num_dims_i = num_dims[i];
1401
1402 if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
1403 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
1404 tensorflow::strings::StrCat(
1405 "Value specified for `", attr_name, "` has ", num_dims_i,
1406 " dimensions which is over the limit of ",
1407 tensorflow::TensorShape::MaxDimensions(), ".")
1408 .c_str());
1409 return;
1410 }
1411 if (num_dims_i < 0) {
1412 proto[i].set_unknown_rank(true);
1413 } else {
1414 const int64_t* dims_i = dims[i];
1415 auto proto_i = &proto[i];
1416 for (int d = 0; d < num_dims_i; ++d) {
1417 proto_i->add_dim()->set_size(dims_i[d]);
1418 }
1419 }
1420 }
1421 op->operation.MutableAttrs()->Set(
1422 attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
1423 proto.get(), num_values));
1424 }
1425
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)1426 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
1427 const TFE_Op** value, int num_values) {
1428 std::unique_ptr<tensorflow::NameAttrList[]> funcs(
1429 new tensorflow::NameAttrList[num_values]);
1430 for (int i = 0; i < num_values; i++) {
1431 funcs[i].set_name(value[i]->operation.Name());
1432 value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
1433 }
1434 op->operation.MutableAttrs()->Set(
1435 attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
1436 funcs.get(), num_values));
1437 }
1438
TFE_OpGetInputLength(TFE_Op * op,const char * input_name,TF_Status * status)1439 TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
1440 const char* input_name,
1441 TF_Status* status) {
1442 const tensorflow::OpDef* op_def = GetOpDef(op, status);
1443 if (!status->status.ok()) {
1444 return -1;
1445 }
1446 tensorflow::AttrValueMap attrs;
1447 op->operation.Attrs().FillAttrValueMap(&attrs);
1448 tensorflow::NameRangeMap name_ranges;
1449 status->status = tensorflow::NameRangesForNode(
1450 tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
1451 if (!status->status.ok()) {
1452 return -1;
1453 }
1454 auto iter = name_ranges.find(input_name);
1455 if (iter == name_ranges.end()) {
1456 status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
1457 "' not found");
1458 return -1;
1459 }
1460 return iter->second.second - iter->second.first;
1461 }
1462
TFE_OpGetOutputLength(TFE_Op * op,const char * output_name,TF_Status * status)1463 TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
1464 const char* output_name,
1465 TF_Status* status) {
1466 const tensorflow::OpDef* op_def = GetOpDef(op, status);
1467 if (!status->status.ok()) {
1468 return -1;
1469 }
1470 tensorflow::AttrValueMap attrs;
1471 op->operation.Attrs().FillAttrValueMap(&attrs);
1472 tensorflow::NameRangeMap name_ranges;
1473 status->status = tensorflow::NameRangesForNode(
1474 tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
1475 if (!status->status.ok()) {
1476 return -1;
1477 }
1478 auto iter = name_ranges.find(output_name);
1479 if (iter == name_ranges.end()) {
1480 status->status = tensorflow::errors::InvalidArgument(
1481 "Output '", output_name, "' not found");
1482 return -1;
1483 }
1484 return iter->second.second - iter->second.first;
1485 }
1486
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)1487 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
1488 TF_Status* status) {
1489 absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
1490 VLOG(1) << "Calling TFE_Execute() on op " << op;
1491 status->status = tensorflow::EagerExecute(&op->operation,
1492 handle_retvals.data(), num_retvals);
1493 if (!status->status.ok()) {
1494 return;
1495 }
1496 for (int i = 0; i < *num_retvals; ++i) {
1497 retvals[i] = new TFE_TensorHandle{
1498 std::make_unique<tensorflow::TensorHandleInterface>(handle_retvals[i])};
1499 }
1500 }
1501
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)1502 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
1503 TFE_Context* ctx,
1504 const char* device_name,
1505 TF_Status* status) {
1506 tensorflow::TensorHandle* handle = nullptr;
1507 tensorflow::Device* device;
1508 tensorflow::EagerContext* context = ctx->context;
1509 status->status = context->FindDeviceFromName(device_name, &device);
1510 if (!status->status.ok()) {
1511 return nullptr;
1512 }
1513 status->status = tensorflow::EagerCopyToDevice(
1514 tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
1515 ->Handle(),
1516 context, &context->Executor(), device, false, &handle);
1517 if (status->status.ok()) {
1518 return new TFE_TensorHandle{
1519 std::make_unique<tensorflow::TensorHandleInterface>(handle)};
1520 }
1521 return nullptr;
1522 }
1523
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)1524 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
1525 const char* serialized_function_def, size_t size,
1526 TF_Status* status) {
1527 tensorflow::FunctionDef function_def;
1528 if (!function_def.ParseFromArray(serialized_function_def, size)) {
1529 status->status =
1530 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
1531 return;
1532 }
1533 status->status = ctx->context->AddFunctionDef(function_def);
1534 }
1535
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)1536 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
1537 TF_Status* status) {
1538 status->status = ctx->context->AddFunctionDef(function->fdef);
1539 }
1540
TFE_ContextRemoveFunction(TFE_Context * ctx,const char * name,TF_Status * status)1541 void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
1542 TF_Status* status) {
1543 status->status = ctx->context->RemoveFunction(name);
1544 }
1545
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)1546 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
1547 return ctx->context->FindFunctionDef(name) != nullptr;
1548 }
1549
TFE_ContextEnableRunMetadata(TFE_Context * ctx)1550 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
1551 ctx->context->SetShouldStoreGraphs(true);
1552 }
1553
TFE_ContextDisableRunMetadata(TFE_Context * ctx)1554 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
1555 ctx->context->SetShouldStoreGraphs(false);
1556 }
1557
1558 } // extern "C"
1559
TFE_NewTensorHandle(const tensorflow::Tensor & t,TF_Status * status)1560 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
1561 TF_Status* status) {
1562 return TFE_TensorHandle::CreateLocalHandle(t, status);
1563 }
1564
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)1565 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
1566 TF_Status* status) {
1567 tensorflow::EagerContext* context = ctx->context;
1568 status->status = context->Executor().WaitForAllPendingNodes();
1569 if (!status->status.ok()) return;
1570 tensorflow::mutex_lock ml(*context->MetadataMu());
1571 status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
1572 context->ClearRunMetadata();
1573 }
1574
1575 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)1576 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
1577 TF_Status* status) {
1578 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
1579 for (const auto& attr : func.attr()) {
1580 if (!status->status.ok()) return nullptr;
1581 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
1582 if (!status->status.ok()) return nullptr;
1583 }
1584 return func_op;
1585 }
1586 } // namespace
1587
TFE_ContextStartStep(TFE_Context * ctx)1588 void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
1589
TFE_ContextEndStep(TFE_Context * ctx)1590 void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
1591
1592 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)1593 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
1594 const tensorflow::AttrValue& default_value,
1595 const char* attr_name, TF_Status* status) {
1596 switch (default_value.value_case()) {
1597 case tensorflow::AttrValue::kS: {
1598 const string& v = default_value.s();
1599 TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
1600 break;
1601 }
1602 case tensorflow::AttrValue::kI:
1603 TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
1604 break;
1605 case tensorflow::AttrValue::kF:
1606 TFE_OpSetAttrFloat(op, attr_name, default_value.f());
1607 break;
1608 case tensorflow::AttrValue::kB:
1609 TFE_OpSetAttrBool(op, attr_name, default_value.b());
1610 break;
1611 case tensorflow::AttrValue::kType:
1612 TFE_OpSetAttrType(op, attr_name,
1613 static_cast<TF_DataType>(default_value.type()));
1614 break;
1615 case tensorflow::AttrValue::kShape: {
1616 const auto& tensor_shape = default_value.shape();
1617 if (tensor_shape.unknown_rank()) {
1618 TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
1619 } else {
1620 const auto num_dims = tensor_shape.dim_size();
1621 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
1622 for (int i = 0; i < num_dims; ++i) {
1623 dims[i] = tensor_shape.dim(i).size();
1624 }
1625 TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
1626 }
1627 } break;
1628 case tensorflow::AttrValue::kFunc: {
1629 const auto func_op = GetFunc(ctx, default_value.func(), status);
1630 if (!status->status.ok()) return;
1631 // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
1632 // require TFE_Op* and just convert it internally a NameAttrValue, so
1633 // consider adding an overload to the C API to make this case easier.
1634 TFE_OpSetAttrFunction(op, attr_name, func_op);
1635 } break;
1636 case tensorflow::AttrValue::kList:
1637 TF_FALLTHROUGH_INTENDED;
1638 case tensorflow::AttrValue::kTensor:
1639 TF_FALLTHROUGH_INTENDED;
1640 case tensorflow::AttrValue::kPlaceholder:
1641 TF_FALLTHROUGH_INTENDED;
1642 case tensorflow::AttrValue::VALUE_NOT_SET:
1643 TF_SetStatus(
1644 status, TF_UNIMPLEMENTED,
1645 tensorflow::strings::StrCat("Unable to get setfor default value: ",
1646 default_value.DebugString())
1647 .data());
1648 }
1649 }
1650 } // namespace tensorflow
1651