1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17
18 #include "absl/types/optional.h"
19 #include "absl/types/variant.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
24 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
25 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
26 #include "tensorflow/core/distributed_runtime/session_mgr.h"
27 #include "tensorflow/core/distributed_runtime/test_utils.h"
28 #include "tensorflow/core/distributed_runtime/worker_env.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/random/random.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/protobuf/eager_service.pb.h"
39 #include "tensorflow/core/protobuf/error_codes.pb.h"
40 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
41 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
42
43 namespace tensorflow {
44 namespace eager {
45 namespace {
46
47 class TestEagerServiceImpl : public EagerServiceImpl {
48 public:
TestEagerServiceImpl(const WorkerEnv * env)49 explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {}
GetEagerContext(const uint64 context_id,EagerContext ** ctx)50 Status GetEagerContext(const uint64 context_id, EagerContext** ctx) {
51 ServerContext* context = nullptr;
52 TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
53 core::ScopedUnref context_unref(context);
54 *ctx = context->Context();
55 return Status::OK();
56 }
GetTensorHandle(const uint64 context_id,const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)57 Status GetTensorHandle(const uint64 context_id,
58 const RemoteTensorHandleInternal& remote_handle,
59 tensorflow::TensorHandle** handle) {
60 ServerContext* context = nullptr;
61 TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
62 core::ScopedUnref context_unref(context);
63
64 return context->Context()->RemoteMgr()->GetTensorHandle(remote_handle,
65 handle);
66 }
67 };
68
69 class FakeEagerClient : public EagerClient {
70 public:
FakeEagerClient()71 FakeEagerClient() {}
~FakeEagerClient()72 ~FakeEagerClient() override {}
73
SetServiceImpl(TestEagerServiceImpl * impl)74 void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; }
75
76 #define CLIENT_METHOD(method) \
77 void method##Async(const method##Request* request, \
78 method##Response* response, StatusCallback done) \
79 override { \
80 done(impl_->method(request, response)); \
81 }
82
83 CLIENT_METHOD(CreateContext);
84 CLIENT_METHOD(UpdateContext);
85 CLIENT_METHOD(WaitQueueDone);
86 CLIENT_METHOD(KeepAlive);
87 CLIENT_METHOD(CloseContext);
88 #undef CLIENT_METHOD
89
EnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)90 void EnqueueAsync(CallOptions* call_opts, const EnqueueRequest* request,
91 EnqueueResponse* response, StatusCallback done) override {
92 done(impl_->Enqueue(call_opts, request, response));
93 }
94
RunComponentFunctionAsync(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)95 void RunComponentFunctionAsync(CallOptions* call_opts,
96 const RunComponentFunctionRequest* request,
97 RunComponentFunctionResponse* response,
98 StatusCallback done) override {
99 impl_->RunComponentFunction(call_opts, request, response, std::move(done));
100 }
101
StreamingEnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)102 void StreamingEnqueueAsync(CallOptions* call_opts,
103 const EnqueueRequest* request,
104 EnqueueResponse* response,
105 StatusCallback done) override {
106 done(impl_->Enqueue(nullptr, request, response));
107 }
108
allow_multiple_pending_requests() const109 bool allow_multiple_pending_requests() const override { return false; }
110
111 private:
112 TestEagerServiceImpl* impl_;
113 };
114
115 class DummyEagerClientCache : public EagerClientCache {
116 public:
DummyEagerClientCache()117 DummyEagerClientCache() : client_(new FakeEagerClient) {}
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)118 Status GetClient(const string& target,
119 core::RefCountPtr<EagerClient>* client) override {
120 client->reset(client_.get());
121 client_->Ref();
122 return Status::OK();
123 }
124
125 private:
126 core::RefCountPtr<EagerClient> client_;
127 };
128
129 class FakeCache : public TestWorkerCache {
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)130 Status GetEagerClientCache(
131 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
132 eager_client_cache->reset(new DummyEagerClientCache);
133 return Status::OK();
134 }
135
ListWorkers(std::vector<string> * workers) const136 void ListWorkers(std::vector<string>* workers) const override {
137 workers->push_back("/job:localhost/replica:0/task:0");
138 }
139 };
140
141 class EagerServiceImplTest : public ::testing::Test {
142 public:
EagerServiceImplTest()143 EagerServiceImplTest()
144 : rendezvous_mgr_(&worker_env_),
145 session_mgr_(new SessionMgr(
146 &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
147 std::unique_ptr<WorkerCacheInterface>(new FakeCache),
148 [](const ServerDef& server_def,
149 WorkerCacheInterface** worker_cache) {
150 *worker_cache = new FakeCache;
151 return Status::OK();
152 })) {
153 worker_env_.env = Env::Default();
154
155 worker_env_.rendezvous_mgr = &rendezvous_mgr_;
156 worker_env_.session_mgr = session_mgr_.get();
157
158 device_mgr_ = absl::make_unique<StaticDeviceMgr>(
159 DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
160 worker_env_.local_devices = device_mgr_->ListDevices();
161 worker_env_.device_mgr = device_mgr_.get();
162 }
163
164 protected:
165 WorkerEnv worker_env_;
166 tensorflow::RpcRendezvousMgr rendezvous_mgr_;
167 std::unique_ptr<SessionMgr> session_mgr_;
168 std::unique_ptr<DeviceMgr> device_mgr_;
169 };
170
SetTensorProto(TensorProto * tensor_proto)171 void SetTensorProto(TensorProto* tensor_proto) {
172 int64_t dims[] = {2, 2};
173 float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
174 TF_Tensor* t = TF_AllocateTensor(
175 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
176 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
177 tensorflow::Tensor tensor;
178 TF_ASSERT_OK(tensorflow::TF_TensorToTensor(t, &tensor));
179 tensor.AsProtoTensorContent(tensor_proto);
180 TF_DeleteTensor(t);
181 }
182
BuildOperation(Operation * operation,int64 id,const string & name,const std::vector<absl::variant<TensorProto,std::pair<int64,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device)183 void BuildOperation(
184 Operation* operation, int64 id, const string& name,
185 const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
186 inputs,
187 const std::unordered_map<string, AttrValue>& attrs, const string& device) {
188 operation->set_id(id);
189 operation->set_name(name);
190 operation->set_device(device);
191
192 for (const auto& input : inputs) {
193 if (input.index() == 0) {
194 *operation->add_op_inputs()->mutable_tensor() =
195 absl::get<TensorProto>(input);
196 } else {
197 const auto& tensor_handle_pair =
198 absl::get<std::pair<int64, int32>>(input);
199 auto* input = operation->add_op_inputs()->mutable_remote_handle();
200 input->set_op_id(tensor_handle_pair.first);
201 input->set_output_num(tensor_handle_pair.second);
202 input->set_op_device(device);
203 input->set_device(device);
204 }
205 }
206
207 for (const auto& attr_entry : attrs) {
208 (*operation->mutable_attrs())[attr_entry.first] = attr_entry.second;
209 }
210 }
211
AddOperationToEnqueueRequest(int64 id,const string & name,const std::vector<absl::variant<TensorProto,std::pair<int64,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,EnqueueRequest * request)212 void AddOperationToEnqueueRequest(
213 int64 id, const string& name,
214 const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
215 inputs,
216 const std::unordered_map<string, AttrValue>& attrs, const string& device,
217 EnqueueRequest* request) {
218 auto* operation = request->add_queue()->mutable_operation();
219 BuildOperation(operation, id, name, inputs, attrs, device);
220 }
221
AddOperationToRunComponentFunctionRequest(int64 id,const string & name,const std::vector<absl::variant<TensorProto,std::pair<int64,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,const int output_num,RunComponentFunctionRequest * request)222 void AddOperationToRunComponentFunctionRequest(
223 int64 id, const string& name,
224 const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
225 inputs,
226 const std::unordered_map<string, AttrValue>& attrs, const string& device,
227 const int output_num, RunComponentFunctionRequest* request) {
228 auto* operation = request->mutable_operation();
229 operation->set_is_function(true);
230 operation->set_is_component_function(true);
231 request->add_output_num(output_num);
232 BuildOperation(operation, id, name, inputs, attrs, device);
233 }
234
MatMulFunctionNodeDef()235 tensorflow::NodeDef MatMulFunctionNodeDef() {
236 tensorflow::NodeDef def;
237 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
238 " name: 'matmul_func'"
239 " op: 'MatMulFunction'"
240 " input: 'a'"
241 " input: 'a'"
242 " attr {"
243 " key: 'T'"
244 " value {"
245 " type: DT_FLOAT"
246 " }"
247 " }",
248 &def));
249 return def;
250 }
251
MatMulFunction()252 tensorflow::FunctionDef MatMulFunction() {
253 tensorflow::FunctionDef def;
254 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
255 " signature {"
256 " name: 'MatMulFunction'"
257 " input_arg {"
258 " name: 'a'"
259 " type: DT_FLOAT"
260 " }"
261 " output_arg {"
262 " name: 'm'"
263 " type: DT_FLOAT"
264 " }"
265 " }"
266 " node_def {"
267 " name: 'matmul'"
268 " op: 'MatMul'"
269 " input: 'a'"
270 " input: 'a'"
271 " attr {"
272 " key: 'T'"
273 " value {"
274 " type: DT_FLOAT"
275 " }"
276 " }"
277 " attr {"
278 " key: 'transpose_a'"
279 " value {"
280 " b: false"
281 " }"
282 " }"
283 " }"
284 " ret {"
285 " key: 'm'"
286 " value: 'matmul:product'"
287 " }",
288 &def));
289 return def;
290 }
291
MatMulNestedFunction()292 tensorflow::FunctionDef MatMulNestedFunction() {
293 tensorflow::FunctionDef def;
294 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
295 " signature {"
296 " name: 'MatMulNestedFunction'"
297 " input_arg {"
298 " name: 'a'"
299 " type: DT_FLOAT"
300 " }"
301 " output_arg {"
302 " name: 'matmul_nested'"
303 " type: DT_FLOAT"
304 " }"
305 " }"
306 " node_def {"
307 " name: 'matmul_nested'"
308 " op: 'MatMulFunction'"
309 " input: 'a'"
310 " attr {"
311 " key: 'T'"
312 " value {"
313 " type: DT_FLOAT"
314 " }"
315 " }"
316 " }"
317 " ret {"
318 " key: 'matmul_nested'"
319 " value: 'matmul_nested:m:0'"
320 " }",
321 &def));
322 return def;
323 }
324
SingleRecvNodeFunction()325 tensorflow::FunctionDef SingleRecvNodeFunction() {
326 tensorflow::FunctionDef def;
327 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
328 " signature {"
329 " name: 'SingleRecvNodeFunction'"
330 " input_arg {"
331 " name: 'a'"
332 " type: DT_FLOAT"
333 " }"
334 " output_arg {"
335 " name: 'recv_tensor'"
336 " type: DT_FLOAT"
337 " }"
338 " }"
339 " node_def {"
340 " name: 'recv_node'"
341 " op: '_Recv'"
342 " device: '/job:localhost/replica:0/task:0/device:CPU:0'"
343 " attr {"
344 " key: 'client_terminated'"
345 " value {"
346 " b: true"
347 " }"
348 " }"
349 " attr {"
350 " key: 'recv_device'"
351 " value {"
352 " s: '/job:localhost/replica:0/task:0/device:CPU:0'"
353 " }"
354 " }"
355 " attr {"
356 " key: 'send_device'"
357 " value {"
358 " s: '/job:localhost/replica:0/task:0/device:CPU:0'"
359 " }"
360 " }"
361 " attr {"
362 " key: 'send_device_incarnation'"
363 " value {"
364 " i: 1"
365 " }"
366 " }"
367 " attr {"
368 " key: 'tensor_name'"
369 " value {"
370 " s: 't0'"
371 " }"
372 " }"
373 " attr {"
374 " key: 'tensor_type'"
375 " value {"
376 " type: DT_FLOAT"
377 " }"
378 " }"
379 " }"
380 " ret {"
381 " key: 'recv_tensor'"
382 " value: 'recv_node:tensor:0'"
383 " }",
384 &def));
385 return def;
386 }
387
388 // Test creates a context and attempts to execute some ops.
TEST_F(EagerServiceImplTest,BasicTest)389 TEST_F(EagerServiceImplTest, BasicTest) {
390 TestEagerServiceImpl eager_service_impl(&worker_env_);
391
392 uint64 context_id = random::New64();
393
394 CreateContextRequest request;
395 request.mutable_server_def()->set_job_name("localhost");
396 request.mutable_server_def()->set_task_index(0);
397 request.set_context_id(context_id);
398 CreateContextResponse response;
399
400 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
401
402 EnqueueRequest remote_enqueue_request;
403 remote_enqueue_request.set_context_id(context_id);
404 EnqueueResponse remote_enqueue_response;
405
406 std::unordered_map<string, AttrValue> const_attrs;
407 AttrValue val;
408 val.set_type(tensorflow::DataType::DT_FLOAT);
409 const_attrs.insert({"dtype", val});
410 val.Clear();
411 SetTensorProto(val.mutable_tensor());
412 const_attrs.insert({"value", val});
413
414 AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
415 "/job:localhost/replica:0/task:0/device:CPU:0",
416 &remote_enqueue_request);
417
418 std::unordered_map<string, AttrValue> attrs;
419 val.Clear();
420 val.set_type(tensorflow::DataType::DT_FLOAT);
421 attrs.insert({"T", val});
422 val.Clear();
423 val.set_b(false);
424 attrs.insert({"transpose_a", val});
425 attrs.insert({"transpose_b", val});
426
427 AddOperationToEnqueueRequest(
428 2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
429 "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
430
431 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
432 &remote_enqueue_response));
433
434 auto& matmul_result_shape =
435 remote_enqueue_response.queue_response(1).shape(0);
436 EXPECT_EQ(matmul_result_shape.dim(0).size(), 2);
437 EXPECT_EQ(matmul_result_shape.dim(1).size(), 2);
438
439 tensorflow::TensorHandle* tensor_handle;
440 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
441 context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
442
443 // This should be OK to do since we've placed all computation on the CPU
444 // device.
445 const tensorflow::Tensor* t = nullptr;
446 TF_ASSERT_OK(tensor_handle->Tensor(&t));
447
448 auto actual = t->flat<float>();
449
450 EXPECT_EQ(4, actual.size());
451
452 EXPECT_EQ(7, actual(0));
453 EXPECT_EQ(10, actual(1));
454 EXPECT_EQ(15, actual(2));
455 EXPECT_EQ(22, actual(3));
456
457 CloseContextRequest close_context_request;
458 close_context_request.set_context_id(context_id);
459 close_context_request.set_context_view_id(0);
460 CloseContextResponse close_context_response;
461 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
462 &close_context_response));
463 }
464
465 class EagerServiceImplFunctionTest : public EagerServiceImplTest {
466 public:
EagerServiceImplFunctionTest()467 EagerServiceImplFunctionTest() : EagerServiceImplTest() {}
468
469 // Creates a context and attempts to execute a function.
TestFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool local_inputs=false,const bool test_cancel=false)470 void TestFunction(const RegisterFunctionOp& register_op,
471 const string& function_name,
472 const bool local_inputs = false,
473 const bool test_cancel = false) {
474 TestEagerServiceImpl eager_service_impl(&worker_env_);
475
476 uint64 context_id = random::New64();
477
478 CreateContextRequest request;
479 request.mutable_server_def()->set_job_name("localhost");
480 request.mutable_server_def()->set_task_index(0);
481 request.set_context_id(context_id);
482 CreateContextResponse response;
483
484 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
485
486 EnqueueRequest enqueue_request;
487 enqueue_request.set_context_id(context_id);
488 *enqueue_request.add_queue()->mutable_register_function() = register_op;
489 EnqueueResponse enqueue_response;
490
491 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
492 &enqueue_response));
493
494 EnqueueRequest remote_enqueue_request;
495 remote_enqueue_request.set_context_id(context_id);
496 EnqueueResponse remote_enqueue_response;
497
498 if (local_inputs) {
499 TensorProto tensor_proto;
500 SetTensorProto(&tensor_proto);
501 AddOperationToEnqueueRequest(
502 2, function_name, {tensor_proto},
503 std::unordered_map<string, AttrValue>(),
504 "/job:localhost/replica:0/task:0/device:CPU:0",
505 &remote_enqueue_request);
506
507 } else {
508 std::unordered_map<string, AttrValue> const_attrs;
509 AttrValue val;
510 val.set_type(tensorflow::DataType::DT_FLOAT);
511 const_attrs.insert({"dtype", val});
512 val.Clear();
513
514 SetTensorProto(val.mutable_tensor());
515 const_attrs.insert({"value", val});
516
517 AddOperationToEnqueueRequest(
518 1, "Const", {}, const_attrs,
519 "/job:localhost/replica:0/task:0/device:CPU:0",
520 &remote_enqueue_request);
521 AddOperationToEnqueueRequest(
522 2, function_name, {std::make_pair(1, 0)},
523 std::unordered_map<string, AttrValue>(),
524 "/job:localhost/replica:0/task:0/device:CPU:0",
525 &remote_enqueue_request);
526 }
527
528 CallOptions call_opts;
529 Status status;
530 Notification n;
531 Env::Default()->SchedClosure([&] {
532 status = eager_service_impl.Enqueue(&call_opts, &remote_enqueue_request,
533 &remote_enqueue_response);
534 n.Notify();
535 });
536
537 if (test_cancel) {
538 // Wait to let the Enqueue thread starts running
539 Env::Default()->SleepForMicroseconds(500000);
540 call_opts.StartCancel();
541 n.WaitForNotification();
542 EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
543 } else {
544 n.WaitForNotification();
545 TF_ASSERT_OK(status);
546 const tensorflow::Tensor* t = nullptr;
547 tensorflow::TensorHandle* tensor_handle;
548 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
549 context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
550 TF_ASSERT_OK(tensor_handle->Tensor(&t));
551
552 auto actual = t->flat<float>();
553 EXPECT_EQ(4, actual.size());
554
555 EXPECT_EQ(7, actual(0));
556 EXPECT_EQ(10, actual(1));
557 EXPECT_EQ(15, actual(2));
558 EXPECT_EQ(22, actual(3));
559 }
560
561 CloseContextRequest close_context_request;
562 close_context_request.set_context_id(context_id);
563 close_context_request.set_context_view_id(0);
564 CloseContextResponse close_context_response;
565 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
566 &close_context_response));
567 }
568
569 // Creates a context and attempts to execute a component function.
TestComponentFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool test_cancel)570 void TestComponentFunction(const RegisterFunctionOp& register_op,
571 const string& function_name,
572 const bool test_cancel) {
573 TestEagerServiceImpl eager_service_impl(&worker_env_);
574 uint64 context_id = random::New64();
575
576 // Create context.
577 CreateContextRequest request;
578 request.mutable_server_def()->set_job_name("localhost");
579 request.mutable_server_def()->set_task_index(0);
580 request.set_context_id(context_id);
581 CreateContextResponse response;
582 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
583
584 // Register function.
585 EnqueueRequest enqueue_request;
586 enqueue_request.set_context_id(context_id);
587 *enqueue_request.add_queue()->mutable_register_function() = register_op;
588 EnqueueResponse enqueue_response;
589 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
590 &enqueue_response));
591
592 // First run an op to generate input for function.
593 EnqueueRequest remote_enqueue_request;
594 remote_enqueue_request.set_context_id(context_id);
595 EnqueueResponse remote_enqueue_response;
596
597 std::unordered_map<string, AttrValue> const_attrs;
598 AttrValue val;
599 val.set_type(tensorflow::DataType::DT_FLOAT);
600 const_attrs.insert({"dtype", val});
601 val.Clear();
602 SetTensorProto(val.mutable_tensor());
603 const_attrs.insert({"value", val});
604 AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
605 "/job:localhost/replica:0/task:0/device:CPU:0",
606 &remote_enqueue_request);
607 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
608 &remote_enqueue_response));
609
610 // Run function with input from the previous op.
611 RunComponentFunctionRequest run_comp_func_request;
612 run_comp_func_request.set_context_id(context_id);
613 RunComponentFunctionResponse run_comp_func_response;
614 const int output_num = 5;
615 AddOperationToRunComponentFunctionRequest(
616 2, function_name, {std::make_pair(1, 0)},
617 std::unordered_map<string, AttrValue>(),
618 "/job:localhost/replica:0/task:0/device:CPU:0", output_num,
619 &run_comp_func_request);
620
621 CallOptions call_opts;
622 Notification n;
623 Status status;
624 eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request,
625 &run_comp_func_response,
626 [&status, &n](const Status& s) {
627 status.Update(s);
628 n.Notify();
629 });
630 if (test_cancel) {
631 call_opts.StartCancel();
632 }
633 n.WaitForNotification();
634 if (test_cancel) {
635 EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
636 } else {
637 TF_ASSERT_OK(status);
638 // Retrieve the output.
639 const tensorflow::Tensor* t = nullptr;
640 tensorflow::TensorHandle* tensor_handle;
641 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
642 context_id, RemoteTensorHandleInternal(2, output_num),
643 &tensor_handle));
644 TF_ASSERT_OK(tensor_handle->Tensor(&t));
645
646 auto actual = t->flat<float>();
647 EXPECT_EQ(4, actual.size());
648
649 EXPECT_EQ(7, actual(0));
650 EXPECT_EQ(10, actual(1));
651 EXPECT_EQ(15, actual(2));
652 EXPECT_EQ(22, actual(3));
653 }
654
655 CloseContextRequest close_context_request;
656 close_context_request.set_context_id(context_id);
657 close_context_request.set_context_view_id(0);
658 CloseContextResponse close_context_response;
659 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
660 &close_context_response));
661 }
662 };
663
TEST_F(EagerServiceImplFunctionTest,BasicFunctionTest)664 TEST_F(EagerServiceImplFunctionTest, BasicFunctionTest) {
665 RegisterFunctionOp register_op;
666 *register_op.mutable_function_def() = MatMulFunction();
667 TestFunction(register_op, "MatMulFunction");
668 }
669
TEST_F(EagerServiceImplFunctionTest,FunctionWithLocalInputsTest)670 TEST_F(EagerServiceImplFunctionTest, FunctionWithLocalInputsTest) {
671 RegisterFunctionOp register_op;
672 *register_op.mutable_function_def() = MatMulFunction();
673 TestFunction(register_op, "MatMulFunction", /*local_inputs=*/true);
674 }
675
TEST_F(EagerServiceImplFunctionTest,NestedFunctionTest)676 TEST_F(EagerServiceImplFunctionTest, NestedFunctionTest) {
677 RegisterFunctionOp register_op;
678 *register_op.mutable_function_def() = MatMulNestedFunction();
679 *register_op.mutable_library()->add_function() = MatMulFunction();
680 TestFunction(register_op, "MatMulNestedFunction");
681 }
682
TEST_F(EagerServiceImplFunctionTest,FunctionCancellationTest)683 TEST_F(EagerServiceImplFunctionTest, FunctionCancellationTest) {
684 RegisterFunctionOp register_op;
685 *register_op.mutable_function_def() = SingleRecvNodeFunction();
686 TestFunction(register_op, "SingleRecvNodeFunction", /*local_inputs=*/false,
687 /*test_cancel=*/true);
688 }
689
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionTest)690 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionTest) {
691 RegisterFunctionOp register_op;
692 *register_op.mutable_function_def() = MatMulFunction();
693 TestComponentFunction(register_op, "MatMulFunction", false);
694 }
695
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionCancellationTest)696 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionCancellationTest) {
697 RegisterFunctionOp register_op;
698 *register_op.mutable_function_def() = SingleRecvNodeFunction();
699 TestComponentFunction(register_op, "SingleRecvNodeFunction", true);
700 }
701
702 class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
703 public:
FunctionWithRemoteInputsTest()704 FunctionWithRemoteInputsTest()
705 : EagerServiceImplTest(), eager_service_impl_(&worker_env_) {
706 remote_device_mgr_ = absl::make_unique<StaticDeviceMgr>(
707 DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
708 context_id_ = random::New64();
709 }
710
711 class TestExecuteNodeArgs : public EagerKernelArgs {
712 public:
TestExecuteNodeArgs(gtl::InlinedVector<TensorValue,4> && tensor_args,std::function<Status (const int,eager::RemoteTensorHandle *)> serialize_remote_handle)713 TestExecuteNodeArgs(
714 gtl::InlinedVector<TensorValue, 4>&& tensor_args,
715 std::function<Status(const int, eager::RemoteTensorHandle*)>
716 serialize_remote_handle)
717 : EagerKernelArgs(std::move(tensor_args)),
718 serialize_remote_handle_(std::move(serialize_remote_handle)) {}
719
HasRemoteOrPackedInputs() const720 bool HasRemoteOrPackedInputs() const override { return true; }
721
GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val) const722 Status GetRemoteArg(const FunctionArgIndex& index,
723 eager::RemoteTensorHandle* val) const override {
724 return serialize_remote_handle_(index.index, val);
725 }
726
727 private:
728 std::function<Status(const int, eager::RemoteTensorHandle*)>
729 serialize_remote_handle_;
730 };
731
MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef & fdef)732 bool MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef& fdef) {
733 for (const auto& node : fdef.node_def()) {
734 if (node.op() == "MatMul") {
735 return node.attr().find("transpose_a") != node.attr().end();
736 }
737 }
738 return false;
739 }
740
Init()741 void Init() {
742 CreateContextRequest request;
743 request.mutable_server_def()->set_job_name("localhost");
744 request.mutable_server_def()->set_task_index(0);
745 request.set_context_id(context_id_);
746 CreateContextResponse response;
747 TF_ASSERT_OK(eager_service_impl_.CreateContext(&request, &response));
748
749 // Make the fake EagerClient use the local eager_service_impl.
750 EagerContext* ctx = nullptr;
751 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
752 Device* device;
753 TF_ASSERT_OK(ctx->FindDeviceFromName(local_device_.c_str(), &device));
754 core::RefCountPtr<EagerClient> client;
755 TF_ASSERT_OK(ctx->GetClient(device, &client));
756 FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client.get());
757 fake_client->SetServiceImpl(&eager_service_impl_);
758
759 // Create an input on local_device for MatMulFunction.
760 EnqueueRequest remote_enqueue_request;
761 remote_enqueue_request.set_context_id(context_id_);
762 EnqueueResponse remote_enqueue_response;
763 std::unordered_map<string, AttrValue> const_attrs;
764 AttrValue val;
765 val.set_type(tensorflow::DataType::DT_FLOAT);
766 const_attrs.insert({"dtype", val});
767 val.Clear();
768 SetTensorProto(val.mutable_tensor());
769 const_attrs.insert({"value", val});
770 AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, local_device_,
771 &remote_enqueue_request);
772 TF_EXPECT_OK(eager_service_impl_.Enqueue(nullptr, &remote_enqueue_request,
773 &remote_enqueue_response));
774 eager_cluster_flr_ = absl::make_unique<EagerClusterFunctionLibraryRuntime>(
775 context_id_, ctx, device_mgr_.get());
776
777 fdef_ = MatMulFunction();
778 TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
779 eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
780 remote_device_mgr_.get(), Env::Default(), /*config=*/
781 nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
782 /*thread_pool=*/nullptr, eager_cluster_flr_.get(),
783 /*session_metadata=*/nullptr,
784 Rendezvous::Factory{[this](const int64 step_id,
785 const DeviceMgr* device_mgr,
786 Rendezvous** r) {
787 *r = worker_env_.rendezvous_mgr->Find(step_id);
788 return Status::OK();
789 }});
790 }
791
CheckOutputTensorAndClose(const Tensor & tensor)792 void CheckOutputTensorAndClose(const Tensor& tensor) {
793 auto actual = tensor.flat<float>();
794 EXPECT_EQ(4, actual.size());
795 EXPECT_EQ(7, actual(0));
796 EXPECT_EQ(10, actual(1));
797 EXPECT_EQ(15, actual(2));
798 EXPECT_EQ(22, actual(3));
799
800 CloseContextRequest close_context_request;
801 close_context_request.set_context_id(context_id_);
802 close_context_request.set_context_view_id(0);
803 CloseContextResponse close_context_response;
804 TF_ASSERT_OK(eager_service_impl_.CloseContext(&close_context_request,
805 &close_context_response));
806 }
807
CheckOutputsAndClose(const std::vector<FunctionRet> & outputs,const int64 op_id)808 void CheckOutputsAndClose(const std::vector<FunctionRet>& outputs,
809 const int64 op_id) {
810 const tensorflow::Tensor* t = nullptr;
811 tensorflow::TensorHandle* tensor_handle;
812 TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
813 context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle));
814 TF_ASSERT_OK(tensor_handle->Tensor(&t));
815 EXPECT_EQ(outputs.size(), 1);
816 EXPECT_EQ(outputs.at(0).index(), 1);
817 const TensorShape& shape = absl::get<TensorShape>(outputs.at(0));
818 EXPECT_EQ(shape, t->shape());
819 CheckOutputTensorAndClose(*t);
820 }
821
822 protected:
823 const string local_device_ = "/job:localhost/replica:0/task:0/device:CPU:0";
824 const string remote_device_ = "/job:localhost/replica:0/task:1/device:CPU:0";
825 TestEagerServiceImpl eager_service_impl_;
826 std::unique_ptr<DeviceMgr> remote_device_mgr_;
827 uint64 context_id_;
828 tensorflow::FunctionDef fdef_;
829 std::unique_ptr<ProcessFunctionLibraryRuntime> eager_pflr_;
830 std::unique_ptr<EagerClusterFunctionLibraryRuntime> eager_cluster_flr_;
831 FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
832 };
833
834 // Test executes a remote function through
835 // ProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime).
TEST_F(FunctionWithRemoteInputsTest,EagerPFLRTest)836 TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
837 Init();
838 // Instantiate MatMulFunction on remote_device.
839 FunctionLibraryRuntime::InstantiateOptions options;
840 options.target = remote_device_;
841 options.is_multi_device_function = true;
842 options.input_devices.push_back(local_device_);
843 FunctionLibraryRuntime::Handle handle;
844 EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
845 TF_ASSERT_OK(eager_pflr_->Instantiate(
846 fdef_.signature().name(), AttrSlice(&fdef_.attr()), options, &handle));
847 EagerContext* ctx = nullptr;
848 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
849 for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
850 const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
851 EXPECT_TRUE(fdef != nullptr);
852 if (absl::StartsWith(func_name, "MatMulFunction")) {
853 EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
854 }
855 }
856 bool is_cross_process = false;
857 TF_CHECK_OK(eager_pflr_->IsCrossProcess(handle, &is_cross_process));
858 EXPECT_TRUE(is_cross_process);
859
860 // Run MatMulFunction on remote_device.
861 FunctionLibraryRuntime::Options opts;
862 const uint64 op_id = 2;
863 opts.op_id = op_id;
864 Notification done;
865 Status status;
866 RemoteTensorHandle input;
867 input.set_op_id(1);
868 input.set_output_num(0);
869 input.set_op_device(local_device_);
870 input.set_device(local_device_);
871 std::vector<RemoteTensorHandle> inputs = {input};
872 std::vector<FunctionRet> outputs;
873 gtl::InlinedVector<TensorValue, 4> tensor_args = {TensorValue()};
874 TestExecuteNodeArgs args(
875 std::move(tensor_args),
876 [&inputs](const int i, RemoteTensorHandle* handle) -> Status {
877 *handle = inputs.at(i);
878 return Status::OK();
879 });
880 eager_pflr_->Run(opts, handle, args, &outputs,
881 [&status, &done](const Status& s) {
882 status = s;
883 done.Notify();
884 });
885 done.WaitForNotification();
886 TF_ASSERT_OK(status);
887 CheckOutputsAndClose(outputs, op_id);
888 }
889
890 // Test executes a remote function with local input and output tensors.
TEST_F(FunctionWithRemoteInputsTest,EagerClusterFLRTestWithLocalInputAndOutput)891 TEST_F(FunctionWithRemoteInputsTest,
892 EagerClusterFLRTestWithLocalInputAndOutput) {
893 Init();
894 // Instantiate MatMulFunction on remote_device.
895 FunctionLibraryRuntime::Handle handle;
896 EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
897 Status status;
898 Notification instantiate_done;
899 eager_cluster_flr_->Instantiate(
900 fdef_.signature().name(), func_lib_def_, AttrSlice(&fdef_.attr()),
901 FunctionLibraryRuntime::InstantiateOptions(), &handle,
902 [&status, &instantiate_done](const Status& s) {
903 status = s;
904 instantiate_done.Notify();
905 });
906 instantiate_done.WaitForNotification();
907 TF_ASSERT_OK(status);
908 EagerContext* ctx = nullptr;
909 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
910 for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
911 const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
912 EXPECT_TRUE(fdef != nullptr);
913 if (absl::StartsWith(func_name, "MatMulFunction")) {
914 EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
915 }
916 }
917 const tensorflow::Tensor* input_tensor = nullptr;
918 tensorflow::TensorHandle* tensor_handle;
919 TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
920 context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle));
921 TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor));
922
923 // Send input_tensor to the remote device, execute MatMulFunction on the
924 // remote device, and send the output back.
925 FunctionLibraryRuntime::Options opts;
926 Notification execute_done;
927 std::vector<Tensor> inputs = {*input_tensor};
928 std::vector<Tensor> outputs;
929 eager_cluster_flr_->Run(opts, handle, inputs, &outputs,
930 [&status, &execute_done](const Status& s) {
931 status = s;
932 execute_done.Notify();
933 });
934 execute_done.WaitForNotification();
935 TF_ASSERT_OK(status);
936 EXPECT_EQ(outputs.size(), 1);
937 CheckOutputTensorAndClose(outputs.at(0));
938 }
939
940 // Test executes a remote function through KernelAndDeviceFunc::Run.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncTest)941 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
942 Init();
943 Device* local_device;
944 TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
945 std::vector<Device*> input_dev_ptrs;
946 input_dev_ptrs.push_back(local_device);
947 FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
948 EagerContext* ctx = nullptr;
949 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
950 core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
951 const int64 op_id = 2;
952 kernel.reset(new KernelAndDeviceFunc(
953 flr, eager_pflr_.get(), std::move(input_dev_ptrs),
954 /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
955 /*runner=*/nullptr,
956 /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
957 /*outputs_on_op_device=*/false, ctx->RendezvousCreator(),
958 [=]() { return op_id; }));
959
960 // Instantiate MatMulFunction on remote_device.
961 const NodeDef node_def = MatMulFunctionNodeDef();
962 TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
963
964 // Run MatMulFunction on remote_device.
965 gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
966 RemoteTensorHandle input;
967 input.set_op_id(1);
968 input.set_output_num(0);
969 input.set_op_device(local_device_);
970 input.set_device(local_device_);
971 std::vector<RemoteTensorHandle> remote_handles = {input};
972 TestExecuteNodeArgs inputs(
973 std::move(input_tensors),
974 [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
975 *handle = remote_handles.at(index);
976 return Status::OK();
977 });
978 std::vector<FunctionRet> outputs;
979
980 TF_ASSERT_OK(kernel->Run(/*step_container=*/nullptr, inputs, &outputs,
981 /*cancellation_manager=*/nullptr,
982 /*remote_func_params=*/absl::nullopt,
983 /*stack_trace=*/absl::nullopt));
984
985 CheckOutputsAndClose(outputs, op_id);
986 }
987
988 // Test executes a remote function through KernelAndDeviceFunc::RunAsync.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncAsyncTest)989 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
990 Init();
991 Device* local_device;
992 TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
993 std::vector<Device*> input_dev_ptrs;
994 input_dev_ptrs.push_back(local_device);
995 FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
996 EagerContext* ctx = nullptr;
997 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
998 core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
999 const int64 op_id = 2;
1000 kernel.reset(new KernelAndDeviceFunc(
1001 flr, eager_pflr_.get(), std::move(input_dev_ptrs),
1002 /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
1003 /*runner=*/nullptr,
1004 /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
1005 /*outputs_on_op_device=*/false, ctx->RendezvousCreator(),
1006 [=]() { return op_id; }));
1007
1008 // Instantiate MatMulFunction on remote_device.
1009 const NodeDef node_def = MatMulFunctionNodeDef();
1010 TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
1011
1012 // Run MatMulFunction on remote_device.
1013 gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
1014 RemoteTensorHandle input;
1015 input.set_op_id(1);
1016 input.set_output_num(0);
1017 input.set_op_device(local_device_);
1018 input.set_device(local_device_);
1019 std::vector<RemoteTensorHandle> remote_handles = {input};
1020 TestExecuteNodeArgs inputs(
1021 std::move(input_tensors),
1022 [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
1023 *handle = remote_handles.at(index);
1024 return Status::OK();
1025 });
1026 std::vector<FunctionRet> outputs;
1027
1028 Status status;
1029 Notification n;
1030 kernel->RunAsync(/*step_container=*/nullptr, inputs, &outputs,
1031 /*cancellation_manager=*/nullptr,
1032 /*remote_func_params=*/absl::nullopt,
1033 [&status, &n](const Status& s) {
1034 status = s;
1035 n.Notify();
1036 });
1037 n.WaitForNotification();
1038 TF_ASSERT_OK(status);
1039 CheckOutputsAndClose(outputs, op_id);
1040 }
1041
1042 // Test creates a context and attempts to send a tensor (using the RPC), and
1043 // then use the tensor.
TEST_F(EagerServiceImplTest,SendTensorTest)1044 TEST_F(EagerServiceImplTest, SendTensorTest) {
1045 TestEagerServiceImpl eager_service_impl(&worker_env_);
1046
1047 uint64 context_id = random::New64();
1048
1049 CreateContextRequest request;
1050 request.mutable_server_def()->set_job_name("localhost");
1051 request.mutable_server_def()->set_task_index(0);
1052 request.set_context_id(context_id);
1053 CreateContextResponse response;
1054
1055 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1056
1057 EnqueueRequest remote_enqueue_request;
1058 remote_enqueue_request.set_context_id(context_id);
1059 EnqueueResponse remote_enqueue_response;
1060
1061 auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1062 send_tensor->set_op_id(1);
1063 SetTensorProto(send_tensor->add_tensors());
1064
1065 std::unordered_map<string, AttrValue> attrs;
1066 AttrValue val;
1067 val.Clear();
1068 val.set_type(tensorflow::DataType::DT_FLOAT);
1069 attrs.insert({"T", val});
1070 val.Clear();
1071 val.set_b(false);
1072 attrs.insert({"transpose_a", val});
1073 attrs.insert({"transpose_b", val});
1074
1075 AddOperationToEnqueueRequest(
1076 2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
1077 "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
1078
1079 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1080 &remote_enqueue_response));
1081
1082 const tensorflow::Tensor* t = nullptr;
1083 tensorflow::TensorHandle* tensor_handle;
1084 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1085 context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
1086 TF_ASSERT_OK(tensor_handle->Tensor(&t));
1087
1088 EXPECT_EQ(tensor_handle->device(), nullptr);
1089
1090 auto actual = t->flat<float>();
1091 EXPECT_EQ(4, actual.size());
1092
1093 EXPECT_EQ(7, actual(0));
1094 EXPECT_EQ(10, actual(1));
1095 EXPECT_EQ(15, actual(2));
1096 EXPECT_EQ(22, actual(3));
1097
1098 CloseContextRequest close_context_request;
1099 close_context_request.set_context_id(context_id);
1100 close_context_request.set_context_view_id(0);
1101 CloseContextResponse close_context_response;
1102 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1103 &close_context_response));
1104 }
1105
1106 // Test serializes and sends a pack TensorHandle.
TEST_F(EagerServiceImplTest,SendPackedHandleTest)1107 TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
1108 TestEagerServiceImpl eager_service_impl(&worker_env_);
1109
1110 const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
1111 const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
1112 const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
1113 const string composite_device =
1114 "/job:localhost/replica:0/task:0/device:COMPOSITE:0";
1115
1116 uint64 context_id = random::New64();
1117 CreateContextRequest request;
1118 auto* server_def = request.mutable_server_def();
1119 server_def->set_job_name("localhost");
1120 server_def->set_task_index(0);
1121 request.add_cluster_device_attributes()->set_name(device0);
1122 request.add_cluster_device_attributes()->set_name(device1);
1123 request.add_cluster_device_attributes()->set_name(device2);
1124 request.set_context_id(context_id);
1125 CreateContextResponse response;
1126
1127 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1128
1129 EnqueueRequest remote_enqueue_request;
1130 remote_enqueue_request.set_context_id(context_id);
1131 EnqueueResponse remote_enqueue_response;
1132
1133 // Copy a tensor to device0
1134 auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1135 send_tensor->set_op_id(1);
1136 SetTensorProto(send_tensor->add_tensors());
1137
1138 // Copy a packed handle to device0
1139 auto* send_packed_handle =
1140 remote_enqueue_request.add_queue()->mutable_send_packed_handle();
1141 send_packed_handle->set_op_id(3);
1142 RemoteTensorHandle* remote_handle =
1143 send_packed_handle->add_handles()->mutable_remote_handle();
1144 remote_handle->set_op_id(send_tensor->op_id());
1145 remote_handle->set_output_num(0);
1146 remote_handle->set_op_device(device0);
1147 remote_handle->set_device(device0);
1148
1149 SendPackedHandleOp::LocalTensorHandle* lcoal_handle =
1150 send_packed_handle->add_handles()->mutable_local_handle();
1151 SetTensorProto(lcoal_handle->mutable_tensor());
1152 lcoal_handle->set_device(device1);
1153
1154 remote_handle = send_packed_handle->add_handles()->mutable_remote_handle();
1155 remote_handle->set_op_id(2);
1156 remote_handle->set_output_num(5);
1157 remote_handle->set_op_device(device2);
1158 remote_handle->set_device(device2);
1159
1160 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1161 &remote_enqueue_response));
1162
1163 tensorflow::TensorHandle* packed_handle;
1164 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1165 context_id, RemoteTensorHandleInternal(3, 0), &packed_handle));
1166
1167 EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
1168 EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
1169 EXPECT_EQ(packed_handle->device()->name(), composite_device);
1170
1171 TensorHandle* handle0 = nullptr;
1172 TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
1173 EXPECT_EQ(handle0->Type(), TensorHandle::LOCAL);
1174 EXPECT_EQ(handle0->op_device()->name(), device0);
1175 const Tensor* t0 = nullptr;
1176 TF_ASSERT_OK(handle0->Tensor(&t0));
1177 auto actual = t0->flat<float>();
1178 EXPECT_EQ(4, actual.size());
1179 EXPECT_EQ(1.0, actual(0));
1180 EXPECT_EQ(2.0, actual(1));
1181 EXPECT_EQ(3.0, actual(2));
1182 EXPECT_EQ(4.0, actual(3));
1183
1184 TensorHandle* handle1 = nullptr;
1185 TF_ASSERT_OK(packed_handle->ExtractPackedHandle(1, &handle1));
1186 EXPECT_EQ(handle1->Type(), TensorHandle::LOCAL);
1187 EXPECT_EQ(handle1->op_device()->name(), device1);
1188 const Tensor* t1 = nullptr;
1189 TF_ASSERT_OK(handle0->Tensor(&t1));
1190 EXPECT_EQ(t1, t0);
1191
1192 TensorHandle* handle2 = nullptr;
1193 TF_ASSERT_OK(packed_handle->ExtractPackedHandle(2, &handle2));
1194 EXPECT_EQ(handle2->Type(), TensorHandle::REMOTE);
1195 EXPECT_EQ(handle2->op_device()->name(), device2);
1196 int64 op_id;
1197 int32 output_num;
1198 TF_ASSERT_OK(handle2->RemoteAddress(handle2->device(),
1199 /*wait_until_ready=*/true, &op_id,
1200 &output_num));
1201 EXPECT_EQ(op_id, 2);
1202 EXPECT_EQ(output_num, 5);
1203
1204 CloseContextRequest close_context_request;
1205 close_context_request.set_context_id(context_id);
1206 close_context_request.set_context_view_id(0);
1207 CloseContextResponse close_context_response;
1208 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1209 &close_context_response));
1210 }
1211
1212 // Test requests sent to the eager service on master.
TEST_F(EagerServiceImplTest,RequestsToMasterTest)1213 TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
1214 tensorflow::Rendezvous* rendezvous =
1215 new tensorflow::IntraProcessRendezvous(device_mgr_.get());
1216 // Create a master eager context.
1217 tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
1218 SessionOptions(),
1219 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
1220 /*async=*/false, device_mgr_.get(), false, rendezvous);
1221 const uint64 context_id = random::New64();
1222
1223 // Set RemoteMgr to ctx.
1224 auto remote_mgr =
1225 absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true, ctx);
1226 TF_ASSERT_OK(ctx->InitializeRemoteWorker(
1227 /*remote_eager_workers=*/nullptr, /*remote_device_mgr=*/nullptr,
1228 /*remote_contexts=*/{}, context_id, /*context_view_id=*/0,
1229 /*rendezvous_creator=*/nullptr,
1230 /*cluster_flr=*/nullptr, std::move(remote_mgr),
1231 /*resource_deallocator=*/nullptr));
1232
1233 TestEagerServiceImpl eager_service_impl(&worker_env_);
1234
1235 EnqueueRequest remote_enqueue_request;
1236 remote_enqueue_request.set_context_id(context_id);
1237 EnqueueResponse remote_enqueue_response;
1238
1239 auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1240 send_tensor->set_op_id(1);
1241 SetTensorProto(send_tensor->add_tensors());
1242
1243 // Unable to handle the request since there is no eager context.
1244 Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1245 &remote_enqueue_response);
1246 EXPECT_EQ(error::UNAVAILABLE, status.code());
1247 EXPECT_TRUE(absl::StrContains(
1248 status.error_message(),
1249 "Unable to find a context_id matching the specified one"));
1250
1251 // The request can be handled after adding the master eager context to
1252 // service.
1253 TF_ASSERT_OK(eager_service_impl.CreateMasterContext(context_id, ctx));
1254 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1255 &remote_enqueue_response));
1256 ctx->Unref();
1257 }
1258
TEST_F(EagerServiceImplTest,KeepAliveTest)1259 TEST_F(EagerServiceImplTest, KeepAliveTest) {
1260 TestEagerServiceImpl eager_service_impl(&worker_env_);
1261
1262 uint64 context_id = random::New64();
1263 CreateContextRequest request;
1264 request.mutable_server_def()->set_job_name("localhost");
1265 request.mutable_server_def()->set_task_index(0);
1266 request.set_context_id(context_id);
1267 request.set_keep_alive_secs(3);
1268 CreateContextResponse response;
1269
1270 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1271
1272 worker_env_.env->SleepForMicroseconds(5 *
1273 tensorflow::EnvTime::kSecondsToMicros);
1274
1275 KeepAliveRequest keep_alive_request;
1276 KeepAliveResponse keep_alive_response;
1277
1278 keep_alive_request.set_context_id(context_id);
1279
1280 Status status =
1281 eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
1282
1283 EXPECT_EQ(status.code(), error::UNAVAILABLE);
1284 EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
1285 status.error_message());
1286
1287 uint64 new_context_id = random::New64();
1288 // Create a new context.
1289 request.set_context_id(new_context_id);
1290 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1291
1292 // The context should not be GC'd.
1293 worker_env_.env->SleepForMicroseconds(1 *
1294 tensorflow::EnvTime::kSecondsToMicros);
1295
1296 keep_alive_request.set_context_id(new_context_id);
1297
1298 TF_ASSERT_OK(
1299 eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response));
1300 }
1301
1302 } // namespace
1303 } // namespace eager
1304 } // namespace tensorflow
1305