1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api_remote_test_util.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/c/eager/c_api_internal.h"
20 #include "tensorflow/c/eager/c_api_test_util.h"
21 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
26
27 using ::tensorflow::string;
28
MatMulFunction(const string & matmul_device)29 string MatMulFunction(const string& matmul_device) {
30 tensorflow::FunctionDef def;
31 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
32 absl::StrCat(" signature {"
33 " name: 'MatMulFunction'"
34 " input_arg {"
35 " name: 'a'"
36 " type: DT_FLOAT"
37 " }"
38 " input_arg {"
39 " name: 'b'"
40 " type: DT_FLOAT"
41 " }"
42 " output_arg {"
43 " name: 'm'"
44 " type: DT_FLOAT"
45 " }"
46 " }"
47 " node_def {"
48 " name: 'matmul'"
49 " op: 'MatMul'"
50 " input: 'a'"
51 " input: 'b'"
52 " device: '",
53 matmul_device, "'",
54 " attr {"
55 " key: 'T'"
56 " value {"
57 " type: DT_FLOAT"
58 " }"
59 " }"
60 " }"
61 " ret {"
62 " key: 'm'"
63 " value: 'matmul:product'"
64 " }"),
65 &def));
66 return def.SerializeAsString();
67 }
68
TestRemoteExecuteSilentCopies(bool async,bool remote,bool func,bool heavy_load_on_streaming_rpc,bool remote_func_outputs,bool has_packed_input)69 void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
70 bool heavy_load_on_streaming_rpc,
71 bool remote_func_outputs,
72 bool has_packed_input) {
73 CHECK(!has_packed_input || func);
74 tensorflow::ServerDef server_def = GetServerDef(3);
75
76 // This server def has the task index set to 0.
77 string serialized = server_def.SerializeAsString();
78
79 server_def.set_task_index(1);
80 std::unique_ptr<tensorflow::GrpcServer> worker_server1;
81 ASSERT_TRUE(tensorflow::GrpcServer::Create(
82 server_def, tensorflow::Env::Default(), &worker_server1)
83 .ok());
84 ASSERT_TRUE(worker_server1->Start().ok());
85
86 server_def.set_task_index(2);
87 std::unique_ptr<tensorflow::GrpcServer> worker_server2;
88 ASSERT_TRUE(tensorflow::GrpcServer::Create(
89 server_def, tensorflow::Env::Default(), &worker_server2)
90 .ok());
91 ASSERT_TRUE(worker_server2->Start().ok());
92
93 TF_Status* status = TF_NewStatus();
94 TFE_ContextOptions* opts = TFE_NewContextOptions();
95 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
96 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
97 TFE_Context* ctx = TFE_NewContext(opts, status);
98 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
99 TFE_DeleteContextOptions(opts);
100
101 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
102 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
103
104 TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
105 TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
106 std::vector<TFE_TensorHandle*> handles_task0;
107 if (heavy_load_on_streaming_rpc) {
108 // Send 50 tensor copy requests to simulate that there have been some RPC
109 // requests been enqueued.
110 for (int i = 0; i < 50; ++i) {
111 handles_task0.push_back(TestMatrixTensorHandle(ctx));
112 }
113 }
114 const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
115 const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
116
117 std::vector<TFE_TensorHandle*> handles_task2;
118 for (auto* h_task0 : handles_task0) {
119 handles_task2.push_back(
120 TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
121 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
122 }
123
124 auto* h1_task2 =
125 TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
126 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
127
128 TFE_TensorHandle* packed_handle = nullptr;
129 if (has_packed_input) {
130 int num_replicas = 1;
131 std::vector<TFE_TensorHandle*> packed_handles = {h1_task2};
132 packed_handle = TFE_CreatePackedTensorHandle(ctx, packed_handles.data(),
133 &num_replicas, status);
134 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
135 }
136
137 TFE_Op* matmul = nullptr;
138 if (func) {
139 const string matmul_device = remote_func_outputs ? task2_name : "";
140 string function_def = MatMulFunction(matmul_device);
141 TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
142 status);
143 CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
144
145 matmul = TFE_NewOp(ctx, "MatMulFunction", status);
146 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
147 TFE_OpAddInput(matmul, h0_task0, status);
148 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
149 TFE_OpAddInput(matmul, has_packed_input ? packed_handle : h1_task2, status);
150 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
151 } else {
152 // Handles are on task0 (local), and task2, but op is on task1.
153 matmul = MatMulOp(ctx, h0_task0, h1_task2);
154 }
155 if (remote) {
156 TFE_OpSetDevice(matmul, task1_name, status);
157 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
158 } else if (!async) {
159 // Set the local device to CPU to easily validate mirroring
160 string cpu_device_name;
161 ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
162 TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
163 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
164 auto remote_arg =
165 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
166 // The input handles should never change since they have been mirrored.
167 ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
168 }
169
170 TFE_TensorHandle* retvals[1];
171 int num_retvals = 1;
172 TFE_Execute(matmul, &retvals[0], &num_retvals, status);
173 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
174
175 // TODO(gjn): Add support for waiting on async local mirrors
176 if (!remote && !async && !remote_func_outputs) {
177 auto remote_arg =
178 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
179 // The input handles should never change since they have been mirrored.
180 ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
181 }
182
183 if (remote_func_outputs) {
184 const string backing_device =
185 TFE_TensorHandleBackingDeviceName(retvals[0], status);
186 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
187 EXPECT_EQ(backing_device, task2_name);
188 }
189
190 auto* retval_task0 = TFE_TensorHandleCopyToDevice(
191 retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
192 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
193
194 TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
195 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
196 TFE_DeleteTensorHandle(retval_task0);
197 float product[4] = {0};
198 EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
199 memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
200 TF_DeleteTensor(t);
201 EXPECT_EQ(7, product[0]);
202 EXPECT_EQ(10, product[1]);
203 EXPECT_EQ(15, product[2]);
204 EXPECT_EQ(22, product[3]);
205
206 TFE_DeleteTensorHandle(h0_task0);
207 TFE_DeleteTensorHandle(h1_task0);
208 if (packed_handle) {
209 TFE_DeleteTensorHandle(packed_handle);
210 }
211 TFE_DeleteTensorHandle(h1_task2);
212 TFE_DeleteTensorHandle(retvals[0]);
213 for (auto* h : handles_task0) {
214 TFE_DeleteTensorHandle(h);
215 }
216 for (auto* h : handles_task2) {
217 TFE_DeleteTensorHandle(h);
218 }
219
220 TFE_DeleteOp(matmul);
221
222 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
223 TFE_ExecutorWaitForAllPendingNodes(executor, status);
224 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
225 TFE_DeleteExecutor(executor);
226 if (func) {
227 TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
228 }
229 TFE_DeleteContext(ctx);
230
231 TF_DeleteStatus(status);
232
233 // TODO(b/136478427): Figure out how to correctly shut the server down.
234 worker_server1.release();
235 worker_server2.release();
236 }
237