1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <array>
17 #include <string>
18
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/c_api_experimental.h"
21 #include "tensorflow/c/eager/c_api.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
24 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
26 #include "tensorflow/core/platform/test.h"
27
GetServerDef(const std::string & job_name,int num_tasks)28 tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) {
29 tensorflow::ServerDef server_def;
30 server_def.set_protocol("grpc");
31 server_def.set_job_name(job_name);
32 server_def.set_task_index(0);
33 tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
34 tensorflow::JobDef* job_def = cluster_def->add_job();
35 job_def->set_name(job_name);
36 for (int i = 0; i < num_tasks; i++) {
37 int port = tensorflow::testing::PickUnusedPortOrDie();
38 job_def->mutable_tasks()->insert(
39 {i, tensorflow::strings::StrCat("localhost", ":", port)});
40 }
41 return server_def;
42 }
43
44 namespace tensorflow {
45 namespace parallel_device {
46
TEST(PARALLEL_DEVICE,TestRemoteBasic)47 TEST(PARALLEL_DEVICE, TestRemoteBasic) {
48 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
49 TFE_NewContextOptions(), TFE_DeleteContextOptions);
50 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
51 TF_NewStatus(), TF_DeleteStatus);
52 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
53 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
54 tensorflow::ServerDef server_def = GetServerDef("worker", 3);
55
56 // This server def has the task index set to 0.
57 std::string serialized = server_def.SerializeAsString();
58
59 server_def.set_task_index(1);
60 std::unique_ptr<tensorflow::GrpcServer> worker_server1;
61 ASSERT_TRUE(tensorflow::GrpcServer::Create(
62 server_def, tensorflow::Env::Default(), &worker_server1)
63 .ok());
64 ASSERT_TRUE(worker_server1->Start().ok());
65
66 server_def.set_task_index(2);
67 std::unique_ptr<tensorflow::GrpcServer> worker_server2;
68 ASSERT_TRUE(tensorflow::GrpcServer::Create(
69 server_def, tensorflow::Env::Default(), &worker_server2)
70 .ok());
71 ASSERT_TRUE(worker_server2->Start().ok());
72
73 TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
74 serialized.size(), status.get());
75 EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
76
77 BasicTestsForTwoDevices(context.get(),
78 "/job:worker/replica:0/task:1/device:CPU:0",
79 "/job:worker/replica:0/task:2/device:CPU:0");
80
81 worker_server1.release();
82 worker_server2.release();
83 }
84
TEST(PARALLEL_DEVICE,TestAsyncCopyOff)85 TEST(PARALLEL_DEVICE, TestAsyncCopyOff) {
86 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
87 TFE_NewContextOptions(), TFE_DeleteContextOptions);
88 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
89 TF_NewStatus(), TF_DeleteStatus);
90 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
91 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
92 tensorflow::ServerDef server_def = GetServerDef("worker", 3);
93
94 // This server def has the task index set to 0.
95 std::string serialized = server_def.SerializeAsString();
96
97 server_def.set_task_index(1);
98 std::unique_ptr<tensorflow::GrpcServer> worker_server1;
99 ASSERT_TRUE(tensorflow::GrpcServer::Create(
100 server_def, tensorflow::Env::Default(), &worker_server1)
101 .ok());
102 ASSERT_TRUE(worker_server1->Start().ok());
103
104 server_def.set_task_index(2);
105 std::unique_ptr<tensorflow::GrpcServer> worker_server2;
106 ASSERT_TRUE(tensorflow::GrpcServer::Create(
107 server_def, tensorflow::Env::Default(), &worker_server2)
108 .ok());
109 ASSERT_TRUE(worker_server2->Start().ok());
110
111 TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
112 serialized.size(), status.get());
113 EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
114
115 const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0";
116 const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0";
117 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
118 std::array<const char*, 2> underlying_devices{first_device, second_device};
119 RegisterParallelDevice(context.get(), device_name, underlying_devices,
120 status.get());
121 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
122
123 TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
124 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
125 TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
126 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
127 std::array<TFE_TensorHandle*, 2> in_components{value_one.get(),
128 value_two.get()};
129 TensorHandlePtr combined_value = CreatePerDeviceValues(
130 context.get(), in_components, device_name, status.get());
131 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
132
133 // Loop to make synchronization failures more deterministic
134 for (int i = 0; i < 100; ++i) {
135 TensorHandlePtr multiply_result(
136 Multiply(context.get(), combined_value.get(), combined_value.get(),
137 status.get()));
138 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
139 std::array<TensorHandlePtr, 2> out_components;
140 ExtractPerDeviceValues(context.get(), multiply_result.get(),
141 &out_components, status.get());
142 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
143
144 ExpectScalarEq<float>(out_components[0].get(), 9.);
145 ExpectScalarEq<float>(out_components[1].get(), 4.);
146 }
147
148 worker_server1.release();
149 worker_server2.release();
150 }
151 } // namespace parallel_device
152 } // namespace tensorflow
153