• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/direct_session.h"
17 
18 #include <map>
19 #include <memory>
20 #include <random>
21 #include <string>
22 #include <thread>  // NOLINT
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "absl/memory/memory.h"
27 #include "absl/strings/match.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/function_testlib.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_testutil.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/graph/costmodel.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/node_builder.h"
40 #include "tensorflow/core/graph/testlib.h"
41 #include "tensorflow/core/kernels/ops_util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/core/threadpool.h"
46 #include "tensorflow/core/lib/strings/str_util.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/stacktrace.h"
49 #include "tensorflow/core/platform/test.h"
50 #include "tensorflow/core/platform/test_benchmark.h"
51 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
52 #include "tensorflow/core/public/session.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/util/device_name_utils.h"
55 
56 #if GOOGLE_CUDA
57 #include "third_party/gpus/cuda/include/cuda.h"
58 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
59 #elif TENSORFLOW_USE_ROCM
60 #include "rocm/include/hip/hip_runtime.h"
61 #endif  // GOOGLE_CUDA
62 
63 namespace tensorflow {
64 namespace {
65 
MakeCallableOptions(gtl::ArraySlice<string> feeds,gtl::ArraySlice<string> fetches,gtl::ArraySlice<string> targets)66 CallableOptions MakeCallableOptions(gtl::ArraySlice<string> feeds,
67                                     gtl::ArraySlice<string> fetches,
68                                     gtl::ArraySlice<string> targets) {
69   CallableOptions ret;
70   for (const string& feed : feeds) {
71     ret.add_feed(feed);
72   }
73   for (const string& fetch : fetches) {
74     ret.add_fetch(fetch);
75   }
76   for (const string& target : targets) {
77     ret.add_target(target);
78   }
79   return ret;
80 }
81 
DefaultSessionOptions()82 SessionOptions DefaultSessionOptions() {
83   SessionOptions options;
84   (*options.config.mutable_device_count())["CPU"] = 2;
85   return options;
86 }
87 
CreateSession()88 std::unique_ptr<Session> CreateSession() {
89   return std::unique_ptr<Session>(NewSession(DefaultSessionOptions()));
90 }
91 
92 class DirectSessionMinusAXTest : public ::testing::Test {
93  public:
Initialize(std::initializer_list<float> a_values)94   void Initialize(std::initializer_list<float> a_values) {
95     Graph graph(OpRegistry::Global());
96 
97     Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
98     test::FillValues<float>(&a_tensor, a_values);
99     Node* a = test::graph::Constant(&graph, a_tensor);
100     a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
101     a_ = a->name();
102 
103     Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
104     test::FillValues<float>(&x_tensor, {1, 1});
105     Node* x = test::graph::Constant(&graph, x_tensor);
106     x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
107     x_ = x->name();
108 
109     // y = A * x
110     Node* y = test::graph::Matmul(&graph, a, x, false, false);
111     y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
112     y_ = y->name();
113 
114     Node* y_neg = test::graph::Unary(&graph, "Neg", y);
115     y_neg_ = y_neg->name();
116     y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
117 
118     Node* z = test::graph::Unary(&graph, "Identity", y_neg);
119     z_ = z->name();
120     z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
121 
122     graph.ToGraphDef(&def_);
123   }
124 
125   string a_;
126   string x_;
127   string y_;
128   string y_neg_;
129   string z_;
130   GraphDef def_;
131 };
132 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork)133 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
134   Initialize({3, 2, -1, 0});
135   auto session = CreateSession();
136   ASSERT_TRUE(session != nullptr);
137   TF_ASSERT_OK(session->Create(def_));
138   std::vector<std::pair<string, Tensor>> inputs;
139 
140   // Request two targets: one fetch output and one non-fetched output.
141   std::vector<string> output_names = {y_ + ":0"};
142   std::vector<string> target_nodes = {y_neg_};
143   std::vector<Tensor> outputs;
144   Status s = session->Run(inputs, output_names, target_nodes, &outputs);
145   TF_ASSERT_OK(s);
146 
147   ASSERT_EQ(1, outputs.size());
148   // The first output should be initialized and have the correct
149   // output.
150   auto mat = outputs[0].matrix<float>();
151   ASSERT_TRUE(outputs[0].IsInitialized());
152   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
153 }
154 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_Callable)155 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
156   Initialize({3, 2, -1, 0});
157   auto session = CreateSession();
158   ASSERT_TRUE(session != nullptr);
159   TF_ASSERT_OK(session->Create(def_));
160 
161   // Run the test twice to ensure that the Make/Run/Release cycle is hermetic.
162   for (int i = 0; i < 2; ++i) {
163     // Request two targets: one fetch output and one non-fetched output.
164     Session::CallableHandle handle;
165     TF_ASSERT_OK(session->MakeCallable(
166         MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
167 
168     for (int i = 0; i < 2; ++i) {
169       std::vector<Tensor> outputs;
170       TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
171 
172       ASSERT_EQ(1, outputs.size());
173       // The first output should be initialized and have the correct
174       // output.
175       auto mat = outputs[0].matrix<float>();
176       ASSERT_TRUE(outputs[0].IsInitialized());
177       EXPECT_FLOAT_EQ(5.0, mat(0, 0));
178     }
179 
180     Status s = session->RunCallable(handle, {}, nullptr, nullptr);
181     EXPECT_TRUE(errors::IsInvalidArgument(s));
182     EXPECT_TRUE(absl::StrContains(s.error_message(),
183                                   "`fetch_tensors` must be provided"));
184 
185     TF_ASSERT_OK(session->ReleaseCallable(handle));
186 
187     std::vector<Tensor> outputs;
188     s = session->RunCallable(handle, {}, &outputs, nullptr);
189     EXPECT_TRUE(errors::IsInvalidArgument(s));
190     EXPECT_TRUE(absl::StrContains(
191         s.error_message(),
192         "Attempted to run callable after handle was released"));
193 
194     s = session->RunCallable(handle + 1, {}, &outputs, nullptr);
195     EXPECT_TRUE(errors::IsInvalidArgument(s));
196     EXPECT_TRUE(
197         absl::StrContains(s.error_message(), "No such callable handle"));
198   }
199 }
200 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_OptimizeForStaticGraph)201 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_OptimizeForStaticGraph) {
202   Initialize({3, 2, -1, 0});
203   SessionOptions options(DefaultSessionOptions());
204   options.config.mutable_experimental()->set_optimize_for_static_graph(true);
205   auto session = absl::WrapUnique(NewSession(options));
206 
207   ASSERT_TRUE(session != nullptr);
208   TF_ASSERT_OK(session->Create(def_));
209   std::vector<std::pair<string, Tensor>> inputs;
210 
211   // Request two targets: one fetch output and one non-fetched output.
212   std::vector<string> output_names = {y_ + ":0"};
213   std::vector<string> target_nodes = {y_neg_};
214   std::vector<Tensor> outputs;
215   Status s = session->Run(inputs, output_names, target_nodes, &outputs);
216   TF_ASSERT_OK(s);
217 
218   ASSERT_EQ(1, outputs.size());
219   // The first output should be initialized and have the correct
220   // output.
221   auto mat = outputs[0].matrix<float>();
222   ASSERT_TRUE(outputs[0].IsInitialized());
223   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
224 
225   s = session->Extend({});
226   EXPECT_TRUE(errors::IsFailedPrecondition(s));
227   EXPECT_TRUE(
228       absl::StrContains(s.error_message(), "optimize_for_static_graph"));
229 }
230 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_DisableOutputPartitionGraphs)231 TEST_F(DirectSessionMinusAXTest,
232        RunSimpleNetwork_DisableOutputPartitionGraphs) {
233   Initialize({3, 2, -1, 0});
234   SessionOptions options(DefaultSessionOptions());
235   options.config.mutable_experimental()->set_disable_output_partition_graphs(
236       true);
237   auto session = absl::WrapUnique(NewSession(options));
238 
239   ASSERT_TRUE(session != nullptr);
240   TF_ASSERT_OK(session->Create(def_));
241   std::vector<std::pair<string, Tensor>> inputs;
242 
243   // Request two targets: one fetch output and one non-fetched output.
244   std::vector<string> output_names = {y_ + ":0"};
245   std::vector<string> target_nodes = {y_neg_};
246   std::vector<Tensor> outputs;
247   Status s = session->Run(inputs, output_names, target_nodes, &outputs);
248   TF_ASSERT_OK(s);
249 
250   ASSERT_EQ(1, outputs.size());
251   // The first output should be initialized and have the correct
252   // output.
253   auto mat = outputs[0].matrix<float>();
254   ASSERT_TRUE(outputs[0].IsInitialized());
255   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
256 
257   // The Run() call should fail when `output_partition_graphs` is set to true.
258   RunOptions run_options;
259   run_options.set_output_partition_graphs(true);
260   RunMetadata run_metadata;
261   s = session->Run(run_options, inputs, output_names, target_nodes, &outputs,
262                    &run_metadata);
263 
264   EXPECT_TRUE(errors::IsInvalidArgument(s));
265   EXPECT_TRUE(
266       absl::StrContains(s.error_message(), "disable_output_partition_graphs"));
267 }
268 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_FinalizeWithCallables)269 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithCallables) {
270   Initialize({3, 2, -1, 0});
271   auto session = CreateSession();
272   ASSERT_TRUE(session != nullptr);
273   TF_ASSERT_OK(session->Create(def_));
274 
275   // Request two targets: one fetch output and one non-fetched output.
276   Session::CallableHandle handle;
277   TF_ASSERT_OK(session->MakeCallable(
278       MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
279 
280   // Finalize the session.
281   TF_ASSERT_OK(session->Finalize());
282 
283   // The callable is usable after finalization.
284   for (int i = 0; i < 2; ++i) {
285     std::vector<Tensor> outputs;
286     TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
287 
288     ASSERT_EQ(1, outputs.size());
289     // The first output should be initialized and have the correct
290     // output.
291     auto mat = outputs[0].matrix<float>();
292     ASSERT_TRUE(outputs[0].IsInitialized());
293     EXPECT_FLOAT_EQ(5.0, mat(0, 0));
294   }
295 
296   TF_ASSERT_OK(session->ReleaseCallable(handle));
297 
298   // Making a new callable fails because the session has been finalized.
299   Status s =
300       session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle);
301   EXPECT_TRUE(errors::IsFailedPrecondition(s));
302   EXPECT_TRUE(
303       absl::StrContains(s.error_message(), "Session has been finalized."));
304 }
305 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_FinalizeWithRun)306 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithRun) {
307   Initialize({3, 2, -1, 0});
308   auto session = CreateSession();
309   ASSERT_TRUE(session != nullptr);
310   TF_ASSERT_OK(session->Create(def_));
311 
312   // Request two targets: one fetch output and one non-fetched output.
313   std::vector<Tensor> outputs;
314   TF_ASSERT_OK(session->Run({}, {y_ + ":0"}, {y_neg_}, &outputs));
315 
316   ASSERT_EQ(1, outputs.size());
317   // The first output should be initialized and have the correct output.
318   auto mat = outputs[0].matrix<float>();
319   ASSERT_TRUE(outputs[0].IsInitialized());
320   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
321 
322   // Finalize the session.
323   TF_ASSERT_OK(session->Finalize());
324 
325   // Running the exact same subgraph succeeds after finalization.
326   TF_ASSERT_OK(session->Run({}, {y_ + ":0"}, {y_neg_}, &outputs));
327   ASSERT_EQ(1, outputs.size());
328   mat = outputs[0].matrix<float>();
329   ASSERT_TRUE(outputs[0].IsInitialized());
330   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
331 
332   // Running a different subgraph fails because the session has been finalized.
333   Status s = session->Run({}, {y_ + ":0"}, {}, &outputs);
334   EXPECT_TRUE(errors::IsFailedPrecondition(s));
335   EXPECT_TRUE(
336       absl::StrContains(s.error_message(), "Session has been finalized."));
337 }
338 
TEST_F(DirectSessionMinusAXTest,TestTensorConnection)339 TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
340   Initialize({3, 2, -1, 0});
341   auto session = CreateSession();
342   ASSERT_TRUE(session != nullptr);
343   TF_ASSERT_OK(session->Create(def_));
344 
345   {
346     // Directly wire the output of node a to the output of node y, making the
347     // callable graph into "Neg(a);".
348     CallableOptions callable_options;
349     TensorConnection* c = callable_options.add_tensor_connection();
350     c->set_from_tensor(a_ + ":0");
351     c->set_to_tensor(y_ + ":0");
352     callable_options.add_fetch(y_neg_ + ":0");
353 
354     Session::CallableHandle handle;
355     TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
356     std::vector<Tensor> outputs;
357     TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
358     ASSERT_EQ(1, outputs.size());
359     auto mat = outputs[0].matrix<float>();
360     ASSERT_TRUE(outputs[0].IsInitialized());
361     EXPECT_FLOAT_EQ(-3.0, mat(0, 0));
362     EXPECT_FLOAT_EQ(-2.0, mat(0, 1));
363     EXPECT_FLOAT_EQ(1.0, mat(1, 0));
364     EXPECT_FLOAT_EQ(0.0, mat(1, 1));
365     TF_ASSERT_OK(session->ReleaseCallable(handle));
366   }
367 
368   {
369     // Directly wire the output of node a to the output of node y, making the
370     // callable graph into "Neg(a);"; also fetch the result of a.
371     CallableOptions callable_options;
372     TensorConnection* c = callable_options.add_tensor_connection();
373     c->set_from_tensor(a_ + ":0");
374     c->set_to_tensor(y_ + ":0");
375     callable_options.add_fetch(a_ + ":0");
376     callable_options.add_fetch(y_neg_ + ":0");
377 
378     Session::CallableHandle handle;
379     TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
380     std::vector<Tensor> outputs;
381     TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
382     ASSERT_EQ(2, outputs.size());
383     auto mat_a = outputs[0].matrix<float>();
384     ASSERT_TRUE(outputs[0].IsInitialized());
385     EXPECT_FLOAT_EQ(3.0, mat_a(0, 0));
386     EXPECT_FLOAT_EQ(2.0, mat_a(0, 1));
387     EXPECT_FLOAT_EQ(-1.0, mat_a(1, 0));
388     EXPECT_FLOAT_EQ(0.0, mat_a(1, 1));
389 
390     auto mat_y_neg = outputs[1].matrix<float>();
391     ASSERT_TRUE(outputs[1].IsInitialized());
392     EXPECT_FLOAT_EQ(-3.0, mat_y_neg(0, 0));
393     EXPECT_FLOAT_EQ(-2.0, mat_y_neg(0, 1));
394     EXPECT_FLOAT_EQ(1.0, mat_y_neg(1, 0));
395     EXPECT_FLOAT_EQ(0.0, mat_y_neg(1, 1));
396     TF_ASSERT_OK(session->ReleaseCallable(handle));
397   }
398 
399   {
400     // Wire the output of "Neg(Matmul(a, x))" to the output of "a",
401     // creating an invalid cycle.
402     CallableOptions callable_options;
403     TensorConnection* c = callable_options.add_tensor_connection();
404     c->set_from_tensor(y_ + ":0");
405     c->set_to_tensor(a_ + ":0");
406     callable_options.add_fetch(y_ + ":0");
407 
408     Session::CallableHandle handle;
409     Status s = session->MakeCallable(callable_options, &handle);
410     EXPECT_TRUE(errors::IsInvalidArgument(s));
411     EXPECT_TRUE(absl::StrContains(s.error_message(), "would create a cycle"));
412   }
413 
414   {
415     // Attempt to wire a non-existent node to a node that does exist.
416     CallableOptions callable_options;
417     TensorConnection* c = callable_options.add_tensor_connection();
418     c->set_from_tensor("unknown_node:0");
419     c->set_to_tensor(y_ + ":0");
420     callable_options.add_fetch(y_ + ":0");
421 
422     Session::CallableHandle handle;
423     Status s = session->MakeCallable(callable_options, &handle);
424     EXPECT_TRUE(errors::IsInvalidArgument(s));
425     EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown node"));
426   }
427 
428   {
429     // Attempt to wire a non-existent output from a node that does
430     // exist to another node.
431     CallableOptions callable_options;
432     TensorConnection* c = callable_options.add_tensor_connection();
433     c->set_from_tensor(a_ + ":17");
434     c->set_to_tensor(y_ + ":0");
435     callable_options.add_fetch(y_ + ":0");
436 
437     Session::CallableHandle handle;
438     Status s = session->MakeCallable(callable_options, &handle);
439     EXPECT_TRUE(errors::IsInvalidArgument(s));
440     EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown edge"));
441   }
442 
443   {
444     // Attempt to wire a tensor to a node that doesn't exist.
445     CallableOptions callable_options;
446     TensorConnection* c = callable_options.add_tensor_connection();
447     c->set_from_tensor(a_ + ":0");
448     c->set_to_tensor("unknown_node:0");
449     callable_options.add_fetch(y_ + ":0");
450 
451     Session::CallableHandle handle;
452     Status s = session->MakeCallable(callable_options, &handle);
453     EXPECT_TRUE(errors::IsNotFound(s));
454     EXPECT_TRUE(
455         absl::StrContains(s.error_message(), "unable to find feed output"));
456   }
457 
458   {
459     // Attempt to wire two tensors to the same tensor.
460     CallableOptions callable_options;
461     TensorConnection* c1 = callable_options.add_tensor_connection();
462     c1->set_from_tensor(a_ + ":0");
463     c1->set_to_tensor(y_neg_ + ":0");
464     TensorConnection* c2 = callable_options.add_tensor_connection();
465     c2->set_from_tensor(x_ + ":0");
466     c2->set_to_tensor(y_neg_ + ":0");
467     callable_options.add_fetch(z_ + ":0");
468 
469     Session::CallableHandle handle;
470     Status s = session->MakeCallable(callable_options, &handle);
471     EXPECT_TRUE(errors::IsInvalidArgument(s));
472     EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
473   }
474 
475   {
476     // Attempt to wire a tensor to a tensor that is also being fed.
477     CallableOptions callable_options;
478     TensorConnection* c = callable_options.add_tensor_connection();
479     c->set_from_tensor(a_ + ":0");
480     c->set_to_tensor(y_ + ":0");
481     callable_options.add_feed(y_ + ":0");
482     callable_options.add_fetch(y_neg_ + ":0");
483 
484     Session::CallableHandle handle;
485     Status s = session->MakeCallable(callable_options, &handle);
486     EXPECT_TRUE(errors::IsInvalidArgument(s));
487     EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
488   }
489 }
490 
TEST_F(DirectSessionMinusAXTest,TestFeed)491 TEST_F(DirectSessionMinusAXTest, TestFeed) {
492   Initialize({1, 2, 3, 4});
493   auto session = CreateSession();
494   ASSERT_TRUE(session != nullptr);
495 
496   TF_ASSERT_OK(session->Create(def_));
497 
498   // Fill in the input and ask for the output
499   //
500   // Note that the input being fed is on the second device.
501   Tensor t(DT_FLOAT, TensorShape({2, 1}));
502   t.matrix<float>()(0, 0) = 5;
503   t.matrix<float>()(1, 0) = 6;
504   std::vector<std::pair<string, Tensor>> inputs = {{x_, t}};
505   std::vector<string> output_names = {y_ + ":0"};
506   std::vector<Tensor> outputs;
507 
508   // Run the graph
509   Status s = session->Run(inputs, output_names, {}, &outputs);
510   TF_ASSERT_OK(s);
511 
512   ASSERT_EQ(1, outputs.size());
513   auto mat = outputs[0].matrix<float>();
514 
515   // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
516   EXPECT_FLOAT_EQ(17.0, mat(0, 0));
517   EXPECT_FLOAT_EQ(39.0, mat(1, 0));
518 }
519 
TEST_F(DirectSessionMinusAXTest,TestFeed_Callable)520 TEST_F(DirectSessionMinusAXTest, TestFeed_Callable) {
521   Initialize({1, 2, 3, 4});
522   auto session = CreateSession();
523   ASSERT_TRUE(session != nullptr);
524 
525   TF_ASSERT_OK(session->Create(def_));
526 
527   // Fill in the input and ask for the output
528   //
529   // Note that the input being fed is on the second device.
530   CallableOptions callable_options;
531   callable_options.add_feed(x_);
532   callable_options.add_fetch(y_ + ":0");
533   Session::CallableHandle handle;
534   TF_ASSERT_OK(session->MakeCallable(MakeCallableOptions({x_}, {y_ + ":0"}, {}),
535                                      &handle));
536   Tensor t(DT_FLOAT, TensorShape({2, 1}));
537   t.matrix<float>()(0, 0) = 5;
538   t.matrix<float>()(1, 0) = 6;
539   std::vector<Tensor> inputs = {t};
540   std::vector<Tensor> outputs;
541 
542   // Run the callable
543   TF_ASSERT_OK(session->RunCallable(handle, inputs, &outputs, nullptr));
544 
545   ASSERT_EQ(1, outputs.size());
546   auto mat = outputs[0].matrix<float>();
547 
548   // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
549   EXPECT_FLOAT_EQ(17.0, mat(0, 0));
550   EXPECT_FLOAT_EQ(39.0, mat(1, 0));
551 }
552 
TEST_F(DirectSessionMinusAXTest,TestConcurrency)553 TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
554   Initialize({1, 2, 3, 4});
555   auto session = CreateSession();
556   ASSERT_TRUE(session != nullptr);
557   TF_ASSERT_OK(session->Create(def_));
558 
559   // Fill in the input and ask for the output
560   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
561 
562   // Run the graph 1000 times in 4 different threads concurrently.
563   std::vector<string> output_names = {y_ + ":0"};
564   auto fn = [&session, output_names]() {
565     for (int i = 0; i < 1000; ++i) {
566       std::vector<std::pair<string, Tensor>> inputs;
567       std::vector<Tensor> outputs;
568       // Run the graph
569       Status s = session->Run(inputs, output_names, {}, &outputs);
570       TF_ASSERT_OK(s);
571       ASSERT_EQ(1, outputs.size());
572       auto mat = outputs[0].matrix<float>();
573       EXPECT_FLOAT_EQ(3.0, mat(0, 0));
574     }
575   };
576 
577   for (int i = 0; i < 4; ++i) {
578     tp->Schedule(fn);
579   }
580 
581   // Wait for the functions to finish.
582   delete tp;
583 }
584 
TEST_F(DirectSessionMinusAXTest,TestConcurrency_Callable)585 TEST_F(DirectSessionMinusAXTest, TestConcurrency_Callable) {
586   Initialize({1, 2, 3, 4});
587   auto session = CreateSession();
588   ASSERT_TRUE(session != nullptr);
589   TF_ASSERT_OK(session->Create(def_));
590 
591   // Fill in the input and ask for the output
592   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
593 
594   Session::CallableHandle handle;
595   TF_ASSERT_OK(
596       session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle));
597 
598   // Run the callable 1000 times in 4 different threads concurrently.
599   auto fn = [&session, handle]() {
600     for (int i = 0; i < 1000; ++i) {
601       std::vector<Tensor> outputs;
602       // Run the graph
603       TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
604       ASSERT_EQ(1, outputs.size());
605       auto mat = outputs[0].matrix<float>();
606       EXPECT_FLOAT_EQ(3.0, mat(0, 0));
607     }
608   };
609 
610   for (int i = 0; i < 4; ++i) {
611     tp->Schedule(fn);
612   }
613 
614   // Wait for the functions to finish.
615   delete tp;
616 }
617 
TEST_F(DirectSessionMinusAXTest,TestPerSessionThreads)618 TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) {
619   Initialize({1, 2, 3, 4});
620 
621   SessionOptions options;
622   options.config.set_use_per_session_threads(true);
623   (*options.config.mutable_device_count())["CPU"] = 2;
624   std::unique_ptr<Session> session(NewSession(options));
625 
626   ASSERT_TRUE(session != nullptr);
627   TF_ASSERT_OK(session->Create(def_));
628 
629   // Fill in the input and ask for the output
630   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
631 
632   // Run the graph 1000 times in 4 different threads concurrently.
633   std::vector<string> output_names = {y_ + ":0"};
634   auto fn = [&session, output_names]() {
635     for (int i = 0; i < 1000; ++i) {
636       std::vector<std::pair<string, Tensor>> inputs;
637       std::vector<Tensor> outputs;
638       // Run the graph
639       Status s = session->Run(inputs, output_names, {}, &outputs);
640       TF_ASSERT_OK(s);
641       ASSERT_EQ(1, outputs.size());
642       auto mat = outputs[0].matrix<float>();
643       EXPECT_FLOAT_EQ(3.0, mat(0, 0));
644     }
645   };
646 
647   for (int i = 0; i < 4; ++i) {
648     tp->Schedule(fn);
649   }
650 
651   // Wait for the functions to finish.
652   delete tp;
653 }
654 
TEST_F(DirectSessionMinusAXTest,TwoCreateCallsFails)655 TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
656   Initialize({1, 2, 3, 4});
657   auto session = CreateSession();
658   ASSERT_TRUE(session != nullptr);
659   TF_ASSERT_OK(session->Create(def_));
660 
661   // Second is not.
662   ASSERT_FALSE(session->Create(def_).ok());
663 }
664 
TEST_F(DirectSessionMinusAXTest,ForgetToCreate)665 TEST_F(DirectSessionMinusAXTest, ForgetToCreate) {
666   Initialize({1, 2, 3, 4});
667   auto session = CreateSession();
668   ASSERT_TRUE(session != nullptr);
669   std::vector<std::pair<string, Tensor>> inputs;
670   std::vector<Tensor> outputs;
671   ASSERT_FALSE(session->Run(inputs, {y_ + ":0"}, {y_neg_}, &outputs).ok());
672 }
673 
TEST_F(DirectSessionMinusAXTest,InvalidDevice)674 TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
675   GraphDef def;
676   Graph graph(OpRegistry::Global());
677 
678   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
679   a_tensor.flat<float>().setRandom();
680   Node* a = test::graph::Constant(&graph, a_tensor);
681   a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
682   Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
683   x_tensor.flat<float>().setRandom();
684   Node* x = test::graph::Constant(&graph, x_tensor);
685   x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
686   // Skip placing y.
687   Node* y = test::graph::Matmul(&graph, a, x, false, false);
688   y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:2");
689 
690   graph.ToGraphDef(&def);
691 
692   SessionOptions options;
693   (*options.config.mutable_device_count())["CPU"] = 2;
694   std::unique_ptr<Session> session(NewSession(options));
695   ASSERT_TRUE(session != nullptr);
696   // Should return an error.
697   ASSERT_FALSE(session->Create(def).ok());
698 
699   // Fix placement and run again
700   def.Clear();
701   y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
702   graph.ToGraphDef(&def);
703   session.reset(NewSession(options));
704   TF_ASSERT_OK(session->Create(def));
705   std::vector<Tensor> outputs;
706   TF_ASSERT_OK(session->Run({}, {y->name() + ":0"}, {}, &outputs));
707 }
708 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetworkWithOpts)709 TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
710   Initialize({3, 2, -1, 0});
711   auto session = CreateSession();
712   ASSERT_TRUE(session != nullptr);
713   TF_ASSERT_OK(session->Create(def_));
714   std::vector<std::pair<string, Tensor>> inputs;
715 
716   // Request two targets: one fetch output and one non-fetched output.
717   std::vector<string> output_names = {y_ + ":0"};
718   std::vector<string> target_nodes = {y_neg_};
719   std::vector<Tensor> outputs;
720 
721   // Prepares RunOptions and RunMetadata
722   RunOptions run_options;
723   run_options.set_trace_level(RunOptions::SOFTWARE_TRACE);
724   RunMetadata run_metadata;
725   EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
726 
727   Status s = session->Run(run_options, inputs, output_names, target_nodes,
728                           &outputs, &run_metadata);
729   TF_ASSERT_OK(s);
730 
731   ASSERT_EQ(1, outputs.size());
732   // The first output should be initialized and have the correct
733   // output.
734   auto mat = outputs[0].matrix<float>();
735   ASSERT_TRUE(outputs[0].IsInitialized());
736   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
737 
738   // Checks RunMetadata is well-formed
739   ASSERT_TRUE(run_metadata.has_step_stats());
740   EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
741 }
742 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetworkWithOpts_Callable)743 TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
744   Initialize({3, 2, -1, 0});
745   auto session = CreateSession();
746   ASSERT_TRUE(session != nullptr);
747   TF_ASSERT_OK(session->Create(def_));
748 
749   // Request two targets: one fetch output and one non-fetched output.
750   Session::CallableHandle handle;
751   CallableOptions callable_options =
752       MakeCallableOptions({}, {y_ + ":0"}, {y_neg_});
753   callable_options.mutable_run_options()->set_trace_level(
754       RunOptions::SOFTWARE_TRACE);
755   TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
756 
757   RunMetadata run_metadata;
758   EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
759 
760   std::vector<Tensor> outputs;
761   TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, &run_metadata));
762 
763   ASSERT_EQ(1, outputs.size());
764   // The first output should be initialized and have the correct
765   // output.
766   auto mat = outputs[0].matrix<float>();
767   ASSERT_TRUE(outputs[0].IsInitialized());
768   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
769 
770   // Checks RunMetadata is well-formed
771   ASSERT_TRUE(run_metadata.has_step_stats());
772   EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
773 }
774 
TEST_F(DirectSessionMinusAXTest,UseRunHandlerPool)775 TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
776   Initialize({3, 2, -1, 0});
777   auto session = CreateSession();
778   ASSERT_TRUE(session != nullptr);
779   TF_ASSERT_OK(session->Create(def_));
780   std::vector<std::pair<string, Tensor>> inputs;
781 
782   // Request two targets: one fetch output and one non-fetched output.
783   std::vector<string> output_names = {y_ + ":0"};
784   std::vector<string> target_nodes = {y_neg_};
785   std::vector<Tensor> outputs;
786 
787   // Prepares RunOptions and RunMetadata
788   RunOptions run_options;
789   run_options.mutable_experimental()->set_use_run_handler_pool(true);
790 
791   Status s = session->Run(run_options, inputs, output_names, target_nodes,
792                           &outputs, nullptr);
793   TF_ASSERT_OK(s);
794 
795   ASSERT_EQ(1, outputs.size());
796   // The first output should be initialized and have the correct
797   // output.
798   auto mat = outputs[0].matrix<float>();
799   ASSERT_TRUE(outputs[0].IsInitialized());
800   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
801 }
802 
TEST(DirectSessionTest,KeepsStateAcrossRunsOfSession)803 TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
804   GraphDef def;
805   Graph g(OpRegistry::Global());
806   Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
807   var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
808 
809   Tensor twenty(DT_FLOAT, TensorShape({10}));
810   for (int i = 0; i < 10; ++i) {
811     twenty.flat<float>()(i) = 20.0;
812   }
813 
814   Node* twenty_node = test::graph::Constant(&g, twenty);
815   twenty_node->set_assigned_device_name(
816       "/job:localhost/replica:0/task:0/cpu:0");
817 
818   Node* init = test::graph::Assign(&g, var, twenty_node);
819   init->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
820 
821   g.ToGraphDef(&def);
822 
823   auto session = CreateSession();
824   ASSERT_TRUE(session != nullptr);
825   TF_ASSERT_OK(session->Create(def));
826 
827   std::vector<std::pair<string, Tensor>> inputs;
828   std::vector<Tensor> outputs;
829 
830   // Initialize the variable
831   Status s = session->Run(inputs, {init->name()}, {}, &outputs);
832   TF_ASSERT_OK(s);
833 
834   // Get the variable's data
835   s = session->Run(inputs, {var->name() + ":0"}, {}, &outputs);
836   TF_ASSERT_OK(s);
837   ASSERT_EQ(1, outputs.size());
838   ASSERT_TRUE(outputs[0].IsInitialized());
839   EXPECT_EQ(20.0, outputs[0].flat<float>()(0));
840 }
841 
TEST(DirectSessionTest,MultipleFeedTest)842 TEST(DirectSessionTest, MultipleFeedTest) {
843   GraphDef def;
844   Graph g(OpRegistry::Global());
845 
846   Tensor first_value(DT_FLOAT, TensorShape({}));
847   first_value.scalar<float>()() = 1.0;
848   Node* first_const = test::graph::Constant(&g, first_value);
849   Node* first_identity = test::graph::Identity(&g, first_const);
850 
851   Tensor second_value(DT_FLOAT, TensorShape({}));
852   second_value.scalar<float>()() = 2.0;
853   Node* second_const = test::graph::Constant(&g, second_value);
854   Node* second_identity = test::graph::Identity(&g, second_const);
855 
856   g.ToGraphDef(&def);
857 
858   auto session = CreateSession();
859   ASSERT_TRUE(session != nullptr);
860   TF_ASSERT_OK(session->Create(def));
861 
862   std::vector<Tensor> outputs;
863 
864   // Fetch without feeding.
865   Status s = session->Run(
866       {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
867       &outputs);
868   TF_ASSERT_OK(s);
869   ASSERT_EQ(2, outputs.size());
870   ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
871   ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
872 
873   s = session->Run(
874       {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
875       &outputs);
876   TF_ASSERT_OK(s);
877   ASSERT_EQ(2, outputs.size());
878   ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
879   ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
880 
881   Tensor value_11(DT_FLOAT, TensorShape({}));
882   value_11.scalar<float>()() = 11.0;
883   Tensor value_22(DT_FLOAT, TensorShape({}));
884   value_22.scalar<float>()() = 22.0;
885 
886   // Feed [first_const, second_const]
887   s = session->Run(
888       {{first_const->name(), value_11}, {second_const->name(), value_22}},
889       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
890       &outputs);
891   TF_ASSERT_OK(s);
892   ASSERT_EQ(2, outputs.size());
893   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
894   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
895 
896   // Feed [second_const, first_const]
897   s = session->Run(
898       {{second_const->name(), value_22}, {first_const->name(), value_11}},
899       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
900       &outputs);
901   TF_ASSERT_OK(s);
902   ASSERT_EQ(2, outputs.size());
903   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
904   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
905 
906   // Feed [first_const, first_const]
907   s = session->Run(
908       {{first_const->name(), value_11}, {first_const->name(), value_22}},
909       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
910       &outputs);
911   EXPECT_TRUE(errors::IsInvalidArgument(s));
912   EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
913 }
914 
TEST(DirectSessionTest,MultipleFeedTest_Callable)915 TEST(DirectSessionTest, MultipleFeedTest_Callable) {
916   GraphDef def;
917   Graph g(OpRegistry::Global());
918 
919   Tensor first_value(DT_FLOAT, TensorShape({}));
920   first_value.scalar<float>()() = 1.0;
921   Node* first_const = test::graph::Constant(&g, first_value);
922   Node* first_identity = test::graph::Identity(&g, first_const);
923 
924   Tensor second_value(DT_FLOAT, TensorShape({}));
925   second_value.scalar<float>()() = 2.0;
926   Node* second_const = test::graph::Constant(&g, second_value);
927   Node* second_identity = test::graph::Identity(&g, second_const);
928 
929   g.ToGraphDef(&def);
930 
931   auto session = CreateSession();
932   ASSERT_TRUE(session != nullptr);
933   TF_ASSERT_OK(session->Create(def));
934 
935   Session::CallableHandle handle;
936   std::vector<Tensor> outputs;
937 
938   // Fetch without feeding.
939   TF_ASSERT_OK(session->MakeCallable(
940       MakeCallableOptions(
941           {}, {first_identity->name() + ":0", second_identity->name() + ":0"},
942           {}),
943       &handle));
944   TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
945   ASSERT_EQ(2, outputs.size());
946   ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
947   ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
948 
949   TF_ASSERT_OK(session->MakeCallable(
950       MakeCallableOptions(
951           {}, {second_identity->name() + ":0", first_identity->name() + ":0"},
952           {}),
953       &handle));
954   TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
955   ASSERT_EQ(2, outputs.size());
956   ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
957   ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
958 
959   Tensor value_11(DT_FLOAT, TensorShape({}));
960   value_11.scalar<float>()() = 11.0;
961   Tensor value_22(DT_FLOAT, TensorShape({}));
962   value_22.scalar<float>()() = 22.0;
963 
964   // Feed [first_const, second_const]
965   TF_ASSERT_OK(session->MakeCallable(
966       MakeCallableOptions(
967           {first_const->name(), second_const->name()},
968           {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
969       &handle));
970   TF_ASSERT_OK(
971       session->RunCallable(handle, {value_11, value_22}, &outputs, nullptr));
972   ASSERT_EQ(2, outputs.size());
973   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
974   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
975 
976   // Feed [second_const, first_const]
977   TF_ASSERT_OK(session->MakeCallable(
978       MakeCallableOptions(
979           {second_const->name(), first_const->name()},
980           {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
981       &handle));
982   TF_ASSERT_OK(
983       session->RunCallable(handle, {value_22, value_11}, &outputs, nullptr));
984   ASSERT_EQ(2, outputs.size());
985   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
986   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
987 
988   // Feed [first_const, first_const]
989   Status s = session->MakeCallable(
990       MakeCallableOptions(
991           {first_const->name(), first_const->name()},
992           {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
993       &handle);
994   EXPECT_TRUE(errors::IsInvalidArgument(s));
995   EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
996 }
997 
TEST(DirectSessionTest,TestTensorConnectionUseTwice)998 TEST(DirectSessionTest, TestTensorConnectionUseTwice) {
999   Graph graph(OpRegistry::Global());
1000 
1001   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
1002   test::FillValues<float>(&a_tensor, {1.0, 2.0, 3.0, 4.0});
1003   Node* a = test::graph::Constant(&graph, a_tensor);
1004 
1005   Tensor dummy_tensor(DT_FLOAT, TensorShape({1}));
1006   test::FillValues<float>(&dummy_tensor, {-1.0});
1007 
1008   Node* left = test::graph::Constant(&graph, dummy_tensor);
1009   Node* right = test::graph::Constant(&graph, dummy_tensor);
1010 
1011   // y = A * x
1012   Node* y = test::graph::Add(&graph, left, right);
1013 
1014   GraphDef def;
1015   graph.ToGraphDef(&def);
1016 
1017   auto session = CreateSession();
1018   ASSERT_TRUE(session != nullptr);
1019   TF_ASSERT_OK(session->Create(def));
1020 
1021   CallableOptions callable_options;
1022   // Directly wire the output of node a to the outputs of nodes left
1023   // and right, making the callable graph into "a + a;".
1024   TensorConnection* c_left = callable_options.add_tensor_connection();
1025   c_left->set_from_tensor(a->name() + ":0");
1026   c_left->set_to_tensor(left->name() + ":0");
1027   TensorConnection* c_right = callable_options.add_tensor_connection();
1028   c_right->set_from_tensor(a->name() + ":0");
1029   c_right->set_to_tensor(right->name() + ":0");
1030 
1031   callable_options.add_fetch(y->name() + ":0");
1032 
1033   Session::CallableHandle handle;
1034   TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
1035   std::vector<Tensor> outputs;
1036   TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
1037   ASSERT_EQ(1, outputs.size());
1038   auto mat = outputs[0].matrix<float>();
1039   ASSERT_TRUE(outputs[0].IsInitialized());
1040   EXPECT_FLOAT_EQ(2.0, mat(0, 0));
1041   EXPECT_FLOAT_EQ(4.0, mat(0, 1));
1042   EXPECT_FLOAT_EQ(6.0, mat(1, 0));
1043   EXPECT_FLOAT_EQ(8.0, mat(1, 1));
1044   TF_ASSERT_OK(session->ReleaseCallable(handle));
1045 }
1046 
TEST(DirectSessionTest,FetchMultipleTimes)1047 TEST(DirectSessionTest, FetchMultipleTimes) {
1048   Graph g(OpRegistry::Global());
1049   Tensor seven_tensor(DT_INT32, TensorShape());
1050   seven_tensor.flat<int32>()(0) = 7;
1051   Node* seven_node = test::graph::Constant(&g, seven_tensor);
1052 
1053   GraphDef def;
1054   g.ToGraphDef(&def);
1055 
1056   auto session = CreateSession();
1057   ASSERT_TRUE(session != nullptr);
1058   TF_ASSERT_OK(session->Create(def));
1059 
1060   const std::vector<std::pair<string, Tensor>> inputs;
1061   std::vector<Tensor> outputs;
1062 
1063   auto seven = seven_node->name();
1064   Status s = session->Run(inputs, {seven, seven}, {}, &outputs);
1065   TF_ASSERT_OK(s);
1066 
1067   EXPECT_EQ(2, outputs.size());
1068   for (int i = 0; i < outputs.size(); ++i) {
1069     const Tensor& t = outputs[i];
1070     ASSERT_TRUE(t.IsInitialized()) << i;
1071     EXPECT_EQ(7, t.flat<int32>()(0)) << i;
1072   }
1073 }
1074 
TEST(DirectSessionTest,MultipleFeedTestSomeSyncRun)1075 TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) {
1076   GraphDef def;
1077   Graph g(OpRegistry::Global());
1078   RunOptions run_options;
1079   run_options.set_inter_op_thread_pool(-1);
1080 
1081   Tensor first_value(DT_FLOAT, TensorShape({}));
1082   first_value.scalar<float>()() = 1.0;
1083   Node* first_const = test::graph::Constant(&g, first_value);
1084   Node* first_identity = test::graph::Identity(&g, first_const);
1085 
1086   Tensor second_value(DT_FLOAT, TensorShape({}));
1087   second_value.scalar<float>()() = 2.0;
1088   Node* second_const = test::graph::Constant(&g, second_value);
1089   Node* second_identity = test::graph::Identity(&g, second_const);
1090 
1091   g.ToGraphDef(&def);
1092 
1093   auto session = CreateSession();
1094   ASSERT_TRUE(session != nullptr);
1095   TF_ASSERT_OK(session->Create(def));
1096 
1097   std::vector<Tensor> outputs;
1098 
1099   // Fetch without feeding.
1100   Status s = session->Run(
1101       run_options, {},
1102       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1103       &outputs, nullptr);
1104   TF_ASSERT_OK(s);
1105   ASSERT_EQ(2, outputs.size());
1106   ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
1107   ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
1108 
1109   s = session->Run(
1110       {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
1111       &outputs);
1112   TF_ASSERT_OK(s);
1113   ASSERT_EQ(2, outputs.size());
1114   ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
1115   ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
1116 
1117   Tensor value_11(DT_FLOAT, TensorShape({}));
1118   value_11.scalar<float>()() = 11.0;
1119   Tensor value_22(DT_FLOAT, TensorShape({}));
1120   value_22.scalar<float>()() = 22.0;
1121 
1122   // Feed [first_const, second_const]
1123   s = session->Run(
1124       {{first_const->name(), value_11}, {second_const->name(), value_22}},
1125       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1126       &outputs);
1127   TF_ASSERT_OK(s);
1128   ASSERT_EQ(2, outputs.size());
1129   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1130   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
1131 
1132   // Feed [second_const, first_const]
1133   s = session->Run(
1134       {{second_const->name(), value_22}, {first_const->name(), value_11}},
1135       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1136       &outputs);
1137   TF_ASSERT_OK(s);
1138   ASSERT_EQ(2, outputs.size());
1139   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1140   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
1141 
1142   // Feed [first_const, first_const]
1143   s = session->Run(
1144       run_options,
1145       {{first_const->name(), value_11}, {first_const->name(), value_22}},
1146       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1147       &outputs, nullptr);
1148   EXPECT_TRUE(errors::IsInvalidArgument(s));
1149   EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
1150 }
1151 
1152 REGISTER_OP("SessionMetadataReader")
1153     .Input("x: int64")
1154     .Output("y: string")
1155     .SetIsStateful()
1156     .Doc(R"doc(SessionMetadataReader returns the session metadata.
1157 
1158 x: int64
1159 y: string
1160 )doc");
1161 
1162 class SessionMetadataReaderOp : public OpKernel {
1163  public:
SessionMetadataReaderOp(OpKernelConstruction * ctx)1164   explicit SessionMetadataReaderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1165   void Compute(OpKernelContext* ctx) override {
1166     Tensor* out_tensor = nullptr;
1167     OP_REQUIRES_OK(ctx,
1168                    ctx->allocate_output("y", TensorShape({}), &out_tensor));
1169     if (ctx->session_metadata() != nullptr) {
1170       out_tensor->scalar<tstring>()() = ctx->session_metadata()->DebugString();
1171     } else {
1172       out_tensor->scalar<tstring>()() = "";
1173     }
1174   }
1175 };
1176 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_CPU),
1177                         SessionMetadataReaderOp);
1178 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_GPU),
1179                         SessionMetadataReaderOp);
1180 
SessionMetadataReaderOpFn()1181 FunctionDef SessionMetadataReaderOpFn() {
1182   return FunctionDefHelper::Define(
1183       // Name
1184       "SessionMetadataReaderFn",
1185       // Args
1186       {"x: int64"},
1187       // Return values
1188       {"y: string"},
1189       // Attr def
1190       {},
1191       // Nodes
1192       {{{"y"}, "SessionMetadataReader", {"x"}, {}}});
1193 }
1194 
TEST(DirectSessionTest,SessionMetadataAbsent)1195 TEST(DirectSessionTest, SessionMetadataAbsent) {
1196   Graph g(OpRegistry::Global());
1197   Tensor vx(DT_INT64, TensorShape({}));
1198   vx.scalar<int64>()() = 17;
1199   Node* x = test::graph::Constant(&g, vx);
1200   Node* y = test::graph::Unary(&g, "SessionMetadataReader", x);
1201   GraphDef def;
1202   g.ToGraphDef(&def);
1203   auto sess = CreateSession();
1204   TF_ASSERT_OK(sess->Create(def));
1205   std::vector<Tensor> outputs;
1206   RunOptions run_opts;
1207   run_opts.set_inter_op_thread_pool(-1);
1208   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1209 
1210   EXPECT_EQ("", outputs[0].scalar<tstring>()());
1211 }
1212 
TEST(DirectSessionTest,SessionMetadataAbsentViaFunction)1213 TEST(DirectSessionTest, SessionMetadataAbsentViaFunction) {
1214   FunctionDefLibrary library_graph_def;
1215   *library_graph_def.add_function() = SessionMetadataReaderOpFn();
1216   FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1217   Graph g(&flib);
1218   Tensor vx(DT_INT64, TensorShape({}));
1219   vx.scalar<int64>()() = 17;
1220   Node* x = test::graph::Constant(&g, vx);
1221   Node* y = test::graph::Unary(&g, "SessionMetadataReaderFn", x);
1222   GraphDef def;
1223   g.ToGraphDef(&def);
1224   *def.mutable_library() = library_graph_def;
1225   auto sess = CreateSession();
1226   TF_ASSERT_OK(sess->Create(def));
1227   std::vector<Tensor> outputs;
1228   RunOptions run_opts;
1229   run_opts.set_inter_op_thread_pool(-1);
1230   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1231 
1232   EXPECT_EQ("", outputs[0].scalar<tstring>()());
1233 }
1234 
TEST(DirectSessionTest,SessionMetadataPresent)1235 TEST(DirectSessionTest, SessionMetadataPresent) {
1236   Graph g(OpRegistry::Global());
1237   Tensor vx(DT_INT64, TensorShape({}));
1238   vx.scalar<int64>()() = 17;
1239   Node* x = test::graph::Constant(&g, vx);
1240   Node* y = test::graph::Unary(&g, "SessionMetadataReader", x);
1241   GraphDef def;
1242   g.ToGraphDef(&def);
1243   auto session_options = DefaultSessionOptions();
1244   auto* session_metadata =
1245       session_options.config.mutable_experimental()->mutable_session_metadata();
1246   session_metadata->set_name("name");
1247   session_metadata->set_version(1);
1248   auto sess = std::unique_ptr<Session>(NewSession(session_options));
1249   TF_ASSERT_OK(sess->Create(def));
1250   std::vector<Tensor> outputs;
1251   RunOptions run_opts;
1252   run_opts.set_inter_op_thread_pool(-1);
1253   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1254 
1255   SessionMetadata read_metadata;
1256   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1257       outputs[0].scalar<tstring>()(), &read_metadata));
1258   EXPECT_EQ("name", read_metadata.name());
1259   EXPECT_EQ(1, read_metadata.version());
1260 }
1261 
TEST(DirectSessionTest,SessionMetadataPresentViaFunction)1262 TEST(DirectSessionTest, SessionMetadataPresentViaFunction) {
1263   FunctionDefLibrary library_graph_def;
1264   *library_graph_def.add_function() = SessionMetadataReaderOpFn();
1265   FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1266   Graph g(&flib);
1267   Tensor vx(DT_INT64, TensorShape({}));
1268   vx.scalar<int64>()() = 17;
1269   Node* x = test::graph::Constant(&g, vx);
1270   Node* y = test::graph::Unary(&g, "SessionMetadataReaderFn", x);
1271   GraphDef def;
1272   g.ToGraphDef(&def);
1273   *def.mutable_library() = library_graph_def;
1274   auto session_options = DefaultSessionOptions();
1275   auto* session_metadata =
1276       session_options.config.mutable_experimental()->mutable_session_metadata();
1277   session_metadata->set_name("name");
1278   session_metadata->set_version(1);
1279   auto sess = std::unique_ptr<Session>(NewSession(session_options));
1280   TF_ASSERT_OK(sess->Create(def));
1281   std::vector<Tensor> outputs;
1282   RunOptions run_opts;
1283   run_opts.set_inter_op_thread_pool(-1);
1284   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1285 
1286   SessionMetadata read_metadata;
1287   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1288       outputs[0].scalar<tstring>()(), &read_metadata));
1289   EXPECT_EQ("name", read_metadata.name());
1290   EXPECT_EQ(1, read_metadata.version());
1291 }
1292 
TEST(DirectSessionTest,SessionMetadataKey)1293 TEST(DirectSessionTest, SessionMetadataKey) {
1294   auto session_options0 = DefaultSessionOptions();
1295   auto* session_metadata0 = session_options0.config.mutable_experimental()
1296                                 ->mutable_session_metadata();
1297   session_metadata0->set_name("name");
1298   Session* sess0_ptr;
1299   ASSERT_TRUE(NewSession(session_options0, &sess0_ptr).ok());
1300   auto sess0 = absl::WrapUnique(sess0_ptr);
1301 
1302   // Trying to use the same metadata (name, version) will cause an error.
1303   Session* dup_ptr;
1304   EXPECT_TRUE(
1305       errors::IsInvalidArgument(NewSession(session_options0, &dup_ptr)));
1306 
1307   // A new (name, version) is fine.
1308   auto session_options1 = DefaultSessionOptions();
1309   auto* session_metadata1 = session_options1.config.mutable_experimental()
1310                                 ->mutable_session_metadata();
1311   session_metadata1->set_name("name");
1312   session_metadata1->set_version(1);
1313   Session* sess1_ptr;
1314   EXPECT_TRUE(NewSession(session_options1, &sess1_ptr).ok());
1315   auto sess1 = absl::WrapUnique(sess1_ptr);
1316 
1317   // If the previous session, using the same (name, version) is gone, then it's
1318   // fine.
1319   sess0 = nullptr;
1320   EXPECT_TRUE(NewSession(session_options0, &dup_ptr).ok());
1321   auto dup = absl::WrapUnique(dup_ptr);
1322 
1323   // Sessions without metadata options are always fine.
1324   auto sess_without_metadata0 = CreateSession();
1325   EXPECT_NE(sess_without_metadata0, nullptr);
1326   auto sess_without_metadata1 = CreateSession();
1327   EXPECT_NE(sess_without_metadata1, nullptr);
1328 }
1329 
TEST(DirectSessionTest,SessionMetadataInvalid)1330 TEST(DirectSessionTest, SessionMetadataInvalid) {
1331   const auto valid_session_options = DefaultSessionOptions();
1332   Session* sess_ptr;
1333   ASSERT_TRUE(NewSession(valid_session_options, &sess_ptr).ok());
1334   auto sess = absl::WrapUnique(sess_ptr);
1335 
1336   auto invalid_session_options = valid_session_options;
1337   auto* invalid_metadata =
1338       invalid_session_options.config.mutable_experimental()
1339           ->mutable_session_metadata();
1340   invalid_metadata->set_name("name");
1341   // Version should be >= 0.
1342   invalid_metadata->set_version(-1);
1343   Session* error_sess_ptr;
1344   EXPECT_TRUE(errors::IsInvalidArgument(
1345       NewSession(invalid_session_options, &error_sess_ptr)));
1346 }
1347 
1348 REGISTER_OP("ThreadID").Input("x: int64").Output("y: int64").Doc(R"doc(
1349 ThreadID returns the thread ID that called compute.
1350 
1351 x: int64
1352 y: int64
1353 )doc");
1354 
1355 // The ThreadID kernel returns the thread ID that executed Compute.
1356 class ThreadIDOp : public OpKernel {
1357  public:
ThreadIDOp(OpKernelConstruction * ctx)1358   explicit ThreadIDOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1359   void Compute(OpKernelContext* ctx) override {
1360     Tensor* out_tensor = nullptr;
1361     OP_REQUIRES_OK(ctx,
1362                    ctx->allocate_output("y", TensorShape({}), &out_tensor));
1363     std::hash<std::thread::id> hasher;
1364     out_tensor->scalar<int64>()() =
1365         static_cast<int64>(hasher(std::this_thread::get_id()));
1366   }
1367 };
1368 REGISTER_KERNEL_BUILDER(Name("ThreadID").Device(DEVICE_CPU), ThreadIDOp);
1369 
TEST(DirectSessionTest,SessionSyncRun)1370 TEST(DirectSessionTest, SessionSyncRun) {
1371   Graph g(OpRegistry::Global());
1372   Tensor vx(DT_INT64, TensorShape({}));
1373   vx.scalar<int64>()() = 17;
1374   Node* x = test::graph::Constant(&g, vx);
1375   Node* y = test::graph::Unary(&g, "ThreadID", x);
1376   GraphDef def;
1377   g.ToGraphDef(&def);
1378   auto sess = CreateSession();
1379   TF_ASSERT_OK(sess->Create(def));
1380   std::vector<Tensor> outputs;
1381   RunOptions run_opts;
1382   run_opts.set_inter_op_thread_pool(-1);
1383   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1384 
1385   std::hash<std::thread::id> hasher;
1386   EXPECT_EQ(static_cast<int64>(hasher(std::this_thread::get_id())),
1387             static_cast<int64>(outputs[0].scalar<int64>()()));
1388 }
1389 
1390 REGISTER_OP("ExpensiveNoop").SetIsStateful();
1391 
1392 class ExpensiveNoopOp : public OpKernel {
1393  public:
1394   using OpKernel::OpKernel;
IsExpensive()1395   bool IsExpensive() override { return true; }
Compute(OpKernelContext * ctx)1396   void Compute(OpKernelContext* ctx) override {
1397     const string& stack_trace = tensorflow::CurrentStackTrace();
1398     const string process_method = "ExecutorState::Process()";
1399     size_t pos = 0;
1400     int frame_count = 0;
1401     while ((pos = stack_trace.find("ExecutorState::Process()", pos)) !=
1402            string::npos) {
1403       ++frame_count;
1404       ++pos;
1405     }
1406     OP_REQUIRES(ctx, frame_count <= 1,
1407                 errors::Internal(
1408                     "Recursive call to ExecutorState::Process() detected."));
1409   }
1410 };
1411 
1412 REGISTER_KERNEL_BUILDER(Name("ExpensiveNoop").Device(DEVICE_CPU),
1413                         ExpensiveNoopOp);
1414 
TEST(DirectSessionTest,SessionSyncRun_DeepGraph)1415 TEST(DirectSessionTest, SessionSyncRun_DeepGraph) {
1416   Graph g(OpRegistry::Global());
1417 
1418   std::vector<Node*> nodes;
1419   nodes.reserve(1024);
1420 
1421   auto make_expensive_noop = [&g](gtl::ArraySlice<Node*> control_deps) {
1422     Node* ret;
1423     auto builder = NodeBuilder(g.NewName("N"), "ExpensiveNoop");
1424     for (Node* control_dep : control_deps) {
1425       builder = builder.ControlInput(control_dep);
1426     }
1427     TF_CHECK_OK(builder.Finalize(&g, &ret));
1428     return ret;
1429   };
1430 
1431   Node* base = make_expensive_noop({});
1432 
1433   Node* child_1 = make_expensive_noop({base});
1434   Node* child_2 = make_expensive_noop({base});
1435 
1436   GraphDef def;
1437   g.ToGraphDef(&def);
1438 
1439   auto sess = CreateSession();
1440   TF_ASSERT_OK(sess->Create(def));
1441   std::vector<Tensor> outputs;
1442   RunOptions run_opts;
1443   run_opts.set_inter_op_thread_pool(-1);
1444 
1445   EXPECT_TRUE(sess->Run(run_opts, {}, {}, {child_1->name(), child_2->name()},
1446                         &outputs, nullptr)
1447                   .ok());
1448 }
1449 
TEST(DirectSessionTest,SyncSession)1450 TEST(DirectSessionTest, SyncSession) {
1451   Graph g(OpRegistry::Global());
1452   Tensor vx(DT_INT64, TensorShape({}));
1453   vx.scalar<int64>()() = 17;
1454   Node* x = test::graph::Constant(&g, vx);
1455   Node* y = test::graph::Unary(&g, "ThreadID", x);
1456   GraphDef def;
1457   g.ToGraphDef(&def);
1458   SessionOptions options;
1459   options.config.set_inter_op_parallelism_threads(-1);
1460   std::unique_ptr<Session> sess(NewSession(options));
1461   TF_ASSERT_OK(sess->Create(def));
1462   std::vector<Tensor> outputs;
1463   RunOptions run_opts;
1464   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1465 
1466   std::hash<std::thread::id> hasher;
1467   EXPECT_EQ(static_cast<int64>(hasher(std::this_thread::get_id())),
1468             static_cast<int64>(outputs[0].scalar<int64>()()));
1469 }
1470 
1471 REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc(
1472 Darth promises one return value.
1473 
1474 x: float
1475 y: float
1476 )doc");
1477 
1478 // The DarthOp kernel violates its promise to return one-value.
1479 class DarthOp : public OpKernel {
1480  public:
DarthOp(OpKernelConstruction * ctx)1481   explicit DarthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1482   void Compute(OpKernelContext* ctx) override {}
1483 };
1484 REGISTER_KERNEL_BUILDER(Name("Darth").Device(DEVICE_CPU), DarthOp);
1485 
TEST(DirectSessionTest,DarthKernel)1486 TEST(DirectSessionTest, DarthKernel) {
1487   Graph g(OpRegistry::Global());
1488   Tensor vx(DT_FLOAT, TensorShape({}));
1489   vx.scalar<float>()() = 1.0;
1490   Node* x = test::graph::Constant(&g, vx);
1491   Node* y = test::graph::Unary(&g, "Darth", x);
1492   GraphDef def;
1493   g.ToGraphDef(&def);
1494   auto sess = CreateSession();
1495   TF_ASSERT_OK(sess->Create(def));
1496   std::vector<Tensor> outputs;
1497   auto s = sess->Run({}, {y->name() + ":0"}, {}, &outputs);
1498   EXPECT_TRUE(errors::IsInternal(s));
1499 }
1500 
1501 // Have the Darth op in the graph placed on GPU, but don't run it.
TEST(DirectSessionTest,PlacePrunedGraph)1502 TEST(DirectSessionTest, PlacePrunedGraph) {
1503   {
1504     Graph g(OpRegistry::Global());
1505     Tensor vx(DT_FLOAT, TensorShape({}));
1506     vx.scalar<float>()() = 1.0;
1507     Node* x = test::graph::Constant(&g, vx);
1508     Node* y = test::graph::Unary(&g, "Darth", x);
1509     y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
1510     GraphDef def;
1511     g.ToGraphDef(&def);
1512 
1513     // By default, we place the entire graph, so we should fail the
1514     // call to Create.
1515     SessionOptions options;
1516     std::unique_ptr<Session> sess(NewSession(options));
1517     auto s = sess->Create(def);
1518     EXPECT_TRUE(errors::IsInvalidArgument(s));
1519   }
1520 
1521   {
1522     Graph g(OpRegistry::Global());
1523     Tensor vx(DT_FLOAT, TensorShape({}));
1524     vx.scalar<float>()() = 1.0;
1525     Node* x = test::graph::Constant(&g, vx);
1526     Node* y = test::graph::Unary(&g, "Darth", x);
1527     y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
1528     GraphDef def;
1529     g.ToGraphDef(&def);
1530 
1531     SessionOptions options;
1532     // Set the option to place pruned graphs, we should expect this
1533     // to run.
1534     options.config.mutable_graph_options()->set_place_pruned_graph(true);
1535     std::unique_ptr<Session> sess(NewSession(options));
1536     TF_ASSERT_OK(sess->Create(def));
1537     std::vector<Tensor> outputs;
1538     auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs);
1539     TF_EXPECT_OK(s);
1540   }
1541 }
1542 
TEST(DirectSessionTest,PartialRunTest)1543 TEST(DirectSessionTest, PartialRunTest) {
1544   GraphDef def;
1545   Graph g(OpRegistry::Global());
1546 
1547   Tensor first_value(DT_FLOAT, TensorShape({}));
1548   first_value.scalar<float>()() = 1.0;
1549   Node* first_const = test::graph::Constant(&g, first_value);
1550   Node* first_identity = test::graph::Identity(&g, first_const);
1551 
1552   Tensor second_value(DT_FLOAT, TensorShape({}));
1553   second_value.scalar<float>()() = 2.0;
1554   Node* second_const = test::graph::Constant(&g, second_value);
1555   Node* second_identity = test::graph::Identity(&g, second_const);
1556 
1557   Node* third = test::graph::Add(&g, first_identity, second_identity);
1558   Node* third_identity = test::graph::Identity(&g, third);
1559 
1560   g.ToGraphDef(&def);
1561 
1562   auto session = CreateSession();
1563   ASSERT_TRUE(session != nullptr);
1564   TF_ASSERT_OK(session->Create(def));
1565 
1566   std::vector<Tensor> outputs;
1567 
1568   string handle;
1569   Status s = session->PRunSetup(
1570       {first_const->name(), second_const->name()},
1571       {first_identity->name() + ":0", second_identity->name() + ":0",
1572        third_identity->name() + ":0"},
1573       {}, &handle);
1574   TF_ASSERT_OK(s);
1575 
1576   Tensor value_11(DT_FLOAT, TensorShape({}));
1577   value_11.scalar<float>()() = 11.0;
1578   Tensor value_22(DT_FLOAT, TensorShape({}));
1579   value_22.scalar<float>()() = 22.0;
1580 
1581   // Feed first_const, fetch first_identity
1582   s = session->PRun(handle, {{first_const->name(), value_11}},
1583                     {first_identity->name() + ":0"}, &outputs);
1584   TF_ASSERT_OK(s);
1585   ASSERT_EQ(1, outputs.size());
1586   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1587 
1588   // Feed second_const, fetch second_identity and third_identity
1589   s = session->PRun(
1590       handle, {{second_const->name(), value_22}},
1591       {second_identity->name() + ":0", third_identity->name() + ":0"},
1592       &outputs);
1593   TF_ASSERT_OK(s);
1594   ASSERT_EQ(2, outputs.size());
1595   ASSERT_EQ(22.0, outputs[0].flat<float>()(0));
1596   ASSERT_EQ(11.0 + 22.0, outputs[1].flat<float>()(0));
1597 }
1598 
TEST(DirectSessionTest,PartialRunMissingFeed)1599 TEST(DirectSessionTest, PartialRunMissingFeed) {
1600   GraphDef def;
1601   Graph g(OpRegistry::Global());
1602 
1603   Tensor first_value(DT_FLOAT, TensorShape({}));
1604   first_value.scalar<float>()() = 1.0;
1605   Node* first_const = test::graph::Constant(&g, first_value);
1606   Node* first_identity = test::graph::Identity(&g, first_const);
1607 
1608   Tensor second_value(DT_FLOAT, TensorShape({}));
1609   second_value.scalar<float>()() = 2.0;
1610   Node* second_const = test::graph::Constant(&g, second_value);
1611   Node* second_identity = test::graph::Identity(&g, second_const);
1612 
1613   Node* third = test::graph::Add(&g, first_identity, second_identity);
1614   Node* third_identity = test::graph::Identity(&g, third);
1615 
1616   g.ToGraphDef(&def);
1617 
1618   auto session = CreateSession();
1619   ASSERT_TRUE(session != nullptr);
1620   TF_ASSERT_OK(session->Create(def));
1621 
1622   std::vector<Tensor> outputs;
1623 
1624   string handle;
1625   Status s = session->PRunSetup({first_const->name(), second_const->name()},
1626                                 {third_identity->name() + ":0"}, {}, &handle);
1627   TF_ASSERT_OK(s);
1628 
1629   // Feed first_const, fetch third_identity
1630   Tensor value_11(DT_FLOAT, TensorShape({}));
1631   value_11.scalar<float>()() = 11.0;
1632   s = session->PRun(handle, {{first_const->name(), value_11}},
1633                     {third_identity->name() + ":0"}, &outputs);
1634   ASSERT_TRUE(errors::IsInvalidArgument(s));
1635   EXPECT_TRUE(
1636       absl::StrContains(s.error_message(), "can't be computed from the feeds"));
1637 }
1638 
TEST(DirectSessionTest,PartialRunMultiOutputFeed)1639 TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
1640   GraphDef def;
1641   Graph g(OpRegistry::Global());
1642 
1643   Tensor bool_value(DT_BOOL, TensorShape({}));
1644   bool_value.scalar<bool>()() = true;
1645   Node* bool_const = test::graph::Constant(&g, bool_value);
1646   Node* switch_node = test::graph::Switch(&g, bool_const, bool_const);
1647   Node* fourth_identity = test::graph::Identity(&g, switch_node, 1);
1648 
1649   g.ToGraphDef(&def);
1650 
1651   auto session = CreateSession();
1652   ASSERT_TRUE(session != nullptr);
1653   TF_ASSERT_OK(session->Create(def));
1654 
1655   std::vector<Tensor> outputs;
1656 
1657   string handle;
1658   Status s = session->PRunSetup({switch_node->name() + ":1"},
1659                                 {fourth_identity->name() + ":0"}, {}, &handle);
1660   TF_ASSERT_OK(s);
1661 
1662   // Fetch fourth_identity without feeds.
1663   s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs);
1664   ASSERT_TRUE(errors::IsInvalidArgument(s));
1665   EXPECT_TRUE(
1666       absl::StrContains(s.error_message(), "can't be computed from the feeds"));
1667 
1668   // Feed switch_node:1 and fetch fourth_identity.
1669   s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}},
1670                     {fourth_identity->name() + ":0"}, &outputs);
1671   TF_ASSERT_OK(s);
1672   ASSERT_EQ(1, outputs.size());
1673   ASSERT_EQ(true, outputs[0].flat<bool>()(0));
1674 }
1675 
TEST(DirectSessionTest,RunHandleTest)1676 TEST(DirectSessionTest, RunHandleTest) {
1677   GraphDef def;
1678   Graph g(OpRegistry::Global());
1679 
1680   Tensor value0(DT_FLOAT, TensorShape({}));
1681   value0.scalar<float>()() = 1.0;
1682   Node* const0 = test::graph::Constant(&g, value0);
1683   Node* identity0 = test::graph::Identity(&g, const0);
1684 
1685   Tensor value1(DT_FLOAT, TensorShape({}));
1686   value1.scalar<float>()() = 2.0;
1687   Node* const1 = test::graph::Constant(&g, value1);
1688   Node* node3 = test::graph::Add(&g, identity0, const1);
1689   Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
1690 
1691   Tensor value2(DT_STRING, TensorShape({}));
1692   Node* const2 = test::graph::Constant(&g, value2);
1693   Node* node5 = test::graph::GetSessionTensor(&g, const2);
1694   Node* node6 = test::graph::Add(&g, node5, const1);
1695 
1696   Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
1697 
1698   g.ToGraphDef(&def);
1699 
1700   auto session = CreateSession();
1701   ASSERT_TRUE(session != nullptr);
1702   TF_ASSERT_OK(session->Create(def));
1703 
1704   // First run call: Create a handle.
1705   std::vector<Tensor> outputs;
1706   Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
1707   ASSERT_TRUE(s.ok());
1708   ASSERT_EQ(1, outputs.size());
1709 
1710   const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
1711   Tensor string_handle(DT_STRING, {});
1712   string_handle.flat<tstring>().setConstant(resource_handle.name());
1713 
1714   // Second run call: Use a handle.
1715   std::vector<Tensor> outputs1;
1716   s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
1717                    {}, &outputs1);
1718   ASSERT_TRUE(s.ok());
1719   ASSERT_EQ(1, outputs1.size());
1720   ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
1721 
1722   // Third run call: Delete a handle.
1723   std::vector<Tensor> outputs2;
1724   s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
1725                    &outputs2);
1726   ASSERT_TRUE(s.ok());
1727 }
1728 
TEST(DirectSessionTest,RunHandleTest_Callable)1729 TEST(DirectSessionTest, RunHandleTest_Callable) {
1730   GraphDef def;
1731   Graph g(OpRegistry::Global());
1732 
1733   Tensor value0(DT_FLOAT, TensorShape({}));
1734   value0.scalar<float>()() = 1.0;
1735   Node* const0 = test::graph::Constant(&g, value0);
1736   Node* identity0 = test::graph::Identity(&g, const0);
1737 
1738   Tensor value1(DT_FLOAT, TensorShape({}));
1739   value1.scalar<float>()() = 2.0;
1740   Node* const1 = test::graph::Constant(&g, value1);
1741   Node* node3 = test::graph::Add(&g, identity0, const1);
1742   Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
1743 
1744   Tensor value2(DT_STRING, TensorShape({}));
1745   Node* const2 = test::graph::Constant(&g, value2);
1746   Node* node5 = test::graph::GetSessionTensor(&g, const2);
1747   Node* node6 = test::graph::Add(&g, node5, const1);
1748 
1749   Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
1750 
1751   g.ToGraphDef(&def);
1752 
1753   auto session = CreateSession();
1754   ASSERT_TRUE(session != nullptr);
1755   TF_ASSERT_OK(session->Create(def));
1756 
1757   // First run call: Create a handle.
1758   std::vector<Tensor> outputs;
1759   Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
1760   ASSERT_TRUE(s.ok());
1761   ASSERT_EQ(1, outputs.size());
1762 
1763   const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
1764   Tensor string_handle(DT_STRING, {});
1765   string_handle.flat<tstring>().setConstant(resource_handle.name());
1766 
1767   // Second run call: Use a handle.
1768   std::vector<Tensor> outputs1;
1769   s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
1770                    {}, &outputs1);
1771   ASSERT_TRUE(s.ok());
1772   ASSERT_EQ(1, outputs1.size());
1773   ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
1774 
1775   // Third run call: Delete a handle.
1776   std::vector<Tensor> outputs2;
1777   s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
1778                    &outputs2);
1779   ASSERT_TRUE(s.ok());
1780 }
1781 
TEST(DirectSessionTest,CreateGraphFailsWhenAssigningAFedVar)1782 TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) {
1783   Graph graph(OpRegistry::Global());
1784 
1785   Node* a = test::graph::Var(&graph, DT_FLOAT, {});
1786   Node* b = test::graph::Constant(&graph, {});
1787 
1788   Tensor zero(DT_FLOAT, {});
1789   test::FillValues<float>(&zero, {0});
1790 
1791   // a = b
1792   Node* assign = test::graph::Assign(&graph, a, b);
1793 
1794   auto session = CreateSession();
1795   ASSERT_TRUE(session != nullptr);
1796 
1797   // The graph is invalid since a constant cannot be assigned to a constant.
1798   // The return Status of session->Run should flag this as an invalid argument.
1799   std::vector<Tensor> outputs;
1800   Status s = session->Run({{a->name(), zero}}, {assign->name()}, {}, &outputs);
1801   ASSERT_TRUE(errors::IsInvalidArgument(s));
1802 }
1803 
TEST(DirectSessionTest,TimeoutSession)1804 TEST(DirectSessionTest, TimeoutSession) {
1805   GraphDef graph;
1806   // Creates a graph with one FIFOQueue and one dequeue op.
1807   protobuf::TextFormat::ParseFromString(R"proto(
1808     node {
1809       name: 'fifo_queue'
1810       op: 'FIFOQueue'
1811       device: '/device:CPU:0'
1812       attr {
1813         key: 'capacity'
1814         value { i: 10 }
1815       }
1816       attr {
1817         key: 'component_types'
1818         value { list { type: DT_FLOAT } }
1819       }
1820       attr {
1821         key: 'container'
1822         value { s: '' }
1823       }
1824       attr {
1825         key: 'shapes'
1826         value { list {} }
1827       }
1828       attr {
1829         key: 'shared_name'
1830         value { s: '' }
1831       }
1832     }
1833     node {
1834       name: 'fifo_queue_Dequeue'
1835       op: 'QueueDequeue'
1836       input: 'fifo_queue'
1837       device: '/device:CPU:0'
1838       attr {
1839         key: 'component_types'
1840         value { list { type: DT_FLOAT } }
1841       }
1842       attr {
1843         key: 'timeout_ms'
1844         value { i: -1 }
1845       }
1846     }
1847     versions { producer: 9 }
1848   )proto", &graph);
1849 
1850   {
1851     // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
1852     SessionOptions options;
1853     (*options.config.mutable_device_count())["CPU"] = 2;
1854     options.config.set_operation_timeout_in_ms(100);
1855 
1856     std::unique_ptr<Session> session(NewSession(options));
1857     ASSERT_TRUE(session != nullptr);
1858     TF_ASSERT_OK(session->Create(graph));
1859 
1860     // Verifies that the error code is DEADLINE_EXCEEDED.
1861     Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr);
1862     ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
1863     TF_ASSERT_OK(session->Close());
1864   }
1865 
1866   {
1867     // Creates a session with no operation_timeout_in_ms.
1868     auto session = CreateSession();
1869     ASSERT_TRUE(session != nullptr);
1870     TF_ASSERT_OK(session->Create(graph));
1871     RunOptions run_options;
1872     run_options.set_timeout_in_ms(20);
1873     // Verifies that the error code is DEADLINE_EXCEEDED.
1874     Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"},
1875                              nullptr, nullptr);
1876     ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code());
1877     TF_ASSERT_OK(session->Close());
1878   }
1879 }
1880 
1881 // Accesses the cancellation manager for the step after the step has been
1882 // cancelled.
1883 class CancellationMgrPollingOp : public OpKernel {
1884  public:
CancellationMgrPollingOp(OpKernelConstruction * ctx)1885   explicit CancellationMgrPollingOp(OpKernelConstruction* ctx)
1886       : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1887   void Compute(OpKernelContext* ctx) override {
1888     CancellationManager* cm = ctx->cancellation_manager();
1889     while (!cm->IsCancelled()) {
1890       ctx->env()->SleepForMicroseconds(1000);
1891     }
1892     notification.Notify();
1893   }
1894   static Notification notification;
1895 };
1896 Notification CancellationMgrPollingOp::notification;
1897 
1898 REGISTER_KERNEL_BUILDER(Name("CancellationMgrPollingOp").Device(DEVICE_CPU),
1899                         CancellationMgrPollingOp);
1900 REGISTER_OP("CancellationMgrPollingOp").Doc("");
1901 
TEST(DirectSessionTest,TestTimeoutCleanShutdown)1902 TEST(DirectSessionTest, TestTimeoutCleanShutdown) {
1903   GraphDef graph;
1904   // Creates a graph with one FIFOQueue and one dequeue op.
1905   protobuf::TextFormat::ParseFromString(R"proto(
1906     node {
1907       name: 'cm_polling'
1908       op: 'CancellationMgrPollingOp'
1909       device: '/device:CPU:0'
1910     }
1911     versions { producer: 9 }
1912   )proto", &graph);
1913 
1914   // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
1915   SessionOptions options;
1916   options.config.set_operation_timeout_in_ms(100);
1917   std::unique_ptr<Session> session(NewSession(options));
1918   ASSERT_TRUE(session != nullptr);
1919   TF_ASSERT_OK(session->Create(graph));
1920 
1921   // Verifies that the error code is DEADLINE_EXCEEDED.
1922   Status s = session->Run({}, {}, {"cm_polling"}, nullptr);
1923   ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
1924 
1925   // Verify that the op ran to completion.
1926   ASSERT_TRUE(CancellationMgrPollingOp::notification.HasBeenNotified());
1927 
1928   TF_ASSERT_OK(session->Close());
1929 }
1930 
TestSessionInterOpThreadsImpl(bool use_function_lib,bool use_global_pools)1931 static void TestSessionInterOpThreadsImpl(bool use_function_lib,
1932                                           bool use_global_pools) {
1933   using test::function::blocking_op_state;
1934   using test::function::BlockingOpState;
1935 
1936   FunctionDefLibrary library_graph_def;
1937   if (use_function_lib) {
1938     *library_graph_def.add_function() = test::function::BlockingOpFn();
1939   }
1940 
1941   FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1942   Graph g(&flib);
1943   Tensor t(DT_FLOAT, TensorShape({}));
1944   t.scalar<float>()() = {1.2f};
1945   Node* x = test::graph::Constant(&g, t);
1946   Node* y;
1947   if (use_function_lib) {
1948     y = test::graph::Unary(&g, "BlockingOpFn", x);
1949   } else {
1950     y = test::graph::Unary(&g, "BlockingOp", x);
1951   }
1952   GraphDef def;
1953   g.ToGraphDef(&def);
1954   *def.mutable_library() = library_graph_def;
1955 
1956   // Create session with two inter-op thread pools.
1957   SessionOptions options;
1958   // Turn off optimizations so that the blocking op doesn't get invoked during
1959   // graph setup.
1960   options.config.mutable_graph_options()
1961       ->mutable_optimizer_options()
1962       ->set_opt_level(OptimizerOptions::L0);
1963   options.config.mutable_graph_options()
1964       ->mutable_rewrite_options()
1965       ->set_constant_folding(RewriterConfig::OFF);
1966   (*options.config.mutable_device_count())["CPU"] = 2;
1967   (*options.config.mutable_device_count())["GPU"] = 0;
1968 
1969   auto* p = options.config.add_session_inter_op_thread_pool();
1970   if (use_global_pools) p->set_global_name("large pool");
1971   p = options.config.add_session_inter_op_thread_pool();
1972   if (use_global_pools) p->set_global_name("small pool");
1973   p->set_num_threads(1);
1974   const int kSyncPool = -1;
1975   const int kLargePool = 0;
1976   const int kSmallPool = 1;
1977 
1978   std::vector<std::unique_ptr<Session>> sessions;
1979   if (!use_global_pools) {
1980     sessions.emplace_back(NewSession(options));
1981     TF_ASSERT_OK(sessions.back()->Create(def));
1982   }
1983   mutex sessions_mu;
1984 
1985   std::atomic<int32> num_done(0);
1986   // Runs session to compute <node>:0 using inter_op thread pool <pool>.
1987   auto add_session_run_call =
1988       [use_global_pools, &def, &options, &sessions, &sessions_mu, &num_done](
1989           thread::ThreadPool* tp, Node* node, int inter_op_pool) {
1990         auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu,
1991                    inter_op_pool, node, &num_done]() {
1992           RunOptions run_options;
1993           run_options.set_inter_op_thread_pool(inter_op_pool);
1994           std::vector<Tensor> outputs;
1995 
1996           Session* session;
1997           if (use_global_pools) {
1998             std::unique_ptr<Session> s(NewSession(options));
1999             TF_ASSERT_OK(s->Create(def));
2000             session = s.get();
2001 
2002             mutex_lock l(sessions_mu);
2003             sessions.emplace_back(std::move(s));
2004           } else {
2005             session = sessions[0].get();
2006           }
2007 
2008           Status s = session->Run(run_options, {} /* inputs */,
2009                                   {node->name() + ":0"} /* output_names */, {},
2010                                   &outputs, nullptr /* run_metadata */);
2011           TF_CHECK_OK(s);
2012           ASSERT_EQ(1, outputs.size());
2013           auto flat = outputs[0].flat<float>();
2014           EXPECT_FLOAT_EQ(1.2, flat(0));
2015           num_done.fetch_add(1);
2016         };
2017         if (tp != nullptr) {
2018           tp->Schedule(fn);
2019         } else {
2020           fn();
2021         }
2022       };
2023 
2024   // For blocking states:
2025   // - Starts at 0, BlockingOp::Compute will move to 1.
2026   // - This main thread will wait for 1, then move to 2 when other ops are done.
2027   //   Moving to 2 unblocks the blocking op, which then moves to state 3.
2028 
2029   // Run the graph once on the non-limited pool.
2030   thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 1);
2031   blocking_op_state = new BlockingOpState();
2032   add_session_run_call(tp1, y, kLargePool);
2033   blocking_op_state->AwaitState(1);
2034   blocking_op_state->MoveToState(1, 2);
2035   blocking_op_state->AwaitState(3);
2036   blocking_op_state->MoveToState(3, 0);
2037   delete tp1;
2038   num_done = 0;
2039 
2040   tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
2041 
2042   // Launch a session run call. It will not finish until the blocking op is
2043   // unblocked, because it is using all threads in the small pool.
2044   add_session_run_call(tp1, y, kSmallPool);
2045 
2046   blocking_op_state->AwaitState(1);  // Wait for the blocking op to Compute.
2047 
2048   // These will block on <BlockingOpState>.
2049   const int kBlockedThreads = 3;
2050   for (int i = 0; i < kBlockedThreads; ++i) {
2051     add_session_run_call(tp1, x, kSmallPool);
2052   }
2053 
2054   // Launch session calls using the other inter-op pool. These will finish
2055   // as they are in inter_op pool #2.
2056   thread::ThreadPool* tp2 = new thread::ThreadPool(Env::Default(), "tp2", 3);
2057   const int kUnblockedThreads = 4;
2058   for (int i = 0; i < kUnblockedThreads; ++i) {
2059     add_session_run_call(tp2, x, kLargePool);
2060   }
2061   delete tp2;
2062   EXPECT_EQ(kUnblockedThreads, num_done.load());
2063 
2064   // Launch a session call using this thread. This will finish as it runs
2065   // synchronously in this thread.
2066   add_session_run_call(nullptr, x, kSyncPool);
2067 
2068   // Unblock the blocked op and wait for the blocked functions to finish.
2069   blocking_op_state->MoveToState(1, 2);
2070   delete tp1;
2071 
2072   EXPECT_EQ(kUnblockedThreads + kBlockedThreads + 1 + 1, num_done.load());
2073   delete blocking_op_state;
2074   blocking_op_state = nullptr;
2075 }
2076 
TEST(DirectSessionTest,TestSessionInterOpThreads)2077 TEST(DirectSessionTest, TestSessionInterOpThreads) {
2078   TestSessionInterOpThreadsImpl(false /* use_function_lib */,
2079                                 false /*use_global_pools */);
2080 }
2081 
TEST(DirectSessionTest,TestSessionInterOpThreadsWithFunctions)2082 TEST(DirectSessionTest, TestSessionInterOpThreadsWithFunctions) {
2083   TestSessionInterOpThreadsImpl(true /* use_function_lib */,
2084                                 false /*use_global_pools */);
2085 }
2086 
TEST(DirectSessionTest,TestSessionInterOpGlobalPools)2087 TEST(DirectSessionTest, TestSessionInterOpGlobalPools) {
2088   TestSessionInterOpThreadsImpl(false /* use_function_lib */,
2089                                 true /*use_global_pools */);
2090 }
2091 
TEST(DirectSessionTest,TestSessionInterOpGlobalPoolsWithFunctions)2092 TEST(DirectSessionTest, TestSessionInterOpGlobalPoolsWithFunctions) {
2093   TestSessionInterOpThreadsImpl(true /* use_function_lib */,
2094                                 true /*use_global_pools */);
2095 }
2096 
TEST(DirectSessionTest,TestSessionInterOpThreadsInvalidOptions)2097 TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
2098   Graph g(OpRegistry::Global());
2099   Tensor t(DT_FLOAT, TensorShape({}));
2100   t.scalar<float>()() = {1.2f};
2101   Node* x = test::graph::Constant(&g, t);
2102   GraphDef def;
2103   g.ToGraphDef(&def);
2104 
2105   SessionOptions options;
2106   options.config.mutable_graph_options()
2107       ->mutable_optimizer_options()
2108       ->set_opt_level(OptimizerOptions::L0);
2109   (*options.config.mutable_device_count())["CPU"] = 2;
2110 
2111   options.config.add_session_inter_op_thread_pool();
2112 
2113   // Wrong pool number on Run call.
2114   {
2115     std::unique_ptr<Session> session(NewSession(options));
2116     TF_ASSERT_OK(session->Create(def));
2117     for (int pool_num = -2; pool_num <= 1; pool_num += 3) {
2118       RunOptions run_options;
2119       run_options.set_inter_op_thread_pool(pool_num);
2120       std::vector<Tensor> outputs;
2121       Status s = session->Run(run_options, {} /* inputs */,
2122                               {x->name() + ":0"} /* output_names */, {},
2123                               &outputs, nullptr /* run_metadata */);
2124       EXPECT_EQ(
2125           strings::StrCat("Invalid argument: Invalid inter_op_thread_pool: ",
2126                           pool_num),
2127           s.ToString());
2128     }
2129   }
2130 
2131   // Global name changes thread count.
2132   std::vector<std::unique_ptr<Session>> sessions;
2133   auto* pool_config = options.config.mutable_session_inter_op_thread_pool(0);
2134   pool_config->set_num_threads(0);
2135   pool_config->set_global_name("foo");
2136   sessions.emplace_back(NewSession(options));
2137   TF_ASSERT_OK(sessions.back()->Create(def));
2138   sessions.emplace_back(NewSession(options));  // repeat creation, okay.
2139   TF_ASSERT_OK(sessions.back()->Create(def));
2140   for (int pass = 0; pass < 2; ++pass) {
2141     for (int i = 1; i < 128; ++i) {
2142       pool_config->set_num_threads(i);
2143       sessions.emplace_back(NewSession(options));
2144       auto status = sessions.back()->Create(def);
2145       ASSERT_FALSE(status.ok()) << status;
2146     }
2147 
2148     // Clear existing sessions before second pass; error still happens.
2149     sessions.clear();
2150   }
2151 }
2152 
TEST(DirectSessionTest,TestDirectSessionRunClose)2153 TEST(DirectSessionTest, TestDirectSessionRunClose) {
2154   // Construct a graph with a variable and a single assign.
2155   Graph g(OpRegistry::Global());
2156   Tensor t(DT_FLOAT, TensorShape({}));
2157   t.scalar<float>()() = {1.2f};
2158   Node* var_val = test::graph::Constant(&g, t);
2159   Node* var = test::graph::Var(&g, DT_FLOAT, {});
2160   Node* var_assign = test::graph::Assign(&g, var, var_val);
2161   GraphDef def;
2162   g.ToGraphDef(&def);
2163 
2164   SessionOptions options;
2165   (*options.config.mutable_device_count())["CPU"] = 2;
2166   std::unique_ptr<Session> session(NewSession(options));
2167   ASSERT_TRUE(session != nullptr);
2168   TF_ASSERT_OK(session->Create(def));
2169 
2170   // Assign a value to the var.
2171   TF_ASSERT_OK(session->Run({} /* inputs */, {},
2172                             {var_assign->name()} /* target_nodes */, nullptr));
2173 
2174   // Run a read on the variable to ensure that it works.
2175   std::vector<Tensor> outputs;
2176   TF_ASSERT_OK(session->Run(
2177       {} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
2178   EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
2179   outputs.clear();
2180 
2181   // Make a callable handle before closing the session.
2182   Session::CallableHandle handle;
2183   TF_ASSERT_OK(session->MakeCallable(
2184       MakeCallableOptions({}, {}, {var_assign->name()}), &handle));
2185 
2186   // Close the session.
2187   TF_ASSERT_OK(session->Close());
2188 
2189   // Run the read on the variable to get an error.
2190   Status s = session->Run({} /* inputs */, {},
2191                           {var_assign->name()} /* target_nodes */, nullptr);
2192   EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
2193 
2194   // Run the read as a callable to verify that we get the same error.
2195   s = session->RunCallable(handle, {}, {}, nullptr);
2196   EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
2197 }
2198 
TEST(DirectSessionTest,TestDirectSessionPRunClose)2199 TEST(DirectSessionTest, TestDirectSessionPRunClose) {
2200   GraphDef def;
2201   Graph g(OpRegistry::Global());
2202 
2203   Tensor first_value(DT_FLOAT, TensorShape({}));
2204   first_value.scalar<float>()() = 1.0;
2205   Node* first_const = test::graph::Constant(&g, first_value);
2206   Node* first_identity = test::graph::Identity(&g, first_const);
2207 
2208   Tensor second_value(DT_FLOAT, TensorShape({}));
2209   second_value.scalar<float>()() = 2.0;
2210   Node* second_const = test::graph::Constant(&g, second_value);
2211   Node* second_identity = test::graph::Identity(&g, second_const);
2212 
2213   Node* third = test::graph::Add(&g, first_identity, second_identity);
2214   Node* third_identity = test::graph::Identity(&g, third);
2215 
2216   g.ToGraphDef(&def);
2217 
2218   auto session = CreateSession();
2219   ASSERT_TRUE(session != nullptr);
2220   TF_ASSERT_OK(session->Create(def));
2221 
2222   std::vector<Tensor> outputs;
2223 
2224   string handle;
2225   Status s = session->PRunSetup(
2226       {first_const->name(), second_const->name()},
2227       {first_identity->name() + ":0", second_identity->name() + ":0",
2228        third_identity->name() + ":0"},
2229       {}, &handle);
2230   TF_ASSERT_OK(s);
2231 
2232   Tensor value_11(DT_FLOAT, TensorShape({}));
2233   value_11.scalar<float>()() = 11.0;
2234   Tensor value_22(DT_FLOAT, TensorShape({}));
2235   value_22.scalar<float>()() = 22.0;
2236 
2237   // Close the session.
2238   TF_ASSERT_OK(session->Close());
2239 
2240   // Feed first_const, fetch first_identity
2241   s = session->PRun(handle, {{first_const->name(), value_11}},
2242                     {first_identity->name() + ":0"}, &outputs);
2243   EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
2244 }
2245 
TEST(DirectSessionTest,TestDirectSessionReset)2246 TEST(DirectSessionTest, TestDirectSessionReset) {
2247   // Construct a graph with a variable and a single assign.
2248   Graph g(OpRegistry::Global());
2249   Tensor t(DT_FLOAT, TensorShape({}));
2250   t.scalar<float>()() = {1.2f};
2251   Node* var_val = test::graph::Constant(&g, t);
2252   Node* var = test::graph::Var(&g, DT_FLOAT, {});
2253   Node* var_assign = test::graph::Assign(&g, var, var_val);
2254   GraphDef def;
2255   g.ToGraphDef(&def);
2256 
2257   SessionOptions options;
2258   (*options.config.mutable_device_count())["CPU"] = 2;
2259   std::unique_ptr<Session> session(NewSession(options));
2260   ASSERT_TRUE(session != nullptr);
2261   TF_ASSERT_OK(session->Create(def));
2262 
2263   // Assign a value to the var.
2264   TF_ASSERT_OK(session->Run({} /* inputs */, {},
2265                             {var_assign->name()} /* target_nodes */, nullptr));
2266 
2267   // Run a read on the variable to ensure that it works.
2268   std::vector<Tensor> outputs;
2269   TF_ASSERT_OK(session->Run(
2270       {} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
2271   EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
2272   outputs.clear();
2273 
2274   // Reset the containers.
2275   TF_EXPECT_OK(Reset(options, {}));
2276 
2277   // Run the read on the variable to get an error.
2278   // TODO(suharshs): This test only works because we close the Session in Reset.
2279   // If we change the behavior of Reset to not close the Session, this test will
2280   // fail, since the Variable buffer is cached by var.
2281   Status s = session->Run({} /* inputs */, {},
2282                           {var_assign->name()} /* target_nodes */, nullptr);
2283   EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
2284 }
2285 
TEST(DirectSessionTest,LocalDeviceManager)2286 TEST(DirectSessionTest, LocalDeviceManager) {
2287   SessionOptions options;
2288   std::unique_ptr<Session> session(NewSession(options));
2289 
2290   const DeviceMgr* mgr = nullptr;
2291   TF_ASSERT_OK(session->LocalDeviceManager(&mgr));
2292   ASSERT_TRUE(mgr != nullptr);
2293   EXPECT_GT(mgr->ListDevices().size(), 0);
2294 }
2295 
2296 // y = tf.square(x)
CreateGraphForYEqualsXSquared()2297 GraphDef CreateGraphForYEqualsXSquared() {
2298   GraphDef graph_def;
2299   const char* text_proto = R"EOF(
2300 node {
2301   name: "x"
2302   op: "Placeholder"
2303   attr { key: "dtype" value { type: DT_FLOAT } }
2304   attr { key: "shape" value { shape { unknown_rank: true } } }
2305 }
2306 node {
2307   name: "y"
2308   op: "Square"
2309   input: "x"
2310   attr { key: "T" value { type: DT_FLOAT } }
2311 }
2312 versions {
2313   producer: 26
2314 }
2315   )EOF";
2316 
2317   QCHECK(protobuf::TextFormat::ParseFromString(text_proto, &graph_def));
2318   return graph_def;
2319 }
2320 
2321 // A graph that consumes and produces string tensors
2322 // (which are not GPU-compatible, i.e., there are no
2323 // GPU kernels for these operations).
IsCUDATensor(const Tensor & t)2324 bool IsCUDATensor(const Tensor& t) {
2325 #ifdef GOOGLE_CUDA
2326   cudaPointerAttributes attributes;
2327   cudaError_t err =
2328       cudaPointerGetAttributes(&attributes, t.tensor_data().data());
2329   if (err == cudaErrorInvalidValue) return false;
2330   CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err);
2331   return (attributes.type == cudaMemoryTypeDevice);
2332 #elif TENSORFLOW_USE_ROCM
2333   hipPointerAttribute_t attributes;
2334   hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data());
2335   if (err == hipErrorInvalidValue) return false;
2336   CHECK_EQ(hipSuccess, err) << hipGetErrorString(err);
2337   return (attributes.memoryType == hipMemoryTypeDevice);
2338 #else
2339   return false;
2340 #endif
2341 }
2342 
GPUDeviceName(Session * session)2343 string GPUDeviceName(Session* session) {
2344   std::vector<DeviceAttributes> devices;
2345   TF_CHECK_OK(session->ListDevices(&devices));
2346   for (const DeviceAttributes& d : devices) {
2347     if (d.device_type() == "GPU" || d.device_type() == "gpu") {
2348       return d.name();
2349     }
2350   }
2351   return "";
2352 }
2353 
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory)2354 TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory) {
2355   std::unique_ptr<Session> session(NewSession(SessionOptions()));
2356   const string gpu_device_name = GPUDeviceName(session.get());
2357   if (gpu_device_name.empty()) {
2358     LOG(INFO) << "Skipping test since no GPU is available";
2359     return;
2360   }
2361 
2362   TF_ASSERT_OK(session->Create(CreateGraphForYEqualsXSquared()));
2363 
2364   CallableOptions opts;
2365   opts.add_feed("x:0");
2366   opts.add_fetch("y:0");
2367 
2368   Tensor gpu_tensor;
2369 
2370   {
2371     Session::CallableHandle feed_cpu_fetch_gpu;
2372     opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2373     opts.set_fetch_skip_sync(true);
2374     TF_ASSERT_OK(session->MakeCallable(opts, &feed_cpu_fetch_gpu));
2375     Tensor input(DT_FLOAT, {});
2376     input.scalar<float>()() = 2.0f;
2377     std::vector<Tensor> outputs;
2378     TF_ASSERT_OK(
2379         session->RunCallable(feed_cpu_fetch_gpu, {input}, &outputs, nullptr));
2380     TF_ASSERT_OK(session->ReleaseCallable(feed_cpu_fetch_gpu));
2381     ASSERT_EQ(1, outputs.size());
2382     gpu_tensor = outputs[0];
2383     ASSERT_TRUE(IsCUDATensor(gpu_tensor));
2384   }
2385 
2386   {
2387     Session::CallableHandle feed_gpu_fetch_cpu;
2388     opts.clear_fetch_devices();
2389     opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2390     TF_ASSERT_OK(session->MakeCallable(opts, &feed_gpu_fetch_cpu));
2391     std::vector<Tensor> outputs;
2392     TF_ASSERT_OK(session->RunCallable(feed_gpu_fetch_cpu, {gpu_tensor},
2393                                       &outputs, nullptr));
2394     TF_ASSERT_OK(session->ReleaseCallable(feed_gpu_fetch_cpu));
2395     ASSERT_EQ(1, outputs.size());
2396     // The output is in CPU/host memory, so it can be dereferenced.
2397     ASSERT_EQ(16.0, outputs[0].scalar<float>()());
2398   }
2399 }
2400 
CreateIdentityGraphDef(DataType dtype)2401 GraphDef CreateIdentityGraphDef(DataType dtype) {
2402   GraphDef def;
2403 
2404   AttrValue dtype_attr;
2405   dtype_attr.set_type(dtype);
2406 
2407   AttrValue shape_attr;
2408   shape_attr.mutable_shape()->set_unknown_rank(true);
2409 
2410   auto* placeholder = def.add_node();
2411   placeholder->set_name("x");
2412   placeholder->set_op("Placeholder");
2413   placeholder->mutable_attr()->insert({"dtype", dtype_attr});
2414   placeholder->mutable_attr()->insert({"shape", shape_attr});
2415 
2416   auto* identity = def.add_node();
2417   identity->set_name("y");
2418   identity->set_op("Identity");
2419   identity->add_input("x");
2420   identity->mutable_attr()->insert({"T", dtype_attr});
2421 
2422   return def;
2423 }
2424 
TestFeedAndFetchTensorsInDeviceMemory(const SessionOptions & session_options,DataType dtype)2425 void TestFeedAndFetchTensorsInDeviceMemory(
2426     const SessionOptions& session_options, DataType dtype) {
2427   std::unique_ptr<Session> session(NewSession(session_options));
2428   const string gpu_device_name = GPUDeviceName(session.get());
2429   if (gpu_device_name.empty()) {
2430     LOG(INFO) << "Skipping test since no GPU is available";
2431     return;
2432   }
2433 
2434   TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
2435       << DataType_Name(dtype);
2436 
2437   CallableOptions opts;
2438   opts.add_feed("x:0");
2439   opts.add_fetch("y:0");
2440 
2441   Tensor gpu_tensor;
2442   Tensor host_tensor(dtype, {3});
2443   {
2444     // Ask for the fetched tensor to be backed by device memory.
2445     // Even though the kernel that created the tensor produced it in host
2446     // memory.
2447     opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2448     opts.set_fetch_skip_sync(true);
2449     Session::CallableHandle handle;
2450     TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
2451     std::vector<Tensor> outputs;
2452     TF_ASSERT_OK(session->RunCallable(handle, {host_tensor}, &outputs, nullptr))
2453         << DataType_Name(dtype);
2454     TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
2455     ASSERT_EQ(1, outputs.size()) << DataType_Name(dtype);
2456     gpu_tensor = outputs[0];
2457     ASSERT_TRUE(IsCUDATensor(gpu_tensor)) << DataType_Name(dtype);
2458   }
2459 
2460   {
2461     // Feed a tensor backed by device memory, even though the operations in the
2462     // graph expect it in host memory.
2463     opts.clear_fetch_devices();
2464     opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2465     Session::CallableHandle handle;
2466     TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
2467     std::vector<Tensor> outputs;
2468     TF_ASSERT_OK(session->RunCallable(handle, {gpu_tensor}, &outputs, nullptr))
2469         << DataType_Name(dtype);
2470     TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
2471     ASSERT_EQ(1, outputs.size());
2472     const StringPiece actual_data = outputs[0].tensor_data();
2473     const StringPiece expected_data = host_tensor.tensor_data();
2474     EXPECT_EQ(expected_data.size(), actual_data.size()) << DataType_Name(dtype);
2475     EXPECT_EQ(0, memcmp(expected_data.data(), actual_data.data(),
2476                         std::min(expected_data.size(), actual_data.size())))
2477         << DataType_Name(dtype);
2478   }
2479 }
2480 
TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(const SessionOptions & session_options,DataType dtype)2481 void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(
2482     const SessionOptions& session_options, DataType dtype) {
2483   std::unique_ptr<Session> session(NewSession(session_options));
2484   const string gpu_device_name = GPUDeviceName(session.get());
2485   if (gpu_device_name.empty()) {
2486     LOG(INFO) << "Skipping test since no GPU is available";
2487     return;
2488   }
2489 
2490   TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
2491       << DataType_Name(dtype);
2492 
2493   CallableOptions opts;
2494   opts.add_feed("x:0");
2495   opts.add_fetch("y:0");
2496 
2497   // Fail when asking to fetch into GPU memory.
2498   {
2499     opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2500     opts.set_fetch_skip_sync(true);
2501     Session::CallableHandle handle;
2502     Status status = session->MakeCallable(opts, &handle);
2503     EXPECT_FALSE(status.ok()) << DataType_Name(dtype);
2504     EXPECT_TRUE(absl::StrContains(
2505         status.error_message(),
2506         strings::StrCat(
2507             "Cannot feed or fetch tensor 'y:0' from device ", gpu_device_name,
2508             " as feeding/fetching from GPU devices is not yet supported for ",
2509             DataTypeString(dtype), " tensors")))
2510         << DataType_Name(dtype) << ", Status: " << status;
2511   }
2512 
2513   // Fail when feeding from GPU memory.
2514   {
2515     opts.clear_feed_devices();
2516     opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2517     Session::CallableHandle handle;
2518     Status status = session->MakeCallable(opts, &handle);
2519     EXPECT_FALSE(status.ok());
2520     EXPECT_TRUE(absl::StrContains(
2521         status.error_message(),
2522         strings::StrCat(
2523             "Cannot feed or fetch tensor 'x:0' from device ", gpu_device_name,
2524             " as feeding/fetching from GPU devices is not yet supported for ",
2525             DataTypeString(dtype), " tensors")))
2526         << DataType_Name(dtype) << ", Status: " << status;
2527   }
2528 }
2529 
TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(const SessionOptions & opts)2530 void TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(
2531     const SessionOptions& opts) {
2532   // Feeding/fetching on device does not work for all DataTypes as it
2533   // relies on the implementation of the _Arg and _Retval kernels which
2534   // are not registered for some types or consume/produce inputs/outputs
2535   // in host memory for some types.
2536   //
2537   // Run through all datatypes to validate that either:
2538   // (a) MakeCallable fails (because the given type cannot be fed/fetched
2539   //     in device memory),
2540   //     OR
2541   // (b) Succeeds: RunCallable should gladly accept inputs in device memory
2542   //     and produce output tensors in device memory.
2543   for (int i = DataType_MIN; i <= DataType_MAX; ++i) {
2544     if (!DataType_IsValid(i)) continue;
2545     const DataType dtype = static_cast<DataType>(i);
2546     switch (dtype) {
2547       case DT_INVALID:
2548         break;
2549       case DT_BFLOAT16:
2550       case DT_BOOL:
2551       case DT_COMPLEX128:
2552       case DT_COMPLEX64:
2553       case DT_DOUBLE:
2554       case DT_FLOAT:
2555       case DT_HALF:
2556       case DT_INT16:
2557       case DT_INT64:
2558       case DT_INT8:
2559       case DT_UINT16:
2560       case DT_UINT8:
2561         TestFeedAndFetchTensorsInDeviceMemory(opts, dtype);
2562         break;
2563       default:
2564         // Ignore all REF types since Tensors of this type aren't intended to
2565         // be fed (and attempting to create one via the Tensor constructor
2566         // will result in a LOG(FATAL)).
2567         if (!IsRefType(dtype)) {
2568           TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(opts, dtype);
2569         }
2570         break;
2571     }
2572   }
2573 }
2574 
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory_AllDataTypes)2575 TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory_AllDataTypes) {
2576   SessionOptions opts;
2577   opts.config.set_allow_soft_placement(false);
2578   TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
2579 }
2580 
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement)2581 TEST(DirectSessionTest,
2582      FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement) {
2583   SessionOptions opts;
2584   opts.config.set_allow_soft_placement(true);
2585   TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
2586 }
2587 
2588 // A simple benchmark for the overhead of `DirectSession::Run()` calls
2589 // with varying numbers of feeds/fetches.
FeedFetchBenchmarkHelper(::testing::benchmark::State & state,int num_feeds,bool use_make_callable,int inter_op_threads,bool use_single_threaded_executor)2590 void FeedFetchBenchmarkHelper(::testing::benchmark::State& state, int num_feeds,
2591                               bool use_make_callable, int inter_op_threads,
2592                               bool use_single_threaded_executor) {
2593   Tensor value(DT_FLOAT, TensorShape());
2594   value.flat<float>()(0) = 37.0;
2595 
2596   std::vector<std::pair<string, Tensor>> inputs;
2597   inputs.reserve(num_feeds);
2598   std::vector<string> outputs;
2599 
2600   Graph g(OpRegistry::Global());
2601   for (int i = 0; i < num_feeds; ++i) {
2602     // NOTE(mrry): We pin nodes to the "/cpu:0" device, so as not to
2603     // measure CPU<->GPU copying overhead. We should also optimize and
2604     // monitor this overhead where possible, but that is not the
2605     // object of study in this benchmark.
2606     Node* placeholder;
2607     TF_CHECK_OK(NodeBuilder(g.NewName("Placeholder"), "Placeholder")
2608                     .Attr("shape", TensorShape())
2609                     .Attr("dtype", DT_FLOAT)
2610                     .Device("/cpu:0")
2611                     .Finalize(&g, &placeholder));
2612     Node* identity;
2613     TF_CHECK_OK(NodeBuilder(g.NewName("Identity"), "Identity")
2614                     .Input(placeholder)
2615                     .Attr("T", DT_FLOAT)
2616                     .Device("/cpu:0")
2617                     .Finalize(&g, &identity));
2618     inputs.push_back({placeholder->name() + ":0", value});
2619     outputs.push_back(identity->name() + ":0");
2620   }
2621   GraphDef gd;
2622   g.ToGraphDef(&gd);
2623   SessionOptions opts;
2624   opts.config.set_inter_op_parallelism_threads(inter_op_threads);
2625   if (use_single_threaded_executor) {
2626     opts.config.mutable_experimental()->set_executor_type(
2627         "SINGLE_THREADED_EXECUTOR");
2628   }
2629   std::unique_ptr<Session> session(NewSession(opts));
2630   TF_CHECK_OK(session->Create(gd));
2631   if (use_make_callable) {
2632     Session::CallableHandle handle;
2633     CallableOptions callable_options;
2634     std::vector<Tensor> input_tensors;
2635     for (const auto& input : inputs) {
2636       callable_options.add_feed(input.first);
2637       input_tensors.push_back(input.second);
2638     }
2639     for (const string& output : outputs) {
2640       callable_options.add_fetch(output);
2641     }
2642     TF_CHECK_OK(session->MakeCallable(callable_options, &handle));
2643 
2644     for (auto s : state) {
2645       std::vector<Tensor> output_values;
2646       TF_CHECK_OK(
2647           session->RunCallable(handle, input_tensors, &output_values, nullptr));
2648     }
2649   } else {
2650     {
2651       // NOTE(mrry): Ignore the first run, which will incur the graph
2652       // partitioning/pruning overhead and skew the results.
2653       //
2654       // Note that we should also optimize and monitor the overhead on
2655       // the first run, which will impact application startup times, but
2656       // that is not the object of study in this benchmark.
2657       std::vector<Tensor> output_values;
2658       TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
2659     }
2660 
2661     for (auto s : state) {
2662       std::vector<Tensor> output_values;
2663       TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
2664     }
2665   }
2666 }
2667 
BM_FeedFetch(::testing::benchmark::State & state)2668 void BM_FeedFetch(::testing::benchmark::State& state) {
2669   const int num_feeds = state.range(0);
2670 
2671   FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ false,
2672                            /* inter_op_threads */ 0,
2673                            /* use_single_threaded_executor */ false);
2674 }
BM_FeedFetchCallable(::testing::benchmark::State & state)2675 void BM_FeedFetchCallable(::testing::benchmark::State& state) {
2676   const int num_feeds = state.range(0);
2677 
2678   FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2679                            /* inter_op_threads */ 0,
2680                            /* use_single_threaded_executor */ false);
2681 }
BM_FeedFetchCallableSingleThread(::testing::benchmark::State & state)2682 void BM_FeedFetchCallableSingleThread(::testing::benchmark::State& state) {
2683   const int num_feeds = state.range(0);
2684 
2685   FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2686                            /* inter_op_threads */ -1,
2687                            /* use_single_threaded_executor */ false);
2688 }
BM_FeedFetchCallableSingleThreadExecutor(::testing::benchmark::State & state)2689 void BM_FeedFetchCallableSingleThreadExecutor(
2690     ::testing::benchmark::State& state) {
2691   const int num_feeds = state.range(0);
2692 
2693   FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2694                            /* inter_op_threads */ -1,
2695                            /* use_single_threaded_executor */ true);
2696 }
2697 
2698 BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2699 BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2700 BENCHMARK(BM_FeedFetchCallableSingleThread)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2701 BENCHMARK(BM_FeedFetchCallableSingleThreadExecutor)
2702     ->Arg(1)
2703     ->Arg(2)
2704     ->Arg(5)
2705     ->Arg(10);
2706 
2707 }  // namespace
2708 
2709 class DirectSessionCollectiveTest : public ::testing::Test {
2710  public:
2711   // Creates a graph with CollectiveOps inside functions and runs it.  Returns
2712   // the generated collective_graph_key.
RunGraphWithCollectiveFunctions(bool add_unused_function,int64 * collective_graph_key)2713   Status RunGraphWithCollectiveFunctions(bool add_unused_function,
2714                                          int64* collective_graph_key) {
2715     GraphDef g = CreateGraph(add_unused_function);
2716     const Tensor t1 =
2717         test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
2718     const Tensor t2 =
2719         test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
2720     auto session = CreateSession();
2721     TF_RETURN_IF_ERROR(session->Create(g));
2722     std::vector<Tensor> outputs;
2723     TF_RETURN_IF_ERROR(
2724         session->Run({{"input0:0", t1}, {"input1:0", t2}}, {},
2725                      {"collective_call0:0", "collective_call1:0"}, &outputs));
2726     DirectSession* direct_session = static_cast<DirectSession*>(session.get());
2727     {
2728       mutex_lock l(direct_session->collective_graph_key_lock_);
2729       *collective_graph_key = direct_session->collective_graph_key_;
2730     }
2731     return Status::OK();
2732   }
2733 
2734  private:
2735   // Creates a function with name `function_name` and a single CollectiveReduce
2736   // node with instance key set as `instance_key`.
CollectiveFunction(const string & function_name,int instance_key)2737   FunctionDef CollectiveFunction(const string& function_name,
2738                                  int instance_key) {
2739     return FunctionDefHelper::Define(
2740         // Function name
2741         function_name,
2742         // In def
2743         {"arg:float"},
2744         // Out def
2745         {"reduce:float"},
2746         // Attr def
2747         {},
2748         // Node def
2749         {{
2750             {"reduce"},
2751             "CollectiveReduce",
2752             {"arg"},
2753             {{"group_size", 2},
2754              {"group_key", 1},
2755              {"instance_key", instance_key},
2756              {"subdiv_offsets", gtl::ArraySlice<int32>({0})},
2757              {"merge_op", "Add"},
2758              {"final_op", "Div"},
2759              {"T", DT_FLOAT}},
2760         }});
2761   }
2762 
Input(int id)2763   NodeDef Input(int id) {
2764     AttrValue dtype_attr;
2765     SetAttrValue(DT_FLOAT, &dtype_attr);
2766     NodeDef input;
2767     input.set_name(strings::StrCat("input", id));
2768     input.set_op("Placeholder");
2769     input.mutable_attr()->insert({"dtype", dtype_attr});
2770     return input;
2771   }
2772 
CollectiveCall(const string & op,const string & input,int cpu_id)2773   NodeDef CollectiveCall(const string& op, const string& input, int cpu_id) {
2774     NodeDef collective_call;
2775     collective_call.set_name(strings::StrCat("collective_call", cpu_id));
2776     collective_call.set_op(op);
2777     collective_call.add_input(input);
2778     collective_call.set_device(
2779         strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", cpu_id));
2780     return collective_call;
2781   }
2782 
2783   // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
2784   // CPU1, with instance_key 1, and appropriate placeholder inputs.  If
2785   // `add_unused_function` is true, adds another CollectiveFunction with
2786   // instance_key 2 that is not invoked in the graph.
CreateGraph(bool add_unused_function)2787   GraphDef CreateGraph(bool add_unused_function) {
2788     GraphDef g;
2789     FunctionDef collective_function =
2790         CollectiveFunction("CollectiveFunction1", 1);
2791     FunctionDefLibrary* lib = g.mutable_library();
2792     *lib->add_function() = collective_function;
2793     if (add_unused_function) {
2794       FunctionDef unused_function =
2795           CollectiveFunction("CollectiveFunction2", 2);
2796       *lib->add_function() = unused_function;
2797     }
2798 
2799     *g.add_node() = Input(0);
2800     *g.add_node() = Input(1);
2801     // CollectiveReduce on CPU0 with instance_key 1.
2802     *g.add_node() = CollectiveCall("CollectiveFunction1", "input0", 0);
2803     // CollectiveReduce on CPU1 with instance_key 1.
2804     *g.add_node() = CollectiveCall("CollectiveFunction1", "input1", 1);
2805 
2806     return g;
2807   }
2808 };
2809 
TEST_F(DirectSessionCollectiveTest,TestCollectiveGraphKeyUsesOnlyCalledFunctions)2810 TEST_F(DirectSessionCollectiveTest,
2811        TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
2812   int64 key1;
2813   TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
2814   int64 key2;
2815   TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
2816   ASSERT_EQ(key1, key2);
2817 }
2818 
2819 // Accesses the cancellation manager for the step after the step has been
2820 // cancelled.
2821 class StatefulOutputRequiredOp : public OpKernel {
2822  public:
StatefulOutputRequiredOp(OpKernelConstruction * ctx)2823   explicit StatefulOutputRequiredOp(OpKernelConstruction* ctx)
2824       : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)2825   void Compute(OpKernelContext* ctx) override {
2826     // The op counts the number of outputs required in the current subgraph,
2827     // and emits that number on each of its required outputs.
2828     Tensor count_outputs_required_t(int64{0});
2829     int64& count_outputs_required = count_outputs_required_t.scalar<int64>()();
2830     for (int i = 0; i < num_outputs(); ++i) {
2831       if (ctx->output_required(i)) ++count_outputs_required;
2832     }
2833     for (int i = 0; i < num_outputs(); ++i) {
2834       if (ctx->output_required(i)) ctx->set_output(i, count_outputs_required_t);
2835     }
2836   }
2837 };
2838 
2839 REGISTER_KERNEL_BUILDER(Name("StatefulOutputRequired").Device(DEVICE_CPU),
2840                         StatefulOutputRequiredOp);
2841 REGISTER_OP("StatefulOutputRequired")
2842     .Output("results : num_outs * int64")
2843     .Attr("num_outs : int = 5")
2844     .SetIsStateful();
2845 
TEST(DirectSessionTest,TestStatefulOutputRequiredOp)2846 TEST(DirectSessionTest, TestStatefulOutputRequiredOp) {
2847   GraphDef graph;
2848   // Creates a graph with a StatefulOutputRequired op with 5 outputs.
2849   protobuf::TextFormat::ParseFromString(
2850       R"proto(
2851         node { name: 'n' op: 'StatefulOutputRequired' device: '/device:CPU:0' }
2852         versions { producer: 9 }
2853       )proto",
2854       &graph);
2855 
2856   std::unique_ptr<Session> session(NewSession(SessionOptions()));
2857   ASSERT_TRUE(session != nullptr);
2858   TF_ASSERT_OK(session->Create(std::move(graph)));
2859 
2860   // As a stateful op, a single StatefulOutputRequired kernel will be created
2861   // and shared across multiple subgraphs. We create 5 different subgraphs,
2862   // fetching different prefixes of the output of the op.
2863   for (int num_outputs_required = 1; num_outputs_required <= 5;
2864        ++num_outputs_required) {
2865     std::vector<string> fetch_tensor_names;
2866     fetch_tensor_names.reserve(num_outputs_required);
2867     for (int output_idx = 0; output_idx < num_outputs_required; ++output_idx) {
2868       fetch_tensor_names.push_back(strings::StrCat("n:", output_idx));
2869     }
2870     std::vector<Tensor> fetch_tensors;
2871     TF_ASSERT_OK(session->Run({}, fetch_tensor_names, {}, &fetch_tensors));
2872     ASSERT_EQ(num_outputs_required, fetch_tensors.size());
2873     for (const Tensor& t : fetch_tensors) {
2874       ASSERT_EQ(num_outputs_required, t.scalar<int64>()());
2875     }
2876   }
2877 
2878   TF_ASSERT_OK(session->Close());
2879 }
2880 
2881 }  // namespace tensorflow
2882