• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <unordered_map>
22 #include <utility>
23 #include <variant>
24 #include <vector>
25 
26 #include "absl/types/optional.h"
27 #include "absl/types/variant.h"
28 #include "tensorflow/c/tf_tensor.h"
29 #include "tensorflow/c/tf_tensor_internal.h"
30 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
31 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
32 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
33 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
34 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
35 #include "tensorflow/core/distributed_runtime/session_mgr.h"
36 #include "tensorflow/core/distributed_runtime/test_utils.h"
37 #include "tensorflow/core/distributed_runtime/worker_env.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/lib/core/status_test_util.h"
40 #include "tensorflow/core/platform/errors.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/protobuf/eager_service.pb.h"
44 #include "tensorflow/core/protobuf/error_codes.pb.h"
45 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
46 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
47 
48 namespace tensorflow {
49 namespace eager {
50 namespace {
51 
52 class TestEagerServiceImpl : public EagerServiceImpl {
53  public:
TestEagerServiceImpl(const WorkerEnv * env)54   explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {}
GetEagerContext(const uint64 context_id,EagerContext ** ctx)55   Status GetEagerContext(const uint64 context_id, EagerContext** ctx) {
56     ServerContext* context = nullptr;
57     TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
58     core::ScopedUnref context_unref(context);
59     *ctx = context->Context();
60     return OkStatus();
61   }
GetTensorHandle(const uint64 context_id,const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)62   Status GetTensorHandle(const uint64 context_id,
63                          const RemoteTensorHandleInternal& remote_handle,
64                          tensorflow::TensorHandle** handle) {
65     ServerContext* context = nullptr;
66     TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
67     core::ScopedUnref context_unref(context);
68 
69     return context->Context()->RemoteMgr()->GetTensorHandle(remote_handle,
70                                                             handle);
71   }
72 };
73 
74 class FakeEagerClient : public EagerClient {
75  public:
FakeEagerClient()76   FakeEagerClient() {}
~FakeEagerClient()77   ~FakeEagerClient() override {}
78 
SetServiceImpl(TestEagerServiceImpl * impl)79   void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; }
80 
81 #define CLIENT_METHOD(method)                                         \
82   void method##Async(const method##Request* request,                  \
83                      method##Response* response, StatusCallback done) \
84       override {                                                      \
85     done(impl_->method(request, response));                           \
86   }
87 
88   CLIENT_METHOD(CreateContext);
89   CLIENT_METHOD(UpdateContext);
90   CLIENT_METHOD(WaitQueueDone);
91   CLIENT_METHOD(KeepAlive);
92   CLIENT_METHOD(CloseContext);
93 #undef CLIENT_METHOD
94 
EnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)95   void EnqueueAsync(CallOptions* call_opts, const EnqueueRequest* request,
96                     EnqueueResponse* response, StatusCallback done) override {
97     done(impl_->Enqueue(call_opts, request, response));
98   }
99 
RunComponentFunctionAsync(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)100   void RunComponentFunctionAsync(CallOptions* call_opts,
101                                  const RunComponentFunctionRequest* request,
102                                  RunComponentFunctionResponse* response,
103                                  StatusCallback done) override {
104     impl_->RunComponentFunction(call_opts, request, response, std::move(done));
105   }
106 
StreamingEnqueueAsync(bool enable_streaming_enqueue,CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)107   void StreamingEnqueueAsync(bool enable_streaming_enqueue,
108                              CallOptions* call_opts,
109                              const EnqueueRequest* request,
110                              EnqueueResponse* response,
111                              StatusCallback done) override {
112     done(impl_->Enqueue(nullptr, request, response));
113   }
114 
allow_multiple_pending_requests() const115   bool allow_multiple_pending_requests() const override { return false; }
116 
117  private:
118   TestEagerServiceImpl* impl_;
119 };
120 
121 class DummyEagerClientCache : public EagerClientCache {
122  public:
DummyEagerClientCache()123   DummyEagerClientCache() : client_(new FakeEagerClient) {}
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)124   Status GetClient(const string& target,
125                    core::RefCountPtr<EagerClient>* client) override {
126     client->reset(client_.get());
127     client_->Ref();
128     return OkStatus();
129   }
130 
131  private:
132   core::RefCountPtr<EagerClient> client_;
133 };
134 
135 class FakeCache : public TestWorkerCache {
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)136   Status GetEagerClientCache(
137       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
138     *eager_client_cache = std::make_unique<DummyEagerClientCache>();
139     return OkStatus();
140   }
141 
ListWorkers(std::vector<string> * workers) const142   void ListWorkers(std::vector<string>* workers) const override {
143     workers->push_back("/job:localhost/replica:0/task:0");
144   }
145 };
146 
147 class EagerServiceImplTest : public ::testing::Test {
148  public:
EagerServiceImplTest()149   EagerServiceImplTest()
150       : rendezvous_mgr_(&worker_env_),
151         session_mgr_(new SessionMgr(
152             &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
153             std::unique_ptr<WorkerCacheInterface>(new FakeCache),
154             [](const ServerDef& server_def,
155                WorkerCacheInterface** worker_cache) {
156               *worker_cache = new FakeCache;
157               return OkStatus();
158             })) {
159     worker_env_.env = Env::Default();
160 
161     worker_env_.rendezvous_mgr = &rendezvous_mgr_;
162     worker_env_.session_mgr = session_mgr_.get();
163 
164     device_mgr_ = std::make_unique<StaticDeviceMgr>(
165         DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
166     worker_env_.local_devices = device_mgr_->ListDevices();
167     worker_env_.device_mgr = device_mgr_.get();
168   }
169 
170  protected:
171   WorkerEnv worker_env_;
172   tensorflow::RpcRendezvousMgr rendezvous_mgr_;
173   std::unique_ptr<SessionMgr> session_mgr_;
174   std::unique_ptr<DeviceMgr> device_mgr_;
175 };
176 
SetTensorProto(TensorProto * tensor_proto)177 void SetTensorProto(TensorProto* tensor_proto) {
178   int64_t dims[] = {2, 2};
179   float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
180   TF_Tensor* t = TF_AllocateTensor(
181       TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
182   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
183   tensorflow::Tensor tensor;
184   TF_ASSERT_OK(tensorflow::TF_TensorToTensor(t, &tensor));
185   tensor.AsProtoTensorContent(tensor_proto);
186   TF_DeleteTensor(t);
187 }
188 
BuildOperation(Operation * operation,int64_t id,const string & name,const std::vector<std::variant<TensorProto,std::pair<int64_t,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device)189 void BuildOperation(
190     Operation* operation, int64_t id, const string& name,
191     const std::vector<std::variant<TensorProto, std::pair<int64_t, int32>>>&
192         inputs,
193     const std::unordered_map<string, AttrValue>& attrs, const string& device) {
194   operation->set_id(id);
195   operation->set_name(name);
196   operation->set_device(device);
197 
198   for (const auto& input : inputs) {
199     if (input.index() == 0) {
200       *operation->add_op_inputs()->mutable_tensor() =
201           std::get<TensorProto>(input);
202     } else {
203       const auto& tensor_handle_pair =
204           std::get<std::pair<int64_t, int32>>(input);
205       auto* input = operation->add_op_inputs()->mutable_remote_handle();
206       input->set_op_id(tensor_handle_pair.first);
207       input->set_output_num(tensor_handle_pair.second);
208       input->set_op_device(device);
209       input->set_device(device);
210     }
211   }
212 
213   for (const auto& attr_entry : attrs) {
214     (*operation->mutable_attrs())[attr_entry.first] = attr_entry.second;
215   }
216 }
217 
AddOperationToEnqueueRequest(int64_t id,const string & name,const std::vector<std::variant<TensorProto,std::pair<int64_t,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,EnqueueRequest * request)218 void AddOperationToEnqueueRequest(
219     int64_t id, const string& name,
220     const std::vector<std::variant<TensorProto, std::pair<int64_t, int32>>>&
221         inputs,
222     const std::unordered_map<string, AttrValue>& attrs, const string& device,
223     EnqueueRequest* request) {
224   auto* operation = request->add_queue()->mutable_operation();
225   BuildOperation(operation, id, name, inputs, attrs, device);
226 }
227 
AddOperationToRunComponentFunctionRequest(int64_t id,const string & name,const std::vector<std::variant<TensorProto,std::pair<int64_t,int32>>> & inputs,const std::unordered_map<string,AttrValue> & attrs,const string & device,const int output_num,RunComponentFunctionRequest * request)228 void AddOperationToRunComponentFunctionRequest(
229     int64_t id, const string& name,
230     const std::vector<std::variant<TensorProto, std::pair<int64_t, int32>>>&
231         inputs,
232     const std::unordered_map<string, AttrValue>& attrs, const string& device,
233     const int output_num, RunComponentFunctionRequest* request) {
234   auto* operation = request->mutable_operation();
235   operation->set_is_function(true);
236   operation->set_is_component_function(true);
237   request->add_output_num(output_num);
238   BuildOperation(operation, id, name, inputs, attrs, device);
239 }
240 
MatMulFunctionNodeDef()241 tensorflow::NodeDef MatMulFunctionNodeDef() {
242   tensorflow::NodeDef def;
243   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
244       "    name: 'matmul_func'"
245       "    op: 'MatMulFunction'"
246       "    input: 'a'"
247       "    input: 'a'"
248       "    attr {"
249       "      key: 'T'"
250       "      value {"
251       "        type: DT_FLOAT"
252       "      }"
253       "    }",
254       &def));
255   return def;
256 }
257 
MatMulFunction()258 tensorflow::FunctionDef MatMulFunction() {
259   tensorflow::FunctionDef def;
260   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
261       "    signature {"
262       "      name: 'MatMulFunction'"
263       "      input_arg {"
264       "        name: 'a'"
265       "        type: DT_FLOAT"
266       "      }"
267       "      output_arg {"
268       "        name: 'm'"
269       "        type: DT_FLOAT"
270       "      }"
271       "    }"
272       "    node_def {"
273       "      name: 'matmul'"
274       "      op: 'MatMul'"
275       "      input: 'a'"
276       "      input: 'a'"
277       "      attr {"
278       "        key: 'T'"
279       "        value {"
280       "          type: DT_FLOAT"
281       "        }"
282       "      }"
283       "      attr {"
284       "        key: 'transpose_a'"
285       "        value {"
286       "          b: false"
287       "        }"
288       "      }"
289       "    }"
290       "    ret {"
291       "      key: 'm'"
292       "      value: 'matmul:product'"
293       "    }",
294       &def));
295   return def;
296 }
297 
MatMulNestedFunction()298 tensorflow::FunctionDef MatMulNestedFunction() {
299   tensorflow::FunctionDef def;
300   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
301       "    signature {"
302       "      name: 'MatMulNestedFunction'"
303       "      input_arg {"
304       "        name: 'a'"
305       "        type: DT_FLOAT"
306       "      }"
307       "      output_arg {"
308       "        name: 'matmul_nested'"
309       "        type: DT_FLOAT"
310       "      }"
311       "    }"
312       "    node_def {"
313       "      name: 'matmul_nested'"
314       "      op: 'MatMulFunction'"
315       "      input: 'a'"
316       "      attr {"
317       "        key: 'T'"
318       "        value {"
319       "          type: DT_FLOAT"
320       "        }"
321       "      }"
322       "    }"
323       "    ret {"
324       "      key: 'matmul_nested'"
325       "      value: 'matmul_nested:m:0'"
326       "    }",
327       &def));
328   return def;
329 }
330 
SingleRecvNodeFunction()331 tensorflow::FunctionDef SingleRecvNodeFunction() {
332   tensorflow::FunctionDef def;
333   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
334       "    signature {"
335       "      name: 'SingleRecvNodeFunction'"
336       "      input_arg {"
337       "        name: 'a'"
338       "        type: DT_FLOAT"
339       "      }"
340       "      output_arg {"
341       "        name: 'recv_tensor'"
342       "        type: DT_FLOAT"
343       "      }"
344       "    }"
345       "    node_def {"
346       "      name: 'recv_node'"
347       "      op: '_Recv'"
348       "      device: '/job:localhost/replica:0/task:0/device:CPU:0'"
349       "      attr {"
350       "        key: 'client_terminated'"
351       "        value {"
352       "          b: true"
353       "        }"
354       "      }"
355       "      attr {"
356       "        key: 'recv_device'"
357       "        value {"
358       "          s: '/job:localhost/replica:0/task:0/device:CPU:0'"
359       "        }"
360       "      }"
361       "      attr {"
362       "        key: 'send_device'"
363       "        value {"
364       "          s: '/job:localhost/replica:0/task:0/device:CPU:0'"
365       "        }"
366       "      }"
367       "      attr {"
368       "        key: 'send_device_incarnation'"
369       "        value {"
370       "          i: 1"
371       "        }"
372       "      }"
373       "      attr {"
374       "        key: 'tensor_name'"
375       "        value {"
376       "          s: 't0'"
377       "        }"
378       "      }"
379       "      attr {"
380       "        key: 'tensor_type'"
381       "        value {"
382       "          type: DT_FLOAT"
383       "        }"
384       "      }"
385       "    }"
386       "    ret {"
387       "      key: 'recv_tensor'"
388       "      value: 'recv_node:tensor:0'"
389       "    }",
390       &def));
391   return def;
392 }
393 
394 // Test creates a context and attempts to execute some ops.
TEST_F(EagerServiceImplTest,BasicTest)395 TEST_F(EagerServiceImplTest, BasicTest) {
396   TestEagerServiceImpl eager_service_impl(&worker_env_);
397 
398   uint64 context_id = random::New64();
399 
400   CreateContextRequest request;
401   request.mutable_server_def()->set_job_name("localhost");
402   request.mutable_server_def()->set_task_index(0);
403   request.set_context_id(context_id);
404   CreateContextResponse response;
405 
406   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
407 
408   EnqueueRequest remote_enqueue_request;
409   remote_enqueue_request.set_context_id(context_id);
410   EnqueueResponse remote_enqueue_response;
411 
412   std::unordered_map<string, AttrValue> const_attrs;
413   AttrValue val;
414   val.set_type(tensorflow::DataType::DT_FLOAT);
415   const_attrs.insert({"dtype", val});
416   val.Clear();
417   SetTensorProto(val.mutable_tensor());
418   const_attrs.insert({"value", val});
419 
420   AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
421                                "/job:localhost/replica:0/task:0/device:CPU:0",
422                                &remote_enqueue_request);
423 
424   std::unordered_map<string, AttrValue> attrs;
425   val.Clear();
426   val.set_type(tensorflow::DataType::DT_FLOAT);
427   attrs.insert({"T", val});
428   val.Clear();
429   val.set_b(false);
430   attrs.insert({"transpose_a", val});
431   attrs.insert({"transpose_b", val});
432 
433   AddOperationToEnqueueRequest(
434       2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
435       "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
436 
437   TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
438                                           &remote_enqueue_response));
439 
440   auto& matmul_result_shape =
441       remote_enqueue_response.queue_response(1).shape(0);
442   EXPECT_EQ(matmul_result_shape.dim(0).size(), 2);
443   EXPECT_EQ(matmul_result_shape.dim(1).size(), 2);
444 
445   tensorflow::TensorHandle* tensor_handle;
446   TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
447       context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
448 
449   // This should be OK to do since we've placed all computation on the CPU
450   // device.
451   const tensorflow::Tensor* t = nullptr;
452   TF_ASSERT_OK(tensor_handle->Tensor(&t));
453 
454   auto actual = t->flat<float>();
455 
456   EXPECT_EQ(4, actual.size());
457 
458   EXPECT_EQ(7, actual(0));
459   EXPECT_EQ(10, actual(1));
460   EXPECT_EQ(15, actual(2));
461   EXPECT_EQ(22, actual(3));
462 
463   CloseContextRequest close_context_request;
464   close_context_request.set_context_id(context_id);
465   close_context_request.set_context_view_id(0);
466   CloseContextResponse close_context_response;
467   TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
468                                                &close_context_response));
469 }
470 
471 class EagerServiceImplFunctionTest : public EagerServiceImplTest {
472  public:
EagerServiceImplFunctionTest()473   EagerServiceImplFunctionTest() : EagerServiceImplTest() {}
474 
475   // Creates a context and attempts to execute a function.
TestFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool local_inputs=false,const bool test_cancel=false)476   void TestFunction(const RegisterFunctionOp& register_op,
477                     const string& function_name,
478                     const bool local_inputs = false,
479                     const bool test_cancel = false) {
480     TestEagerServiceImpl eager_service_impl(&worker_env_);
481 
482     uint64 context_id = random::New64();
483 
484     CreateContextRequest request;
485     request.mutable_server_def()->set_job_name("localhost");
486     request.mutable_server_def()->set_task_index(0);
487     request.set_context_id(context_id);
488     CreateContextResponse response;
489 
490     TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
491 
492     EnqueueRequest enqueue_request;
493     enqueue_request.set_context_id(context_id);
494     *enqueue_request.add_queue()->mutable_register_function() = register_op;
495     EnqueueResponse enqueue_response;
496 
497     TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
498                                             &enqueue_response));
499 
500     EnqueueRequest remote_enqueue_request;
501     remote_enqueue_request.set_context_id(context_id);
502     EnqueueResponse remote_enqueue_response;
503 
504     if (local_inputs) {
505       TensorProto tensor_proto;
506       SetTensorProto(&tensor_proto);
507       AddOperationToEnqueueRequest(
508           2, function_name, {tensor_proto},
509           std::unordered_map<string, AttrValue>(),
510           "/job:localhost/replica:0/task:0/device:CPU:0",
511           &remote_enqueue_request);
512 
513     } else {
514       std::unordered_map<string, AttrValue> const_attrs;
515       AttrValue val;
516       val.set_type(tensorflow::DataType::DT_FLOAT);
517       const_attrs.insert({"dtype", val});
518       val.Clear();
519 
520       SetTensorProto(val.mutable_tensor());
521       const_attrs.insert({"value", val});
522 
523       AddOperationToEnqueueRequest(
524           1, "Const", {}, const_attrs,
525           "/job:localhost/replica:0/task:0/device:CPU:0",
526           &remote_enqueue_request);
527       AddOperationToEnqueueRequest(
528           2, function_name, {std::make_pair(1, 0)},
529           std::unordered_map<string, AttrValue>(),
530           "/job:localhost/replica:0/task:0/device:CPU:0",
531           &remote_enqueue_request);
532     }
533 
534     CallOptions call_opts;
535     Status status;
536     Notification n;
537     Env::Default()->SchedClosure([&] {
538       status = eager_service_impl.Enqueue(&call_opts, &remote_enqueue_request,
539                                           &remote_enqueue_response);
540       n.Notify();
541     });
542 
543     if (test_cancel) {
544       // Wait to let the Enqueue thread starts running
545       Env::Default()->SleepForMicroseconds(500000);
546       call_opts.StartCancel();
547       n.WaitForNotification();
548       EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
549     } else {
550       n.WaitForNotification();
551       TF_ASSERT_OK(status);
552       const tensorflow::Tensor* t = nullptr;
553       tensorflow::TensorHandle* tensor_handle;
554       TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
555           context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
556       TF_ASSERT_OK(tensor_handle->Tensor(&t));
557 
558       auto actual = t->flat<float>();
559       EXPECT_EQ(4, actual.size());
560 
561       EXPECT_EQ(7, actual(0));
562       EXPECT_EQ(10, actual(1));
563       EXPECT_EQ(15, actual(2));
564       EXPECT_EQ(22, actual(3));
565     }
566 
567     CloseContextRequest close_context_request;
568     close_context_request.set_context_id(context_id);
569     close_context_request.set_context_view_id(0);
570     CloseContextResponse close_context_response;
571     TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
572                                                  &close_context_response));
573   }
574 
575   // Creates a context and attempts to execute a component function.
TestComponentFunction(const RegisterFunctionOp & register_op,const string & function_name,const bool test_cancel)576   void TestComponentFunction(const RegisterFunctionOp& register_op,
577                              const string& function_name,
578                              const bool test_cancel) {
579     TestEagerServiceImpl eager_service_impl(&worker_env_);
580     uint64 context_id = random::New64();
581 
582     // Create context.
583     CreateContextRequest request;
584     request.mutable_server_def()->set_job_name("localhost");
585     request.mutable_server_def()->set_task_index(0);
586     request.set_context_id(context_id);
587     CreateContextResponse response;
588     TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
589 
590     // Register function.
591     EnqueueRequest enqueue_request;
592     enqueue_request.set_context_id(context_id);
593     *enqueue_request.add_queue()->mutable_register_function() = register_op;
594     EnqueueResponse enqueue_response;
595     TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &enqueue_request,
596                                             &enqueue_response));
597 
598     // First run an op to generate input for function.
599     EnqueueRequest remote_enqueue_request;
600     remote_enqueue_request.set_context_id(context_id);
601     EnqueueResponse remote_enqueue_response;
602 
603     std::unordered_map<string, AttrValue> const_attrs;
604     AttrValue val;
605     val.set_type(tensorflow::DataType::DT_FLOAT);
606     const_attrs.insert({"dtype", val});
607     val.Clear();
608     SetTensorProto(val.mutable_tensor());
609     const_attrs.insert({"value", val});
610     AddOperationToEnqueueRequest(1, "Const", {}, const_attrs,
611                                  "/job:localhost/replica:0/task:0/device:CPU:0",
612                                  &remote_enqueue_request);
613     TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
614                                             &remote_enqueue_response));
615 
616     // Run function with input from the previous op.
617     RunComponentFunctionRequest run_comp_func_request;
618     run_comp_func_request.set_context_id(context_id);
619     RunComponentFunctionResponse run_comp_func_response;
620     const int output_num = 5;
621     AddOperationToRunComponentFunctionRequest(
622         2, function_name, {std::make_pair(1, 0)},
623         std::unordered_map<string, AttrValue>(),
624         "/job:localhost/replica:0/task:0/device:CPU:0", output_num,
625         &run_comp_func_request);
626 
627     CallOptions call_opts;
628     Notification n;
629     Status status;
630     eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request,
631                                             &run_comp_func_response,
632                                             [&status, &n](const Status& s) {
633                                               status.Update(s);
634                                               n.Notify();
635                                             });
636     if (test_cancel) {
637       call_opts.StartCancel();
638     }
639     n.WaitForNotification();
640     if (test_cancel) {
641       EXPECT_TRUE(errors::IsCancelled(status)) << status.error_message();
642     } else {
643       TF_ASSERT_OK(status);
644       // Retrieve the output.
645       const tensorflow::Tensor* t = nullptr;
646       tensorflow::TensorHandle* tensor_handle;
647       TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
648           context_id, RemoteTensorHandleInternal(2, output_num),
649           &tensor_handle));
650       TF_ASSERT_OK(tensor_handle->Tensor(&t));
651 
652       auto actual = t->flat<float>();
653       EXPECT_EQ(4, actual.size());
654 
655       EXPECT_EQ(7, actual(0));
656       EXPECT_EQ(10, actual(1));
657       EXPECT_EQ(15, actual(2));
658       EXPECT_EQ(22, actual(3));
659     }
660 
661     CloseContextRequest close_context_request;
662     close_context_request.set_context_id(context_id);
663     close_context_request.set_context_view_id(0);
664     CloseContextResponse close_context_response;
665     TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
666                                                  &close_context_response));
667   }
668 };
669 
TEST_F(EagerServiceImplFunctionTest,BasicFunctionTest)670 TEST_F(EagerServiceImplFunctionTest, BasicFunctionTest) {
671   RegisterFunctionOp register_op;
672   *register_op.mutable_function_def() = MatMulFunction();
673   TestFunction(register_op, "MatMulFunction");
674 }
675 
TEST_F(EagerServiceImplFunctionTest,FunctionWithLocalInputsTest)676 TEST_F(EagerServiceImplFunctionTest, FunctionWithLocalInputsTest) {
677   RegisterFunctionOp register_op;
678   *register_op.mutable_function_def() = MatMulFunction();
679   TestFunction(register_op, "MatMulFunction", /*local_inputs=*/true);
680 }
681 
TEST_F(EagerServiceImplFunctionTest,NestedFunctionTest)682 TEST_F(EagerServiceImplFunctionTest, NestedFunctionTest) {
683   RegisterFunctionOp register_op;
684   *register_op.mutable_function_def() = MatMulNestedFunction();
685   *register_op.mutable_library()->add_function() = MatMulFunction();
686   TestFunction(register_op, "MatMulNestedFunction");
687 }
688 
TEST_F(EagerServiceImplFunctionTest,FunctionCancellationTest)689 TEST_F(EagerServiceImplFunctionTest, FunctionCancellationTest) {
690   RegisterFunctionOp register_op;
691   *register_op.mutable_function_def() = SingleRecvNodeFunction();
692   TestFunction(register_op, "SingleRecvNodeFunction", /*local_inputs=*/false,
693                /*test_cancel=*/true);
694 }
695 
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionTest)696 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionTest) {
697   RegisterFunctionOp register_op;
698   *register_op.mutable_function_def() = MatMulFunction();
699   TestComponentFunction(register_op, "MatMulFunction", false);
700 }
701 
TEST_F(EagerServiceImplFunctionTest,ComponentFunctionCancellationTest)702 TEST_F(EagerServiceImplFunctionTest, ComponentFunctionCancellationTest) {
703   RegisterFunctionOp register_op;
704   *register_op.mutable_function_def() = SingleRecvNodeFunction();
705   TestComponentFunction(register_op, "SingleRecvNodeFunction", true);
706 }
707 
708 class FunctionWithRemoteInputsTest : public EagerServiceImplTest {
709  public:
FunctionWithRemoteInputsTest()710   FunctionWithRemoteInputsTest()
711       : EagerServiceImplTest(), eager_service_impl_(&worker_env_) {
712     remote_device_mgr_ = std::make_unique<StaticDeviceMgr>(
713         DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
714     context_id_ = random::New64();
715   }
716 
717   class TestExecuteNodeArgs : public EagerKernelArgs {
718    public:
TestExecuteNodeArgs(gtl::InlinedVector<TensorValue,4> && tensor_args,std::function<Status (const int,eager::RemoteTensorHandle *)> serialize_remote_handle)719     TestExecuteNodeArgs(
720         gtl::InlinedVector<TensorValue, 4>&& tensor_args,
721         std::function<Status(const int, eager::RemoteTensorHandle*)>
722             serialize_remote_handle)
723         : EagerKernelArgs(std::move(tensor_args)),
724           serialize_remote_handle_(std::move(serialize_remote_handle)) {}
725 
HasRemoteOrPackedInputs() const726     bool HasRemoteOrPackedInputs() const override { return true; }
727 
GetRemoteArg(const FunctionArgIndex & index,eager::RemoteTensorHandle * val) const728     Status GetRemoteArg(const FunctionArgIndex& index,
729                         eager::RemoteTensorHandle* val) const override {
730       return serialize_remote_handle_(index.index, val);
731     }
732 
733    private:
734     std::function<Status(const int, eager::RemoteTensorHandle*)>
735         serialize_remote_handle_;
736   };
737 
MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef & fdef)738   bool MatMulHasAttrWithDefaultValue(const tensorflow::FunctionDef& fdef) {
739     for (const auto& node : fdef.node_def()) {
740       if (node.op() == "MatMul") {
741         return node.attr().find("transpose_a") != node.attr().end();
742       }
743     }
744     return false;
745   }
746 
Init()747   void Init() {
748     CreateContextRequest request;
749     request.mutable_server_def()->set_job_name("localhost");
750     request.mutable_server_def()->set_task_index(0);
751     request.set_context_id(context_id_);
752     CreateContextResponse response;
753     TF_ASSERT_OK(eager_service_impl_.CreateContext(&request, &response));
754 
755     // Make the fake EagerClient use the local eager_service_impl.
756     EagerContext* ctx = nullptr;
757     TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
758     Device* device;
759     TF_ASSERT_OK(ctx->FindDeviceFromName(local_device_.c_str(), &device));
760     core::RefCountPtr<EagerClient> client;
761     TF_ASSERT_OK(ctx->GetClient(device, &client));
762     FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client.get());
763     fake_client->SetServiceImpl(&eager_service_impl_);
764 
765     // Create an input on local_device for MatMulFunction.
766     EnqueueRequest remote_enqueue_request;
767     remote_enqueue_request.set_context_id(context_id_);
768     EnqueueResponse remote_enqueue_response;
769     std::unordered_map<string, AttrValue> const_attrs;
770     AttrValue val;
771     val.set_type(tensorflow::DataType::DT_FLOAT);
772     const_attrs.insert({"dtype", val});
773     val.Clear();
774     SetTensorProto(val.mutable_tensor());
775     const_attrs.insert({"value", val});
776     AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, local_device_,
777                                  &remote_enqueue_request);
778     TF_EXPECT_OK(eager_service_impl_.Enqueue(nullptr, &remote_enqueue_request,
779                                              &remote_enqueue_response));
780     eager_cluster_flr_ = std::make_unique<EagerClusterFunctionLibraryRuntime>(
781         context_id_, ctx, device_mgr_.get());
782 
783     fdef_ = MatMulFunction();
784     TF_ASSERT_OK(func_lib_def_.AddFunctionDef(fdef_));
785     eager_pflr_ = std::make_unique<ProcessFunctionLibraryRuntime>(
786         remote_device_mgr_.get(), Env::Default(), /*config=*/
787         nullptr, TF_GRAPH_DEF_VERSION, &func_lib_def_, OptimizerOptions(),
788         /*thread_pool=*/nullptr, eager_cluster_flr_.get(),
789         /*session_metadata=*/nullptr,
790         Rendezvous::Factory{[this](const int64_t step_id,
791                                    const DeviceMgr* device_mgr,
792                                    Rendezvous** r) {
793           *r = worker_env_.rendezvous_mgr->Find(step_id);
794           return OkStatus();
795         }});
796   }
797 
CheckOutputTensorAndClose(const Tensor & tensor)798   void CheckOutputTensorAndClose(const Tensor& tensor) {
799     auto actual = tensor.flat<float>();
800     EXPECT_EQ(4, actual.size());
801     EXPECT_EQ(7, actual(0));
802     EXPECT_EQ(10, actual(1));
803     EXPECT_EQ(15, actual(2));
804     EXPECT_EQ(22, actual(3));
805 
806     CloseContextRequest close_context_request;
807     close_context_request.set_context_id(context_id_);
808     close_context_request.set_context_view_id(0);
809     CloseContextResponse close_context_response;
810     TF_ASSERT_OK(eager_service_impl_.CloseContext(&close_context_request,
811                                                   &close_context_response));
812   }
813 
CheckOutputsAndClose(const std::vector<FunctionRet> & outputs,const int64_t op_id)814   void CheckOutputsAndClose(const std::vector<FunctionRet>& outputs,
815                             const int64_t op_id) {
816     const tensorflow::Tensor* t = nullptr;
817     tensorflow::TensorHandle* tensor_handle;
818     TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
819         context_id_, RemoteTensorHandleInternal(2, 0), &tensor_handle));
820     TF_ASSERT_OK(tensor_handle->Tensor(&t));
821     EXPECT_EQ(outputs.size(), 1);
822     EXPECT_EQ(outputs.at(0).index(), 1);
823     const TensorShape& shape = std::get<TensorShape>(outputs.at(0));
824     EXPECT_EQ(shape, t->shape());
825     CheckOutputTensorAndClose(*t);
826   }
827 
828  protected:
829   const string local_device_ = "/job:localhost/replica:0/task:0/device:CPU:0";
830   const string remote_device_ = "/job:localhost/replica:0/task:1/device:CPU:0";
831   TestEagerServiceImpl eager_service_impl_;
832   std::unique_ptr<DeviceMgr> remote_device_mgr_;
833   uint64 context_id_;
834   tensorflow::FunctionDef fdef_;
835   std::unique_ptr<ProcessFunctionLibraryRuntime> eager_pflr_;
836   std::unique_ptr<EagerClusterFunctionLibraryRuntime> eager_cluster_flr_;
837   FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
838 };
839 
840 // Test executes a remote function through
841 // ProcessFunctionLibraryRuntime(EagerClusterFunctionLibraryRuntime).
TEST_F(FunctionWithRemoteInputsTest,EagerPFLRTest)842 TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) {
843   Init();
844   // Instantiate MatMulFunction on remote_device.
845   FunctionLibraryRuntime::InstantiateOptions options;
846   options.target = remote_device_;
847   options.is_multi_device_function = true;
848   options.input_devices.push_back(local_device_);
849   FunctionLibraryRuntime::Handle handle;
850   EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
851   TF_ASSERT_OK(eager_pflr_->Instantiate(
852       fdef_.signature().name(), AttrSlice(&fdef_.attr()), options, &handle));
853   EagerContext* ctx = nullptr;
854   TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
855   for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
856     const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
857     EXPECT_TRUE(fdef != nullptr);
858     if (absl::StartsWith(func_name, "MatMulFunction")) {
859       EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
860     }
861   }
862   bool is_cross_process = false;
863   TF_CHECK_OK(eager_pflr_->IsCrossProcess(handle, &is_cross_process));
864   EXPECT_TRUE(is_cross_process);
865 
866   // Run MatMulFunction on remote_device.
867   FunctionLibraryRuntime::Options opts;
868   const uint64 op_id = 2;
869   opts.op_id = op_id;
870   Notification done;
871   Status status;
872   RemoteTensorHandle input;
873   input.set_op_id(1);
874   input.set_output_num(0);
875   input.set_op_device(local_device_);
876   input.set_device(local_device_);
877   std::vector<RemoteTensorHandle> inputs = {input};
878   std::vector<FunctionRet> outputs;
879   gtl::InlinedVector<TensorValue, 4> tensor_args = {TensorValue()};
880   TestExecuteNodeArgs args(
881       std::move(tensor_args),
882       [&inputs](const int i, RemoteTensorHandle* handle) -> Status {
883         *handle = inputs.at(i);
884         return OkStatus();
885       });
886   eager_pflr_->Run(opts, handle, args, &outputs,
887                    [&status, &done](const Status& s) {
888                      status = s;
889                      done.Notify();
890                    });
891   done.WaitForNotification();
892   TF_ASSERT_OK(status);
893   CheckOutputsAndClose(outputs, op_id);
894 }
895 
896 // Test executes a remote function with local input and output tensors.
TEST_F(FunctionWithRemoteInputsTest,EagerClusterFLRTestWithLocalInputAndOutput)897 TEST_F(FunctionWithRemoteInputsTest,
898        EagerClusterFLRTestWithLocalInputAndOutput) {
899   Init();
900   // Instantiate MatMulFunction on remote_device.
901   FunctionLibraryRuntime::Handle handle;
902   EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_));
903   Status status;
904   Notification instantiate_done;
905   eager_cluster_flr_->Instantiate(
906       fdef_.signature().name(), func_lib_def_, AttrSlice(&fdef_.attr()),
907       FunctionLibraryRuntime::InstantiateOptions(), &handle,
908       [&status, &instantiate_done](const Status& s) {
909         status = s;
910         instantiate_done.Notify();
911       });
912   instantiate_done.WaitForNotification();
913   TF_ASSERT_OK(status);
914   EagerContext* ctx = nullptr;
915   TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
916   for (const string& func_name : ctx->FuncLibDef()->ListFunctionNames()) {
917     const FunctionDef* fdef = ctx->FuncLibDef()->Find(func_name);
918     EXPECT_TRUE(fdef != nullptr);
919     if (absl::StartsWith(func_name, "MatMulFunction")) {
920       EXPECT_FALSE(MatMulHasAttrWithDefaultValue(*fdef));
921     }
922   }
923   const tensorflow::Tensor* input_tensor = nullptr;
924   tensorflow::TensorHandle* tensor_handle;
925   TF_ASSERT_OK(eager_service_impl_.GetTensorHandle(
926       context_id_, RemoteTensorHandleInternal(1, 0), &tensor_handle));
927   TF_ASSERT_OK(tensor_handle->Tensor(&input_tensor));
928 
929   // Send input_tensor to the remote device, execute MatMulFunction on the
930   // remote device, and send the output back.
931   FunctionLibraryRuntime::Options opts;
932   Notification execute_done;
933   std::vector<Tensor> inputs = {*input_tensor};
934   std::vector<Tensor> outputs;
935   eager_cluster_flr_->Run(opts, handle, inputs, &outputs,
936                           [&status, &execute_done](const Status& s) {
937                             status = s;
938                             execute_done.Notify();
939                           });
940   execute_done.WaitForNotification();
941   TF_ASSERT_OK(status);
942   EXPECT_EQ(outputs.size(), 1);
943   CheckOutputTensorAndClose(outputs.at(0));
944 }
945 
946 // Test executes a remote function through KernelAndDeviceFunc::Run.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncTest)947 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) {
948   Init();
949   Device* local_device;
950   TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
951   std::vector<Device*> input_dev_ptrs;
952   input_dev_ptrs.push_back(local_device);
953   FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
954   EagerContext* ctx = nullptr;
955   TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
956   core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
957   const int64_t op_id = 2;
958   kernel.reset(new KernelAndDeviceFunc(
959       flr, eager_pflr_.get(), std::move(input_dev_ptrs),
960       /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
961       /*runner=*/nullptr,
962       /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
963       /*outputs_on_op_device=*/false,
964       /*allow_small_function_optimizations=*/false,
965       /*allow_control_flow_sync_execution=*/false,
966       /*shape_inference_on_tfe_dialect_import=*/true,
967       /*int_args_and_retvals_on_device=*/false,
968       /*xla_compile_device_type=*/std::nullopt, ctx->RendezvousCreator(),
969       [=]() { return op_id; }));
970 
971   // Instantiate MatMulFunction on remote_device.
972   const NodeDef node_def = MatMulFunctionNodeDef();
973   TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
974 
975   // Run MatMulFunction on remote_device.
976   gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
977   RemoteTensorHandle input;
978   input.set_op_id(1);
979   input.set_output_num(0);
980   input.set_op_device(local_device_);
981   input.set_device(local_device_);
982   std::vector<RemoteTensorHandle> remote_handles = {input};
983   TestExecuteNodeArgs inputs(
984       std::move(input_tensors),
985       [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
986         *handle = remote_handles.at(index);
987         return OkStatus();
988       });
989   std::vector<FunctionRet> outputs;
990 
991   TF_ASSERT_OK(kernel->Run(/*step_container=*/nullptr, inputs, &outputs,
992                            /*cancellation_manager=*/nullptr,
993                            /*eager_func_params=*/std::nullopt,
994                            /*stack_trace=*/std::nullopt,
995                            /*coordination_service_agent=*/nullptr));
996 
997   CheckOutputsAndClose(outputs, op_id);
998 }
999 
1000 // Test executes a remote function through KernelAndDeviceFunc::RunAsync.
TEST_F(FunctionWithRemoteInputsTest,KernelAndDeviceFuncAsyncTest)1001 TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) {
1002   Init();
1003   Device* local_device;
1004   TF_ASSERT_OK(device_mgr_->LookupDevice(local_device_, &local_device));
1005   std::vector<Device*> input_dev_ptrs;
1006   input_dev_ptrs.push_back(local_device);
1007   FunctionLibraryRuntime* flr = eager_pflr_->GetFLR(remote_device_);
1008   EagerContext* ctx = nullptr;
1009   TF_ASSERT_OK(eager_service_impl_.GetEagerContext(context_id_, &ctx));
1010   core::RefCountPtr<KernelAndDeviceFunc> kernel = nullptr;
1011   const int64_t op_id = 2;
1012   kernel.reset(new KernelAndDeviceFunc(
1013       flr, eager_pflr_.get(), std::move(input_dev_ptrs),
1014       /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{},
1015       /*runner=*/nullptr,
1016       /*collective_executor=*/nullptr, local_device, fdef_.signature().name(),
1017       /*outputs_on_op_device=*/false,
1018       /*allow_small_function_optimizations=*/false,
1019       /*allow_control_flow_sync_execution=*/false,
1020       /*shape_inference_on_tfe_dialect_import=*/true,
1021       /*int_args_and_retvals_on_device=*/false,
1022       /*xla_compile_device_type=*/std::nullopt, ctx->RendezvousCreator(),
1023       [=]() { return op_id; }));
1024 
1025   // Instantiate MatMulFunction on remote_device.
1026   const NodeDef node_def = MatMulFunctionNodeDef();
1027   TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr));
1028 
1029   // Run MatMulFunction on remote_device.
1030   gtl::InlinedVector<TensorValue, 4> input_tensors = {TensorValue()};
1031   RemoteTensorHandle input;
1032   input.set_op_id(1);
1033   input.set_output_num(0);
1034   input.set_op_device(local_device_);
1035   input.set_device(local_device_);
1036   std::vector<RemoteTensorHandle> remote_handles = {input};
1037   TestExecuteNodeArgs inputs(
1038       std::move(input_tensors),
1039       [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status {
1040         *handle = remote_handles.at(index);
1041         return OkStatus();
1042       });
1043   std::vector<FunctionRet> outputs;
1044 
1045   Status status;
1046   Notification n;
1047   kernel->RunAsync(/*step_container=*/nullptr, inputs, &outputs,
1048                    /*cancellation_manager=*/nullptr,
1049                    /*eager_func_params=*/std::nullopt,
1050                    /*coordination_service_agent=*/nullptr,
1051                    [&status, &n](const Status& s) {
1052                      status = s;
1053                      n.Notify();
1054                    });
1055   n.WaitForNotification();
1056   TF_ASSERT_OK(status);
1057   CheckOutputsAndClose(outputs, op_id);
1058 }
1059 
1060 // Test creates a context and attempts to send a tensor (using the RPC), and
1061 // then use the tensor.
TEST_F(EagerServiceImplTest,SendTensorTest)1062 TEST_F(EagerServiceImplTest, SendTensorTest) {
1063   TestEagerServiceImpl eager_service_impl(&worker_env_);
1064 
1065   uint64 context_id = random::New64();
1066 
1067   CreateContextRequest request;
1068   request.mutable_server_def()->set_job_name("localhost");
1069   request.mutable_server_def()->set_task_index(0);
1070   request.set_context_id(context_id);
1071   CreateContextResponse response;
1072 
1073   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1074 
1075   EnqueueRequest remote_enqueue_request;
1076   remote_enqueue_request.set_context_id(context_id);
1077   EnqueueResponse remote_enqueue_response;
1078 
1079   auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1080   send_tensor->set_op_id(1);
1081   SetTensorProto(send_tensor->add_tensors());
1082 
1083   std::unordered_map<string, AttrValue> attrs;
1084   AttrValue val;
1085   val.Clear();
1086   val.set_type(tensorflow::DataType::DT_FLOAT);
1087   attrs.insert({"T", val});
1088   val.Clear();
1089   val.set_b(false);
1090   attrs.insert({"transpose_a", val});
1091   attrs.insert({"transpose_b", val});
1092 
1093   AddOperationToEnqueueRequest(
1094       2, "MatMul", {std::make_pair(1, 0), std::make_pair(1, 0)}, attrs,
1095       "/job:localhost/replica:0/task:0/device:CPU:0", &remote_enqueue_request);
1096 
1097   TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1098                                           &remote_enqueue_response));
1099 
1100   const tensorflow::Tensor* t = nullptr;
1101   tensorflow::TensorHandle* tensor_handle;
1102   TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1103       context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
1104   TF_ASSERT_OK(tensor_handle->Tensor(&t));
1105 
1106   EXPECT_EQ(tensor_handle->device(), nullptr);
1107 
1108   auto actual = t->flat<float>();
1109   EXPECT_EQ(4, actual.size());
1110 
1111   EXPECT_EQ(7, actual(0));
1112   EXPECT_EQ(10, actual(1));
1113   EXPECT_EQ(15, actual(2));
1114   EXPECT_EQ(22, actual(3));
1115 
1116   CloseContextRequest close_context_request;
1117   close_context_request.set_context_id(context_id);
1118   close_context_request.set_context_view_id(0);
1119   CloseContextResponse close_context_response;
1120   TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1121                                                &close_context_response));
1122 }
1123 
1124 // Test serializes and sends a pack TensorHandle.
TEST_F(EagerServiceImplTest,SendPackedHandleTest)1125 TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
1126   TestEagerServiceImpl eager_service_impl(&worker_env_);
1127 
1128   const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
1129   const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
1130   const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
1131   const string composite_device =
1132       "/job:localhost/replica:0/task:0/device:COMPOSITE:0";
1133 
1134   uint64 context_id = random::New64();
1135   CreateContextRequest request;
1136   auto* server_def = request.mutable_server_def();
1137   server_def->set_job_name("localhost");
1138   server_def->set_task_index(0);
1139   request.add_cluster_device_attributes()->set_name(device0);
1140   request.add_cluster_device_attributes()->set_name(device1);
1141   request.add_cluster_device_attributes()->set_name(device2);
1142   request.set_context_id(context_id);
1143   CreateContextResponse response;
1144 
1145   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1146 
1147   EnqueueRequest remote_enqueue_request;
1148   remote_enqueue_request.set_context_id(context_id);
1149   EnqueueResponse remote_enqueue_response;
1150 
1151   // Copy a tensor to device0
1152   auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1153   send_tensor->set_op_id(1);
1154   SetTensorProto(send_tensor->add_tensors());
1155 
1156   // Copy a packed handle to device0
1157   auto* send_packed_handle =
1158       remote_enqueue_request.add_queue()->mutable_send_packed_handle();
1159   send_packed_handle->set_op_id(3);
1160   RemoteTensorHandle* remote_handle =
1161       send_packed_handle->add_handles()->mutable_remote_handle();
1162   remote_handle->set_op_id(send_tensor->op_id());
1163   remote_handle->set_output_num(0);
1164   remote_handle->set_op_device(device0);
1165   remote_handle->set_device(device0);
1166 
1167   SendPackedHandleOp::LocalTensorHandle* lcoal_handle =
1168       send_packed_handle->add_handles()->mutable_local_handle();
1169   SetTensorProto(lcoal_handle->mutable_tensor());
1170   lcoal_handle->set_device(device1);
1171 
1172   remote_handle = send_packed_handle->add_handles()->mutable_remote_handle();
1173   remote_handle->set_op_id(2);
1174   remote_handle->set_output_num(5);
1175   remote_handle->set_op_device(device2);
1176   remote_handle->set_device(device2);
1177 
1178   TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1179                                           &remote_enqueue_response));
1180 
1181   tensorflow::TensorHandle* packed_handle;
1182   TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
1183       context_id, RemoteTensorHandleInternal(3, 0), &packed_handle));
1184 
1185   EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
1186   EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
1187   EXPECT_EQ(packed_handle->device()->name(), composite_device);
1188 
1189   TensorHandle* handle0 = nullptr;
1190   TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
1191   EXPECT_EQ(handle0->Type(), TensorHandle::LOCAL);
1192   EXPECT_EQ(handle0->op_device()->name(), device0);
1193   const Tensor* t0 = nullptr;
1194   TF_ASSERT_OK(handle0->Tensor(&t0));
1195   auto actual = t0->flat<float>();
1196   EXPECT_EQ(4, actual.size());
1197   EXPECT_EQ(1.0, actual(0));
1198   EXPECT_EQ(2.0, actual(1));
1199   EXPECT_EQ(3.0, actual(2));
1200   EXPECT_EQ(4.0, actual(3));
1201 
1202   TensorHandle* handle1 = nullptr;
1203   TF_ASSERT_OK(packed_handle->ExtractPackedHandle(1, &handle1));
1204   EXPECT_EQ(handle1->Type(), TensorHandle::LOCAL);
1205   EXPECT_EQ(handle1->op_device()->name(), device1);
1206   const Tensor* t1 = nullptr;
1207   TF_ASSERT_OK(handle0->Tensor(&t1));
1208   EXPECT_EQ(t1, t0);
1209 
1210   TensorHandle* handle2 = nullptr;
1211   TF_ASSERT_OK(packed_handle->ExtractPackedHandle(2, &handle2));
1212   EXPECT_EQ(handle2->Type(), TensorHandle::REMOTE);
1213   EXPECT_EQ(handle2->op_device()->name(), device2);
1214   int64_t op_id;
1215   int32_t output_num;
1216   TF_ASSERT_OK(handle2->RemoteAddress(handle2->device(),
1217                                       /*wait_until_ready=*/true, &op_id,
1218                                       &output_num));
1219   EXPECT_EQ(op_id, 2);
1220   EXPECT_EQ(output_num, 5);
1221 
1222   CloseContextRequest close_context_request;
1223   close_context_request.set_context_id(context_id);
1224   close_context_request.set_context_view_id(0);
1225   CloseContextResponse close_context_response;
1226   TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
1227                                                &close_context_response));
1228 }
1229 
1230 // Test requests sent to the eager service on master.
TEST_F(EagerServiceImplTest,RequestsToMasterTest)1231 TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
1232   tensorflow::Rendezvous* rendezvous =
1233       new tensorflow::IntraProcessRendezvous(device_mgr_.get());
1234   // Create a master eager context.
1235   tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
1236       SessionOptions(),
1237       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
1238       /*async=*/false, device_mgr_.get(), false, rendezvous);
1239   const uint64 context_id = random::New64();
1240 
1241   // Set RemoteMgr to ctx.
1242   auto remote_mgr =
1243       std::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/true, ctx);
1244   TF_ASSERT_OK(ctx->InitializeRemoteWorker(
1245       /*remote_eager_workers=*/nullptr, /*remote_device_mgr=*/nullptr,
1246       /*remote_contexts=*/{}, context_id, /*context_view_id=*/0,
1247       /*rendezvous_creator=*/nullptr,
1248       /*cluster_flr=*/nullptr, std::move(remote_mgr),
1249       /*resource_deallocator=*/nullptr));
1250 
1251   TestEagerServiceImpl eager_service_impl(&worker_env_);
1252 
1253   EnqueueRequest remote_enqueue_request;
1254   remote_enqueue_request.set_context_id(context_id);
1255   EnqueueResponse remote_enqueue_response;
1256 
1257   auto* send_tensor = remote_enqueue_request.add_queue()->mutable_send_tensor();
1258   send_tensor->set_op_id(1);
1259   SetTensorProto(send_tensor->add_tensors());
1260 
1261   // Unable to handle the request since there is no eager context.
1262   Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1263                                              &remote_enqueue_response);
1264   EXPECT_EQ(error::ABORTED, status.code());
1265   EXPECT_TRUE(absl::StrContains(
1266       status.error_message(),
1267       "Unable to find a context_id matching the specified one"));
1268 
1269   // The request can be handled after adding the master eager context to
1270   // service.
1271   TF_ASSERT_OK(eager_service_impl.CreateMasterContext(context_id, ctx));
1272   TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
1273                                           &remote_enqueue_response));
1274   ctx->Unref();
1275 }
1276 
TEST_F(EagerServiceImplTest,KeepAliveTest)1277 TEST_F(EagerServiceImplTest, KeepAliveTest) {
1278   TestEagerServiceImpl eager_service_impl(&worker_env_);
1279 
1280   uint64 context_id = random::New64();
1281   CreateContextRequest request;
1282   request.mutable_server_def()->set_job_name("localhost");
1283   request.mutable_server_def()->set_task_index(0);
1284   request.set_context_id(context_id);
1285   request.set_keep_alive_secs(3);
1286   CreateContextResponse response;
1287 
1288   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1289 
1290   worker_env_.env->SleepForMicroseconds(5 *
1291                                         tensorflow::EnvTime::kSecondsToMicros);
1292 
1293   KeepAliveRequest keep_alive_request;
1294   KeepAliveResponse keep_alive_response;
1295 
1296   keep_alive_request.set_context_id(context_id);
1297 
1298   Status status =
1299       eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
1300 
1301   EXPECT_EQ(status.code(), error::ABORTED);
1302   EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
1303                       status.error_message());
1304 
1305   uint64 new_context_id = random::New64();
1306   // Create a new context.
1307   request.set_context_id(new_context_id);
1308   TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
1309 
1310   // The context should not be GC'd.
1311   worker_env_.env->SleepForMicroseconds(1 *
1312                                         tensorflow::EnvTime::kSecondsToMicros);
1313 
1314   keep_alive_request.set_context_id(new_context_id);
1315 
1316   TF_ASSERT_OK(
1317       eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response));
1318 }
1319 
1320 }  // namespace
1321 }  // namespace eager
1322 }  // namespace tensorflow
1323