• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17 
18 #include "absl/types/optional.h"
19 #include "absl/types/variant.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
24 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
25 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
26 #include "tensorflow/core/distributed_runtime/session_mgr.h"
27 #include "tensorflow/core/distributed_runtime/test_utils.h"
28 #include "tensorflow/core/distributed_runtime/worker_env.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/random/random.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/protobuf/eager_service.pb.h"
39 #include "tensorflow/core/protobuf/error_codes.pb.h"
40 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
41 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
42 
43 namespace tensorflow {
44 namespace eager {
45 namespace {
46 
47 class TestEagerServiceImpl : public EagerServiceImpl {
48  public:
TestEagerServiceImpl(const WorkerEnv * env)49   explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {}
GetEagerContext(const uint64 context_id,EagerContext ** ctx)50   Status GetEagerContext(const uint64 context_id, EagerContext** ctx) {
51     ServerContext* context = nullptr;
52     TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
53     core::ScopedUnref context_unref(context);
54     *ctx = context->Context();
55     return Status::OK();
56   }
GetTensorHandle(const uint64 context_id,const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)57   Status GetTensorHandle(const uint64 context_id,
58                          const RemoteTensorHandleInternal& remote_handle,
59                          tensorflow::TensorHandle** handle) {
60     ServerContext* context = nullptr;
61     TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
62     core::ScopedUnref context_unref(context);
63 
64     return context->Context()->RemoteMgr()->GetTensorHandle(remote_handle,
65                                                             handle);
66   }
67 };
68 
69 class FakeEagerClient : public EagerClient {
70  public:
FakeEagerClient()71   FakeEagerClient() {}
~FakeEagerClient()72   ~FakeEagerClient() override {}
73 
SetServiceImpl(TestEagerServiceImpl * impl)74   void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; }
75 
76 #define CLIENT_METHOD(method)                                         \
77   void method##Async(const method##Request* request,                  \
78                      method##Response* response, StatusCallback done) \
79       override {                                                      \
80     done(impl_->method(request, response));                           \
81   }
82 
83   CLIENT_METHOD(CreateContext);
84   CLIENT_METHOD(UpdateContext);
85   CLIENT_METHOD(WaitQueueDone);
86   CLIENT_METHOD(KeepAlive);
87   CLIENT_METHOD(CloseContext);
88 #undef CLIENT_METHOD
89 
EnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)90   void EnqueueAsync(CallOptions* call_opts, const EnqueueRequest* request,
91                     EnqueueResponse* response, StatusCallback done) override {
92     done(impl_->Enqueue(call_opts, request, response));
93   }
94 
RunComponentFunctionAsync(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)95   void RunComponentFunctionAsync(CallOptions* call_opts,
96                                  const RunComponentFunctionRequest* request,
97                                  RunComponentFunctionResponse* response,
98                                  StatusCallback done) override {
99     impl_->RunComponentFunction(call_opts, request, response, std::move(done));
100   }
101 
StreamingEnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)102   void StreamingEnqueueAsync(CallOptions* call_opts,
103                              const EnqueueRequest* request,
104                              EnqueueResponse* response,
105                              StatusCallback done) override {
106     done(impl_->Enqueue(nullptr, request, response));
107   }
108 
allow_multiple_pending_requests() const109   bool allow_multiple_pending_requests() const override { return false; }
110 
111  private:
112   TestEagerServiceImpl* impl_;
113 };
114 
115 class DummyEagerClientCache : public EagerClientCache {
116  public:
DummyEagerClientCache()117   DummyEagerClientCache() : client_(new FakeEagerClient) {}
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)118   Status GetClient(const string& target,
119                    core::RefCountPtr<EagerClient>* client) override {
120     client->reset(client_.get());
121     client_->Ref();
122     return Status::OK();
123   }
124 
125  private:
126   core::RefCountPtr<EagerClient> client_;
127 };
128 
129 class FakeCache : public TestWorkerCache {
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)130   Status GetEagerClientCache(
131       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
132     eager_client_cache->reset(new DummyEagerClientCache);
133     return Status::OK();
134   }
135 
ListWorkers(std::vector<string> * workers) const136   void ListWorkers(std::vector<string>* workers) const override {
137     workers->push_back("/job:localhost/replica:0/task:0");
138   }
139 };
140 
141 class EagerServiceImplTest : public ::testing::Test {
142  public:
EagerServiceImplTest()143   EagerServiceImplTest()
144       : rendezvous_mgr_(&worker_env_),
145         session_mgr_(new SessionMgr(
146             &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
147             std::unique_ptr<WorkerCacheInterface>(new FakeCache),
148             [](const ServerDef& server_def,
149                WorkerCacheInterface** worker_cache) {
150               *worker_cache = new FakeCache;
151               return Status::OK();
152             })) {
153     worker_env_.env = Env::Default();
154 
155     worker_env_.rendezvous_mgr = &rendezvous_mgr_;
156     worker_env_.session_mgr = session_mgr_.get();
157 
158     device_mgr_ = absl::make_unique<StaticDeviceMgr>(
159         DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
160     worker_env_.local_devices = device_mgr_->ListDevices();
161     worker_env_.device_mgr = device_mgr_.get();
162   }
163 
164  protected:
165   WorkerEnv worker_env_;
166   tensorflow::RpcRendezvousMgr rendezvous_mgr_;
167   std::unique_ptr<SessionMgr> session_mgr_;
168   std::unique_ptr<DeviceMgr> device_mgr_;
169 };
170 
SetTensorProto(TensorProto * tensor_proto)171 void SetTensorProto(TensorProto* tensor_proto) {
172   int64_t dims[] = {2, 2};
173   float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
174   TF_Tensor* t = TF_AllocateTensor(
175       TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
176   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
177   tensorflow::Tensor tensor;
178   TF_ASSERT_OK(tensorflow::TF_TensorToTensor(t, &tensor));
179   tensor.AsProtoTensorContent(tensor_proto);
180   TF_DeleteTensor(t);
181 }
182 
BuildOperation(Operation * operation,int64_t id,const string & name,const std::vector<absl::variant<TensorProto,std::pair<int64,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device)183 void BuildOperation(
184     Operation* operation, int64_t id, const string& name,
185     const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
186         inputs,
187     const std::unordered_map<string, AttrValue>& attrs, const string& device) {
188   operation->set_id(id);
189   operation->set_name(name);
190   operation->set_device(device);
191 
192   for (const auto& input : inputs) {
193     if (input.index() == 0) {
194       *operation->add_op_inputs()->mutable_tensor() =
195           absl::get<TensorProto>(input);
196     } else {
197       const auto& tensor_handle_pair =
198           absl::get<std::pair<int64, int32>>(input);
199       auto* input = operation->add_op_inputs()->mutable_remote_handle();
200       input->set_op_id(tensor_handle_pair.first);
201       input->set_output_num(tensor_handle_pair.second);
202       input->set_op_device(device);
203       input->set_device(device);
204     }
205   }
206 
207   for (const auto& attr_entry : attrs) {
208     (*operation->mutable_attrs())[attr_entry.first] = attr_entry.second;
209   }
210 }
211 
AddOperationToEnqueueRequest(int64_t id,const string & name,const std::vector<absl::variant<TensorProto,std::pair<int64,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,EnqueueRequest * request)212 void AddOperationToEnqueueRequest(
213     int64_t id, const string& name,
214     const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
215         inputs,
216     const std::unordered_map<string, AttrValue>& attrs, const string& device,
217     EnqueueRequest* request) {
218   auto* operation = request->add_queue()->mutable_operation();
219   BuildOperation(operation, id, name, inputs, attrs, device);
220 }
221 
AddOperationToRunComponentFunctionRequest(int64_t id,const string & name,const std::vector<absl::variant<TensorProto,std::pair<int64,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,const int output_num,RunComponentFunctionRequest * request)222 void AddOperationToRunComponentFunctionRequest(
223     int64_t id, const string& name,
224     const std::vector<absl::variant<TensorProto, std::pair<int64, int32>>>&
225         inputs,
226     const std::unordered_map<string, AttrValue>& attrs, const string& device,
227     const int output_num, RunComponentFunctionRequest* request) {
228   auto* operation = request->mutable_operation();
229   operation->set_is_function(true);
230   operation->set_is_component_function(true);
231   request->add_output_num(output_num);
232   BuildOperation(operation, id, name, inputs, attrs, device);
233 }
234 
MatMulFunctionNodeDef()235 tensorflow::NodeDef MatMulFunctionNodeDef() {
236   tensorflow::NodeDef def;
237   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
238       "    name: 'matmul_func'"
239       "    op: 'MatMulFunction'"
240       "    input: 'a'"
241       "    input: 'a'"
242       "    attr {"
243       "      key: 'T'"
244       "      value {"
245       "        type: DT_FLOAT"
246       "      }"
247       "    }",
248       &def));
249   return def;
250 }
251 
MatMulFunction()252 tensorflow::FunctionDef MatMulFunction() {
253   tensorflow::FunctionDef def;
254   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
255       "    signature {"
256       "      name: 'MatMulFunction'"
257       "      input_arg {"
258       "        name: 'a'"
259       "        type: DT_FLOAT"
260       "      }"
261       "      output_arg {"
262       "        name: 'm'"
263       "        type: DT_FLOAT"
264       "      }"
265       "    }"
266       "    node_def {"
267       "      name: 'matmul'"
268       "      op: 'MatMul'"
269       "      input: 'a'"
270       "      input: 'a'"
271       "      attr {"
272       "        key: 'T'"
273       "        value {"
274       "          type: DT_FLOAT"
275       "        }"
276       "      }"
277       "      attr {"
278       "        key: 'transpose_a'"
279       "        value {"
280       "          b: false"
281       "        }"
282       "      }"
283       "    }"
284       "    ret {"
285       "      key: 'm'"
286       "      value: 'matmul:product'"
287       "    }",
288       &def));
289   return def;
290 }
291 
MatMulNestedFunction()292 tensorflow::FunctionDef MatMulNestedFunction() {
293   tensorflow::FunctionDef def;
294   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
295       "    signature {"
296       "      name: 'MatMulNestedFunction'"
297       "      input_arg {"
298       "        name: 'a'"
299       "        type: DT_FLOAT"
300       "      }"
301       "      output_arg {"
302       "        name: 'matmul_nested'"
303       "        type: DT_FLOAT"
304       "      }"
305       "    }"
306       "    node_def {"
307       "      name: 'matmul_nested'"
308       "      op: 'MatMulFunction'"
309       "      input: 'a'"
310       "      attr {"
311       "        key: 'T'"
312       "        value {"
313       "          type: DT_FLOAT"
314       "        }"
315       "      }"
316       "    }"
317       "    ret {"
318       "      key: 'matmul_nested'"
319       "      value: 'matmul_nested:m:0'"
320       "    }",
321       &def));
322   return def;
323 }
324 
SingleRecvNodeFunction()325 tensorflow::FunctionDef SingleRecvNodeFunction() {
326   tensorflow::FunctionDef def;
327   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
328       "    signature {"
329       "      name: 'SingleRecvNodeFunction'"
330       "      input_arg {"
331       "        name: 'a'"
332       "        type: DT_FLOAT"
333       "      }"
334       "      output_arg {"
335       "        name: 'recv_tensor'"
336       "        type: DT_FLOAT"
337       "      }"
338       "    }"
339       "    node_def {"
340       "      name: 'recv_node'"
341       "      op: '_Recv'"
342       "      device: '/job:localhost/replica:0/task:0/device:CPU:0'"
343       "      attr {"
344       "        key: 'client_terminated'"
345       "        value {"
346       "          b: true"
347       "        }"
348       "      }"
349       "      attr {"
350       "        key: 'recv_device'"
351       "        value {"
352       "          s: '/job:localhost/replica:0/task:0/device:CPU:0'"
353       "        }"
354       "      }"
355       "      attr {"
356       "        key: 'send_device'"
357       "        value {"
358       "          s: '/job:localhost/replica:0/task:0/device:CPU:0'"
359       "        }"
360       "      }"
361       "      attr {"
362       "        key: 'send_device_incarnation'"
363       "        value {"
364       "          i: 1"
365       "        }"
366       "      }"
367       "      attr {"
368       "        key: 'tensor_name'"
369       "        value {"
370       "          s: 't0'"
371       "        }"
372       "      }"
373       "      attr {"
374       "        key: 'tensor_type'"
375       "        value {"
376       "          type: DT_FLOAT"
377       "        }"
378       "      }"
379       "    }"
380       "    ret {"
381       "      key: 'recv_tensor'"
382       "      value: 'recv_node:tensor:0'"
383       "    }",
384       &def));
385   return def;
386 }
387 
388 // Test creates a context and attempts to execute some ops.
TEST_F(EagerServiceImplTest,BasicTest)389 TEST_F(EagerServiceImplTest, BasicTest) {
390   TestEagerServiceImpl eager_service_impl(&worker_env_);
391 
392   uint64 context_id = random::New64();
393 
394   CreateContextRequest request;
395   request.mutable_server_def()->set_job_name("localhost");
396   request.mutable_server_def()->set_task_index(0);
397   request.set_context_id(context_id);
398   CreateContextResponse response;
399 
400   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
401 
402   EnqueueRequest remote_enqueue_request;
403   remote_enqueue_request.set_context_id(context_id);
404   EnqueueResponse remote_enqueue_response;
405 
406   std::unordered_map<string, AttrValue> const_attrs;
407   AttrValue val;
408   val.set_type(tensorflow::DataType::DT_FLOAT);
409   const_attrs.insert({"dtype", val});
410   val.Clear();
411   SetTensorProto(val.mutable_tensor());
412   const_attrs.insert({"value", val});
413 
414   AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
415                                "/job:localhost/replica:0/task:0/device:CPU:0",
416                                &remote_enqueue_request);
417 
418   std::unordered_map<string, AttrValue> attrs;
419   val.Clear();
420   val.set_type(tensorflow::DataType::DT_FLOAT);
421   attrs.insert({"T", val});
422   val.Clear();
423   val.set_b(false);
424   attrs.insert({"transpose_a", val});
425   attrs.insert({"transpose_b", val});
426 
427   AddOperationToEnqueueRequest(
428       2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
429       "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
430 
431   TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
432                                           &remote_enqueue_response));
433 
434   auto& matmul_result_shape =
435       remote_enqueue_response.queue_response(1).shape(0);
436   EXPECT_EQ(matmul_result_shape.dim(0).size(), 2);
437   EXPECT_EQ(matmul_result_shape.dim(1).size(), 2);
438 
439   tensorflow::TensorHandle* tensor_handle;
440   TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
441       context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
442 
443   // This should be OK to do since we've placed all computation on the CPU
444   // device.
445   const tensorflow::Tensor* t = nullptr;
446   TF_ASSERT_OK(tensor_handle->Tensor(&t));
447 
448   auto actual = t->flat<float>();
449 
450   EXPECT_EQ(4, actual.size());
451 
452   EXPECT_EQ(7, actual(0));
453   EXPECT_EQ(10, actual(1));
454   EXPECT_EQ(15, actual(2));
455   EXPECT_EQ(22, actual(3));
456 
457   CloseContextRequest close_context_request;
458   close_context_request.set_context_id(context_id);
459   close_context_request.set_context_view_id(0);
460   CloseContextResponse close_context_response;
461   TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
462                                                &close_context_response));
463 }
464 
465 class EagerServiceImplFunctionTest : public EagerServiceImplTest {
466  public:
EagerServiceImplFunctionTest()467   EagerServiceImplFunctionTest() : EagerServiceImplTest() {}
468 
469   // Creates a context and attempts to execute a function.
TestFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool local_inputs=false,const bool test_cancel=false)470   void TestFunction(const RegisterFunctionOp& register_op,
471                     const string& function_name,
472                     const bool local_inputs = false,
473                     const bool test_cancel = false) {
474     TestEagerServiceImpl eager_service_impl(&worker_env_);
475 
476     uint64 context_id = random::New64();
477 
478     CreateContextRequest request;
479     request.mutable_server_def()->set_job_name("localhost");
480     request.mutable_server_def()->set_task_index(0);
481     request.set_context_id(context_id);
482     CreateContextResponse response;
483 
484     TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
485 
486     EnqueueRequest enqueue_request;
487     enqueue_request.set_context_id(context_id);
488     *enqueue_request.add_queue()->mutable_register_function() = register_op;
489     EnqueueResponse enqueue_response;
490 
491     TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
492                                             &enqueue_response));
493 
494     EnqueueRequest remote_enqueue_request;
495     remote_enqueue_request.set_context_id(context_id);
496     EnqueueResponse remote_enqueue_response;
497 
498     if (local_inputs) {
499       TensorProto tensor_proto;
500       SetTensorProto(&tensor_proto);
501       AddOperationToEnqueueRequest(
502           2, function_name, {tensor_proto},
503           std::unordered_map<string, AttrValue>(),
504           "/job:localhost/replica:0/task:0/device:CPU:0",
505           &remote_enqueue_request);
506 
507     } else {
508       std::unordered_map<string, AttrValue> const_attrs;
509       AttrValue val;
510       val.set_type(tensorflow::DataType::DT_FLOAT);
511       const_attrs.insert({"dtype", val});
512       val.Clear();
513 
514       SetTensorProto(val.mutable_tensor());
515       const_attrs.insert({"value", val});
516 
517       AddOperationToEnqueueRequest(
518           1, "Const", {}, const_attrs,
519           "/job:localhost/replica:0/task:0/device:CPU:0",
520           &remote_enqueue_request);
521       AddOperationToEnqueueRequest(
522           2, function_name, {std::make_pair(1, 0)},
523           std::unordered_map<string, AttrValue>(),
524           "/job:localhost/replica:0/task:0/device:CPU:0",
525           &remote_enqueue_request);
526     }
527 
528     CallOptions call_opts;
529     Status status;
530     Notification n;
531     Env::Default()->SchedClosure([&] {
532       status = eager_service_impl.Enqueue(&call_opts, &remote_enqueue_request,
533                                           &remote_enqueue_response);
534       n.Notify();
535     });
536 
537     if (test_cancel) {
538       // Wait to let the Enqueue thread starts running
539       Env::Default()->SleepForMicroseconds(500000);
540       call_opts.StartCancel();
541       n.WaitForNotification();
542       EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
543     } else {
544       n.WaitForNotification();
545       TF_ASSERT_OK(status);
546       const tensorflow::Tensor* t = nullptr;
547       tensorflow::TensorHandle* tensor_handle;
548       TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
549           context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
550       TF_ASSERT_OK(tensor_handle->Tensor(&t));
551 
552       auto actual = t->flat<float>();
553       EXPECT_EQ(4, actual.size());
554 
555       EXPECT_EQ(7, actual(0));
556       EXPECT_EQ(10, actual(1));
557       EXPECT_EQ(15, actual(2));
558       EXPECT_EQ(22, actual(3));
559     }
560 
561     CloseContextRequest close_context_request;
562     close_context_request.set_context_id(context_id);
563     close_context_request.set_context_view_id(0);
564     CloseContextResponse close_context_response;
565     TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
566                                                  &close_context_response));
567   }
568 
569   // Creates a context and attempts to execute a component function.
TestComponentFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool test_cancel)570   void TestComponentFunction(const RegisterFunctionOp& register_op,
571                              const string& function_name,
572                              const bool test_cancel) {
573     TestEagerServiceImpl eager_service_impl(&worker_env_);
574     uint64 context_id = random::New64();
575 
576     // Create context.
577     CreateContextRequest request;
578     request.mutable_server_def()->set_job_name("localhost");
579     request.mutable_server_def()->set_task_index(0);
580     request.set_context_id(context_id);
581     CreateContextResponse response;
582     TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
583 
584     // Register function.
585     EnqueueRequest enqueue_request;
586     enqueue_request.set_context_id(context_id);
587     *enqueue_request.add_queue()->mutable_register_function() = register_op;
588     EnqueueResponse enqueue_response;
589     TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
590                                             &enqueue_response));
591 
592     // First run an op to generate input for function.
593     EnqueueRequest remote_enqueue_request;
594     remote_enqueue_request.set_context_id(context_id);
595     EnqueueResponse remote_enqueue_response;
596 
597     std::unordered_map<string, AttrValue> const_attrs;
598     AttrValue val;
599     val.set_type(tensorflow::DataType::DT_FLOAT);
600     const_attrs.insert({"dtype", val});
601     val.Clear();
602     SetTensorProto(val.mutable_tensor());
603     const_attrs.insert({"value", val});
604     AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
605                                  "/job:localhost/replica:0/task:0/device:CPU:0",
606                                  &remote_enqueue_request);
607     TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
608                                             &remote_enqueue_response));
609 
610     // Run function with input from the previous op.
611     RunComponentFunctionRequest run_comp_func_request;
612     run_comp_func_request.set_context_id(context_id);
613     RunComponentFunctionResponse run_comp_func_response;
614     const int output_num = 5;
615     AddOperationToRunComponentFunctionRequest(
616         2, function_name, {std::make_pair(1, 0)},
617         std::unordered_map<string, AttrValue>(),
618         "/job:localhost/replica:0/task:0/device:CPU:0", output_num,
619         &run_comp_func_request);
620 
621     CallOptions call_opts;
622     Notification n;
623     Status status;
624     eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request,
625                                             &run_comp_func_response,
626                                             [&status, &n](const Status& s) {
627                                               status.Update(s);
628                                               n.Notify();
629                                             });
630     if (test_cancel) {
631       call_opts.StartCancel();
632     }
633     n.WaitForNotification();
634     if (test_cancel) {
635       EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
636     } else {
637       TF_ASSERT_OK(status);
638       // Retrieve the output.
639       const tensorflow::Tensor* t = nullptr;
640       tensorflow::TensorHandle* tensor_handle;
641       TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
642           context_id, RemoteTensorHandleInternal(2, output_num),
643           &tensor_handle));
644       TF_ASSERT_OK(tensor_handle->Tensor(&t));
645 
646       auto actual = t->flat<float>();
647       EXPECT_EQ(4, actual.size());
648 
649       EXPECT_EQ(7, actual(0));
650       EXPECT_EQ(10, actual(1));
651       EXPECT_EQ(15, actual(2));
652       EXPECT_EQ(22, actual(3));
653     }
654 
655     CloseContextRequest close_context_request;
656     close_context_request.set_context_id(context_id);
657     close_context_request.set_context_view_id(0);
658     CloseContextResponse close_context_response;
659     TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
660                                                  &close_context_response));
661   }
662 };
663 
TEST_F(EagerServiceImplFunctionTest,BasicFunctionTest)664 TEST_F(EagerServiceImplFunctionTest, BasicFunctionTest) {
665   RegisterFunctionOp register_op;
666   *register_op.mutable_function_def() = MatMulFunction();
667   TestFunction(register_op, "MatMulFunction");
668 }
669 
TEST_F(EagerServiceImplFunctionTest,FunctionWithLocalInputsTest)670 TEST_F(EagerServiceImplFunctionTest, FunctionWithLocalInputsTest) {
671   RegisterFunctionOp register_op;
672   *register_op.mutable_function_def() = MatMulFunction();
673   TestFunction(register_op, "MatMulFunction", /*local_inputs=*/true);
674 }
675 
TEST_F(EagerServiceImplFunctionTest,NestedFunctionTest)676 TEST_F(EagerServiceImplFunctionTest, NestedFunctionTest) {
677   RegisterFunctionOp register_op;
678   *register_op.mutable_function_def() = MatMulNestedFunction();
679   *register_op.mutable_library()->add_function() = MatMulFunction();
680   TestFunction(register_op, "MatMulNestedFunction");
681 }
682 
TEST_F(EagerServiceImplFunctionTest,FunctionCancellationTest)683 TEST_F(EagerServiceImplFunctionTest, FunctionCancellationTest) {
684   RegisterFunctionOp register_op;
685   *register_op.mutable_function_def() = SingleRecvNodeFunction();
686   TestFunction(register_op, "SingleRecvNodeFunction", /*local_inputs=*/false,
687                /*test_cancel=*/true);
688 }
689 
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionTest)690 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionTest) {
691   RegisterFunctionOp register_op;
692   *register_op.mutable_function_def() = MatMulFunction();
693   TestComponentFunction(register_op, "MatMulFunction", false);
694 }
695 
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionCancellationTest)696 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionCancellationTest) {
697   RegisterFunctionOp register_op;
698   *register_op.mutable_function_def() = SingleRecvNodeFunction();
699   TestComponentFunction(register_op, "SingleRecvNodeFunction", true);
700 }
701 
702 class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
703  public:
FunctionWithRemoteInputsTest()704   FunctionWithRemoteInputsTest()
705       : EagerServiceImplTest(), eager_service_impl_(&worker_env_) {
706     remote_device_mgr_ = absl::make_unique<StaticDeviceMgr>(
707         DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
708     context_id_ = random::New64();
709   }
710 
711   class TestExecuteNodeArgs : public EagerKernelArgs {
712    public:
TestExecuteNodeArgs(gtl::InlinedVector<TensorValue,4> && tensor_args,std::function<Status (const int,eager::RemoteTensorHandle *)> serialize_remote_handle)713     TestExecuteNodeArgs(
714         gtl::InlinedVector<TensorValue, 4>&& tensor_args,
715         std::function<Status(const int, eager::RemoteTensorHandle*)>
716             serialize_remote_handle)
717         : EagerKernelArgs(std::move(tensor_args)),
718           serialize_remote_handle_(std::move(serialize_remote_handle)) {}
719 
HasRemoteOrPackedInputs() const720     bool HasRemoteOrPackedInputs() const override { return true; }
721 
GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val) const722     Status GetRemoteArg(const FunctionArgIndex& index,
723                         eager::RemoteTensorHandle* val) const override {
724       return serialize_remote_handle_(index.index, val);
725     }
726 
727    private:
728     std::function<Status(const int, eager::RemoteTensorHandle*)>
729         serialize_remote_handle_;
730   };
731 
MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef & fdef)732   bool MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef& fdef) {
733     for (const auto& node : fdef.node_def()) {
734       if (node.op() == "MatMul") {
735         return node.attr().find("transpose_a") != node.attr().end();
736       }
737     }
738     return false;
739   }
740 
Init()741   void Init() {
742     CreateContextRequest request;
743     request.mutable_server_def()->set_job_name("localhost");
744     request.mutable_server_def()->set_task_index(0);
745     request.set_context_id(context_id_);
746     CreateContextResponse response;
747     TF_ASSERT_OK(eager_service_impl_.CreateContext(&request, &response));
748 
749     // Make the fake EagerClient use the local eager_service_impl.
750     EagerContext* ctx = nullptr;
751     TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
752     Device* device;
753     TF_ASSERT_OK(ctx->FindDeviceFromName(local_device_.c_str(), &device));
754     core::RefCountPtr<EagerClient> client;
755     TF_ASSERT_OK(ctx->GetClient(device, &client));
756     FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client.get());
757     fake_client->SetServiceImpl(&eager_service_impl_);
758 
759     // Create an input on local_device for MatMulFunction.
760     EnqueueRequest remote_enqueue_request;
761     remote_enqueue_request.set_context_id(context_id_);
762     EnqueueResponse remote_enqueue_response;
763     std::unordered_map<string, AttrValue> const_attrs;
764     AttrValue val;
765     val.set_type(tensorflow::DataType::DT_FLOAT);
766     const_attrs.insert({"dtype", val});
767     val.Clear();
768     SetTensorProto(val.mutable_tensor());
769     const_attrs.insert({"value", val});
770     AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, local_device_,
771                                  &remote_enqueue_request);
772     TF_EXPECT_OK(eager_service_impl_.Enqueue(nullptr, &remote_enqueue_request,
773                                              &remote_enqueue_response));
774     eager_cluster_flr_ = absl::make_unique<EagerClusterFunctionLibraryRuntime>(
775         context_id_, ctx, device_mgr_.get());
776 
777     fdef_ = MatMulFunction();
778     TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
779     eager_pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
780         remote_device_mgr_.get(), Env::Default(), /*config=*/
781         nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
782         /*thread_pool=*/nullptr, eager_cluster_flr_.get(),
783         /*session_metadata=*/nullptr,
784         Rendezvous::Factory{[this](const int64_t step_id,
785                                    const DeviceMgr* device_mgr,
786                                    Rendezvous** r) {
787           *r = worker_env_.rendezvous_mgr->Find(step_id);
788           return Status::OK();
789         }});
790   }
791 
CheckOutputTensorAndClose(const Tensor & tensor)792   void CheckOutputTensorAndClose(const Tensor& tensor) {
793     auto actual = tensor.flat<float>();
794     EXPECT_EQ(4, actual.size());
795     EXPECT_EQ(7, actual(0));
796     EXPECT_EQ(10, actual(1));
797     EXPECT_EQ(15, actual(2));
798     EXPECT_EQ(22, actual(3));
799 
800     CloseContextRequest close_context_request;
801     close_context_request.set_context_id(context_id_);
802     close_context_request.set_context_view_id(0);
803     CloseContextResponse close_context_response;
804     TF_ASSERT_OK(eager_service_impl_.CloseContext(&close_context_request,
805                                                   &close_context_response));
806   }
807 
CheckOutputsAndClose(const std::vector<FunctionRet> & outputs,const int64_t op_id)808   void CheckOutputsAndClose(const std::vector<FunctionRet>& outputs,
809                             const int64_t op_id) {
810     const tensorflow::Tensor* t = nullptr;
811     tensorflow::TensorHandle* tensor_handle;
812     TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
813         context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle));
814     TF_ASSERT_OK(tensor_handle->Tensor(&t));
815     EXPECT_EQ(outputs.size(), 1);
816     EXPECT_EQ(outputs.at(0).index(), 1);
817     const TensorShape& shape = absl::get<TensorShape>(outputs.at(0));
818     EXPECT_EQ(shape, t->shape());
819     CheckOutputTensorAndClose(*t);
820   }
821 
822  protected:
823   const string local_device_ = "/job:localhost/replica:0/task:0/device:CPU:0";
824   const string remote_device_ = "/job:localhost/replica:0/task:1/device:CPU:0";
825   TestEagerServiceImpl eager_service_impl_;
826   std::unique_ptr<DeviceMgr> remote_device_mgr_;
827   uint64 context_id_;
828   tensorflow::FunctionDef fdef_;
829   std::unique_ptr<ProcessFunctionLibraryRuntime> eager_pflr_;
830   std::unique_ptr<EagerClusterFunctionLibraryRuntime> eager_cluster_flr_;
831   FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
832 };
833 
834 // Test executes a remote function through
835 // ProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime).
TEST_F(FunctionWithRemoteInputsTest,EagerPFLRTest)836 TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
837   Init();
838   // Instantiate MatMulFunction on remote_device.
839   FunctionLibraryRuntime::InstantiateOptions options;
840   options.target = remote_device_;
841   options.is_multi_device_function = true;
842   options.input_devices.push_back(local_device_);
843   FunctionLibraryRuntime::Handle handle;
844   EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
845   TF_ASSERT_OK(eager_pflr_->Instantiate(
846       fdef_.signature().name(), AttrSlice(&fdef_.attr()), options, &handle));
847   EagerContext* ctx = nullptr;
848   TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
849   for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
850     const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
851     EXPECT_TRUE(fdef != nullptr);
852     if (absl::StartsWith(func_name, "MatMulFunction")) {
853       EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
854     }
855   }
856   bool is_cross_process = false;
857   TF_CHECK_OK(eager_pflr_->IsCrossProcess(handle, &is_cross_process));
858   EXPECT_TRUE(is_cross_process);
859 
860   // Run MatMulFunction on remote_device.
861   FunctionLibraryRuntime::Options opts;
862   const uint64 op_id = 2;
863   opts.op_id = op_id;
864   Notification done;
865   Status status;
866   RemoteTensorHandle input;
867   input.set_op_id(1);
868   input.set_output_num(0);
869   input.set_op_device(local_device_);
870   input.set_device(local_device_);
871   std::vector<RemoteTensorHandle> inputs = {input};
872   std::vector<FunctionRet> outputs;
873   gtl::InlinedVector<TensorValue, 4> tensor_args = {TensorValue()};
874   TestExecuteNodeArgs args(
875       std::move(tensor_args),
876       [&inputs](const int i, RemoteTensorHandle* handle) -> Status {
877         *handle = inputs.at(i);
878         return Status::OK();
879       });
880   eager_pflr_->Run(opts, handle, args, &outputs,
881                    [&status, &done](const Status& s) {
882                      status = s;
883                      done.Notify();
884                    });
885   done.WaitForNotification();
886   TF_ASSERT_OK(status);
887   CheckOutputsAndClose(outputs, op_id);
888 }
889 
890 // Test executes a remote function with local input and output tensors.
TEST_F(FunctionWithRemoteInputsTest,EagerClusterFLRTestWithLocalInputAndOutput)891 TEST_F(FunctionWithRemoteInputsTest,
892        EagerClusterFLRTestWithLocalInputAndOutput) {
893   Init();
894   // Instantiate MatMulFunction on remote_device.
895   FunctionLibraryRuntime::Handle handle;
896   EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
897   Status status;
898   Notification instantiate_done;
899   eager_cluster_flr_->Instantiate(
900       fdef_.signature().name(), func_lib_def_, AttrSlice(&fdef_.attr()),
901       FunctionLibraryRuntime::InstantiateOptions(), &handle,
902       [&status, &instantiate_done](const Status& s) {
903         status = s;
904         instantiate_done.Notify();
905       });
906   instantiate_done.WaitForNotification();
907   TF_ASSERT_OK(status);
908   EagerContext* ctx = nullptr;
909   TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
910   for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
911     const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
912     EXPECT_TRUE(fdef != nullptr);
913     if (absl::StartsWith(func_name, "MatMulFunction")) {
914       EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
915     }
916   }
917   const tensorflow::Tensor* input_tensor = nullptr;
918   tensorflow::TensorHandle* tensor_handle;
919   TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
920       context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle));
921   TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor));
922 
923   // Send input_tensor to the remote device, execute MatMulFunction on the
924   // remote device, and send the output back.
925   FunctionLibraryRuntime::Options opts;
926   Notification execute_done;
927   std::vector<Tensor> inputs = {*input_tensor};
928   std::vector<Tensor> outputs;
929   eager_cluster_flr_->Run(opts, handle, inputs, &outputs,
930                           [&status, &execute_done](const Status& s) {
931                             status = s;
932                             execute_done.Notify();
933                           });
934   execute_done.WaitForNotification();
935   TF_ASSERT_OK(status);
936   EXPECT_EQ(outputs.size(), 1);
937   CheckOutputTensorAndClose(outputs.at(0));
938 }
939 
940 // Test executes a remote function through KernelAndDeviceFunc::Run.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncTest)941 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
942   Init();
943   Device* local_device;
944   TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
945   std::vector<Device*> input_dev_ptrs;
946   input_dev_ptrs.push_back(local_device);
947   FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
948   EagerContext* ctx = nullptr;
949   TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
950   core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
951   const int64_t op_id = 2;
952   kernel.reset(new KernelAndDeviceFunc(
953       flr, eager_pflr_.get(), std::move(input_dev_ptrs),
954       /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
955       /*runner=*/nullptr,
956       /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
957       /*outputs_on_op_device=*/false, ctx->RendezvousCreator(),
958       [=]() { return op_id; }));
959 
960   // Instantiate MatMulFunction on remote_device.
961   const NodeDef node_def = MatMulFunctionNodeDef();
962   TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
963 
964   // Run MatMulFunction on remote_device.
965   gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
966   RemoteTensorHandle input;
967   input.set_op_id(1);
968   input.set_output_num(0);
969   input.set_op_device(local_device_);
970   input.set_device(local_device_);
971   std::vector<RemoteTensorHandle> remote_handles = {input};
972   TestExecuteNodeArgs inputs(
973       std::move(input_tensors),
974       [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
975         *handle = remote_handles.at(index);
976         return Status::OK();
977       });
978   std::vector<FunctionRet> outputs;
979 
980   TF_ASSERT_OK(kernel->Run(/*step_container=*/nullptr, inputs, &outputs,
981                            /*cancellation_manager=*/nullptr,
982                            /*remote_func_params=*/absl::nullopt,
983                            /*stack_trace=*/absl::nullopt,
984                            /*coordination_service_agent=*/nullptr));
985 
986   CheckOutputsAndClose(outputs, op_id);
987 }
988 
989 // Test executes a remote function through KernelAndDeviceFunc::RunAsync.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncAsyncTest)990 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
991   Init();
992   Device* local_device;
993   TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
994   std::vector<Device*> input_dev_ptrs;
995   input_dev_ptrs.push_back(local_device);
996   FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
997   EagerContext* ctx = nullptr;
998   TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
999   core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
1000   const int64_t op_id = 2;
1001   kernel.reset(new KernelAndDeviceFunc(
1002       flr, eager_pflr_.get(), std::move(input_dev_ptrs),
1003       /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
1004       /*runner=*/nullptr,
1005       /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
1006       /*outputs_on_op_device=*/false, ctx->RendezvousCreator(),
1007       [=]() { return op_id; }));
1008 
1009   // Instantiate MatMulFunction on remote_device.
1010   const NodeDef node_def = MatMulFunctionNodeDef();
1011   TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
1012 
1013   // Run MatMulFunction on remote_device.
1014   gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
1015   RemoteTensorHandle input;
1016   input.set_op_id(1);
1017   input.set_output_num(0);
1018   input.set_op_device(local_device_);
1019   input.set_device(local_device_);
1020   std::vector<RemoteTensorHandle> remote_handles = {input};
1021   TestExecuteNodeArgs inputs(
1022       std::move(input_tensors),
1023       [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
1024         *handle = remote_handles.at(index);
1025         return Status::OK();
1026       });
1027   std::vector<FunctionRet> outputs;
1028 
1029   Status status;
1030   Notification n;
1031   kernel->RunAsync(/*step_container=*/nullptr, inputs, &outputs,
1032                    /*cancellation_manager=*/nullptr,
1033                    /*remote_func_params=*/absl::nullopt,
1034                    /*coordination_service_agent=*/nullptr,
1035                    [&status, &n](const Status& s) {
1036                      status = s;
1037                      n.Notify();
1038                    });
1039   n.WaitForNotification();
1040   TF_ASSERT_OK(status);
1041   CheckOutputsAndClose(outputs, op_id);
1042 }
1043 
1044 // Test creates a context and attempts to send a tensor (using the RPC), and
1045 // then use the tensor.
TEST_F(EagerServiceImplTest,SendTensorTest)1046 TEST_F(EagerServiceImplTest, SendTensorTest) {
1047   TestEagerServiceImpl eager_service_impl(&worker_env_);
1048 
1049   uint64 context_id = random::New64();
1050 
1051   CreateContextRequest request;
1052   request.mutable_server_def()->set_job_name("localhost");
1053   request.mutable_server_def()->set_task_index(0);
1054   request.set_context_id(context_id);
1055   CreateContextResponse response;
1056 
1057   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1058 
1059   EnqueueRequest remote_enqueue_request;
1060   remote_enqueue_request.set_context_id(context_id);
1061   EnqueueResponse remote_enqueue_response;
1062 
1063   auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1064   send_tensor->set_op_id(1);
1065   SetTensorProto(send_tensor->add_tensors());
1066 
1067   std::unordered_map<string, AttrValue> attrs;
1068   AttrValue val;
1069   val.Clear();
1070   val.set_type(tensorflow::DataType::DT_FLOAT);
1071   attrs.insert({"T", val});
1072   val.Clear();
1073   val.set_b(false);
1074   attrs.insert({"transpose_a", val});
1075   attrs.insert({"transpose_b", val});
1076 
1077   AddOperationToEnqueueRequest(
1078       2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
1079       "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
1080 
1081   TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1082                                           &remote_enqueue_response));
1083 
1084   const tensorflow::Tensor* t = nullptr;
1085   tensorflow::TensorHandle* tensor_handle;
1086   TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1087       context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
1088   TF_ASSERT_OK(tensor_handle->Tensor(&t));
1089 
1090   EXPECT_EQ(tensor_handle->device(), nullptr);
1091 
1092   auto actual = t->flat<float>();
1093   EXPECT_EQ(4, actual.size());
1094 
1095   EXPECT_EQ(7, actual(0));
1096   EXPECT_EQ(10, actual(1));
1097   EXPECT_EQ(15, actual(2));
1098   EXPECT_EQ(22, actual(3));
1099 
1100   CloseContextRequest close_context_request;
1101   close_context_request.set_context_id(context_id);
1102   close_context_request.set_context_view_id(0);
1103   CloseContextResponse close_context_response;
1104   TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1105                                                &close_context_response));
1106 }
1107 
1108 // Test serializes and sends a pack TensorHandle.
TEST_F(EagerServiceImplTest,SendPackedHandleTest)1109 TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
1110   TestEagerServiceImpl eager_service_impl(&worker_env_);
1111 
1112   const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
1113   const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
1114   const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
1115   const string composite_device =
1116       "/job:localhost/replica:0/task:0/device:COMPOSITE:0";
1117 
1118   uint64 context_id = random::New64();
1119   CreateContextRequest request;
1120   auto* server_def = request.mutable_server_def();
1121   server_def->set_job_name("localhost");
1122   server_def->set_task_index(0);
1123   request.add_cluster_device_attributes()->set_name(device0);
1124   request.add_cluster_device_attributes()->set_name(device1);
1125   request.add_cluster_device_attributes()->set_name(device2);
1126   request.set_context_id(context_id);
1127   CreateContextResponse response;
1128 
1129   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1130 
1131   EnqueueRequest remote_enqueue_request;
1132   remote_enqueue_request.set_context_id(context_id);
1133   EnqueueResponse remote_enqueue_response;
1134 
1135   // Copy a tensor to device0
1136   auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1137   send_tensor->set_op_id(1);
1138   SetTensorProto(send_tensor->add_tensors());
1139 
1140   // Copy a packed handle to device0
1141   auto* send_packed_handle =
1142       remote_enqueue_request.add_queue()->mutable_send_packed_handle();
1143   send_packed_handle->set_op_id(3);
1144   RemoteTensorHandle* remote_handle =
1145       send_packed_handle->add_handles()->mutable_remote_handle();
1146   remote_handle->set_op_id(send_tensor->op_id());
1147   remote_handle->set_output_num(0);
1148   remote_handle->set_op_device(device0);
1149   remote_handle->set_device(device0);
1150 
1151   SendPackedHandleOp::LocalTensorHandle* lcoal_handle =
1152       send_packed_handle->add_handles()->mutable_local_handle();
1153   SetTensorProto(lcoal_handle->mutable_tensor());
1154   lcoal_handle->set_device(device1);
1155 
1156   remote_handle = send_packed_handle->add_handles()->mutable_remote_handle();
1157   remote_handle->set_op_id(2);
1158   remote_handle->set_output_num(5);
1159   remote_handle->set_op_device(device2);
1160   remote_handle->set_device(device2);
1161 
1162   TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1163                                           &remote_enqueue_response));
1164 
1165   tensorflow::TensorHandle* packed_handle;
1166   TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1167       context_id, RemoteTensorHandleInternal(3, 0), &packed_handle));
1168 
1169   EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
1170   EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
1171   EXPECT_EQ(packed_handle->device()->name(), composite_device);
1172 
1173   TensorHandle* handle0 = nullptr;
1174   TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
1175   EXPECT_EQ(handle0->Type(), TensorHandle::LOCAL);
1176   EXPECT_EQ(handle0->op_device()->name(), device0);
1177   const Tensor* t0 = nullptr;
1178   TF_ASSERT_OK(handle0->Tensor(&t0));
1179   auto actual = t0->flat<float>();
1180   EXPECT_EQ(4, actual.size());
1181   EXPECT_EQ(1.0, actual(0));
1182   EXPECT_EQ(2.0, actual(1));
1183   EXPECT_EQ(3.0, actual(2));
1184   EXPECT_EQ(4.0, actual(3));
1185 
1186   TensorHandle* handle1 = nullptr;
1187   TF_ASSERT_OK(packed_handle->ExtractPackedHandle(1, &handle1));
1188   EXPECT_EQ(handle1->Type(), TensorHandle::LOCAL);
1189   EXPECT_EQ(handle1->op_device()->name(), device1);
1190   const Tensor* t1 = nullptr;
1191   TF_ASSERT_OK(handle0->Tensor(&t1));
1192   EXPECT_EQ(t1, t0);
1193 
1194   TensorHandle* handle2 = nullptr;
1195   TF_ASSERT_OK(packed_handle->ExtractPackedHandle(2, &handle2));
1196   EXPECT_EQ(handle2->Type(), TensorHandle::REMOTE);
1197   EXPECT_EQ(handle2->op_device()->name(), device2);
1198   int64_t op_id;
1199   int32_t output_num;
1200   TF_ASSERT_OK(handle2->RemoteAddress(handle2->device(),
1201                                       /*wait_until_ready=*/true, &op_id,
1202                                       &output_num));
1203   EXPECT_EQ(op_id, 2);
1204   EXPECT_EQ(output_num, 5);
1205 
1206   CloseContextRequest close_context_request;
1207   close_context_request.set_context_id(context_id);
1208   close_context_request.set_context_view_id(0);
1209   CloseContextResponse close_context_response;
1210   TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1211                                                &close_context_response));
1212 }
1213 
1214 // Test requests sent to the eager service on master.
TEST_F(EagerServiceImplTest,RequestsToMasterTest)1215 TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
1216   tensorflow::Rendezvous* rendezvous =
1217       new tensorflow::IntraProcessRendezvous(device_mgr_.get());
1218   // Create a master eager context.
1219   tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
1220       SessionOptions(),
1221       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
1222       /*async=*/false, device_mgr_.get(), false, rendezvous);
1223   const uint64 context_id = random::New64();
1224 
1225   // Set RemoteMgr to ctx.
1226   auto remote_mgr =
1227       absl::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true, ctx);
1228   TF_ASSERT_OK(ctx->InitializeRemoteWorker(
1229       /*remote_eager_workers=*/nullptr, /*remote_device_mgr=*/nullptr,
1230       /*remote_contexts=*/{}, context_id, /*context_view_id=*/0,
1231       /*rendezvous_creator=*/nullptr,
1232       /*cluster_flr=*/nullptr, std::move(remote_mgr),
1233       /*resource_deallocator=*/nullptr));
1234 
1235   TestEagerServiceImpl eager_service_impl(&worker_env_);
1236 
1237   EnqueueRequest remote_enqueue_request;
1238   remote_enqueue_request.set_context_id(context_id);
1239   EnqueueResponse remote_enqueue_response;
1240 
1241   auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1242   send_tensor->set_op_id(1);
1243   SetTensorProto(send_tensor->add_tensors());
1244 
1245   // Unable to handle the request since there is no eager context.
1246   Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1247                                              &remote_enqueue_response);
1248   EXPECT_EQ(error::UNAVAILABLE, status.code());
1249   EXPECT_TRUE(absl::StrContains(
1250       status.error_message(),
1251       "Unable to find a context_id matching the specified one"));
1252 
1253   // The request can be handled after adding the master eager context to
1254   // service.
1255   TF_ASSERT_OK(eager_service_impl.CreateMasterContext(context_id, ctx));
1256   TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1257                                           &remote_enqueue_response));
1258   ctx->Unref();
1259 }
1260 
TEST_F(EagerServiceImplTest,KeepAliveTest)1261 TEST_F(EagerServiceImplTest, KeepAliveTest) {
1262   TestEagerServiceImpl eager_service_impl(&worker_env_);
1263 
1264   uint64 context_id = random::New64();
1265   CreateContextRequest request;
1266   request.mutable_server_def()->set_job_name("localhost");
1267   request.mutable_server_def()->set_task_index(0);
1268   request.set_context_id(context_id);
1269   request.set_keep_alive_secs(3);
1270   CreateContextResponse response;
1271 
1272   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1273 
1274   worker_env_.env->SleepForMicroseconds(5 *
1275                                         tensorflow::EnvTime::kSecondsToMicros);
1276 
1277   KeepAliveRequest keep_alive_request;
1278   KeepAliveResponse keep_alive_response;
1279 
1280   keep_alive_request.set_context_id(context_id);
1281 
1282   Status status =
1283       eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
1284 
1285   EXPECT_EQ(status.code(), error::UNAVAILABLE);
1286   EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
1287                       status.error_message());
1288 
1289   uint64 new_context_id = random::New64();
1290   // Create a new context.
1291   request.set_context_id(new_context_id);
1292   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1293 
1294   // The context should not be GC'd.
1295   worker_env_.env->SleepForMicroseconds(1 *
1296                                         tensorflow::EnvTime::kSecondsToMicros);
1297 
1298   keep_alive_request.set_context_id(new_context_id);
1299 
1300   TF_ASSERT_OK(
1301       eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response));
1302 }
1303 
1304 }  // namespace
1305 }  // namespace eager
1306 }  // namespace tensorflow
1307