• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api.h"
17 #include "tensorflow/c/eager/c_api_experimental.h"
18 #include "tensorflow/c/eager/c_api_internal.h"
19 #include "tensorflow/c/eager/c_api_test_util.h"
20 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
23 #include "tensorflow/core/platform/casts.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
28 
29 namespace {
30 
31 using ::tensorflow::string;
32 
ReplaceTaskInServerDef(tensorflow::ServerDef * server_def,int task_index)33 void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
34   tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
35   int port = tensorflow::testing::PickUnusedPortOrDie();
36   job_def->mutable_tasks()->at(task_index) =
37       tensorflow::strings::StrCat("localhost:", port);
38 }
39 
CheckTFE_TensorHandleHasFloats(TFE_TensorHandle * handle,const std::vector<float> & expected_values)40 void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
41                                     const std::vector<float>& expected_values) {
42   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
43       TF_NewStatus(), TF_DeleteStatus);
44   TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
45   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
46   std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
47   EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
48   memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
49   TF_DeleteTensor(t);
50 
51   for (int i = 0; i < expected_values.size(); i++) {
52     EXPECT_EQ(expected_values[i], actual_values[i])
53         << "Mismatch in expected values at (zero-based) index " << i;
54   }
55 }
56 
CheckRemoteMatMulExecutesOK(TFE_Context * ctx,const char * remote_device_name,const char * local_device_name)57 void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
58                                  const char* remote_device_name,
59                                  const char* local_device_name) {
60   TF_Status* status = TF_NewStatus();
61   TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
62 
63   TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
64   TFE_OpSetDevice(matmul, remote_device_name, status);
65   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
66 
67   TFE_TensorHandle* retvals[1];
68   int num_retvals = 1;
69   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
70   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
71 
72   auto* retval_task0 =
73       TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
74   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
75 
76   CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
77 
78   TFE_DeleteTensorHandle(retval_task0);
79   TFE_DeleteTensorHandle(h0_task0);
80   TFE_DeleteTensorHandle(retvals[0]);
81 
82   TFE_DeleteOp(matmul);
83 
84   TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
85   TFE_ExecutorWaitForAllPendingNodes(executor, status);
86   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
87   TFE_DeleteExecutor(executor);
88   TF_DeleteStatus(status);
89 }
90 
91 // Read the value of variable `var` and save it into `out_value`.
ReadVariable(TFE_Context * ctx,TFE_TensorHandle * var,TFE_TensorHandle ** out_value)92 void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
93                   TFE_TensorHandle** out_value) {
94   TF_Status* status = TF_NewStatus();
95   TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
96   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
97   TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
98   TFE_OpAddInput(op, var, status);
99   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
100   int num_retvals = 1;
101   TFE_Execute(op, out_value, &num_retvals, status);
102   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
103   TFE_DeleteOp(op);
104   TF_DeleteStatus(status);
105 }
106 
TestRemoteExecuteChangeServerDef(bool async)107 void TestRemoteExecuteChangeServerDef(bool async) {
108   tensorflow::ServerDef server_def = GetServerDef(2);
109 
110   // This server def has the task index set to 0.
111   string serialized = server_def.SerializeAsString();
112 
113   server_def.set_task_index(1);
114 
115   std::unique_ptr<tensorflow::GrpcServer> worker_server;
116   ASSERT_TRUE(tensorflow::GrpcServer::Create(
117                   server_def, tensorflow::Env::Default(), &worker_server)
118                   .ok());
119   ASSERT_TRUE(worker_server->Start().ok());
120 
121   TF_Status* status = TF_NewStatus();
122   TFE_ContextOptions* opts = TFE_NewContextOptions();
123   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
124   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
125   TFE_Context* ctx = TFE_NewContext(opts, status);
126   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
127   TFE_DeleteContextOptions(opts);
128 
129   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
130   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
131 
132   const char remote_device_name[] =
133       "/job:localhost/replica:0/task:1/device:CPU:0";
134   const char local_device_name[] =
135       "/job:localhost/replica:0/task:0/device:CPU:0";
136   CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
137 
138   TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
139   TFE_ExecutorWaitForAllPendingNodes(executor, status);
140   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
141 
142   // TODO(b/136478427): Figure out how to correctly shut the server down.
143   worker_server.release();
144 
145   // Update the server def with a new set of names (worker instead of
146   // localhost).
147   tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
148   serialized = updated_server_def.SerializeAsString();
149 
150   updated_server_def.set_task_index(1);
151   tensorflow::Status s = tensorflow::GrpcServer::Create(
152       updated_server_def, tensorflow::Env::Default(), &worker_server);
153   ASSERT_TRUE(s.ok()) << s.error_message();
154   ASSERT_TRUE(worker_server->Start().ok());
155 
156   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
157   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
158 
159   // Create a new tensor_handle.
160   TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
161 
162   // Check that copying it to the old remote device (named localhost) fails.
163   TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
164   EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
165 
166   // Copying and executing on the new remote device works.
167   const char new_remote_device_name[] =
168       "/job:worker/replica:0/task:1/device:CPU:0";
169   const char new_local_device_name[] =
170       "/job:worker/replica:0/task:0/device:CPU:0";
171 
172   auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
173       h0_task0_new, ctx, new_remote_device_name, status);
174   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
175 
176   TFE_DeleteTensorHandle(h0_task0_new);
177   TFE_DeleteTensorHandle(h0_task1_new);
178 
179   CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
180                               new_local_device_name);
181 
182   TFE_ExecutorWaitForAllPendingNodes(executor, status);
183   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
184   TFE_DeleteExecutor(executor);
185 
186   TF_DeleteStatus(status);
187 
188   TFE_DeleteContext(ctx);
189 
190   // TODO(b/136478427): Figure out how to correctly shut the server down.
191   worker_server.release();
192 }
193 
TEST(CAPI,RemoteExecuteChangeServerDef)194 TEST(CAPI, RemoteExecuteChangeServerDef) {
195   TestRemoteExecuteChangeServerDef(false);
196 }
TEST(CAPI,RemoteExecuteChangeServerDefAsync)197 TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
198   TestRemoteExecuteChangeServerDef(true);
199 }
200 
TestRemoteExecuteUpdateServerDef(bool async)201 void TestRemoteExecuteUpdateServerDef(bool async) {
202   tensorflow::ServerDef server_def = GetServerDef(2);
203   // This server def has the task index set to 0.
204   string serialized = server_def.SerializeAsString();
205 
206   server_def.set_task_index(1);
207   std::unique_ptr<tensorflow::GrpcServer> worker_server;
208   ASSERT_TRUE(tensorflow::GrpcServer::Create(
209                   server_def, tensorflow::Env::Default(), &worker_server)
210                   .ok());
211   ASSERT_TRUE(worker_server->Start().ok());
212 
213   TF_Status* status = TF_NewStatus();
214   TFE_ContextOptions* opts = TFE_NewContextOptions();
215   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
216   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
217   TFE_Context* ctx = TFE_NewContext(opts, status);
218   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
219   TFE_DeleteContextOptions(opts);
220 
221   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
222   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
223   const char local_device_name[] =
224       "/job:localhost/replica:0/task:0/device:CPU:0";
225   const char remote_device_name[] =
226       "/job:localhost/replica:0/task:1/device:CPU:0";
227   CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
228 
229   TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
230                              status);
231   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
232   CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
233 
234   TFE_DeleteContext(ctx);
235   TF_DeleteStatus(status);
236 
237   // TODO(b/136478427): Figure out how to correctly shut the server down.
238   worker_server.release();
239 }
240 
TEST(CAPI,RemoteExecuteUpdateServerDef)241 TEST(CAPI, RemoteExecuteUpdateServerDef) {
242   TestRemoteExecuteUpdateServerDef(false);
243 }
244 
TEST(CAPI,RemoteExecuteUpdateServerDefAsync)245 TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
246   TestRemoteExecuteUpdateServerDef(true);
247 }
248 
TestRemoteExecuteUpdateServerDefResourceAccess(bool async)249 void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
250   tensorflow::ServerDef server_def = GetServerDef(2);
251   // This server def has the task index set to 0.
252   string serialized = server_def.SerializeAsString();
253 
254   server_def.set_task_index(1);
255   std::unique_ptr<tensorflow::GrpcServer> worker_server;
256   ASSERT_TRUE(tensorflow::GrpcServer::Create(
257                   server_def, tensorflow::Env::Default(), &worker_server)
258                   .ok());
259   ASSERT_TRUE(worker_server->Start().ok());
260 
261   TF_Status* status = TF_NewStatus();
262   TFE_ContextOptions* opts = TFE_NewContextOptions();
263   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
264   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
265   TFE_Context* ctx = TFE_NewContext(opts, status);
266   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
267   TFE_DeleteContextOptions(opts);
268 
269   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
270   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
271   const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
272   const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
273 
274   TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
275   EXPECT_NE(var_handle0, nullptr);
276   TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
277   EXPECT_NE(var_handle1, nullptr);
278 
279   TFE_TensorHandle* value_handle = nullptr;
280   ReadVariable(ctx, var_handle1, &value_handle);
281   CheckTFE_TensorHandleHasFloats(value_handle, {2});
282   TFE_DeleteTensorHandle(value_handle);
283 
284   // Start a new worker to replace task:1
285   ReplaceTaskInServerDef(&server_def, 1);
286   server_def.set_task_index(1);
287   // TODO(b/136478427): Figure out how to correctly shut the server down.
288   worker_server.release();
289   ASSERT_TRUE(tensorflow::GrpcServer::Create(
290                   server_def, tensorflow::Env::Default(), &worker_server)
291                   .ok());
292   ASSERT_TRUE(worker_server->Start().ok());
293 
294   // Update server def to replace the remote device with the device info on the
295   // new worker (different incarnation ID).
296   server_def.set_task_index(0);
297   string serialized_update = server_def.SerializeAsString();
298   TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
299                              serialized_update.size(), status);
300   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
301 
302   // The device of var_handle0 is local device which is the same before and
303   // after cluster update. Remove resource with valid device should succeed.
304   TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
305   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
306   TFE_OpAddInput(op, var_handle0, status);
307   TFE_OpSetDevice(op, dev0_name, status);
308   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
309   int num_retvals = 0;
310   TFE_Execute(op, nullptr, &num_retvals, status);
311   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
312   TFE_DeleteOp(op);
313 
314   // The device of var_handle1 is remote device, which was replaced during
315   // cluster update. Removing resource with invalid device should fail
316   // gracefully (i.e., with error status) instead of crashing with segfaults.
317   op = TFE_NewOp(ctx, "DestroyResourceOp", status);
318   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
319   TFE_OpAddInput(op, var_handle1, status);
320   TFE_OpSetDevice(op, dev1_name, status);
321   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
322   num_retvals = 0;
323   TFE_Execute(op, nullptr, &num_retvals, status);
324   EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
325   TFE_DeleteOp(op);
326 
327   TFE_DeleteTensorHandle(var_handle0);
328   TFE_DeleteTensorHandle(var_handle1);
329 
330   TFE_DeleteContext(ctx);
331   TF_DeleteStatus(status);
332 
333   // TODO(b/136478427): Figure out how to correctly shut the server down.
334   worker_server.release();
335 }
336 
TEST(CAPI,TestRemoteExecuteUpdateServerDefResourceAccess)337 TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
338   TestRemoteExecuteUpdateServerDefResourceAccess(false);
339 }
340 
TEST(CAPI,TestRemoteExecuteUpdateServerDefResourceAccessAsync)341 TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
342   TestRemoteExecuteUpdateServerDefResourceAccess(true);
343 }
344 
TestRemoteExecuteUpdateServerDefWithFailures(bool async)345 void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
346   // Fail fast on GetStatus requests so we can get errors instead of timeout
347   // when updating cluster with non-exsitent worker
348   tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
349 
350   tensorflow::ServerDef server_def = GetServerDef(2);
351   // This server def has the task index set to 0.
352   string serialized = server_def.SerializeAsString();
353 
354   server_def.set_task_index(1);
355   std::unique_ptr<tensorflow::GrpcServer> worker_server;
356   ASSERT_TRUE(tensorflow::GrpcServer::Create(
357                   server_def, tensorflow::Env::Default(), &worker_server)
358                   .ok());
359   ASSERT_TRUE(worker_server->Start().ok());
360 
361   TF_Status* status = TF_NewStatus();
362   TFE_ContextOptions* opts = TFE_NewContextOptions();
363   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
364   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
365   TFE_Context* ctx = TFE_NewContext(opts, status);
366   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
367   TFE_DeleteContextOptions(opts);
368 
369   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
370   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
371   const char local_device_name[] =
372       "/job:localhost/replica:0/task:0/device:CPU:0";
373   const char remote_device_name[] =
374       "/job:localhost/replica:0/task:1/device:CPU:0";
375   CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
376 
377   // Adding a non-existent remote worker to cluster def. This should cause the
378   // UpdateServerDef call to fail.
379   tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
380   tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
381   int port = tensorflow::testing::PickUnusedPortOrDie();
382   job_def->mutable_tasks()->insert(
383       {2, tensorflow::strings::StrCat("localhost:", port)});
384   server_def.set_task_index(0);
385   string serialized_update = server_def.SerializeAsString();
386   TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
387                              serialized_update.size(), status);
388   EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
389 
390   // Even after the prevoiusly failed cluster update, another update and op
391   // execution should work fine as long as the provided server_def is valid.
392   TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
393                              status);
394   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
395   CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
396 
397   TFE_DeleteContext(ctx);
398   TF_DeleteStatus(status);
399 
400   // TODO(b/136478427): Figure out how to correctly shut the server down.
401   worker_server.release();
402   tensorflow::unsetenv("GRPC_FAIL_FAST");
403 }
404 
TEST(CAPI,RemoteExecuteUpdateServerDefWithFailures)405 TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
406   TestRemoteExecuteUpdateServerDefWithFailures(false);
407 }
408 
TEST(CAPI,RemoteExecuteUpdateServerDefWithFailuresAsync)409 TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
410   TestRemoteExecuteUpdateServerDefWithFailures(true);
411 }
412 
TestConnectToCluster(bool keep_localhost_for_first_connect)413 void TestConnectToCluster(bool keep_localhost_for_first_connect) {
414   // Fail fast on GetStatus requests so we can get errors instead of timeout
415   // when updating cluster with non-exsitent worker
416   tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
417 
418   const string first_name =
419       keep_localhost_for_first_connect ? "localhost" : "abc";
420   tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
421 
422   TF_Status* status = TF_NewStatus();
423   TFE_ContextOptions* opts = TFE_NewContextOptions();
424   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
425   TFE_Context* ctx = TFE_NewContext(opts, status);
426   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
427   TFE_DeleteContextOptions(opts);
428 
429   const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
430   TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
431   EXPECT_NE(var_handle0, nullptr);
432 
433   tensorflow::Status status2;
434   EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
435 
436   // Rename local device
437   // This server def has the task index set to 0.
438   string serialized = server_def.SerializeAsString();
439   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
440   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
441 
442   const string dev1_name =
443       absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
444   TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
445   EXPECT_NE(var_handle1, nullptr);
446   EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
447 
448   // Another renaming of local device
449   const string second_name = "def";
450   server_def.set_job_name(second_name);
451   server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
452   (*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
453       absl::StrCat(second_name, ":",
454                    tensorflow::testing::PickUnusedPortOrDie());
455 
456   serialized = server_def.SerializeAsString();
457   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
458   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
459 
460   const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
461   TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
462   EXPECT_NE(var_handle2, nullptr);
463   EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
464 
465   TFE_DeleteTensorHandle(var_handle0);
466   TFE_DeleteTensorHandle(var_handle1);
467   TFE_DeleteTensorHandle(var_handle2);
468 
469   TFE_DeleteContext(ctx);
470   TF_DeleteStatus(status);
471 
472   tensorflow::unsetenv("GRPC_FAIL_FAST");
473 }
474 
TEST(CAPI,ConnectToClusterLocalhostFirst)475 TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
476 
TEST(CAPI,ConnectToClusterRenameFirst)477 TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
478 
479 }  // namespace
480