• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <regex>  // NOLINT
17 
18 #include "tensorflow/c/eager/c_api.h"
19 #include "tensorflow/c/eager/c_api_experimental.h"
20 #include "tensorflow/c/eager/c_api_internal.h"
21 #include "tensorflow/c/eager/c_api_test_util.h"
22 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
23 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
24 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
25 #include "tensorflow/core/common_runtime/optimization_registry.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/platform/casts.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/protobuf/cluster.pb.h"
34 #include "tensorflow/core/protobuf/config.pb.h"
35 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
36 
37 namespace {
38 
39 using ::tensorflow::string;
40 
41 // Add the values of three variables on three different tasks.
AddVariablesFunction()42 string AddVariablesFunction() {
43   tensorflow::FunctionDef def;
44   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
45       "    signature {"
46       "      name: 'AddVariablesFunction'"
47       "      input_arg {"
48       "        name: 'var'"
49       "        type: DT_RESOURCE"
50       "      }"
51       "      output_arg {"
52       "        name: 'sum'"
53       "        type: DT_FLOAT"
54       "      }"
55       "    }"
56       "    node_def {"
57       "      name: 'read0'"
58       "      op: 'ReadVariableOp'"
59       "      input: 'var'"
60       "      device: '/job:localhost/replica:0/task:0/device:CPU:0'"
61       "      attr {"
62       "        key: 'dtype'"
63       "        value {"
64       "          type: DT_FLOAT"
65       "        }"
66       "      }"
67       "    }"
68       "    node_def {"
69       "      name: 'read1'"
70       "      op: 'ReadVariableOp'"
71       "      input: 'var'"
72       "      device: '/job:localhost/replica:0/task:1/device:CPU:0'"
73       "      attr {"
74       "        key: 'dtype'"
75       "        value {"
76       "          type: DT_FLOAT"
77       "        }"
78       "      }"
79       "    }"
80       "    node_def {"
81       "      name: 'read2'"
82       "      op: 'ReadVariableOp'"
83       "      input: 'var'"
84       "      device: '/job:localhost/replica:0/task:2/device:CPU:0'"
85       "      attr {"
86       "        key: 'dtype'"
87       "        value {"
88       "          type: DT_FLOAT"
89       "        }"
90       "      }"
91       "    }"
92       "    node_def {"
93       "      name: 'add1'"
94       "      op: 'Add'"
95       "      input: 'read0:value:0'"
96       "      input: 'read1:value:0'"
97       "      attr {"
98       "        key: 'T'"
99       "        value {"
100       "          type: DT_FLOAT"
101       "        }"
102       "      }"
103       "    }"
104       "    node_def {"
105       "      name: 'add2'"
106       "      op: 'Add'"
107       "      input: 'add1:z:0'"
108       "      input: 'read2:value:0'"
109       "      attr {"
110       "        key: 'T'"
111       "        value {"
112       "          type: DT_FLOAT"
113       "        }"
114       "      }"
115       "    }"
116       "    ret {"
117       "      key: 'sum'"
118       "      value: 'add2:z:0'"
119       "    }",
120       &def));
121   return def.SerializeAsString();
122 }
123 
TestFunctionWithPackedInput(const bool remote)124 void TestFunctionWithPackedInput(const bool remote) {
125   tensorflow::ServerDef server_def = GetServerDef(3);
126 
127   // This server def has the task index set to 0.
128   string serialized = server_def.SerializeAsString();
129 
130   server_def.set_task_index(1);
131   std::unique_ptr<tensorflow::GrpcServer> worker_server1;
132   ASSERT_TRUE(tensorflow::GrpcServer::Create(
133                   server_def, tensorflow::Env::Default(), &worker_server1)
134                   .ok());
135   ASSERT_TRUE(worker_server1->Start().ok());
136 
137   server_def.set_task_index(2);
138   std::unique_ptr<tensorflow::GrpcServer> worker_server2;
139   ASSERT_TRUE(tensorflow::GrpcServer::Create(
140                   server_def, tensorflow::Env::Default(), &worker_server2)
141                   .ok());
142   ASSERT_TRUE(worker_server2->Start().ok());
143 
144   TF_Status* status = TF_NewStatus();
145   TFE_ContextOptions* opts = TFE_NewContextOptions();
146   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
147   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
148   TFE_Context* ctx = TFE_NewContext(opts, status);
149   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
150   TFE_DeleteContextOptions(opts);
151 
152   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
153   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
154 
155   const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
156   const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
157   const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
158 
159   // Create one variable per task.
160   TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task1_name);
161   TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task2_name);
162   TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task0_name);
163 
164   // Add a sync point to make sure that variables have been initialized
165   // before the function execution starts.
166   TFE_ContextAsyncWait(ctx, status);
167   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
168 
169   // Pack 3 variable handles into one TFE_TensorHandle.
170   // When remote is false, function device is placed on task0. Handle types are
171   // REMOTE, REMOTE, LOCAL on task0. When remote is true, function device is
172   // placed on task1, Handle types are LOCAL, REMOTE, LOCAL on task1.
173   int num_replicas = 3;
174   std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
175   TFE_TensorHandle* packed_handle =
176       TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
177   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
178   EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
179   EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
180   EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
181 
182   const string composite_device_name =
183       "/job:localhost/replica:0/task:0/device:COMPOSITE:0";
184   EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
185             composite_device_name);
186   EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
187             composite_device_name);
188   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
189 
190   // Register and run a function which returns the sum of 3 variables.
191   const string function_def = AddVariablesFunction();
192   TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
193                             status);
194   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
195 
196   TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
197   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
198   TFE_OpAddInput(func, packed_handle, status);
199   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
200   if (remote) {
201     TFE_OpSetDevice(func, task1_name, status);
202     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
203   }
204 
205   TFE_TensorHandle* retvals[1] = {nullptr};
206   int num_retvals = 1;
207   TFE_Execute(func, &retvals[0], &num_retvals, status);
208   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
209   ASSERT_EQ(1, num_retvals);
210   TFE_DeleteOp(func);
211   TFE_DeleteTensorHandle(packed_handle);
212   TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
213   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
214   TFE_DeleteTensorHandle(retvals[0]);
215   float sum = 0;
216   EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
217   memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
218   TF_DeleteTensor(t);
219   EXPECT_EQ(sum, 6.0);
220 
221   TFE_DeleteTensorHandle(h0);
222   TFE_DeleteTensorHandle(h1);
223   TFE_DeleteTensorHandle(h2);
224 
225   TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
226   TFE_ExecutorWaitForAllPendingNodes(executor, status);
227   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
228   TFE_DeleteExecutor(executor);
229   TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
230   TFE_DeleteContext(ctx);
231 
232   TF_DeleteStatus(status);
233 
234   // TODO(b/136478427): Figure out how to correctly shut the server down.
235   worker_server1.release();
236   worker_server2.release();
237 }
238 
TEST(CAPI,TestLocalFunctionWithPackedInput)239 TEST(CAPI, TestLocalFunctionWithPackedInput) {
240   TestFunctionWithPackedInput(/*remote=*/false);
241 }
242 
TEST(CAPI,TestRemoteFunctionWithPackedInput)243 TEST(CAPI, TestRemoteFunctionWithPackedInput) {
244   TestFunctionWithPackedInput(/*remote=*/true);
245 }
246 
VariableAddFunctionSignature()247 string VariableAddFunctionSignature() {
248   return "    signature {"
249          "      name: 'VariableAddFunction'"
250          "      input_arg {"
251          "        name: 'var0'"
252          "        type: DT_RESOURCE"
253          "      }"
254          "      output_arg {"
255          "        name: 'var0_value'"
256          "        type: DT_FLOAT"
257          "      }"
258          "    }"
259          "    node_def {"
260          "      name: 'read0'"
261          "      op: 'ReadVariableOp'"
262          "      input: 'var0'"
263          "      attr {"
264          "        key: 'dtype'"
265          "        value {"
266          "          type: DT_FLOAT"
267          "        }"
268          "      }"
269          "    }"
270          "    node_def {"
271          "      name: 'add'"
272          "      op: 'Add'"
273          "      input: 'read0:value:0'"
274          "      input: 'read0:value:0'"
275          "      device: '/job:localhost/task:1/device:CPU:0'"
276          "      attr {"
277          "        key: 'T'"
278          "        value {"
279          "          type: DT_FLOAT"
280          "        }"
281          "      }"
282          "    }"
283          "    node_def {"
284          "      name: 'identity'"
285          "      op: 'Identity'"
286          "      input: 'add:z:0'"
287          "      device: '/job:localhost/task:0/device:CPU:0'"
288          "      attr {"
289          "        key: 'T'"
290          "        value {"
291          "          type: DT_FLOAT"
292          "        }"
293          "      }"
294          "    }"
295          "    ret {"
296          "      key: 'var0_value'"
297          "      value: 'identity:output:0'"
298          "    }";
299 }
300 
VariableAddFunction()301 string VariableAddFunction() {
302   tensorflow::FunctionDef def;
303   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
304       VariableAddFunctionSignature(), &def));
305   return def.SerializeAsString();
306 }
307 
308 // A graph optimization pass that would fail when triggered for more than once.
309 class GraphErrorInjectionPass : public tensorflow::GraphOptimizationPass {
310  public:
311   static bool enabled_;
GraphErrorInjectionPass()312   GraphErrorInjectionPass() {}
313 
Run(const tensorflow::GraphOptimizationPassOptions & options)314   tensorflow::Status Run(
315       const tensorflow::GraphOptimizationPassOptions& options) override {
316     if (!enabled_) {
317       return ::tensorflow::OkStatus();
318     }
319     if (first_call_) {
320       first_call_ = false;
321       return ::tensorflow::OkStatus();
322     }
323     return tensorflow::errors::Internal("Graph pass runs for more than once!");
324   }
325 
326  private:
327   bool first_call_ = true;
328 };
329 
330 // After the graph pass is registered, it takes effect globally and can affect
331 // other test cases. Define a static variable to switch it on and off.
332 bool GraphErrorInjectionPass::enabled_ = false;
333 
334 // Test to ensure that a registered graph optimization pass is only executed
335 // once (i.e., on the main function side) in running distributed functions.
336 // This test creates a cluster with two workers, create a variable on the
337 // second worker, and run a distributed function (VariableAddFunction) whose ops
338 // span the local and remote workers. If the graph optimization pass is executed
339 // on both the main function side and the component function side, an error will
340 // be thrown in the registered graph optimization pass.
TEST(CAPI,DistributedFunctionGraphPassOnlyOnce)341 TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) {
342   // Register graph pass that will raise error if called more than once.
343   tensorflow::optimization_registration::OptimizationPassRegistration
344       register_test_pass(tensorflow::OptimizationPassRegistry::PRE_PLACEMENT, 0,
345                          std::make_unique<GraphErrorInjectionPass>(),
346                          "error_injector");
347   GraphErrorInjectionPass::enabled_ = true;
348 
349   tensorflow::ServerDef server_def = GetServerDef(3);
350   // This server def has the task index set to 0.
351   string serialized = server_def.SerializeAsString();
352 
353   server_def.set_task_index(1);
354   std::unique_ptr<tensorflow::GrpcServer> worker_server1;
355   ASSERT_TRUE(tensorflow::GrpcServer::Create(
356                   server_def, tensorflow::Env::Default(), &worker_server1)
357                   .ok());
358   ASSERT_TRUE(worker_server1->Start().ok());
359   server_def.set_task_index(2);
360   std::unique_ptr<tensorflow::GrpcServer> worker_server2;
361   ASSERT_TRUE(tensorflow::GrpcServer::Create(
362                   server_def, tensorflow::Env::Default(), &worker_server2)
363                   .ok());
364   ASSERT_TRUE(worker_server2->Start().ok());
365   const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
366 
367   TF_Status* status = TF_NewStatus();
368   TFE_ContextOptions* opts = TFE_NewContextOptions();
369   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
370   TFE_Context* ctx = TFE_NewContext(opts, status);
371   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
372   TFE_DeleteContextOptions(opts);
373 
374   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
375   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
376 
377   TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
378   EXPECT_NE(var_handle, nullptr);
379   TFE_ContextAsyncWait(ctx, status);
380   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
381 
382   const string function_def = VariableAddFunction();
383   TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
384                             status);
385   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
386 
387   TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
388   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
389   TFE_OpAddInput(func, var_handle, status);
390   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
391   TFE_TensorHandle* retvals[1] = {nullptr};
392   int num_retvals = 1;
393   TFE_Execute(func, &retvals[0], &num_retvals, status);
394   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
395   ASSERT_EQ(1, num_retvals);
396   TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
397   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
398   TFE_DeleteTensorHandle(retvals[0]);
399   float sum = 0;
400   ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
401   memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
402   TF_DeleteTensor(t);
403   ASSERT_EQ(sum, 4.0);
404 
405   TFE_DeleteOp(func);
406   TFE_DeleteTensorHandle(var_handle);
407   TFE_DeleteContext(ctx);
408   TF_DeleteStatus(status);
409 
410   // TODO(b/136478427): Figure out how to correctly shut the server down.
411   worker_server1.release();
412   worker_server2.release();
413 
414   // Disable the test graph pass so it does not affect other test cases.
415   GraphErrorInjectionPass::enabled_ = false;
416 }
417 
VariableAddFunctionWithGraphError()418 string VariableAddFunctionWithGraphError() {
419   string signature = VariableAddFunctionSignature();
420   // Replace the node 'read0' with 'read0_maybe_with_graph_error', so that the
421   // error injecting pass can identify and introduce graph pass errors.
422   signature = std::regex_replace(signature, std::regex("read0"),
423                                  "read0_maybe_with_graph_error");
424   tensorflow::FunctionDef def;
425   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(signature, &def));
426   return def.SerializeAsString();
427 }
428 
429 class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
430  public:
FunctionErrorInjectionPass(string error_node,string error_device)431   FunctionErrorInjectionPass(string error_node, string error_device)
432       : error_node_(error_node), error_device_(error_device) {}
Run(const tensorflow::DeviceSet & device_set,const tensorflow::ConfigProto & config_proto,std::unique_ptr<tensorflow::Graph> * graph,tensorflow::FunctionLibraryDefinition * flib_def,std::vector<std::string> * control_ret_node_names,bool * control_rets_updated)433   tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
434                          const tensorflow::ConfigProto& config_proto,
435                          std::unique_ptr<tensorflow::Graph>* graph,
436                          tensorflow::FunctionLibraryDefinition* flib_def,
437                          std::vector<std::string>* control_ret_node_names,
438                          bool* control_rets_updated) override {
439     // Inject failure to function instantiation if finding a node that contains
440     // the given node name (error_node_) and requested device (error_device_).
441     for (const auto node : graph->get()->nodes()) {
442       if (node->name().find(error_node_) != string::npos &&
443           node->requested_device() == error_device_) {
444         return tensorflow::errors::Internal("Injected graph pass error.");
445       }
446     }
447     return ::tensorflow::OkStatus();
448   }
449 
450  private:
451   const string error_node_;
452   const string error_device_;
453 };
454 
TestDistributedFunctionCancellation(bool inject_error)455 void TestDistributedFunctionCancellation(bool inject_error) {
456   tensorflow::ServerDef server_def = GetServerDef(3);
457   // This server def has the task index set to 0.
458   string serialized = server_def.SerializeAsString();
459 
460   server_def.set_task_index(1);
461   std::unique_ptr<tensorflow::GrpcServer> worker_server1;
462   ASSERT_TRUE(tensorflow::GrpcServer::Create(
463                   server_def, tensorflow::Env::Default(), &worker_server1)
464                   .ok());
465   ASSERT_TRUE(worker_server1->Start().ok());
466   server_def.set_task_index(2);
467   std::unique_ptr<tensorflow::GrpcServer> worker_server2;
468   ASSERT_TRUE(tensorflow::GrpcServer::Create(
469                   server_def, tensorflow::Env::Default(), &worker_server2)
470                   .ok());
471   ASSERT_TRUE(worker_server2->Start().ok());
472   const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
473 
474   if (inject_error) {
475     // Inject a function optimization pass failure when it sees the
476     // 'read0_maybe_with_graph_error' op having a requested device `dev2_name`.
477     // During execution:
478     //   * task:0 processes main function `VariableAddFunctionWithGraphError`
479     //     and places the 'read0_maybe_with_graph_error' op on task:2
480     //   * task:0 partitions the main function with a subgraph containing
481     //     'read0_maybe_with_graph_error' sent to task:2
482     //   * task:2 graph pass reports an error when it sees
483     //     'read0_maybe_with_graph_error' with dev2_name
484     tensorflow::function_optimization_registration::
485         FunctionOptimizationPassRegistration register_test_pass(
486             std::make_unique<FunctionErrorInjectionPass>(
487                 "read0_maybe_with_graph_error", dev2_name));
488   }
489 
490   TF_Status* status = TF_NewStatus();
491   TFE_ContextOptions* opts = TFE_NewContextOptions();
492   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
493   TFE_Context* ctx = TFE_NewContext(opts, status);
494   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
495   TFE_DeleteContextOptions(opts);
496 
497   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
498   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
499 
500   TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
501   EXPECT_NE(var_handle, nullptr);
502   TFE_ContextAsyncWait(ctx, status);
503   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
504 
505   const string function_def = inject_error ? VariableAddFunctionWithGraphError()
506                                            : VariableAddFunction();
507   TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
508                             status);
509   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
510 
511   TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
512   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
513   TFE_OpAddInput(func, var_handle, status);
514   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
515   TFE_TensorHandle* retvals[1] = {nullptr};
516   int num_retvals = 1;
517   TFE_Execute(func, &retvals[0], &num_retvals, status);
518 
519   if (inject_error) {
520     ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
521   } else {
522     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
523     ASSERT_EQ(1, num_retvals);
524     TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
525     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
526     TFE_DeleteTensorHandle(retvals[0]);
527     float sum = 0;
528     ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
529     memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
530     TF_DeleteTensor(t);
531     ASSERT_EQ(sum, 4.0);
532   }
533 
534   TFE_DeleteOp(func);
535   TFE_DeleteTensorHandle(var_handle);
536   TFE_DeleteContext(ctx);
537   TF_DeleteStatus(status);
538 
539   // TODO(b/136478427): Figure out how to correctly shut the server down.
540   worker_server1.release();
541   worker_server2.release();
542 }
543 
TEST(CAPI,DistributedFunctionNoError)544 TEST(CAPI, DistributedFunctionNoError) {
545   TestDistributedFunctionCancellation(false);
546 }
547 
548 // TODO(b/170399182): Update test once an alternative to using the function
549 // optimization hook is in place.
TEST(CAPI,DISABLED_DistributedFunctionCancelledOnError)550 TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
551   TestDistributedFunctionCancellation(true);
552 }
553 
TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async)554 void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
555   tensorflow::ServerDef server_def = GetServerDef(2);
556 
557   // This server def has the task index set to 0.
558   string serialized = server_def.SerializeAsString();
559 
560   server_def.set_task_index(1);
561 
562   std::unique_ptr<tensorflow::GrpcServer> worker_server;
563   ASSERT_TRUE(tensorflow::GrpcServer::Create(
564                   server_def, tensorflow::Env::Default(), &worker_server)
565                   .ok());
566   ASSERT_TRUE(worker_server->Start().ok());
567 
568   TF_Status* status = TF_NewStatus();
569   TFE_ContextOptions* opts = TFE_NewContextOptions();
570   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
571   TFE_ContextOptionsSetDevicePlacementPolicy(opts,
572                                              TFE_DEVICE_PLACEMENT_EXPLICIT);
573   TFE_Context* ctx = TFE_NewContext(opts, status);
574   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
575   TFE_DeleteContextOptions(opts);
576 
577   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
578   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
579 
580   // Use large matrices so that RPCs don't return before we get a chance
581   // to call TFE_DeleteContext.
582   TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
583   TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
584   const char remote_device_name[] =
585       "/job:localhost/replica:0/task:1/device:CPU:0";
586   auto* h0_task1 =
587       TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
588   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
589   auto* h1_task1 =
590       TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
591   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
592 
593   TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
594   TFE_OpSetDevice(matmul, remote_device_name, status);
595   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
596 
597   TFE_TensorHandle* retvals[1];
598   int num_retvals = 1;
599   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
600   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
601   TF_DeleteStatus(status);
602 
603   TFE_DeleteTensorHandle(h0_task0);
604   TFE_DeleteTensorHandle(h1_task0);
605   TFE_DeleteTensorHandle(h0_task1);
606   TFE_DeleteTensorHandle(h1_task1);
607   TFE_DeleteTensorHandle(retvals[0]);
608 
609   TFE_DeleteOp(matmul);
610 
611   TFE_DeleteContext(ctx);
612 
613   // TODO(b/136478427): Figure out how to correctly shut the server down.
614   worker_server.release();
615 }
616 
TEST(CAPI,RemoteExecuteDeleteContextWithOutstandingRPC)617 TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
618   TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
619 }
620 
TEST(CAPI,RemoteExecuteDeleteContextWithOutstandingRPCAsync)621 TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
622   TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
623 }
624 }  // namespace
625