• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17 
18 #include "absl/container/fixed_array.h"
19 #include "absl/memory/memory.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
26 #include "tensorflow/core/common_runtime/eager/execute.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
30 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
31 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
32 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
33 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
34 #include "tensorflow/core/distributed_runtime/server_lib.h"
35 #include "tensorflow/core/distributed_runtime/session_mgr.h"
36 #include "tensorflow/core/distributed_runtime/worker_cache.h"
37 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
38 #include "tensorflow/core/distributed_runtime/worker_env.h"
39 #include "tensorflow/core/framework/rendezvous.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/gtl/cleanup.h"
42 #include "tensorflow/core/lib/random/random.h"
43 #include "tensorflow/core/lib/strings/strcat.h"
44 #include "tensorflow/core/lib/strings/stringprintf.h"
45 #include "tensorflow/core/platform/cpu_info.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/host_info.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/refcount.h"
51 #include "tensorflow/core/profiler/lib/traceme.h"
52 #include "tensorflow/core/protobuf/error_codes.pb.h"
53 
54 namespace tensorflow {
55 namespace eager {
56 
57 namespace {
GetNumRetvals(tensorflow::EagerContext * context,const string & op_name,const google::protobuf::Map<string,tensorflow::AttrValue> & attrs,int * num_retvals)58 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
59                      const google::protobuf::Map<string, tensorflow::AttrValue>& attrs,
60                      int* num_retvals) {
61   const tensorflow::OpRegistrationData* op_reg_data = nullptr;
62   auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
63   if (errors::IsNotFound(status)) {
64     status = context->FindFunctionOpData(op_name, &op_reg_data);
65   }
66   TF_RETURN_IF_ERROR(status);
67 
68   const tensorflow::OpDef& op_def = op_reg_data->op_def;
69 
70   for (const auto& output_arg : op_def.output_arg()) {
71     if (!output_arg.number_attr().empty()) {
72       auto iter = attrs.find(output_arg.number_attr());
73       if (iter == attrs.end()) {
74         return errors::InvalidArgument("Unable to find number_attr ",
75                                        output_arg.number_attr(),
76                                        " for Op: ", op_name);
77       }
78       *num_retvals += iter->second.i();
79     } else if (!output_arg.type_list_attr().empty()) {
80       auto iter = attrs.find(output_arg.type_list_attr());
81       if (iter == attrs.end()) {
82         return errors::InvalidArgument("Unable to find type_list_attr ",
83                                        output_arg.type_list_attr(),
84                                        " for Op: ", op_name);
85       }
86       *num_retvals += iter->second.list().type_size();
87     } else {
88       *num_retvals += 1;
89     }
90   }
91 
92   return Status::OK();
93 }
94 
GetEagerOperationAndNumRetvals(const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,EagerOperation * eager_op,int * num_retvals)95 Status GetEagerOperationAndNumRetvals(const Operation& operation,
96                                       EagerContext* eager_context,
97                                       EagerExecutor* eager_executor,
98                                       EagerOperation* eager_op,
99                                       int* num_retvals) {
100   const char* name = operation.name().c_str();  // Shorthand
101   absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
102       absl::nullopt;
103   if (operation.is_function()) {
104     if (operation.is_component_function()) {
105       remote_func_params = {operation.id(), operation.func_step_id()};
106     } else {
107       remote_func_params = {operation.id(), absl::nullopt};
108     }
109   }
110   TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false,
111                                      eager_executor, remote_func_params));
112 
113   {
114     profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
115                                profiler::TraceMeLevel::kVerbose);
116     for (const auto& input : operation.op_inputs()) {
117       tensorflow::TensorHandle* handle;
118       if (input.has_remote_handle()) {
119         TF_RETURN_IF_ERROR(
120             eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
121                 input.remote_handle(), &handle));
122         TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
123       } else {
124         Tensor tensor;
125         if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
126           return errors::InvalidArgument("Invalid TensorProto: ",
127                                          input.tensor().DebugString());
128         } else {
129           handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
130                                                    nullptr, eager_context);
131           TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
132         }
133       }
134       // Unref handle since it has a ref as an input now.
135       handle->Unref();
136     }
137   }
138 
139   for (const auto& attr : operation.attrs()) {
140     eager_op->MutableAttrs()->Set(attr.first, attr.second);
141   }
142 
143   // TODO(nareshmodi): Consider caching this.
144   return GetNumRetvals(eager_context, operation.name(), operation.attrs(),
145                        num_retvals);
146 }
147 
TensorHandleProto(TensorHandle * handle,TensorProto * proto)148 Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
149   const tensorflow::Tensor* t = nullptr;
150   TF_RETURN_IF_ERROR(handle->Tensor(&t));
151   t->AsProtoTensorContent(proto);
152   return Status::OK();
153 }
154 
TensorHandleShape(TensorHandle * handle,TensorShapeProto * proto)155 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
156   const tensorflow::Tensor* t = nullptr;
157 
158   // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
159   if (handle->Type() == TensorHandle::LOCAL) {
160     TF_RETURN_IF_ERROR(handle->Tensor(&t));
161 
162     t->shape().AsProto(proto);
163   } else {
164     TensorShape shape;
165     TF_RETURN_IF_ERROR(handle->Shape(&shape));
166     shape.AsProto(proto);
167   }
168 
169   return Status::OK();
170 }
171 
AddOpRetvalsToResponse(EagerContext * eager_context,int op_id,int num_retvals,const std::vector<int32> & output_nums,TensorHandle ** retvals,std::function<TensorProto * ()> add_tensor_proto_fn,std::function<TensorShapeProto * ()> add_shape_proto_fn,std::function<string * ()> add_device_fn=nullptr)172 Status AddOpRetvalsToResponse(
173     EagerContext* eager_context, int op_id, int num_retvals,
174     const std::vector<int32>& output_nums, TensorHandle** retvals,
175     std::function<TensorProto*()> add_tensor_proto_fn,
176     std::function<TensorShapeProto*()> add_shape_proto_fn,
177     std::function<string*()> add_device_fn = nullptr) {
178   if (op_id == kInvalidRemoteOpId) {
179     // Copy the output tensors back along with the response, since the op id
180     // is invalid which cannot be added to RemoteMgr.
181     for (int i = 0; i < num_retvals; i++) {
182       TF_RETURN_IF_ERROR(TensorHandleProto(retvals[i], add_tensor_proto_fn()));
183       retvals[i]->Unref();
184     }
185   } else {
186     for (int i = 0; i < num_retvals; i++) {
187       TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
188       if (add_device_fn) {
189         Device* device = retvals[i]->device();
190         *add_device_fn() = device ? device->name() : "";
191       }
192       if (retvals[i]->Type() == TensorHandle::REMOTE) {
193         retvals[i]->Unref();
194       } else {
195         const int output_num = output_nums.empty() ? i : output_nums.at(i);
196         eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id,
197                                                        output_num);
198       }
199     }
200   }
201   return Status::OK();
202 }
203 }  // namespace
204 
CreateContext(const CreateContextRequest * request,CreateContextResponse * response)205 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
206                                        CreateContextResponse* response) {
207   {
208     mutex_lock l(contexts_mu_);
209     auto context_it = contexts_.find(request->context_id());
210     if (context_it != contexts_.end()) {
211       if (request->context_view_id() <
212           context_it->second->Context()->GetContextViewId()) {
213         return errors::InvalidArgument("EagerService:CreateContext failed. ",
214                                        "Context id: <", request->context_id(),
215                                        "> already exists.");
216       } else {
217         // For existing context with a stale context_view_id, close the old one
218         // and recreate with new view id. This is likely due to the worker
219         // disconnected and then reconnected after one or more cluster updates.
220         context_it->second->Unref();
221         contexts_.erase(context_it);
222       }
223     }
224   }
225   // make sure env_ , env_->rendezvous_mgr available
226   if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
227     return tensorflow::errors::Internal(
228         "invalid eager env_ or env_->rendezvous_mgr.");
229   }
230 
231   auto* r = env_->rendezvous_mgr->Find(request->context_id());
232   auto session_name =
233       tensorflow::strings::StrCat("eager_", request->context_id());
234   if (VLOG_IS_ON(2)) {
235     VLOG(2) << "Creating context on /job:" << request->server_def().job_name()
236             << "/task:" << request->server_def().task_index();
237     for (const auto& da : request->cluster_device_attributes()) {
238       VLOG(2) << "    " << da.name();
239     }
240   }
241   TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
242       session_name, request->server_def(), request->cluster_device_attributes(),
243       true));
244   int64 context_id = request->context_id();
245   std::function<void()> session_destroyer = [this, context_id, session_name]() {
246     env_->rendezvous_mgr->Cleanup(context_id);
247     auto s = env_->session_mgr->DeleteSession(session_name);
248     if (!s.ok()) {
249       LOG(WARNING) << "Failed to destroy worker session '" << session_name
250                    << "' due to " << s.error_message();
251     }
252   };
253 
254   std::shared_ptr<WorkerSession> worker_session;
255   TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
256       session_name, &worker_session));
257 
258   const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
259 
260   // Initialize remote tensor communication based on worker session.
261   TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
262 
263   std::function<Rendezvous*(const int64)> rendezvous_creator =
264       [worker_session, this](const int64 step_id) {
265         auto* r = env_->rendezvous_mgr->Find(step_id);
266         r->Initialize(worker_session.get()).IgnoreError();
267         return r;
268       };
269 
270   LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
271             << " eager service context with rendezvous_id on host "
272             << port::Hostname() << " " << worker_session->worker_name();
273   SessionOptions opts;
274   opts.config = request->server_def().default_session_config();
275   tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
276       opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
277       request->async(), device_mgr, false, r, worker_session->cluster_flr());
278   // Ownership will be transferred to the ServerContext, or else in an error
279   // case ctx will be deleted by this unref.
280   core::ScopedUnref unref_ctx(ctx);
281 
282   std::vector<string> remote_workers;
283   worker_session->worker_cache()->ListWorkers(&remote_workers);
284   remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
285                                    worker_session->worker_name()),
286                        remote_workers.end());
287 
288   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
289   TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
290       &remote_eager_workers));
291   DistributedFunctionLibraryRuntime* cluster_flr =
292       eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
293 
294   auto remote_mgr =
295       absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);
296   Status s = ctx->InitializeRemoteWorker(
297       std::move(remote_eager_workers), worker_session->remote_device_mgr(),
298       remote_workers, request->context_id(), request->context_view_id(),
299       std::move(rendezvous_creator), cluster_flr, std::move(remote_mgr),
300       std::move(session_destroyer));
301   if (!s.ok()) {
302     VLOG(1) << "EagerContext::InitializeRemoteWorker failed with "
303             << s.ToString();
304     return s;
305   }
306 
307   std::vector<DeviceAttributes> device_attributes;
308   device_mgr->ListDeviceAttributes(&device_attributes);
309 
310   for (const auto& da : device_attributes) {
311     *response->add_device_attributes() = da;
312   }
313   {
314     mutex_lock l(contexts_mu_);
315     auto context_it = contexts_.find(request->context_id());
316     if (context_it != contexts_.end()) {
317       return errors::InvalidArgument("EagerService:CreateContext failed. ",
318                                      "Context id: <", request->context_id(),
319                                      "> already exists.");
320     }
321     contexts_.emplace(request->context_id(),
322                       new ServerContext(ctx, request->keep_alive_secs(), env_));
323   }
324 
325   return Status::OK();
326 }
327 
UpdateContext(const UpdateContextRequest * request,UpdateContextResponse * response)328 Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
329                                        UpdateContextResponse* response) {
330   // make sure env_ , env_->rendezvous_mgr available
331   if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
332     return tensorflow::errors::Internal(
333         "invalid eager env_ or env_->rendezvous_mgr.");
334   }
335 
336   // Find the context to update by the requested context_id
337   ServerContext* server_context = nullptr;
338   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &server_context));
339   core::ScopedUnref context_unref(server_context);
340 
341   tensorflow::EagerContext* ctx = server_context->Context();
342   if (request->context_view_id() != ctx->GetContextViewId() + 1) {
343     return errors::InvalidArgument(
344         "EagerService:UpdateContext failed. Context id: <",
345         request->context_id(), "> currently at view #", ctx->GetContextViewId(),
346         " but received update request at view #", request->context_view_id(),
347         ". View id should only be continuously incremented.");
348   }
349   if (request->cluster_device_attributes_size() == 0) {
350     // In this case, the client indicates that the updated `server_def` and
351     // device info is irrelevant to this worker, since it is not connected to
352     // the updated ones (likely due to device filter settings). The worker
353     // simply needs to update view ID and does not update other internal state.
354     ctx->IncrementContextViewId();
355     VLOG(1) << "Processing simplified UpdateContextRequest on "
356             << ctx->HostCPU()->name();
357     return Status::OK();
358   }
359   // TODO(b/143914772): Potential memory leak if rendezvous has pending
360   // tensors for removed / replaced workers.
361 
362   auto session_name =
363       tensorflow::strings::StrCat("eager_", request->context_id());
364 
365   TF_RETURN_IF_ERROR(env_->session_mgr->UpdateSession(
366       session_name, request->server_def(), request->cluster_device_attributes(),
367       true));
368 
369   std::shared_ptr<WorkerSession> worker_session;
370   TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
371       session_name, &worker_session));
372 
373   const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
374 
375   std::vector<string> remote_workers;
376   worker_session->worker_cache()->ListWorkers(&remote_workers);
377   remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
378                                    worker_session->worker_name()),
379                        remote_workers.end());
380   VLOG(1) << "On existing server " << worker_session->worker_name()
381           << " updating remote workers";
382   if (VLOG_IS_ON(2)) {
383     for (const string& rw : remote_workers) {
384       VLOG(2) << "Remote worker " << rw;
385     }
386   }
387 
388   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
389   TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
390       &remote_eager_workers));
391 
392   ctx->ClearCachesAndThreadExecutors();
393   Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers),
394                                      remote_workers, request->context_id());
395   if (!s.ok()) {
396     VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString();
397     return s;
398   }
399 
400   std::vector<DeviceAttributes> device_attributes;
401   device_mgr->ListDeviceAttributes(&device_attributes);
402 
403   for (const auto& da : device_attributes) {
404     *response->add_device_attributes() = da;
405   }
406 
407   return Status::OK();
408 }
409 
CreateMasterContext(const tensorflow::uint64 context_id,EagerContext * context)410 Status EagerServiceImpl::CreateMasterContext(
411     const tensorflow::uint64 context_id, EagerContext* context) {
412   {
413     mutex_lock l(contexts_mu_);
414     auto iter = contexts_.find(context_id);
415     if (iter != contexts_.end()) {
416       return errors::InvalidArgument(
417           "EagerService:CreateMasterContext failed. ", "Context id: <",
418           context_id, "> already exists.");
419     }
420   }
421   ServerContext* server_context =
422       ServerContext::CreateMasterContext(context, env_);
423   mutex_lock l(contexts_mu_);
424   contexts_.emplace(context_id, server_context);
425   return Status::OK();
426 }
427 
RunComponentFunction(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)428 void EagerServiceImpl::RunComponentFunction(
429     CallOptions* call_opts, const RunComponentFunctionRequest* request,
430     RunComponentFunctionResponse* response, StatusCallback done) {
431   ServerContext* context = nullptr;
432   Status s = GetServerContext(request->context_id(), &context);
433   if (!s.ok()) {
434     done(s);
435     return;
436   }
437   core::ScopedUnref context_unref(context);
438 
439   auto& operation = request->operation();
440   // This codepath should only be triggered for executing component function
441   if (!operation.is_function() || !operation.is_component_function()) {
442     done(errors::Internal(
443         "RunComponentFunction request can only be used to execute "
444         "component functions."));
445     return;
446   }
447 
448   EagerContext* eager_context = context->Context();
449   EagerExecutor* eager_executor = &eager_context->Executor();
450 
451   EagerOperation* op = new EagerOperation(eager_context);
452   int* num_retvals = new int(0);
453   s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor,
454                                      op, num_retvals);
455   if (!s.ok()) {
456     done(s);
457     return;
458   }
459   if (!op->IsLocal()) {
460     done(errors::Internal(
461         "Received RunComponentFunction request with remote function device. "));
462     return;
463   }
464   s = op->SetAttrBool("is_component_function", true);
465   if (!s.ok()) {
466     done(errors::Internal("Error setting is_component_function attribute: ",
467                           s.error_message()));
468     return;
469   }
470 
471   auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
472   VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
473           << operation.id();
474   std::vector<int32> output_nums;
475   for (const int32 output_num : request->output_num()) {
476     output_nums.push_back(output_num);
477   }
478 
479   auto cm = std::make_shared<CancellationManager>();
480   op->SetCancellationManager(cm.get());
481   call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
482 
483   context->Ref();
484   EagerLocalExecuteAsync(
485       op, retvals->data(), num_retvals,
486       [op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
487        call_opts, response, eager_context, context,
488        done = std::move(done)](const Status& status) {
489         call_opts->ClearCancelCallback();
490         auto wrapped_done = [&](const Status& status) {
491           context->Unref();
492           done(status);
493           delete op;
494           delete num_retvals;
495           delete retvals;
496         };
497         if (!status.ok()) {
498           wrapped_done(status);
499           return;
500         }
501         // The output device of a component function is the component device
502         // which is known on the default device of it's parent function.
503         wrapped_done(AddOpRetvalsToResponse(
504             eager_context, op_id, *num_retvals, output_nums, retvals->data(),
505             [response] { return response->add_tensor(); },
506             [response] { return response->add_shape(); }));
507       });
508 }
509 
ExecuteOp(CallOptions * call_opts,const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,QueueResponse * queue_response)510 Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
511                                    const Operation& operation,
512                                    EagerContext* eager_context,
513                                    EagerExecutor* eager_executor,
514                                    QueueResponse* queue_response) {
515   tensorflow::EagerOperation op(eager_context);
516   int num_retvals = 0;
517   TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals(
518       operation, eager_context, eager_executor, &op, &num_retvals));
519 
520   auto cm = std::make_shared<CancellationManager>();
521   if (call_opts) {
522     op.SetCancellationManager(cm.get());
523     call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
524   }
525 
526   absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
527   VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
528   TF_RETURN_IF_ERROR(op.Execute(
529       absl::MakeSpan(
530           reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()),
531           num_retvals),
532       &num_retvals));
533 
534   std::function<string*()> add_device_fn = nullptr;
535   // Send the output devices of a function back to let a client know where the
536   // outputs are. For a primitive op, an output devics is the op device which is
537   // known on a client.
538   if (op.is_function()) {
539     add_device_fn = [queue_response] { return queue_response->add_device(); };
540   }
541 
542   return AddOpRetvalsToResponse(
543       eager_context, operation.id(), num_retvals, /*output_nums=*/{},
544       retvals.data(), [queue_response] { return queue_response->add_tensor(); },
545       [queue_response] { return queue_response->add_shape(); },
546       std::move(add_device_fn));
547 }
548 
Enqueue(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,uint64 stream_id)549 Status EagerServiceImpl::Enqueue(CallOptions* call_opts,
550                                  const EnqueueRequest* request,
551                                  EnqueueResponse* response, uint64 stream_id) {
552   profiler::TraceMe activity(
553       [&] {
554         return absl::StrCat(
555             "EagerService:Enqueue#debug_str=", request->DebugString(), "#");
556       },
557       profiler::TraceMeLevel::kInfo);
558   ServerContext* context = nullptr;
559   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
560   core::ScopedUnref context_unref(context);
561 
562   EagerExecutor& executor =
563       stream_id == kInvalidStreamId
564           ? context->Context()->Executor()
565           : context->Context()->RemoteMgr()->GetOrCreateExecutorForStream(
566                 stream_id);
567   Status s;
568   for (const auto& item : request->queue()) {
569     auto* queue_response = response->add_queue_response();
570     if (item.has_operation()) {
571       s = ExecuteOp(call_opts, item.operation(), context->Context(), &executor,
572                     queue_response);
573     } else if (item.has_handle_to_decref()) {
574       auto handle_to_decref = absl::make_unique<RemoteTensorHandleInternal>(
575           item.handle_to_decref());
576       auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
577           context, std::move(handle_to_decref));
578       s = context->Context()->Executor().AddOrExecute(std::move(node));
579     } else if (item.has_send_tensor()) {
580       s = SendTensor(item.send_tensor(), context->Context());
581     } else if (item.has_send_packed_handle()) {
582       s = SendPackedHandle(item.send_packed_handle(), context->Context());
583     } else if (item.has_register_function()) {
584       s = RegisterFunction(item.register_function(), context->Context());
585     } else if (item.has_cleanup_function()) {
586       s = CleanupFunction(item.cleanup_function());
587     } else {
588       DCHECK(item.has_sync_remote_executor_for_stream());
589       s = executor.WaitForAllPendingNodes();
590     }
591 
592     if (!s.ok()) {
593       if (stream_id != kInvalidStreamId) {
594         // TODO(b/138847548): Cleanup the executor when StreamCall is deleted.
595         context->Context()->RemoteMgr()->DeleteExecutorForStream(stream_id);
596       }
597       return s;
598     }
599   }
600 
601   return Status::OK();
602 }
603 
WaitQueueDone(const WaitQueueDoneRequest * request,WaitQueueDoneResponse * response)604 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
605                                        WaitQueueDoneResponse* response) {
606   ServerContext* context = nullptr;
607   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
608   core::ScopedUnref context_unref(context);
609 
610   if (request->op_id_size() > 0) {
611     return errors::Unimplemented(
612         "EagerServiceImpl::WaitQueueDone is not "
613         "implemented for particular op IDs.");
614   }
615   return context->Context()->Executor().WaitForAllPendingNodes();
616 }
617 
KeepAlive(const KeepAliveRequest * request,KeepAliveResponse * response)618 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
619                                    KeepAliveResponse* response) {
620   ServerContext* context = nullptr;
621   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
622   core::ScopedUnref context_unref(context);
623 
624   tensorflow::EagerContext* ctx = context->Context();
625   response->set_context_view_id(ctx->GetContextViewId());
626   return Status::OK();
627 }
628 
CloseContext(const CloseContextRequest * request,CloseContextResponse * response)629 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
630                                       CloseContextResponse* response) {
631   VLOG(1) << "Executing EagerService::CloseContext for context "
632           << request->context_id();
633   ServerContext* context = nullptr;
634   if (!GetServerContext(request->context_id(), &context).ok()) {
635     // Swallow the error here.
636     return Status::OK();
637   }
638   core::ScopedUnref context_unref(context);
639 
640   if (request->context_view_id() < context->Context()->GetContextViewId()) {
641     // Swallow the error here.
642     LOG(INFO) << "Ignoring CloseContext request with a stale context_view_id "
643               << request->context_view_id() << "  for context_id "
644               << request->context_id() << ". The current context_view_id is "
645               << context->Context()->GetContextViewId() << ".";
646     return Status::OK();
647   }
648 
649   mutex_lock l(contexts_mu_);
650   contexts_.erase(request->context_id());
651 
652   // GetServerContext returns a newly Reffed copy of ServerContext, which is
653   // unreffed by context_unref. Additionally, we need to unref it one time since
654   // we are releasing it from the map.
655   context->Unref();
656 
657   return Status::OK();
658 }
659 
RegisterFunction(const RegisterFunctionOp & register_function,EagerContext * eager_context)660 Status EagerServiceImpl::RegisterFunction(
661     const RegisterFunctionOp& register_function, EagerContext* eager_context) {
662   // If the function is a component of a multi-device function, we only need to
663   // register it locally.
664   return eager_context->AddFunctionDef(
665       register_function.function_def(), register_function.library(),
666       register_function.is_component_function());
667 }
668 
CleanupFunction(const CleanupFunctionOp & cleanup_function)669 Status EagerServiceImpl::CleanupFunction(
670     const CleanupFunctionOp& cleanup_function) {
671   env_->rendezvous_mgr->Cleanup(cleanup_function.step_id());
672   return Status::OK();
673 }
674 
SendTensor(const SendTensorOp & send_tensor,EagerContext * eager_context)675 Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
676                                     EagerContext* eager_context) {
677   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
678   for (const auto& tensor_proto : send_tensor.tensors()) {
679     Tensor tensor;
680     if (!tensor.FromProto(tensor_proto)) {
681       return errors::InvalidArgument("Unable to parse tensor proto");
682     }
683 
684     TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle(
685         std::move(tensor), nullptr, nullptr, eager_context);
686     TensorHandle* copied_handle = nullptr;
687     Device* device;
688     TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
689         send_tensor.device_name().c_str(), &device));
690     TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, eager_context,
691                                          &eager_context->Executor(), device,
692                                          false, &copied_handle));
693     tensors.push_back(copied_handle);
694     tensor_handle->Unref();
695   }
696 
697   eager_context->RemoteMgr()->AddOperationOutputs(tensors, send_tensor.op_id());
698 
699   return Status::OK();
700 }
701 
SendPackedHandle(const SendPackedHandleOp & send_packed_handle,EagerContext * eager_context)702 Status EagerServiceImpl::SendPackedHandle(
703     const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) {
704   if (send_packed_handle.handles().empty()) {
705     return errors::InvalidArgument("Handles should not be empty.");
706   }
707 
708   std::vector<tensorflow::TensorHandle*> handles;
709   handles.resize(send_packed_handle.handles_size());
710   for (int i = 0; i < send_packed_handle.handles_size(); ++i) {
711     const auto& item = send_packed_handle.handles(i);
712     if (item.has_local_handle()) {
713       Tensor tensor;
714       if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) {
715         return errors::InvalidArgument(
716             "Invalid TensorProto: ",
717             item.local_handle().tensor().DebugString());
718       }
719       Device* op_device = nullptr;
720       TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
721           item.local_handle().device().c_str(), &op_device));
722       handles[i] = TensorHandle::CreateLocalHandle(
723           std::move(tensor), /*d=*/nullptr, op_device, eager_context);
724     } else {
725       TF_RETURN_IF_ERROR(
726           eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
727               item.remote_handle(), &handles[i]));
728     }
729   }
730 
731   tensorflow::TensorHandle* packed_handle = nullptr;
732   std::vector<tensorflow::TensorHandle*> handles_to_pack = handles;
733   // Create a unshaped packed TensorHandle.
734   TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
735       std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
736       send_packed_handle.device_name(), eager_context, &packed_handle));
737 
738   for (auto* h : handles) {
739     // Unref handle since it has a ref in the packed handle now.
740     h->Unref();
741   }
742 
743   eager_context->RemoteMgr()->AddOperationOutputs({packed_handle},
744                                                   send_packed_handle.op_id());
745   return Status::OK();
746 }
747 
GetServerContext(uint64 context_id,ServerContext ** server_context)748 tensorflow::Status EagerServiceImpl::GetServerContext(
749     uint64 context_id, ServerContext** server_context) {
750   tf_shared_lock l(contexts_mu_);
751   auto iter = contexts_.find(context_id);
752   if (iter == contexts_.end()) {
753     *server_context = nullptr;
754     return errors::Unavailable(strings::Printf(
755         "Unable to find a context_id matching the specified one "
756         "(%llu). Perhaps the worker was restarted, or the context was GC'd?",
757         static_cast<unsigned long long>(context_id)));
758   }
759 
760   *server_context = iter->second;
761   (*server_context)->Ref();
762 
763   (*server_context)->RecordAccess();
764 
765   return Status::OK();
766 }
767 
768 }  // namespace eager
769 }  // namespace tensorflow
770