1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17 
18 #include "absl/container/fixed_array.h"
19 #include "absl/memory/memory.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
23 #include "tensorflow/c/tf_status_helper.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/eager/context.h"
26 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
27 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
28 #include "tensorflow/core/common_runtime/eager/execute.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
32 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
33 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
34 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
35 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
36 #include "tensorflow/core/distributed_runtime/server_lib.h"
37 #include "tensorflow/core/distributed_runtime/session_mgr.h"
38 #include "tensorflow/core/distributed_runtime/worker_cache.h"
39 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
40 #include "tensorflow/core/distributed_runtime/worker_env.h"
41 #include "tensorflow/core/framework/rendezvous.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/gtl/cleanup.h"
44 #include "tensorflow/core/lib/random/random.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/lib/strings/stringprintf.h"
47 #include "tensorflow/core/platform/cpu_info.h"
48 #include "tensorflow/core/platform/env.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/core/platform/host_info.h"
51 #include "tensorflow/core/platform/mutex.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/profiler/lib/traceme.h"
54 #include "tensorflow/core/protobuf/error_codes.pb.h"
55 
56 namespace tensorflow {
57 namespace eager {
58 
59 namespace {
GetNumRetvals(tensorflow::EagerContext * context,const string & op_name,const google::protobuf::Map<string,tensorflow::AttrValue> & attrs,int * num_retvals)60 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
61                      const google::protobuf::Map<string, tensorflow::AttrValue>& attrs,
62                      int* num_retvals) {
63   const tensorflow::OpRegistrationData* op_reg_data = nullptr;
64   auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
65   if (errors::IsNotFound(status)) {
66     status = context->FindFunctionOpData(op_name, &op_reg_data);
67   }
68   TF_RETURN_IF_ERROR(status);
69 
70   const tensorflow::OpDef& op_def = op_reg_data->op_def;
71 
72   for (const auto& output_arg : op_def.output_arg()) {
73     if (!output_arg.number_attr().empty()) {
74       auto iter = attrs.find(output_arg.number_attr());
75       if (iter == attrs.end()) {
76         return errors::InvalidArgument("Unable to find number_attr ",
77                                        output_arg.number_attr(),
78                                        " for Op: ", op_name);
79       }
80       *num_retvals += iter->second.i();
81     } else if (!output_arg.type_list_attr().empty()) {
82       auto iter = attrs.find(output_arg.type_list_attr());
83       if (iter == attrs.end()) {
84         return errors::InvalidArgument("Unable to find type_list_attr ",
85                                        output_arg.type_list_attr(),
86                                        " for Op: ", op_name);
87       }
88       *num_retvals += iter->second.list().type_size();
89     } else {
90       *num_retvals += 1;
91     }
92   }
93 
94   return Status::OK();
95 }
96 
GetEagerOperationAndNumRetvals(const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,EagerOperation * eager_op,int * num_retvals)97 Status GetEagerOperationAndNumRetvals(const Operation& operation,
98                                       EagerContext* eager_context,
99                                       EagerExecutor* eager_executor,
100                                       EagerOperation* eager_op,
101                                       int* num_retvals) {
102   const char* name = operation.name().c_str();  // Shorthand
103   absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
104       absl::nullopt;
105   if (operation.is_function()) {
106     if (operation.is_component_function()) {
107       remote_func_params = {operation.id(), operation.func_step_id()};
108     } else {
109       remote_func_params = {operation.id(), absl::nullopt};
110     }
111   }
112   TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false,
113                                      eager_executor, remote_func_params));
114 
115   {
116     profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
117                                profiler::TraceMeLevel::kVerbose);
118     for (const auto& input : operation.op_inputs()) {
119       tensorflow::TensorHandle* handle;
120       if (input.has_remote_handle()) {
121         TF_RETURN_IF_ERROR(
122             eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
123                 input.remote_handle(), &handle));
124         TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
125       } else {
126         Tensor tensor;
127         if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
128           return errors::InvalidArgument("Invalid TensorProto: ",
129                                          input.tensor().DebugString());
130         } else {
131           handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
132                                                    nullptr, eager_context);
133           TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
134         }
135       }
136       // Unref handle since it has a ref as an input now.
137       handle->Unref();
138     }
139   }
140 
141   for (const auto& attr : operation.attrs()) {
142     eager_op->MutableAttrs()->Set(attr.first, attr.second);
143   }
144 
145   // TODO(nareshmodi): Consider caching this.
146   return GetNumRetvals(eager_context, operation.name(), operation.attrs(),
147                        num_retvals);
148 }
149 
TensorHandleProto(TensorHandle * handle,TensorProto * proto)150 Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
151   const tensorflow::Tensor* t = nullptr;
152   TF_RETURN_IF_ERROR(handle->Tensor(&t));
153   t->AsProtoTensorContent(proto);
154   return Status::OK();
155 }
156 
TensorHandleShape(TensorHandle * handle,TensorShapeProto * proto)157 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
158   const tensorflow::Tensor* t = nullptr;
159 
160   // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
161   if (handle->Type() == TensorHandle::LOCAL) {
162     TF_RETURN_IF_ERROR(handle->Tensor(&t));
163 
164     t->shape().AsProto(proto);
165   } else {
166     TensorShape shape;
167     TF_RETURN_IF_ERROR(handle->Shape(&shape));
168     shape.AsProto(proto);
169   }
170 
171   return Status::OK();
172 }
173 
AddOpRetvalsToResponse(EagerContext * eager_context,int op_id,int num_retvals,const std::vector<int32> & output_nums,TensorHandle ** retvals,std::function<TensorProto * ()> add_tensor_proto_fn,std::function<TensorShapeProto * ()> add_shape_proto_fn,std::function<string * ()> add_device_fn=nullptr)174 Status AddOpRetvalsToResponse(
175     EagerContext* eager_context, int op_id, int num_retvals,
176     const std::vector<int32>& output_nums, TensorHandle** retvals,
177     std::function<TensorProto*()> add_tensor_proto_fn,
178     std::function<TensorShapeProto*()> add_shape_proto_fn,
179     std::function<string*()> add_device_fn = nullptr) {
180   if (op_id == kInvalidRemoteOpId) {
181     // Copy the output tensors back along with the response, since the op id
182     // is invalid which cannot be added to RemoteMgr.
183     for (int i = 0; i < num_retvals; i++) {
184       TF_RETURN_IF_ERROR(TensorHandleProto(retvals[i], add_tensor_proto_fn()));
185       retvals[i]->Unref();
186     }
187   } else {
188     for (int i = 0; i < num_retvals; i++) {
189       TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
190       if (add_device_fn) {
191         Device* device = retvals[i]->device();
192         *add_device_fn() = device ? device->name() : "";
193       }
194       if (retvals[i]->Type() == TensorHandle::REMOTE) {
195         retvals[i]->Unref();
196       } else {
197         const int output_num = output_nums.empty() ? i : output_nums.at(i);
198         eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id,
199                                                        output_num);
200       }
201     }
202   }
203   return Status::OK();
204 }
205 }  // namespace
206 
CreateContext(const CreateContextRequest * request,CreateContextResponse * response)207 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
208                                        CreateContextResponse* response) {
209   {
210     mutex_lock l(contexts_mu_);
211     auto context_it = contexts_.find(request->context_id());
212     if (context_it != contexts_.end()) {
213       if (request->context_view_id() <
214           context_it->second->Context()->GetContextViewId()) {
215         return errors::InvalidArgument("EagerService:CreateContext failed. ",
216                                        "Context id: <", request->context_id(),
217                                        "> already exists.");
218       } else {
219         // For existing context with a stale context_view_id, close the old one
220         // and recreate with new view id. This is likely due to the worker
221         // disconnected and then reconnected after one or more cluster updates.
222         context_it->second->Unref();
223         contexts_.erase(context_it);
224       }
225     }
226   }
227   // make sure env_ , env_->rendezvous_mgr available
228   if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
229     return tensorflow::errors::Internal(
230         "invalid eager env_ or env_->rendezvous_mgr.");
231   }
232 
233   auto* r = env_->rendezvous_mgr->Find(request->context_id());
234   auto session_name =
235       tensorflow::strings::StrCat("eager_", request->context_id());
236   if (VLOG_IS_ON(2)) {
237     VLOG(2) << "Creating context on /job:" << request->server_def().job_name()
238             << "/task:" << request->server_def().task_index();
239     for (const auto& da : request->cluster_device_attributes()) {
240       VLOG(2) << "    " << da.name();
241     }
242   }
243   TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
244       session_name, request->server_def(), request->cluster_device_attributes(),
245       true));
246   int64_t context_id = request->context_id();
247   std::function<void()> session_destroyer = [this, context_id, session_name]() {
248     env_->rendezvous_mgr->Cleanup(context_id);
249     auto s = env_->session_mgr->DeleteSession(session_name);
250     if (!s.ok()) {
251       LOG(WARNING) << "Failed to destroy worker session '" << session_name
252                    << "' due to " << s.error_message();
253     }
254   };
255 
256   std::shared_ptr<WorkerSession> worker_session;
257   TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
258       session_name, &worker_session));
259 
260   tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
261 
262   // Initialize remote tensor communication based on worker session.
263   TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
264 
265   std::function<Rendezvous*(const int64_t)> rendezvous_creator =
266       [worker_session, this](const int64_t step_id) {
267         auto* r = env_->rendezvous_mgr->Find(step_id);
268         r->Initialize(worker_session.get()).IgnoreError();
269         return r;
270       };
271 
272   LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
273             << " eager service context with rendezvous_id on host "
274             << port::Hostname() << " " << worker_session->worker_name();
275   SessionOptions opts;
276   opts.config = request->server_def().default_session_config();
277   tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
278       opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
279       request->async(), device_mgr, false, r, worker_session->cluster_flr(),
280       env_->collective_executor_mgr.get());
281   // Ownership will be transferred to the ServerContext, or else in an error
282   // case ctx will be deleted by this unref.
283   core::ScopedUnref unref_ctx(ctx);
284 
285   std::vector<string> remote_workers;
286   worker_session->worker_cache()->ListWorkers(&remote_workers);
287   remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
288                                    worker_session->worker_name()),
289                        remote_workers.end());
290 
291   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
292   TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
293       &remote_eager_workers));
294   DistributedFunctionLibraryRuntime* cluster_flr =
295       eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
296 
297   auto remote_mgr =
298       absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);
299   Status s = ctx->InitializeRemoteWorker(
300       std::move(remote_eager_workers), worker_session->remote_device_mgr(),
301       remote_workers, request->context_id(), request->context_view_id(),
302       std::move(rendezvous_creator), cluster_flr, std::move(remote_mgr),
303       std::move(session_destroyer));
304   if (!s.ok()) {
305     VLOG(1) << "EagerContext::InitializeRemoteWorker failed with "
306             << s.ToString();
307     return s;
308   }
309 
310 #if !defined(IS_MOBILE_PLATFORM)
311   const auto& config = request->server_def().default_session_config();
312   const bool enable_coordination =
313       !config.experimental().coordination_service().empty();
314   if (enable_coordination) {
315     auto dist_mgr = std::make_unique<EagerContextDistributedManager>(ctx);
316     ctx->SetDistributedManager(std::move(dist_mgr));
317     TF_RETURN_IF_ERROR(ctx->GetDistributedManager()->EnableCoordinationService(
318         config.experimental().coordination_service(), env_,
319         request->server_def(), worker_session->worker_cache()));
320     std::unique_ptr<CoordinationClientCache> client_cache;
321     TF_RETURN_IF_ERROR(
322         worker_session->worker_cache()->GetCoordinationClientCache(
323             &client_cache));
324     TF_RETURN_IF_ERROR(
325         ctx->GetDistributedManager()->GetCoordinationServiceAgent()->Initialize(
326             env_, request->server_def(), std::move(client_cache),
327             /*error_fn=*/[](Status s) {
328               LOG(ERROR) << "Coordination agent is set to error: " << s;
329             }));
330   }
331 #endif  // !IS_MOBILE_PLATFORM
332 
333   std::vector<DeviceAttributes> device_attributes;
334   device_mgr->ListDeviceAttributes(&device_attributes);
335 
336   for (const auto& da : device_attributes) {
337     *response->add_device_attributes() = da;
338   }
339   {
340     mutex_lock l(contexts_mu_);
341     auto context_it = contexts_.find(request->context_id());
342     if (context_it != contexts_.end()) {
343       return errors::InvalidArgument("EagerService:CreateContext failed. ",
344                                      "Context id: <", request->context_id(),
345                                      "> already exists.");
346     }
347     contexts_.emplace(request->context_id(),
348                       new ServerContext(ctx, request->keep_alive_secs(), env_));
349   }
350 
351   return Status::OK();
352 }
353 
UpdateContext(const UpdateContextRequest * request,UpdateContextResponse * response)354 Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
355                                        UpdateContextResponse* response) {
356   // make sure env_ , env_->rendezvous_mgr available
357   if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
358     return tensorflow::errors::Internal(
359         "invalid eager env_ or env_->rendezvous_mgr.");
360   }
361 
362   // Find the context to update by the requested context_id
363   ServerContext* server_context = nullptr;
364   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &server_context));
365   core::ScopedUnref context_unref(server_context);
366 
367   tensorflow::EagerContext* ctx = server_context->Context();
368   if (request->context_view_id() != ctx->GetContextViewId() + 1) {
369     return errors::InvalidArgument(
370         "EagerService:UpdateContext failed. Context id: <",
371         request->context_id(), "> currently at view #", ctx->GetContextViewId(),
372         " but received update request at view #", request->context_view_id(),
373         ". View id should only be continuously incremented.");
374   }
375   if (request->cluster_device_attributes_size() == 0) {
376     // In this case, the client indicates that the updated `server_def` and
377     // device info is irrelevant to this worker, since it is not connected to
378     // the updated ones (likely due to device filter settings). The worker
379     // simply needs to update view ID and does not update other internal state.
380     ctx->IncrementContextViewId();
381     VLOG(1) << "Processing simplified UpdateContextRequest on "
382             << ctx->HostCPU()->name();
383     return Status::OK();
384   }
385 
386   auto session_name =
387       tensorflow::strings::StrCat("eager_", request->context_id());
388 
389   TF_RETURN_IF_ERROR(env_->session_mgr->UpdateSession(
390       session_name, request->server_def(), request->cluster_device_attributes(),
391       true));
392 
393   std::shared_ptr<WorkerSession> worker_session;
394   TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
395       session_name, &worker_session));
396 
397   const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
398 
399   std::vector<string> remote_workers;
400   worker_session->worker_cache()->ListWorkers(&remote_workers);
401   remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
402                                    worker_session->worker_name()),
403                        remote_workers.end());
404   VLOG(1) << "On existing server " << worker_session->worker_name()
405           << " updating remote workers";
406   if (VLOG_IS_ON(2)) {
407     for (const string& rw : remote_workers) {
408       VLOG(2) << "Remote worker " << rw;
409     }
410   }
411 
412   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
413   TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
414       &remote_eager_workers));
415 
416   ctx->ClearCachesAndThreadExecutors();
417   Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers),
418                                      remote_workers, request->context_id());
419   if (!s.ok()) {
420     VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString();
421     return s;
422   }
423 
424   std::vector<DeviceAttributes> device_attributes;
425   device_mgr->ListDeviceAttributes(&device_attributes);
426 
427   for (const auto& da : device_attributes) {
428     *response->add_device_attributes() = da;
429   }
430 
431   return Status::OK();
432 }
433 
CreateMasterContext(const tensorflow::uint64 context_id,EagerContext * context)434 Status EagerServiceImpl::CreateMasterContext(
435     const tensorflow::uint64 context_id, EagerContext* context) {
436   {
437     mutex_lock l(contexts_mu_);
438     auto iter = contexts_.find(context_id);
439     if (iter != contexts_.end()) {
440       return errors::InvalidArgument(
441           "EagerService:CreateMasterContext failed. ", "Context id: <",
442           context_id, "> already exists.");
443     }
444   }
445   ServerContext* server_context =
446       ServerContext::CreateMasterContext(context, env_);
447   mutex_lock l(contexts_mu_);
448   contexts_.emplace(context_id, server_context);
449   return Status::OK();
450 }
451 
RunComponentFunction(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)452 void EagerServiceImpl::RunComponentFunction(
453     CallOptions* call_opts, const RunComponentFunctionRequest* request,
454     RunComponentFunctionResponse* response, StatusCallback done) {
455   ServerContext* context = nullptr;
456   Status s = GetServerContext(request->context_id(), &context);
457   if (!s.ok()) {
458     done(s);
459     return;
460   }
461   core::ScopedUnref context_unref(context);
462 
463   auto& operation = request->operation();
464   // This codepath should only be triggered for executing component function
465   if (!operation.is_function() || !operation.is_component_function()) {
466     done(errors::Internal(
467         "RunComponentFunction request can only be used to execute "
468         "component functions."));
469     return;
470   }
471 
472   EagerContext* eager_context = context->Context();
473   EagerExecutor* eager_executor = &eager_context->Executor();
474 
475   EagerOperation* op = new EagerOperation(eager_context);
476   int* num_retvals = new int(0);
477   s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor,
478                                      op, num_retvals);
479   if (!s.ok()) {
480     done(s);
481     return;
482   }
483   if (!op->IsLocal()) {
484     done(errors::Internal(
485         "Received RunComponentFunction request with remote function device. "));
486     return;
487   }
488   s = op->SetAttrBool("is_component_function", true);
489   if (!s.ok()) {
490     done(errors::Internal("Error setting is_component_function attribute: ",
491                           s.error_message()));
492     return;
493   }
494 
495   auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
496   VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
497           << operation.id();
498   std::vector<int32> output_nums;
499   for (const int32_t output_num : request->output_num()) {
500     output_nums.push_back(output_num);
501   }
502 
503   auto cm = std::make_shared<CancellationManager>();
504   op->SetCancellationManager(cm.get());
505   call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
506 
507   context->Ref();
508   EagerLocalExecuteAsync(
509       op, retvals->data(), num_retvals,
510       [op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
511        call_opts, response, eager_context, context,
512        done = std::move(done)](const Status& status) {
513         call_opts->ClearCancelCallback();
514         auto wrapped_done = [&](const Status& status) {
515           context->Unref();
516           done(status);
517           delete op;
518           delete num_retvals;
519           delete retvals;
520         };
521         if (!status.ok()) {
522           wrapped_done(status);
523           return;
524         }
525         // The output device of a component function is the component device
526         // which is known on the default device of it's parent function.
527         wrapped_done(AddOpRetvalsToResponse(
528             eager_context, op_id, *num_retvals, output_nums, retvals->data(),
529             [response] { return response->add_tensor(); },
530             [response] { return response->add_shape(); }));
531       });
532 }
533 
ExecuteOp(CallOptions * call_opts,const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,QueueResponse * queue_response)534 Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
535                                    const Operation& operation,
536                                    EagerContext* eager_context,
537                                    EagerExecutor* eager_executor,
538                                    QueueResponse* queue_response) {
539   tensorflow::EagerOperation op(eager_context);
540   int num_retvals = 0;
541   TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals(
542       operation, eager_context, eager_executor, &op, &num_retvals));
543 
544   auto cm = std::make_shared<CancellationManager>();
545   if (call_opts) {
546     op.SetCancellationManager(cm.get());
547     call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
548   }
549 
550   absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
551   VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
552   TF_RETURN_IF_ERROR(op.Execute(
553       absl::MakeSpan(
554           reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()),
555           num_retvals),
556       &num_retvals));
557 
558   std::function<string*()> add_device_fn = nullptr;
559   // Send the output devices of a function back to let a client know where the
560   // outputs are. For a primitive op, an output devics is the op device which is
561   // known on a client.
562   if (op.is_function()) {
563     add_device_fn = [queue_response] { return queue_response->add_device(); };
564   }
565 
566   return AddOpRetvalsToResponse(
567       eager_context, operation.id(), num_retvals, /*output_nums=*/{},
568       retvals.data(), [queue_response] { return queue_response->add_tensor(); },
569       [queue_response] { return queue_response->add_shape(); },
570       std::move(add_device_fn));
571 }
572 
Enqueue(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,uint64 stream_id)573 Status EagerServiceImpl::Enqueue(CallOptions* call_opts,
574                                  const EnqueueRequest* request,
575                                  EnqueueResponse* response, uint64 stream_id) {
576   profiler::TraceMe activity(
577       [&] {
578         return absl::StrCat(
579             "EagerService:Enqueue#debug_str=", request->DebugString(), "#");
580       },
581       profiler::TraceMeLevel::kInfo);
582   ServerContext* context = nullptr;
583   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
584   core::ScopedUnref context_unref(context);
585 
586   EagerExecutor& executor =
587       stream_id == kInvalidStreamId
588           ? context->Context()->Executor()
589           : context->Context()->RemoteMgr()->GetOrCreateExecutorForStream(
590                 stream_id);
591   Status s;
592   for (const auto& item : request->queue()) {
593     auto* queue_response = response->add_queue_response();
594     if (item.has_operation()) {
595       s = ExecuteOp(call_opts, item.operation(), context->Context(), &executor,
596                     queue_response);
597     } else if (item.has_handle_to_decref()) {
598       auto handle_to_decref = absl::make_unique<RemoteTensorHandleInternal>(
599           item.handle_to_decref());
600       auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
601           context, std::move(handle_to_decref));
602       s = context->Context()->Executor().AddOrExecute(std::move(node));
603     } else if (item.has_send_tensor()) {
604       s = SendTensor(item.send_tensor(), context->Context());
605     } else if (item.has_send_packed_handle()) {
606       s = SendPackedHandle(item.send_packed_handle(), context->Context());
607     } else if (item.has_register_function()) {
608       s = RegisterFunction(item.register_function(), context->Context());
609     } else if (item.has_cleanup_function()) {
610       s = CleanupFunction(item.cleanup_function());
611     } else {
612       DCHECK(item.has_sync_remote_executor_for_stream());
613       s = executor.WaitForAllPendingNodes();
614     }
615 
616     if (!s.ok()) {
617       if (stream_id != kInvalidStreamId) {
618         context->Context()->RemoteMgr()->DeleteExecutorForStream(stream_id);
619       }
620       return s;
621     }
622   }
623 
624   return Status::OK();
625 }
626 
WaitQueueDone(const WaitQueueDoneRequest * request,WaitQueueDoneResponse * response)627 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
628                                        WaitQueueDoneResponse* response) {
629   ServerContext* context = nullptr;
630   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
631   core::ScopedUnref context_unref(context);
632 
633   if (request->op_id_size() > 0) {
634     return errors::Unimplemented(
635         "EagerServiceImpl::WaitQueueDone is not "
636         "implemented for particular op IDs.");
637   }
638   return context->Context()->Executor().WaitForAllPendingNodes();
639 }
640 
KeepAlive(const KeepAliveRequest * request,KeepAliveResponse * response)641 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
642                                    KeepAliveResponse* response) {
643   ServerContext* context = nullptr;
644   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
645   core::ScopedUnref context_unref(context);
646 
647   tensorflow::EagerContext* ctx = context->Context();
648   response->set_context_view_id(ctx->GetContextViewId());
649   return Status::OK();
650 }
651 
CloseContext(const CloseContextRequest * request,CloseContextResponse * response)652 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
653                                       CloseContextResponse* response) {
654   VLOG(1) << "Executing EagerService::CloseContext for context "
655           << request->context_id();
656   ServerContext* context = nullptr;
657   if (!GetServerContext(request->context_id(), &context).ok()) {
658     // Swallow the error here.
659     return Status::OK();
660   }
661   core::ScopedUnref context_unref(context);
662 
663   if (request->context_view_id() < context->Context()->GetContextViewId()) {
664     // Swallow the error here.
665     LOG(INFO) << "Ignoring CloseContext request with a stale context_view_id "
666               << request->context_view_id() << "  for context_id "
667               << request->context_id() << ". The current context_view_id is "
668               << context->Context()->GetContextViewId() << ".";
669     return Status::OK();
670   }
671 
672   mutex_lock l(contexts_mu_);
673   contexts_.erase(request->context_id());
674 
675   // GetServerContext returns a newly Reffed copy of ServerContext, which is
676   // unreffed by context_unref. Additionally, we need to unref it one time since
677   // we are releasing it from the map.
678   context->Unref();
679 
680   return Status::OK();
681 }
682 
RegisterFunction(const RegisterFunctionOp & register_function,EagerContext * eager_context)683 Status EagerServiceImpl::RegisterFunction(
684     const RegisterFunctionOp& register_function, EagerContext* eager_context) {
685   // If the function is a component of a multi-device function, we only need to
686   // register it locally.
687   return eager_context->AddFunctionDef(
688       register_function.function_def(), register_function.library(),
689       register_function.is_component_function());
690 }
691 
CleanupFunction(const CleanupFunctionOp & cleanup_function)692 Status EagerServiceImpl::CleanupFunction(
693     const CleanupFunctionOp& cleanup_function) {
694   env_->rendezvous_mgr->Cleanup(cleanup_function.step_id());
695   return Status::OK();
696 }
697 
SendTensor(const SendTensorOp & send_tensor,EagerContext * eager_context)698 Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
699                                     EagerContext* eager_context) {
700   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
701   for (const auto& tensor_proto : send_tensor.tensors()) {
702     Tensor tensor;
703     if (!tensor.FromProto(tensor_proto)) {
704       return errors::InvalidArgument("Unable to parse tensor proto");
705     }
706 
707     TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle(
708         std::move(tensor), nullptr, nullptr, eager_context);
709     TensorHandle* copied_handle = nullptr;
710     Device* device;
711     TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
712         send_tensor.device_name().c_str(), &device));
713     TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, eager_context,
714                                          &eager_context->Executor(), device,
715                                          false, &copied_handle));
716     tensors.push_back(copied_handle);
717     tensor_handle->Unref();
718   }
719 
720   eager_context->RemoteMgr()->AddOperationOutputs(tensors, send_tensor.op_id());
721 
722   return Status::OK();
723 }
724 
SendPackedHandle(const SendPackedHandleOp & send_packed_handle,EagerContext * eager_context)725 Status EagerServiceImpl::SendPackedHandle(
726     const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) {
727   if (send_packed_handle.handles().empty()) {
728     return errors::InvalidArgument("Handles should not be empty.");
729   }
730 
731   std::vector<tensorflow::TensorHandle*> handles;
732   handles.resize(send_packed_handle.handles_size());
733   for (int i = 0; i < send_packed_handle.handles_size(); ++i) {
734     const auto& item = send_packed_handle.handles(i);
735     if (item.has_local_handle()) {
736       Tensor tensor;
737       if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) {
738         return errors::InvalidArgument(
739             "Invalid TensorProto: ",
740             item.local_handle().tensor().DebugString());
741       }
742       Device* op_device = nullptr;
743       TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
744           item.local_handle().device().c_str(), &op_device));
745       handles[i] = TensorHandle::CreateLocalHandle(
746           std::move(tensor), /*d=*/nullptr, op_device, eager_context);
747     } else {
748       TF_RETURN_IF_ERROR(
749           eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
750               item.remote_handle(), &handles[i]));
751     }
752   }
753 
754   tensorflow::TensorHandle* packed_handle = nullptr;
755   std::vector<tensorflow::TensorHandle*> handles_to_pack = handles;
756   // Create a unshaped packed TensorHandle.
757   TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
758       std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
759       send_packed_handle.device_name(), eager_context, &packed_handle));
760 
761   for (auto* h : handles) {
762     // Unref handle since it has a ref in the packed handle now.
763     h->Unref();
764   }
765 
766   eager_context->RemoteMgr()->AddOperationOutputs({packed_handle},
767                                                   send_packed_handle.op_id());
768   return Status::OK();
769 }
770 
GetServerContext(uint64 context_id,ServerContext ** server_context)771 tensorflow::Status EagerServiceImpl::GetServerContext(
772     uint64 context_id, ServerContext** server_context) {
773   tf_shared_lock l(contexts_mu_);
774   auto iter = contexts_.find(context_id);
775   if (iter == contexts_.end()) {
776     *server_context = nullptr;
777     return errors::Unavailable(strings::Printf(
778         "Unable to find a context_id matching the specified one "
779         "(%llu). Perhaps the worker was restarted, or the context was GC'd?",
780         static_cast<unsigned long long>(context_id)));
781   }
782 
783   *server_context = iter->second;
784   (*server_context)->Ref();
785 
786   (*server_context)->RecordAccess();
787 
788   return Status::OK();
789 }
790 
791 }  // namespace eager
792 }  // namespace tensorflow
793