• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
17 
18 #include "tensorflow/core/common_runtime/device.h"
19 #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/graph/default_device.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/graph/testlib.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/init_main.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/protobuf/error_codes.pb.h"
32 #include "tensorflow/core/public/session.h"
33 #include "tensorflow/core/util/port.h"
34 
35 namespace tensorflow {
36 
Devices(int num_cpus,int num_gpus)37 static SessionOptions Devices(int num_cpus, int num_gpus) {
38   SessionOptions result;
39   (*result.config.mutable_device_count())["CPU"] = num_cpus;
40   (*result.config.mutable_device_count())["GPU"] = num_gpus;
41   return result;
42 }
43 
CreateGraphDef(GraphDef * graph_def,string node_names[3])44 void CreateGraphDef(GraphDef* graph_def, string node_names[3]) {
45   Graph graph(OpRegistry::Global());
46 
47   Tensor a_tensor(DT_FLOAT, TensorShape({1, 2}));
48   test::FillValues<float>(&a_tensor, {1, 2});
49   Node* a = test::graph::Constant(&graph, a_tensor);
50   node_names[0] = a->name();
51 
52   Tensor b_tensor(DT_FLOAT, TensorShape({2, 1}));
53   test::FillValues<float>(&b_tensor, {2, 1});
54   Node* b = test::graph::Constant(&graph, b_tensor);
55   node_names[1] = b->name();
56 
57   Node* c = test::graph::Matmul(&graph, a, b, false, false);
58   node_names[2] = c->name();
59 
60   test::graph::ToGraphDef(&graph, graph_def);
61 }
62 
63 // Asserts that "val" is a single float tensor. The only float is
64 // "expected_val".
IsSingleFloatValue(const Tensor & val,float expected_val)65 static void IsSingleFloatValue(const Tensor& val, float expected_val) {
66   ASSERT_EQ(val.dtype(), DT_FLOAT);
67   ASSERT_EQ(val.NumElements(), 1);
68   ASSERT_EQ(val.flat<float>()(0), expected_val);
69 }
70 
Options(const string & target,int placement_period)71 static SessionOptions Options(const string& target, int placement_period) {
72   SessionOptions options;
73   // NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target
74   // string.
75   options.target = strings::StrCat("grpc://", target);
76   options.config.set_placement_period(placement_period);
77   options.config.mutable_graph_options()
78       ->mutable_optimizer_options()
79       ->set_opt_level(OptimizerOptions::L0);
80   return options;
81 }
82 
NewRemote(const SessionOptions & options)83 static Session* NewRemote(const SessionOptions& options) {
84   return CHECK_NOTNULL(NewSession(options));
85 }
86 
TEST(GrpcSessionTest,BasicNonProtoAPI)87 TEST(GrpcSessionTest, BasicNonProtoAPI) {
88   GraphDef graph;
89   string node_names[3];
90   // c = a * b
91   CreateGraphDef(&graph, node_names);
92 
93   std::unique_ptr<test::TestCluster> cluster;
94   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
95 
96   std::unique_ptr<Session> session(
97       NewRemote(Options(cluster->targets()[0], 1)));
98   ASSERT_TRUE(session != nullptr);
99 
100   for (int iters = 0; iters < 25; ++iters) {
101     TF_CHECK_OK(session->Create(graph));
102     {
103       // Just run to target node
104       std::vector<std::pair<string, Tensor>> inputs;
105       std::vector<string> targets = {node_names[2]};
106       TF_CHECK_OK(session->Run(inputs, {}, targets, nullptr));
107     }
108     {
109       // Run to a target node and a real tensor
110       std::vector<std::pair<string, Tensor>> inputs;
111       std::vector<string> names = {node_names[2] + ":0"};
112       std::vector<string> targets = {node_names[1]};
113       std::vector<Tensor> outputs;
114       TF_CHECK_OK(session->Run(inputs, names, targets, &outputs));
115       ASSERT_TRUE(outputs[0].IsInitialized());
116       ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
117     }
118 
119     TF_CHECK_OK(session->Close());
120   }
121 }
122 
TEST(GrpcSessionTest,BasicCallable)123 TEST(GrpcSessionTest, BasicCallable) {
124   GraphDef graph;
125   string node_names[3];
126   // c = a * b
127   CreateGraphDef(&graph, node_names);
128 
129   std::unique_ptr<test::TestCluster> cluster;
130   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
131 
132   std::unique_ptr<Session> session(
133       NewRemote(Options(cluster->targets()[0], 1)));
134   ASSERT_TRUE(session != nullptr);
135 
136   for (int iters = 0; iters < 25; ++iters) {
137     TF_CHECK_OK(session->Create(graph));
138     {
139       // Just run to target node
140       CallableOptions opts;
141       opts.add_target(node_names[2]);
142       Session::CallableHandle handle;
143       TF_CHECK_OK(session->MakeCallable(opts, &handle));
144       TF_CHECK_OK(session->RunCallable(handle, {}, nullptr, nullptr));
145       TF_CHECK_OK(session->ReleaseCallable(handle));
146     }
147     {
148       // Run to a target node and a real tensor
149       CallableOptions opts;
150       opts.add_target(node_names[1]);
151       opts.add_fetch(node_names[2] + ":0");
152       Session::CallableHandle handle;
153       TF_CHECK_OK(session->MakeCallable(opts, &handle));
154       std::vector<Tensor> outputs;
155       TF_CHECK_OK(session->RunCallable(handle, {}, &outputs, nullptr));
156       ASSERT_EQ(1, outputs.size());
157       ASSERT_TRUE(outputs[0].IsInitialized());
158       ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
159       TF_CHECK_OK(session->ReleaseCallable(handle));
160     }
161 
162     TF_CHECK_OK(session->Close());
163   }
164 }
165 
TEST(GrpcSessionTest,CallableWithOnDeviceFeedsAndFetches)166 TEST(GrpcSessionTest, CallableWithOnDeviceFeedsAndFetches) {
167   // Specifying feeds/fetch devices for remote sessions is not yet defined.
168   // Ensure that the error is graceful.
169   GraphDef graph;
170   string node_names[3];
171   // c = a * b
172   CreateGraphDef(&graph, node_names);
173 
174   std::unique_ptr<test::TestCluster> cluster;
175   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
176 
177   std::unique_ptr<Session> session(
178       NewRemote(Options(cluster->targets()[0], 1)));
179   ASSERT_TRUE(session != nullptr);
180 
181   TF_CHECK_OK(session->Create(graph));
182 
183   std::vector<DeviceAttributes> devices;
184   TF_CHECK_OK(session->ListDevices(&devices));
185   ASSERT_GT(devices.size(), 0);
186   const string device_name = devices.back().name();
187 
188   CallableOptions opts;
189   const string fetch = node_names[2] + ":0";
190   opts.add_fetch(fetch);
191   opts.mutable_fetch_devices()->insert({fetch, device_name});
192 
193   Session::CallableHandle handle;
194   Status status = session->MakeCallable(opts, &handle);
195   EXPECT_EQ(error::UNIMPLEMENTED, status.code());
196   TF_CHECK_OK(session->Close());
197 }
198 
TEST(GrpcSessionTest,BasicNonProtoAPIConsistentOrder)199 TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
200   GraphDef graph;
201   string node_names[3];
202   // c = a * b
203   CreateGraphDef(&graph, node_names);
204 
205   std::unique_ptr<test::TestCluster> cluster;
206   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
207 
208   std::unique_ptr<Session> session(
209       NewRemote(Options(cluster->targets()[0], 1)));
210   ASSERT_TRUE(session != nullptr);
211   ASSERT_TRUE(session->Create(graph).ok());
212 
213   // Test that the order of the output names matches the order of the
214   // returned Tensors.
215   std::vector<std::pair<string, Tensor>> inputs;
216   std::vector<string> names = {node_names[2] + ":0", node_names[0] + ":0",
217                                node_names[1] + ":0"};
218 
219   std::vector<string> target_ops = {node_names[1]};
220   std::vector<Tensor> outputs;
221   ASSERT_TRUE(session->Run(inputs, names, target_ops, &outputs).ok());
222   ASSERT_TRUE(outputs[0].IsInitialized());
223   ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
224   ASSERT_TRUE(outputs[1].IsInitialized());
225   ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
226   ASSERT_TRUE(outputs[2].IsInitialized());
227   ASSERT_EQ(2.0, outputs[2].flat<float>()(0));
228   ASSERT_TRUE(session->Close().ok());
229 }
230 
TEST(GrpcSessionTest,NonLocalWithFilters)231 TEST(GrpcSessionTest, NonLocalWithFilters) {
232   GraphDef graph;
233   string node_names[3];
234   // c = a * b
235   CreateGraphDef(&graph, node_names);
236 
237   std::unique_ptr<test::TestCluster> cluster;
238   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
239 
240   SessionOptions options;
241   options.target = strings::StrCat("grpc://", cluster->targets()[0]);
242   options.config.add_device_filters(cluster->devices()[0].name());
243 
244   std::unique_ptr<Session> session(NewRemote(options));
245   ASSERT_TRUE(session != nullptr);
246 
247   {
248     GraphDef graph_copy(graph);
249     graph::SetDefaultDevice(cluster->devices()[0].name(), &graph_copy);
250     TF_CHECK_OK(session->Create(graph_copy));
251     TF_CHECK_OK(session->Run({}, {}, {node_names[2]}, nullptr));
252     TF_CHECK_OK(session->Close());
253   }
254   {
255     GraphDef graph_copy(graph);
256     graph::SetDefaultDevice(cluster->devices()[1].name(), &graph_copy);
257     auto status = session->Create(graph_copy);
258     EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
259   }
260 }
261 
TEST(GrpcSessionTest,FetchMultipleTimes)262 TEST(GrpcSessionTest, FetchMultipleTimes) {
263   GraphDef graph;
264   string node_names[3];
265   CreateGraphDef(&graph, node_names);
266 
267   std::unique_ptr<test::TestCluster> cluster;
268   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
269 
270   std::unique_ptr<Session> session(
271       NewRemote(Options(cluster->targets()[0], 1)));
272   ASSERT_TRUE(session != nullptr);
273 
274   TF_CHECK_OK(session->Create(graph));
275   const std::vector<std::pair<string, Tensor>> inputs;
276   std::vector<Tensor> outputs;
277 
278   const string node = node_names[2] + ":0";
279   TF_CHECK_OK(session->Run(inputs, {node, node}, {}, &outputs));
280   EXPECT_EQ(2, outputs.size());
281   for (int i = 0; i < outputs.size(); ++i) {
282     const Tensor& t = outputs[i];
283     ASSERT_TRUE(t.IsInitialized()) << i;
284     ASSERT_EQ(4.0, t.flat<float>()(0)) << i;
285   }
286   TF_CHECK_OK(session->Close());
287 }
288 
TEST(GrpcSessionTest,DisableOutputPartitionGraphs)289 TEST(GrpcSessionTest, DisableOutputPartitionGraphs) {
290   GraphDef graph;
291   string node_names[3];
292   CreateGraphDef(&graph, node_names);
293 
294   std::unique_ptr<test::TestCluster> cluster;
295   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
296 
297   SessionOptions options = Options(cluster->targets()[0], 1);
298   options.config.mutable_experimental()->set_disable_output_partition_graphs(
299       true);
300 
301   std::unique_ptr<Session> session(NewRemote(options));
302   ASSERT_TRUE(session != nullptr);
303 
304   TF_CHECK_OK(session->Create(graph));
305   {
306     // Just run to target node.
307     TF_CHECK_OK(session->Run({}, {}, {node_names[2]}, nullptr));
308   }
309   {
310     // Attempting to get the partition graphs should fail.
311     RunOptions run_options;
312     run_options.set_output_partition_graphs(true);
313     RunMetadata run_metadata;
314     Status s = session->Run(run_options, {}, {}, {node_names[2]}, nullptr,
315                             &run_metadata);
316     EXPECT_TRUE(errors::IsInvalidArgument(s));
317     EXPECT_TRUE(absl::StrContains(s.error_message(),
318                                   "disable_output_partition_graphs"));
319   }
320 
321   TF_CHECK_OK(session->Close());
322 }
323 
324 // A = [3 2; -1 0]; x = rand(2, 1); We want to compute the largest
325 // eigenvalue for A, which is 2.0. Iteratively, we do
326 //   repeat x = y / y.norm(); y = A * x; end
327 // At the end, we expect "lambda" converges to 2.0.
FindMaxEigen(const string & target)328 void FindMaxEigen(const string& target) {
329   Graph graph(OpRegistry::Global());
330 
331   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
332   // Store rows [3, 2] and [-1, 0] in row major format.
333   test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
334   Node* a = test::graph::Constant(&graph, a_tensor);
335 
336   // x is from the feed.
337   Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
338   test::FillValues<float>(&x_tensor, {0, 0});
339   Node* x = test::graph::Constant(&graph, x_tensor);
340 
341   // y = A * x
342   Node* y = test::graph::Matmul(&graph, a, x, false, false);
343 
344   // y2 = y.^2
345   Node* y2 = test::graph::Unary(&graph, "Square", y);
346 
347   // const tensor for reduction
348   Tensor rdim_tensor(DT_INT32, TensorShape({}));
349   rdim_tensor.scalar<int32>()() = 0;
350   Node* rdim = test::graph::Constant(&graph, rdim_tensor);
351 
352   // y2_sum = sum(y2)
353   Node* y2_sum = test::graph::Reduce(&graph, "Sum", y2, rdim);
354 
355   // y_norm = sqrt(y2_sum)
356   Node* y_norm = test::graph::Unary(&graph, "Sqrt", y2_sum);
357 
358   // y_normalized = y ./ y_norm
359   Node* y_normalized = test::graph::Binary(&graph, "Div", y, y_norm);
360 
361   GraphDef def;
362   test::graph::ToGraphDef(&graph, &def);
363 
364   std::unique_ptr<Session> session(NewRemote(Options(target, 1)));
365   ASSERT_TRUE(session != nullptr);
366   TF_CHECK_OK(session->Create(def));
367 
368   // Setup feeds and fetches.
369   float lambda;
370   Tensor feed_value(DT_FLOAT, TensorShape({2, 1}));
371   feed_value.matrix<float>()(0, 0) = -3.1415;
372   feed_value.matrix<float>()(1, 0) = +2.7183;
373 
374   for (int i = 0; i < 25; ++i) {
375     std::vector<Tensor> outputs;
376     TF_CHECK_OK(session->Run({{x->name(), feed_value}},
377                              {y->name(), y_normalized->name()}, {}, &outputs));
378     const Tensor& y = outputs[0];
379     const Tensor& y_normalized = outputs[1];
380     // Print out lambda, x, and y.
381     CHECK_EQ(2, feed_value.NumElements());
382     CHECK_EQ(2, y.NumElements());
383     lambda = y.flat<float>()(0) / feed_value.flat<float>()(0);
384     printf("%06d lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]\n", i,
385            lambda, feed_value.flat<float>()(0), feed_value.flat<float>()(1),
386            y.flat<float>()(0), y.flat<float>()(1));
387     // Copies y_normalized to  *x.
388     feed_value = y_normalized;
389   }
390   EXPECT_NEAR(2.0, lambda, 1e-6);
391 }
392 
TEST(FindMaxEigenTest,RemoteDevice)393 TEST(FindMaxEigenTest, RemoteDevice) {
394   std::unique_ptr<test::TestCluster> cluster;
395   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
396   FindMaxEigen(cluster->targets()[0]);
397 }
398 
SetDevice(GraphDef * graph,const string & name,const string & dev)399 void SetDevice(GraphDef* graph, const string& name, const string& dev) {
400   for (int i = 0; i < graph->node_size(); ++i) {
401     if (graph->node(i).name() == name) {
402       graph->mutable_node(i)->set_device(dev);
403       return;
404     }
405   }
406   LOG(FATAL) << "Name '" << name << "' not found.";
407 }
408 
409 // TODO(b/32636929): This test fails 1/1000 times. Disable it while we
410 // figure out why.
TEST(GrpcSessionTest,DISABLED_MultiDevices)411 TEST(GrpcSessionTest, DISABLED_MultiDevices) {
412   std::unique_ptr<test::TestCluster> cluster;
413   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
414 
415   Graph graph(OpRegistry::Global());
416   const int kSize = 1048576;
417 
418   // c = a * b = 2 * 3 * kSize
419   Tensor a_tensor(DT_FLOAT, TensorShape({1, kSize}));
420   Tensor b_tensor(DT_FLOAT, TensorShape({kSize, 1}));
421   for (int i = 0; i < kSize; ++i) {
422     a_tensor.flat<float>()(i) = 2;
423     b_tensor.flat<float>()(i) = 3;
424   }
425   Node* a = test::graph::Constant(&graph, a_tensor);
426   Node* b = test::graph::Constant(&graph, b_tensor);
427   Node* c = test::graph::Matmul(&graph, a, b, false, false);
428 
429   GraphDef def;
430   test::graph::ToGraphDef(&graph, &def);
431 
432   // In this test, we force each node (a, b, c) on every possible device.
433   // We test all possible cases.
434   for (const auto& a_dev : cluster->devices()) {
435     for (const auto& b_dev : cluster->devices()) {
436       for (const auto& c_dev : cluster->devices()) {
437         LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name()
438                   << " c: " << c_dev.name();
439 
440         SetDevice(&def, a->name(), a_dev.name());
441         SetDevice(&def, b->name(), b_dev.name());
442         SetDevice(&def, c->name(), c_dev.name());
443 
444         std::unique_ptr<Session> session(
445             NewRemote(Options(cluster->targets()[0], 1000)));
446         ASSERT_TRUE(session != nullptr);
447         TF_CHECK_OK(session->Create(def));
448         {
449           std::vector<Tensor> outputs;
450           RunOptions options;
451           options.set_trace_level(RunOptions::FULL_TRACE);
452           RunMetadata metadata;
453           TF_CHECK_OK(
454               session->Run(options, {}, {c->name()}, {}, &outputs, &metadata));
455           ASSERT_EQ(1, outputs.size());
456           IsSingleFloatValue(outputs[0], 6.0 * kSize);
457 
458           const StepStats& ss = metadata.step_stats();
459           // NOTE(mrry): We only assert that `c` is placed correctly,
460           // because the current placement algorithm will move its
461           // inputs to be colocated with it, when it is the sole
462           // consumer.
463           bool c_placed_correctly = false;
464           for (const auto& dev : ss.dev_stats()) {
465             for (const auto& node : dev.node_stats()) {
466               if (node.node_name() == c->name() &&
467                   dev.device() == c_dev.name()) {
468                 c_placed_correctly = true;
469               }
470             }
471           }
472           ASSERT_TRUE(c_placed_correctly);
473         }
474         TF_CHECK_OK(session->Close());
475       }
476     }
477   }
478 }
479 
TEST(GrpcSessionTest,LargeTensorSend)480 TEST(GrpcSessionTest, LargeTensorSend) {
481   std::unique_ptr<test::TestCluster> cluster;
482   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
483 
484   Graph graph(OpRegistry::Global());
485 
486   // Define a 3 GB fill result.
487   Tensor fill_shape_tensor(DT_INT32, TensorShape({4}));
488   fill_shape_tensor.vec<int32>()(0) = 1;
489   fill_shape_tensor.vec<int32>()(1) = 256;
490   fill_shape_tensor.vec<int32>()(2) = 1024;
491   fill_shape_tensor.vec<int32>()(3) = 1024;
492   Node* fill_shape_node = test::graph::Constant(&graph, fill_shape_tensor);
493 
494   Tensor fill_val_tensor(DT_FLOAT, TensorShape({}));
495   fill_val_tensor.flat<float>()(0) = 1.0;
496   Node* fill_val_node = test::graph::Constant(&graph, fill_val_tensor);
497 
498   Node* fill_node =
499       test::graph::Binary(&graph, "Fill", fill_shape_node, fill_val_node);
500 
501   Tensor max_axes_tensor(DT_INT32, TensorShape({4}));
502   max_axes_tensor.vec<int32>()(0) = 0;
503   max_axes_tensor.vec<int32>()(1) = 1;
504   max_axes_tensor.vec<int32>()(2) = 2;
505   max_axes_tensor.vec<int32>()(3) = 3;
506   Node* max_axes_node = test::graph::Constant(&graph, max_axes_tensor);
507   Node* max_node = test::graph::Reduce(&graph, "Max", fill_node, max_axes_node);
508 
509   GraphDef def;
510   test::graph::ToGraphDef(&graph, &def);
511 
512   SetDevice(&def, fill_node->name(), cluster->devices()[0].name());
513   SetDevice(&def, fill_node->name(), cluster->devices()[1].name());
514 
515   std::unique_ptr<Session> session(
516       NewRemote(Options(cluster->targets()[0], 1000)));
517   ASSERT_TRUE(session != nullptr);
518   TF_CHECK_OK(session->Create(def));
519   {
520     std::vector<Tensor> outputs;
521     TF_CHECK_OK(session->Run({}, {max_node->name()}, {}, &outputs));
522     ASSERT_EQ(1, outputs.size());
523     IsSingleFloatValue(outputs[0], 1.0);
524   }
525   TF_CHECK_OK(session->Close());
526 }
527 
TEST(GrpcSessionTest,MultiDevices_String)528 TEST(GrpcSessionTest, MultiDevices_String) {
529   std::unique_ptr<test::TestCluster> cluster;
530   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster));
531   std::unique_ptr<Session> session(
532       NewRemote(Options(cluster->targets()[0], 1000)));
533   ASSERT_TRUE(session != nullptr);
534 
535   // b = a
536   Graph graph(OpRegistry::Global());
537   Tensor a_tensor(DT_STRING, TensorShape({2, 2}));
538   for (int i = 0; i < 4; ++i) {
539     a_tensor.flat<tstring>()(i) = "hello, world";
540   }
541   Node* a = test::graph::Constant(&graph, a_tensor);
542   Node* b = test::graph::Identity(&graph, a);
543 
544   GraphDef def;
545   test::graph::ToGraphDef(&graph, &def);
546 
547   // In this test, we force each node (a, b) on every possible device.
548   // We test all possible cases.
549   for (const auto& a_dev : cluster->devices()) {
550     for (const auto& b_dev : cluster->devices()) {
551       LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name();
552       SetDevice(&def, a->name(), a_dev.name());
553       SetDevice(&def, b->name(), b_dev.name());
554 
555       Status s = session->Create(def);
556       if (s.ok()) {
557         std::vector<Tensor> outputs;
558         TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs));
559         ASSERT_EQ(1, outputs.size());
560         ASSERT_EQ(outputs[0].dtype(), DT_STRING);
561         ASSERT_EQ(outputs[0].NumElements(), 4);
562         for (int i = 0; i < outputs[0].NumElements(); ++i) {
563           EXPECT_EQ(outputs[0].flat<tstring>()(i), "hello, world");
564         }
565         TF_CHECK_OK(session->Close());
566       } else {
567         LOG(ERROR) << "Error: " << s;
568         ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
569                     (b_dev.device_type() == DEVICE_GPU));
570         ASSERT_FALSE(s.ok());
571       }
572     }
573   }
574 }
575 
TEST(GrpcSessionTest,SendRecv_Node_Naming)576 TEST(GrpcSessionTest, SendRecv_Node_Naming) {
577   std::unique_ptr<test::TestCluster> cluster;
578   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 3, &cluster));
579   std::unique_ptr<Session> session(
580       NewRemote(Options(cluster->targets()[0], 1)));
581   ASSERT_TRUE(session != nullptr);
582 
583   // This test case needs at least 3 devices.
584   CHECK_GE(cluster->devices().size(), 3);
585   const DeviceAttributes& src = cluster->devices()[0];
586   const DeviceAttributes& dst0 = cluster->devices()[1];
587   const DeviceAttributes& dst1 = cluster->devices()[2];
588   LOG(INFO) << "src = " << src.name() << " dst0 = " << dst0.name()
589             << " dst1 = " << dst1.name();
590 
591   // Within the same session, we compute two subgraphs:
592   //   1) a on 'src' sends to b on 'dst0';
593   //   2) a on 'src' sends to c on 'dst1'.
594   Graph graph(OpRegistry::Global());
595   Tensor a_tensor(DT_FLOAT, TensorShape({1, 1}));
596   a_tensor.flat<float>()(0) = 100;
597   Node* a = test::graph::Constant(&graph, a_tensor);
598   Node* b = test::graph::Identity(&graph, a);
599   Node* c = test::graph::Identity(&graph, a);
600 
601   GraphDef def;
602   test::graph::ToGraphDef(&graph, &def);
603 
604   // The base graph have a, b, c, assigned to devices explicitly.
605   SetDevice(&def, a->name(), src.name());
606   SetDevice(&def, b->name(), dst0.name());
607   SetDevice(&def, c->name(), dst1.name());
608   TF_CHECK_OK(session->Create(def));
609 
610   // Run subgraph a -> b, and fetch b.
611   {
612     std::vector<Tensor> outputs;
613     TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs));
614     ASSERT_EQ(1, outputs.size());
615     IsSingleFloatValue(outputs[0], 100);
616   }
617 
618   // Run subgraph a -> c, and fetch c.
619   {
620     std::vector<Tensor> outputs;
621     TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
622     ASSERT_EQ(1, outputs.size());
623     IsSingleFloatValue(outputs[0], 100);
624   }
625 
626   TF_CHECK_OK(session->Close());
627 }
628 
TEST(GrpcSessionTest,Error)629 TEST(GrpcSessionTest, Error) {
630   std::unique_ptr<test::TestCluster> cluster;
631   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
632   const string& master = cluster->targets()[0];
633   const string& dev_a = cluster->devices()[0].name();
634   const string& dev_b = cluster->devices()[1].name();
635   LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b;
636   GraphDef gdef;
637   std::vector<string> fetches;
638   {
639     Graph g(OpRegistry::Global());
640 
641     // a2 = a + error(a)
642     //
643     // Subgraph for "a" fails. The master will cancel the subgraph for
644     // "b" and then returns the Session::Run.
645     auto a = test::graph::Constant(&g, Tensor());
646     a->set_assigned_device_name(dev_a);
647     auto a_err = test::graph::Error(&g, a, "fantasia!");
648     a_err->set_assigned_device_name(dev_a);
649     auto a2 = test::graph::Add(&g, a, a_err);
650     a2->set_assigned_device_name(dev_a);
651     fetches.push_back(a2->name());
652 
653     // b2 = b + delay(b)
654     //
655     // Subgraph for "b" sleeps at the node "b_delay". When the sleep
656     // finishes, the subgraph "b" will continue execution till it
657     // notices that it is canceled. Meanwhile, subgraph's executor
658     // and its related state (registered ops) should still be alive.
659     auto b = test::graph::Constant(&g, Tensor());
660     b->set_assigned_device_name(dev_b);
661     auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000));
662     b_delay->set_assigned_device_name(dev_b);
663     auto b2 = test::graph::Add(&g, b, b_delay);
664     b2->set_assigned_device_name(dev_b);
665     fetches.push_back(b2->name());
666     test::graph::ToGraphDef(&g, &gdef);
667   }
668   std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
669   ASSERT_TRUE(session != nullptr);
670 
671   TF_CHECK_OK(session->Create(gdef));
672   {
673     Status status = session->Run({}, fetches, {}, nullptr);
674     EXPECT_FALSE(status.ok());
675     EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
676   }
677   // session->Close() shall clean up all states related to the session->
678   // E.g., deregisters subgraph with workers, etc.
679   TF_CHECK_OK(session->Close());
680 
681   // Sleep a bit so that most of asynchronous works finishes before
682   // the test process finishes.
683   Env::Default()->SleepForMicroseconds(2000000);
684 }
685 
TEST(GrpcSessionTest,ErrorStatusLog)686 TEST(GrpcSessionTest, ErrorStatusLog) {
687   std::unique_ptr<test::TestCluster> cluster;
688   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
689   const string& master = cluster->targets()[0];
690   const string& dev_a = cluster->devices()[0].name();
691   const string& dev_b = cluster->devices()[1].name();
692   LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b;
693   GraphDef gdef;
694   std::vector<string> fetches;
695   {
696     Graph g(OpRegistry::Global());
697 
698     // a2 = a + error(a)
699     //
700     // Subgraph for "a" fails. The master will cancel the subgraph for
701     // "b" and then returns the Session::Run.
702     auto a = test::graph::Constant(&g, Tensor());
703     a->set_assigned_device_name(dev_a);
704     auto a_err = test::graph::Error(&g, a, "fantasia!", true);
705     a_err->set_assigned_device_name(dev_a);
706     auto a2 = test::graph::Add(&g, a, a_err);
707     a2->set_assigned_device_name(dev_a);
708     fetches.push_back(a2->name());
709 
710     // b2 = b + delay(b)
711     //
712     // Subgraph for "b" sleeps at the node "b_delay". When the sleep
713     // finishes, the subgraph "b" will continue execution till it
714     // notices that it is canceled. Meanwhile, subgraph's executor
715     // and its related state (registered ops) should still be alive.
716     auto b = test::graph::Constant(&g, Tensor());
717     b->set_assigned_device_name(dev_b);
718     auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000));
719     b_delay->set_assigned_device_name(dev_b);
720     auto b2 = test::graph::Add(&g, b, b_delay);
721     b2->set_assigned_device_name(dev_b);
722     fetches.push_back(b2->name());
723     g.ToGraphDef(&gdef);
724   }
725   std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
726   ASSERT_TRUE(session != nullptr);
727 
728   TF_CHECK_OK(session->Create(gdef));
729   {
730     Status status = session->Run({}, fetches, {}, nullptr);
731     EXPECT_FALSE(status.ok());
732     std::cerr << status << "\n";
733     EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
734     EXPECT_NE(status.ToString().find("ErrorOp: fantasia!"), string::npos);
735   }
736   // session->Close() shall clean up all states related to the session->
737   // E.g., deregisters subgraph with workers, etc.
738   TF_CHECK_OK(session->Close());
739 
740   // Sleep a bit so that most of asynchronous works finishes before
741   // the test process finishes.
742   Env::Default()->SleepForMicroseconds(2000000);
743 }
744 
TEST(GrpcSessionTest,LongErrorMessage)745 TEST(GrpcSessionTest, LongErrorMessage) {
746   std::unique_ptr<test::TestCluster> cluster;
747   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
748   const string& master = cluster->targets()[0];
749   const string& dev_a = cluster->devices()[0].name();
750   const string& dev_b = cluster->devices()[1].name();
751   LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b;
752   GraphDef gdef;
753   std::vector<string> fetches;
754   {
755     Graph g(OpRegistry::Global());
756 
757     // a2 = a + error(a)
758     //
759     // Subgraph for "a" fails. The master will cancel the subgraph for
760     // "b" and then returns the Session::Run.
761     auto a = test::graph::Constant(&g, Tensor());
762     a->set_assigned_device_name(dev_a);
763     std::vector<char> long_string_buffer(1024 * 1024, 'x');
764     StringPiece long_string(long_string_buffer.data(), 1024 * 1024);
765     string name = strings::StrCat(long_string, "fantasia!");
766     auto a_err = test::graph::Error(&g, a, name);
767     a_err->set_assigned_device_name(dev_a);
768     auto a2 = test::graph::Add(&g, a, a_err);
769     a2->set_assigned_device_name(dev_a);
770     fetches.push_back(a2->name());
771 
772     // b2 = b + delay(b)
773     //
774     // Subgraph for "b" sleeps at the node "b_delay". When the sleep
775     // finishes, the subgraph "b" will continue execution till it
776     // notices that it is canceled. Meanwhile, subgraph's executor
777     // and its related state (registered ops) should still be alive.
778     auto b = test::graph::Constant(&g, Tensor());
779     b->set_assigned_device_name(dev_b);
780     auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000));
781     b_delay->set_assigned_device_name(dev_b);
782     auto b2 = test::graph::Add(&g, b, b_delay);
783     b2->set_assigned_device_name(dev_b);
784     fetches.push_back(b2->name());
785     test::graph::ToGraphDef(&g, &gdef);
786   }
787   std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
788   ASSERT_TRUE(session != nullptr);
789 
790   TF_CHECK_OK(session->Create(gdef));
791   {
792     Status status = session->Run({}, fetches, {}, nullptr);
793     EXPECT_FALSE(status.ok());
794     EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
795   }
796   // session->Close() shall clean up all states related to the session->
797   // E.g., deregisters subgraph with workers, etc.
798   TF_CHECK_OK(session->Close());
799 
800   // Sleep a bit so that most of asynchronous works finishes before
801   // the test process finishes.
802   Env::Default()->SleepForMicroseconds(2000000);
803 }
804 
TEST(SessionTest,SharedVar)805 TEST(SessionTest, SharedVar) {
806   std::unique_ptr<test::TestCluster> cluster;
807   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
808   const string master = cluster->targets()[0];
809   CHECK_EQ(cluster->devices().size(), 1);
810 
811   GraphDef gdef;
812   string init_name;
813   string inc_name;
814   string get_name;
815   {
816     Graph g(OpRegistry::Global());
817     Tensor one(DT_FLOAT, TensorShape({}));
818     one.scalar<float>()() = 1.0;
819     Node* var = test::graph::Var(&g, DT_FLOAT, one.shape());
820     Node* init = test::graph::Assign(&g, var, test::graph::Constant(&g, one));
821     init_name = init->name();
822     Node* update = test::graph::Assign(
823         &g, var, test::graph::Add(&g, var, test::graph::Constant(&g, one)));
824     inc_name = update->name();
825     get_name = var->name();
826     test::graph::ToGraphDef(&g, &gdef);
827   }
828 
829   // Init a variable
830   {
831     Session* sess = NewRemote(Options(master, 1));
832     TF_CHECK_OK(sess->Create(gdef));
833     std::vector<std::pair<string, Tensor>> inp;
834     TF_CHECK_OK(sess->Run(inp, {}, {init_name}, nullptr));
835     TF_CHECK_OK(sess->Close());
836     delete sess;
837   }
838 
839   for (int rep = 1; rep < 10; ++rep) {
840     // Update a variable
841     {
842       Session* sess = NewRemote(Options(master, 1));
843       TF_CHECK_OK(sess->Create(gdef));
844       std::vector<std::pair<string, Tensor>> inp;
845       TF_CHECK_OK(sess->Run(inp, {}, {inc_name}, nullptr));
846       TF_CHECK_OK(sess->Close());
847       delete sess;
848     }
849 
850     // Gets the variable's value.
851     {
852       Session* sess = NewRemote(Options(master, 1));
853       TF_CHECK_OK(sess->Create(gdef));
854       std::vector<std::pair<string, Tensor>> inp;
855       std::vector<Tensor> ret;
856       TF_CHECK_OK(sess->Run(inp, {get_name}, {}, &ret));
857       ASSERT_EQ(ret.size(), 1);
858       EXPECT_EQ(ret[0].scalar<float>()(), 1.0 * (1 + rep));
859       TF_CHECK_OK(sess->Close());
860       delete sess;
861     }
862   }
863 }
864 
CreateInvalidGraph(const string & graph_def_ascii,const string & error_substring)865 void CreateInvalidGraph(const string& graph_def_ascii,
866                         const string& error_substring) {
867   GraphDef graph;
868   CHECK(protobuf::TextFormat::ParseFromString(graph_def_ascii, &graph));
869 
870   std::unique_ptr<test::TestCluster> cluster;
871   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
872 
873   std::unique_ptr<Session> session(
874       NewRemote(Options(cluster->targets()[0], 1)));
875   Status s = session->Create(graph);
876 
877   ASSERT_FALSE(s.ok());
878   EXPECT_NE(s.error_message().find(error_substring), string::npos);
879 }
880 
TEST(SessionTest,InvalidOpName)881 TEST(SessionTest, InvalidOpName) {
882   CreateInvalidGraph(R"(
883     node {
884       name: 'a:b' op: 'Const'
885       attr { key: 'dtype' value { type: DT_FLOAT } }
886       attr { key: 'value' value {
887         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
888                  float_val: [100] }
889       } }
890     }
891   )",
892                      "Illegal op name");
893 
894   CreateInvalidGraph(R"(
895     node {
896       name: 'a:0' op: 'Const'
897       attr { key: 'dtype' value { type: DT_FLOAT } }
898       attr { key: 'value' value {
899         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
900                  float_val: [100] }
901       } }
902     }
903   )",
904                      "Illegal op name");
905 
906   CreateInvalidGraph(R"(
907     node {
908       name: '_a' op: 'Const'
909       attr { key: 'dtype' value { type: DT_FLOAT } }
910       attr { key: 'value' value {
911         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
912                  float_val: [100] }
913       } }
914     }
915   )",
916                      "Illegal op name");
917 }
918 
TEST(SessionTest,InvalidOpInputName)919 TEST(SessionTest, InvalidOpInputName) {
920   CreateInvalidGraph(R"(
921     node {
922       name: 'a' op: 'const'
923       attr { key: 'dtype' value { type: DT_FLOAT } }
924       attr { key: 'value' value {
925         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
926                  float_val: [100] }
927       } }
928     }
929     node {
930       name:'b' op:'MatMul' input:'a:first' input:'a'
931       attr { key: 'T' value { type: DT_FLOAT } }
932       attr { key: 'transpose_a' value { b: false } }
933       attr { key: 'transpose_b' value { b: false } }
934     }
935   )",
936                      "Illegal op input name");
937 
938   CreateInvalidGraph(R"(
939     node {
940       name: 'a' op: 'const'
941       attr { key: 'dtype' value { type: DT_FLOAT } }
942       attr { key: 'value' value {
943         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
944                  float_val: [100] }
945       } }
946     }
947     node {
948       name:'b' op:'MatMul' input:'_a' input:'a'
949       attr { key: 'T' value { type: DT_FLOAT } }
950       attr { key: 'transpose_a' value { b: false } }
951       attr { key: 'transpose_b' value { b: false } }
952     }
953   )",
954                      "Illegal op input name");
955 
956   CreateInvalidGraph(R"(
957     node {
958       name: 'a' op: 'const'
959       attr { key: 'dtype' value { type: DT_FLOAT } }
960       attr { key: 'value' value {
961         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
962                  float_val: [100] }
963       } }
964     }
965     node {
966       name:'b' op:'MatMul' input:'_a:0' input:'a'
967       attr { key: 'T' value { type: DT_FLOAT } }
968       attr { key: 'transpose_a' value { b: false } }
969       attr { key: 'transpose_b' value { b: false } }
970     }
971   )",
972                      "Illegal op input name");
973 
974   CreateInvalidGraph(R"(
975     node {
976       name: 'a' op: 'const'
977       attr { key: 'dtype' value { type: DT_FLOAT } }
978       attr { key: 'value' value {
979         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
980                  float_val: [100] }
981       } }
982     }
983     node {
984       name:'b' op:'MatMul' input:'a:01' input:'a'
985       attr { key: 'T' value { type: DT_FLOAT } }
986       attr { key: 'transpose_a' value { b: false } }
987       attr { key: 'transpose_b' value { b: false } }
988     }
989   )",
990                      "Illegal op input name");
991 }
992 
TEST(SessionTest,ExtendValidation)993 TEST(SessionTest, ExtendValidation) {
994   GraphDef graph;
995   bool success = protobuf::TextFormat::ParseFromString(R"(
996     node {
997       name: 'a' op: 'Const'
998       attr { key: 'dtype' value { type: DT_FLOAT } }
999       attr { key: 'value' value {
1000         tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
1001                  float_val: [100] }
1002       } }
1003     }
1004   )",
1005                                                        &graph);
1006   // NOTE(mrry): CHECK not done inline to avoid a compilation error in
1007   // open-source (due to a multi-line string in a macro argument).
1008   ASSERT_TRUE(success);
1009 
1010   std::unique_ptr<test::TestCluster> cluster;
1011   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
1012 
1013   std::unique_ptr<Session> session(
1014       NewRemote(Options(cluster->targets()[0], 1)));
1015   TF_CHECK_OK(session->Create(graph));
1016 
1017   // 1. Fail with an unknown input name.
1018   GraphDef extension;
1019   success = protobuf::TextFormat::ParseFromString(R"(
1020     node {
1021       name:'b' op:'MatMul' input:'a:first' input:'a'
1022       attr { key: 'T' value { type: DT_FLOAT } }
1023       attr { key: 'transpose_a' value { b: false } }
1024       attr { key: 'transpose_b' value { b: false } }
1025     }
1026   )",
1027                                                   &extension);
1028   ASSERT_TRUE(success);
1029 
1030   Status s = session->Extend(extension);
1031   ASSERT_FALSE(s.ok());
1032   EXPECT_NE(s.error_message().find("Illegal op input name"), string::npos);
1033 
1034   // 2. Succeed with a valid node.
1035   success = protobuf::TextFormat::ParseFromString(R"(
1036     node {
1037       name:'b' op:'MatMul' input:'a' input:'a'
1038       attr { key: 'T' value { type: DT_FLOAT } }
1039       attr { key: 'transpose_a' value { b: false } }
1040       attr { key: 'transpose_b' value { b: false } }
1041     }
1042   )",
1043                                                   &extension);
1044   ASSERT_TRUE(success);
1045   TF_CHECK_OK(session->Extend(extension));
1046 
1047   // 2. Fail with a duplicate node.
1048   success = protobuf::TextFormat::ParseFromString(R"(
1049     node {
1050       name:'b' op:'MatMul' input:'a' input:'a'
1051       attr { key: 'T' value { type: DT_FLOAT } }
1052       attr { key: 'transpose_a' value { b: false } }
1053       attr { key: 'transpose_b' value { b: false } }
1054     }
1055   )",
1056                                                   &extension);
1057   ASSERT_TRUE(success);
1058   s = session->Extend(extension);
1059   ASSERT_FALSE(s.ok());
1060   EXPECT_NE(s.error_message().find("'b', which was created by a previous call"),
1061             string::npos);
1062 }
1063 // Tests that Create() with "operation_timeout_in_ms" set times out.
TEST(SessionTest,CreateTimeoutWithSessionOptions)1064 TEST(SessionTest, CreateTimeoutWithSessionOptions) {
1065   // Creates a RemoteSession with "operation_timeout_in_ms" set to 100.
1066   SessionOptions options = Options("example.org:2222", 1);
1067   options.config.set_operation_timeout_in_ms(100);
1068   std::unique_ptr<Session> session(NewRemote(options));
1069 
1070   // Creates a long running op.
1071   Graph graph(OpRegistry::Global());
1072   Node* b = test::graph::Constant(&graph, Tensor());
1073   test::graph::Delay(&graph, b, Microseconds(1000000));
1074   GraphDef gdef;
1075   test::graph::ToGraphDef(&graph, &gdef);
1076   Status status = session->Create(gdef);
1077   // Either error is possible, depending on the environment.
1078   EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
1079               error::UNAVAILABLE == status.code());
1080 }
1081 
1082 // Tests that Create() with "timeout_in_ms" in RunOptions set times out.
TEST(SessionTest,CreateTimeoutWithRunOptions)1083 TEST(SessionTest, CreateTimeoutWithRunOptions) {
1084   SessionOptions options = Options("example.org:2222", 1);
1085   std::unique_ptr<Session> session(NewRemote(options));
1086 
1087   // Creates a long running op.
1088   Graph graph(OpRegistry::Global());
1089   Node* b = test::graph::Constant(&graph, Tensor());
1090   test::graph::Delay(&graph, b, Microseconds(1000000));
1091   GraphDef gdef;
1092   test::graph::ToGraphDef(&graph, &gdef);
1093   RunOptions run_options;
1094   // Sets RunOption timeout_in_ms to 20.
1095   run_options.set_timeout_in_ms(20);
1096   Status status = session->Create(run_options, gdef);
1097   // Either error is possible, depending on the environment.
1098   EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
1099               error::UNAVAILABLE == status.code());
1100 }
1101 
1102 // Tests that Run() with "operation_timeout_in_ms" set times out.
TEST(SessionTest,RunTimeoutWithSessionOptions)1103 TEST(SessionTest, RunTimeoutWithSessionOptions) {
1104   // Creates a RemoteSession with "operation_timeout_in_ms" set to 100.
1105   std::unique_ptr<test::TestCluster> cluster;
1106   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
1107   SessionOptions options = Options(cluster->targets()[0], 100);
1108   options.config.set_operation_timeout_in_ms(1);
1109   std::unique_ptr<Session> session(NewRemote(options));
1110 
1111   // Creates a long running op.
1112   Graph graph(OpRegistry::Global());
1113   Node* b = test::graph::Constant(&graph, Tensor());
1114   Node* b_delay = test::graph::Delay(&graph, b, Microseconds(2000000));
1115   GraphDef gdef;
1116   test::graph::ToGraphDef(&graph, &gdef);
1117   RunOptions run_options;
1118   TF_CHECK_OK(session->Create(run_options, gdef));
1119 
1120   // Verifies that Run() times out, and the error code is DEADLINE_EXCEEDED.
1121   std::vector<std::pair<string, Tensor>> inputs;
1122   Status status = session->Run(inputs, {}, {b_delay->name()}, nullptr);
1123   // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get
1124   // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL.
1125   EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
1126               error::INTERNAL == status.code());
1127 }
1128 
1129 // Tests that Run() with "timeout_in_ms" set times out.
TEST(SessionTest,RunTimeoutWithRunOptions)1130 TEST(SessionTest, RunTimeoutWithRunOptions) {
1131   std::unique_ptr<test::TestCluster> cluster;
1132   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
1133   SessionOptions options = Options(cluster->targets()[0], 1);
1134   std::unique_ptr<Session> session(NewRemote(options));
1135 
1136   // Creates a long running op.
1137   Graph graph(OpRegistry::Global());
1138   Node* b = test::graph::Constant(&graph, Tensor());
1139   Node* b_delay = test::graph::Delay(&graph, b, Microseconds(1000000));
1140   GraphDef gdef;
1141   test::graph::ToGraphDef(&graph, &gdef);
1142   TF_CHECK_OK(session->Create(gdef));
1143 
1144   // Verifies that Run() times out, and the error code is DEADLINE_EXCEEDED.
1145   std::vector<std::pair<string, Tensor>> inputs;
1146   RunOptions run_options;
1147   run_options.set_timeout_in_ms(100);
1148   Status status = session->Run(run_options, inputs, {}, {b_delay->name()},
1149                                nullptr, nullptr);
1150   // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get
1151   // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL.
1152   EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() ||
1153               error::INTERNAL == status.code());
1154 }
1155 
TEST(SessionTest,TestCompression)1156 TEST(SessionTest, TestCompression) {
1157   std::unique_ptr<test::TestCluster> cluster;
1158   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
1159   SessionOptions options = Options(cluster->targets()[0], 100);
1160   RPCOptions* rpc_options = options.config.mutable_rpc_options();
1161   rpc_options->set_compression_algorithm("deflate");
1162   rpc_options->set_compression_level(GRPC_COMPRESS_LEVEL_HIGH);
1163 
1164   std::unique_ptr<Session> session(NewRemote(options));
1165 
1166   static const float kTestValue = 409.1934f;
1167   Graph graph(OpRegistry::Global());
1168   Tensor tensor(DT_FLOAT, TensorShape({1, 1}));
1169   tensor.flat<float>()(0) = kTestValue;
1170   Node* b = test::graph::Constant(&graph, tensor);
1171   GraphDef gdef;
1172   graph.ToGraphDef(&gdef);
1173   RunOptions run_options;
1174   TF_CHECK_OK(session->Create(run_options, gdef));
1175 
1176   std::vector<std::pair<string, Tensor>> inputs;
1177   std::vector<Tensor> outputs;
1178   TF_CHECK_OK(session->Run(inputs, {b->name()}, {}, &outputs));
1179   ASSERT_EQ(1, outputs.size());
1180   IsSingleFloatValue(outputs[0], kTestValue);
1181 }
1182 
TEST(GrpcSessionTest,ErrorAggregationTwoWorkersTwoErrors)1183 TEST(GrpcSessionTest, ErrorAggregationTwoWorkersTwoErrors) {
1184   std::unique_ptr<test::TestCluster> cluster;
1185   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
1186   auto& devs = cluster->devices();
1187   const string& master = cluster->targets()[0];
1188   // worker 1
1189   const string w1_dev1 = devs[0].name();
1190   // worker 2
1191   const string w2_dev1 = devs[1].name();
1192 
1193   LOG(INFO) << "master " << master << "w1_dev1 " << w1_dev1 << " w2_dev1 "
1194             << w2_dev1;
1195   GraphDef gdef;
1196   std::vector<string> fetches;
1197   {
1198     // Set up a graph to test the error handling when two workers both reports
1199     // original errors. The expected behavior is:
1200     //   1. The master issues a cancel operation upon receiving the first error.
1201     //   2. The master may receive one or both errors depending on the timing
1202     //      of the cancel operation.
1203     //
1204     // Set up:
1205     // Set up two workers. Both worker reports error the master without any
1206     // delay.
1207     Graph g(OpRegistry::Global());
1208 
1209     // Worker 1. a_err runs on w1_dev1 and a_delay runs on w2_dev2.
1210     auto a = test::graph::Constant(&g, Tensor(1));
1211     a->set_assigned_device_name(w1_dev1);
1212 
1213     auto a_err = test::graph::Error(&g, a, "fantasia1!");
1214     a_err->set_assigned_device_name(w1_dev1);
1215 
1216     fetches.push_back(a_err->name());
1217 
1218     // Worker 2. b2 depends on a_err and detects the error via the rendezvous
1219     // from worker 1.
1220     auto b = test::graph::Constant(&g, Tensor(1));
1221     b->set_assigned_device_name(w2_dev1);
1222 
1223     auto b_err = test::graph::Error(&g, b, "fantasia2!");
1224     b_err->set_assigned_device_name(w2_dev1);
1225 
1226     fetches.push_back(b_err->name());
1227 
1228     g.ToGraphDef(&gdef);
1229   }
1230 
1231   std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
1232   ASSERT_TRUE(session != nullptr);
1233 
1234   TF_CHECK_OK(session->Create(gdef));
1235   {
1236     std::vector<Tensor> outputs;
1237     Status status = session->Run({}, fetches, {}, &outputs);
1238     LOG(INFO) << status;
1239     EXPECT_FALSE(status.ok());
1240     // Status contains the error either worker1 or worker2.
1241     EXPECT_NE(status.ToString().find("fantasia"), string::npos);
1242     EXPECT_EQ(status.code(), error::Code::INTERNAL);
1243   }
1244   // session->Close() shall clean up all states related to the session->
1245   // E.g., deregisters subgraph with workers, etc.
1246   TF_CHECK_OK(session->Close());
1247 
1248   // Sleep a bit so that most of asynchronous works finishes before
1249   // the test process finishes.
1250   Env::Default()->SleepForMicroseconds(2000000);
1251 }
1252 
TEST(GrpcSessionTest,ErrorAggregationTwoWorkerRace)1253 TEST(GrpcSessionTest, ErrorAggregationTwoWorkerRace) {
1254   std::unique_ptr<test::TestCluster> cluster;
1255   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(2, 0), 2, &cluster));
1256   auto& devs = cluster->devices();
1257   const string& master = cluster->targets()[0];
1258   // worker 1
1259   const string w1_dev1 = devs[0].name();
1260   const string w1_dev2 = devs[1].name();
1261   // worker 2
1262   const string w2_dev1 = devs[2].name();
1263 
1264   LOG(INFO) << "master " << master << "w1_dev1 " << w1_dev1 << " w1_dev2 "
1265             << w1_dev2 << " w2_dev1 " << w2_dev1;
1266   GraphDef gdef;
1267   std::vector<string> fetches;
1268   std::vector<string> targets;
1269   {
1270     // Set up a graph to test the error handling when a derived error is
1271     // reported to master before the original error. The expected behavior is:
1272     //    1. the original error will be received by the master and reported
1273     //       to the user as the error status.
1274     //
1275     // Setup:
1276     //
1277     // Worker 1 generates the original error but it delays for 5 seconds before
1278     // reporting to master. Worker 2 detects the error (via Rendezvous) and
1279     // reports to the master before worker 1.
1280     Graph g(OpRegistry::Global());
1281 
1282     // Worker 1. a_err runs on w1_dev1 and a_delay runs on w2_dev2.
1283     auto a = test::graph::Constant(&g, Tensor(1));
1284     a->set_assigned_device_name(w1_dev1);
1285 
1286     auto a_err = test::graph::Error(&g, a, "fantasia!");
1287     a_err->set_assigned_device_name(w1_dev1);
1288 
1289     auto a_delay = test::graph::Delay(&g, a, Microseconds(5000000));
1290     a_delay->set_assigned_device_name(w1_dev2);
1291 
1292     // We need to put a_delay in targets instead of fetches. Putting
1293     // a_delay in fetches will cause the executor for w1_dev2 to report failure
1294     // status as well as the Rendezvous is already poisoned by the a_err op in
1295     // w1_dev1.
1296     targets.push_back(a_delay->name());
1297     fetches.push_back(a_err->name());
1298 
1299     // Worker 2. b2 depends on a_err and detects the error via the rendezvous
1300     // from worker 1.
1301     auto b = test::graph::Constant(&g, Tensor(3));
1302     b->set_assigned_device_name(w2_dev1);
1303     auto b2 = test::graph::Add(&g, b, a_err);
1304     b2->set_assigned_device_name(w2_dev1);
1305     fetches.push_back(b2->name());
1306 
1307     g.ToGraphDef(&gdef);
1308   }
1309 
1310   std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
1311   ASSERT_TRUE(session != nullptr);
1312 
1313   TF_CHECK_OK(session->Create(gdef));
1314   {
1315     std::vector<Tensor> outputs;
1316     Status status = session->Run({}, fetches, targets, &outputs);
1317     LOG(INFO) << status;
1318     EXPECT_FALSE(status.ok());
1319     // assert status contains the root error
1320     EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
1321     // assert status does not contain cancelled error.
1322     EXPECT_EQ(status.ToString().find("Cancelled"), string::npos);
1323     EXPECT_EQ(status.code(), error::Code::INTERNAL);
1324   }
1325   // session->Close() shall clean up all states related to the session->
1326   // E.g., deregisters subgraph with workers, etc.
1327   TF_CHECK_OK(session->Close());
1328 
1329   // Sleep a bit so that most of asynchronous works finishes before
1330   // the test process finishes.
1331   Env::Default()->SleepForMicroseconds(2000000);
1332 }
1333 
TEST(GrpcSessionTest,ErrorAggregationThreeWorkerRaceVariant1)1334 TEST(GrpcSessionTest, ErrorAggregationThreeWorkerRaceVariant1) {
1335   std::unique_ptr<test::TestCluster> cluster;
1336   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(2, 0), 3, &cluster));
1337   auto& devs = cluster->devices();
1338   const string& master = cluster->targets()[0];
1339   // worker 1
1340   const string w1_dev1 = devs[0].name();
1341   const string w1_dev2 = devs[1].name();
1342   // worker 2
1343   const string w2_dev1 = devs[2].name();
1344   // worker 3
1345   const string w3_dev1 = devs[4].name();
1346 
1347   LOG(INFO) << "master " << master << "w1_dev1 " << w1_dev1 << " w1_dev2 "
1348             << w1_dev2 << " w2_dev1 " << w2_dev1 << " w3_dev1 " << w3_dev1;
1349   GraphDef gdef;
1350   std::vector<string> fetches;
1351   std::vector<string> targets;
1352   {
1353     // Set up a graph to test the error handling when a derived error is
1354     // reported to master before the original error and a third worker is
1355     // canceled by the master. The expect behavior is that
1356     //    1. the original error will be received by the master,
1357     //    2. the canceled error will be treated as a derived error.
1358     //
1359     // Setup:
1360     //
1361     // Worker 1 generates the original error but it delays for 5 seconds before
1362     // reporting to master. Worker 2 detects the error (via Rendezvous) and
1363     // reports to the master before worker 1. Worker 3 runs a delay op and will
1364     // be canceled by the master.
1365     Graph g(OpRegistry::Global());
1366 
1367     // Worker 1. a_err runs on w1_dev1 and a_delay runs on w2_dev2.
1368     auto a = test::graph::Constant(&g, Tensor(1));
1369     a->set_assigned_device_name(w1_dev1);
1370 
1371     auto a_err = test::graph::Error(&g, a, "fantasia!");
1372     a_err->set_assigned_device_name(w1_dev1);
1373 
1374     auto a_delay = test::graph::Delay(&g, a, Microseconds(5000000));
1375     a_delay->set_assigned_device_name(w1_dev2);
1376 
1377     // Putting a_delay in fetches will cause the executor for w1_dev2 to report
1378     // failure status as well due to the use of SendOp, as the Rendezvous is
1379     // already poisoned by the a_err op in w1_dev1.
1380     targets.push_back(a_delay->name());
1381     fetches.push_back(a_err->name());
1382 
1383     // Worker 2. b2 depends on a_err and detects the error via the rendezvous
1384     // from worker 1.
1385     auto b = test::graph::Constant(&g, Tensor(3));
1386     b->set_assigned_device_name(w2_dev1);
1387     auto b2 = test::graph::Add(&g, b, a_err);
1388     b2->set_assigned_device_name(w2_dev1);
1389     fetches.push_back(b2->name());
1390 
1391     // Worker 3. Runs only a delay op. This worker will be cancelled by master
1392     // when the master receives the root error from Worker 1.
1393     auto c = test::graph::Constant(&g, Tensor(3));
1394     c->set_assigned_device_name(w3_dev1);
1395     auto c_delay = test::graph::Delay(&g, c, Microseconds(4000000));
1396     c_delay->set_assigned_device_name(w3_dev1);
1397 
1398     // Put c_delay in targets so that an implicit SendOp for c_delay to
1399     // worker 1 is not generated.
1400     targets.push_back(c_delay->name());
1401 
1402     g.ToGraphDef(&gdef);
1403   }
1404 
1405   std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
1406   ASSERT_TRUE(session != nullptr);
1407 
1408   TF_CHECK_OK(session->Create(gdef));
1409   {
1410     std::vector<Tensor> outputs;
1411     Status status = session->Run({}, fetches, targets, &outputs);
1412     LOG(INFO) << status;
1413     EXPECT_FALSE(status.ok());
1414     // assert status contains the root error
1415     EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
1416     // assert status does not contain cancelled or aborted error.
1417     EXPECT_EQ(status.ToString().find("Cancelled"), string::npos);
1418     EXPECT_EQ(status.ToString().find("Aborted"), string::npos);
1419     EXPECT_EQ(status.code(), error::Code::INTERNAL);
1420   }
1421   // session->Close() shall clean up all states related to the session->
1422   // E.g., deregisters subgraph with workers, etc.
1423   TF_CHECK_OK(session->Close());
1424 
1425   // Sleep a bit so that most of asynchronous works finishes before
1426   // the test process finishes.
1427   Env::Default()->SleepForMicroseconds(2000000);
1428 }
1429 
TEST(GrpcSessionTest,ErrorAggregationThreeWorkerRaceVariant2)1430 TEST(GrpcSessionTest, ErrorAggregationThreeWorkerRaceVariant2) {
1431   std::unique_ptr<test::TestCluster> cluster;
1432   TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(2, 0), 3, &cluster));
1433   auto& devs = cluster->devices();
1434   const string& master = cluster->targets()[0];
1435   // worker 1
1436   const string w1_dev1 = devs[0].name();
1437   const string w1_dev2 = devs[1].name();
1438   // worker 2
1439   const string w2_dev1 = devs[2].name();
1440   // worker 3
1441   const string w3_dev1 = devs[4].name();
1442 
1443   LOG(INFO) << "master " << master << "w1_dev1 " << w1_dev1 << " w1_dev2 "
1444             << w1_dev2 << " w2_dev1 " << w2_dev1 << " w3_dev1 " << w3_dev1;
1445   GraphDef gdef;
1446   std::vector<string> fetches;
1447   std::vector<string> targets;
1448   {
1449     // Set up a graph to test the error handling when a derived error is
1450     // reported to master before the original error and a third worker is
1451     // aborted from an implicit SendOp. The expect behavior is that
1452     //    1. the original error will be received by the master,
1453     //    2. the aborted error will be treated as a derived error.
1454     //
1455     // Setup:
1456     //
1457     // Worker 1 generates the original error but it delays for 5 seconds before
1458     // reporting to master. Worker 2 detects the error (via Rendezvous) and
1459     // reports to the master before worker 1. Worker 3 runs a delay op and an
1460     // implicit SendOp (for sending tensor c_delay to Worker 1) and will be
1461     // aborted by worker 1.
1462     Graph g(OpRegistry::Global());
1463 
1464     // Worker 1. a_err runs on w1_dev1 and a_delay runs on w2_dev2.
1465     auto a = test::graph::Constant(&g, Tensor(1));
1466     a->set_assigned_device_name(w1_dev1);
1467 
1468     auto a_err = test::graph::Error(&g, a, "fantasia!");
1469     a_err->set_assigned_device_name(w1_dev1);
1470 
1471     auto a_delay = test::graph::Delay(&g, a, Microseconds(5000000));
1472     a_delay->set_assigned_device_name(w1_dev2);
1473 
1474     // Putting a_delay in fetches will cause the executor for w1_dev2 to report
1475     // failure status as well due to the use of SendOp, as the Rendezvous is
1476     // already poisoned by the a_err op in w1_dev1.
1477     targets.push_back(a_delay->name());
1478     fetches.push_back(a_err->name());
1479 
1480     // Worker 2. b2 depends on a_err and detects the error via the rendezvous
1481     // from worker 1.
1482     auto b = test::graph::Constant(&g, Tensor(3));
1483     b->set_assigned_device_name(w2_dev1);
1484     auto b2 = test::graph::Add(&g, b, a_err);
1485     b2->set_assigned_device_name(w2_dev1);
1486     fetches.push_back(b2->name());
1487 
1488     // Worker 3. Runs only a delay op. This worker will be cancelled by master
1489     // when the master receives the root error from Worker 1.
1490     auto c = test::graph::Constant(&g, Tensor(3));
1491     c->set_assigned_device_name(w3_dev1);
1492     auto c_delay = test::graph::Delay(&g, c, Microseconds(4000000));
1493     c_delay->set_assigned_device_name(w3_dev1);
1494 
1495     // Put c_delay in fetches so that an implicit SendOp for c_delay to worker 1
1496     // is generated.
1497     fetches.push_back(c_delay->name());
1498 
1499     g.ToGraphDef(&gdef);
1500   }
1501 
1502   std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
1503   ASSERT_TRUE(session != nullptr);
1504 
1505   TF_CHECK_OK(session->Create(gdef));
1506   {
1507     std::vector<Tensor> outputs;
1508     Status status = session->Run({}, fetches, targets, &outputs);
1509     LOG(INFO) << status;
1510     EXPECT_FALSE(status.ok());
1511     // assert status contains the root error
1512     EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
1513     // assert status does not contain cancelled or aborted error.
1514     EXPECT_EQ(status.ToString().find("Cancelled"), string::npos);
1515     EXPECT_EQ(status.ToString().find("Aborted"), string::npos);
1516     EXPECT_EQ(status.code(), error::Code::INTERNAL);
1517   }
1518   // session->Close() shall clean up all states related to the session->
1519   // E.g., deregisters subgraph with workers, etc.
1520   TF_CHECK_OK(session->Close());
1521 
1522   // Sleep a bit so that most of asynchronous works finishes before
1523   // the test process finishes.
1524   Env::Default()->SleepForMicroseconds(2000000);
1525 }
1526 
1527 }  // namespace tensorflow
1528