1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17
18 #include "absl/container/fixed_array.h"
19 #include "absl/memory/memory.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/c/c_api_internal.h"
22 #include "tensorflow/c/tf_status_helper.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
26 #include "tensorflow/core/common_runtime/eager/execute.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/process_util.h"
29 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
30 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
31 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
32 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
33 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
34 #include "tensorflow/core/distributed_runtime/server_lib.h"
35 #include "tensorflow/core/distributed_runtime/session_mgr.h"
36 #include "tensorflow/core/distributed_runtime/worker_cache.h"
37 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
38 #include "tensorflow/core/distributed_runtime/worker_env.h"
39 #include "tensorflow/core/framework/rendezvous.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/gtl/cleanup.h"
42 #include "tensorflow/core/lib/random/random.h"
43 #include "tensorflow/core/lib/strings/strcat.h"
44 #include "tensorflow/core/lib/strings/stringprintf.h"
45 #include "tensorflow/core/platform/cpu_info.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/host_info.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/refcount.h"
51 #include "tensorflow/core/profiler/lib/traceme.h"
52 #include "tensorflow/core/protobuf/error_codes.pb.h"
53
54 namespace tensorflow {
55 namespace eager {
56
57 namespace {
GetNumRetvals(tensorflow::EagerContext * context,const string & op_name,const google::protobuf::Map<string,tensorflow::AttrValue> & attrs,int * num_retvals)58 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
59 const google::protobuf::Map<string, tensorflow::AttrValue>& attrs,
60 int* num_retvals) {
61 const tensorflow::OpRegistrationData* op_reg_data = nullptr;
62 auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
63 if (errors::IsNotFound(status)) {
64 status = context->FindFunctionOpData(op_name, &op_reg_data);
65 }
66 TF_RETURN_IF_ERROR(status);
67
68 const tensorflow::OpDef& op_def = op_reg_data->op_def;
69
70 for (const auto& output_arg : op_def.output_arg()) {
71 if (!output_arg.number_attr().empty()) {
72 auto iter = attrs.find(output_arg.number_attr());
73 if (iter == attrs.end()) {
74 return errors::InvalidArgument("Unable to find number_attr ",
75 output_arg.number_attr(),
76 " for Op: ", op_name);
77 }
78 *num_retvals += iter->second.i();
79 } else if (!output_arg.type_list_attr().empty()) {
80 auto iter = attrs.find(output_arg.type_list_attr());
81 if (iter == attrs.end()) {
82 return errors::InvalidArgument("Unable to find type_list_attr ",
83 output_arg.type_list_attr(),
84 " for Op: ", op_name);
85 }
86 *num_retvals += iter->second.list().type_size();
87 } else {
88 *num_retvals += 1;
89 }
90 }
91
92 return Status::OK();
93 }
94
GetEagerOperationAndNumRetvals(const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,EagerOperation * eager_op,int * num_retvals)95 Status GetEagerOperationAndNumRetvals(const Operation& operation,
96 EagerContext* eager_context,
97 EagerExecutor* eager_executor,
98 EagerOperation* eager_op,
99 int* num_retvals) {
100 const char* name = operation.name().c_str(); // Shorthand
101 absl::optional<tensorflow::EagerRemoteFunctionParams> remote_func_params =
102 absl::nullopt;
103 if (operation.is_function()) {
104 if (operation.is_component_function()) {
105 remote_func_params = {operation.id(), operation.func_step_id()};
106 } else {
107 remote_func_params = {operation.id(), absl::nullopt};
108 }
109 }
110 TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false,
111 eager_executor, remote_func_params));
112
113 {
114 profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
115 profiler::TraceMeLevel::kVerbose);
116 for (const auto& input : operation.op_inputs()) {
117 tensorflow::TensorHandle* handle;
118 if (input.has_remote_handle()) {
119 TF_RETURN_IF_ERROR(
120 eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
121 input.remote_handle(), &handle));
122 TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
123 } else {
124 Tensor tensor;
125 if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
126 return errors::InvalidArgument("Invalid TensorProto: ",
127 input.tensor().DebugString());
128 } else {
129 handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
130 nullptr, eager_context);
131 TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
132 }
133 }
134 // Unref handle since it has a ref as an input now.
135 handle->Unref();
136 }
137 }
138
139 for (const auto& attr : operation.attrs()) {
140 eager_op->MutableAttrs()->Set(attr.first, attr.second);
141 }
142
143 // TODO(nareshmodi): Consider caching this.
144 return GetNumRetvals(eager_context, operation.name(), operation.attrs(),
145 num_retvals);
146 }
147
TensorHandleProto(TensorHandle * handle,TensorProto * proto)148 Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
149 const tensorflow::Tensor* t = nullptr;
150 TF_RETURN_IF_ERROR(handle->Tensor(&t));
151 t->AsProtoTensorContent(proto);
152 return Status::OK();
153 }
154
TensorHandleShape(TensorHandle * handle,TensorShapeProto * proto)155 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
156 const tensorflow::Tensor* t = nullptr;
157
158 // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
159 if (handle->Type() == TensorHandle::LOCAL) {
160 TF_RETURN_IF_ERROR(handle->Tensor(&t));
161
162 t->shape().AsProto(proto);
163 } else {
164 TensorShape shape;
165 TF_RETURN_IF_ERROR(handle->Shape(&shape));
166 shape.AsProto(proto);
167 }
168
169 return Status::OK();
170 }
171
AddOpRetvalsToResponse(EagerContext * eager_context,int op_id,int num_retvals,const std::vector<int32> & output_nums,TensorHandle ** retvals,std::function<TensorProto * ()> add_tensor_proto_fn,std::function<TensorShapeProto * ()> add_shape_proto_fn,std::function<string * ()> add_device_fn=nullptr)172 Status AddOpRetvalsToResponse(
173 EagerContext* eager_context, int op_id, int num_retvals,
174 const std::vector<int32>& output_nums, TensorHandle** retvals,
175 std::function<TensorProto*()> add_tensor_proto_fn,
176 std::function<TensorShapeProto*()> add_shape_proto_fn,
177 std::function<string*()> add_device_fn = nullptr) {
178 if (op_id == kInvalidRemoteOpId) {
179 // Copy the output tensors back along with the response, since the op id
180 // is invalid which cannot be added to RemoteMgr.
181 for (int i = 0; i < num_retvals; i++) {
182 TF_RETURN_IF_ERROR(TensorHandleProto(retvals[i], add_tensor_proto_fn()));
183 retvals[i]->Unref();
184 }
185 } else {
186 for (int i = 0; i < num_retvals; i++) {
187 TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
188 if (add_device_fn) {
189 Device* device = retvals[i]->device();
190 *add_device_fn() = device ? device->name() : "";
191 }
192 if (retvals[i]->Type() == TensorHandle::REMOTE) {
193 retvals[i]->Unref();
194 } else {
195 const int output_num = output_nums.empty() ? i : output_nums.at(i);
196 eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id,
197 output_num);
198 }
199 }
200 }
201 return Status::OK();
202 }
203 } // namespace
204
CreateContext(const CreateContextRequest * request,CreateContextResponse * response)205 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
206 CreateContextResponse* response) {
207 {
208 mutex_lock l(contexts_mu_);
209 auto context_it = contexts_.find(request->context_id());
210 if (context_it != contexts_.end()) {
211 if (request->context_view_id() <
212 context_it->second->Context()->GetContextViewId()) {
213 return errors::InvalidArgument("EagerService:CreateContext failed. ",
214 "Context id: <", request->context_id(),
215 "> already exists.");
216 } else {
217 // For existing context with a stale context_view_id, close the old one
218 // and recreate with new view id. This is likely due to the worker
219 // disconnected and then reconnected after one or more cluster updates.
220 context_it->second->Unref();
221 contexts_.erase(context_it);
222 }
223 }
224 }
225 // make sure env_ , env_->rendezvous_mgr available
226 if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
227 return tensorflow::errors::Internal(
228 "invalid eager env_ or env_->rendezvous_mgr.");
229 }
230
231 auto* r = env_->rendezvous_mgr->Find(request->context_id());
232 auto session_name =
233 tensorflow::strings::StrCat("eager_", request->context_id());
234 if (VLOG_IS_ON(2)) {
235 VLOG(2) << "Creating context on /job:" << request->server_def().job_name()
236 << "/task:" << request->server_def().task_index();
237 for (const auto& da : request->cluster_device_attributes()) {
238 VLOG(2) << " " << da.name();
239 }
240 }
241 TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
242 session_name, request->server_def(), request->cluster_device_attributes(),
243 true));
244 int64 context_id = request->context_id();
245 std::function<void()> session_destroyer = [this, context_id, session_name]() {
246 env_->rendezvous_mgr->Cleanup(context_id);
247 auto s = env_->session_mgr->DeleteSession(session_name);
248 if (!s.ok()) {
249 LOG(WARNING) << "Failed to destroy worker session '" << session_name
250 << "' due to " << s.error_message();
251 }
252 };
253
254 std::shared_ptr<WorkerSession> worker_session;
255 TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
256 session_name, &worker_session));
257
258 const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
259
260 // Initialize remote tensor communication based on worker session.
261 TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
262
263 std::function<Rendezvous*(const int64)> rendezvous_creator =
264 [worker_session, this](const int64 step_id) {
265 auto* r = env_->rendezvous_mgr->Find(step_id);
266 r->Initialize(worker_session.get()).IgnoreError();
267 return r;
268 };
269
270 LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
271 << " eager service context with rendezvous_id on host "
272 << port::Hostname() << " " << worker_session->worker_name();
273 SessionOptions opts;
274 opts.config = request->server_def().default_session_config();
275 tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
276 opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
277 request->async(), device_mgr, false, r, worker_session->cluster_flr());
278 // Ownership will be transferred to the ServerContext, or else in an error
279 // case ctx will be deleted by this unref.
280 core::ScopedUnref unref_ctx(ctx);
281
282 std::vector<string> remote_workers;
283 worker_session->worker_cache()->ListWorkers(&remote_workers);
284 remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
285 worker_session->worker_name()),
286 remote_workers.end());
287
288 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
289 TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
290 &remote_eager_workers));
291 DistributedFunctionLibraryRuntime* cluster_flr =
292 eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
293
294 auto remote_mgr =
295 absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);
296 Status s = ctx->InitializeRemoteWorker(
297 std::move(remote_eager_workers), worker_session->remote_device_mgr(),
298 remote_workers, request->context_id(), request->context_view_id(),
299 std::move(rendezvous_creator), cluster_flr, std::move(remote_mgr),
300 std::move(session_destroyer));
301 if (!s.ok()) {
302 VLOG(1) << "EagerContext::InitializeRemoteWorker failed with "
303 << s.ToString();
304 return s;
305 }
306
307 std::vector<DeviceAttributes> device_attributes;
308 device_mgr->ListDeviceAttributes(&device_attributes);
309
310 for (const auto& da : device_attributes) {
311 *response->add_device_attributes() = da;
312 }
313 {
314 mutex_lock l(contexts_mu_);
315 auto context_it = contexts_.find(request->context_id());
316 if (context_it != contexts_.end()) {
317 return errors::InvalidArgument("EagerService:CreateContext failed. ",
318 "Context id: <", request->context_id(),
319 "> already exists.");
320 }
321 contexts_.emplace(request->context_id(),
322 new ServerContext(ctx, request->keep_alive_secs(), env_));
323 }
324
325 return Status::OK();
326 }
327
UpdateContext(const UpdateContextRequest * request,UpdateContextResponse * response)328 Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
329 UpdateContextResponse* response) {
330 // make sure env_ , env_->rendezvous_mgr available
331 if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
332 return tensorflow::errors::Internal(
333 "invalid eager env_ or env_->rendezvous_mgr.");
334 }
335
336 // Find the context to update by the requested context_id
337 ServerContext* server_context = nullptr;
338 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &server_context));
339 core::ScopedUnref context_unref(server_context);
340
341 tensorflow::EagerContext* ctx = server_context->Context();
342 if (request->context_view_id() != ctx->GetContextViewId() + 1) {
343 return errors::InvalidArgument(
344 "EagerService:UpdateContext failed. Context id: <",
345 request->context_id(), "> currently at view #", ctx->GetContextViewId(),
346 " but received update request at view #", request->context_view_id(),
347 ". View id should only be continuously incremented.");
348 }
349 if (request->cluster_device_attributes_size() == 0) {
350 // In this case, the client indicates that the updated `server_def` and
351 // device info is irrelevant to this worker, since it is not connected to
352 // the updated ones (likely due to device filter settings). The worker
353 // simply needs to update view ID and does not update other internal state.
354 ctx->IncrementContextViewId();
355 VLOG(1) << "Processing simplified UpdateContextRequest on "
356 << ctx->HostCPU()->name();
357 return Status::OK();
358 }
359 // TODO(b/143914772): Potential memory leak if rendezvous has pending
360 // tensors for removed / replaced workers.
361
362 auto session_name =
363 tensorflow::strings::StrCat("eager_", request->context_id());
364
365 TF_RETURN_IF_ERROR(env_->session_mgr->UpdateSession(
366 session_name, request->server_def(), request->cluster_device_attributes(),
367 true));
368
369 std::shared_ptr<WorkerSession> worker_session;
370 TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
371 session_name, &worker_session));
372
373 const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
374
375 std::vector<string> remote_workers;
376 worker_session->worker_cache()->ListWorkers(&remote_workers);
377 remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
378 worker_session->worker_name()),
379 remote_workers.end());
380 VLOG(1) << "On existing server " << worker_session->worker_name()
381 << " updating remote workers";
382 if (VLOG_IS_ON(2)) {
383 for (const string& rw : remote_workers) {
384 VLOG(2) << "Remote worker " << rw;
385 }
386 }
387
388 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
389 TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
390 &remote_eager_workers));
391
392 ctx->ClearCachesAndThreadExecutors();
393 Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers),
394 remote_workers, request->context_id());
395 if (!s.ok()) {
396 VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString();
397 return s;
398 }
399
400 std::vector<DeviceAttributes> device_attributes;
401 device_mgr->ListDeviceAttributes(&device_attributes);
402
403 for (const auto& da : device_attributes) {
404 *response->add_device_attributes() = da;
405 }
406
407 return Status::OK();
408 }
409
CreateMasterContext(const tensorflow::uint64 context_id,EagerContext * context)410 Status EagerServiceImpl::CreateMasterContext(
411 const tensorflow::uint64 context_id, EagerContext* context) {
412 {
413 mutex_lock l(contexts_mu_);
414 auto iter = contexts_.find(context_id);
415 if (iter != contexts_.end()) {
416 return errors::InvalidArgument(
417 "EagerService:CreateMasterContext failed. ", "Context id: <",
418 context_id, "> already exists.");
419 }
420 }
421 ServerContext* server_context =
422 ServerContext::CreateMasterContext(context, env_);
423 mutex_lock l(contexts_mu_);
424 contexts_.emplace(context_id, server_context);
425 return Status::OK();
426 }
427
RunComponentFunction(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)428 void EagerServiceImpl::RunComponentFunction(
429 CallOptions* call_opts, const RunComponentFunctionRequest* request,
430 RunComponentFunctionResponse* response, StatusCallback done) {
431 ServerContext* context = nullptr;
432 Status s = GetServerContext(request->context_id(), &context);
433 if (!s.ok()) {
434 done(s);
435 return;
436 }
437 core::ScopedUnref context_unref(context);
438
439 auto& operation = request->operation();
440 // This codepath should only be triggered for executing component function
441 if (!operation.is_function() || !operation.is_component_function()) {
442 done(errors::Internal(
443 "RunComponentFunction request can only be used to execute "
444 "component functions."));
445 return;
446 }
447
448 EagerContext* eager_context = context->Context();
449 EagerExecutor* eager_executor = &eager_context->Executor();
450
451 EagerOperation* op = new EagerOperation(eager_context);
452 int* num_retvals = new int(0);
453 s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor,
454 op, num_retvals);
455 if (!s.ok()) {
456 done(s);
457 return;
458 }
459 if (!op->IsLocal()) {
460 done(errors::Internal(
461 "Received RunComponentFunction request with remote function device. "));
462 return;
463 }
464 s = op->SetAttrBool("is_component_function", true);
465 if (!s.ok()) {
466 done(errors::Internal("Error setting is_component_function attribute: ",
467 s.error_message()));
468 return;
469 }
470
471 auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
472 VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
473 << operation.id();
474 std::vector<int32> output_nums;
475 for (const int32 output_num : request->output_num()) {
476 output_nums.push_back(output_num);
477 }
478
479 auto cm = std::make_shared<CancellationManager>();
480 op->SetCancellationManager(cm.get());
481 call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
482
483 context->Ref();
484 EagerLocalExecuteAsync(
485 op, retvals->data(), num_retvals,
486 [op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
487 call_opts, response, eager_context, context,
488 done = std::move(done)](const Status& status) {
489 call_opts->ClearCancelCallback();
490 auto wrapped_done = [&](const Status& status) {
491 context->Unref();
492 done(status);
493 delete op;
494 delete num_retvals;
495 delete retvals;
496 };
497 if (!status.ok()) {
498 wrapped_done(status);
499 return;
500 }
501 // The output device of a component function is the component device
502 // which is known on the default device of it's parent function.
503 wrapped_done(AddOpRetvalsToResponse(
504 eager_context, op_id, *num_retvals, output_nums, retvals->data(),
505 [response] { return response->add_tensor(); },
506 [response] { return response->add_shape(); }));
507 });
508 }
509
ExecuteOp(CallOptions * call_opts,const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,QueueResponse * queue_response)510 Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
511 const Operation& operation,
512 EagerContext* eager_context,
513 EagerExecutor* eager_executor,
514 QueueResponse* queue_response) {
515 tensorflow::EagerOperation op(eager_context);
516 int num_retvals = 0;
517 TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals(
518 operation, eager_context, eager_executor, &op, &num_retvals));
519
520 auto cm = std::make_shared<CancellationManager>();
521 if (call_opts) {
522 op.SetCancellationManager(cm.get());
523 call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
524 }
525
526 absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
527 VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
528 TF_RETURN_IF_ERROR(op.Execute(
529 absl::MakeSpan(
530 reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()),
531 num_retvals),
532 &num_retvals));
533
534 std::function<string*()> add_device_fn = nullptr;
535 // Send the output devices of a function back to let a client know where the
536 // outputs are. For a primitive op, an output devics is the op device which is
537 // known on a client.
538 if (op.is_function()) {
539 add_device_fn = [queue_response] { return queue_response->add_device(); };
540 }
541
542 return AddOpRetvalsToResponse(
543 eager_context, operation.id(), num_retvals, /*output_nums=*/{},
544 retvals.data(), [queue_response] { return queue_response->add_tensor(); },
545 [queue_response] { return queue_response->add_shape(); },
546 std::move(add_device_fn));
547 }
548
Enqueue(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,uint64 stream_id)549 Status EagerServiceImpl::Enqueue(CallOptions* call_opts,
550 const EnqueueRequest* request,
551 EnqueueResponse* response, uint64 stream_id) {
552 profiler::TraceMe activity(
553 [&] {
554 return absl::StrCat(
555 "EagerService:Enqueue#debug_str=", request->DebugString(), "#");
556 },
557 profiler::TraceMeLevel::kInfo);
558 ServerContext* context = nullptr;
559 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
560 core::ScopedUnref context_unref(context);
561
562 EagerExecutor& executor =
563 stream_id == kInvalidStreamId
564 ? context->Context()->Executor()
565 : context->Context()->RemoteMgr()->GetOrCreateExecutorForStream(
566 stream_id);
567 Status s;
568 for (const auto& item : request->queue()) {
569 auto* queue_response = response->add_queue_response();
570 if (item.has_operation()) {
571 s = ExecuteOp(call_opts, item.operation(), context->Context(), &executor,
572 queue_response);
573 } else if (item.has_handle_to_decref()) {
574 auto handle_to_decref = absl::make_unique<RemoteTensorHandleInternal>(
575 item.handle_to_decref());
576 auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
577 context, std::move(handle_to_decref));
578 s = context->Context()->Executor().AddOrExecute(std::move(node));
579 } else if (item.has_send_tensor()) {
580 s = SendTensor(item.send_tensor(), context->Context());
581 } else if (item.has_send_packed_handle()) {
582 s = SendPackedHandle(item.send_packed_handle(), context->Context());
583 } else if (item.has_register_function()) {
584 s = RegisterFunction(item.register_function(), context->Context());
585 } else if (item.has_cleanup_function()) {
586 s = CleanupFunction(item.cleanup_function());
587 } else {
588 DCHECK(item.has_sync_remote_executor_for_stream());
589 s = executor.WaitForAllPendingNodes();
590 }
591
592 if (!s.ok()) {
593 if (stream_id != kInvalidStreamId) {
594 // TODO(b/138847548): Cleanup the executor when StreamCall is deleted.
595 context->Context()->RemoteMgr()->DeleteExecutorForStream(stream_id);
596 }
597 return s;
598 }
599 }
600
601 return Status::OK();
602 }
603
WaitQueueDone(const WaitQueueDoneRequest * request,WaitQueueDoneResponse * response)604 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
605 WaitQueueDoneResponse* response) {
606 ServerContext* context = nullptr;
607 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
608 core::ScopedUnref context_unref(context);
609
610 if (request->op_id_size() > 0) {
611 return errors::Unimplemented(
612 "EagerServiceImpl::WaitQueueDone is not "
613 "implemented for particular op IDs.");
614 }
615 return context->Context()->Executor().WaitForAllPendingNodes();
616 }
617
KeepAlive(const KeepAliveRequest * request,KeepAliveResponse * response)618 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
619 KeepAliveResponse* response) {
620 ServerContext* context = nullptr;
621 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
622 core::ScopedUnref context_unref(context);
623
624 tensorflow::EagerContext* ctx = context->Context();
625 response->set_context_view_id(ctx->GetContextViewId());
626 return Status::OK();
627 }
628
CloseContext(const CloseContextRequest * request,CloseContextResponse * response)629 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
630 CloseContextResponse* response) {
631 VLOG(1) << "Executing EagerService::CloseContext for context "
632 << request->context_id();
633 ServerContext* context = nullptr;
634 if (!GetServerContext(request->context_id(), &context).ok()) {
635 // Swallow the error here.
636 return Status::OK();
637 }
638 core::ScopedUnref context_unref(context);
639
640 if (request->context_view_id() < context->Context()->GetContextViewId()) {
641 // Swallow the error here.
642 LOG(INFO) << "Ignoring CloseContext request with a stale context_view_id "
643 << request->context_view_id() << " for context_id "
644 << request->context_id() << ". The current context_view_id is "
645 << context->Context()->GetContextViewId() << ".";
646 return Status::OK();
647 }
648
649 mutex_lock l(contexts_mu_);
650 contexts_.erase(request->context_id());
651
652 // GetServerContext returns a newly Reffed copy of ServerContext, which is
653 // unreffed by context_unref. Additionally, we need to unref it one time since
654 // we are releasing it from the map.
655 context->Unref();
656
657 return Status::OK();
658 }
659
RegisterFunction(const RegisterFunctionOp & register_function,EagerContext * eager_context)660 Status EagerServiceImpl::RegisterFunction(
661 const RegisterFunctionOp& register_function, EagerContext* eager_context) {
662 // If the function is a component of a multi-device function, we only need to
663 // register it locally.
664 return eager_context->AddFunctionDef(
665 register_function.function_def(), register_function.library(),
666 register_function.is_component_function());
667 }
668
CleanupFunction(const CleanupFunctionOp & cleanup_function)669 Status EagerServiceImpl::CleanupFunction(
670 const CleanupFunctionOp& cleanup_function) {
671 env_->rendezvous_mgr->Cleanup(cleanup_function.step_id());
672 return Status::OK();
673 }
674
SendTensor(const SendTensorOp & send_tensor,EagerContext * eager_context)675 Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
676 EagerContext* eager_context) {
677 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
678 for (const auto& tensor_proto : send_tensor.tensors()) {
679 Tensor tensor;
680 if (!tensor.FromProto(tensor_proto)) {
681 return errors::InvalidArgument("Unable to parse tensor proto");
682 }
683
684 TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle(
685 std::move(tensor), nullptr, nullptr, eager_context);
686 TensorHandle* copied_handle = nullptr;
687 Device* device;
688 TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
689 send_tensor.device_name().c_str(), &device));
690 TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, eager_context,
691 &eager_context->Executor(), device,
692 false, &copied_handle));
693 tensors.push_back(copied_handle);
694 tensor_handle->Unref();
695 }
696
697 eager_context->RemoteMgr()->AddOperationOutputs(tensors, send_tensor.op_id());
698
699 return Status::OK();
700 }
701
SendPackedHandle(const SendPackedHandleOp & send_packed_handle,EagerContext * eager_context)702 Status EagerServiceImpl::SendPackedHandle(
703 const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) {
704 if (send_packed_handle.handles().empty()) {
705 return errors::InvalidArgument("Handles should not be empty.");
706 }
707
708 std::vector<tensorflow::TensorHandle*> handles;
709 handles.resize(send_packed_handle.handles_size());
710 for (int i = 0; i < send_packed_handle.handles_size(); ++i) {
711 const auto& item = send_packed_handle.handles(i);
712 if (item.has_local_handle()) {
713 Tensor tensor;
714 if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) {
715 return errors::InvalidArgument(
716 "Invalid TensorProto: ",
717 item.local_handle().tensor().DebugString());
718 }
719 Device* op_device = nullptr;
720 TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
721 item.local_handle().device().c_str(), &op_device));
722 handles[i] = TensorHandle::CreateLocalHandle(
723 std::move(tensor), /*d=*/nullptr, op_device, eager_context);
724 } else {
725 TF_RETURN_IF_ERROR(
726 eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
727 item.remote_handle(), &handles[i]));
728 }
729 }
730
731 tensorflow::TensorHandle* packed_handle = nullptr;
732 std::vector<tensorflow::TensorHandle*> handles_to_pack = handles;
733 // Create a unshaped packed TensorHandle.
734 TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
735 std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
736 send_packed_handle.device_name(), eager_context, &packed_handle));
737
738 for (auto* h : handles) {
739 // Unref handle since it has a ref in the packed handle now.
740 h->Unref();
741 }
742
743 eager_context->RemoteMgr()->AddOperationOutputs({packed_handle},
744 send_packed_handle.op_id());
745 return Status::OK();
746 }
747
GetServerContext(uint64 context_id,ServerContext ** server_context)748 tensorflow::Status EagerServiceImpl::GetServerContext(
749 uint64 context_id, ServerContext** server_context) {
750 tf_shared_lock l(contexts_mu_);
751 auto iter = contexts_.find(context_id);
752 if (iter == contexts_.end()) {
753 *server_context = nullptr;
754 return errors::Unavailable(strings::Printf(
755 "Unable to find a context_id matching the specified one "
756 "(%llu). Perhaps the worker was restarted, or the context was GC'd?",
757 static_cast<unsigned long long>(context_id)));
758 }
759
760 *server_context = iter->second;
761 (*server_context)->Ref();
762
763 (*server_context)->RecordAccess();
764
765 return Status::OK();
766 }
767
768 } // namespace eager
769 } // namespace tensorflow
770