1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api.h"
17 #include "tensorflow/c/eager/c_api_experimental.h"
18 #include "tensorflow/c/eager/c_api_internal.h"
19 #include "tensorflow/c/eager/c_api_test_util.h"
20 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
23 #include "tensorflow/core/platform/casts.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
28
29 namespace {
30
31 using ::tensorflow::string;
32
ReplaceTaskInServerDef(tensorflow::ServerDef * server_def,int task_index)33 void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
34 tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
35 int port = tensorflow::testing::PickUnusedPortOrDie();
36 job_def->mutable_tasks()->at(task_index) =
37 tensorflow::strings::StrCat("localhost:", port);
38 }
39
CheckTFE_TensorHandleHasFloats(TFE_TensorHandle * handle,const std::vector<float> & expected_values)40 void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
41 const std::vector<float>& expected_values) {
42 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
43 TF_NewStatus(), TF_DeleteStatus);
44 TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
45 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
46 std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
47 EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
48 memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
49 TF_DeleteTensor(t);
50
51 for (int i = 0; i < expected_values.size(); i++) {
52 EXPECT_EQ(expected_values[i], actual_values[i])
53 << "Mismatch in expected values at (zero-based) index " << i;
54 }
55 }
56
CheckRemoteMatMulExecutesOK(TFE_Context * ctx,const char * remote_device_name,const char * local_device_name)57 void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
58 const char* remote_device_name,
59 const char* local_device_name) {
60 TF_Status* status = TF_NewStatus();
61 TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
62
63 TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
64 TFE_OpSetDevice(matmul, remote_device_name, status);
65 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
66
67 TFE_TensorHandle* retvals[1];
68 int num_retvals = 1;
69 TFE_Execute(matmul, &retvals[0], &num_retvals, status);
70 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
71
72 auto* retval_task0 =
73 TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
74 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
75
76 CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
77
78 TFE_DeleteTensorHandle(retval_task0);
79 TFE_DeleteTensorHandle(h0_task0);
80 TFE_DeleteTensorHandle(retvals[0]);
81
82 TFE_DeleteOp(matmul);
83
84 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
85 TFE_ExecutorWaitForAllPendingNodes(executor, status);
86 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
87 TFE_DeleteExecutor(executor);
88 TF_DeleteStatus(status);
89 }
90
91 // Read the value of variable `var` and save it into `out_value`.
ReadVariable(TFE_Context * ctx,TFE_TensorHandle * var,TFE_TensorHandle ** out_value)92 void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
93 TFE_TensorHandle** out_value) {
94 TF_Status* status = TF_NewStatus();
95 TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
96 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
97 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
98 TFE_OpAddInput(op, var, status);
99 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
100 int num_retvals = 1;
101 TFE_Execute(op, out_value, &num_retvals, status);
102 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
103 TFE_DeleteOp(op);
104 TF_DeleteStatus(status);
105 }
106
TestRemoteExecuteChangeServerDef(bool async)107 void TestRemoteExecuteChangeServerDef(bool async) {
108 tensorflow::ServerDef server_def = GetServerDef(2);
109
110 // This server def has the task index set to 0.
111 string serialized = server_def.SerializeAsString();
112
113 server_def.set_task_index(1);
114
115 std::unique_ptr<tensorflow::GrpcServer> worker_server;
116 ASSERT_TRUE(tensorflow::GrpcServer::Create(
117 server_def, tensorflow::Env::Default(), &worker_server)
118 .ok());
119 ASSERT_TRUE(worker_server->Start().ok());
120
121 TF_Status* status = TF_NewStatus();
122 TFE_ContextOptions* opts = TFE_NewContextOptions();
123 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
124 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
125 TFE_Context* ctx = TFE_NewContext(opts, status);
126 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
127 TFE_DeleteContextOptions(opts);
128
129 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
130 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
131
132 const char remote_device_name[] =
133 "/job:localhost/replica:0/task:1/device:CPU:0";
134 const char local_device_name[] =
135 "/job:localhost/replica:0/task:0/device:CPU:0";
136 CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
137
138 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
139 TFE_ExecutorWaitForAllPendingNodes(executor, status);
140 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
141
142 // TODO(b/136478427): Figure out how to correctly shut the server down.
143 worker_server.release();
144
145 // Update the server def with a new set of names (worker instead of
146 // localhost).
147 tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
148 serialized = updated_server_def.SerializeAsString();
149
150 updated_server_def.set_task_index(1);
151 tensorflow::Status s = tensorflow::GrpcServer::Create(
152 updated_server_def, tensorflow::Env::Default(), &worker_server);
153 ASSERT_TRUE(s.ok()) << s.error_message();
154 ASSERT_TRUE(worker_server->Start().ok());
155
156 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
157 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
158
159 // Create a new tensor_handle.
160 TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
161
162 // Check that copying it to the old remote device (named localhost) fails.
163 TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
164 EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
165
166 // Copying and executing on the new remote device works.
167 const char new_remote_device_name[] =
168 "/job:worker/replica:0/task:1/device:CPU:0";
169 const char new_local_device_name[] =
170 "/job:worker/replica:0/task:0/device:CPU:0";
171
172 auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
173 h0_task0_new, ctx, new_remote_device_name, status);
174 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
175
176 TFE_DeleteTensorHandle(h0_task0_new);
177 TFE_DeleteTensorHandle(h0_task1_new);
178
179 CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
180 new_local_device_name);
181
182 TFE_ExecutorWaitForAllPendingNodes(executor, status);
183 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
184 TFE_DeleteExecutor(executor);
185
186 TF_DeleteStatus(status);
187
188 TFE_DeleteContext(ctx);
189
190 // TODO(b/136478427): Figure out how to correctly shut the server down.
191 worker_server.release();
192 }
193
TEST(CAPI,RemoteExecuteChangeServerDef)194 TEST(CAPI, RemoteExecuteChangeServerDef) {
195 TestRemoteExecuteChangeServerDef(false);
196 }
TEST(CAPI,RemoteExecuteChangeServerDefAsync)197 TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
198 TestRemoteExecuteChangeServerDef(true);
199 }
200
TestRemoteExecuteUpdateServerDef(bool async)201 void TestRemoteExecuteUpdateServerDef(bool async) {
202 tensorflow::ServerDef server_def = GetServerDef(2);
203 // This server def has the task index set to 0.
204 string serialized = server_def.SerializeAsString();
205
206 server_def.set_task_index(1);
207 std::unique_ptr<tensorflow::GrpcServer> worker_server;
208 ASSERT_TRUE(tensorflow::GrpcServer::Create(
209 server_def, tensorflow::Env::Default(), &worker_server)
210 .ok());
211 ASSERT_TRUE(worker_server->Start().ok());
212
213 TF_Status* status = TF_NewStatus();
214 TFE_ContextOptions* opts = TFE_NewContextOptions();
215 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
216 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
217 TFE_Context* ctx = TFE_NewContext(opts, status);
218 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
219 TFE_DeleteContextOptions(opts);
220
221 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
222 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
223 const char local_device_name[] =
224 "/job:localhost/replica:0/task:0/device:CPU:0";
225 const char remote_device_name[] =
226 "/job:localhost/replica:0/task:1/device:CPU:0";
227 CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
228
229 TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
230 status);
231 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
232 CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
233
234 TFE_DeleteContext(ctx);
235 TF_DeleteStatus(status);
236
237 // TODO(b/136478427): Figure out how to correctly shut the server down.
238 worker_server.release();
239 }
240
TEST(CAPI,RemoteExecuteUpdateServerDef)241 TEST(CAPI, RemoteExecuteUpdateServerDef) {
242 TestRemoteExecuteUpdateServerDef(false);
243 }
244
TEST(CAPI,RemoteExecuteUpdateServerDefAsync)245 TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
246 TestRemoteExecuteUpdateServerDef(true);
247 }
248
TestRemoteExecuteUpdateServerDefResourceAccess(bool async)249 void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
250 tensorflow::ServerDef server_def = GetServerDef(2);
251 // This server def has the task index set to 0.
252 string serialized = server_def.SerializeAsString();
253
254 server_def.set_task_index(1);
255 std::unique_ptr<tensorflow::GrpcServer> worker_server;
256 ASSERT_TRUE(tensorflow::GrpcServer::Create(
257 server_def, tensorflow::Env::Default(), &worker_server)
258 .ok());
259 ASSERT_TRUE(worker_server->Start().ok());
260
261 TF_Status* status = TF_NewStatus();
262 TFE_ContextOptions* opts = TFE_NewContextOptions();
263 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
264 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
265 TFE_Context* ctx = TFE_NewContext(opts, status);
266 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
267 TFE_DeleteContextOptions(opts);
268
269 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
270 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
271 const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
272 const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
273
274 TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
275 EXPECT_NE(var_handle0, nullptr);
276 TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
277 EXPECT_NE(var_handle1, nullptr);
278
279 TFE_TensorHandle* value_handle = nullptr;
280 ReadVariable(ctx, var_handle1, &value_handle);
281 CheckTFE_TensorHandleHasFloats(value_handle, {2});
282 TFE_DeleteTensorHandle(value_handle);
283
284 // Start a new worker to replace task:1
285 ReplaceTaskInServerDef(&server_def, 1);
286 server_def.set_task_index(1);
287 // TODO(b/136478427): Figure out how to correctly shut the server down.
288 worker_server.release();
289 ASSERT_TRUE(tensorflow::GrpcServer::Create(
290 server_def, tensorflow::Env::Default(), &worker_server)
291 .ok());
292 ASSERT_TRUE(worker_server->Start().ok());
293
294 // Update server def to replace the remote device with the device info on the
295 // new worker (different incarnation ID).
296 server_def.set_task_index(0);
297 string serialized_update = server_def.SerializeAsString();
298 TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
299 serialized_update.size(), status);
300 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
301
302 // The device of var_handle0 is local device which is the same before and
303 // after cluster update. Remove resource with valid device should succeed.
304 TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
305 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
306 TFE_OpAddInput(op, var_handle0, status);
307 TFE_OpSetDevice(op, dev0_name, status);
308 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
309 int num_retvals = 0;
310 TFE_Execute(op, nullptr, &num_retvals, status);
311 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
312 TFE_DeleteOp(op);
313
314 // The device of var_handle1 is remote device, which was replaced during
315 // cluster update. Removing resource with invalid device should fail
316 // gracefully (i.e., with error status) instead of crashing with segfaults.
317 op = TFE_NewOp(ctx, "DestroyResourceOp", status);
318 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
319 TFE_OpAddInput(op, var_handle1, status);
320 TFE_OpSetDevice(op, dev1_name, status);
321 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
322 num_retvals = 0;
323 TFE_Execute(op, nullptr, &num_retvals, status);
324 EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
325 TFE_DeleteOp(op);
326
327 TFE_DeleteTensorHandle(var_handle0);
328 TFE_DeleteTensorHandle(var_handle1);
329
330 TFE_DeleteContext(ctx);
331 TF_DeleteStatus(status);
332
333 // TODO(b/136478427): Figure out how to correctly shut the server down.
334 worker_server.release();
335 }
336
TEST(CAPI,TestRemoteExecuteUpdateServerDefResourceAccess)337 TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
338 TestRemoteExecuteUpdateServerDefResourceAccess(false);
339 }
340
TEST(CAPI,TestRemoteExecuteUpdateServerDefResourceAccessAsync)341 TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
342 TestRemoteExecuteUpdateServerDefResourceAccess(true);
343 }
344
TestRemoteExecuteUpdateServerDefWithFailures(bool async)345 void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
346 // Fail fast on GetStatus requests so we can get errors instead of timeout
347 // when updating cluster with non-exsitent worker
348 tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
349
350 tensorflow::ServerDef server_def = GetServerDef(2);
351 // This server def has the task index set to 0.
352 string serialized = server_def.SerializeAsString();
353
354 server_def.set_task_index(1);
355 std::unique_ptr<tensorflow::GrpcServer> worker_server;
356 ASSERT_TRUE(tensorflow::GrpcServer::Create(
357 server_def, tensorflow::Env::Default(), &worker_server)
358 .ok());
359 ASSERT_TRUE(worker_server->Start().ok());
360
361 TF_Status* status = TF_NewStatus();
362 TFE_ContextOptions* opts = TFE_NewContextOptions();
363 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
364 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
365 TFE_Context* ctx = TFE_NewContext(opts, status);
366 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
367 TFE_DeleteContextOptions(opts);
368
369 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
370 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
371 const char local_device_name[] =
372 "/job:localhost/replica:0/task:0/device:CPU:0";
373 const char remote_device_name[] =
374 "/job:localhost/replica:0/task:1/device:CPU:0";
375 CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
376
377 // Adding a non-existent remote worker to cluster def. This should cause the
378 // UpdateServerDef call to fail.
379 tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
380 tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
381 int port = tensorflow::testing::PickUnusedPortOrDie();
382 job_def->mutable_tasks()->insert(
383 {2, tensorflow::strings::StrCat("localhost:", port)});
384 server_def.set_task_index(0);
385 string serialized_update = server_def.SerializeAsString();
386 TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
387 serialized_update.size(), status);
388 EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
389
390 // Even after the prevoiusly failed cluster update, another update and op
391 // execution should work fine as long as the provided server_def is valid.
392 TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
393 status);
394 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
395 CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
396
397 TFE_DeleteContext(ctx);
398 TF_DeleteStatus(status);
399
400 // TODO(b/136478427): Figure out how to correctly shut the server down.
401 worker_server.release();
402 tensorflow::unsetenv("GRPC_FAIL_FAST");
403 }
404
TEST(CAPI,RemoteExecuteUpdateServerDefWithFailures)405 TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
406 TestRemoteExecuteUpdateServerDefWithFailures(false);
407 }
408
TEST(CAPI,RemoteExecuteUpdateServerDefWithFailuresAsync)409 TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
410 TestRemoteExecuteUpdateServerDefWithFailures(true);
411 }
412
TestConnectToCluster(bool keep_localhost_for_first_connect)413 void TestConnectToCluster(bool keep_localhost_for_first_connect) {
414 // Fail fast on GetStatus requests so we can get errors instead of timeout
415 // when updating cluster with non-exsitent worker
416 tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
417
418 const string first_name =
419 keep_localhost_for_first_connect ? "localhost" : "abc";
420 tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
421
422 TF_Status* status = TF_NewStatus();
423 TFE_ContextOptions* opts = TFE_NewContextOptions();
424 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
425 TFE_Context* ctx = TFE_NewContext(opts, status);
426 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
427 TFE_DeleteContextOptions(opts);
428
429 const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
430 TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
431 EXPECT_NE(var_handle0, nullptr);
432
433 tensorflow::Status status2;
434 EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
435
436 // Rename local device
437 // This server def has the task index set to 0.
438 string serialized = server_def.SerializeAsString();
439 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
440 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
441
442 const string dev1_name =
443 absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
444 TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
445 EXPECT_NE(var_handle1, nullptr);
446 EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
447
448 // Another renaming of local device
449 const string second_name = "def";
450 server_def.set_job_name(second_name);
451 server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
452 (*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
453 absl::StrCat(second_name, ":",
454 tensorflow::testing::PickUnusedPortOrDie());
455
456 serialized = server_def.SerializeAsString();
457 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
458 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
459
460 const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
461 TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
462 EXPECT_NE(var_handle2, nullptr);
463 EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
464
465 TFE_DeleteTensorHandle(var_handle0);
466 TFE_DeleteTensorHandle(var_handle1);
467 TFE_DeleteTensorHandle(var_handle2);
468
469 TFE_DeleteContext(ctx);
470 TF_DeleteStatus(status);
471
472 tensorflow::unsetenv("GRPC_FAIL_FAST");
473 }
474
TEST(CAPI,ConnectToClusterLocalhostFirst)475 TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
476
TEST(CAPI,ConnectToClusterRenameFirst)477 TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
478
479 } // namespace
480