• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/direct_session.h"
17 
18 #include <map>
19 #include <memory>
20 #include <random>
21 #include <string>
22 #include <thread>  // NOLINT
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "absl/memory/memory.h"
27 #include "absl/strings/match.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/function_testlib.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_testutil.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/graph/costmodel.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/node_builder.h"
40 #include "tensorflow/core/graph/testlib.h"
41 #include "tensorflow/core/kernels/ops_util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/core/threadpool.h"
46 #include "tensorflow/core/lib/strings/str_util.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/stacktrace.h"
49 #include "tensorflow/core/platform/test.h"
50 #include "tensorflow/core/platform/test_benchmark.h"
51 #include "tensorflow/core/protobuf/error_codes.pb.h"
52 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
53 #include "tensorflow/core/public/session.h"
54 #include "tensorflow/core/public/session_options.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 
57 #if GOOGLE_CUDA
58 #include "third_party/gpus/cuda/include/cuda.h"
59 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
60 #elif TENSORFLOW_USE_ROCM
61 #include "rocm/include/hip/hip_runtime.h"
62 #endif  // GOOGLE_CUDA
63 
64 namespace tensorflow {
65 namespace {
66 
MakeCallableOptions(gtl::ArraySlice<string> feeds,gtl::ArraySlice<string> fetches,gtl::ArraySlice<string> targets)67 CallableOptions MakeCallableOptions(gtl::ArraySlice<string> feeds,
68                                     gtl::ArraySlice<string> fetches,
69                                     gtl::ArraySlice<string> targets) {
70   CallableOptions ret;
71   for (const string& feed : feeds) {
72     ret.add_feed(feed);
73   }
74   for (const string& fetch : fetches) {
75     ret.add_fetch(fetch);
76   }
77   for (const string& target : targets) {
78     ret.add_target(target);
79   }
80   return ret;
81 }
82 
DefaultSessionOptions()83 SessionOptions DefaultSessionOptions() {
84   SessionOptions options;
85   (*options.config.mutable_device_count())["CPU"] = 2;
86   return options;
87 }
88 
CreateSession()89 std::unique_ptr<Session> CreateSession() {
90   return std::unique_ptr<Session>(NewSession(DefaultSessionOptions()));
91 }
92 
93 class DirectSessionMinusAXTest : public ::testing::Test {
94  public:
Initialize(std::initializer_list<float> a_values)95   void Initialize(std::initializer_list<float> a_values) {
96     Graph graph(OpRegistry::Global());
97 
98     Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
99     test::FillValues<float>(&a_tensor, a_values);
100     Node* a = test::graph::Constant(&graph, a_tensor);
101     a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
102     a_ = a->name();
103 
104     Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
105     test::FillValues<float>(&x_tensor, {1, 1});
106     Node* x = test::graph::Constant(&graph, x_tensor);
107     x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
108     x_ = x->name();
109 
110     // y = A * x
111     Node* y = test::graph::Matmul(&graph, a, x, false, false);
112     y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
113     y_ = y->name();
114 
115     Node* y_neg = test::graph::Unary(&graph, "Neg", y);
116     y_neg_ = y_neg->name();
117     y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
118 
119     Node* z = test::graph::Unary(&graph, "Identity", y_neg);
120     z_ = z->name();
121     z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
122 
123     graph.ToGraphDef(&def_);
124   }
125 
126   string a_;
127   string x_;
128   string y_;
129   string y_neg_;
130   string z_;
131   GraphDef def_;
132 };
133 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork)134 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
135   Initialize({3, 2, -1, 0});
136   auto session = CreateSession();
137   ASSERT_TRUE(session != nullptr);
138   TF_ASSERT_OK(session->Create(def_));
139   std::vector<std::pair<string, Tensor>> inputs;
140 
141   // Request two targets: one fetch output and one non-fetched output.
142   std::vector<string> output_names = {y_ + ":0"};
143   std::vector<string> target_nodes = {y_neg_};
144   std::vector<Tensor> outputs;
145   Status s = session->Run(inputs, output_names, target_nodes, &outputs);
146   TF_ASSERT_OK(s);
147 
148   ASSERT_EQ(1, outputs.size());
149   // The first output should be initialized and have the correct
150   // output.
151   auto mat = outputs[0].matrix<float>();
152   ASSERT_TRUE(outputs[0].IsInitialized());
153   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
154 }
155 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_Callable)156 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
157   Initialize({3, 2, -1, 0});
158   auto session = CreateSession();
159   ASSERT_TRUE(session != nullptr);
160   TF_ASSERT_OK(session->Create(def_));
161 
162   // Run the test twice to ensure that the Make/Run/Release cycle is hermetic.
163   for (int i = 0; i < 2; ++i) {
164     // Request two targets: one fetch output and one non-fetched output.
165     Session::CallableHandle handle;
166     TF_ASSERT_OK(session->MakeCallable(
167         MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
168 
169     for (int i = 0; i < 2; ++i) {
170       std::vector<Tensor> outputs;
171       TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
172 
173       ASSERT_EQ(1, outputs.size());
174       // The first output should be initialized and have the correct
175       // output.
176       auto mat = outputs[0].matrix<float>();
177       ASSERT_TRUE(outputs[0].IsInitialized());
178       EXPECT_FLOAT_EQ(5.0, mat(0, 0));
179     }
180 
181     Status s = session->RunCallable(handle, {}, nullptr, nullptr);
182     EXPECT_TRUE(errors::IsInvalidArgument(s));
183     EXPECT_TRUE(absl::StrContains(s.error_message(),
184                                   "`fetch_tensors` must be provided"));
185 
186     TF_ASSERT_OK(session->ReleaseCallable(handle));
187 
188     std::vector<Tensor> outputs;
189     s = session->RunCallable(handle, {}, &outputs, nullptr);
190     EXPECT_TRUE(errors::IsInvalidArgument(s));
191     EXPECT_TRUE(absl::StrContains(
192         s.error_message(),
193         "Attempted to run callable after handle was released"));
194 
195     s = session->RunCallable(handle + 1, {}, &outputs, nullptr);
196     EXPECT_TRUE(errors::IsInvalidArgument(s));
197     EXPECT_TRUE(
198         absl::StrContains(s.error_message(), "No such callable handle"));
199   }
200 }
201 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_OptimizeForStaticGraph)202 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_OptimizeForStaticGraph) {
203   Initialize({3, 2, -1, 0});
204   SessionOptions options(DefaultSessionOptions());
205   options.config.mutable_experimental()->set_optimize_for_static_graph(true);
206   auto session = absl::WrapUnique(NewSession(options));
207 
208   ASSERT_TRUE(session != nullptr);
209   TF_ASSERT_OK(session->Create(def_));
210   std::vector<std::pair<string, Tensor>> inputs;
211 
212   // Request two targets: one fetch output and one non-fetched output.
213   std::vector<string> output_names = {y_ + ":0"};
214   std::vector<string> target_nodes = {y_neg_};
215   std::vector<Tensor> outputs;
216   Status s = session->Run(inputs, output_names, target_nodes, &outputs);
217   TF_ASSERT_OK(s);
218 
219   ASSERT_EQ(1, outputs.size());
220   // The first output should be initialized and have the correct
221   // output.
222   auto mat = outputs[0].matrix<float>();
223   ASSERT_TRUE(outputs[0].IsInitialized());
224   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
225 
226   s = session->Extend({});
227   EXPECT_TRUE(errors::IsFailedPrecondition(s));
228   EXPECT_TRUE(
229       absl::StrContains(s.error_message(), "optimize_for_static_graph"));
230 }
231 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_DisableOutputPartitionGraphs)232 TEST_F(DirectSessionMinusAXTest,
233        RunSimpleNetwork_DisableOutputPartitionGraphs) {
234   Initialize({3, 2, -1, 0});
235   SessionOptions options(DefaultSessionOptions());
236   options.config.mutable_experimental()->set_disable_output_partition_graphs(
237       true);
238   auto session = absl::WrapUnique(NewSession(options));
239 
240   ASSERT_TRUE(session != nullptr);
241   TF_ASSERT_OK(session->Create(def_));
242   std::vector<std::pair<string, Tensor>> inputs;
243 
244   // Request two targets: one fetch output and one non-fetched output.
245   std::vector<string> output_names = {y_ + ":0"};
246   std::vector<string> target_nodes = {y_neg_};
247   std::vector<Tensor> outputs;
248   Status s = session->Run(inputs, output_names, target_nodes, &outputs);
249   TF_ASSERT_OK(s);
250 
251   ASSERT_EQ(1, outputs.size());
252   // The first output should be initialized and have the correct
253   // output.
254   auto mat = outputs[0].matrix<float>();
255   ASSERT_TRUE(outputs[0].IsInitialized());
256   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
257 
258   // The Run() call should fail when `output_partition_graphs` is set to true.
259   RunOptions run_options;
260   run_options.set_output_partition_graphs(true);
261   RunMetadata run_metadata;
262   s = session->Run(run_options, inputs, output_names, target_nodes, &outputs,
263                    &run_metadata);
264 
265   EXPECT_TRUE(errors::IsInvalidArgument(s));
266   EXPECT_TRUE(
267       absl::StrContains(s.error_message(), "disable_output_partition_graphs"));
268 }
269 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_FinalizeWithCallables)270 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithCallables) {
271   Initialize({3, 2, -1, 0});
272   auto session = CreateSession();
273   ASSERT_TRUE(session != nullptr);
274   TF_ASSERT_OK(session->Create(def_));
275 
276   // Request two targets: one fetch output and one non-fetched output.
277   Session::CallableHandle handle;
278   TF_ASSERT_OK(session->MakeCallable(
279       MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
280 
281   // Finalize the session.
282   TF_ASSERT_OK(session->Finalize());
283 
284   // The callable is usable after finalization.
285   for (int i = 0; i < 2; ++i) {
286     std::vector<Tensor> outputs;
287     TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
288 
289     ASSERT_EQ(1, outputs.size());
290     // The first output should be initialized and have the correct
291     // output.
292     auto mat = outputs[0].matrix<float>();
293     ASSERT_TRUE(outputs[0].IsInitialized());
294     EXPECT_FLOAT_EQ(5.0, mat(0, 0));
295   }
296 
297   TF_ASSERT_OK(session->ReleaseCallable(handle));
298 
299   // Making a new callable fails because the session has been finalized.
300   Status s =
301       session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle);
302   EXPECT_TRUE(errors::IsFailedPrecondition(s));
303   EXPECT_TRUE(
304       absl::StrContains(s.error_message(), "Session has been finalized."));
305 }
306 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_FinalizeWithRun)307 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithRun) {
308   Initialize({3, 2, -1, 0});
309   auto session = CreateSession();
310   ASSERT_TRUE(session != nullptr);
311   TF_ASSERT_OK(session->Create(def_));
312 
313   // Request two targets: one fetch output and one non-fetched output.
314   std::vector<Tensor> outputs;
315   TF_ASSERT_OK(session->Run({}, {y_ + ":0"}, {y_neg_}, &outputs));
316 
317   ASSERT_EQ(1, outputs.size());
318   // The first output should be initialized and have the correct output.
319   auto mat = outputs[0].matrix<float>();
320   ASSERT_TRUE(outputs[0].IsInitialized());
321   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
322 
323   // Finalize the session.
324   TF_ASSERT_OK(session->Finalize());
325 
326   // Running the exact same subgraph succeeds after finalization.
327   TF_ASSERT_OK(session->Run({}, {y_ + ":0"}, {y_neg_}, &outputs));
328   ASSERT_EQ(1, outputs.size());
329   mat = outputs[0].matrix<float>();
330   ASSERT_TRUE(outputs[0].IsInitialized());
331   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
332 
333   // Running a different subgraph fails because the session has been finalized.
334   Status s = session->Run({}, {y_ + ":0"}, {}, &outputs);
335   EXPECT_TRUE(errors::IsFailedPrecondition(s));
336   EXPECT_TRUE(
337       absl::StrContains(s.error_message(), "Session has been finalized."));
338 }
339 
TEST_F(DirectSessionMinusAXTest,TestTensorConnection)340 TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
341   Initialize({3, 2, -1, 0});
342   auto session = CreateSession();
343   ASSERT_TRUE(session != nullptr);
344   TF_ASSERT_OK(session->Create(def_));
345 
346   {
347     // Directly wire the output of node a to the output of node y, making the
348     // callable graph into "Neg(a);".
349     CallableOptions callable_options;
350     TensorConnection* c = callable_options.add_tensor_connection();
351     c->set_from_tensor(a_ + ":0");
352     c->set_to_tensor(y_ + ":0");
353     callable_options.add_fetch(y_neg_ + ":0");
354 
355     Session::CallableHandle handle;
356     TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
357     std::vector<Tensor> outputs;
358     TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
359     ASSERT_EQ(1, outputs.size());
360     auto mat = outputs[0].matrix<float>();
361     ASSERT_TRUE(outputs[0].IsInitialized());
362     EXPECT_FLOAT_EQ(-3.0, mat(0, 0));
363     EXPECT_FLOAT_EQ(-2.0, mat(0, 1));
364     EXPECT_FLOAT_EQ(1.0, mat(1, 0));
365     EXPECT_FLOAT_EQ(0.0, mat(1, 1));
366     TF_ASSERT_OK(session->ReleaseCallable(handle));
367   }
368 
369   {
370     // Directly wire the output of node a to the output of node y, making the
371     // callable graph into "Neg(a);"; also fetch the result of a.
372     CallableOptions callable_options;
373     TensorConnection* c = callable_options.add_tensor_connection();
374     c->set_from_tensor(a_ + ":0");
375     c->set_to_tensor(y_ + ":0");
376     callable_options.add_fetch(a_ + ":0");
377     callable_options.add_fetch(y_neg_ + ":0");
378 
379     Session::CallableHandle handle;
380     TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
381     std::vector<Tensor> outputs;
382     TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
383     ASSERT_EQ(2, outputs.size());
384     auto mat_a = outputs[0].matrix<float>();
385     ASSERT_TRUE(outputs[0].IsInitialized());
386     EXPECT_FLOAT_EQ(3.0, mat_a(0, 0));
387     EXPECT_FLOAT_EQ(2.0, mat_a(0, 1));
388     EXPECT_FLOAT_EQ(-1.0, mat_a(1, 0));
389     EXPECT_FLOAT_EQ(0.0, mat_a(1, 1));
390 
391     auto mat_y_neg = outputs[1].matrix<float>();
392     ASSERT_TRUE(outputs[1].IsInitialized());
393     EXPECT_FLOAT_EQ(-3.0, mat_y_neg(0, 0));
394     EXPECT_FLOAT_EQ(-2.0, mat_y_neg(0, 1));
395     EXPECT_FLOAT_EQ(1.0, mat_y_neg(1, 0));
396     EXPECT_FLOAT_EQ(0.0, mat_y_neg(1, 1));
397     TF_ASSERT_OK(session->ReleaseCallable(handle));
398   }
399 
400   {
401     // Wire the output of "Neg(Matmul(a, x))" to the output of "a",
402     // creating an invalid cycle.
403     CallableOptions callable_options;
404     TensorConnection* c = callable_options.add_tensor_connection();
405     c->set_from_tensor(y_ + ":0");
406     c->set_to_tensor(a_ + ":0");
407     callable_options.add_fetch(y_ + ":0");
408 
409     Session::CallableHandle handle;
410     Status s = session->MakeCallable(callable_options, &handle);
411     EXPECT_TRUE(errors::IsInvalidArgument(s));
412     EXPECT_TRUE(absl::StrContains(s.error_message(), "would create a cycle"));
413   }
414 
415   {
416     // Attempt to wire a non-existent node to a node that does exist.
417     CallableOptions callable_options;
418     TensorConnection* c = callable_options.add_tensor_connection();
419     c->set_from_tensor("unknown_node:0");
420     c->set_to_tensor(y_ + ":0");
421     callable_options.add_fetch(y_ + ":0");
422 
423     Session::CallableHandle handle;
424     Status s = session->MakeCallable(callable_options, &handle);
425     EXPECT_TRUE(errors::IsInvalidArgument(s));
426     EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown node"));
427   }
428 
429   {
430     // Attempt to wire a non-existent output from a node that does
431     // exist to another node.
432     CallableOptions callable_options;
433     TensorConnection* c = callable_options.add_tensor_connection();
434     c->set_from_tensor(a_ + ":17");
435     c->set_to_tensor(y_ + ":0");
436     callable_options.add_fetch(y_ + ":0");
437 
438     Session::CallableHandle handle;
439     Status s = session->MakeCallable(callable_options, &handle);
440     EXPECT_TRUE(errors::IsInvalidArgument(s));
441     EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown edge"));
442   }
443 
444   {
445     // Attempt to wire a tensor to a node that doesn't exist.
446     CallableOptions callable_options;
447     TensorConnection* c = callable_options.add_tensor_connection();
448     c->set_from_tensor(a_ + ":0");
449     c->set_to_tensor("unknown_node:0");
450     callable_options.add_fetch(y_ + ":0");
451 
452     Session::CallableHandle handle;
453     Status s = session->MakeCallable(callable_options, &handle);
454     EXPECT_TRUE(errors::IsNotFound(s));
455     EXPECT_TRUE(
456         absl::StrContains(s.error_message(), "unable to find feed output"));
457   }
458 
459   {
460     // Attempt to wire two tensors to the same tensor.
461     CallableOptions callable_options;
462     TensorConnection* c1 = callable_options.add_tensor_connection();
463     c1->set_from_tensor(a_ + ":0");
464     c1->set_to_tensor(y_neg_ + ":0");
465     TensorConnection* c2 = callable_options.add_tensor_connection();
466     c2->set_from_tensor(x_ + ":0");
467     c2->set_to_tensor(y_neg_ + ":0");
468     callable_options.add_fetch(z_ + ":0");
469 
470     Session::CallableHandle handle;
471     Status s = session->MakeCallable(callable_options, &handle);
472     EXPECT_TRUE(errors::IsInvalidArgument(s));
473     EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
474   }
475 
476   {
477     // Attempt to wire a tensor to a tensor that is also being fed.
478     CallableOptions callable_options;
479     TensorConnection* c = callable_options.add_tensor_connection();
480     c->set_from_tensor(a_ + ":0");
481     c->set_to_tensor(y_ + ":0");
482     callable_options.add_feed(y_ + ":0");
483     callable_options.add_fetch(y_neg_ + ":0");
484 
485     Session::CallableHandle handle;
486     Status s = session->MakeCallable(callable_options, &handle);
487     EXPECT_TRUE(errors::IsInvalidArgument(s));
488     EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
489   }
490 }
491 
TEST_F(DirectSessionMinusAXTest,TestFeed)492 TEST_F(DirectSessionMinusAXTest, TestFeed) {
493   Initialize({1, 2, 3, 4});
494   auto session = CreateSession();
495   ASSERT_TRUE(session != nullptr);
496 
497   TF_ASSERT_OK(session->Create(def_));
498 
499   // Fill in the input and ask for the output
500   //
501   // Note that the input being fed is on the second device.
502   Tensor t(DT_FLOAT, TensorShape({2, 1}));
503   t.matrix<float>()(0, 0) = 5;
504   t.matrix<float>()(1, 0) = 6;
505   std::vector<std::pair<string, Tensor>> inputs = {{x_, t}};
506   std::vector<string> output_names = {y_ + ":0"};
507   std::vector<Tensor> outputs;
508 
509   // Run the graph
510   Status s = session->Run(inputs, output_names, {}, &outputs);
511   TF_ASSERT_OK(s);
512 
513   ASSERT_EQ(1, outputs.size());
514   auto mat = outputs[0].matrix<float>();
515 
516   // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
517   EXPECT_FLOAT_EQ(17.0, mat(0, 0));
518   EXPECT_FLOAT_EQ(39.0, mat(1, 0));
519 }
520 
TEST_F(DirectSessionMinusAXTest,TestFeed_Callable)521 TEST_F(DirectSessionMinusAXTest, TestFeed_Callable) {
522   Initialize({1, 2, 3, 4});
523   auto session = CreateSession();
524   ASSERT_TRUE(session != nullptr);
525 
526   TF_ASSERT_OK(session->Create(def_));
527 
528   // Fill in the input and ask for the output
529   //
530   // Note that the input being fed is on the second device.
531   CallableOptions callable_options;
532   callable_options.add_feed(x_);
533   callable_options.add_fetch(y_ + ":0");
534   Session::CallableHandle handle;
535   TF_ASSERT_OK(session->MakeCallable(MakeCallableOptions({x_}, {y_ + ":0"}, {}),
536                                      &handle));
537   Tensor t(DT_FLOAT, TensorShape({2, 1}));
538   t.matrix<float>()(0, 0) = 5;
539   t.matrix<float>()(1, 0) = 6;
540   std::vector<Tensor> inputs = {t};
541   std::vector<Tensor> outputs;
542 
543   // Run the callable
544   TF_ASSERT_OK(session->RunCallable(handle, inputs, &outputs, nullptr));
545 
546   ASSERT_EQ(1, outputs.size());
547   auto mat = outputs[0].matrix<float>();
548 
549   // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
550   EXPECT_FLOAT_EQ(17.0, mat(0, 0));
551   EXPECT_FLOAT_EQ(39.0, mat(1, 0));
552 }
553 
TEST_F(DirectSessionMinusAXTest,TestConcurrency)554 TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
555   Initialize({1, 2, 3, 4});
556   auto session = CreateSession();
557   ASSERT_TRUE(session != nullptr);
558   TF_ASSERT_OK(session->Create(def_));
559 
560   // Fill in the input and ask for the output
561   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
562 
563   // Run the graph 1000 times in 4 different threads concurrently.
564   std::vector<string> output_names = {y_ + ":0"};
565   auto fn = [&session, output_names]() {
566     for (int i = 0; i < 1000; ++i) {
567       std::vector<std::pair<string, Tensor>> inputs;
568       std::vector<Tensor> outputs;
569       // Run the graph
570       Status s = session->Run(inputs, output_names, {}, &outputs);
571       TF_ASSERT_OK(s);
572       ASSERT_EQ(1, outputs.size());
573       auto mat = outputs[0].matrix<float>();
574       EXPECT_FLOAT_EQ(3.0, mat(0, 0));
575     }
576   };
577 
578   for (int i = 0; i < 4; ++i) {
579     tp->Schedule(fn);
580   }
581 
582   // Wait for the functions to finish.
583   delete tp;
584 }
585 
TEST_F(DirectSessionMinusAXTest,TestConcurrency_Callable)586 TEST_F(DirectSessionMinusAXTest, TestConcurrency_Callable) {
587   Initialize({1, 2, 3, 4});
588   auto session = CreateSession();
589   ASSERT_TRUE(session != nullptr);
590   TF_ASSERT_OK(session->Create(def_));
591 
592   // Fill in the input and ask for the output
593   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
594 
595   Session::CallableHandle handle;
596   TF_ASSERT_OK(
597       session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle));
598 
599   // Run the callable 1000 times in 4 different threads concurrently.
600   auto fn = [&session, handle]() {
601     for (int i = 0; i < 1000; ++i) {
602       std::vector<Tensor> outputs;
603       // Run the graph
604       TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
605       ASSERT_EQ(1, outputs.size());
606       auto mat = outputs[0].matrix<float>();
607       EXPECT_FLOAT_EQ(3.0, mat(0, 0));
608     }
609   };
610 
611   for (int i = 0; i < 4; ++i) {
612     tp->Schedule(fn);
613   }
614 
615   // Wait for the functions to finish.
616   delete tp;
617 }
618 
TEST_F(DirectSessionMinusAXTest,TestPerSessionThreads)619 TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) {
620   Initialize({1, 2, 3, 4});
621 
622   SessionOptions options;
623   options.config.set_use_per_session_threads(true);
624   (*options.config.mutable_device_count())["CPU"] = 2;
625   std::unique_ptr<Session> session(NewSession(options));
626 
627   ASSERT_TRUE(session != nullptr);
628   TF_ASSERT_OK(session->Create(def_));
629 
630   // Fill in the input and ask for the output
631   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
632 
633   // Run the graph 1000 times in 4 different threads concurrently.
634   std::vector<string> output_names = {y_ + ":0"};
635   auto fn = [&session, output_names]() {
636     for (int i = 0; i < 1000; ++i) {
637       std::vector<std::pair<string, Tensor>> inputs;
638       std::vector<Tensor> outputs;
639       // Run the graph
640       Status s = session->Run(inputs, output_names, {}, &outputs);
641       TF_ASSERT_OK(s);
642       ASSERT_EQ(1, outputs.size());
643       auto mat = outputs[0].matrix<float>();
644       EXPECT_FLOAT_EQ(3.0, mat(0, 0));
645     }
646   };
647 
648   for (int i = 0; i < 4; ++i) {
649     tp->Schedule(fn);
650   }
651 
652   // Wait for the functions to finish.
653   delete tp;
654 }
655 
TEST_F(DirectSessionMinusAXTest,TwoCreateCallsFails)656 TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
657   Initialize({1, 2, 3, 4});
658   auto session = CreateSession();
659   ASSERT_TRUE(session != nullptr);
660   TF_ASSERT_OK(session->Create(def_));
661 
662   // Second is not.
663   ASSERT_FALSE(session->Create(def_).ok());
664 }
665 
TEST_F(DirectSessionMinusAXTest,ForgetToCreate)666 TEST_F(DirectSessionMinusAXTest, ForgetToCreate) {
667   Initialize({1, 2, 3, 4});
668   auto session = CreateSession();
669   ASSERT_TRUE(session != nullptr);
670   std::vector<std::pair<string, Tensor>> inputs;
671   std::vector<Tensor> outputs;
672   ASSERT_FALSE(session->Run(inputs, {y_ + ":0"}, {y_neg_}, &outputs).ok());
673 }
674 
TEST_F(DirectSessionMinusAXTest,InvalidDevice)675 TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
676   GraphDef def;
677   Graph graph(OpRegistry::Global());
678 
679   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
680   a_tensor.flat<float>().setRandom();
681   Node* a = test::graph::Constant(&graph, a_tensor);
682   a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
683   Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
684   x_tensor.flat<float>().setRandom();
685   Node* x = test::graph::Constant(&graph, x_tensor);
686   x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
687   // Skip placing y.
688   Node* y = test::graph::Matmul(&graph, a, x, false, false);
689   y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:2");
690 
691   graph.ToGraphDef(&def);
692 
693   SessionOptions options;
694   (*options.config.mutable_device_count())["CPU"] = 2;
695   std::unique_ptr<Session> session(NewSession(options));
696   ASSERT_TRUE(session != nullptr);
697   // Should return an error.
698   ASSERT_FALSE(session->Create(def).ok());
699 
700   // Fix placement and run again
701   def.Clear();
702   y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
703   graph.ToGraphDef(&def);
704   session.reset(NewSession(options));
705   TF_ASSERT_OK(session->Create(def));
706   std::vector<Tensor> outputs;
707   TF_ASSERT_OK(session->Run({}, {y->name() + ":0"}, {}, &outputs));
708 }
709 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetworkWithOpts)710 TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
711   Initialize({3, 2, -1, 0});
712   auto session = CreateSession();
713   ASSERT_TRUE(session != nullptr);
714   TF_ASSERT_OK(session->Create(def_));
715   std::vector<std::pair<string, Tensor>> inputs;
716 
717   // Request two targets: one fetch output and one non-fetched output.
718   std::vector<string> output_names = {y_ + ":0"};
719   std::vector<string> target_nodes = {y_neg_};
720   std::vector<Tensor> outputs;
721 
722   // Prepares RunOptions and RunMetadata
723   RunOptions run_options;
724   run_options.set_trace_level(RunOptions::SOFTWARE_TRACE);
725   RunMetadata run_metadata;
726   EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
727 
728   Status s = session->Run(run_options, inputs, output_names, target_nodes,
729                           &outputs, &run_metadata);
730   TF_ASSERT_OK(s);
731 
732   ASSERT_EQ(1, outputs.size());
733   // The first output should be initialized and have the correct
734   // output.
735   auto mat = outputs[0].matrix<float>();
736   ASSERT_TRUE(outputs[0].IsInitialized());
737   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
738 
739   // Checks RunMetadata is well-formed
740   ASSERT_TRUE(run_metadata.has_step_stats());
741   EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
742 }
743 
TEST_F(DirectSessionMinusAXTest,RunSimpleNetworkWithOpts_Callable)744 TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
745   Initialize({3, 2, -1, 0});
746   auto session = CreateSession();
747   ASSERT_TRUE(session != nullptr);
748   TF_ASSERT_OK(session->Create(def_));
749 
750   // Request two targets: one fetch output and one non-fetched output.
751   Session::CallableHandle handle;
752   CallableOptions callable_options =
753       MakeCallableOptions({}, {y_ + ":0"}, {y_neg_});
754   callable_options.mutable_run_options()->set_trace_level(
755       RunOptions::SOFTWARE_TRACE);
756   TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
757 
758   RunMetadata run_metadata;
759   EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
760 
761   std::vector<Tensor> outputs;
762   TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, &run_metadata));
763 
764   ASSERT_EQ(1, outputs.size());
765   // The first output should be initialized and have the correct
766   // output.
767   auto mat = outputs[0].matrix<float>();
768   ASSERT_TRUE(outputs[0].IsInitialized());
769   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
770 
771   // Checks RunMetadata is well-formed
772   ASSERT_TRUE(run_metadata.has_step_stats());
773   EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
774 }
775 
TEST_F(DirectSessionMinusAXTest,UseRunHandlerPool)776 TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
777   Initialize({3, 2, -1, 0});
778   auto session = CreateSession();
779   ASSERT_TRUE(session != nullptr);
780   TF_ASSERT_OK(session->Create(def_));
781   std::vector<std::pair<string, Tensor>> inputs;
782 
783   // Request two targets: one fetch output and one non-fetched output.
784   std::vector<string> output_names = {y_ + ":0"};
785   std::vector<string> target_nodes = {y_neg_};
786   std::vector<Tensor> outputs;
787 
788   // Prepares RunOptions and RunMetadata
789   RunOptions run_options;
790   run_options.mutable_experimental()->set_use_run_handler_pool(true);
791 
792   Status s = session->Run(run_options, inputs, output_names, target_nodes,
793                           &outputs, nullptr);
794   TF_ASSERT_OK(s);
795 
796   ASSERT_EQ(1, outputs.size());
797   // The first output should be initialized and have the correct
798   // output.
799   auto mat = outputs[0].matrix<float>();
800   ASSERT_TRUE(outputs[0].IsInitialized());
801   EXPECT_FLOAT_EQ(5.0, mat(0, 0));
802 }
803 
TEST(DirectSessionTest,KeepsStateAcrossRunsOfSession)804 TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
805   GraphDef def;
806   Graph g(OpRegistry::Global());
807   Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
808   var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
809 
810   Tensor twenty(DT_FLOAT, TensorShape({10}));
811   for (int i = 0; i < 10; ++i) {
812     twenty.flat<float>()(i) = 20.0;
813   }
814 
815   Node* twenty_node = test::graph::Constant(&g, twenty);
816   twenty_node->set_assigned_device_name(
817       "/job:localhost/replica:0/task:0/cpu:0");
818 
819   Node* init = test::graph::Assign(&g, var, twenty_node);
820   init->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
821 
822   g.ToGraphDef(&def);
823 
824   auto session = CreateSession();
825   ASSERT_TRUE(session != nullptr);
826   TF_ASSERT_OK(session->Create(def));
827 
828   std::vector<std::pair<string, Tensor>> inputs;
829   std::vector<Tensor> outputs;
830 
831   // Initialize the variable
832   Status s = session->Run(inputs, {init->name()}, {}, &outputs);
833   TF_ASSERT_OK(s);
834 
835   // Get the variable's data
836   s = session->Run(inputs, {var->name() + ":0"}, {}, &outputs);
837   TF_ASSERT_OK(s);
838   ASSERT_EQ(1, outputs.size());
839   ASSERT_TRUE(outputs[0].IsInitialized());
840   EXPECT_EQ(20.0, outputs[0].flat<float>()(0));
841 }
842 
TEST(DirectSessionTest,MultipleFeedTest)843 TEST(DirectSessionTest, MultipleFeedTest) {
844   GraphDef def;
845   Graph g(OpRegistry::Global());
846 
847   Tensor first_value(DT_FLOAT, TensorShape({}));
848   first_value.scalar<float>()() = 1.0;
849   Node* first_const = test::graph::Constant(&g, first_value);
850   Node* first_identity = test::graph::Identity(&g, first_const);
851 
852   Tensor second_value(DT_FLOAT, TensorShape({}));
853   second_value.scalar<float>()() = 2.0;
854   Node* second_const = test::graph::Constant(&g, second_value);
855   Node* second_identity = test::graph::Identity(&g, second_const);
856 
857   g.ToGraphDef(&def);
858 
859   auto session = CreateSession();
860   ASSERT_TRUE(session != nullptr);
861   TF_ASSERT_OK(session->Create(def));
862 
863   std::vector<Tensor> outputs;
864 
865   // Fetch without feeding.
866   Status s = session->Run(
867       {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
868       &outputs);
869   TF_ASSERT_OK(s);
870   ASSERT_EQ(2, outputs.size());
871   ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
872   ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
873 
874   s = session->Run(
875       {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
876       &outputs);
877   TF_ASSERT_OK(s);
878   ASSERT_EQ(2, outputs.size());
879   ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
880   ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
881 
882   Tensor value_11(DT_FLOAT, TensorShape({}));
883   value_11.scalar<float>()() = 11.0;
884   Tensor value_22(DT_FLOAT, TensorShape({}));
885   value_22.scalar<float>()() = 22.0;
886 
887   // Feed [first_const, second_const]
888   s = session->Run(
889       {{first_const->name(), value_11}, {second_const->name(), value_22}},
890       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
891       &outputs);
892   TF_ASSERT_OK(s);
893   ASSERT_EQ(2, outputs.size());
894   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
895   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
896 
897   // Feed [second_const, first_const]
898   s = session->Run(
899       {{second_const->name(), value_22}, {first_const->name(), value_11}},
900       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
901       &outputs);
902   TF_ASSERT_OK(s);
903   ASSERT_EQ(2, outputs.size());
904   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
905   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
906 
907   // Feed [first_const, first_const]
908   s = session->Run(
909       {{first_const->name(), value_11}, {first_const->name(), value_22}},
910       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
911       &outputs);
912   EXPECT_TRUE(errors::IsInvalidArgument(s));
913   EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
914 }
915 
TEST(DirectSessionTest,MultipleFeedTest_Callable)916 TEST(DirectSessionTest, MultipleFeedTest_Callable) {
917   GraphDef def;
918   Graph g(OpRegistry::Global());
919 
920   Tensor first_value(DT_FLOAT, TensorShape({}));
921   first_value.scalar<float>()() = 1.0;
922   Node* first_const = test::graph::Constant(&g, first_value);
923   Node* first_identity = test::graph::Identity(&g, first_const);
924 
925   Tensor second_value(DT_FLOAT, TensorShape({}));
926   second_value.scalar<float>()() = 2.0;
927   Node* second_const = test::graph::Constant(&g, second_value);
928   Node* second_identity = test::graph::Identity(&g, second_const);
929 
930   g.ToGraphDef(&def);
931 
932   auto session = CreateSession();
933   ASSERT_TRUE(session != nullptr);
934   TF_ASSERT_OK(session->Create(def));
935 
936   Session::CallableHandle handle;
937   std::vector<Tensor> outputs;
938 
939   // Fetch without feeding.
940   TF_ASSERT_OK(session->MakeCallable(
941       MakeCallableOptions(
942           {}, {first_identity->name() + ":0", second_identity->name() + ":0"},
943           {}),
944       &handle));
945   TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
946   ASSERT_EQ(2, outputs.size());
947   ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
948   ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
949 
950   TF_ASSERT_OK(session->MakeCallable(
951       MakeCallableOptions(
952           {}, {second_identity->name() + ":0", first_identity->name() + ":0"},
953           {}),
954       &handle));
955   TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
956   ASSERT_EQ(2, outputs.size());
957   ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
958   ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
959 
960   Tensor value_11(DT_FLOAT, TensorShape({}));
961   value_11.scalar<float>()() = 11.0;
962   Tensor value_22(DT_FLOAT, TensorShape({}));
963   value_22.scalar<float>()() = 22.0;
964 
965   // Feed [first_const, second_const]
966   TF_ASSERT_OK(session->MakeCallable(
967       MakeCallableOptions(
968           {first_const->name(), second_const->name()},
969           {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
970       &handle));
971   TF_ASSERT_OK(
972       session->RunCallable(handle, {value_11, value_22}, &outputs, nullptr));
973   ASSERT_EQ(2, outputs.size());
974   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
975   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
976 
977   // Feed [second_const, first_const]
978   TF_ASSERT_OK(session->MakeCallable(
979       MakeCallableOptions(
980           {second_const->name(), first_const->name()},
981           {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
982       &handle));
983   TF_ASSERT_OK(
984       session->RunCallable(handle, {value_22, value_11}, &outputs, nullptr));
985   ASSERT_EQ(2, outputs.size());
986   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
987   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
988 
989   // Feed [first_const, first_const]
990   Status s = session->MakeCallable(
991       MakeCallableOptions(
992           {first_const->name(), first_const->name()},
993           {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
994       &handle);
995   EXPECT_TRUE(errors::IsInvalidArgument(s));
996   EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
997 }
998 
TEST(DirectSessionTest,TestTensorConnectionUseTwice)999 TEST(DirectSessionTest, TestTensorConnectionUseTwice) {
1000   Graph graph(OpRegistry::Global());
1001 
1002   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
1003   test::FillValues<float>(&a_tensor, {1.0, 2.0, 3.0, 4.0});
1004   Node* a = test::graph::Constant(&graph, a_tensor);
1005 
1006   Tensor dummy_tensor(DT_FLOAT, TensorShape({1}));
1007   test::FillValues<float>(&dummy_tensor, {-1.0});
1008 
1009   Node* left = test::graph::Constant(&graph, dummy_tensor);
1010   Node* right = test::graph::Constant(&graph, dummy_tensor);
1011 
1012   // y = A * x
1013   Node* y = test::graph::Add(&graph, left, right);
1014 
1015   GraphDef def;
1016   graph.ToGraphDef(&def);
1017 
1018   auto session = CreateSession();
1019   ASSERT_TRUE(session != nullptr);
1020   TF_ASSERT_OK(session->Create(def));
1021 
1022   CallableOptions callable_options;
1023   // Directly wire the output of node a to the outputs of nodes left
1024   // and right, making the callable graph into "a + a;".
1025   TensorConnection* c_left = callable_options.add_tensor_connection();
1026   c_left->set_from_tensor(a->name() + ":0");
1027   c_left->set_to_tensor(left->name() + ":0");
1028   TensorConnection* c_right = callable_options.add_tensor_connection();
1029   c_right->set_from_tensor(a->name() + ":0");
1030   c_right->set_to_tensor(right->name() + ":0");
1031 
1032   callable_options.add_fetch(y->name() + ":0");
1033 
1034   Session::CallableHandle handle;
1035   TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
1036   std::vector<Tensor> outputs;
1037   TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
1038   ASSERT_EQ(1, outputs.size());
1039   auto mat = outputs[0].matrix<float>();
1040   ASSERT_TRUE(outputs[0].IsInitialized());
1041   EXPECT_FLOAT_EQ(2.0, mat(0, 0));
1042   EXPECT_FLOAT_EQ(4.0, mat(0, 1));
1043   EXPECT_FLOAT_EQ(6.0, mat(1, 0));
1044   EXPECT_FLOAT_EQ(8.0, mat(1, 1));
1045   TF_ASSERT_OK(session->ReleaseCallable(handle));
1046 }
1047 
TEST(DirectSessionTest,FetchMultipleTimes)1048 TEST(DirectSessionTest, FetchMultipleTimes) {
1049   Graph g(OpRegistry::Global());
1050   Tensor seven_tensor(DT_INT32, TensorShape());
1051   seven_tensor.flat<int32>()(0) = 7;
1052   Node* seven_node = test::graph::Constant(&g, seven_tensor);
1053 
1054   GraphDef def;
1055   g.ToGraphDef(&def);
1056 
1057   auto session = CreateSession();
1058   ASSERT_TRUE(session != nullptr);
1059   TF_ASSERT_OK(session->Create(def));
1060 
1061   const std::vector<std::pair<string, Tensor>> inputs;
1062   std::vector<Tensor> outputs;
1063 
1064   auto seven = seven_node->name();
1065   Status s = session->Run(inputs, {seven, seven}, {}, &outputs);
1066   TF_ASSERT_OK(s);
1067 
1068   EXPECT_EQ(2, outputs.size());
1069   for (int i = 0; i < outputs.size(); ++i) {
1070     const Tensor& t = outputs[i];
1071     ASSERT_TRUE(t.IsInitialized()) << i;
1072     EXPECT_EQ(7, t.flat<int32>()(0)) << i;
1073   }
1074 }
1075 
TEST(DirectSessionTest,MultipleFeedTestSomeSyncRun)1076 TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) {
1077   GraphDef def;
1078   Graph g(OpRegistry::Global());
1079   RunOptions run_options;
1080   run_options.set_inter_op_thread_pool(-1);
1081 
1082   Tensor first_value(DT_FLOAT, TensorShape({}));
1083   first_value.scalar<float>()() = 1.0;
1084   Node* first_const = test::graph::Constant(&g, first_value);
1085   Node* first_identity = test::graph::Identity(&g, first_const);
1086 
1087   Tensor second_value(DT_FLOAT, TensorShape({}));
1088   second_value.scalar<float>()() = 2.0;
1089   Node* second_const = test::graph::Constant(&g, second_value);
1090   Node* second_identity = test::graph::Identity(&g, second_const);
1091 
1092   g.ToGraphDef(&def);
1093 
1094   auto session = CreateSession();
1095   ASSERT_TRUE(session != nullptr);
1096   TF_ASSERT_OK(session->Create(def));
1097 
1098   std::vector<Tensor> outputs;
1099 
1100   // Fetch without feeding.
1101   Status s = session->Run(
1102       run_options, {},
1103       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1104       &outputs, nullptr);
1105   TF_ASSERT_OK(s);
1106   ASSERT_EQ(2, outputs.size());
1107   ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
1108   ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
1109 
1110   s = session->Run(
1111       {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
1112       &outputs);
1113   TF_ASSERT_OK(s);
1114   ASSERT_EQ(2, outputs.size());
1115   ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
1116   ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
1117 
1118   Tensor value_11(DT_FLOAT, TensorShape({}));
1119   value_11.scalar<float>()() = 11.0;
1120   Tensor value_22(DT_FLOAT, TensorShape({}));
1121   value_22.scalar<float>()() = 22.0;
1122 
1123   // Feed [first_const, second_const]
1124   s = session->Run(
1125       {{first_const->name(), value_11}, {second_const->name(), value_22}},
1126       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1127       &outputs);
1128   TF_ASSERT_OK(s);
1129   ASSERT_EQ(2, outputs.size());
1130   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1131   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
1132 
1133   // Feed [second_const, first_const]
1134   s = session->Run(
1135       {{second_const->name(), value_22}, {first_const->name(), value_11}},
1136       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1137       &outputs);
1138   TF_ASSERT_OK(s);
1139   ASSERT_EQ(2, outputs.size());
1140   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1141   ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
1142 
1143   // Feed [first_const, first_const]
1144   s = session->Run(
1145       run_options,
1146       {{first_const->name(), value_11}, {first_const->name(), value_22}},
1147       {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1148       &outputs, nullptr);
1149   EXPECT_TRUE(errors::IsInvalidArgument(s));
1150   EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
1151 }
1152 
1153 REGISTER_OP("SessionMetadataReader")
1154     .Input("x: int64")
1155     .Output("y: string")
1156     .SetIsStateful()
1157     .Doc(R"doc(SessionMetadataReader returns the session metadata.
1158 
1159 x: int64
1160 y: string
1161 )doc");
1162 
1163 class SessionMetadataReaderOp : public OpKernel {
1164  public:
SessionMetadataReaderOp(OpKernelConstruction * ctx)1165   explicit SessionMetadataReaderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1166   void Compute(OpKernelContext* ctx) override {
1167     Tensor* out_tensor = nullptr;
1168     OP_REQUIRES_OK(ctx,
1169                    ctx->allocate_output("y", TensorShape({}), &out_tensor));
1170     if (ctx->session_metadata() != nullptr) {
1171       out_tensor->scalar<tstring>()() = ctx->session_metadata()->DebugString();
1172     } else {
1173       out_tensor->scalar<tstring>()() = "";
1174     }
1175   }
1176 };
1177 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_CPU),
1178                         SessionMetadataReaderOp);
1179 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_GPU),
1180                         SessionMetadataReaderOp);
1181 
SessionMetadataReaderOpFn()1182 FunctionDef SessionMetadataReaderOpFn() {
1183   return FunctionDefHelper::Define(
1184       // Name
1185       "SessionMetadataReaderFn",
1186       // Args
1187       {"x: int64"},
1188       // Return values
1189       {"y: string"},
1190       // Attr def
1191       {},
1192       // Nodes
1193       {{{"y"}, "SessionMetadataReader", {"x"}, {}}});
1194 }
1195 
TEST(DirectSessionTest,SessionMetadataAbsent)1196 TEST(DirectSessionTest, SessionMetadataAbsent) {
1197   Graph g(OpRegistry::Global());
1198   Tensor vx(DT_INT64, TensorShape({}));
1199   vx.scalar<int64_t>()() = 17;
1200   Node* x = test::graph::Constant(&g, vx);
1201   Node* y = test::graph::Unary(&g, "SessionMetadataReader", x);
1202   GraphDef def;
1203   g.ToGraphDef(&def);
1204   auto sess = CreateSession();
1205   TF_ASSERT_OK(sess->Create(def));
1206   std::vector<Tensor> outputs;
1207   RunOptions run_opts;
1208   run_opts.set_inter_op_thread_pool(-1);
1209   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1210 
1211   EXPECT_EQ("", outputs[0].scalar<tstring>()());
1212 }
1213 
TEST(DirectSessionTest,SessionMetadataAbsentViaFunction)1214 TEST(DirectSessionTest, SessionMetadataAbsentViaFunction) {
1215   FunctionDefLibrary library_graph_def;
1216   *library_graph_def.add_function() = SessionMetadataReaderOpFn();
1217   FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1218   Graph g(&flib);
1219   Tensor vx(DT_INT64, TensorShape({}));
1220   vx.scalar<int64_t>()() = 17;
1221   Node* x = test::graph::Constant(&g, vx);
1222   Node* y = test::graph::Unary(&g, "SessionMetadataReaderFn", x);
1223   GraphDef def;
1224   g.ToGraphDef(&def);
1225   *def.mutable_library() = library_graph_def;
1226   auto sess = CreateSession();
1227   TF_ASSERT_OK(sess->Create(def));
1228   std::vector<Tensor> outputs;
1229   RunOptions run_opts;
1230   run_opts.set_inter_op_thread_pool(-1);
1231   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1232 
1233   EXPECT_EQ("", outputs[0].scalar<tstring>()());
1234 }
1235 
TEST(DirectSessionTest,SessionMetadataPresent)1236 TEST(DirectSessionTest, SessionMetadataPresent) {
1237   Graph g(OpRegistry::Global());
1238   Tensor vx(DT_INT64, TensorShape({}));
1239   vx.scalar<int64_t>()() = 17;
1240   Node* x = test::graph::Constant(&g, vx);
1241   Node* y = test::graph::Unary(&g, "SessionMetadataReader", x);
1242   GraphDef def;
1243   g.ToGraphDef(&def);
1244   auto session_options = DefaultSessionOptions();
1245   auto* session_metadata =
1246       session_options.config.mutable_experimental()->mutable_session_metadata();
1247   session_metadata->set_name("name");
1248   session_metadata->set_version(1);
1249   auto sess = std::unique_ptr<Session>(NewSession(session_options));
1250   TF_ASSERT_OK(sess->Create(def));
1251   std::vector<Tensor> outputs;
1252   RunOptions run_opts;
1253   run_opts.set_inter_op_thread_pool(-1);
1254 
1255   // Run without requesting RunMetadata.
1256   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs,
1257                      nullptr /*=run_metadata*/);
1258   SessionMetadata read_metadata;
1259   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1260       outputs[0].scalar<tstring>()(), &read_metadata));
1261   EXPECT_EQ("name", read_metadata.name());
1262   EXPECT_EQ(1, read_metadata.version());
1263 
1264   // Run with requesting RunMetadata.
1265   RunMetadata metadata;
1266   s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs,
1267                 &metadata /*=run_metadata*/);
1268   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1269       outputs[0].scalar<tstring>()(), &read_metadata));
1270   EXPECT_EQ("name", read_metadata.name());
1271   EXPECT_EQ(1, read_metadata.version());
1272   EXPECT_EQ(session_metadata->name(), metadata.session_metadata().name());
1273   EXPECT_EQ(session_metadata->version(), metadata.session_metadata().version());
1274 }
1275 
TEST(DirectSessionTest,SessionMetadataPresentViaFunction)1276 TEST(DirectSessionTest, SessionMetadataPresentViaFunction) {
1277   FunctionDefLibrary library_graph_def;
1278   *library_graph_def.add_function() = SessionMetadataReaderOpFn();
1279   FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1280   Graph g(&flib);
1281   Tensor vx(DT_INT64, TensorShape({}));
1282   vx.scalar<int64_t>()() = 17;
1283   Node* x = test::graph::Constant(&g, vx);
1284   Node* y = test::graph::Unary(&g, "SessionMetadataReaderFn", x);
1285   GraphDef def;
1286   g.ToGraphDef(&def);
1287   *def.mutable_library() = library_graph_def;
1288   auto session_options = DefaultSessionOptions();
1289   auto* session_metadata =
1290       session_options.config.mutable_experimental()->mutable_session_metadata();
1291   session_metadata->set_name("name");
1292   session_metadata->set_version(1);
1293   auto sess = std::unique_ptr<Session>(NewSession(session_options));
1294   TF_ASSERT_OK(sess->Create(def));
1295   std::vector<Tensor> outputs;
1296   RunOptions run_opts;
1297   run_opts.set_inter_op_thread_pool(-1);
1298 
1299   // Run without requesting RunMetadata.
1300   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs,
1301                      nullptr /*=run_metadata*/);
1302   SessionMetadata read_metadata;
1303   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1304       outputs[0].scalar<tstring>()(), &read_metadata));
1305   EXPECT_EQ("name", read_metadata.name());
1306   EXPECT_EQ(1, read_metadata.version());
1307 
1308   // Run with requesting RunMetadata.
1309   RunMetadata metadata;
1310   s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs,
1311                 &metadata /*=run_metadata*/);
1312   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1313       outputs[0].scalar<tstring>()(), &read_metadata));
1314   EXPECT_EQ("name", read_metadata.name());
1315   EXPECT_EQ(1, read_metadata.version());
1316   EXPECT_EQ(session_metadata->name(), metadata.session_metadata().name());
1317   EXPECT_EQ(session_metadata->version(), metadata.session_metadata().version());
1318 }
1319 
TEST(DirectSessionTest,SessionMetadataKey)1320 TEST(DirectSessionTest, SessionMetadataKey) {
1321   auto session_options0 = DefaultSessionOptions();
1322   auto* session_metadata0 = session_options0.config.mutable_experimental()
1323                                 ->mutable_session_metadata();
1324   session_metadata0->set_name("name");
1325   Session* sess0_ptr;
1326   ASSERT_TRUE(NewSession(session_options0, &sess0_ptr).ok());
1327   auto sess0 = absl::WrapUnique(sess0_ptr);
1328 
1329   // Trying to use the same metadata (name, version) will cause an error.
1330   Session* dup_ptr;
1331   EXPECT_TRUE(
1332       errors::IsInvalidArgument(NewSession(session_options0, &dup_ptr)));
1333 
1334   // A new (name, version) is fine.
1335   auto session_options1 = DefaultSessionOptions();
1336   auto* session_metadata1 = session_options1.config.mutable_experimental()
1337                                 ->mutable_session_metadata();
1338   session_metadata1->set_name("name");
1339   session_metadata1->set_version(1);
1340   Session* sess1_ptr;
1341   EXPECT_TRUE(NewSession(session_options1, &sess1_ptr).ok());
1342   auto sess1 = absl::WrapUnique(sess1_ptr);
1343 
1344   // If the previous session, using the same (name, version) is gone, then it's
1345   // fine.
1346   sess0 = nullptr;
1347   EXPECT_TRUE(NewSession(session_options0, &dup_ptr).ok());
1348   auto dup = absl::WrapUnique(dup_ptr);
1349 
1350   // Sessions without metadata options are always fine.
1351   auto sess_without_metadata0 = CreateSession();
1352   EXPECT_NE(sess_without_metadata0, nullptr);
1353   auto sess_without_metadata1 = CreateSession();
1354   EXPECT_NE(sess_without_metadata1, nullptr);
1355 }
1356 
TEST(DirectSessionTest,SessionMetadataInvalid)1357 TEST(DirectSessionTest, SessionMetadataInvalid) {
1358   const auto valid_session_options = DefaultSessionOptions();
1359   Session* sess_ptr;
1360   ASSERT_TRUE(NewSession(valid_session_options, &sess_ptr).ok());
1361   auto sess = absl::WrapUnique(sess_ptr);
1362 
1363   auto invalid_session_options = valid_session_options;
1364   auto* invalid_metadata =
1365       invalid_session_options.config.mutable_experimental()
1366           ->mutable_session_metadata();
1367   invalid_metadata->set_name("name");
1368   // Version should be >= 0.
1369   invalid_metadata->set_version(-1);
1370   Session* error_sess_ptr;
1371   EXPECT_TRUE(errors::IsInvalidArgument(
1372       NewSession(invalid_session_options, &error_sess_ptr)));
1373 }
1374 
1375 REGISTER_OP("ThreadID").Input("x: int64").Output("y: int64").Doc(R"doc(
1376 ThreadID returns the thread ID that called compute.
1377 
1378 x: int64
1379 y: int64
1380 )doc");
1381 
1382 // The ThreadID kernel returns the thread ID that executed Compute.
1383 class ThreadIDOp : public OpKernel {
1384  public:
ThreadIDOp(OpKernelConstruction * ctx)1385   explicit ThreadIDOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1386   void Compute(OpKernelContext* ctx) override {
1387     Tensor* out_tensor = nullptr;
1388     OP_REQUIRES_OK(ctx,
1389                    ctx->allocate_output("y", TensorShape({}), &out_tensor));
1390     std::hash<std::thread::id> hasher;
1391     out_tensor->scalar<int64_t>()() =
1392         static_cast<int64_t>(hasher(std::this_thread::get_id()));
1393   }
1394 };
1395 REGISTER_KERNEL_BUILDER(Name("ThreadID").Device(DEVICE_CPU), ThreadIDOp);
1396 
TEST(DirectSessionTest,SessionSyncRun)1397 TEST(DirectSessionTest, SessionSyncRun) {
1398   Graph g(OpRegistry::Global());
1399   Tensor vx(DT_INT64, TensorShape({}));
1400   vx.scalar<int64_t>()() = 17;
1401   Node* x = test::graph::Constant(&g, vx);
1402   Node* y = test::graph::Unary(&g, "ThreadID", x);
1403   GraphDef def;
1404   g.ToGraphDef(&def);
1405   auto sess = CreateSession();
1406   TF_ASSERT_OK(sess->Create(def));
1407   std::vector<Tensor> outputs;
1408   RunOptions run_opts;
1409   run_opts.set_inter_op_thread_pool(-1);
1410   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1411 
1412   std::hash<std::thread::id> hasher;
1413   EXPECT_EQ(static_cast<int64_t>(hasher(std::this_thread::get_id())),
1414             static_cast<int64_t>(outputs[0].scalar<int64_t>()()));
1415 }
1416 
1417 REGISTER_OP("ExpensiveNoop").SetIsStateful();
1418 
1419 class ExpensiveNoopOp : public OpKernel {
1420  public:
1421   using OpKernel::OpKernel;
IsExpensive()1422   bool IsExpensive() override { return true; }
Compute(OpKernelContext * ctx)1423   void Compute(OpKernelContext* ctx) override {
1424     const string& stack_trace = tensorflow::CurrentStackTrace();
1425     const string process_method = "ExecutorState::Process()";
1426     size_t pos = 0;
1427     int frame_count = 0;
1428     while ((pos = stack_trace.find("ExecutorState::Process()", pos)) !=
1429            string::npos) {
1430       ++frame_count;
1431       ++pos;
1432     }
1433     OP_REQUIRES(ctx, frame_count <= 1,
1434                 errors::Internal(
1435                     "Recursive call to ExecutorState::Process() detected."));
1436   }
1437 };
1438 
1439 REGISTER_KERNEL_BUILDER(Name("ExpensiveNoop").Device(DEVICE_CPU),
1440                         ExpensiveNoopOp);
1441 
TEST(DirectSessionTest,SessionSyncRun_DeepGraph)1442 TEST(DirectSessionTest, SessionSyncRun_DeepGraph) {
1443   Graph g(OpRegistry::Global());
1444 
1445   std::vector<Node*> nodes;
1446   nodes.reserve(1024);
1447 
1448   auto make_expensive_noop = [&g](gtl::ArraySlice<Node*> control_deps) {
1449     Node* ret;
1450     auto builder = NodeBuilder(g.NewName("N"), "ExpensiveNoop");
1451     for (Node* control_dep : control_deps) {
1452       builder = builder.ControlInput(control_dep);
1453     }
1454     TF_CHECK_OK(builder.Finalize(&g, &ret));
1455     return ret;
1456   };
1457 
1458   Node* base = make_expensive_noop({});
1459 
1460   Node* child_1 = make_expensive_noop({base});
1461   Node* child_2 = make_expensive_noop({base});
1462 
1463   GraphDef def;
1464   g.ToGraphDef(&def);
1465 
1466   auto sess = CreateSession();
1467   TF_ASSERT_OK(sess->Create(def));
1468   std::vector<Tensor> outputs;
1469   RunOptions run_opts;
1470   run_opts.set_inter_op_thread_pool(-1);
1471 
1472   EXPECT_TRUE(sess->Run(run_opts, {}, {}, {child_1->name(), child_2->name()},
1473                         &outputs, nullptr)
1474                   .ok());
1475 }
1476 
TEST(DirectSessionTest,SyncSession)1477 TEST(DirectSessionTest, SyncSession) {
1478   Graph g(OpRegistry::Global());
1479   Tensor vx(DT_INT64, TensorShape({}));
1480   vx.scalar<int64_t>()() = 17;
1481   Node* x = test::graph::Constant(&g, vx);
1482   Node* y = test::graph::Unary(&g, "ThreadID", x);
1483   GraphDef def;
1484   g.ToGraphDef(&def);
1485   SessionOptions options;
1486   options.config.set_inter_op_parallelism_threads(-1);
1487   std::unique_ptr<Session> sess(NewSession(options));
1488   TF_ASSERT_OK(sess->Create(def));
1489   std::vector<Tensor> outputs;
1490   RunOptions run_opts;
1491   auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1492 
1493   std::hash<std::thread::id> hasher;
1494   EXPECT_EQ(static_cast<int64_t>(hasher(std::this_thread::get_id())),
1495             static_cast<int64_t>(outputs[0].scalar<int64_t>()()));
1496 }
1497 
1498 REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc(
1499 Darth promises one return value.
1500 
1501 x: float
1502 y: float
1503 )doc");
1504 
1505 // The DarthOp kernel violates its promise to return one-value.
1506 class DarthOp : public OpKernel {
1507  public:
DarthOp(OpKernelConstruction * ctx)1508   explicit DarthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1509   void Compute(OpKernelContext* ctx) override {}
1510 };
1511 REGISTER_KERNEL_BUILDER(Name("Darth").Device(DEVICE_CPU), DarthOp);
1512 
TEST(DirectSessionTest,DarthKernel)1513 TEST(DirectSessionTest, DarthKernel) {
1514   Graph g(OpRegistry::Global());
1515   Tensor vx(DT_FLOAT, TensorShape({}));
1516   vx.scalar<float>()() = 1.0;
1517   Node* x = test::graph::Constant(&g, vx);
1518   Node* y = test::graph::Unary(&g, "Darth", x);
1519   GraphDef def;
1520   g.ToGraphDef(&def);
1521   auto sess = CreateSession();
1522   TF_ASSERT_OK(sess->Create(def));
1523   std::vector<Tensor> outputs;
1524   auto s = sess->Run({}, {y->name() + ":0"}, {}, &outputs);
1525   EXPECT_TRUE(errors::IsInternal(s));
1526 }
1527 
1528 // Have the Darth op in the graph placed on GPU, but don't run it.
TEST(DirectSessionTest,PlacePrunedGraph)1529 TEST(DirectSessionTest, PlacePrunedGraph) {
1530   {
1531     Graph g(OpRegistry::Global());
1532     Tensor vx(DT_FLOAT, TensorShape({}));
1533     vx.scalar<float>()() = 1.0;
1534     Node* x = test::graph::Constant(&g, vx);
1535     Node* y = test::graph::Unary(&g, "Darth", x);
1536     y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
1537     GraphDef def;
1538     g.ToGraphDef(&def);
1539 
1540     // By default, we place the entire graph, so we should fail the
1541     // call to Create.
1542     SessionOptions options;
1543     std::unique_ptr<Session> sess(NewSession(options));
1544     auto s = sess->Create(def);
1545     EXPECT_TRUE(errors::IsInvalidArgument(s));
1546   }
1547 
1548   {
1549     Graph g(OpRegistry::Global());
1550     Tensor vx(DT_FLOAT, TensorShape({}));
1551     vx.scalar<float>()() = 1.0;
1552     Node* x = test::graph::Constant(&g, vx);
1553     Node* y = test::graph::Unary(&g, "Darth", x);
1554     y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
1555     GraphDef def;
1556     g.ToGraphDef(&def);
1557 
1558     SessionOptions options;
1559     // Set the option to place pruned graphs, we should expect this
1560     // to run.
1561     options.config.mutable_graph_options()->set_place_pruned_graph(true);
1562     std::unique_ptr<Session> sess(NewSession(options));
1563     TF_ASSERT_OK(sess->Create(def));
1564     std::vector<Tensor> outputs;
1565     auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs);
1566     TF_EXPECT_OK(s);
1567   }
1568 }
1569 
TEST(DirectSessionTest,PartialRunTest)1570 TEST(DirectSessionTest, PartialRunTest) {
1571   GraphDef def;
1572   Graph g(OpRegistry::Global());
1573 
1574   Tensor first_value(DT_FLOAT, TensorShape({}));
1575   first_value.scalar<float>()() = 1.0;
1576   Node* first_const = test::graph::Constant(&g, first_value);
1577   Node* first_identity = test::graph::Identity(&g, first_const);
1578 
1579   Tensor second_value(DT_FLOAT, TensorShape({}));
1580   second_value.scalar<float>()() = 2.0;
1581   Node* second_const = test::graph::Constant(&g, second_value);
1582   Node* second_identity = test::graph::Identity(&g, second_const);
1583 
1584   Node* third = test::graph::Add(&g, first_identity, second_identity);
1585   Node* third_identity = test::graph::Identity(&g, third);
1586 
1587   g.ToGraphDef(&def);
1588 
1589   auto session = CreateSession();
1590   ASSERT_TRUE(session != nullptr);
1591   TF_ASSERT_OK(session->Create(def));
1592 
1593   std::vector<Tensor> outputs;
1594 
1595   string handle;
1596   Status s = session->PRunSetup(
1597       {first_const->name(), second_const->name()},
1598       {first_identity->name() + ":0", second_identity->name() + ":0",
1599        third_identity->name() + ":0"},
1600       {}, &handle);
1601   TF_ASSERT_OK(s);
1602 
1603   Tensor value_11(DT_FLOAT, TensorShape({}));
1604   value_11.scalar<float>()() = 11.0;
1605   Tensor value_22(DT_FLOAT, TensorShape({}));
1606   value_22.scalar<float>()() = 22.0;
1607 
1608   // Feed first_const, fetch first_identity
1609   s = session->PRun(handle, {{first_const->name(), value_11}},
1610                     {first_identity->name() + ":0"}, &outputs);
1611   TF_ASSERT_OK(s);
1612   ASSERT_EQ(1, outputs.size());
1613   ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1614 
1615   // Feed second_const, fetch second_identity and third_identity
1616   s = session->PRun(
1617       handle, {{second_const->name(), value_22}},
1618       {second_identity->name() + ":0", third_identity->name() + ":0"},
1619       &outputs);
1620   TF_ASSERT_OK(s);
1621   ASSERT_EQ(2, outputs.size());
1622   ASSERT_EQ(22.0, outputs[0].flat<float>()(0));
1623   ASSERT_EQ(11.0 + 22.0, outputs[1].flat<float>()(0));
1624 }
1625 
TEST(DirectSessionTest,PartialRunMissingFeed)1626 TEST(DirectSessionTest, PartialRunMissingFeed) {
1627   GraphDef def;
1628   Graph g(OpRegistry::Global());
1629 
1630   Tensor first_value(DT_FLOAT, TensorShape({}));
1631   first_value.scalar<float>()() = 1.0;
1632   Node* first_const = test::graph::Constant(&g, first_value);
1633   Node* first_identity = test::graph::Identity(&g, first_const);
1634 
1635   Tensor second_value(DT_FLOAT, TensorShape({}));
1636   second_value.scalar<float>()() = 2.0;
1637   Node* second_const = test::graph::Constant(&g, second_value);
1638   Node* second_identity = test::graph::Identity(&g, second_const);
1639 
1640   Node* third = test::graph::Add(&g, first_identity, second_identity);
1641   Node* third_identity = test::graph::Identity(&g, third);
1642 
1643   g.ToGraphDef(&def);
1644 
1645   auto session = CreateSession();
1646   ASSERT_TRUE(session != nullptr);
1647   TF_ASSERT_OK(session->Create(def));
1648 
1649   std::vector<Tensor> outputs;
1650 
1651   string handle;
1652   Status s = session->PRunSetup({first_const->name(), second_const->name()},
1653                                 {third_identity->name() + ":0"}, {}, &handle);
1654   TF_ASSERT_OK(s);
1655 
1656   // Feed first_const, fetch third_identity
1657   Tensor value_11(DT_FLOAT, TensorShape({}));
1658   value_11.scalar<float>()() = 11.0;
1659   s = session->PRun(handle, {{first_const->name(), value_11}},
1660                     {third_identity->name() + ":0"}, &outputs);
1661   ASSERT_TRUE(errors::IsInvalidArgument(s));
1662   EXPECT_TRUE(
1663       absl::StrContains(s.error_message(), "can't be computed from the feeds"));
1664 }
1665 
TEST(DirectSessionTest,PartialRunMultiOutputFeed)1666 TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
1667   GraphDef def;
1668   Graph g(OpRegistry::Global());
1669 
1670   Tensor bool_value(DT_BOOL, TensorShape({}));
1671   bool_value.scalar<bool>()() = true;
1672   Node* bool_const = test::graph::Constant(&g, bool_value);
1673   Node* switch_node = test::graph::Switch(&g, bool_const, bool_const);
1674   Node* fourth_identity = test::graph::Identity(&g, switch_node, 1);
1675 
1676   g.ToGraphDef(&def);
1677 
1678   auto session = CreateSession();
1679   ASSERT_TRUE(session != nullptr);
1680   TF_ASSERT_OK(session->Create(def));
1681 
1682   std::vector<Tensor> outputs;
1683 
1684   string handle;
1685   Status s = session->PRunSetup({switch_node->name() + ":1"},
1686                                 {fourth_identity->name() + ":0"}, {}, &handle);
1687   TF_ASSERT_OK(s);
1688 
1689   // Fetch fourth_identity without feeds.
1690   s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs);
1691   ASSERT_TRUE(errors::IsInvalidArgument(s));
1692   EXPECT_TRUE(
1693       absl::StrContains(s.error_message(), "can't be computed from the feeds"));
1694 
1695   // Feed switch_node:1 and fetch fourth_identity.
1696   s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}},
1697                     {fourth_identity->name() + ":0"}, &outputs);
1698   TF_ASSERT_OK(s);
1699   ASSERT_EQ(1, outputs.size());
1700   ASSERT_EQ(true, outputs[0].flat<bool>()(0));
1701 }
1702 
TEST(DirectSessionTest,RunHandleTest)1703 TEST(DirectSessionTest, RunHandleTest) {
1704   GraphDef def;
1705   Graph g(OpRegistry::Global());
1706 
1707   Tensor value0(DT_FLOAT, TensorShape({}));
1708   value0.scalar<float>()() = 1.0;
1709   Node* const0 = test::graph::Constant(&g, value0);
1710   Node* identity0 = test::graph::Identity(&g, const0);
1711 
1712   Tensor value1(DT_FLOAT, TensorShape({}));
1713   value1.scalar<float>()() = 2.0;
1714   Node* const1 = test::graph::Constant(&g, value1);
1715   Node* node3 = test::graph::Add(&g, identity0, const1);
1716   Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
1717 
1718   Tensor value2(DT_STRING, TensorShape({}));
1719   Node* const2 = test::graph::Constant(&g, value2);
1720   Node* node5 = test::graph::GetSessionTensor(&g, const2);
1721   Node* node6 = test::graph::Add(&g, node5, const1);
1722 
1723   Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
1724 
1725   g.ToGraphDef(&def);
1726 
1727   auto session = CreateSession();
1728   ASSERT_TRUE(session != nullptr);
1729   TF_ASSERT_OK(session->Create(def));
1730 
1731   // First run call: Create a handle.
1732   std::vector<Tensor> outputs;
1733   Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
1734   ASSERT_TRUE(s.ok());
1735   ASSERT_EQ(1, outputs.size());
1736 
1737   const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
1738   Tensor string_handle(DT_STRING, {});
1739   string_handle.flat<tstring>().setConstant(resource_handle.name());
1740 
1741   // Second run call: Use a handle.
1742   std::vector<Tensor> outputs1;
1743   s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
1744                    {}, &outputs1);
1745   ASSERT_TRUE(s.ok());
1746   ASSERT_EQ(1, outputs1.size());
1747   ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
1748 
1749   // Third run call: Delete a handle.
1750   std::vector<Tensor> outputs2;
1751   s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
1752                    &outputs2);
1753   ASSERT_TRUE(s.ok());
1754 }
1755 
TEST(DirectSessionTest,RunHandleTest_Callable)1756 TEST(DirectSessionTest, RunHandleTest_Callable) {
1757   GraphDef def;
1758   Graph g(OpRegistry::Global());
1759 
1760   Tensor value0(DT_FLOAT, TensorShape({}));
1761   value0.scalar<float>()() = 1.0;
1762   Node* const0 = test::graph::Constant(&g, value0);
1763   Node* identity0 = test::graph::Identity(&g, const0);
1764 
1765   Tensor value1(DT_FLOAT, TensorShape({}));
1766   value1.scalar<float>()() = 2.0;
1767   Node* const1 = test::graph::Constant(&g, value1);
1768   Node* node3 = test::graph::Add(&g, identity0, const1);
1769   Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
1770 
1771   Tensor value2(DT_STRING, TensorShape({}));
1772   Node* const2 = test::graph::Constant(&g, value2);
1773   Node* node5 = test::graph::GetSessionTensor(&g, const2);
1774   Node* node6 = test::graph::Add(&g, node5, const1);
1775 
1776   Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
1777 
1778   g.ToGraphDef(&def);
1779 
1780   auto session = CreateSession();
1781   ASSERT_TRUE(session != nullptr);
1782   TF_ASSERT_OK(session->Create(def));
1783 
1784   // First run call: Create a handle.
1785   std::vector<Tensor> outputs;
1786   Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
1787   ASSERT_TRUE(s.ok());
1788   ASSERT_EQ(1, outputs.size());
1789 
1790   const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
1791   Tensor string_handle(DT_STRING, {});
1792   string_handle.flat<tstring>().setConstant(resource_handle.name());
1793 
1794   // Second run call: Use a handle.
1795   std::vector<Tensor> outputs1;
1796   s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
1797                    {}, &outputs1);
1798   ASSERT_TRUE(s.ok());
1799   ASSERT_EQ(1, outputs1.size());
1800   ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
1801 
1802   // Third run call: Delete a handle.
1803   std::vector<Tensor> outputs2;
1804   s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
1805                    &outputs2);
1806   ASSERT_TRUE(s.ok());
1807 }
1808 
TEST(DirectSessionTest,CreateGraphFailsWhenAssigningAFedVar)1809 TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) {
1810   Graph graph(OpRegistry::Global());
1811 
1812   Node* a = test::graph::Var(&graph, DT_FLOAT, {});
1813   Node* b = test::graph::Constant(&graph, {});
1814 
1815   Tensor zero(DT_FLOAT, {});
1816   test::FillValues<float>(&zero, {0});
1817 
1818   // a = b
1819   Node* assign = test::graph::Assign(&graph, a, b);
1820 
1821   auto session = CreateSession();
1822   ASSERT_TRUE(session != nullptr);
1823 
1824   // The graph is invalid since a constant cannot be assigned to a constant.
1825   // The return Status of session->Run should flag this as an invalid argument.
1826   std::vector<Tensor> outputs;
1827   Status s = session->Run({{a->name(), zero}}, {assign->name()}, {}, &outputs);
1828   ASSERT_TRUE(errors::IsInvalidArgument(s));
1829 }
1830 
TEST(DirectSessionTest,TimeoutSession)1831 TEST(DirectSessionTest, TimeoutSession) {
1832   GraphDef graph;
1833   // Creates a graph with one FIFOQueue and one dequeue op.
1834   protobuf::TextFormat::ParseFromString(R"pb(
1835                                           node {
1836                                             name: 'fifo_queue'
1837                                             op: 'FIFOQueue'
1838                                             device: '/device:CPU:0'
1839                                             attr {
1840                                               key: 'capacity'
1841                                               value { i: 10 }
1842                                             }
1843                                             attr {
1844                                               key: 'component_types'
1845                                               value { list { type: DT_FLOAT } }
1846                                             }
1847                                             attr {
1848                                               key: 'container'
1849                                               value { s: '' }
1850                                             }
1851                                             attr {
1852                                               key: 'shapes'
1853                                               value { list {} }
1854                                             }
1855                                             attr {
1856                                               key: 'shared_name'
1857                                               value { s: '' }
1858                                             }
1859                                           }
1860                                           node {
1861                                             name: 'fifo_queue_Dequeue'
1862                                             op: 'QueueDequeue'
1863                                             input: 'fifo_queue'
1864                                             device: '/device:CPU:0'
1865                                             attr {
1866                                               key: 'component_types'
1867                                               value { list { type: DT_FLOAT } }
1868                                             }
1869                                             attr {
1870                                               key: 'timeout_ms'
1871                                               value { i: -1 }
1872                                             }
1873                                           }
1874                                           versions { producer: 9 }
1875                                         )pb",
1876                                         &graph);
1877 
1878   {
1879     // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
1880     SessionOptions options;
1881     (*options.config.mutable_device_count())["CPU"] = 2;
1882     options.config.set_operation_timeout_in_ms(100);
1883 
1884     std::unique_ptr<Session> session(NewSession(options));
1885     ASSERT_TRUE(session != nullptr);
1886     TF_ASSERT_OK(session->Create(graph));
1887 
1888     // Verifies that the error code is DEADLINE_EXCEEDED.
1889     Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr);
1890     ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
1891     TF_ASSERT_OK(session->Close());
1892   }
1893 
1894   {
1895     // Creates a session with no operation_timeout_in_ms.
1896     auto session = CreateSession();
1897     ASSERT_TRUE(session != nullptr);
1898     TF_ASSERT_OK(session->Create(graph));
1899     RunOptions run_options;
1900     run_options.set_timeout_in_ms(20);
1901     // Verifies that the error code is DEADLINE_EXCEEDED.
1902     Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"},
1903                              nullptr, nullptr);
1904     ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code());
1905     TF_ASSERT_OK(session->Close());
1906   }
1907 }
1908 
1909 // Accesses the cancellation manager for the step after the step has been
1910 // cancelled.
1911 class CancellationMgrPollingOp : public OpKernel {
1912  public:
CancellationMgrPollingOp(OpKernelConstruction * ctx)1913   explicit CancellationMgrPollingOp(OpKernelConstruction* ctx)
1914       : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1915   void Compute(OpKernelContext* ctx) override {
1916     CancellationManager* cm = ctx->cancellation_manager();
1917     while (!cm->IsCancelled()) {
1918       ctx->env()->SleepForMicroseconds(1000);
1919     }
1920     notification.Notify();
1921   }
1922   static Notification notification;
1923 };
1924 Notification CancellationMgrPollingOp::notification;
1925 
1926 REGISTER_KERNEL_BUILDER(Name("CancellationMgrPollingOp").Device(DEVICE_CPU),
1927                         CancellationMgrPollingOp);
1928 REGISTER_OP("CancellationMgrPollingOp").Doc("");
1929 
TEST(DirectSessionTest,TestTimeoutCleanShutdown)1930 TEST(DirectSessionTest, TestTimeoutCleanShutdown) {
1931   GraphDef graph;
1932   // Creates a graph with one FIFOQueue and one dequeue op.
1933   protobuf::TextFormat::ParseFromString(R"pb(
1934                                           node {
1935                                             name: 'cm_polling'
1936                                             op: 'CancellationMgrPollingOp'
1937                                             device: '/device:CPU:0'
1938                                           }
1939                                           versions { producer: 9 }
1940                                         )pb",
1941                                         &graph);
1942 
1943   // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
1944   SessionOptions options;
1945   options.config.set_operation_timeout_in_ms(100);
1946   std::unique_ptr<Session> session(NewSession(options));
1947   ASSERT_TRUE(session != nullptr);
1948   TF_ASSERT_OK(session->Create(graph));
1949 
1950   // Verifies that the error code is DEADLINE_EXCEEDED.
1951   Status s = session->Run({}, {}, {"cm_polling"}, nullptr);
1952   ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
1953 
1954   // Verify that the op ran to completion.
1955   ASSERT_TRUE(CancellationMgrPollingOp::notification.HasBeenNotified());
1956 
1957   TF_ASSERT_OK(session->Close());
1958 }
1959 
TestSessionInterOpThreadsImpl(bool use_function_lib,bool use_global_pools)1960 static void TestSessionInterOpThreadsImpl(bool use_function_lib,
1961                                           bool use_global_pools) {
1962   using test::function::blocking_op_state;
1963   using test::function::BlockingOpState;
1964 
1965   FunctionDefLibrary library_graph_def;
1966   if (use_function_lib) {
1967     *library_graph_def.add_function() = test::function::BlockingOpFn();
1968   }
1969 
1970   FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1971   Graph g(&flib);
1972   Tensor t(DT_FLOAT, TensorShape({}));
1973   t.scalar<float>()() = {1.2f};
1974   Node* x = test::graph::Constant(&g, t);
1975   Node* y;
1976   if (use_function_lib) {
1977     y = test::graph::Unary(&g, "BlockingOpFn", x);
1978   } else {
1979     y = test::graph::Unary(&g, "BlockingOp", x);
1980   }
1981   GraphDef def;
1982   g.ToGraphDef(&def);
1983   *def.mutable_library() = library_graph_def;
1984 
1985   // Create session with two inter-op thread pools.
1986   SessionOptions options;
1987   // Turn off optimizations so that the blocking op doesn't get invoked during
1988   // graph setup.
1989   options.config.mutable_graph_options()
1990       ->mutable_optimizer_options()
1991       ->set_opt_level(OptimizerOptions::L0);
1992   options.config.mutable_graph_options()
1993       ->mutable_rewrite_options()
1994       ->set_constant_folding(RewriterConfig::OFF);
1995   (*options.config.mutable_device_count())["CPU"] = 2;
1996   (*options.config.mutable_device_count())["GPU"] = 0;
1997 
1998   auto* p = options.config.add_session_inter_op_thread_pool();
1999   if (use_global_pools) p->set_global_name("large pool");
2000   p = options.config.add_session_inter_op_thread_pool();
2001   if (use_global_pools) p->set_global_name("small pool");
2002   p->set_num_threads(1);
2003   const int kSyncPool = -1;
2004   const int kLargePool = 0;
2005   const int kSmallPool = 1;
2006 
2007   std::vector<std::unique_ptr<Session>> sessions;
2008   if (!use_global_pools) {
2009     sessions.emplace_back(NewSession(options));
2010     TF_ASSERT_OK(sessions.back()->Create(def));
2011   }
2012   mutex sessions_mu;
2013 
2014   std::atomic<int32> num_done(0);
2015   // Runs session to compute <node>:0 using inter_op thread pool <pool>.
2016   auto add_session_run_call =
2017       [use_global_pools, &def, &options, &sessions, &sessions_mu, &num_done](
2018           thread::ThreadPool* tp, Node* node, int inter_op_pool) {
2019         auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu,
2020                    inter_op_pool, node, &num_done]() {
2021           RunOptions run_options;
2022           run_options.set_inter_op_thread_pool(inter_op_pool);
2023           std::vector<Tensor> outputs;
2024 
2025           Session* session;
2026           if (use_global_pools) {
2027             std::unique_ptr<Session> s(NewSession(options));
2028             TF_ASSERT_OK(s->Create(def));
2029             session = s.get();
2030 
2031             mutex_lock l(sessions_mu);
2032             sessions.emplace_back(std::move(s));
2033           } else {
2034             session = sessions[0].get();
2035           }
2036 
2037           Status s = session->Run(run_options, {} /* inputs */,
2038                                   {node->name() + ":0"} /* output_names */, {},
2039                                   &outputs, nullptr /* run_metadata */);
2040           TF_CHECK_OK(s);
2041           ASSERT_EQ(1, outputs.size());
2042           auto flat = outputs[0].flat<float>();
2043           EXPECT_FLOAT_EQ(1.2, flat(0));
2044           num_done.fetch_add(1);
2045         };
2046         if (tp != nullptr) {
2047           tp->Schedule(fn);
2048         } else {
2049           fn();
2050         }
2051       };
2052 
2053   // For blocking states:
2054   // - Starts at 0, BlockingOp::Compute will move to 1.
2055   // - This main thread will wait for 1, then move to 2 when other ops are done.
2056   //   Moving to 2 unblocks the blocking op, which then moves to state 3.
2057 
2058   // Run the graph once on the non-limited pool.
2059   thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 1);
2060   blocking_op_state = new BlockingOpState();
2061   add_session_run_call(tp1, y, kLargePool);
2062   blocking_op_state->AwaitState(1);
2063   blocking_op_state->MoveToState(1, 2);
2064   blocking_op_state->AwaitState(3);
2065   blocking_op_state->MoveToState(3, 0);
2066   delete tp1;
2067   num_done = 0;
2068 
2069   tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
2070 
2071   // Launch a session run call. It will not finish until the blocking op is
2072   // unblocked, because it is using all threads in the small pool.
2073   add_session_run_call(tp1, y, kSmallPool);
2074 
2075   blocking_op_state->AwaitState(1);  // Wait for the blocking op to Compute.
2076 
2077   // These will block on <BlockingOpState>.
2078   const int kBlockedThreads = 3;
2079   for (int i = 0; i < kBlockedThreads; ++i) {
2080     add_session_run_call(tp1, x, kSmallPool);
2081   }
2082 
2083   // Launch session calls using the other inter-op pool. These will finish
2084   // as they are in inter_op pool #2.
2085   thread::ThreadPool* tp2 = new thread::ThreadPool(Env::Default(), "tp2", 3);
2086   const int kUnblockedThreads = 4;
2087   for (int i = 0; i < kUnblockedThreads; ++i) {
2088     add_session_run_call(tp2, x, kLargePool);
2089   }
2090   delete tp2;
2091   EXPECT_EQ(kUnblockedThreads, num_done.load());
2092 
2093   // Launch a session call using this thread. This will finish as it runs
2094   // synchronously in this thread.
2095   add_session_run_call(nullptr, x, kSyncPool);
2096 
2097   // Unblock the blocked op and wait for the blocked functions to finish.
2098   blocking_op_state->MoveToState(1, 2);
2099   delete tp1;
2100 
2101   EXPECT_EQ(kUnblockedThreads + kBlockedThreads + 1 + 1, num_done.load());
2102   delete blocking_op_state;
2103   blocking_op_state = nullptr;
2104 }
2105 
TEST(DirectSessionTest,TestSessionInterOpThreads)2106 TEST(DirectSessionTest, TestSessionInterOpThreads) {
2107   TestSessionInterOpThreadsImpl(false /* use_function_lib */,
2108                                 false /*use_global_pools */);
2109 }
2110 
TEST(DirectSessionTest,TestSessionInterOpThreadsWithFunctions)2111 TEST(DirectSessionTest, TestSessionInterOpThreadsWithFunctions) {
2112   TestSessionInterOpThreadsImpl(true /* use_function_lib */,
2113                                 false /*use_global_pools */);
2114 }
2115 
TEST(DirectSessionTest,TestSessionInterOpGlobalPools)2116 TEST(DirectSessionTest, TestSessionInterOpGlobalPools) {
2117   TestSessionInterOpThreadsImpl(false /* use_function_lib */,
2118                                 true /*use_global_pools */);
2119 }
2120 
TEST(DirectSessionTest,TestSessionInterOpGlobalPoolsWithFunctions)2121 TEST(DirectSessionTest, TestSessionInterOpGlobalPoolsWithFunctions) {
2122   TestSessionInterOpThreadsImpl(true /* use_function_lib */,
2123                                 true /*use_global_pools */);
2124 }
2125 
TEST(DirectSessionTest,TestSessionInterOpThreadsInvalidOptions)2126 TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
2127   Graph g(OpRegistry::Global());
2128   Tensor t(DT_FLOAT, TensorShape({}));
2129   t.scalar<float>()() = {1.2f};
2130   Node* x = test::graph::Constant(&g, t);
2131   GraphDef def;
2132   g.ToGraphDef(&def);
2133 
2134   SessionOptions options;
2135   options.config.mutable_graph_options()
2136       ->mutable_optimizer_options()
2137       ->set_opt_level(OptimizerOptions::L0);
2138   (*options.config.mutable_device_count())["CPU"] = 2;
2139 
2140   options.config.add_session_inter_op_thread_pool();
2141 
2142   // Wrong pool number on Run call.
2143   {
2144     std::unique_ptr<Session> session(NewSession(options));
2145     TF_ASSERT_OK(session->Create(def));
2146     for (int pool_num = -2; pool_num <= 1; pool_num += 3) {
2147       RunOptions run_options;
2148       run_options.set_inter_op_thread_pool(pool_num);
2149       std::vector<Tensor> outputs;
2150       Status s = session->Run(run_options, {} /* inputs */,
2151                               {x->name() + ":0"} /* output_names */, {},
2152                               &outputs, nullptr /* run_metadata */);
2153       EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
2154       EXPECT_TRUE(absl::StrContains(
2155           s.error_message(),
2156           strings::StrCat("Invalid inter_op_thread_pool: ", pool_num)));
2157     }
2158   }
2159 
2160   // Global name changes thread count.
2161   std::vector<std::unique_ptr<Session>> sessions;
2162   auto* pool_config = options.config.mutable_session_inter_op_thread_pool(0);
2163   pool_config->set_num_threads(0);
2164   pool_config->set_global_name("foo");
2165   sessions.emplace_back(NewSession(options));
2166   TF_ASSERT_OK(sessions.back()->Create(def));
2167   sessions.emplace_back(NewSession(options));  // repeat creation, okay.
2168   TF_ASSERT_OK(sessions.back()->Create(def));
2169   for (int pass = 0; pass < 2; ++pass) {
2170     for (int i = 1; i < 128; ++i) {
2171       pool_config->set_num_threads(i);
2172       sessions.emplace_back(NewSession(options));
2173       auto status = sessions.back()->Create(def);
2174       ASSERT_FALSE(status.ok()) << status;
2175     }
2176 
2177     // Clear existing sessions before second pass; error still happens.
2178     sessions.clear();
2179   }
2180 }
2181 
TEST(DirectSessionTest,TestDirectSessionRunClose)2182 TEST(DirectSessionTest, TestDirectSessionRunClose) {
2183   // Construct a graph with a variable and a single assign.
2184   Graph g(OpRegistry::Global());
2185   Tensor t(DT_FLOAT, TensorShape({}));
2186   t.scalar<float>()() = {1.2f};
2187   Node* var_val = test::graph::Constant(&g, t);
2188   Node* var = test::graph::Var(&g, DT_FLOAT, {});
2189   Node* var_assign = test::graph::Assign(&g, var, var_val);
2190   GraphDef def;
2191   g.ToGraphDef(&def);
2192 
2193   SessionOptions options;
2194   (*options.config.mutable_device_count())["CPU"] = 2;
2195   std::unique_ptr<Session> session(NewSession(options));
2196   ASSERT_TRUE(session != nullptr);
2197   TF_ASSERT_OK(session->Create(def));
2198 
2199   // Assign a value to the var.
2200   TF_ASSERT_OK(session->Run({} /* inputs */, {},
2201                             {var_assign->name()} /* target_nodes */, nullptr));
2202 
2203   // Run a read on the variable to ensure that it works.
2204   std::vector<Tensor> outputs;
2205   TF_ASSERT_OK(session->Run(
2206       {} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
2207   EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
2208   outputs.clear();
2209 
2210   // Make a callable handle before closing the session.
2211   Session::CallableHandle handle;
2212   TF_ASSERT_OK(session->MakeCallable(
2213       MakeCallableOptions({}, {}, {var_assign->name()}), &handle));
2214 
2215   // Close the session.
2216   TF_ASSERT_OK(session->Close());
2217 
2218   // Run the read on the variable to get an error.
2219   Status s = session->Run({} /* inputs */, {},
2220                           {var_assign->name()} /* target_nodes */, nullptr);
2221   EXPECT_EQ(s.code(), error::CANCELLED);
2222   EXPECT_TRUE(absl::StrContains(s.error_message(), "Session has been closed."));
2223 
2224   // Run the read as a callable to verify that we get the same error.
2225   s = session->RunCallable(handle, {}, {}, nullptr);
2226   EXPECT_EQ(s.code(), error::CANCELLED);
2227   EXPECT_TRUE(absl::StrContains(s.error_message(), "Session has been closed."));
2228 }
2229 
TEST(DirectSessionTest,TestDirectSessionPRunClose)2230 TEST(DirectSessionTest, TestDirectSessionPRunClose) {
2231   GraphDef def;
2232   Graph g(OpRegistry::Global());
2233 
2234   Tensor first_value(DT_FLOAT, TensorShape({}));
2235   first_value.scalar<float>()() = 1.0;
2236   Node* first_const = test::graph::Constant(&g, first_value);
2237   Node* first_identity = test::graph::Identity(&g, first_const);
2238 
2239   Tensor second_value(DT_FLOAT, TensorShape({}));
2240   second_value.scalar<float>()() = 2.0;
2241   Node* second_const = test::graph::Constant(&g, second_value);
2242   Node* second_identity = test::graph::Identity(&g, second_const);
2243 
2244   Node* third = test::graph::Add(&g, first_identity, second_identity);
2245   Node* third_identity = test::graph::Identity(&g, third);
2246 
2247   g.ToGraphDef(&def);
2248 
2249   auto session = CreateSession();
2250   ASSERT_TRUE(session != nullptr);
2251   TF_ASSERT_OK(session->Create(def));
2252 
2253   std::vector<Tensor> outputs;
2254 
2255   string handle;
2256   Status s = session->PRunSetup(
2257       {first_const->name(), second_const->name()},
2258       {first_identity->name() + ":0", second_identity->name() + ":0",
2259        third_identity->name() + ":0"},
2260       {}, &handle);
2261   TF_ASSERT_OK(s);
2262 
2263   Tensor value_11(DT_FLOAT, TensorShape({}));
2264   value_11.scalar<float>()() = 11.0;
2265   Tensor value_22(DT_FLOAT, TensorShape({}));
2266   value_22.scalar<float>()() = 22.0;
2267 
2268   // Close the session.
2269   TF_ASSERT_OK(session->Close());
2270 
2271   // Feed first_const, fetch first_identity
2272   s = session->PRun(handle, {{first_const->name(), value_11}},
2273                     {first_identity->name() + ":0"}, &outputs);
2274   EXPECT_EQ(s.code(), error::CANCELLED);
2275   EXPECT_TRUE(absl::StrContains(s.error_message(), "Session has been closed."));
2276 }
2277 
TEST(DirectSessionTest,TestDirectSessionReset)2278 TEST(DirectSessionTest, TestDirectSessionReset) {
2279   // Construct a graph with a variable and a single assign.
2280   Graph g(OpRegistry::Global());
2281   Tensor t(DT_FLOAT, TensorShape({}));
2282   t.scalar<float>()() = {1.2f};
2283   Node* var_val = test::graph::Constant(&g, t);
2284   Node* var = test::graph::Var(&g, DT_FLOAT, {});
2285   Node* var_assign = test::graph::Assign(&g, var, var_val);
2286   GraphDef def;
2287   g.ToGraphDef(&def);
2288 
2289   SessionOptions options;
2290   (*options.config.mutable_device_count())["CPU"] = 2;
2291   std::unique_ptr<Session> session(NewSession(options));
2292   ASSERT_TRUE(session != nullptr);
2293   TF_ASSERT_OK(session->Create(def));
2294 
2295   // Assign a value to the var.
2296   TF_ASSERT_OK(session->Run({} /* inputs */, {},
2297                             {var_assign->name()} /* target_nodes */, nullptr));
2298 
2299   // Run a read on the variable to ensure that it works.
2300   std::vector<Tensor> outputs;
2301   TF_ASSERT_OK(session->Run(
2302       {} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
2303   EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
2304   outputs.clear();
2305 
2306   // Reset the containers.
2307   TF_EXPECT_OK(Reset(options, {}));
2308 
2309   // Run the read on the variable to get an error.
2310   // TODO(suharshs): This test only works because we close the Session in Reset.
2311   // If we change the behavior of Reset to not close the Session, this test will
2312   // fail, since the Variable buffer is cached by var.
2313   Status s = session->Run({} /* inputs */, {},
2314                           {var_assign->name()} /* target_nodes */, nullptr);
2315   EXPECT_EQ(s.code(), error::CANCELLED);
2316   EXPECT_TRUE(absl::StrContains(s.error_message(), "Session has been closed."));
2317 }
2318 
TEST(DirectSessionTest,LocalDeviceManager)2319 TEST(DirectSessionTest, LocalDeviceManager) {
2320   SessionOptions options;
2321   std::unique_ptr<Session> session(NewSession(options));
2322 
2323   const DeviceMgr* mgr = nullptr;
2324   TF_ASSERT_OK(session->LocalDeviceManager(&mgr));
2325   ASSERT_TRUE(mgr != nullptr);
2326   EXPECT_GT(mgr->ListDevices().size(), 0);
2327 }
2328 
2329 // y = tf.square(x)
CreateGraphForYEqualsXSquared()2330 GraphDef CreateGraphForYEqualsXSquared() {
2331   GraphDef graph_def;
2332   const char* text_proto = R"EOF(
2333 node {
2334   name: "x"
2335   op: "Placeholder"
2336   attr { key: "dtype" value { type: DT_FLOAT } }
2337   attr { key: "shape" value { shape { unknown_rank: true } } }
2338 }
2339 node {
2340   name: "y"
2341   op: "Square"
2342   input: "x"
2343   attr { key: "T" value { type: DT_FLOAT } }
2344 }
2345 versions {
2346   producer: 26
2347 }
2348   )EOF";
2349 
2350   QCHECK(protobuf::TextFormat::ParseFromString(text_proto, &graph_def));
2351   return graph_def;
2352 }
2353 
2354 // A graph that consumes and produces string tensors
2355 // (which are not GPU-compatible, i.e., there are no
2356 // GPU kernels for these operations).
IsCUDATensor(const Tensor & t)2357 bool IsCUDATensor(const Tensor& t) {
2358 #ifdef GOOGLE_CUDA
2359   cudaPointerAttributes attributes;
2360   cudaError_t err =
2361       cudaPointerGetAttributes(&attributes, t.tensor_data().data());
2362   if (err == cudaErrorInvalidValue) return false;
2363   CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err);
2364   return (attributes.type == cudaMemoryTypeDevice);
2365 #elif TENSORFLOW_USE_ROCM
2366   hipPointerAttribute_t attributes;
2367   hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data());
2368   if (err == hipErrorInvalidValue) return false;
2369   CHECK_EQ(hipSuccess, err) << hipGetErrorString(err);
2370   return (attributes.memoryType == hipMemoryTypeDevice);
2371 #else
2372   return false;
2373 #endif
2374 }
2375 
GPUDeviceName(Session * session)2376 string GPUDeviceName(Session* session) {
2377   std::vector<DeviceAttributes> devices;
2378   TF_CHECK_OK(session->ListDevices(&devices));
2379   for (const DeviceAttributes& d : devices) {
2380     if (d.device_type() == "GPU" || d.device_type() == "gpu") {
2381       return d.name();
2382     }
2383   }
2384   return "";
2385 }
2386 
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory)2387 TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory) {
2388   std::unique_ptr<Session> session(NewSession(SessionOptions()));
2389   const string gpu_device_name = GPUDeviceName(session.get());
2390   if (gpu_device_name.empty()) {
2391     LOG(INFO) << "Skipping test since no GPU is available";
2392     return;
2393   }
2394 
2395   TF_ASSERT_OK(session->Create(CreateGraphForYEqualsXSquared()));
2396 
2397   CallableOptions opts;
2398   opts.add_feed("x:0");
2399   opts.add_fetch("y:0");
2400 
2401   Tensor gpu_tensor;
2402 
2403   {
2404     Session::CallableHandle feed_cpu_fetch_gpu;
2405     opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2406     opts.set_fetch_skip_sync(true);
2407     TF_ASSERT_OK(session->MakeCallable(opts, &feed_cpu_fetch_gpu));
2408     Tensor input(DT_FLOAT, {});
2409     input.scalar<float>()() = 2.0f;
2410     std::vector<Tensor> outputs;
2411     TF_ASSERT_OK(
2412         session->RunCallable(feed_cpu_fetch_gpu, {input}, &outputs, nullptr));
2413     TF_ASSERT_OK(session->ReleaseCallable(feed_cpu_fetch_gpu));
2414     ASSERT_EQ(1, outputs.size());
2415     gpu_tensor = outputs[0];
2416     ASSERT_TRUE(IsCUDATensor(gpu_tensor));
2417   }
2418 
2419   {
2420     Session::CallableHandle feed_gpu_fetch_cpu;
2421     opts.clear_fetch_devices();
2422     opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2423     TF_ASSERT_OK(session->MakeCallable(opts, &feed_gpu_fetch_cpu));
2424     std::vector<Tensor> outputs;
2425     TF_ASSERT_OK(session->RunCallable(feed_gpu_fetch_cpu, {gpu_tensor},
2426                                       &outputs, nullptr));
2427     TF_ASSERT_OK(session->ReleaseCallable(feed_gpu_fetch_cpu));
2428     ASSERT_EQ(1, outputs.size());
2429     // The output is in CPU/host memory, so it can be dereferenced.
2430     ASSERT_EQ(16.0, outputs[0].scalar<float>()());
2431   }
2432 }
2433 
CreateIdentityGraphDef(DataType dtype)2434 GraphDef CreateIdentityGraphDef(DataType dtype) {
2435   GraphDef def;
2436 
2437   AttrValue dtype_attr;
2438   dtype_attr.set_type(dtype);
2439 
2440   AttrValue shape_attr;
2441   shape_attr.mutable_shape()->set_unknown_rank(true);
2442 
2443   auto* placeholder = def.add_node();
2444   placeholder->set_name("x");
2445   placeholder->set_op("Placeholder");
2446   placeholder->mutable_attr()->insert({"dtype", dtype_attr});
2447   placeholder->mutable_attr()->insert({"shape", shape_attr});
2448 
2449   auto* identity = def.add_node();
2450   identity->set_name("y");
2451   identity->set_op("Identity");
2452   identity->add_input("x");
2453   identity->mutable_attr()->insert({"T", dtype_attr});
2454 
2455   return def;
2456 }
2457 
TestFeedAndFetchTensorsInDeviceMemory(const SessionOptions & session_options,DataType dtype)2458 void TestFeedAndFetchTensorsInDeviceMemory(
2459     const SessionOptions& session_options, DataType dtype) {
2460   std::unique_ptr<Session> session(NewSession(session_options));
2461   const string gpu_device_name = GPUDeviceName(session.get());
2462   if (gpu_device_name.empty()) {
2463     LOG(INFO) << "Skipping test since no GPU is available";
2464     return;
2465   }
2466 
2467   TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
2468       << DataType_Name(dtype);
2469 
2470   CallableOptions opts;
2471   opts.add_feed("x:0");
2472   opts.add_fetch("y:0");
2473 
2474   Tensor gpu_tensor;
2475   Tensor host_tensor(dtype, {3});
2476   {
2477     // Ask for the fetched tensor to be backed by device memory.
2478     // Even though the kernel that created the tensor produced it in host
2479     // memory.
2480     opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2481     opts.set_fetch_skip_sync(true);
2482     Session::CallableHandle handle;
2483     TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
2484     std::vector<Tensor> outputs;
2485     TF_ASSERT_OK(session->RunCallable(handle, {host_tensor}, &outputs, nullptr))
2486         << DataType_Name(dtype);
2487     TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
2488     ASSERT_EQ(1, outputs.size()) << DataType_Name(dtype);
2489     gpu_tensor = outputs[0];
2490     ASSERT_TRUE(IsCUDATensor(gpu_tensor)) << DataType_Name(dtype);
2491   }
2492 
2493   {
2494     // Feed a tensor backed by device memory, even though the operations in the
2495     // graph expect it in host memory.
2496     opts.clear_fetch_devices();
2497     opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2498     Session::CallableHandle handle;
2499     TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
2500     std::vector<Tensor> outputs;
2501     TF_ASSERT_OK(session->RunCallable(handle, {gpu_tensor}, &outputs, nullptr))
2502         << DataType_Name(dtype);
2503     TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
2504     ASSERT_EQ(1, outputs.size());
2505     const StringPiece actual_data = outputs[0].tensor_data();
2506     const StringPiece expected_data = host_tensor.tensor_data();
2507     EXPECT_EQ(expected_data.size(), actual_data.size()) << DataType_Name(dtype);
2508     EXPECT_EQ(0, memcmp(expected_data.data(), actual_data.data(),
2509                         std::min(expected_data.size(), actual_data.size())))
2510         << DataType_Name(dtype);
2511   }
2512 }
2513 
TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(const SessionOptions & session_options,DataType dtype)2514 void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(
2515     const SessionOptions& session_options, DataType dtype) {
2516   std::unique_ptr<Session> session(NewSession(session_options));
2517   const string gpu_device_name = GPUDeviceName(session.get());
2518   if (gpu_device_name.empty()) {
2519     LOG(INFO) << "Skipping test since no GPU is available";
2520     return;
2521   }
2522 
2523   TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
2524       << DataType_Name(dtype);
2525 
2526   CallableOptions opts;
2527   opts.add_feed("x:0");
2528   opts.add_fetch("y:0");
2529 
2530   // Fail when asking to fetch into GPU memory.
2531   {
2532     opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2533     opts.set_fetch_skip_sync(true);
2534     Session::CallableHandle handle;
2535     Status status = session->MakeCallable(opts, &handle);
2536     EXPECT_FALSE(status.ok()) << DataType_Name(dtype);
2537     EXPECT_TRUE(absl::StrContains(
2538         status.error_message(),
2539         strings::StrCat(
2540             "Cannot feed or fetch tensor 'y:0' from device ", gpu_device_name,
2541             " as feeding/fetching from GPU devices is not yet supported for ",
2542             DataTypeString(dtype), " tensors")))
2543         << DataType_Name(dtype) << ", Status: " << status;
2544   }
2545 
2546   // Fail when feeding from GPU memory.
2547   {
2548     opts.clear_feed_devices();
2549     opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2550     Session::CallableHandle handle;
2551     Status status = session->MakeCallable(opts, &handle);
2552     EXPECT_FALSE(status.ok());
2553     EXPECT_TRUE(absl::StrContains(
2554         status.error_message(),
2555         strings::StrCat(
2556             "Cannot feed or fetch tensor 'x:0' from device ", gpu_device_name,
2557             " as feeding/fetching from GPU devices is not yet supported for ",
2558             DataTypeString(dtype), " tensors")))
2559         << DataType_Name(dtype) << ", Status: " << status;
2560   }
2561 }
2562 
TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(const SessionOptions & opts)2563 void TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(
2564     const SessionOptions& opts) {
2565   // Feeding/fetching on device does not work for all DataTypes as it
2566   // relies on the implementation of the _Arg and _Retval kernels which
2567   // are not registered for some types or consume/produce inputs/outputs
2568   // in host memory for some types.
2569   //
2570   // Run through all datatypes to validate that either:
2571   // (a) MakeCallable fails (because the given type cannot be fed/fetched
2572   //     in device memory),
2573   //     OR
2574   // (b) Succeeds: RunCallable should gladly accept inputs in device memory
2575   //     and produce output tensors in device memory.
2576   for (int i = DataType_MIN; i <= DataType_MAX; ++i) {
2577     if (!DataType_IsValid(i)) continue;
2578     const DataType dtype = static_cast<DataType>(i);
2579     switch (dtype) {
2580       case DT_INVALID:
2581         break;
2582       case DT_BFLOAT16:
2583       case DT_BOOL:
2584       case DT_COMPLEX128:
2585       case DT_COMPLEX64:
2586       case DT_DOUBLE:
2587       case DT_FLOAT:
2588       case DT_HALF:
2589       case DT_INT16:
2590       case DT_INT64:
2591       case DT_INT8:
2592       case DT_UINT16:
2593       case DT_UINT8:
2594         TestFeedAndFetchTensorsInDeviceMemory(opts, dtype);
2595         break;
2596       default:
2597         // Ignore all REF types since Tensors of this type aren't intended to
2598         // be fed (and attempting to create one via the Tensor constructor
2599         // will result in a LOG(FATAL)).
2600         if (!IsRefType(dtype)) {
2601           TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(opts, dtype);
2602         }
2603         break;
2604     }
2605   }
2606 }
2607 
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory_AllDataTypes)2608 TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory_AllDataTypes) {
2609   SessionOptions opts;
2610   opts.config.set_allow_soft_placement(false);
2611   TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
2612 }
2613 
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement)2614 TEST(DirectSessionTest,
2615      FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement) {
2616   SessionOptions opts;
2617   opts.config.set_allow_soft_placement(true);
2618   TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
2619 }
2620 
2621 // A simple benchmark for the overhead of `DirectSession::Run()` calls
2622 // with varying numbers of feeds/fetches.
FeedFetchBenchmarkHelper(::testing::benchmark::State & state,int num_feeds,bool use_make_callable,int inter_op_threads,bool use_single_threaded_executor)2623 void FeedFetchBenchmarkHelper(::testing::benchmark::State& state, int num_feeds,
2624                               bool use_make_callable, int inter_op_threads,
2625                               bool use_single_threaded_executor) {
2626   Tensor value(DT_FLOAT, TensorShape());
2627   value.flat<float>()(0) = 37.0;
2628 
2629   std::vector<std::pair<string, Tensor>> inputs;
2630   inputs.reserve(num_feeds);
2631   std::vector<string> outputs;
2632 
2633   Graph g(OpRegistry::Global());
2634   for (int i = 0; i < num_feeds; ++i) {
2635     // NOTE(mrry): We pin nodes to the "/cpu:0" device, so as not to
2636     // measure CPU<->GPU copying overhead. We should also optimize and
2637     // monitor this overhead where possible, but that is not the
2638     // object of study in this benchmark.
2639     Node* placeholder;
2640     TF_CHECK_OK(NodeBuilder(g.NewName("Placeholder"), "Placeholder")
2641                     .Attr("shape", TensorShape())
2642                     .Attr("dtype", DT_FLOAT)
2643                     .Device("/cpu:0")
2644                     .Finalize(&g, &placeholder));
2645     Node* identity;
2646     TF_CHECK_OK(NodeBuilder(g.NewName("Identity"), "Identity")
2647                     .Input(placeholder)
2648                     .Attr("T", DT_FLOAT)
2649                     .Device("/cpu:0")
2650                     .Finalize(&g, &identity));
2651     inputs.push_back({placeholder->name() + ":0", value});
2652     outputs.push_back(identity->name() + ":0");
2653   }
2654   GraphDef gd;
2655   g.ToGraphDef(&gd);
2656   SessionOptions opts;
2657   opts.config.set_inter_op_parallelism_threads(inter_op_threads);
2658   if (use_single_threaded_executor) {
2659     opts.config.mutable_experimental()->set_executor_type(
2660         "SINGLE_THREADED_EXECUTOR");
2661   }
2662   std::unique_ptr<Session> session(NewSession(opts));
2663   TF_CHECK_OK(session->Create(gd));
2664   if (use_make_callable) {
2665     Session::CallableHandle handle;
2666     CallableOptions callable_options;
2667     std::vector<Tensor> input_tensors;
2668     for (const auto& input : inputs) {
2669       callable_options.add_feed(input.first);
2670       input_tensors.push_back(input.second);
2671     }
2672     for (const string& output : outputs) {
2673       callable_options.add_fetch(output);
2674     }
2675     TF_CHECK_OK(session->MakeCallable(callable_options, &handle));
2676 
2677     for (auto s : state) {
2678       std::vector<Tensor> output_values;
2679       TF_CHECK_OK(
2680           session->RunCallable(handle, input_tensors, &output_values, nullptr));
2681     }
2682   } else {
2683     {
2684       // NOTE(mrry): Ignore the first run, which will incur the graph
2685       // partitioning/pruning overhead and skew the results.
2686       //
2687       // Note that we should also optimize and monitor the overhead on
2688       // the first run, which will impact application startup times, but
2689       // that is not the object of study in this benchmark.
2690       std::vector<Tensor> output_values;
2691       TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
2692     }
2693 
2694     for (auto s : state) {
2695       std::vector<Tensor> output_values;
2696       TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
2697     }
2698   }
2699 }
2700 
BM_FeedFetch(::testing::benchmark::State & state)2701 void BM_FeedFetch(::testing::benchmark::State& state) {
2702   const int num_feeds = state.range(0);
2703 
2704   FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ false,
2705                            /* inter_op_threads */ 0,
2706                            /* use_single_threaded_executor */ false);
2707 }
BM_FeedFetchCallable(::testing::benchmark::State & state)2708 void BM_FeedFetchCallable(::testing::benchmark::State& state) {
2709   const int num_feeds = state.range(0);
2710 
2711   FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2712                            /* inter_op_threads */ 0,
2713                            /* use_single_threaded_executor */ false);
2714 }
BM_FeedFetchCallableSingleThread(::testing::benchmark::State & state)2715 void BM_FeedFetchCallableSingleThread(::testing::benchmark::State& state) {
2716   const int num_feeds = state.range(0);
2717 
2718   FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2719                            /* inter_op_threads */ -1,
2720                            /* use_single_threaded_executor */ false);
2721 }
BM_FeedFetchCallableSingleThreadExecutor(::testing::benchmark::State & state)2722 void BM_FeedFetchCallableSingleThreadExecutor(
2723     ::testing::benchmark::State& state) {
2724   const int num_feeds = state.range(0);
2725 
2726   FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2727                            /* inter_op_threads */ -1,
2728                            /* use_single_threaded_executor */ true);
2729 }
2730 
2731 BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2732 BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2733 BENCHMARK(BM_FeedFetchCallableSingleThread)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2734 BENCHMARK(BM_FeedFetchCallableSingleThreadExecutor)
2735     ->Arg(1)
2736     ->Arg(2)
2737     ->Arg(5)
2738     ->Arg(10);
2739 
2740 }  // namespace
2741 
2742 class DirectSessionCollectiveTest : public ::testing::Test {
2743  public:
2744   // Creates a graph with CollectiveOps inside functions and runs it.  Returns
2745   // the generated collective_graph_key.
RunGraphWithCollectiveFunctions(bool add_unused_function,int64_t * collective_graph_key)2746   Status RunGraphWithCollectiveFunctions(bool add_unused_function,
2747                                          int64_t* collective_graph_key) {
2748     GraphDef g = CreateGraph(add_unused_function);
2749     const Tensor t1 =
2750         test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
2751     const Tensor t2 =
2752         test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
2753     auto session = CreateSession();
2754     TF_RETURN_IF_ERROR(session->Create(g));
2755     std::vector<Tensor> outputs;
2756     TF_RETURN_IF_ERROR(
2757         session->Run({{"input0:0", t1}, {"input1:0", t2}}, {},
2758                      {"collective_call0:0", "collective_call1:0"}, &outputs));
2759     DirectSession* direct_session = static_cast<DirectSession*>(session.get());
2760     {
2761       mutex_lock l(direct_session->collective_graph_key_lock_);
2762       *collective_graph_key = direct_session->collective_graph_key_;
2763     }
2764     return OkStatus();
2765   }
2766 
2767  private:
2768   // Creates a function with name `function_name` and a single CollectiveReduce
2769   // node with instance key set as `instance_key`.
CollectiveFunction(const string & function_name,int instance_key)2770   FunctionDef CollectiveFunction(const string& function_name,
2771                                  int instance_key) {
2772     return FunctionDefHelper::Define(
2773         // Function name
2774         function_name,
2775         // In def
2776         {"arg:float"},
2777         // Out def
2778         {"reduce:float"},
2779         // Attr def
2780         {},
2781         // Node def
2782         {{
2783             {"reduce"},
2784             "CollectiveReduce",
2785             {"arg"},
2786             {{"group_size", 2},
2787              {"group_key", 1},
2788              {"instance_key", instance_key},
2789              {"subdiv_offsets", gtl::ArraySlice<int32>({0})},
2790              {"merge_op", "Add"},
2791              {"final_op", "Div"},
2792              {"T", DT_FLOAT}},
2793         }});
2794   }
2795 
Input(int id)2796   NodeDef Input(int id) {
2797     AttrValue dtype_attr;
2798     SetAttrValue(DT_FLOAT, &dtype_attr);
2799     NodeDef input;
2800     input.set_name(strings::StrCat("input", id));
2801     input.set_op("Placeholder");
2802     input.mutable_attr()->insert({"dtype", dtype_attr});
2803     return input;
2804   }
2805 
CollectiveCall(const string & op,const string & input,int cpu_id)2806   NodeDef CollectiveCall(const string& op, const string& input, int cpu_id) {
2807     NodeDef collective_call;
2808     collective_call.set_name(strings::StrCat("collective_call", cpu_id));
2809     collective_call.set_op(op);
2810     collective_call.add_input(input);
2811     collective_call.set_device(
2812         strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", cpu_id));
2813     return collective_call;
2814   }
2815 
2816   // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
2817   // CPU1, with instance_key 1, and appropriate placeholder inputs.  If
2818   // `add_unused_function` is true, adds another CollectiveFunction with
2819   // instance_key 2 that is not invoked in the graph.
CreateGraph(bool add_unused_function)2820   GraphDef CreateGraph(bool add_unused_function) {
2821     GraphDef g;
2822     FunctionDef collective_function =
2823         CollectiveFunction("CollectiveFunction1", 1);
2824     FunctionDefLibrary* lib = g.mutable_library();
2825     *lib->add_function() = collective_function;
2826     if (add_unused_function) {
2827       FunctionDef unused_function =
2828           CollectiveFunction("CollectiveFunction2", 2);
2829       *lib->add_function() = unused_function;
2830     }
2831 
2832     *g.add_node() = Input(0);
2833     *g.add_node() = Input(1);
2834     // CollectiveReduce on CPU0 with instance_key 1.
2835     *g.add_node() = CollectiveCall("CollectiveFunction1", "input0", 0);
2836     // CollectiveReduce on CPU1 with instance_key 1.
2837     *g.add_node() = CollectiveCall("CollectiveFunction1", "input1", 1);
2838 
2839     return g;
2840   }
2841 };
2842 
TEST_F(DirectSessionCollectiveTest,TestCollectiveGraphKeyUsesOnlyCalledFunctions)2843 TEST_F(DirectSessionCollectiveTest,
2844        TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
2845   int64_t key1;
2846   TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
2847   int64_t key2;
2848   TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
2849   ASSERT_EQ(key1, key2);
2850 }
2851 
2852 // Accesses the cancellation manager for the step after the step has been
2853 // cancelled.
2854 class StatefulOutputRequiredOp : public OpKernel {
2855  public:
StatefulOutputRequiredOp(OpKernelConstruction * ctx)2856   explicit StatefulOutputRequiredOp(OpKernelConstruction* ctx)
2857       : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)2858   void Compute(OpKernelContext* ctx) override {
2859     // The op counts the number of outputs required in the current subgraph,
2860     // and emits that number on each of its required outputs.
2861     Tensor count_outputs_required_t(int64_t{0});
2862     int64_t& count_outputs_required =
2863         count_outputs_required_t.scalar<int64_t>()();
2864     for (int i = 0; i < num_outputs(); ++i) {
2865       if (ctx->output_required(i)) ++count_outputs_required;
2866     }
2867     for (int i = 0; i < num_outputs(); ++i) {
2868       if (ctx->output_required(i)) ctx->set_output(i, count_outputs_required_t);
2869     }
2870   }
2871 };
2872 
2873 REGISTER_KERNEL_BUILDER(Name("StatefulOutputRequired").Device(DEVICE_CPU),
2874                         StatefulOutputRequiredOp);
2875 REGISTER_OP("StatefulOutputRequired")
2876     .Output("results : num_outs * int64")
2877     .Attr("num_outs : int = 5")
2878     .SetIsStateful();
2879 
TEST(DirectSessionTest,TestStatefulOutputRequiredOp)2880 TEST(DirectSessionTest, TestStatefulOutputRequiredOp) {
2881   GraphDef graph;
2882   // Creates a graph with a StatefulOutputRequired op with 5 outputs.
2883   protobuf::TextFormat::ParseFromString(
2884       R"pb(
2885         node { name: 'n' op: 'StatefulOutputRequired' device: '/device:CPU:0' }
2886         versions { producer: 9 }
2887       )pb",
2888       &graph);
2889 
2890   std::unique_ptr<Session> session(NewSession(SessionOptions()));
2891   ASSERT_TRUE(session != nullptr);
2892   TF_ASSERT_OK(session->Create(std::move(graph)));
2893 
2894   // As a stateful op, a single StatefulOutputRequired kernel will be created
2895   // and shared across multiple subgraphs. We create 5 different subgraphs,
2896   // fetching different prefixes of the output of the op.
2897   for (int num_outputs_required = 1; num_outputs_required <= 5;
2898        ++num_outputs_required) {
2899     std::vector<string> fetch_tensor_names;
2900     fetch_tensor_names.reserve(num_outputs_required);
2901     for (int output_idx = 0; output_idx < num_outputs_required; ++output_idx) {
2902       fetch_tensor_names.push_back(strings::StrCat("n:", output_idx));
2903     }
2904     std::vector<Tensor> fetch_tensors;
2905     TF_ASSERT_OK(session->Run({}, fetch_tensor_names, {}, &fetch_tensors));
2906     ASSERT_EQ(num_outputs_required, fetch_tensors.size());
2907     for (const Tensor& t : fetch_tensors) {
2908       ASSERT_EQ(num_outputs_required, t.scalar<int64_t>()());
2909     }
2910   }
2911 
2912   TF_ASSERT_OK(session->Close());
2913 }
2914 
2915 }  // namespace tensorflow
2916