1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17
18 #include "absl/container/fixed_array.h"
19 #include "absl/memory/memory.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
23 #include "tensorflow/c/tf_status_helper.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
27 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
28 #include "tensorflow/core/common_runtime/eager/execute.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
32 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
33 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
34 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
35 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
36 #include "tensorflow/core/distributed_runtime/server_lib.h"
37 #include "tensorflow/core/distributed_runtime/session_mgr.h"
38 #include "tensorflow/core/distributed_runtime/worker_cache.h"
39 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
40 #include "tensorflow/core/distributed_runtime/worker_env.h"
41 #include "tensorflow/core/framework/rendezvous.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/gtl/cleanup.h"
44 #include "tensorflow/core/lib/random/random.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/lib/strings/stringprintf.h"
47 #include "tensorflow/core/platform/cpu_info.h"
48 #include "tensorflow/core/platform/env.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/core/platform/host_info.h"
51 #include "tensorflow/core/platform/mutex.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/profiler/lib/traceme.h"
54 #include "tensorflow/core/protobuf/error_codes.pb.h"
55
56 namespace tensorflow {
57 namespace eager {
58
59 namespace {
GetNumRetvals(tensorflow::EagerContext * context,const string & op_name,const google::protobuf::Map<string,tensorflow::AttrValue> & attrs,int * num_retvals)60 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
61 const google::protobuf::Map<string, tensorflow::AttrValue>& attrs,
62 int* num_retvals) {
63 const tensorflow::OpRegistrationData* op_reg_data = nullptr;
64 auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
65 if (errors::IsNotFound(status)) {
66 status = context->FindFunctionOpData(op_name, &op_reg_data);
67 }
68 TF_RETURN_IF_ERROR(status);
69
70 const tensorflow::OpDef& op_def = op_reg_data->op_def;
71
72 for (const auto& output_arg : op_def.output_arg()) {
73 if (!output_arg.number_attr().empty()) {
74 auto iter = attrs.find(output_arg.number_attr());
75 if (iter == attrs.end()) {
76 return errors::InvalidArgument("Unable to find number_attr ",
77 output_arg.number_attr(),
78 " for Op: ", op_name);
79 }
80 *num_retvals += iter->second.i();
81 } else if (!output_arg.type_list_attr().empty()) {
82 auto iter = attrs.find(output_arg.type_list_attr());
83 if (iter == attrs.end()) {
84 return errors::InvalidArgument("Unable to find type_list_attr ",
85 output_arg.type_list_attr(),
86 " for Op: ", op_name);
87 }
88 *num_retvals += iter->second.list().type_size();
89 } else {
90 *num_retvals += 1;
91 }
92 }
93
94 return Status::OK();
95 }
96
GetEagerOperationAndNumRetvals(const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,EagerOperation * eager_op,int * num_retvals)97 Status GetEagerOperationAndNumRetvals(const Operation& operation,
98 EagerContext* eager_context,
99 EagerExecutor* eager_executor,
100 EagerOperation* eager_op,
101 int* num_retvals) {
102 const char* name = operation.name().c_str(); // Shorthand
103 absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
104 absl::nullopt;
105 if (operation.is_function()) {
106 if (operation.is_component_function()) {
107 remote_func_params = {operation.id(), operation.func_step_id()};
108 } else {
109 remote_func_params = {operation.id(), absl::nullopt};
110 }
111 }
112 TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false,
113 eager_executor, remote_func_params));
114
115 {
116 profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
117 profiler::TraceMeLevel::kVerbose);
118 for (const auto& input : operation.op_inputs()) {
119 tensorflow::TensorHandle* handle;
120 if (input.has_remote_handle()) {
121 TF_RETURN_IF_ERROR(
122 eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
123 input.remote_handle(), &handle));
124 TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
125 } else {
126 Tensor tensor;
127 if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
128 return errors::InvalidArgument("Invalid TensorProto: ",
129 input.tensor().DebugString());
130 } else {
131 handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
132 nullptr, eager_context);
133 TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
134 }
135 }
136 // Unref handle since it has a ref as an input now.
137 handle->Unref();
138 }
139 }
140
141 for (const auto& attr : operation.attrs()) {
142 eager_op->MutableAttrs()->Set(attr.first, attr.second);
143 }
144
145 // TODO(nareshmodi): Consider caching this.
146 return GetNumRetvals(eager_context, operation.name(), operation.attrs(),
147 num_retvals);
148 }
149
TensorHandleProto(TensorHandle * handle,TensorProto * proto)150 Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
151 const tensorflow::Tensor* t = nullptr;
152 TF_RETURN_IF_ERROR(handle->Tensor(&t));
153 t->AsProtoTensorContent(proto);
154 return Status::OK();
155 }
156
TensorHandleShape(TensorHandle * handle,TensorShapeProto * proto)157 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
158 const tensorflow::Tensor* t = nullptr;
159
160 // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
161 if (handle->Type() == TensorHandle::LOCAL) {
162 TF_RETURN_IF_ERROR(handle->Tensor(&t));
163
164 t->shape().AsProto(proto);
165 } else {
166 TensorShape shape;
167 TF_RETURN_IF_ERROR(handle->Shape(&shape));
168 shape.AsProto(proto);
169 }
170
171 return Status::OK();
172 }
173
AddOpRetvalsToResponse(EagerContext * eager_context,int op_id,int num_retvals,const std::vector<int32> & output_nums,TensorHandle ** retvals,std::function<TensorProto * ()> add_tensor_proto_fn,std::function<TensorShapeProto * ()> add_shape_proto_fn,std::function<string * ()> add_device_fn=nullptr)174 Status AddOpRetvalsToResponse(
175 EagerContext* eager_context, int op_id, int num_retvals,
176 const std::vector<int32>& output_nums, TensorHandle** retvals,
177 std::function<TensorProto*()> add_tensor_proto_fn,
178 std::function<TensorShapeProto*()> add_shape_proto_fn,
179 std::function<string*()> add_device_fn = nullptr) {
180 if (op_id == kInvalidRemoteOpId) {
181 // Copy the output tensors back along with the response, since the op id
182 // is invalid which cannot be added to RemoteMgr.
183 for (int i = 0; i < num_retvals; i++) {
184 TF_RETURN_IF_ERROR(TensorHandleProto(retvals[i], add_tensor_proto_fn()));
185 retvals[i]->Unref();
186 }
187 } else {
188 for (int i = 0; i < num_retvals; i++) {
189 TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
190 if (add_device_fn) {
191 Device* device = retvals[i]->device();
192 *add_device_fn() = device ? device->name() : "";
193 }
194 if (retvals[i]->Type() == TensorHandle::REMOTE) {
195 retvals[i]->Unref();
196 } else {
197 const int output_num = output_nums.empty() ? i : output_nums.at(i);
198 eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id,
199 output_num);
200 }
201 }
202 }
203 return Status::OK();
204 }
205 } // namespace
206
CreateContext(const CreateContextRequest * request,CreateContextResponse * response)207 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
208 CreateContextResponse* response) {
209 {
210 mutex_lock l(contexts_mu_);
211 auto context_it = contexts_.find(request->context_id());
212 if (context_it != contexts_.end()) {
213 if (request->context_view_id() <
214 context_it->second->Context()->GetContextViewId()) {
215 return errors::InvalidArgument("EagerService:CreateContext failed. ",
216 "Context id: <", request->context_id(),
217 "> already exists.");
218 } else {
219 // For existing context with a stale context_view_id, close the old one
220 // and recreate with new view id. This is likely due to the worker
221 // disconnected and then reconnected after one or more cluster updates.
222 context_it->second->Unref();
223 contexts_.erase(context_it);
224 }
225 }
226 }
227 // make sure env_ , env_->rendezvous_mgr available
228 if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
229 return tensorflow::errors::Internal(
230 "invalid eager env_ or env_->rendezvous_mgr.");
231 }
232
233 auto* r = env_->rendezvous_mgr->Find(request->context_id());
234 auto session_name =
235 tensorflow::strings::StrCat("eager_", request->context_id());
236 if (VLOG_IS_ON(2)) {
237 VLOG(2) << "Creating context on /job:" << request->server_def().job_name()
238 << "/task:" << request->server_def().task_index();
239 for (const auto& da : request->cluster_device_attributes()) {
240 VLOG(2) << " " << da.name();
241 }
242 }
243 TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
244 session_name, request->server_def(), request->cluster_device_attributes(),
245 true));
246 int64_t context_id = request->context_id();
247 std::function<void()> session_destroyer = [this, context_id, session_name]() {
248 env_->rendezvous_mgr->Cleanup(context_id);
249 auto s = env_->session_mgr->DeleteSession(session_name);
250 if (!s.ok()) {
251 LOG(WARNING) << "Failed to destroy worker session '" << session_name
252 << "' due to " << s.error_message();
253 }
254 };
255
256 std::shared_ptr<WorkerSession> worker_session;
257 TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
258 session_name, &worker_session));
259
260 tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
261
262 // Initialize remote tensor communication based on worker session.
263 TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
264
265 std::function<Rendezvous*(const int64_t)> rendezvous_creator =
266 [worker_session, this](const int64_t step_id) {
267 auto* r = env_->rendezvous_mgr->Find(step_id);
268 r->Initialize(worker_session.get()).IgnoreError();
269 return r;
270 };
271
272 LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
273 << " eager service context with rendezvous_id on host "
274 << port::Hostname() << " " << worker_session->worker_name();
275 SessionOptions opts;
276 opts.config = request->server_def().default_session_config();
277 tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
278 opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
279 request->async(), device_mgr, false, r, worker_session->cluster_flr(),
280 env_->collective_executor_mgr.get());
281 // Ownership will be transferred to the ServerContext, or else in an error
282 // case ctx will be deleted by this unref.
283 core::ScopedUnref unref_ctx(ctx);
284
285 std::vector<string> remote_workers;
286 worker_session->worker_cache()->ListWorkers(&remote_workers);
287 remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
288 worker_session->worker_name()),
289 remote_workers.end());
290
291 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
292 TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
293 &remote_eager_workers));
294 DistributedFunctionLibraryRuntime* cluster_flr =
295 eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
296
297 auto remote_mgr =
298 absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);
299 Status s = ctx->InitializeRemoteWorker(
300 std::move(remote_eager_workers), worker_session->remote_device_mgr(),
301 remote_workers, request->context_id(), request->context_view_id(),
302 std::move(rendezvous_creator), cluster_flr, std::move(remote_mgr),
303 std::move(session_destroyer));
304 if (!s.ok()) {
305 VLOG(1) << "EagerContext::InitializeRemoteWorker failed with "
306 << s.ToString();
307 return s;
308 }
309
310 #if !defined(IS_MOBILE_PLATFORM)
311 const auto& config = request->server_def().default_session_config();
312 const bool enable_coordination =
313 !config.experimental().coordination_service().empty();
314 if (enable_coordination) {
315 auto dist_mgr = std::make_unique<EagerContextDistributedManager>(ctx);
316 ctx->SetDistributedManager(std::move(dist_mgr));
317 TF_RETURN_IF_ERROR(ctx->GetDistributedManager()->EnableCoordinationService(
318 config.experimental().coordination_service(), env_,
319 request->server_def(), worker_session->worker_cache()));
320 std::unique_ptr<CoordinationClientCache> client_cache;
321 TF_RETURN_IF_ERROR(
322 worker_session->worker_cache()->GetCoordinationClientCache(
323 &client_cache));
324 TF_RETURN_IF_ERROR(
325 ctx->GetDistributedManager()->GetCoordinationServiceAgent()->Initialize(
326 env_, request->server_def(), std::move(client_cache),
327 /*error_fn=*/[](Status s) {
328 LOG(ERROR) << "Coordination agent is set to error: " << s;
329 }));
330 }
331 #endif // !IS_MOBILE_PLATFORM
332
333 std::vector<DeviceAttributes> device_attributes;
334 device_mgr->ListDeviceAttributes(&device_attributes);
335
336 for (const auto& da : device_attributes) {
337 *response->add_device_attributes() = da;
338 }
339 {
340 mutex_lock l(contexts_mu_);
341 auto context_it = contexts_.find(request->context_id());
342 if (context_it != contexts_.end()) {
343 return errors::InvalidArgument("EagerService:CreateContext failed. ",
344 "Context id: <", request->context_id(),
345 "> already exists.");
346 }
347 contexts_.emplace(request->context_id(),
348 new ServerContext(ctx, request->keep_alive_secs(), env_));
349 }
350
351 return Status::OK();
352 }
353
UpdateContext(const UpdateContextRequest * request,UpdateContextResponse * response)354 Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
355 UpdateContextResponse* response) {
356 // make sure env_ , env_->rendezvous_mgr available
357 if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
358 return tensorflow::errors::Internal(
359 "invalid eager env_ or env_->rendezvous_mgr.");
360 }
361
362 // Find the context to update by the requested context_id
363 ServerContext* server_context = nullptr;
364 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &server_context));
365 core::ScopedUnref context_unref(server_context);
366
367 tensorflow::EagerContext* ctx = server_context->Context();
368 if (request->context_view_id() != ctx->GetContextViewId() + 1) {
369 return errors::InvalidArgument(
370 "EagerService:UpdateContext failed. Context id: <",
371 request->context_id(), "> currently at view #", ctx->GetContextViewId(),
372 " but received update request at view #", request->context_view_id(),
373 ". View id should only be continuously incremented.");
374 }
375 if (request->cluster_device_attributes_size() == 0) {
376 // In this case, the client indicates that the updated `server_def` and
377 // device info is irrelevant to this worker, since it is not connected to
378 // the updated ones (likely due to device filter settings). The worker
379 // simply needs to update view ID and does not update other internal state.
380 ctx->IncrementContextViewId();
381 VLOG(1) << "Processing simplified UpdateContextRequest on "
382 << ctx->HostCPU()->name();
383 return Status::OK();
384 }
385
386 auto session_name =
387 tensorflow::strings::StrCat("eager_", request->context_id());
388
389 TF_RETURN_IF_ERROR(env_->session_mgr->UpdateSession(
390 session_name, request->server_def(), request->cluster_device_attributes(),
391 true));
392
393 std::shared_ptr<WorkerSession> worker_session;
394 TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
395 session_name, &worker_session));
396
397 const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
398
399 std::vector<string> remote_workers;
400 worker_session->worker_cache()->ListWorkers(&remote_workers);
401 remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
402 worker_session->worker_name()),
403 remote_workers.end());
404 VLOG(1) << "On existing server " << worker_session->worker_name()
405 << " updating remote workers";
406 if (VLOG_IS_ON(2)) {
407 for (const string& rw : remote_workers) {
408 VLOG(2) << "Remote worker " << rw;
409 }
410 }
411
412 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
413 TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
414 &remote_eager_workers));
415
416 ctx->ClearCachesAndThreadExecutors();
417 Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers),
418 remote_workers, request->context_id());
419 if (!s.ok()) {
420 VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString();
421 return s;
422 }
423
424 std::vector<DeviceAttributes> device_attributes;
425 device_mgr->ListDeviceAttributes(&device_attributes);
426
427 for (const auto& da : device_attributes) {
428 *response->add_device_attributes() = da;
429 }
430
431 return Status::OK();
432 }
433
CreateMasterContext(const tensorflow::uint64 context_id,EagerContext * context)434 Status EagerServiceImpl::CreateMasterContext(
435 const tensorflow::uint64 context_id, EagerContext* context) {
436 {
437 mutex_lock l(contexts_mu_);
438 auto iter = contexts_.find(context_id);
439 if (iter != contexts_.end()) {
440 return errors::InvalidArgument(
441 "EagerService:CreateMasterContext failed. ", "Context id: <",
442 context_id, "> already exists.");
443 }
444 }
445 ServerContext* server_context =
446 ServerContext::CreateMasterContext(context, env_);
447 mutex_lock l(contexts_mu_);
448 contexts_.emplace(context_id, server_context);
449 return Status::OK();
450 }
451
RunComponentFunction(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)452 void EagerServiceImpl::RunComponentFunction(
453 CallOptions* call_opts, const RunComponentFunctionRequest* request,
454 RunComponentFunctionResponse* response, StatusCallback done) {
455 ServerContext* context = nullptr;
456 Status s = GetServerContext(request->context_id(), &context);
457 if (!s.ok()) {
458 done(s);
459 return;
460 }
461 core::ScopedUnref context_unref(context);
462
463 auto& operation = request->operation();
464 // This codepath should only be triggered for executing component function
465 if (!operation.is_function() || !operation.is_component_function()) {
466 done(errors::Internal(
467 "RunComponentFunction request can only be used to execute "
468 "component functions."));
469 return;
470 }
471
472 EagerContext* eager_context = context->Context();
473 EagerExecutor* eager_executor = &eager_context->Executor();
474
475 EagerOperation* op = new EagerOperation(eager_context);
476 int* num_retvals = new int(0);
477 s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor,
478 op, num_retvals);
479 if (!s.ok()) {
480 done(s);
481 return;
482 }
483 if (!op->IsLocal()) {
484 done(errors::Internal(
485 "Received RunComponentFunction request with remote function device. "));
486 return;
487 }
488 s = op->SetAttrBool("is_component_function", true);
489 if (!s.ok()) {
490 done(errors::Internal("Error setting is_component_function attribute: ",
491 s.error_message()));
492 return;
493 }
494
495 auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
496 VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
497 << operation.id();
498 std::vector<int32> output_nums;
499 for (const int32_t output_num : request->output_num()) {
500 output_nums.push_back(output_num);
501 }
502
503 auto cm = std::make_shared<CancellationManager>();
504 op->SetCancellationManager(cm.get());
505 call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
506
507 context->Ref();
508 EagerLocalExecuteAsync(
509 op, retvals->data(), num_retvals,
510 [op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
511 call_opts, response, eager_context, context,
512 done = std::move(done)](const Status& status) {
513 call_opts->ClearCancelCallback();
514 auto wrapped_done = [&](const Status& status) {
515 context->Unref();
516 done(status);
517 delete op;
518 delete num_retvals;
519 delete retvals;
520 };
521 if (!status.ok()) {
522 wrapped_done(status);
523 return;
524 }
525 // The output device of a component function is the component device
526 // which is known on the default device of it's parent function.
527 wrapped_done(AddOpRetvalsToResponse(
528 eager_context, op_id, *num_retvals, output_nums, retvals->data(),
529 [response] { return response->add_tensor(); },
530 [response] { return response->add_shape(); }));
531 });
532 }
533
ExecuteOp(CallOptions * call_opts,const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,QueueResponse * queue_response)534 Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
535 const Operation& operation,
536 EagerContext* eager_context,
537 EagerExecutor* eager_executor,
538 QueueResponse* queue_response) {
539 tensorflow::EagerOperation op(eager_context);
540 int num_retvals = 0;
541 TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals(
542 operation, eager_context, eager_executor, &op, &num_retvals));
543
544 auto cm = std::make_shared<CancellationManager>();
545 if (call_opts) {
546 op.SetCancellationManager(cm.get());
547 call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
548 }
549
550 absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
551 VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
552 TF_RETURN_IF_ERROR(op.Execute(
553 absl::MakeSpan(
554 reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()),
555 num_retvals),
556 &num_retvals));
557
558 std::function<string*()> add_device_fn = nullptr;
559 // Send the output devices of a function back to let a client know where the
560 // outputs are. For a primitive op, an output devics is the op device which is
561 // known on a client.
562 if (op.is_function()) {
563 add_device_fn = [queue_response] { return queue_response->add_device(); };
564 }
565
566 return AddOpRetvalsToResponse(
567 eager_context, operation.id(), num_retvals, /*output_nums=*/{},
568 retvals.data(), [queue_response] { return queue_response->add_tensor(); },
569 [queue_response] { return queue_response->add_shape(); },
570 std::move(add_device_fn));
571 }
572
Enqueue(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,uint64 stream_id)573 Status EagerServiceImpl::Enqueue(CallOptions* call_opts,
574 const EnqueueRequest* request,
575 EnqueueResponse* response, uint64 stream_id) {
576 profiler::TraceMe activity(
577 [&] {
578 return absl::StrCat(
579 "EagerService:Enqueue#debug_str=", request->DebugString(), "#");
580 },
581 profiler::TraceMeLevel::kInfo);
582 ServerContext* context = nullptr;
583 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
584 core::ScopedUnref context_unref(context);
585
586 EagerExecutor& executor =
587 stream_id == kInvalidStreamId
588 ? context->Context()->Executor()
589 : context->Context()->RemoteMgr()->GetOrCreateExecutorForStream(
590 stream_id);
591 Status s;
592 for (const auto& item : request->queue()) {
593 auto* queue_response = response->add_queue_response();
594 if (item.has_operation()) {
595 s = ExecuteOp(call_opts, item.operation(), context->Context(), &executor,
596 queue_response);
597 } else if (item.has_handle_to_decref()) {
598 auto handle_to_decref = absl::make_unique<RemoteTensorHandleInternal>(
599 item.handle_to_decref());
600 auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
601 context, std::move(handle_to_decref));
602 s = context->Context()->Executor().AddOrExecute(std::move(node));
603 } else if (item.has_send_tensor()) {
604 s = SendTensor(item.send_tensor(), context->Context());
605 } else if (item.has_send_packed_handle()) {
606 s = SendPackedHandle(item.send_packed_handle(), context->Context());
607 } else if (item.has_register_function()) {
608 s = RegisterFunction(item.register_function(), context->Context());
609 } else if (item.has_cleanup_function()) {
610 s = CleanupFunction(item.cleanup_function());
611 } else {
612 DCHECK(item.has_sync_remote_executor_for_stream());
613 s = executor.WaitForAllPendingNodes();
614 }
615
616 if (!s.ok()) {
617 if (stream_id != kInvalidStreamId) {
618 context->Context()->RemoteMgr()->DeleteExecutorForStream(stream_id);
619 }
620 return s;
621 }
622 }
623
624 return Status::OK();
625 }
626
WaitQueueDone(const WaitQueueDoneRequest * request,WaitQueueDoneResponse * response)627 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
628 WaitQueueDoneResponse* response) {
629 ServerContext* context = nullptr;
630 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
631 core::ScopedUnref context_unref(context);
632
633 if (request->op_id_size() > 0) {
634 return errors::Unimplemented(
635 "EagerServiceImpl::WaitQueueDone is not "
636 "implemented for particular op IDs.");
637 }
638 return context->Context()->Executor().WaitForAllPendingNodes();
639 }
640
KeepAlive(const KeepAliveRequest * request,KeepAliveResponse * response)641 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
642 KeepAliveResponse* response) {
643 ServerContext* context = nullptr;
644 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
645 core::ScopedUnref context_unref(context);
646
647 tensorflow::EagerContext* ctx = context->Context();
648 response->set_context_view_id(ctx->GetContextViewId());
649 return Status::OK();
650 }
651
CloseContext(const CloseContextRequest * request,CloseContextResponse * response)652 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
653 CloseContextResponse* response) {
654 VLOG(1) << "Executing EagerService::CloseContext for context "
655 << request->context_id();
656 ServerContext* context = nullptr;
657 if (!GetServerContext(request->context_id(), &context).ok()) {
658 // Swallow the error here.
659 return Status::OK();
660 }
661 core::ScopedUnref context_unref(context);
662
663 if (request->context_view_id() < context->Context()->GetContextViewId()) {
664 // Swallow the error here.
665 LOG(INFO) << "Ignoring CloseContext request with a stale context_view_id "
666 << request->context_view_id() << " for context_id "
667 << request->context_id() << ". The current context_view_id is "
668 << context->Context()->GetContextViewId() << ".";
669 return Status::OK();
670 }
671
672 mutex_lock l(contexts_mu_);
673 contexts_.erase(request->context_id());
674
675 // GetServerContext returns a newly Reffed copy of ServerContext, which is
676 // unreffed by context_unref. Additionally, we need to unref it one time since
677 // we are releasing it from the map.
678 context->Unref();
679
680 return Status::OK();
681 }
682
RegisterFunction(const RegisterFunctionOp & register_function,EagerContext * eager_context)683 Status EagerServiceImpl::RegisterFunction(
684 const RegisterFunctionOp& register_function, EagerContext* eager_context) {
685 // If the function is a component of a multi-device function, we only need to
686 // register it locally.
687 return eager_context->AddFunctionDef(
688 register_function.function_def(), register_function.library(),
689 register_function.is_component_function());
690 }
691
CleanupFunction(const CleanupFunctionOp & cleanup_function)692 Status EagerServiceImpl::CleanupFunction(
693 const CleanupFunctionOp& cleanup_function) {
694 env_->rendezvous_mgr->Cleanup(cleanup_function.step_id());
695 return Status::OK();
696 }
697
SendTensor(const SendTensorOp & send_tensor,EagerContext * eager_context)698 Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
699 EagerContext* eager_context) {
700 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
701 for (const auto& tensor_proto : send_tensor.tensors()) {
702 Tensor tensor;
703 if (!tensor.FromProto(tensor_proto)) {
704 return errors::InvalidArgument("Unable to parse tensor proto");
705 }
706
707 TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle(
708 std::move(tensor), nullptr, nullptr, eager_context);
709 TensorHandle* copied_handle = nullptr;
710 Device* device;
711 TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
712 send_tensor.device_name().c_str(), &device));
713 TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, eager_context,
714 &eager_context->Executor(), device,
715 false, &copied_handle));
716 tensors.push_back(copied_handle);
717 tensor_handle->Unref();
718 }
719
720 eager_context->RemoteMgr()->AddOperationOutputs(tensors, send_tensor.op_id());
721
722 return Status::OK();
723 }
724
SendPackedHandle(const SendPackedHandleOp & send_packed_handle,EagerContext * eager_context)725 Status EagerServiceImpl::SendPackedHandle(
726 const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) {
727 if (send_packed_handle.handles().empty()) {
728 return errors::InvalidArgument("Handles should not be empty.");
729 }
730
731 std::vector<tensorflow::TensorHandle*> handles;
732 handles.resize(send_packed_handle.handles_size());
733 for (int i = 0; i < send_packed_handle.handles_size(); ++i) {
734 const auto& item = send_packed_handle.handles(i);
735 if (item.has_local_handle()) {
736 Tensor tensor;
737 if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) {
738 return errors::InvalidArgument(
739 "Invalid TensorProto: ",
740 item.local_handle().tensor().DebugString());
741 }
742 Device* op_device = nullptr;
743 TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
744 item.local_handle().device().c_str(), &op_device));
745 handles[i] = TensorHandle::CreateLocalHandle(
746 std::move(tensor), /*d=*/nullptr, op_device, eager_context);
747 } else {
748 TF_RETURN_IF_ERROR(
749 eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
750 item.remote_handle(), &handles[i]));
751 }
752 }
753
754 tensorflow::TensorHandle* packed_handle = nullptr;
755 std::vector<tensorflow::TensorHandle*> handles_to_pack = handles;
756 // Create a unshaped packed TensorHandle.
757 TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
758 std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
759 send_packed_handle.device_name(), eager_context, &packed_handle));
760
761 for (auto* h : handles) {
762 // Unref handle since it has a ref in the packed handle now.
763 h->Unref();
764 }
765
766 eager_context->RemoteMgr()->AddOperationOutputs({packed_handle},
767 send_packed_handle.op_id());
768 return Status::OK();
769 }
770
GetServerContext(uint64 context_id,ServerContext ** server_context)771 tensorflow::Status EagerServiceImpl::GetServerContext(
772 uint64 context_id, ServerContext** server_context) {
773 tf_shared_lock l(contexts_mu_);
774 auto iter = contexts_.find(context_id);
775 if (iter == contexts_.end()) {
776 *server_context = nullptr;
777 return errors::Unavailable(strings::Printf(
778 "Unable to find a context_id matching the specified one "
779 "(%llu). Perhaps the worker was restarted, or the context was GC'd?",
780 static_cast<unsigned long long>(context_id)));
781 }
782
783 *server_context = iter->second;
784 (*server_context)->Ref();
785
786 (*server_context)->RecordAccess();
787
788 return Status::OK();
789 }
790
791 } // namespace eager
792 } // namespace tensorflow
793