• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <array>
17 #include <string>
18 
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/c_api_experimental.h"
21 #include "tensorflow/c/eager/c_api.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
24 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
26 #include "tensorflow/core/platform/test.h"
27 
GetServerDef(const std::string & job_name,int num_tasks)28 tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) {
29   tensorflow::ServerDef server_def;
30   server_def.set_protocol("grpc");
31   server_def.set_job_name(job_name);
32   server_def.set_task_index(0);
33   tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
34   tensorflow::JobDef* job_def = cluster_def->add_job();
35   job_def->set_name(job_name);
36   for (int i = 0; i < num_tasks; i++) {
37     int port = tensorflow::testing::PickUnusedPortOrDie();
38     job_def->mutable_tasks()->insert(
39         {i, tensorflow::strings::StrCat("localhost", ":", port)});
40   }
41   return server_def;
42 }
43 
44 namespace tensorflow {
45 namespace parallel_device {
46 
TEST(PARALLEL_DEVICE,TestRemoteBasic)47 TEST(PARALLEL_DEVICE, TestRemoteBasic) {
48   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
49       TFE_NewContextOptions(), TFE_DeleteContextOptions);
50   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
51       TF_NewStatus(), TF_DeleteStatus);
52   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
53       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
54   tensorflow::ServerDef server_def = GetServerDef("worker", 3);
55 
56   // This server def has the task index set to 0.
57   std::string serialized = server_def.SerializeAsString();
58 
59   server_def.set_task_index(1);
60   std::unique_ptr<tensorflow::GrpcServer> worker_server1;
61   ASSERT_TRUE(tensorflow::GrpcServer::Create(
62                   server_def, tensorflow::Env::Default(), &worker_server1)
63                   .ok());
64   ASSERT_TRUE(worker_server1->Start().ok());
65 
66   server_def.set_task_index(2);
67   std::unique_ptr<tensorflow::GrpcServer> worker_server2;
68   ASSERT_TRUE(tensorflow::GrpcServer::Create(
69                   server_def, tensorflow::Env::Default(), &worker_server2)
70                   .ok());
71   ASSERT_TRUE(worker_server2->Start().ok());
72 
73   TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
74                           serialized.size(), status.get());
75   EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
76 
77   BasicTestsForTwoDevices(context.get(),
78                           "/job:worker/replica:0/task:1/device:CPU:0",
79                           "/job:worker/replica:0/task:2/device:CPU:0");
80 
81   worker_server1.release();
82   worker_server2.release();
83 }
84 
TEST(PARALLEL_DEVICE,TestAsyncCopyOff)85 TEST(PARALLEL_DEVICE, TestAsyncCopyOff) {
86   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
87       TFE_NewContextOptions(), TFE_DeleteContextOptions);
88   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
89       TF_NewStatus(), TF_DeleteStatus);
90   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
91       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
92   tensorflow::ServerDef server_def = GetServerDef("worker", 3);
93 
94   // This server def has the task index set to 0.
95   std::string serialized = server_def.SerializeAsString();
96 
97   server_def.set_task_index(1);
98   std::unique_ptr<tensorflow::GrpcServer> worker_server1;
99   ASSERT_TRUE(tensorflow::GrpcServer::Create(
100                   server_def, tensorflow::Env::Default(), &worker_server1)
101                   .ok());
102   ASSERT_TRUE(worker_server1->Start().ok());
103 
104   server_def.set_task_index(2);
105   std::unique_ptr<tensorflow::GrpcServer> worker_server2;
106   ASSERT_TRUE(tensorflow::GrpcServer::Create(
107                   server_def, tensorflow::Env::Default(), &worker_server2)
108                   .ok());
109   ASSERT_TRUE(worker_server2->Start().ok());
110 
111   TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
112                           serialized.size(), status.get());
113   EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
114 
115   const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0";
116   const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0";
117   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
118   std::array<const char*, 2> underlying_devices{first_device, second_device};
119   RegisterParallelDevice(context.get(), device_name, underlying_devices,
120                          status.get());
121   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
122 
123   TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
124   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
125   TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
126   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
127   std::array<TFE_TensorHandle*, 2> in_components{value_one.get(),
128                                                  value_two.get()};
129   TensorHandlePtr combined_value = CreatePerDeviceValues(
130       context.get(), in_components, device_name, status.get());
131   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
132 
133   // Loop to make synchronization failures more deterministic
134   for (int i = 0; i < 100; ++i) {
135     TensorHandlePtr multiply_result(
136         Multiply(context.get(), combined_value.get(), combined_value.get(),
137                  status.get()));
138     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
139     std::array<TensorHandlePtr, 2> out_components;
140     ExtractPerDeviceValues(context.get(), multiply_result.get(),
141                            &out_components, status.get());
142     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
143 
144     ExpectScalarEq<float>(out_components[0].get(), 9.);
145     ExpectScalarEq<float>(out_components[1].get(), 4.);
146   }
147 
148   worker_server1.release();
149   worker_server2.release();
150 }
151 }  // namespace parallel_device
152 }  // namespace tensorflow
153