1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17
18 #include "absl/types/optional.h"
19 #include "absl/types/variant.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
24 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
25 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
26 #include "tensorflow/core/distributed_runtime/session_mgr.h"
27 #include "tensorflow/core/distributed_runtime/test_utils.h"
28 #include "tensorflow/core/distributed_runtime/worker_env.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/random/random.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/protobuf/eager_service.pb.h"
39 #include "tensorflow/core/protobuf/error_codes.pb.h"
40 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
41 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
42
43 namespace tensorflow {
44 namespace eager {
45 namespace {
46
47 class TestEagerServiceImpl : public EagerServiceImpl {
48 public:
TestEagerServiceImpl(const WorkerEnv * env)49 explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {}
GetEagerContext(const uint64 context_id,EagerContext ** ctx)50 Status GetEagerContext(const uint64 context_id, EagerContext** ctx) {
51 ServerContext* context = nullptr;
52 TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
53 core::ScopedUnref context_unref(context);
54 *ctx = context->Context();
55 return Status::OK();
56 }
GetTensorHandle(const uint64 context_id,const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)57 Status GetTensorHandle(const uint64 context_id,
58 const RemoteTensorHandleInternal& remote_handle,
59 tensorflow::TensorHandle** handle) {
60 ServerContext* context = nullptr;
61 TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
62 core::ScopedUnref context_unref(context);
63
64 return context->Context()->RemoteMgr()->GetTensorHandle(remote_handle,
65 handle);
66 }
67 };
68
69 class FakeEagerClient : public EagerClient {
70 public:
FakeEagerClient()71 FakeEagerClient() {}
~FakeEagerClient()72 ~FakeEagerClient() override {}
73
SetServiceImpl(TestEagerServiceImpl * impl)74 void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; }
75
76 #define CLIENT_METHOD(method) \
77 void method##Async(const method##Request* request, \
78 method##Response* response, StatusCallback done) \
79 override { \
80 done(impl_->method(request, response)); \
81 }
82
83 CLIENT_METHOD(CreateContext);
84 CLIENT_METHOD(UpdateContext);
85 CLIENT_METHOD(WaitQueueDone);
86 CLIENT_METHOD(KeepAlive);
87 CLIENT_METHOD(CloseContext);
88 #undef CLIENT_METHOD
89
EnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)90 void EnqueueAsync(CallOptions* call_opts, const EnqueueRequest* request,
91 EnqueueResponse* response, StatusCallback done) override {
92 done(impl_->Enqueue(call_opts, request, response));
93 }
94
RunComponentFunctionAsync(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)95 void RunComponentFunctionAsync(CallOptions* call_opts,
96 const RunComponentFunctionRequest* request,
97 RunComponentFunctionResponse* response,
98 StatusCallback done) override {
99 impl_->RunComponentFunction(call_opts, request, response, std::move(done));
100 }
101
StreamingEnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)102 void StreamingEnqueueAsync(CallOptions* call_opts,
103 const EnqueueRequest* request,
104 EnqueueResponse* response,
105 StatusCallback done) override {
106 done(impl_->Enqueue(nullptr, request, response));
107 }
108
allow_multiple_pending_requests() const109 bool allow_multiple_pending_requests() const override { return false; }
110
111 private:
112 TestEagerServiceImpl* impl_;
113 };
114
115 class DummyEagerClientCache : public EagerClientCache {
116 public:
DummyEagerClientCache()117 DummyEagerClientCache() : client_(new FakeEagerClient) {}
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)118 Status GetClient(const string& target,
119 core::RefCountPtr<EagerClient>* client) override {
120 client->reset(client_.get());
121 client_->Ref();
122 return Status::OK();
123 }
124
125 private:
126 core::RefCountPtr<EagerClient> client_;
127 };
128
129 class FakeCache : public TestWorkerCache {
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)130 Status GetEagerClientCache(
131 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
132 eager_client_cache->reset(new DummyEagerClientCache);
133 return Status::OK();
134 }
135
ListWorkers(std::vector<string> * workers) const136 void ListWorkers(std::vector<string>* workers) const override {
137 workers->push_back("/job:localhost/replica:0/task:0");
138 }
139 };
140
141 class EagerServiceImplTest : public ::testing::Test {
142 public:
EagerServiceImplTest()143 EagerServiceImplTest()
144 : rendezvous_mgr_(&worker_env_),
145 session_mgr_(new SessionMgr(
146 &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
147 std::unique_ptr<WorkerCacheInterface>(new FakeCache),
148 [](const ServerDef& server_def,
149 WorkerCacheInterface** worker_cache) {
150 *worker_cache = new FakeCache;
151 return Status::OK();
152 })) {
153 worker_env_.env = Env::Default();
154
155 worker_env_.rendezvous_mgr = &rendezvous_mgr_;
156 worker_env_.session_mgr = session_mgr_.get();
157
158 device_mgr_ = absl::make_unique<StaticDeviceMgr>(
159 DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
160 worker_env_.local_devices = device_mgr_->ListDevices();
161 worker_env_.device_mgr = device_mgr_.get();
162 }
163
164 protected:
165 WorkerEnv worker_env_;
166 tensorflow::RpcRendezvousMgr rendezvous_mgr_;
167 std::unique_ptr<SessionMgr> session_mgr_;
168 std::unique_ptr<DeviceMgr> device_mgr_;
169 };
170
SetTensorProto(TensorProto * tensor_proto)171 void SetTensorProto(TensorProto* tensor_proto) {
172 int64_t dims[] = {2, 2};
173 float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
174 TF_Tensor* t = TF_AllocateTensor(
175 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
176 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
177 tensorflow::Tensor tensor;
178 TF_ASSERT_OK(tensorflow::TF_TensorToTensor(t, &tensor));
179 tensor.AsProtoTensorContent(tensor_proto);
180 TF_DeleteTensor(t);
181 }
182
BuildOperation(Operation * operation,int64_t id,const string & name,const std::vector<absl::variant<TensorProto,std::pair<int64,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device)183 void BuildOperation(
184 Operation* operation, int64_t id, const string& name,
185 const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
186 inputs,
187 const std::unordered_map<string, AttrValue>& attrs, const string& device) {
188 operation->set_id(id);
189 operation->set_name(name);
190 operation->set_device(device);
191
192 for (const auto& input : inputs) {
193 if (input.index() == 0) {
194 *operation->add_op_inputs()->mutable_tensor() =
195 absl::get<TensorProto>(input);
196 } else {
197 const auto& tensor_handle_pair =
198 absl::get<std::pair<int64, int32>>(input);
199 auto* input = operation->add_op_inputs()->mutable_remote_handle();
200 input->set_op_id(tensor_handle_pair.first);
201 input->set_output_num(tensor_handle_pair.second);
202 input->set_op_device(device);
203 input->set_device(device);
204 }
205 }
206
207 for (const auto& attr_entry : attrs) {
208 (*operation->mutable_attrs())[attr_entry.first] = attr_entry.second;
209 }
210 }
211
AddOperationToEnqueueRequest(int64_t id,const string & name,const std::vector<absl::variant<TensorProto,std::pair<int64,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,EnqueueRequest * request)212 void AddOperationToEnqueueRequest(
213 int64_t id, const string& name,
214 const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
215 inputs,
216 const std::unordered_map<string, AttrValue>& attrs, const string& device,
217 EnqueueRequest* request) {
218 auto* operation = request->add_queue()->mutable_operation();
219 BuildOperation(operation, id, name, inputs, attrs, device);
220 }
221
AddOperationToRunComponentFunctionRequest(int64_t id,const string & name,const std::vector<absl::variant<TensorProto,std::pair<int64,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,const int output_num,RunComponentFunctionRequest * request)222 void AddOperationToRunComponentFunctionRequest(
223 int64_t id, const string& name,
224 const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
225 inputs,
226 const std::unordered_map<string, AttrValue>& attrs, const string& device,
227 const int output_num, RunComponentFunctionRequest* request) {
228 auto* operation = request->mutable_operation();
229 operation->set_is_function(true);
230 operation->set_is_component_function(true);
231 request->add_output_num(output_num);
232 BuildOperation(operation, id, name, inputs, attrs, device);
233 }
234
MatMulFunctionNodeDef()235 tensorflow::NodeDef MatMulFunctionNodeDef() {
236 tensorflow::NodeDef def;
237 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
238 " name: 'matmul_func'"
239 " op: 'MatMulFunction'"
240 " input: 'a'"
241 " input: 'a'"
242 " attr {"
243 " key: 'T'"
244 " value {"
245 " type: DT_FLOAT"
246 " }"
247 " }",
248 &def));
249 return def;
250 }
251
MatMulFunction()252 tensorflow::FunctionDef MatMulFunction() {
253 tensorflow::FunctionDef def;
254 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
255 " signature {"
256 " name: 'MatMulFunction'"
257 " input_arg {"
258 " name: 'a'"
259 " type: DT_FLOAT"
260 " }"
261 " output_arg {"
262 " name: 'm'"
263 " type: DT_FLOAT"
264 " }"
265 " }"
266 " node_def {"
267 " name: 'matmul'"
268 " op: 'MatMul'"
269 " input: 'a'"
270 " input: 'a'"
271 " attr {"
272 " key: 'T'"
273 " value {"
274 " type: DT_FLOAT"
275 " }"
276 " }"
277 " attr {"
278 " key: 'transpose_a'"
279 " value {"
280 " b: false"
281 " }"
282 " }"
283 " }"
284 " ret {"
285 " key: 'm'"
286 " value: 'matmul:product'"
287 " }",
288 &def));
289 return def;
290 }
291
MatMulNestedFunction()292 tensorflow::FunctionDef MatMulNestedFunction() {
293 tensorflow::FunctionDef def;
294 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
295 " signature {"
296 " name: 'MatMulNestedFunction'"
297 " input_arg {"
298 " name: 'a'"
299 " type: DT_FLOAT"
300 " }"
301 " output_arg {"
302 " name: 'matmul_nested'"
303 " type: DT_FLOAT"
304 " }"
305 " }"
306 " node_def {"
307 " name: 'matmul_nested'"
308 " op: 'MatMulFunction'"
309 " input: 'a'"
310 " attr {"
311 " key: 'T'"
312 " value {"
313 " type: DT_FLOAT"
314 " }"
315 " }"
316 " }"
317 " ret {"
318 " key: 'matmul_nested'"
319 " value: 'matmul_nested:m:0'"
320 " }",
321 &def));
322 return def;
323 }
324
SingleRecvNodeFunction()325 tensorflow::FunctionDef SingleRecvNodeFunction() {
326 tensorflow::FunctionDef def;
327 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
328 " signature {"
329 " name: 'SingleRecvNodeFunction'"
330 " input_arg {"
331 " name: 'a'"
332 " type: DT_FLOAT"
333 " }"
334 " output_arg {"
335 " name: 'recv_tensor'"
336 " type: DT_FLOAT"
337 " }"
338 " }"
339 " node_def {"
340 " name: 'recv_node'"
341 " op: '_Recv'"
342 " device: '/job:localhost/replica:0/task:0/device:CPU:0'"
343 " attr {"
344 " key: 'client_terminated'"
345 " value {"
346 " b: true"
347 " }"
348 " }"
349 " attr {"
350 " key: 'recv_device'"
351 " value {"
352 " s: '/job:localhost/replica:0/task:0/device:CPU:0'"
353 " }"
354 " }"
355 " attr {"
356 " key: 'send_device'"
357 " value {"
358 " s: '/job:localhost/replica:0/task:0/device:CPU:0'"
359 " }"
360 " }"
361 " attr {"
362 " key: 'send_device_incarnation'"
363 " value {"
364 " i: 1"
365 " }"
366 " }"
367 " attr {"
368 " key: 'tensor_name'"
369 " value {"
370 " s: 't0'"
371 " }"
372 " }"
373 " attr {"
374 " key: 'tensor_type'"
375 " value {"
376 " type: DT_FLOAT"
377 " }"
378 " }"
379 " }"
380 " ret {"
381 " key: 'recv_tensor'"
382 " value: 'recv_node:tensor:0'"
383 " }",
384 &def));
385 return def;
386 }
387
388 // Test creates a context and attempts to execute some ops.
TEST_F(EagerServiceImplTest,BasicTest)389 TEST_F(EagerServiceImplTest, BasicTest) {
390 TestEagerServiceImpl eager_service_impl(&worker_env_);
391
392 uint64 context_id = random::New64();
393
394 CreateContextRequest request;
395 request.mutable_server_def()->set_job_name("localhost");
396 request.mutable_server_def()->set_task_index(0);
397 request.set_context_id(context_id);
398 CreateContextResponse response;
399
400 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
401
402 EnqueueRequest remote_enqueue_request;
403 remote_enqueue_request.set_context_id(context_id);
404 EnqueueResponse remote_enqueue_response;
405
406 std::unordered_map<string, AttrValue> const_attrs;
407 AttrValue val;
408 val.set_type(tensorflow::DataType::DT_FLOAT);
409 const_attrs.insert({"dtype", val});
410 val.Clear();
411 SetTensorProto(val.mutable_tensor());
412 const_attrs.insert({"value", val});
413
414 AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
415 "/job:localhost/replica:0/task:0/device:CPU:0",
416 &remote_enqueue_request);
417
418 std::unordered_map<string, AttrValue> attrs;
419 val.Clear();
420 val.set_type(tensorflow::DataType::DT_FLOAT);
421 attrs.insert({"T", val});
422 val.Clear();
423 val.set_b(false);
424 attrs.insert({"transpose_a", val});
425 attrs.insert({"transpose_b", val});
426
427 AddOperationToEnqueueRequest(
428 2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
429 "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
430
431 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
432 &remote_enqueue_response));
433
434 auto& matmul_result_shape =
435 remote_enqueue_response.queue_response(1).shape(0);
436 EXPECT_EQ(matmul_result_shape.dim(0).size(), 2);
437 EXPECT_EQ(matmul_result_shape.dim(1).size(), 2);
438
439 tensorflow::TensorHandle* tensor_handle;
440 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
441 context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
442
443 // This should be OK to do since we've placed all computation on the CPU
444 // device.
445 const tensorflow::Tensor* t = nullptr;
446 TF_ASSERT_OK(tensor_handle->Tensor(&t));
447
448 auto actual = t->flat<float>();
449
450 EXPECT_EQ(4, actual.size());
451
452 EXPECT_EQ(7, actual(0));
453 EXPECT_EQ(10, actual(1));
454 EXPECT_EQ(15, actual(2));
455 EXPECT_EQ(22, actual(3));
456
457 CloseContextRequest close_context_request;
458 close_context_request.set_context_id(context_id);
459 close_context_request.set_context_view_id(0);
460 CloseContextResponse close_context_response;
461 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
462 &close_context_response));
463 }
464
465 class EagerServiceImplFunctionTest : public EagerServiceImplTest {
466 public:
EagerServiceImplFunctionTest()467 EagerServiceImplFunctionTest() : EagerServiceImplTest() {}
468
469 // Creates a context and attempts to execute a function.
TestFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool local_inputs=false,const bool test_cancel=false)470 void TestFunction(const RegisterFunctionOp& register_op,
471 const string& function_name,
472 const bool local_inputs = false,
473 const bool test_cancel = false) {
474 TestEagerServiceImpl eager_service_impl(&worker_env_);
475
476 uint64 context_id = random::New64();
477
478 CreateContextRequest request;
479 request.mutable_server_def()->set_job_name("localhost");
480 request.mutable_server_def()->set_task_index(0);
481 request.set_context_id(context_id);
482 CreateContextResponse response;
483
484 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
485
486 EnqueueRequest enqueue_request;
487 enqueue_request.set_context_id(context_id);
488 *enqueue_request.add_queue()->mutable_register_function() = register_op;
489 EnqueueResponse enqueue_response;
490
491 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
492 &enqueue_response));
493
494 EnqueueRequest remote_enqueue_request;
495 remote_enqueue_request.set_context_id(context_id);
496 EnqueueResponse remote_enqueue_response;
497
498 if (local_inputs) {
499 TensorProto tensor_proto;
500 SetTensorProto(&tensor_proto);
501 AddOperationToEnqueueRequest(
502 2, function_name, {tensor_proto},
503 std::unordered_map<string, AttrValue>(),
504 "/job:localhost/replica:0/task:0/device:CPU:0",
505 &remote_enqueue_request);
506
507 } else {
508 std::unordered_map<string, AttrValue> const_attrs;
509 AttrValue val;
510 val.set_type(tensorflow::DataType::DT_FLOAT);
511 const_attrs.insert({"dtype", val});
512 val.Clear();
513
514 SetTensorProto(val.mutable_tensor());
515 const_attrs.insert({"value", val});
516
517 AddOperationToEnqueueRequest(
518 1, "Const", {}, const_attrs,
519 "/job:localhost/replica:0/task:0/device:CPU:0",
520 &remote_enqueue_request);
521 AddOperationToEnqueueRequest(
522 2, function_name, {std::make_pair(1, 0)},
523 std::unordered_map<string, AttrValue>(),
524 "/job:localhost/replica:0/task:0/device:CPU:0",
525 &remote_enqueue_request);
526 }
527
528 CallOptions call_opts;
529 Status status;
530 Notification n;
531 Env::Default()->SchedClosure([&] {
532 status = eager_service_impl.Enqueue(&call_opts, &remote_enqueue_request,
533 &remote_enqueue_response);
534 n.Notify();
535 });
536
537 if (test_cancel) {
538 // Wait to let the Enqueue thread starts running
539 Env::Default()->SleepForMicroseconds(500000);
540 call_opts.StartCancel();
541 n.WaitForNotification();
542 EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
543 } else {
544 n.WaitForNotification();
545 TF_ASSERT_OK(status);
546 const tensorflow::Tensor* t = nullptr;
547 tensorflow::TensorHandle* tensor_handle;
548 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
549 context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
550 TF_ASSERT_OK(tensor_handle->Tensor(&t));
551
552 auto actual = t->flat<float>();
553 EXPECT_EQ(4, actual.size());
554
555 EXPECT_EQ(7, actual(0));
556 EXPECT_EQ(10, actual(1));
557 EXPECT_EQ(15, actual(2));
558 EXPECT_EQ(22, actual(3));
559 }
560
561 CloseContextRequest close_context_request;
562 close_context_request.set_context_id(context_id);
563 close_context_request.set_context_view_id(0);
564 CloseContextResponse close_context_response;
565 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
566 &close_context_response));
567 }
568
569 // Creates a context and attempts to execute a component function.
TestComponentFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool test_cancel)570 void TestComponentFunction(const RegisterFunctionOp& register_op,
571 const string& function_name,
572 const bool test_cancel) {
573 TestEagerServiceImpl eager_service_impl(&worker_env_);
574 uint64 context_id = random::New64();
575
576 // Create context.
577 CreateContextRequest request;
578 request.mutable_server_def()->set_job_name("localhost");
579 request.mutable_server_def()->set_task_index(0);
580 request.set_context_id(context_id);
581 CreateContextResponse response;
582 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
583
584 // Register function.
585 EnqueueRequest enqueue_request;
586 enqueue_request.set_context_id(context_id);
587 *enqueue_request.add_queue()->mutable_register_function() = register_op;
588 EnqueueResponse enqueue_response;
589 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
590 &enqueue_response));
591
592 // First run an op to generate input for function.
593 EnqueueRequest remote_enqueue_request;
594 remote_enqueue_request.set_context_id(context_id);
595 EnqueueResponse remote_enqueue_response;
596
597 std::unordered_map<string, AttrValue> const_attrs;
598 AttrValue val;
599 val.set_type(tensorflow::DataType::DT_FLOAT);
600 const_attrs.insert({"dtype", val});
601 val.Clear();
602 SetTensorProto(val.mutable_tensor());
603 const_attrs.insert({"value", val});
604 AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
605 "/job:localhost/replica:0/task:0/device:CPU:0",
606 &remote_enqueue_request);
607 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
608 &remote_enqueue_response));
609
610 // Run function with input from the previous op.
611 RunComponentFunctionRequest run_comp_func_request;
612 run_comp_func_request.set_context_id(context_id);
613 RunComponentFunctionResponse run_comp_func_response;
614 const int output_num = 5;
615 AddOperationToRunComponentFunctionRequest(
616 2, function_name, {std::make_pair(1, 0)},
617 std::unordered_map<string, AttrValue>(),
618 "/job:localhost/replica:0/task:0/device:CPU:0", output_num,
619 &run_comp_func_request);
620
621 CallOptions call_opts;
622 Notification n;
623 Status status;
624 eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request,
625 &run_comp_func_response,
626 [&status, &n](const Status& s) {
627 status.Update(s);
628 n.Notify();
629 });
630 if (test_cancel) {
631 call_opts.StartCancel();
632 }
633 n.WaitForNotification();
634 if (test_cancel) {
635 EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
636 } else {
637 TF_ASSERT_OK(status);
638 // Retrieve the output.
639 const tensorflow::Tensor* t = nullptr;
640 tensorflow::TensorHandle* tensor_handle;
641 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
642 context_id, RemoteTensorHandleInternal(2, output_num),
643 &tensor_handle));
644 TF_ASSERT_OK(tensor_handle->Tensor(&t));
645
646 auto actual = t->flat<float>();
647 EXPECT_EQ(4, actual.size());
648
649 EXPECT_EQ(7, actual(0));
650 EXPECT_EQ(10, actual(1));
651 EXPECT_EQ(15, actual(2));
652 EXPECT_EQ(22, actual(3));
653 }
654
655 CloseContextRequest close_context_request;
656 close_context_request.set_context_id(context_id);
657 close_context_request.set_context_view_id(0);
658 CloseContextResponse close_context_response;
659 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
660 &close_context_response));
661 }
662 };
663
TEST_F(EagerServiceImplFunctionTest,BasicFunctionTest)664 TEST_F(EagerServiceImplFunctionTest, BasicFunctionTest) {
665 RegisterFunctionOp register_op;
666 *register_op.mutable_function_def() = MatMulFunction();
667 TestFunction(register_op, "MatMulFunction");
668 }
669
TEST_F(EagerServiceImplFunctionTest,FunctionWithLocalInputsTest)670 TEST_F(EagerServiceImplFunctionTest, FunctionWithLocalInputsTest) {
671 RegisterFunctionOp register_op;
672 *register_op.mutable_function_def() = MatMulFunction();
673 TestFunction(register_op, "MatMulFunction", /*local_inputs=*/true);
674 }
675
TEST_F(EagerServiceImplFunctionTest,NestedFunctionTest)676 TEST_F(EagerServiceImplFunctionTest, NestedFunctionTest) {
677 RegisterFunctionOp register_op;
678 *register_op.mutable_function_def() = MatMulNestedFunction();
679 *register_op.mutable_library()->add_function() = MatMulFunction();
680 TestFunction(register_op, "MatMulNestedFunction");
681 }
682
TEST_F(EagerServiceImplFunctionTest,FunctionCancellationTest)683 TEST_F(EagerServiceImplFunctionTest, FunctionCancellationTest) {
684 RegisterFunctionOp register_op;
685 *register_op.mutable_function_def() = SingleRecvNodeFunction();
686 TestFunction(register_op, "SingleRecvNodeFunction", /*local_inputs=*/false,
687 /*test_cancel=*/true);
688 }
689
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionTest)690 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionTest) {
691 RegisterFunctionOp register_op;
692 *register_op.mutable_function_def() = MatMulFunction();
693 TestComponentFunction(register_op, "MatMulFunction", false);
694 }
695
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionCancellationTest)696 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionCancellationTest) {
697 RegisterFunctionOp register_op;
698 *register_op.mutable_function_def() = SingleRecvNodeFunction();
699 TestComponentFunction(register_op, "SingleRecvNodeFunction", true);
700 }
701
702 class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
703 public:
FunctionWithRemoteInputsTest()704 FunctionWithRemoteInputsTest()
705 : EagerServiceImplTest(), eager_service_impl_(&worker_env_) {
706 remote_device_mgr_ = absl::make_unique<StaticDeviceMgr>(
707 DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
708 context_id_ = random::New64();
709 }
710
711 class TestExecuteNodeArgs : public EagerKernelArgs {
712 public:
TestExecuteNodeArgs(gtl::InlinedVector<TensorValue,4> && tensor_args,std::function<Status (const int,eager::RemoteTensorHandle *)> serialize_remote_handle)713 TestExecuteNodeArgs(
714 gtl::InlinedVector<TensorValue, 4>&& tensor_args,
715 std::function<Status(const int, eager::RemoteTensorHandle*)>
716 serialize_remote_handle)
717 : EagerKernelArgs(std::move(tensor_args)),
718 serialize_remote_handle_(std::move(serialize_remote_handle)) {}
719
HasRemoteOrPackedInputs() const720 bool HasRemoteOrPackedInputs() const override { return true; }
721
GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val) const722 Status GetRemoteArg(const FunctionArgIndex& index,
723 eager::RemoteTensorHandle* val) const override {
724 return serialize_remote_handle_(index.index, val);
725 }
726
727 private:
728 std::function<Status(const int, eager::RemoteTensorHandle*)>
729 serialize_remote_handle_;
730 };
731
MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef & fdef)732 bool MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef& fdef) {
733 for (const auto& node : fdef.node_def()) {
734 if (node.op() == "MatMul") {
735 return node.attr().find("transpose_a") != node.attr().end();
736 }
737 }
738 return false;
739 }
740
Init()741 void Init() {
742 CreateContextRequest request;
743 request.mutable_server_def()->set_job_name("localhost");
744 request.mutable_server_def()->set_task_index(0);
745 request.set_context_id(context_id_);
746 CreateContextResponse response;
747 TF_ASSERT_OK(eager_service_impl_.CreateContext(&request, &response));
748
749 // Make the fake EagerClient use the local eager_service_impl.
750 EagerContext* ctx = nullptr;
751 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
752 Device* device;
753 TF_ASSERT_OK(ctx->FindDeviceFromName(local_device_.c_str(), &device));
754 core::RefCountPtr<EagerClient> client;
755 TF_ASSERT_OK(ctx->GetClient(device, &client));
756 FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client.get());
757 fake_client->SetServiceImpl(&eager_service_impl_);
758
759 // Create an input on local_device for MatMulFunction.
760 EnqueueRequest remote_enqueue_request;
761 remote_enqueue_request.set_context_id(context_id_);
762 EnqueueResponse remote_enqueue_response;
763 std::unordered_map<string, AttrValue> const_attrs;
764 AttrValue val;
765 val.set_type(tensorflow::DataType::DT_FLOAT);
766 const_attrs.insert({"dtype", val});
767 val.Clear();
768 SetTensorProto(val.mutable_tensor());
769 const_attrs.insert({"value", val});
770 AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, local_device_,
771 &remote_enqueue_request);
772 TF_EXPECT_OK(eager_service_impl_.Enqueue(nullptr, &remote_enqueue_request,
773 &remote_enqueue_response));
774 eager_cluster_flr_ = absl::make_unique<EagerClusterFunctionLibraryRuntime>(
775 context_id_, ctx, device_mgr_.get());
776
777 fdef_ = MatMulFunction();
778 TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
779 eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
780 remote_device_mgr_.get(), Env::Default(), /*config=*/
781 nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
782 /*thread_pool=*/nullptr, eager_cluster_flr_.get(),
783 /*session_metadata=*/nullptr,
784 Rendezvous::Factory{[this](const int64_t step_id,
785 const DeviceMgr* device_mgr,
786 Rendezvous** r) {
787 *r = worker_env_.rendezvous_mgr->Find(step_id);
788 return Status::OK();
789 }});
790 }
791
CheckOutputTensorAndClose(const Tensor & tensor)792 void CheckOutputTensorAndClose(const Tensor& tensor) {
793 auto actual = tensor.flat<float>();
794 EXPECT_EQ(4, actual.size());
795 EXPECT_EQ(7, actual(0));
796 EXPECT_EQ(10, actual(1));
797 EXPECT_EQ(15, actual(2));
798 EXPECT_EQ(22, actual(3));
799
800 CloseContextRequest close_context_request;
801 close_context_request.set_context_id(context_id_);
802 close_context_request.set_context_view_id(0);
803 CloseContextResponse close_context_response;
804 TF_ASSERT_OK(eager_service_impl_.CloseContext(&close_context_request,
805 &close_context_response));
806 }
807
CheckOutputsAndClose(const std::vector<FunctionRet> & outputs,const int64_t op_id)808 void CheckOutputsAndClose(const std::vector<FunctionRet>& outputs,
809 const int64_t op_id) {
810 const tensorflow::Tensor* t = nullptr;
811 tensorflow::TensorHandle* tensor_handle;
812 TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
813 context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle));
814 TF_ASSERT_OK(tensor_handle->Tensor(&t));
815 EXPECT_EQ(outputs.size(), 1);
816 EXPECT_EQ(outputs.at(0).index(), 1);
817 const TensorShape& shape = absl::get<TensorShape>(outputs.at(0));
818 EXPECT_EQ(shape, t->shape());
819 CheckOutputTensorAndClose(*t);
820 }
821
822 protected:
823 const string local_device_ = "/job:localhost/replica:0/task:0/device:CPU:0";
824 const string remote_device_ = "/job:localhost/replica:0/task:1/device:CPU:0";
825 TestEagerServiceImpl eager_service_impl_;
826 std::unique_ptr<DeviceMgr> remote_device_mgr_;
827 uint64 context_id_;
828 tensorflow::FunctionDef fdef_;
829 std::unique_ptr<ProcessFunctionLibraryRuntime> eager_pflr_;
830 std::unique_ptr<EagerClusterFunctionLibraryRuntime> eager_cluster_flr_;
831 FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
832 };
833
834 // Test executes a remote function through
835 // ProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime).
TEST_F(FunctionWithRemoteInputsTest,EagerPFLRTest)836 TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
837 Init();
838 // Instantiate MatMulFunction on remote_device.
839 FunctionLibraryRuntime::InstantiateOptions options;
840 options.target = remote_device_;
841 options.is_multi_device_function = true;
842 options.input_devices.push_back(local_device_);
843 FunctionLibraryRuntime::Handle handle;
844 EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
845 TF_ASSERT_OK(eager_pflr_->Instantiate(
846 fdef_.signature().name(), AttrSlice(&fdef_.attr()), options, &handle));
847 EagerContext* ctx = nullptr;
848 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
849 for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
850 const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
851 EXPECT_TRUE(fdef != nullptr);
852 if (absl::StartsWith(func_name, "MatMulFunction")) {
853 EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
854 }
855 }
856 bool is_cross_process = false;
857 TF_CHECK_OK(eager_pflr_->IsCrossProcess(handle, &is_cross_process));
858 EXPECT_TRUE(is_cross_process);
859
860 // Run MatMulFunction on remote_device.
861 FunctionLibraryRuntime::Options opts;
862 const uint64 op_id = 2;
863 opts.op_id = op_id;
864 Notification done;
865 Status status;
866 RemoteTensorHandle input;
867 input.set_op_id(1);
868 input.set_output_num(0);
869 input.set_op_device(local_device_);
870 input.set_device(local_device_);
871 std::vector<RemoteTensorHandle> inputs = {input};
872 std::vector<FunctionRet> outputs;
873 gtl::InlinedVector<TensorValue, 4> tensor_args = {TensorValue()};
874 TestExecuteNodeArgs args(
875 std::move(tensor_args),
876 [&inputs](const int i, RemoteTensorHandle* handle) -> Status {
877 *handle = inputs.at(i);
878 return Status::OK();
879 });
880 eager_pflr_->Run(opts, handle, args, &outputs,
881 [&status, &done](const Status& s) {
882 status = s;
883 done.Notify();
884 });
885 done.WaitForNotification();
886 TF_ASSERT_OK(status);
887 CheckOutputsAndClose(outputs, op_id);
888 }
889
890 // Test executes a remote function with local input and output tensors.
TEST_F(FunctionWithRemoteInputsTest,EagerClusterFLRTestWithLocalInputAndOutput)891 TEST_F(FunctionWithRemoteInputsTest,
892 EagerClusterFLRTestWithLocalInputAndOutput) {
893 Init();
894 // Instantiate MatMulFunction on remote_device.
895 FunctionLibraryRuntime::Handle handle;
896 EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
897 Status status;
898 Notification instantiate_done;
899 eager_cluster_flr_->Instantiate(
900 fdef_.signature().name(), func_lib_def_, AttrSlice(&fdef_.attr()),
901 FunctionLibraryRuntime::InstantiateOptions(), &handle,
902 [&status, &instantiate_done](const Status& s) {
903 status = s;
904 instantiate_done.Notify();
905 });
906 instantiate_done.WaitForNotification();
907 TF_ASSERT_OK(status);
908 EagerContext* ctx = nullptr;
909 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
910 for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
911 const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
912 EXPECT_TRUE(fdef != nullptr);
913 if (absl::StartsWith(func_name, "MatMulFunction")) {
914 EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
915 }
916 }
917 const tensorflow::Tensor* input_tensor = nullptr;
918 tensorflow::TensorHandle* tensor_handle;
919 TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
920 context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle));
921 TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor));
922
923 // Send input_tensor to the remote device, execute MatMulFunction on the
924 // remote device, and send the output back.
925 FunctionLibraryRuntime::Options opts;
926 Notification execute_done;
927 std::vector<Tensor> inputs = {*input_tensor};
928 std::vector<Tensor> outputs;
929 eager_cluster_flr_->Run(opts, handle, inputs, &outputs,
930 [&status, &execute_done](const Status& s) {
931 status = s;
932 execute_done.Notify();
933 });
934 execute_done.WaitForNotification();
935 TF_ASSERT_OK(status);
936 EXPECT_EQ(outputs.size(), 1);
937 CheckOutputTensorAndClose(outputs.at(0));
938 }
939
940 // Test executes a remote function through KernelAndDeviceFunc::Run.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncTest)941 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
942 Init();
943 Device* local_device;
944 TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
945 std::vector<Device*> input_dev_ptrs;
946 input_dev_ptrs.push_back(local_device);
947 FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
948 EagerContext* ctx = nullptr;
949 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
950 core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
951 const int64_t op_id = 2;
952 kernel.reset(new KernelAndDeviceFunc(
953 flr, eager_pflr_.get(), std::move(input_dev_ptrs),
954 /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
955 /*runner=*/nullptr,
956 /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
957 /*outputs_on_op_device=*/false, ctx->RendezvousCreator(),
958 [=]() { return op_id; }));
959
960 // Instantiate MatMulFunction on remote_device.
961 const NodeDef node_def = MatMulFunctionNodeDef();
962 TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
963
964 // Run MatMulFunction on remote_device.
965 gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
966 RemoteTensorHandle input;
967 input.set_op_id(1);
968 input.set_output_num(0);
969 input.set_op_device(local_device_);
970 input.set_device(local_device_);
971 std::vector<RemoteTensorHandle> remote_handles = {input};
972 TestExecuteNodeArgs inputs(
973 std::move(input_tensors),
974 [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
975 *handle = remote_handles.at(index);
976 return Status::OK();
977 });
978 std::vector<FunctionRet> outputs;
979
980 TF_ASSERT_OK(kernel->Run(/*step_container=*/nullptr, inputs, &outputs,
981 /*cancellation_manager=*/nullptr,
982 /*remote_func_params=*/absl::nullopt,
983 /*stack_trace=*/absl::nullopt,
984 /*coordination_service_agent=*/nullptr));
985
986 CheckOutputsAndClose(outputs, op_id);
987 }
988
989 // Test executes a remote function through KernelAndDeviceFunc::RunAsync.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncAsyncTest)990 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
991 Init();
992 Device* local_device;
993 TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
994 std::vector<Device*> input_dev_ptrs;
995 input_dev_ptrs.push_back(local_device);
996 FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
997 EagerContext* ctx = nullptr;
998 TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
999 core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
1000 const int64_t op_id = 2;
1001 kernel.reset(new KernelAndDeviceFunc(
1002 flr, eager_pflr_.get(), std::move(input_dev_ptrs),
1003 /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
1004 /*runner=*/nullptr,
1005 /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
1006 /*outputs_on_op_device=*/false, ctx->RendezvousCreator(),
1007 [=]() { return op_id; }));
1008
1009 // Instantiate MatMulFunction on remote_device.
1010 const NodeDef node_def = MatMulFunctionNodeDef();
1011 TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
1012
1013 // Run MatMulFunction on remote_device.
1014 gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
1015 RemoteTensorHandle input;
1016 input.set_op_id(1);
1017 input.set_output_num(0);
1018 input.set_op_device(local_device_);
1019 input.set_device(local_device_);
1020 std::vector<RemoteTensorHandle> remote_handles = {input};
1021 TestExecuteNodeArgs inputs(
1022 std::move(input_tensors),
1023 [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
1024 *handle = remote_handles.at(index);
1025 return Status::OK();
1026 });
1027 std::vector<FunctionRet> outputs;
1028
1029 Status status;
1030 Notification n;
1031 kernel->RunAsync(/*step_container=*/nullptr, inputs, &outputs,
1032 /*cancellation_manager=*/nullptr,
1033 /*remote_func_params=*/absl::nullopt,
1034 /*coordination_service_agent=*/nullptr,
1035 [&status, &n](const Status& s) {
1036 status = s;
1037 n.Notify();
1038 });
1039 n.WaitForNotification();
1040 TF_ASSERT_OK(status);
1041 CheckOutputsAndClose(outputs, op_id);
1042 }
1043
1044 // Test creates a context and attempts to send a tensor (using the RPC), and
1045 // then use the tensor.
TEST_F(EagerServiceImplTest,SendTensorTest)1046 TEST_F(EagerServiceImplTest, SendTensorTest) {
1047 TestEagerServiceImpl eager_service_impl(&worker_env_);
1048
1049 uint64 context_id = random::New64();
1050
1051 CreateContextRequest request;
1052 request.mutable_server_def()->set_job_name("localhost");
1053 request.mutable_server_def()->set_task_index(0);
1054 request.set_context_id(context_id);
1055 CreateContextResponse response;
1056
1057 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1058
1059 EnqueueRequest remote_enqueue_request;
1060 remote_enqueue_request.set_context_id(context_id);
1061 EnqueueResponse remote_enqueue_response;
1062
1063 auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1064 send_tensor->set_op_id(1);
1065 SetTensorProto(send_tensor->add_tensors());
1066
1067 std::unordered_map<string, AttrValue> attrs;
1068 AttrValue val;
1069 val.Clear();
1070 val.set_type(tensorflow::DataType::DT_FLOAT);
1071 attrs.insert({"T", val});
1072 val.Clear();
1073 val.set_b(false);
1074 attrs.insert({"transpose_a", val});
1075 attrs.insert({"transpose_b", val});
1076
1077 AddOperationToEnqueueRequest(
1078 2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
1079 "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
1080
1081 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1082 &remote_enqueue_response));
1083
1084 const tensorflow::Tensor* t = nullptr;
1085 tensorflow::TensorHandle* tensor_handle;
1086 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1087 context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
1088 TF_ASSERT_OK(tensor_handle->Tensor(&t));
1089
1090 EXPECT_EQ(tensor_handle->device(), nullptr);
1091
1092 auto actual = t->flat<float>();
1093 EXPECT_EQ(4, actual.size());
1094
1095 EXPECT_EQ(7, actual(0));
1096 EXPECT_EQ(10, actual(1));
1097 EXPECT_EQ(15, actual(2));
1098 EXPECT_EQ(22, actual(3));
1099
1100 CloseContextRequest close_context_request;
1101 close_context_request.set_context_id(context_id);
1102 close_context_request.set_context_view_id(0);
1103 CloseContextResponse close_context_response;
1104 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1105 &close_context_response));
1106 }
1107
1108 // Test serializes and sends a pack TensorHandle.
TEST_F(EagerServiceImplTest,SendPackedHandleTest)1109 TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
1110 TestEagerServiceImpl eager_service_impl(&worker_env_);
1111
1112 const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
1113 const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
1114 const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
1115 const string composite_device =
1116 "/job:localhost/replica:0/task:0/device:COMPOSITE:0";
1117
1118 uint64 context_id = random::New64();
1119 CreateContextRequest request;
1120 auto* server_def = request.mutable_server_def();
1121 server_def->set_job_name("localhost");
1122 server_def->set_task_index(0);
1123 request.add_cluster_device_attributes()->set_name(device0);
1124 request.add_cluster_device_attributes()->set_name(device1);
1125 request.add_cluster_device_attributes()->set_name(device2);
1126 request.set_context_id(context_id);
1127 CreateContextResponse response;
1128
1129 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1130
1131 EnqueueRequest remote_enqueue_request;
1132 remote_enqueue_request.set_context_id(context_id);
1133 EnqueueResponse remote_enqueue_response;
1134
1135 // Copy a tensor to device0
1136 auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1137 send_tensor->set_op_id(1);
1138 SetTensorProto(send_tensor->add_tensors());
1139
1140 // Copy a packed handle to device0
1141 auto* send_packed_handle =
1142 remote_enqueue_request.add_queue()->mutable_send_packed_handle();
1143 send_packed_handle->set_op_id(3);
1144 RemoteTensorHandle* remote_handle =
1145 send_packed_handle->add_handles()->mutable_remote_handle();
1146 remote_handle->set_op_id(send_tensor->op_id());
1147 remote_handle->set_output_num(0);
1148 remote_handle->set_op_device(device0);
1149 remote_handle->set_device(device0);
1150
1151 SendPackedHandleOp::LocalTensorHandle* lcoal_handle =
1152 send_packed_handle->add_handles()->mutable_local_handle();
1153 SetTensorProto(lcoal_handle->mutable_tensor());
1154 lcoal_handle->set_device(device1);
1155
1156 remote_handle = send_packed_handle->add_handles()->mutable_remote_handle();
1157 remote_handle->set_op_id(2);
1158 remote_handle->set_output_num(5);
1159 remote_handle->set_op_device(device2);
1160 remote_handle->set_device(device2);
1161
1162 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1163 &remote_enqueue_response));
1164
1165 tensorflow::TensorHandle* packed_handle;
1166 TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1167 context_id, RemoteTensorHandleInternal(3, 0), &packed_handle));
1168
1169 EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
1170 EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
1171 EXPECT_EQ(packed_handle->device()->name(), composite_device);
1172
1173 TensorHandle* handle0 = nullptr;
1174 TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
1175 EXPECT_EQ(handle0->Type(), TensorHandle::LOCAL);
1176 EXPECT_EQ(handle0->op_device()->name(), device0);
1177 const Tensor* t0 = nullptr;
1178 TF_ASSERT_OK(handle0->Tensor(&t0));
1179 auto actual = t0->flat<float>();
1180 EXPECT_EQ(4, actual.size());
1181 EXPECT_EQ(1.0, actual(0));
1182 EXPECT_EQ(2.0, actual(1));
1183 EXPECT_EQ(3.0, actual(2));
1184 EXPECT_EQ(4.0, actual(3));
1185
1186 TensorHandle* handle1 = nullptr;
1187 TF_ASSERT_OK(packed_handle->ExtractPackedHandle(1, &handle1));
1188 EXPECT_EQ(handle1->Type(), TensorHandle::LOCAL);
1189 EXPECT_EQ(handle1->op_device()->name(), device1);
1190 const Tensor* t1 = nullptr;
1191 TF_ASSERT_OK(handle0->Tensor(&t1));
1192 EXPECT_EQ(t1, t0);
1193
1194 TensorHandle* handle2 = nullptr;
1195 TF_ASSERT_OK(packed_handle->ExtractPackedHandle(2, &handle2));
1196 EXPECT_EQ(handle2->Type(), TensorHandle::REMOTE);
1197 EXPECT_EQ(handle2->op_device()->name(), device2);
1198 int64_t op_id;
1199 int32_t output_num;
1200 TF_ASSERT_OK(handle2->RemoteAddress(handle2->device(),
1201 /*wait_until_ready=*/true, &op_id,
1202 &output_num));
1203 EXPECT_EQ(op_id, 2);
1204 EXPECT_EQ(output_num, 5);
1205
1206 CloseContextRequest close_context_request;
1207 close_context_request.set_context_id(context_id);
1208 close_context_request.set_context_view_id(0);
1209 CloseContextResponse close_context_response;
1210 TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1211 &close_context_response));
1212 }
1213
1214 // Test requests sent to the eager service on master.
TEST_F(EagerServiceImplTest,RequestsToMasterTest)1215 TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
1216 tensorflow::Rendezvous* rendezvous =
1217 new tensorflow::IntraProcessRendezvous(device_mgr_.get());
1218 // Create a master eager context.
1219 tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
1220 SessionOptions(),
1221 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
1222 /*async=*/false, device_mgr_.get(), false, rendezvous);
1223 const uint64 context_id = random::New64();
1224
1225 // Set RemoteMgr to ctx.
1226 auto remote_mgr =
1227 absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true, ctx);
1228 TF_ASSERT_OK(ctx->InitializeRemoteWorker(
1229 /*remote_eager_workers=*/nullptr, /*remote_device_mgr=*/nullptr,
1230 /*remote_contexts=*/{}, context_id, /*context_view_id=*/0,
1231 /*rendezvous_creator=*/nullptr,
1232 /*cluster_flr=*/nullptr, std::move(remote_mgr),
1233 /*resource_deallocator=*/nullptr));
1234
1235 TestEagerServiceImpl eager_service_impl(&worker_env_);
1236
1237 EnqueueRequest remote_enqueue_request;
1238 remote_enqueue_request.set_context_id(context_id);
1239 EnqueueResponse remote_enqueue_response;
1240
1241 auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1242 send_tensor->set_op_id(1);
1243 SetTensorProto(send_tensor->add_tensors());
1244
1245 // Unable to handle the request since there is no eager context.
1246 Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1247 &remote_enqueue_response);
1248 EXPECT_EQ(error::UNAVAILABLE, status.code());
1249 EXPECT_TRUE(absl::StrContains(
1250 status.error_message(),
1251 "Unable to find a context_id matching the specified one"));
1252
1253 // The request can be handled after adding the master eager context to
1254 // service.
1255 TF_ASSERT_OK(eager_service_impl.CreateMasterContext(context_id, ctx));
1256 TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1257 &remote_enqueue_response));
1258 ctx->Unref();
1259 }
1260
TEST_F(EagerServiceImplTest,KeepAliveTest)1261 TEST_F(EagerServiceImplTest, KeepAliveTest) {
1262 TestEagerServiceImpl eager_service_impl(&worker_env_);
1263
1264 uint64 context_id = random::New64();
1265 CreateContextRequest request;
1266 request.mutable_server_def()->set_job_name("localhost");
1267 request.mutable_server_def()->set_task_index(0);
1268 request.set_context_id(context_id);
1269 request.set_keep_alive_secs(3);
1270 CreateContextResponse response;
1271
1272 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1273
1274 worker_env_.env->SleepForMicroseconds(5 *
1275 tensorflow::EnvTime::kSecondsToMicros);
1276
1277 KeepAliveRequest keep_alive_request;
1278 KeepAliveResponse keep_alive_response;
1279
1280 keep_alive_request.set_context_id(context_id);
1281
1282 Status status =
1283 eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
1284
1285 EXPECT_EQ(status.code(), error::UNAVAILABLE);
1286 EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
1287 status.error_message());
1288
1289 uint64 new_context_id = random::New64();
1290 // Create a new context.
1291 request.set_context_id(new_context_id);
1292 TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1293
1294 // The context should not be GC'd.
1295 worker_env_.env->SleepForMicroseconds(1 *
1296 tensorflow::EnvTime::kSecondsToMicros);
1297
1298 keep_alive_request.set_context_id(new_context_id);
1299
1300 TF_ASSERT_OK(
1301 eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response));
1302 }
1303
1304 } // namespace
1305 } // namespace eager
1306 } // namespace tensorflow
1307