1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <regex> // NOLINT
17
18 #include "tensorflow/c/eager/c_api.h"
19 #include "tensorflow/c/eager/c_api_experimental.h"
20 #include "tensorflow/c/eager/c_api_internal.h"
21 #include "tensorflow/c/eager/c_api_test_util.h"
22 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
23 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
24 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
25 #include "tensorflow/core/common_runtime/optimization_registry.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/platform/casts.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/protobuf/cluster.pb.h"
34 #include "tensorflow/core/protobuf/config.pb.h"
35 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
36
37 namespace {
38
39 using ::tensorflow::string;
40
41 // Add the values of three variables on three different tasks.
AddVariablesFunction()42 string AddVariablesFunction() {
43 tensorflow::FunctionDef def;
44 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
45 " signature {"
46 " name: 'AddVariablesFunction'"
47 " input_arg {"
48 " name: 'var'"
49 " type: DT_RESOURCE"
50 " }"
51 " output_arg {"
52 " name: 'sum'"
53 " type: DT_FLOAT"
54 " }"
55 " }"
56 " node_def {"
57 " name: 'read0'"
58 " op: 'ReadVariableOp'"
59 " input: 'var'"
60 " device: '/job:localhost/replica:0/task:0/device:CPU:0'"
61 " attr {"
62 " key: 'dtype'"
63 " value {"
64 " type: DT_FLOAT"
65 " }"
66 " }"
67 " }"
68 " node_def {"
69 " name: 'read1'"
70 " op: 'ReadVariableOp'"
71 " input: 'var'"
72 " device: '/job:localhost/replica:0/task:1/device:CPU:0'"
73 " attr {"
74 " key: 'dtype'"
75 " value {"
76 " type: DT_FLOAT"
77 " }"
78 " }"
79 " }"
80 " node_def {"
81 " name: 'read2'"
82 " op: 'ReadVariableOp'"
83 " input: 'var'"
84 " device: '/job:localhost/replica:0/task:2/device:CPU:0'"
85 " attr {"
86 " key: 'dtype'"
87 " value {"
88 " type: DT_FLOAT"
89 " }"
90 " }"
91 " }"
92 " node_def {"
93 " name: 'add1'"
94 " op: 'Add'"
95 " input: 'read0:value:0'"
96 " input: 'read1:value:0'"
97 " attr {"
98 " key: 'T'"
99 " value {"
100 " type: DT_FLOAT"
101 " }"
102 " }"
103 " }"
104 " node_def {"
105 " name: 'add2'"
106 " op: 'Add'"
107 " input: 'add1:z:0'"
108 " input: 'read2:value:0'"
109 " attr {"
110 " key: 'T'"
111 " value {"
112 " type: DT_FLOAT"
113 " }"
114 " }"
115 " }"
116 " ret {"
117 " key: 'sum'"
118 " value: 'add2:z:0'"
119 " }",
120 &def));
121 return def.SerializeAsString();
122 }
123
TestFunctionWithPackedInput(const bool remote)124 void TestFunctionWithPackedInput(const bool remote) {
125 tensorflow::ServerDef server_def = GetServerDef(3);
126
127 // This server def has the task index set to 0.
128 string serialized = server_def.SerializeAsString();
129
130 server_def.set_task_index(1);
131 std::unique_ptr<tensorflow::GrpcServer> worker_server1;
132 ASSERT_TRUE(tensorflow::GrpcServer::Create(
133 server_def, tensorflow::Env::Default(), &worker_server1)
134 .ok());
135 ASSERT_TRUE(worker_server1->Start().ok());
136
137 server_def.set_task_index(2);
138 std::unique_ptr<tensorflow::GrpcServer> worker_server2;
139 ASSERT_TRUE(tensorflow::GrpcServer::Create(
140 server_def, tensorflow::Env::Default(), &worker_server2)
141 .ok());
142 ASSERT_TRUE(worker_server2->Start().ok());
143
144 TF_Status* status = TF_NewStatus();
145 TFE_ContextOptions* opts = TFE_NewContextOptions();
146 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
147 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
148 TFE_Context* ctx = TFE_NewContext(opts, status);
149 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
150 TFE_DeleteContextOptions(opts);
151
152 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
153 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
154
155 const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
156 const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
157 const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
158
159 // Create one variable per task.
160 TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task1_name);
161 TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task2_name);
162 TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task0_name);
163
164 // Add a sync point to make sure that variables have been initialized
165 // before the function execution starts.
166 TFE_ContextAsyncWait(ctx, status);
167 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
168
169 // Pack 3 variable handles into one TFE_TensorHandle.
170 // When remote is false, function device is placed on task0. Handle types are
171 // REMOTE, REMOTE, LOCAL on task0. When remote is true, function device is
172 // placed on task1, Handle types are LOCAL, REMOTE, LOCAL on task1.
173 int num_replicas = 3;
174 std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
175 TFE_TensorHandle* packed_handle =
176 TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
177 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
178 EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
179 EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
180 EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
181
182 const string composite_device_name =
183 "/job:localhost/replica:0/task:0/device:COMPOSITE:0";
184 EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
185 composite_device_name);
186 EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
187 composite_device_name);
188 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
189
190 // Register and run a function which returns the sum of 3 variables.
191 const string function_def = AddVariablesFunction();
192 TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
193 status);
194 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
195
196 TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
197 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
198 TFE_OpAddInput(func, packed_handle, status);
199 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
200 if (remote) {
201 TFE_OpSetDevice(func, task1_name, status);
202 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
203 }
204
205 TFE_TensorHandle* retvals[1] = {nullptr};
206 int num_retvals = 1;
207 TFE_Execute(func, &retvals[0], &num_retvals, status);
208 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
209 ASSERT_EQ(1, num_retvals);
210 TFE_DeleteOp(func);
211 TFE_DeleteTensorHandle(packed_handle);
212 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
213 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
214 TFE_DeleteTensorHandle(retvals[0]);
215 float sum = 0;
216 EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
217 memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
218 TF_DeleteTensor(t);
219 EXPECT_EQ(sum, 6.0);
220
221 TFE_DeleteTensorHandle(h0);
222 TFE_DeleteTensorHandle(h1);
223 TFE_DeleteTensorHandle(h2);
224
225 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
226 TFE_ExecutorWaitForAllPendingNodes(executor, status);
227 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
228 TFE_DeleteExecutor(executor);
229 TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
230 TFE_DeleteContext(ctx);
231
232 TF_DeleteStatus(status);
233
234 // TODO(b/136478427): Figure out how to correctly shut the server down.
235 worker_server1.release();
236 worker_server2.release();
237 }
238
TEST(CAPI,TestLocalFunctionWithPackedInput)239 TEST(CAPI, TestLocalFunctionWithPackedInput) {
240 TestFunctionWithPackedInput(/*remote=*/false);
241 }
242
TEST(CAPI,TestRemoteFunctionWithPackedInput)243 TEST(CAPI, TestRemoteFunctionWithPackedInput) {
244 TestFunctionWithPackedInput(/*remote=*/true);
245 }
246
VariableAddFunctionSignature()247 string VariableAddFunctionSignature() {
248 return " signature {"
249 " name: 'VariableAddFunction'"
250 " input_arg {"
251 " name: 'var0'"
252 " type: DT_RESOURCE"
253 " }"
254 " output_arg {"
255 " name: 'var0_value'"
256 " type: DT_FLOAT"
257 " }"
258 " }"
259 " node_def {"
260 " name: 'read0'"
261 " op: 'ReadVariableOp'"
262 " input: 'var0'"
263 " attr {"
264 " key: 'dtype'"
265 " value {"
266 " type: DT_FLOAT"
267 " }"
268 " }"
269 " }"
270 " node_def {"
271 " name: 'add'"
272 " op: 'Add'"
273 " input: 'read0:value:0'"
274 " input: 'read0:value:0'"
275 " device: '/job:localhost/task:1/device:CPU:0'"
276 " attr {"
277 " key: 'T'"
278 " value {"
279 " type: DT_FLOAT"
280 " }"
281 " }"
282 " }"
283 " node_def {"
284 " name: 'identity'"
285 " op: 'Identity'"
286 " input: 'add:z:0'"
287 " device: '/job:localhost/task:0/device:CPU:0'"
288 " attr {"
289 " key: 'T'"
290 " value {"
291 " type: DT_FLOAT"
292 " }"
293 " }"
294 " }"
295 " ret {"
296 " key: 'var0_value'"
297 " value: 'identity:output:0'"
298 " }";
299 }
300
VariableAddFunction()301 string VariableAddFunction() {
302 tensorflow::FunctionDef def;
303 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
304 VariableAddFunctionSignature(), &def));
305 return def.SerializeAsString();
306 }
307
308 // A graph optimization pass that would fail when triggered for more than once.
309 class GraphErrorInjectionPass : public tensorflow::GraphOptimizationPass {
310 public:
311 static bool enabled_;
GraphErrorInjectionPass()312 GraphErrorInjectionPass() {}
313
Run(const tensorflow::GraphOptimizationPassOptions & options)314 tensorflow::Status Run(
315 const tensorflow::GraphOptimizationPassOptions& options) override {
316 if (!enabled_) {
317 return ::tensorflow::OkStatus();
318 }
319 if (first_call_) {
320 first_call_ = false;
321 return ::tensorflow::OkStatus();
322 }
323 return tensorflow::errors::Internal("Graph pass runs for more than once!");
324 }
325
326 private:
327 bool first_call_ = true;
328 };
329
330 // After the graph pass is registered, it takes effect globally and can affect
331 // other test cases. Define a static variable to switch it on and off.
332 bool GraphErrorInjectionPass::enabled_ = false;
333
334 // Test to ensure that a registered graph optimization pass is only executed
335 // once (i.e., on the main function side) in running distributed functions.
336 // This test creates a cluster with two workers, create a variable on the
337 // second worker, and run a distributed function (VariableAddFunction) whose ops
338 // span the local and remote workers. If the graph optimization pass is executed
339 // on both the main function side and the component function side, an error will
340 // be thrown in the registered graph optimization pass.
TEST(CAPI,DistributedFunctionGraphPassOnlyOnce)341 TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) {
342 // Register graph pass that will raise error if called more than once.
343 tensorflow::optimization_registration::OptimizationPassRegistration
344 register_test_pass(tensorflow::OptimizationPassRegistry::PRE_PLACEMENT, 0,
345 std::make_unique<GraphErrorInjectionPass>(),
346 "error_injector");
347 GraphErrorInjectionPass::enabled_ = true;
348
349 tensorflow::ServerDef server_def = GetServerDef(3);
350 // This server def has the task index set to 0.
351 string serialized = server_def.SerializeAsString();
352
353 server_def.set_task_index(1);
354 std::unique_ptr<tensorflow::GrpcServer> worker_server1;
355 ASSERT_TRUE(tensorflow::GrpcServer::Create(
356 server_def, tensorflow::Env::Default(), &worker_server1)
357 .ok());
358 ASSERT_TRUE(worker_server1->Start().ok());
359 server_def.set_task_index(2);
360 std::unique_ptr<tensorflow::GrpcServer> worker_server2;
361 ASSERT_TRUE(tensorflow::GrpcServer::Create(
362 server_def, tensorflow::Env::Default(), &worker_server2)
363 .ok());
364 ASSERT_TRUE(worker_server2->Start().ok());
365 const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
366
367 TF_Status* status = TF_NewStatus();
368 TFE_ContextOptions* opts = TFE_NewContextOptions();
369 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
370 TFE_Context* ctx = TFE_NewContext(opts, status);
371 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
372 TFE_DeleteContextOptions(opts);
373
374 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
375 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
376
377 TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
378 EXPECT_NE(var_handle, nullptr);
379 TFE_ContextAsyncWait(ctx, status);
380 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
381
382 const string function_def = VariableAddFunction();
383 TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
384 status);
385 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
386
387 TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
388 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
389 TFE_OpAddInput(func, var_handle, status);
390 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
391 TFE_TensorHandle* retvals[1] = {nullptr};
392 int num_retvals = 1;
393 TFE_Execute(func, &retvals[0], &num_retvals, status);
394 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
395 ASSERT_EQ(1, num_retvals);
396 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
397 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
398 TFE_DeleteTensorHandle(retvals[0]);
399 float sum = 0;
400 ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
401 memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
402 TF_DeleteTensor(t);
403 ASSERT_EQ(sum, 4.0);
404
405 TFE_DeleteOp(func);
406 TFE_DeleteTensorHandle(var_handle);
407 TFE_DeleteContext(ctx);
408 TF_DeleteStatus(status);
409
410 // TODO(b/136478427): Figure out how to correctly shut the server down.
411 worker_server1.release();
412 worker_server2.release();
413
414 // Disable the test graph pass so it does not affect other test cases.
415 GraphErrorInjectionPass::enabled_ = false;
416 }
417
VariableAddFunctionWithGraphError()418 string VariableAddFunctionWithGraphError() {
419 string signature = VariableAddFunctionSignature();
420 // Replace the node 'read0' with 'read0_maybe_with_graph_error', so that the
421 // error injecting pass can identify and introduce graph pass errors.
422 signature = std::regex_replace(signature, std::regex("read0"),
423 "read0_maybe_with_graph_error");
424 tensorflow::FunctionDef def;
425 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(signature, &def));
426 return def.SerializeAsString();
427 }
428
429 class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
430 public:
FunctionErrorInjectionPass(string error_node,string error_device)431 FunctionErrorInjectionPass(string error_node, string error_device)
432 : error_node_(error_node), error_device_(error_device) {}
Run(const tensorflow::DeviceSet & device_set,const tensorflow::ConfigProto & config_proto,std::unique_ptr<tensorflow::Graph> * graph,tensorflow::FunctionLibraryDefinition * flib_def,std::vector<std::string> * control_ret_node_names,bool * control_rets_updated)433 tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
434 const tensorflow::ConfigProto& config_proto,
435 std::unique_ptr<tensorflow::Graph>* graph,
436 tensorflow::FunctionLibraryDefinition* flib_def,
437 std::vector<std::string>* control_ret_node_names,
438 bool* control_rets_updated) override {
439 // Inject failure to function instantiation if finding a node that contains
440 // the given node name (error_node_) and requested device (error_device_).
441 for (const auto node : graph->get()->nodes()) {
442 if (node->name().find(error_node_) != string::npos &&
443 node->requested_device() == error_device_) {
444 return tensorflow::errors::Internal("Injected graph pass error.");
445 }
446 }
447 return ::tensorflow::OkStatus();
448 }
449
450 private:
451 const string error_node_;
452 const string error_device_;
453 };
454
TestDistributedFunctionCancellation(bool inject_error)455 void TestDistributedFunctionCancellation(bool inject_error) {
456 tensorflow::ServerDef server_def = GetServerDef(3);
457 // This server def has the task index set to 0.
458 string serialized = server_def.SerializeAsString();
459
460 server_def.set_task_index(1);
461 std::unique_ptr<tensorflow::GrpcServer> worker_server1;
462 ASSERT_TRUE(tensorflow::GrpcServer::Create(
463 server_def, tensorflow::Env::Default(), &worker_server1)
464 .ok());
465 ASSERT_TRUE(worker_server1->Start().ok());
466 server_def.set_task_index(2);
467 std::unique_ptr<tensorflow::GrpcServer> worker_server2;
468 ASSERT_TRUE(tensorflow::GrpcServer::Create(
469 server_def, tensorflow::Env::Default(), &worker_server2)
470 .ok());
471 ASSERT_TRUE(worker_server2->Start().ok());
472 const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
473
474 if (inject_error) {
475 // Inject a function optimization pass failure when it sees the
476 // 'read0_maybe_with_graph_error' op having a requested device `dev2_name`.
477 // During execution:
478 // * task:0 processes main function `VariableAddFunctionWithGraphError`
479 // and places the 'read0_maybe_with_graph_error' op on task:2
480 // * task:0 partitions the main function with a subgraph containing
481 // 'read0_maybe_with_graph_error' sent to task:2
482 // * task:2 graph pass reports an error when it sees
483 // 'read0_maybe_with_graph_error' with dev2_name
484 tensorflow::function_optimization_registration::
485 FunctionOptimizationPassRegistration register_test_pass(
486 std::make_unique<FunctionErrorInjectionPass>(
487 "read0_maybe_with_graph_error", dev2_name));
488 }
489
490 TF_Status* status = TF_NewStatus();
491 TFE_ContextOptions* opts = TFE_NewContextOptions();
492 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
493 TFE_Context* ctx = TFE_NewContext(opts, status);
494 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
495 TFE_DeleteContextOptions(opts);
496
497 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
498 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
499
500 TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
501 EXPECT_NE(var_handle, nullptr);
502 TFE_ContextAsyncWait(ctx, status);
503 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
504
505 const string function_def = inject_error ? VariableAddFunctionWithGraphError()
506 : VariableAddFunction();
507 TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
508 status);
509 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
510
511 TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
512 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
513 TFE_OpAddInput(func, var_handle, status);
514 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
515 TFE_TensorHandle* retvals[1] = {nullptr};
516 int num_retvals = 1;
517 TFE_Execute(func, &retvals[0], &num_retvals, status);
518
519 if (inject_error) {
520 ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
521 } else {
522 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
523 ASSERT_EQ(1, num_retvals);
524 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
525 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
526 TFE_DeleteTensorHandle(retvals[0]);
527 float sum = 0;
528 ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
529 memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
530 TF_DeleteTensor(t);
531 ASSERT_EQ(sum, 4.0);
532 }
533
534 TFE_DeleteOp(func);
535 TFE_DeleteTensorHandle(var_handle);
536 TFE_DeleteContext(ctx);
537 TF_DeleteStatus(status);
538
539 // TODO(b/136478427): Figure out how to correctly shut the server down.
540 worker_server1.release();
541 worker_server2.release();
542 }
543
TEST(CAPI,DistributedFunctionNoError)544 TEST(CAPI, DistributedFunctionNoError) {
545 TestDistributedFunctionCancellation(false);
546 }
547
548 // TODO(b/170399182): Update test once an alternative to using the function
549 // optimization hook is in place.
TEST(CAPI,DISABLED_DistributedFunctionCancelledOnError)550 TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
551 TestDistributedFunctionCancellation(true);
552 }
553
TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async)554 void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
555 tensorflow::ServerDef server_def = GetServerDef(2);
556
557 // This server def has the task index set to 0.
558 string serialized = server_def.SerializeAsString();
559
560 server_def.set_task_index(1);
561
562 std::unique_ptr<tensorflow::GrpcServer> worker_server;
563 ASSERT_TRUE(tensorflow::GrpcServer::Create(
564 server_def, tensorflow::Env::Default(), &worker_server)
565 .ok());
566 ASSERT_TRUE(worker_server->Start().ok());
567
568 TF_Status* status = TF_NewStatus();
569 TFE_ContextOptions* opts = TFE_NewContextOptions();
570 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
571 TFE_ContextOptionsSetDevicePlacementPolicy(opts,
572 TFE_DEVICE_PLACEMENT_EXPLICIT);
573 TFE_Context* ctx = TFE_NewContext(opts, status);
574 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
575 TFE_DeleteContextOptions(opts);
576
577 TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
578 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
579
580 // Use large matrices so that RPCs don't return before we get a chance
581 // to call TFE_DeleteContext.
582 TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
583 TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
584 const char remote_device_name[] =
585 "/job:localhost/replica:0/task:1/device:CPU:0";
586 auto* h0_task1 =
587 TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
588 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
589 auto* h1_task1 =
590 TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
591 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
592
593 TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
594 TFE_OpSetDevice(matmul, remote_device_name, status);
595 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
596
597 TFE_TensorHandle* retvals[1];
598 int num_retvals = 1;
599 TFE_Execute(matmul, &retvals[0], &num_retvals, status);
600 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
601 TF_DeleteStatus(status);
602
603 TFE_DeleteTensorHandle(h0_task0);
604 TFE_DeleteTensorHandle(h1_task0);
605 TFE_DeleteTensorHandle(h0_task1);
606 TFE_DeleteTensorHandle(h1_task1);
607 TFE_DeleteTensorHandle(retvals[0]);
608
609 TFE_DeleteOp(matmul);
610
611 TFE_DeleteContext(ctx);
612
613 // TODO(b/136478427): Figure out how to correctly shut the server down.
614 worker_server.release();
615 }
616
TEST(CAPI,RemoteExecuteDeleteContextWithOutstandingRPC)617 TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
618 TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
619 }
620
TEST(CAPI,RemoteExecuteDeleteContextWithOutstandingRPCAsync)621 TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
622 TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
623 }
624 } // namespace
625