1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/direct_session.h"
17
18 #include <map>
19 #include <memory>
20 #include <random>
21 #include <string>
22 #include <thread> // NOLINT
23 #include <unordered_map>
24 #include <vector>
25
26 #include "absl/memory/memory.h"
27 #include "absl/strings/match.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/function_testlib.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_testutil.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/graph/costmodel.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/node_builder.h"
40 #include "tensorflow/core/graph/testlib.h"
41 #include "tensorflow/core/kernels/ops_util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/core/threadpool.h"
46 #include "tensorflow/core/lib/strings/str_util.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/stacktrace.h"
49 #include "tensorflow/core/platform/test.h"
50 #include "tensorflow/core/platform/test_benchmark.h"
51 #include "tensorflow/core/protobuf/error_codes.pb.h"
52 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
53 #include "tensorflow/core/public/session.h"
54 #include "tensorflow/core/public/session_options.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56
57 #if GOOGLE_CUDA
58 #include "third_party/gpus/cuda/include/cuda.h"
59 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
60 #elif TENSORFLOW_USE_ROCM
61 #include "rocm/include/hip/hip_runtime.h"
62 #endif // GOOGLE_CUDA
63
64 namespace tensorflow {
65 namespace {
66
MakeCallableOptions(gtl::ArraySlice<string> feeds,gtl::ArraySlice<string> fetches,gtl::ArraySlice<string> targets)67 CallableOptions MakeCallableOptions(gtl::ArraySlice<string> feeds,
68 gtl::ArraySlice<string> fetches,
69 gtl::ArraySlice<string> targets) {
70 CallableOptions ret;
71 for (const string& feed : feeds) {
72 ret.add_feed(feed);
73 }
74 for (const string& fetch : fetches) {
75 ret.add_fetch(fetch);
76 }
77 for (const string& target : targets) {
78 ret.add_target(target);
79 }
80 return ret;
81 }
82
DefaultSessionOptions()83 SessionOptions DefaultSessionOptions() {
84 SessionOptions options;
85 (*options.config.mutable_device_count())["CPU"] = 2;
86 return options;
87 }
88
CreateSession()89 std::unique_ptr<Session> CreateSession() {
90 return std::unique_ptr<Session>(NewSession(DefaultSessionOptions()));
91 }
92
93 class DirectSessionMinusAXTest : public ::testing::Test {
94 public:
Initialize(std::initializer_list<float> a_values)95 void Initialize(std::initializer_list<float> a_values) {
96 Graph graph(OpRegistry::Global());
97
98 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
99 test::FillValues<float>(&a_tensor, a_values);
100 Node* a = test::graph::Constant(&graph, a_tensor);
101 a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
102 a_ = a->name();
103
104 Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
105 test::FillValues<float>(&x_tensor, {1, 1});
106 Node* x = test::graph::Constant(&graph, x_tensor);
107 x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
108 x_ = x->name();
109
110 // y = A * x
111 Node* y = test::graph::Matmul(&graph, a, x, false, false);
112 y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
113 y_ = y->name();
114
115 Node* y_neg = test::graph::Unary(&graph, "Neg", y);
116 y_neg_ = y_neg->name();
117 y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
118
119 Node* z = test::graph::Unary(&graph, "Identity", y_neg);
120 z_ = z->name();
121 z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
122
123 graph.ToGraphDef(&def_);
124 }
125
126 string a_;
127 string x_;
128 string y_;
129 string y_neg_;
130 string z_;
131 GraphDef def_;
132 };
133
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork)134 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
135 Initialize({3, 2, -1, 0});
136 auto session = CreateSession();
137 ASSERT_TRUE(session != nullptr);
138 TF_ASSERT_OK(session->Create(def_));
139 std::vector<std::pair<string, Tensor>> inputs;
140
141 // Request two targets: one fetch output and one non-fetched output.
142 std::vector<string> output_names = {y_ + ":0"};
143 std::vector<string> target_nodes = {y_neg_};
144 std::vector<Tensor> outputs;
145 Status s = session->Run(inputs, output_names, target_nodes, &outputs);
146 TF_ASSERT_OK(s);
147
148 ASSERT_EQ(1, outputs.size());
149 // The first output should be initialized and have the correct
150 // output.
151 auto mat = outputs[0].matrix<float>();
152 ASSERT_TRUE(outputs[0].IsInitialized());
153 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
154 }
155
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_Callable)156 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
157 Initialize({3, 2, -1, 0});
158 auto session = CreateSession();
159 ASSERT_TRUE(session != nullptr);
160 TF_ASSERT_OK(session->Create(def_));
161
162 // Run the test twice to ensure that the Make/Run/Release cycle is hermetic.
163 for (int i = 0; i < 2; ++i) {
164 // Request two targets: one fetch output and one non-fetched output.
165 Session::CallableHandle handle;
166 TF_ASSERT_OK(session->MakeCallable(
167 MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
168
169 for (int i = 0; i < 2; ++i) {
170 std::vector<Tensor> outputs;
171 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
172
173 ASSERT_EQ(1, outputs.size());
174 // The first output should be initialized and have the correct
175 // output.
176 auto mat = outputs[0].matrix<float>();
177 ASSERT_TRUE(outputs[0].IsInitialized());
178 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
179 }
180
181 Status s = session->RunCallable(handle, {}, nullptr, nullptr);
182 EXPECT_TRUE(errors::IsInvalidArgument(s));
183 EXPECT_TRUE(absl::StrContains(s.error_message(),
184 "`fetch_tensors` must be provided"));
185
186 TF_ASSERT_OK(session->ReleaseCallable(handle));
187
188 std::vector<Tensor> outputs;
189 s = session->RunCallable(handle, {}, &outputs, nullptr);
190 EXPECT_TRUE(errors::IsInvalidArgument(s));
191 EXPECT_TRUE(absl::StrContains(
192 s.error_message(),
193 "Attempted to run callable after handle was released"));
194
195 s = session->RunCallable(handle + 1, {}, &outputs, nullptr);
196 EXPECT_TRUE(errors::IsInvalidArgument(s));
197 EXPECT_TRUE(
198 absl::StrContains(s.error_message(), "No such callable handle"));
199 }
200 }
201
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_OptimizeForStaticGraph)202 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_OptimizeForStaticGraph) {
203 Initialize({3, 2, -1, 0});
204 SessionOptions options(DefaultSessionOptions());
205 options.config.mutable_experimental()->set_optimize_for_static_graph(true);
206 auto session = absl::WrapUnique(NewSession(options));
207
208 ASSERT_TRUE(session != nullptr);
209 TF_ASSERT_OK(session->Create(def_));
210 std::vector<std::pair<string, Tensor>> inputs;
211
212 // Request two targets: one fetch output and one non-fetched output.
213 std::vector<string> output_names = {y_ + ":0"};
214 std::vector<string> target_nodes = {y_neg_};
215 std::vector<Tensor> outputs;
216 Status s = session->Run(inputs, output_names, target_nodes, &outputs);
217 TF_ASSERT_OK(s);
218
219 ASSERT_EQ(1, outputs.size());
220 // The first output should be initialized and have the correct
221 // output.
222 auto mat = outputs[0].matrix<float>();
223 ASSERT_TRUE(outputs[0].IsInitialized());
224 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
225
226 s = session->Extend({});
227 EXPECT_TRUE(errors::IsFailedPrecondition(s));
228 EXPECT_TRUE(
229 absl::StrContains(s.error_message(), "optimize_for_static_graph"));
230 }
231
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_DisableOutputPartitionGraphs)232 TEST_F(DirectSessionMinusAXTest,
233 RunSimpleNetwork_DisableOutputPartitionGraphs) {
234 Initialize({3, 2, -1, 0});
235 SessionOptions options(DefaultSessionOptions());
236 options.config.mutable_experimental()->set_disable_output_partition_graphs(
237 true);
238 auto session = absl::WrapUnique(NewSession(options));
239
240 ASSERT_TRUE(session != nullptr);
241 TF_ASSERT_OK(session->Create(def_));
242 std::vector<std::pair<string, Tensor>> inputs;
243
244 // Request two targets: one fetch output and one non-fetched output.
245 std::vector<string> output_names = {y_ + ":0"};
246 std::vector<string> target_nodes = {y_neg_};
247 std::vector<Tensor> outputs;
248 Status s = session->Run(inputs, output_names, target_nodes, &outputs);
249 TF_ASSERT_OK(s);
250
251 ASSERT_EQ(1, outputs.size());
252 // The first output should be initialized and have the correct
253 // output.
254 auto mat = outputs[0].matrix<float>();
255 ASSERT_TRUE(outputs[0].IsInitialized());
256 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
257
258 // The Run() call should fail when `output_partition_graphs` is set to true.
259 RunOptions run_options;
260 run_options.set_output_partition_graphs(true);
261 RunMetadata run_metadata;
262 s = session->Run(run_options, inputs, output_names, target_nodes, &outputs,
263 &run_metadata);
264
265 EXPECT_TRUE(errors::IsInvalidArgument(s));
266 EXPECT_TRUE(
267 absl::StrContains(s.error_message(), "disable_output_partition_graphs"));
268 }
269
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_FinalizeWithCallables)270 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithCallables) {
271 Initialize({3, 2, -1, 0});
272 auto session = CreateSession();
273 ASSERT_TRUE(session != nullptr);
274 TF_ASSERT_OK(session->Create(def_));
275
276 // Request two targets: one fetch output and one non-fetched output.
277 Session::CallableHandle handle;
278 TF_ASSERT_OK(session->MakeCallable(
279 MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
280
281 // Finalize the session.
282 TF_ASSERT_OK(session->Finalize());
283
284 // The callable is usable after finalization.
285 for (int i = 0; i < 2; ++i) {
286 std::vector<Tensor> outputs;
287 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
288
289 ASSERT_EQ(1, outputs.size());
290 // The first output should be initialized and have the correct
291 // output.
292 auto mat = outputs[0].matrix<float>();
293 ASSERT_TRUE(outputs[0].IsInitialized());
294 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
295 }
296
297 TF_ASSERT_OK(session->ReleaseCallable(handle));
298
299 // Making a new callable fails because the session has been finalized.
300 Status s =
301 session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle);
302 EXPECT_TRUE(errors::IsFailedPrecondition(s));
303 EXPECT_TRUE(
304 absl::StrContains(s.error_message(), "Session has been finalized."));
305 }
306
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_FinalizeWithRun)307 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithRun) {
308 Initialize({3, 2, -1, 0});
309 auto session = CreateSession();
310 ASSERT_TRUE(session != nullptr);
311 TF_ASSERT_OK(session->Create(def_));
312
313 // Request two targets: one fetch output and one non-fetched output.
314 std::vector<Tensor> outputs;
315 TF_ASSERT_OK(session->Run({}, {y_ + ":0"}, {y_neg_}, &outputs));
316
317 ASSERT_EQ(1, outputs.size());
318 // The first output should be initialized and have the correct output.
319 auto mat = outputs[0].matrix<float>();
320 ASSERT_TRUE(outputs[0].IsInitialized());
321 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
322
323 // Finalize the session.
324 TF_ASSERT_OK(session->Finalize());
325
326 // Running the exact same subgraph succeeds after finalization.
327 TF_ASSERT_OK(session->Run({}, {y_ + ":0"}, {y_neg_}, &outputs));
328 ASSERT_EQ(1, outputs.size());
329 mat = outputs[0].matrix<float>();
330 ASSERT_TRUE(outputs[0].IsInitialized());
331 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
332
333 // Running a different subgraph fails because the session has been finalized.
334 Status s = session->Run({}, {y_ + ":0"}, {}, &outputs);
335 EXPECT_TRUE(errors::IsFailedPrecondition(s));
336 EXPECT_TRUE(
337 absl::StrContains(s.error_message(), "Session has been finalized."));
338 }
339
TEST_F(DirectSessionMinusAXTest,TestTensorConnection)340 TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
341 Initialize({3, 2, -1, 0});
342 auto session = CreateSession();
343 ASSERT_TRUE(session != nullptr);
344 TF_ASSERT_OK(session->Create(def_));
345
346 {
347 // Directly wire the output of node a to the output of node y, making the
348 // callable graph into "Neg(a);".
349 CallableOptions callable_options;
350 TensorConnection* c = callable_options.add_tensor_connection();
351 c->set_from_tensor(a_ + ":0");
352 c->set_to_tensor(y_ + ":0");
353 callable_options.add_fetch(y_neg_ + ":0");
354
355 Session::CallableHandle handle;
356 TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
357 std::vector<Tensor> outputs;
358 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
359 ASSERT_EQ(1, outputs.size());
360 auto mat = outputs[0].matrix<float>();
361 ASSERT_TRUE(outputs[0].IsInitialized());
362 EXPECT_FLOAT_EQ(-3.0, mat(0, 0));
363 EXPECT_FLOAT_EQ(-2.0, mat(0, 1));
364 EXPECT_FLOAT_EQ(1.0, mat(1, 0));
365 EXPECT_FLOAT_EQ(0.0, mat(1, 1));
366 TF_ASSERT_OK(session->ReleaseCallable(handle));
367 }
368
369 {
370 // Directly wire the output of node a to the output of node y, making the
371 // callable graph into "Neg(a);"; also fetch the result of a.
372 CallableOptions callable_options;
373 TensorConnection* c = callable_options.add_tensor_connection();
374 c->set_from_tensor(a_ + ":0");
375 c->set_to_tensor(y_ + ":0");
376 callable_options.add_fetch(a_ + ":0");
377 callable_options.add_fetch(y_neg_ + ":0");
378
379 Session::CallableHandle handle;
380 TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
381 std::vector<Tensor> outputs;
382 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
383 ASSERT_EQ(2, outputs.size());
384 auto mat_a = outputs[0].matrix<float>();
385 ASSERT_TRUE(outputs[0].IsInitialized());
386 EXPECT_FLOAT_EQ(3.0, mat_a(0, 0));
387 EXPECT_FLOAT_EQ(2.0, mat_a(0, 1));
388 EXPECT_FLOAT_EQ(-1.0, mat_a(1, 0));
389 EXPECT_FLOAT_EQ(0.0, mat_a(1, 1));
390
391 auto mat_y_neg = outputs[1].matrix<float>();
392 ASSERT_TRUE(outputs[1].IsInitialized());
393 EXPECT_FLOAT_EQ(-3.0, mat_y_neg(0, 0));
394 EXPECT_FLOAT_EQ(-2.0, mat_y_neg(0, 1));
395 EXPECT_FLOAT_EQ(1.0, mat_y_neg(1, 0));
396 EXPECT_FLOAT_EQ(0.0, mat_y_neg(1, 1));
397 TF_ASSERT_OK(session->ReleaseCallable(handle));
398 }
399
400 {
401 // Wire the output of "Neg(Matmul(a, x))" to the output of "a",
402 // creating an invalid cycle.
403 CallableOptions callable_options;
404 TensorConnection* c = callable_options.add_tensor_connection();
405 c->set_from_tensor(y_ + ":0");
406 c->set_to_tensor(a_ + ":0");
407 callable_options.add_fetch(y_ + ":0");
408
409 Session::CallableHandle handle;
410 Status s = session->MakeCallable(callable_options, &handle);
411 EXPECT_TRUE(errors::IsInvalidArgument(s));
412 EXPECT_TRUE(absl::StrContains(s.error_message(), "would create a cycle"));
413 }
414
415 {
416 // Attempt to wire a non-existent node to a node that does exist.
417 CallableOptions callable_options;
418 TensorConnection* c = callable_options.add_tensor_connection();
419 c->set_from_tensor("unknown_node:0");
420 c->set_to_tensor(y_ + ":0");
421 callable_options.add_fetch(y_ + ":0");
422
423 Session::CallableHandle handle;
424 Status s = session->MakeCallable(callable_options, &handle);
425 EXPECT_TRUE(errors::IsInvalidArgument(s));
426 EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown node"));
427 }
428
429 {
430 // Attempt to wire a non-existent output from a node that does
431 // exist to another node.
432 CallableOptions callable_options;
433 TensorConnection* c = callable_options.add_tensor_connection();
434 c->set_from_tensor(a_ + ":17");
435 c->set_to_tensor(y_ + ":0");
436 callable_options.add_fetch(y_ + ":0");
437
438 Session::CallableHandle handle;
439 Status s = session->MakeCallable(callable_options, &handle);
440 EXPECT_TRUE(errors::IsInvalidArgument(s));
441 EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown edge"));
442 }
443
444 {
445 // Attempt to wire a tensor to a node that doesn't exist.
446 CallableOptions callable_options;
447 TensorConnection* c = callable_options.add_tensor_connection();
448 c->set_from_tensor(a_ + ":0");
449 c->set_to_tensor("unknown_node:0");
450 callable_options.add_fetch(y_ + ":0");
451
452 Session::CallableHandle handle;
453 Status s = session->MakeCallable(callable_options, &handle);
454 EXPECT_TRUE(errors::IsNotFound(s));
455 EXPECT_TRUE(
456 absl::StrContains(s.error_message(), "unable to find feed output"));
457 }
458
459 {
460 // Attempt to wire two tensors to the same tensor.
461 CallableOptions callable_options;
462 TensorConnection* c1 = callable_options.add_tensor_connection();
463 c1->set_from_tensor(a_ + ":0");
464 c1->set_to_tensor(y_neg_ + ":0");
465 TensorConnection* c2 = callable_options.add_tensor_connection();
466 c2->set_from_tensor(x_ + ":0");
467 c2->set_to_tensor(y_neg_ + ":0");
468 callable_options.add_fetch(z_ + ":0");
469
470 Session::CallableHandle handle;
471 Status s = session->MakeCallable(callable_options, &handle);
472 EXPECT_TRUE(errors::IsInvalidArgument(s));
473 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
474 }
475
476 {
477 // Attempt to wire a tensor to a tensor that is also being fed.
478 CallableOptions callable_options;
479 TensorConnection* c = callable_options.add_tensor_connection();
480 c->set_from_tensor(a_ + ":0");
481 c->set_to_tensor(y_ + ":0");
482 callable_options.add_feed(y_ + ":0");
483 callable_options.add_fetch(y_neg_ + ":0");
484
485 Session::CallableHandle handle;
486 Status s = session->MakeCallable(callable_options, &handle);
487 EXPECT_TRUE(errors::IsInvalidArgument(s));
488 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
489 }
490 }
491
TEST_F(DirectSessionMinusAXTest,TestFeed)492 TEST_F(DirectSessionMinusAXTest, TestFeed) {
493 Initialize({1, 2, 3, 4});
494 auto session = CreateSession();
495 ASSERT_TRUE(session != nullptr);
496
497 TF_ASSERT_OK(session->Create(def_));
498
499 // Fill in the input and ask for the output
500 //
501 // Note that the input being fed is on the second device.
502 Tensor t(DT_FLOAT, TensorShape({2, 1}));
503 t.matrix<float>()(0, 0) = 5;
504 t.matrix<float>()(1, 0) = 6;
505 std::vector<std::pair<string, Tensor>> inputs = {{x_, t}};
506 std::vector<string> output_names = {y_ + ":0"};
507 std::vector<Tensor> outputs;
508
509 // Run the graph
510 Status s = session->Run(inputs, output_names, {}, &outputs);
511 TF_ASSERT_OK(s);
512
513 ASSERT_EQ(1, outputs.size());
514 auto mat = outputs[0].matrix<float>();
515
516 // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
517 EXPECT_FLOAT_EQ(17.0, mat(0, 0));
518 EXPECT_FLOAT_EQ(39.0, mat(1, 0));
519 }
520
TEST_F(DirectSessionMinusAXTest,TestFeed_Callable)521 TEST_F(DirectSessionMinusAXTest, TestFeed_Callable) {
522 Initialize({1, 2, 3, 4});
523 auto session = CreateSession();
524 ASSERT_TRUE(session != nullptr);
525
526 TF_ASSERT_OK(session->Create(def_));
527
528 // Fill in the input and ask for the output
529 //
530 // Note that the input being fed is on the second device.
531 CallableOptions callable_options;
532 callable_options.add_feed(x_);
533 callable_options.add_fetch(y_ + ":0");
534 Session::CallableHandle handle;
535 TF_ASSERT_OK(session->MakeCallable(MakeCallableOptions({x_}, {y_ + ":0"}, {}),
536 &handle));
537 Tensor t(DT_FLOAT, TensorShape({2, 1}));
538 t.matrix<float>()(0, 0) = 5;
539 t.matrix<float>()(1, 0) = 6;
540 std::vector<Tensor> inputs = {t};
541 std::vector<Tensor> outputs;
542
543 // Run the callable
544 TF_ASSERT_OK(session->RunCallable(handle, inputs, &outputs, nullptr));
545
546 ASSERT_EQ(1, outputs.size());
547 auto mat = outputs[0].matrix<float>();
548
549 // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
550 EXPECT_FLOAT_EQ(17.0, mat(0, 0));
551 EXPECT_FLOAT_EQ(39.0, mat(1, 0));
552 }
553
TEST_F(DirectSessionMinusAXTest,TestConcurrency)554 TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
555 Initialize({1, 2, 3, 4});
556 auto session = CreateSession();
557 ASSERT_TRUE(session != nullptr);
558 TF_ASSERT_OK(session->Create(def_));
559
560 // Fill in the input and ask for the output
561 thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
562
563 // Run the graph 1000 times in 4 different threads concurrently.
564 std::vector<string> output_names = {y_ + ":0"};
565 auto fn = [&session, output_names]() {
566 for (int i = 0; i < 1000; ++i) {
567 std::vector<std::pair<string, Tensor>> inputs;
568 std::vector<Tensor> outputs;
569 // Run the graph
570 Status s = session->Run(inputs, output_names, {}, &outputs);
571 TF_ASSERT_OK(s);
572 ASSERT_EQ(1, outputs.size());
573 auto mat = outputs[0].matrix<float>();
574 EXPECT_FLOAT_EQ(3.0, mat(0, 0));
575 }
576 };
577
578 for (int i = 0; i < 4; ++i) {
579 tp->Schedule(fn);
580 }
581
582 // Wait for the functions to finish.
583 delete tp;
584 }
585
TEST_F(DirectSessionMinusAXTest,TestConcurrency_Callable)586 TEST_F(DirectSessionMinusAXTest, TestConcurrency_Callable) {
587 Initialize({1, 2, 3, 4});
588 auto session = CreateSession();
589 ASSERT_TRUE(session != nullptr);
590 TF_ASSERT_OK(session->Create(def_));
591
592 // Fill in the input and ask for the output
593 thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
594
595 Session::CallableHandle handle;
596 TF_ASSERT_OK(
597 session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle));
598
599 // Run the callable 1000 times in 4 different threads concurrently.
600 auto fn = [&session, handle]() {
601 for (int i = 0; i < 1000; ++i) {
602 std::vector<Tensor> outputs;
603 // Run the graph
604 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
605 ASSERT_EQ(1, outputs.size());
606 auto mat = outputs[0].matrix<float>();
607 EXPECT_FLOAT_EQ(3.0, mat(0, 0));
608 }
609 };
610
611 for (int i = 0; i < 4; ++i) {
612 tp->Schedule(fn);
613 }
614
615 // Wait for the functions to finish.
616 delete tp;
617 }
618
TEST_F(DirectSessionMinusAXTest,TestPerSessionThreads)619 TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) {
620 Initialize({1, 2, 3, 4});
621
622 SessionOptions options;
623 options.config.set_use_per_session_threads(true);
624 (*options.config.mutable_device_count())["CPU"] = 2;
625 std::unique_ptr<Session> session(NewSession(options));
626
627 ASSERT_TRUE(session != nullptr);
628 TF_ASSERT_OK(session->Create(def_));
629
630 // Fill in the input and ask for the output
631 thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
632
633 // Run the graph 1000 times in 4 different threads concurrently.
634 std::vector<string> output_names = {y_ + ":0"};
635 auto fn = [&session, output_names]() {
636 for (int i = 0; i < 1000; ++i) {
637 std::vector<std::pair<string, Tensor>> inputs;
638 std::vector<Tensor> outputs;
639 // Run the graph
640 Status s = session->Run(inputs, output_names, {}, &outputs);
641 TF_ASSERT_OK(s);
642 ASSERT_EQ(1, outputs.size());
643 auto mat = outputs[0].matrix<float>();
644 EXPECT_FLOAT_EQ(3.0, mat(0, 0));
645 }
646 };
647
648 for (int i = 0; i < 4; ++i) {
649 tp->Schedule(fn);
650 }
651
652 // Wait for the functions to finish.
653 delete tp;
654 }
655
TEST_F(DirectSessionMinusAXTest,TwoCreateCallsFails)656 TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
657 Initialize({1, 2, 3, 4});
658 auto session = CreateSession();
659 ASSERT_TRUE(session != nullptr);
660 TF_ASSERT_OK(session->Create(def_));
661
662 // Second is not.
663 ASSERT_FALSE(session->Create(def_).ok());
664 }
665
TEST_F(DirectSessionMinusAXTest,ForgetToCreate)666 TEST_F(DirectSessionMinusAXTest, ForgetToCreate) {
667 Initialize({1, 2, 3, 4});
668 auto session = CreateSession();
669 ASSERT_TRUE(session != nullptr);
670 std::vector<std::pair<string, Tensor>> inputs;
671 std::vector<Tensor> outputs;
672 ASSERT_FALSE(session->Run(inputs, {y_ + ":0"}, {y_neg_}, &outputs).ok());
673 }
674
TEST_F(DirectSessionMinusAXTest,InvalidDevice)675 TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
676 GraphDef def;
677 Graph graph(OpRegistry::Global());
678
679 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
680 a_tensor.flat<float>().setRandom();
681 Node* a = test::graph::Constant(&graph, a_tensor);
682 a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
683 Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
684 x_tensor.flat<float>().setRandom();
685 Node* x = test::graph::Constant(&graph, x_tensor);
686 x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
687 // Skip placing y.
688 Node* y = test::graph::Matmul(&graph, a, x, false, false);
689 y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:2");
690
691 graph.ToGraphDef(&def);
692
693 SessionOptions options;
694 (*options.config.mutable_device_count())["CPU"] = 2;
695 std::unique_ptr<Session> session(NewSession(options));
696 ASSERT_TRUE(session != nullptr);
697 // Should return an error.
698 ASSERT_FALSE(session->Create(def).ok());
699
700 // Fix placement and run again
701 def.Clear();
702 y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
703 graph.ToGraphDef(&def);
704 session.reset(NewSession(options));
705 TF_ASSERT_OK(session->Create(def));
706 std::vector<Tensor> outputs;
707 TF_ASSERT_OK(session->Run({}, {y->name() + ":0"}, {}, &outputs));
708 }
709
TEST_F(DirectSessionMinusAXTest,RunSimpleNetworkWithOpts)710 TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
711 Initialize({3, 2, -1, 0});
712 auto session = CreateSession();
713 ASSERT_TRUE(session != nullptr);
714 TF_ASSERT_OK(session->Create(def_));
715 std::vector<std::pair<string, Tensor>> inputs;
716
717 // Request two targets: one fetch output and one non-fetched output.
718 std::vector<string> output_names = {y_ + ":0"};
719 std::vector<string> target_nodes = {y_neg_};
720 std::vector<Tensor> outputs;
721
722 // Prepares RunOptions and RunMetadata
723 RunOptions run_options;
724 run_options.set_trace_level(RunOptions::SOFTWARE_TRACE);
725 RunMetadata run_metadata;
726 EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
727
728 Status s = session->Run(run_options, inputs, output_names, target_nodes,
729 &outputs, &run_metadata);
730 TF_ASSERT_OK(s);
731
732 ASSERT_EQ(1, outputs.size());
733 // The first output should be initialized and have the correct
734 // output.
735 auto mat = outputs[0].matrix<float>();
736 ASSERT_TRUE(outputs[0].IsInitialized());
737 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
738
739 // Checks RunMetadata is well-formed
740 ASSERT_TRUE(run_metadata.has_step_stats());
741 EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
742 }
743
TEST_F(DirectSessionMinusAXTest,RunSimpleNetworkWithOpts_Callable)744 TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
745 Initialize({3, 2, -1, 0});
746 auto session = CreateSession();
747 ASSERT_TRUE(session != nullptr);
748 TF_ASSERT_OK(session->Create(def_));
749
750 // Request two targets: one fetch output and one non-fetched output.
751 Session::CallableHandle handle;
752 CallableOptions callable_options =
753 MakeCallableOptions({}, {y_ + ":0"}, {y_neg_});
754 callable_options.mutable_run_options()->set_trace_level(
755 RunOptions::SOFTWARE_TRACE);
756 TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
757
758 RunMetadata run_metadata;
759 EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
760
761 std::vector<Tensor> outputs;
762 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, &run_metadata));
763
764 ASSERT_EQ(1, outputs.size());
765 // The first output should be initialized and have the correct
766 // output.
767 auto mat = outputs[0].matrix<float>();
768 ASSERT_TRUE(outputs[0].IsInitialized());
769 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
770
771 // Checks RunMetadata is well-formed
772 ASSERT_TRUE(run_metadata.has_step_stats());
773 EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
774 }
775
TEST_F(DirectSessionMinusAXTest,UseRunHandlerPool)776 TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
777 Initialize({3, 2, -1, 0});
778 auto session = CreateSession();
779 ASSERT_TRUE(session != nullptr);
780 TF_ASSERT_OK(session->Create(def_));
781 std::vector<std::pair<string, Tensor>> inputs;
782
783 // Request two targets: one fetch output and one non-fetched output.
784 std::vector<string> output_names = {y_ + ":0"};
785 std::vector<string> target_nodes = {y_neg_};
786 std::vector<Tensor> outputs;
787
788 // Prepares RunOptions and RunMetadata
789 RunOptions run_options;
790 run_options.mutable_experimental()->set_use_run_handler_pool(true);
791
792 Status s = session->Run(run_options, inputs, output_names, target_nodes,
793 &outputs, nullptr);
794 TF_ASSERT_OK(s);
795
796 ASSERT_EQ(1, outputs.size());
797 // The first output should be initialized and have the correct
798 // output.
799 auto mat = outputs[0].matrix<float>();
800 ASSERT_TRUE(outputs[0].IsInitialized());
801 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
802 }
803
TEST(DirectSessionTest,KeepsStateAcrossRunsOfSession)804 TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
805 GraphDef def;
806 Graph g(OpRegistry::Global());
807 Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
808 var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
809
810 Tensor twenty(DT_FLOAT, TensorShape({10}));
811 for (int i = 0; i < 10; ++i) {
812 twenty.flat<float>()(i) = 20.0;
813 }
814
815 Node* twenty_node = test::graph::Constant(&g, twenty);
816 twenty_node->set_assigned_device_name(
817 "/job:localhost/replica:0/task:0/cpu:0");
818
819 Node* init = test::graph::Assign(&g, var, twenty_node);
820 init->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
821
822 g.ToGraphDef(&def);
823
824 auto session = CreateSession();
825 ASSERT_TRUE(session != nullptr);
826 TF_ASSERT_OK(session->Create(def));
827
828 std::vector<std::pair<string, Tensor>> inputs;
829 std::vector<Tensor> outputs;
830
831 // Initialize the variable
832 Status s = session->Run(inputs, {init->name()}, {}, &outputs);
833 TF_ASSERT_OK(s);
834
835 // Get the variable's data
836 s = session->Run(inputs, {var->name() + ":0"}, {}, &outputs);
837 TF_ASSERT_OK(s);
838 ASSERT_EQ(1, outputs.size());
839 ASSERT_TRUE(outputs[0].IsInitialized());
840 EXPECT_EQ(20.0, outputs[0].flat<float>()(0));
841 }
842
TEST(DirectSessionTest,MultipleFeedTest)843 TEST(DirectSessionTest, MultipleFeedTest) {
844 GraphDef def;
845 Graph g(OpRegistry::Global());
846
847 Tensor first_value(DT_FLOAT, TensorShape({}));
848 first_value.scalar<float>()() = 1.0;
849 Node* first_const = test::graph::Constant(&g, first_value);
850 Node* first_identity = test::graph::Identity(&g, first_const);
851
852 Tensor second_value(DT_FLOAT, TensorShape({}));
853 second_value.scalar<float>()() = 2.0;
854 Node* second_const = test::graph::Constant(&g, second_value);
855 Node* second_identity = test::graph::Identity(&g, second_const);
856
857 g.ToGraphDef(&def);
858
859 auto session = CreateSession();
860 ASSERT_TRUE(session != nullptr);
861 TF_ASSERT_OK(session->Create(def));
862
863 std::vector<Tensor> outputs;
864
865 // Fetch without feeding.
866 Status s = session->Run(
867 {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
868 &outputs);
869 TF_ASSERT_OK(s);
870 ASSERT_EQ(2, outputs.size());
871 ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
872 ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
873
874 s = session->Run(
875 {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
876 &outputs);
877 TF_ASSERT_OK(s);
878 ASSERT_EQ(2, outputs.size());
879 ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
880 ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
881
882 Tensor value_11(DT_FLOAT, TensorShape({}));
883 value_11.scalar<float>()() = 11.0;
884 Tensor value_22(DT_FLOAT, TensorShape({}));
885 value_22.scalar<float>()() = 22.0;
886
887 // Feed [first_const, second_const]
888 s = session->Run(
889 {{first_const->name(), value_11}, {second_const->name(), value_22}},
890 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
891 &outputs);
892 TF_ASSERT_OK(s);
893 ASSERT_EQ(2, outputs.size());
894 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
895 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
896
897 // Feed [second_const, first_const]
898 s = session->Run(
899 {{second_const->name(), value_22}, {first_const->name(), value_11}},
900 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
901 &outputs);
902 TF_ASSERT_OK(s);
903 ASSERT_EQ(2, outputs.size());
904 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
905 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
906
907 // Feed [first_const, first_const]
908 s = session->Run(
909 {{first_const->name(), value_11}, {first_const->name(), value_22}},
910 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
911 &outputs);
912 EXPECT_TRUE(errors::IsInvalidArgument(s));
913 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
914 }
915
TEST(DirectSessionTest,MultipleFeedTest_Callable)916 TEST(DirectSessionTest, MultipleFeedTest_Callable) {
917 GraphDef def;
918 Graph g(OpRegistry::Global());
919
920 Tensor first_value(DT_FLOAT, TensorShape({}));
921 first_value.scalar<float>()() = 1.0;
922 Node* first_const = test::graph::Constant(&g, first_value);
923 Node* first_identity = test::graph::Identity(&g, first_const);
924
925 Tensor second_value(DT_FLOAT, TensorShape({}));
926 second_value.scalar<float>()() = 2.0;
927 Node* second_const = test::graph::Constant(&g, second_value);
928 Node* second_identity = test::graph::Identity(&g, second_const);
929
930 g.ToGraphDef(&def);
931
932 auto session = CreateSession();
933 ASSERT_TRUE(session != nullptr);
934 TF_ASSERT_OK(session->Create(def));
935
936 Session::CallableHandle handle;
937 std::vector<Tensor> outputs;
938
939 // Fetch without feeding.
940 TF_ASSERT_OK(session->MakeCallable(
941 MakeCallableOptions(
942 {}, {first_identity->name() + ":0", second_identity->name() + ":0"},
943 {}),
944 &handle));
945 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
946 ASSERT_EQ(2, outputs.size());
947 ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
948 ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
949
950 TF_ASSERT_OK(session->MakeCallable(
951 MakeCallableOptions(
952 {}, {second_identity->name() + ":0", first_identity->name() + ":0"},
953 {}),
954 &handle));
955 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
956 ASSERT_EQ(2, outputs.size());
957 ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
958 ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
959
960 Tensor value_11(DT_FLOAT, TensorShape({}));
961 value_11.scalar<float>()() = 11.0;
962 Tensor value_22(DT_FLOAT, TensorShape({}));
963 value_22.scalar<float>()() = 22.0;
964
965 // Feed [first_const, second_const]
966 TF_ASSERT_OK(session->MakeCallable(
967 MakeCallableOptions(
968 {first_const->name(), second_const->name()},
969 {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
970 &handle));
971 TF_ASSERT_OK(
972 session->RunCallable(handle, {value_11, value_22}, &outputs, nullptr));
973 ASSERT_EQ(2, outputs.size());
974 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
975 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
976
977 // Feed [second_const, first_const]
978 TF_ASSERT_OK(session->MakeCallable(
979 MakeCallableOptions(
980 {second_const->name(), first_const->name()},
981 {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
982 &handle));
983 TF_ASSERT_OK(
984 session->RunCallable(handle, {value_22, value_11}, &outputs, nullptr));
985 ASSERT_EQ(2, outputs.size());
986 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
987 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
988
989 // Feed [first_const, first_const]
990 Status s = session->MakeCallable(
991 MakeCallableOptions(
992 {first_const->name(), first_const->name()},
993 {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
994 &handle);
995 EXPECT_TRUE(errors::IsInvalidArgument(s));
996 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
997 }
998
TEST(DirectSessionTest,TestTensorConnectionUseTwice)999 TEST(DirectSessionTest, TestTensorConnectionUseTwice) {
1000 Graph graph(OpRegistry::Global());
1001
1002 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
1003 test::FillValues<float>(&a_tensor, {1.0, 2.0, 3.0, 4.0});
1004 Node* a = test::graph::Constant(&graph, a_tensor);
1005
1006 Tensor dummy_tensor(DT_FLOAT, TensorShape({1}));
1007 test::FillValues<float>(&dummy_tensor, {-1.0});
1008
1009 Node* left = test::graph::Constant(&graph, dummy_tensor);
1010 Node* right = test::graph::Constant(&graph, dummy_tensor);
1011
1012 // y = A * x
1013 Node* y = test::graph::Add(&graph, left, right);
1014
1015 GraphDef def;
1016 graph.ToGraphDef(&def);
1017
1018 auto session = CreateSession();
1019 ASSERT_TRUE(session != nullptr);
1020 TF_ASSERT_OK(session->Create(def));
1021
1022 CallableOptions callable_options;
1023 // Directly wire the output of node a to the outputs of nodes left
1024 // and right, making the callable graph into "a + a;".
1025 TensorConnection* c_left = callable_options.add_tensor_connection();
1026 c_left->set_from_tensor(a->name() + ":0");
1027 c_left->set_to_tensor(left->name() + ":0");
1028 TensorConnection* c_right = callable_options.add_tensor_connection();
1029 c_right->set_from_tensor(a->name() + ":0");
1030 c_right->set_to_tensor(right->name() + ":0");
1031
1032 callable_options.add_fetch(y->name() + ":0");
1033
1034 Session::CallableHandle handle;
1035 TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
1036 std::vector<Tensor> outputs;
1037 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
1038 ASSERT_EQ(1, outputs.size());
1039 auto mat = outputs[0].matrix<float>();
1040 ASSERT_TRUE(outputs[0].IsInitialized());
1041 EXPECT_FLOAT_EQ(2.0, mat(0, 0));
1042 EXPECT_FLOAT_EQ(4.0, mat(0, 1));
1043 EXPECT_FLOAT_EQ(6.0, mat(1, 0));
1044 EXPECT_FLOAT_EQ(8.0, mat(1, 1));
1045 TF_ASSERT_OK(session->ReleaseCallable(handle));
1046 }
1047
TEST(DirectSessionTest,FetchMultipleTimes)1048 TEST(DirectSessionTest, FetchMultipleTimes) {
1049 Graph g(OpRegistry::Global());
1050 Tensor seven_tensor(DT_INT32, TensorShape());
1051 seven_tensor.flat<int32>()(0) = 7;
1052 Node* seven_node = test::graph::Constant(&g, seven_tensor);
1053
1054 GraphDef def;
1055 g.ToGraphDef(&def);
1056
1057 auto session = CreateSession();
1058 ASSERT_TRUE(session != nullptr);
1059 TF_ASSERT_OK(session->Create(def));
1060
1061 const std::vector<std::pair<string, Tensor>> inputs;
1062 std::vector<Tensor> outputs;
1063
1064 auto seven = seven_node->name();
1065 Status s = session->Run(inputs, {seven, seven}, {}, &outputs);
1066 TF_ASSERT_OK(s);
1067
1068 EXPECT_EQ(2, outputs.size());
1069 for (int i = 0; i < outputs.size(); ++i) {
1070 const Tensor& t = outputs[i];
1071 ASSERT_TRUE(t.IsInitialized()) << i;
1072 EXPECT_EQ(7, t.flat<int32>()(0)) << i;
1073 }
1074 }
1075
TEST(DirectSessionTest,MultipleFeedTestSomeSyncRun)1076 TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) {
1077 GraphDef def;
1078 Graph g(OpRegistry::Global());
1079 RunOptions run_options;
1080 run_options.set_inter_op_thread_pool(-1);
1081
1082 Tensor first_value(DT_FLOAT, TensorShape({}));
1083 first_value.scalar<float>()() = 1.0;
1084 Node* first_const = test::graph::Constant(&g, first_value);
1085 Node* first_identity = test::graph::Identity(&g, first_const);
1086
1087 Tensor second_value(DT_FLOAT, TensorShape({}));
1088 second_value.scalar<float>()() = 2.0;
1089 Node* second_const = test::graph::Constant(&g, second_value);
1090 Node* second_identity = test::graph::Identity(&g, second_const);
1091
1092 g.ToGraphDef(&def);
1093
1094 auto session = CreateSession();
1095 ASSERT_TRUE(session != nullptr);
1096 TF_ASSERT_OK(session->Create(def));
1097
1098 std::vector<Tensor> outputs;
1099
1100 // Fetch without feeding.
1101 Status s = session->Run(
1102 run_options, {},
1103 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1104 &outputs, nullptr);
1105 TF_ASSERT_OK(s);
1106 ASSERT_EQ(2, outputs.size());
1107 ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
1108 ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
1109
1110 s = session->Run(
1111 {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
1112 &outputs);
1113 TF_ASSERT_OK(s);
1114 ASSERT_EQ(2, outputs.size());
1115 ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
1116 ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
1117
1118 Tensor value_11(DT_FLOAT, TensorShape({}));
1119 value_11.scalar<float>()() = 11.0;
1120 Tensor value_22(DT_FLOAT, TensorShape({}));
1121 value_22.scalar<float>()() = 22.0;
1122
1123 // Feed [first_const, second_const]
1124 s = session->Run(
1125 {{first_const->name(), value_11}, {second_const->name(), value_22}},
1126 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1127 &outputs);
1128 TF_ASSERT_OK(s);
1129 ASSERT_EQ(2, outputs.size());
1130 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1131 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
1132
1133 // Feed [second_const, first_const]
1134 s = session->Run(
1135 {{second_const->name(), value_22}, {first_const->name(), value_11}},
1136 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1137 &outputs);
1138 TF_ASSERT_OK(s);
1139 ASSERT_EQ(2, outputs.size());
1140 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1141 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
1142
1143 // Feed [first_const, first_const]
1144 s = session->Run(
1145 run_options,
1146 {{first_const->name(), value_11}, {first_const->name(), value_22}},
1147 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1148 &outputs, nullptr);
1149 EXPECT_TRUE(errors::IsInvalidArgument(s));
1150 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
1151 }
1152
1153 REGISTER_OP("SessionMetadataReader")
1154 .Input("x: int64")
1155 .Output("y: string")
1156 .SetIsStateful()
1157 .Doc(R"doc(SessionMetadataReader returns the session metadata.
1158
1159 x: int64
1160 y: string
1161 )doc");
1162
1163 class SessionMetadataReaderOp : public OpKernel {
1164 public:
SessionMetadataReaderOp(OpKernelConstruction * ctx)1165 explicit SessionMetadataReaderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1166 void Compute(OpKernelContext* ctx) override {
1167 Tensor* out_tensor = nullptr;
1168 OP_REQUIRES_OK(ctx,
1169 ctx->allocate_output("y", TensorShape({}), &out_tensor));
1170 if (ctx->session_metadata() != nullptr) {
1171 out_tensor->scalar<tstring>()() = ctx->session_metadata()->DebugString();
1172 } else {
1173 out_tensor->scalar<tstring>()() = "";
1174 }
1175 }
1176 };
1177 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_CPU),
1178 SessionMetadataReaderOp);
1179 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_GPU),
1180 SessionMetadataReaderOp);
1181
SessionMetadataReaderOpFn()1182 FunctionDef SessionMetadataReaderOpFn() {
1183 return FunctionDefHelper::Define(
1184 // Name
1185 "SessionMetadataReaderFn",
1186 // Args
1187 {"x: int64"},
1188 // Return values
1189 {"y: string"},
1190 // Attr def
1191 {},
1192 // Nodes
1193 {{{"y"}, "SessionMetadataReader", {"x"}, {}}});
1194 }
1195
TEST(DirectSessionTest,SessionMetadataAbsent)1196 TEST(DirectSessionTest, SessionMetadataAbsent) {
1197 Graph g(OpRegistry::Global());
1198 Tensor vx(DT_INT64, TensorShape({}));
1199 vx.scalar<int64_t>()() = 17;
1200 Node* x = test::graph::Constant(&g, vx);
1201 Node* y = test::graph::Unary(&g, "SessionMetadataReader", x);
1202 GraphDef def;
1203 g.ToGraphDef(&def);
1204 auto sess = CreateSession();
1205 TF_ASSERT_OK(sess->Create(def));
1206 std::vector<Tensor> outputs;
1207 RunOptions run_opts;
1208 run_opts.set_inter_op_thread_pool(-1);
1209 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1210
1211 EXPECT_EQ("", outputs[0].scalar<tstring>()());
1212 }
1213
TEST(DirectSessionTest,SessionMetadataAbsentViaFunction)1214 TEST(DirectSessionTest, SessionMetadataAbsentViaFunction) {
1215 FunctionDefLibrary library_graph_def;
1216 *library_graph_def.add_function() = SessionMetadataReaderOpFn();
1217 FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1218 Graph g(&flib);
1219 Tensor vx(DT_INT64, TensorShape({}));
1220 vx.scalar<int64_t>()() = 17;
1221 Node* x = test::graph::Constant(&g, vx);
1222 Node* y = test::graph::Unary(&g, "SessionMetadataReaderFn", x);
1223 GraphDef def;
1224 g.ToGraphDef(&def);
1225 *def.mutable_library() = library_graph_def;
1226 auto sess = CreateSession();
1227 TF_ASSERT_OK(sess->Create(def));
1228 std::vector<Tensor> outputs;
1229 RunOptions run_opts;
1230 run_opts.set_inter_op_thread_pool(-1);
1231 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1232
1233 EXPECT_EQ("", outputs[0].scalar<tstring>()());
1234 }
1235
TEST(DirectSessionTest,SessionMetadataPresent)1236 TEST(DirectSessionTest, SessionMetadataPresent) {
1237 Graph g(OpRegistry::Global());
1238 Tensor vx(DT_INT64, TensorShape({}));
1239 vx.scalar<int64_t>()() = 17;
1240 Node* x = test::graph::Constant(&g, vx);
1241 Node* y = test::graph::Unary(&g, "SessionMetadataReader", x);
1242 GraphDef def;
1243 g.ToGraphDef(&def);
1244 auto session_options = DefaultSessionOptions();
1245 auto* session_metadata =
1246 session_options.config.mutable_experimental()->mutable_session_metadata();
1247 session_metadata->set_name("name");
1248 session_metadata->set_version(1);
1249 auto sess = std::unique_ptr<Session>(NewSession(session_options));
1250 TF_ASSERT_OK(sess->Create(def));
1251 std::vector<Tensor> outputs;
1252 RunOptions run_opts;
1253 run_opts.set_inter_op_thread_pool(-1);
1254
1255 // Run without requesting RunMetadata.
1256 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs,
1257 nullptr /*=run_metadata*/);
1258 SessionMetadata read_metadata;
1259 ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1260 outputs[0].scalar<tstring>()(), &read_metadata));
1261 EXPECT_EQ("name", read_metadata.name());
1262 EXPECT_EQ(1, read_metadata.version());
1263
1264 // Run with requesting RunMetadata.
1265 RunMetadata metadata;
1266 s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs,
1267 &metadata /*=run_metadata*/);
1268 ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1269 outputs[0].scalar<tstring>()(), &read_metadata));
1270 EXPECT_EQ("name", read_metadata.name());
1271 EXPECT_EQ(1, read_metadata.version());
1272 EXPECT_EQ(session_metadata->name(), metadata.session_metadata().name());
1273 EXPECT_EQ(session_metadata->version(), metadata.session_metadata().version());
1274 }
1275
TEST(DirectSessionTest,SessionMetadataPresentViaFunction)1276 TEST(DirectSessionTest, SessionMetadataPresentViaFunction) {
1277 FunctionDefLibrary library_graph_def;
1278 *library_graph_def.add_function() = SessionMetadataReaderOpFn();
1279 FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1280 Graph g(&flib);
1281 Tensor vx(DT_INT64, TensorShape({}));
1282 vx.scalar<int64_t>()() = 17;
1283 Node* x = test::graph::Constant(&g, vx);
1284 Node* y = test::graph::Unary(&g, "SessionMetadataReaderFn", x);
1285 GraphDef def;
1286 g.ToGraphDef(&def);
1287 *def.mutable_library() = library_graph_def;
1288 auto session_options = DefaultSessionOptions();
1289 auto* session_metadata =
1290 session_options.config.mutable_experimental()->mutable_session_metadata();
1291 session_metadata->set_name("name");
1292 session_metadata->set_version(1);
1293 auto sess = std::unique_ptr<Session>(NewSession(session_options));
1294 TF_ASSERT_OK(sess->Create(def));
1295 std::vector<Tensor> outputs;
1296 RunOptions run_opts;
1297 run_opts.set_inter_op_thread_pool(-1);
1298
1299 // Run without requesting RunMetadata.
1300 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs,
1301 nullptr /*=run_metadata*/);
1302 SessionMetadata read_metadata;
1303 ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1304 outputs[0].scalar<tstring>()(), &read_metadata));
1305 EXPECT_EQ("name", read_metadata.name());
1306 EXPECT_EQ(1, read_metadata.version());
1307
1308 // Run with requesting RunMetadata.
1309 RunMetadata metadata;
1310 s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs,
1311 &metadata /*=run_metadata*/);
1312 ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1313 outputs[0].scalar<tstring>()(), &read_metadata));
1314 EXPECT_EQ("name", read_metadata.name());
1315 EXPECT_EQ(1, read_metadata.version());
1316 EXPECT_EQ(session_metadata->name(), metadata.session_metadata().name());
1317 EXPECT_EQ(session_metadata->version(), metadata.session_metadata().version());
1318 }
1319
TEST(DirectSessionTest,SessionMetadataKey)1320 TEST(DirectSessionTest, SessionMetadataKey) {
1321 auto session_options0 = DefaultSessionOptions();
1322 auto* session_metadata0 = session_options0.config.mutable_experimental()
1323 ->mutable_session_metadata();
1324 session_metadata0->set_name("name");
1325 Session* sess0_ptr;
1326 ASSERT_TRUE(NewSession(session_options0, &sess0_ptr).ok());
1327 auto sess0 = absl::WrapUnique(sess0_ptr);
1328
1329 // Trying to use the same metadata (name, version) will cause an error.
1330 Session* dup_ptr;
1331 EXPECT_TRUE(
1332 errors::IsInvalidArgument(NewSession(session_options0, &dup_ptr)));
1333
1334 // A new (name, version) is fine.
1335 auto session_options1 = DefaultSessionOptions();
1336 auto* session_metadata1 = session_options1.config.mutable_experimental()
1337 ->mutable_session_metadata();
1338 session_metadata1->set_name("name");
1339 session_metadata1->set_version(1);
1340 Session* sess1_ptr;
1341 EXPECT_TRUE(NewSession(session_options1, &sess1_ptr).ok());
1342 auto sess1 = absl::WrapUnique(sess1_ptr);
1343
1344 // If the previous session, using the same (name, version) is gone, then it's
1345 // fine.
1346 sess0 = nullptr;
1347 EXPECT_TRUE(NewSession(session_options0, &dup_ptr).ok());
1348 auto dup = absl::WrapUnique(dup_ptr);
1349
1350 // Sessions without metadata options are always fine.
1351 auto sess_without_metadata0 = CreateSession();
1352 EXPECT_NE(sess_without_metadata0, nullptr);
1353 auto sess_without_metadata1 = CreateSession();
1354 EXPECT_NE(sess_without_metadata1, nullptr);
1355 }
1356
TEST(DirectSessionTest,SessionMetadataInvalid)1357 TEST(DirectSessionTest, SessionMetadataInvalid) {
1358 const auto valid_session_options = DefaultSessionOptions();
1359 Session* sess_ptr;
1360 ASSERT_TRUE(NewSession(valid_session_options, &sess_ptr).ok());
1361 auto sess = absl::WrapUnique(sess_ptr);
1362
1363 auto invalid_session_options = valid_session_options;
1364 auto* invalid_metadata =
1365 invalid_session_options.config.mutable_experimental()
1366 ->mutable_session_metadata();
1367 invalid_metadata->set_name("name");
1368 // Version should be >= 0.
1369 invalid_metadata->set_version(-1);
1370 Session* error_sess_ptr;
1371 EXPECT_TRUE(errors::IsInvalidArgument(
1372 NewSession(invalid_session_options, &error_sess_ptr)));
1373 }
1374
1375 REGISTER_OP("ThreadID").Input("x: int64").Output("y: int64").Doc(R"doc(
1376 ThreadID returns the thread ID that called compute.
1377
1378 x: int64
1379 y: int64
1380 )doc");
1381
1382 // The ThreadID kernel returns the thread ID that executed Compute.
1383 class ThreadIDOp : public OpKernel {
1384 public:
ThreadIDOp(OpKernelConstruction * ctx)1385 explicit ThreadIDOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1386 void Compute(OpKernelContext* ctx) override {
1387 Tensor* out_tensor = nullptr;
1388 OP_REQUIRES_OK(ctx,
1389 ctx->allocate_output("y", TensorShape({}), &out_tensor));
1390 std::hash<std::thread::id> hasher;
1391 out_tensor->scalar<int64_t>()() =
1392 static_cast<int64_t>(hasher(std::this_thread::get_id()));
1393 }
1394 };
1395 REGISTER_KERNEL_BUILDER(Name("ThreadID").Device(DEVICE_CPU), ThreadIDOp);
1396
TEST(DirectSessionTest,SessionSyncRun)1397 TEST(DirectSessionTest, SessionSyncRun) {
1398 Graph g(OpRegistry::Global());
1399 Tensor vx(DT_INT64, TensorShape({}));
1400 vx.scalar<int64_t>()() = 17;
1401 Node* x = test::graph::Constant(&g, vx);
1402 Node* y = test::graph::Unary(&g, "ThreadID", x);
1403 GraphDef def;
1404 g.ToGraphDef(&def);
1405 auto sess = CreateSession();
1406 TF_ASSERT_OK(sess->Create(def));
1407 std::vector<Tensor> outputs;
1408 RunOptions run_opts;
1409 run_opts.set_inter_op_thread_pool(-1);
1410 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1411
1412 std::hash<std::thread::id> hasher;
1413 EXPECT_EQ(static_cast<int64_t>(hasher(std::this_thread::get_id())),
1414 static_cast<int64_t>(outputs[0].scalar<int64_t>()()));
1415 }
1416
1417 REGISTER_OP("ExpensiveNoop").SetIsStateful();
1418
1419 class ExpensiveNoopOp : public OpKernel {
1420 public:
1421 using OpKernel::OpKernel;
IsExpensive()1422 bool IsExpensive() override { return true; }
Compute(OpKernelContext * ctx)1423 void Compute(OpKernelContext* ctx) override {
1424 const string& stack_trace = tensorflow::CurrentStackTrace();
1425 const string process_method = "ExecutorState::Process()";
1426 size_t pos = 0;
1427 int frame_count = 0;
1428 while ((pos = stack_trace.find("ExecutorState::Process()", pos)) !=
1429 string::npos) {
1430 ++frame_count;
1431 ++pos;
1432 }
1433 OP_REQUIRES(ctx, frame_count <= 1,
1434 errors::Internal(
1435 "Recursive call to ExecutorState::Process() detected."));
1436 }
1437 };
1438
1439 REGISTER_KERNEL_BUILDER(Name("ExpensiveNoop").Device(DEVICE_CPU),
1440 ExpensiveNoopOp);
1441
TEST(DirectSessionTest,SessionSyncRun_DeepGraph)1442 TEST(DirectSessionTest, SessionSyncRun_DeepGraph) {
1443 Graph g(OpRegistry::Global());
1444
1445 std::vector<Node*> nodes;
1446 nodes.reserve(1024);
1447
1448 auto make_expensive_noop = [&g](gtl::ArraySlice<Node*> control_deps) {
1449 Node* ret;
1450 auto builder = NodeBuilder(g.NewName("N"), "ExpensiveNoop");
1451 for (Node* control_dep : control_deps) {
1452 builder = builder.ControlInput(control_dep);
1453 }
1454 TF_CHECK_OK(builder.Finalize(&g, &ret));
1455 return ret;
1456 };
1457
1458 Node* base = make_expensive_noop({});
1459
1460 Node* child_1 = make_expensive_noop({base});
1461 Node* child_2 = make_expensive_noop({base});
1462
1463 GraphDef def;
1464 g.ToGraphDef(&def);
1465
1466 auto sess = CreateSession();
1467 TF_ASSERT_OK(sess->Create(def));
1468 std::vector<Tensor> outputs;
1469 RunOptions run_opts;
1470 run_opts.set_inter_op_thread_pool(-1);
1471
1472 EXPECT_TRUE(sess->Run(run_opts, {}, {}, {child_1->name(), child_2->name()},
1473 &outputs, nullptr)
1474 .ok());
1475 }
1476
TEST(DirectSessionTest,SyncSession)1477 TEST(DirectSessionTest, SyncSession) {
1478 Graph g(OpRegistry::Global());
1479 Tensor vx(DT_INT64, TensorShape({}));
1480 vx.scalar<int64_t>()() = 17;
1481 Node* x = test::graph::Constant(&g, vx);
1482 Node* y = test::graph::Unary(&g, "ThreadID", x);
1483 GraphDef def;
1484 g.ToGraphDef(&def);
1485 SessionOptions options;
1486 options.config.set_inter_op_parallelism_threads(-1);
1487 std::unique_ptr<Session> sess(NewSession(options));
1488 TF_ASSERT_OK(sess->Create(def));
1489 std::vector<Tensor> outputs;
1490 RunOptions run_opts;
1491 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1492
1493 std::hash<std::thread::id> hasher;
1494 EXPECT_EQ(static_cast<int64_t>(hasher(std::this_thread::get_id())),
1495 static_cast<int64_t>(outputs[0].scalar<int64_t>()()));
1496 }
1497
1498 REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc(
1499 Darth promises one return value.
1500
1501 x: float
1502 y: float
1503 )doc");
1504
1505 // The DarthOp kernel violates its promise to return one-value.
1506 class DarthOp : public OpKernel {
1507 public:
DarthOp(OpKernelConstruction * ctx)1508 explicit DarthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1509 void Compute(OpKernelContext* ctx) override {}
1510 };
1511 REGISTER_KERNEL_BUILDER(Name("Darth").Device(DEVICE_CPU), DarthOp);
1512
TEST(DirectSessionTest,DarthKernel)1513 TEST(DirectSessionTest, DarthKernel) {
1514 Graph g(OpRegistry::Global());
1515 Tensor vx(DT_FLOAT, TensorShape({}));
1516 vx.scalar<float>()() = 1.0;
1517 Node* x = test::graph::Constant(&g, vx);
1518 Node* y = test::graph::Unary(&g, "Darth", x);
1519 GraphDef def;
1520 g.ToGraphDef(&def);
1521 auto sess = CreateSession();
1522 TF_ASSERT_OK(sess->Create(def));
1523 std::vector<Tensor> outputs;
1524 auto s = sess->Run({}, {y->name() + ":0"}, {}, &outputs);
1525 EXPECT_TRUE(errors::IsInternal(s));
1526 }
1527
1528 // Have the Darth op in the graph placed on GPU, but don't run it.
TEST(DirectSessionTest,PlacePrunedGraph)1529 TEST(DirectSessionTest, PlacePrunedGraph) {
1530 {
1531 Graph g(OpRegistry::Global());
1532 Tensor vx(DT_FLOAT, TensorShape({}));
1533 vx.scalar<float>()() = 1.0;
1534 Node* x = test::graph::Constant(&g, vx);
1535 Node* y = test::graph::Unary(&g, "Darth", x);
1536 y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
1537 GraphDef def;
1538 g.ToGraphDef(&def);
1539
1540 // By default, we place the entire graph, so we should fail the
1541 // call to Create.
1542 SessionOptions options;
1543 std::unique_ptr<Session> sess(NewSession(options));
1544 auto s = sess->Create(def);
1545 EXPECT_TRUE(errors::IsInvalidArgument(s));
1546 }
1547
1548 {
1549 Graph g(OpRegistry::Global());
1550 Tensor vx(DT_FLOAT, TensorShape({}));
1551 vx.scalar<float>()() = 1.0;
1552 Node* x = test::graph::Constant(&g, vx);
1553 Node* y = test::graph::Unary(&g, "Darth", x);
1554 y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
1555 GraphDef def;
1556 g.ToGraphDef(&def);
1557
1558 SessionOptions options;
1559 // Set the option to place pruned graphs, we should expect this
1560 // to run.
1561 options.config.mutable_graph_options()->set_place_pruned_graph(true);
1562 std::unique_ptr<Session> sess(NewSession(options));
1563 TF_ASSERT_OK(sess->Create(def));
1564 std::vector<Tensor> outputs;
1565 auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs);
1566 TF_EXPECT_OK(s);
1567 }
1568 }
1569
TEST(DirectSessionTest,PartialRunTest)1570 TEST(DirectSessionTest, PartialRunTest) {
1571 GraphDef def;
1572 Graph g(OpRegistry::Global());
1573
1574 Tensor first_value(DT_FLOAT, TensorShape({}));
1575 first_value.scalar<float>()() = 1.0;
1576 Node* first_const = test::graph::Constant(&g, first_value);
1577 Node* first_identity = test::graph::Identity(&g, first_const);
1578
1579 Tensor second_value(DT_FLOAT, TensorShape({}));
1580 second_value.scalar<float>()() = 2.0;
1581 Node* second_const = test::graph::Constant(&g, second_value);
1582 Node* second_identity = test::graph::Identity(&g, second_const);
1583
1584 Node* third = test::graph::Add(&g, first_identity, second_identity);
1585 Node* third_identity = test::graph::Identity(&g, third);
1586
1587 g.ToGraphDef(&def);
1588
1589 auto session = CreateSession();
1590 ASSERT_TRUE(session != nullptr);
1591 TF_ASSERT_OK(session->Create(def));
1592
1593 std::vector<Tensor> outputs;
1594
1595 string handle;
1596 Status s = session->PRunSetup(
1597 {first_const->name(), second_const->name()},
1598 {first_identity->name() + ":0", second_identity->name() + ":0",
1599 third_identity->name() + ":0"},
1600 {}, &handle);
1601 TF_ASSERT_OK(s);
1602
1603 Tensor value_11(DT_FLOAT, TensorShape({}));
1604 value_11.scalar<float>()() = 11.0;
1605 Tensor value_22(DT_FLOAT, TensorShape({}));
1606 value_22.scalar<float>()() = 22.0;
1607
1608 // Feed first_const, fetch first_identity
1609 s = session->PRun(handle, {{first_const->name(), value_11}},
1610 {first_identity->name() + ":0"}, &outputs);
1611 TF_ASSERT_OK(s);
1612 ASSERT_EQ(1, outputs.size());
1613 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1614
1615 // Feed second_const, fetch second_identity and third_identity
1616 s = session->PRun(
1617 handle, {{second_const->name(), value_22}},
1618 {second_identity->name() + ":0", third_identity->name() + ":0"},
1619 &outputs);
1620 TF_ASSERT_OK(s);
1621 ASSERT_EQ(2, outputs.size());
1622 ASSERT_EQ(22.0, outputs[0].flat<float>()(0));
1623 ASSERT_EQ(11.0 + 22.0, outputs[1].flat<float>()(0));
1624 }
1625
TEST(DirectSessionTest,PartialRunMissingFeed)1626 TEST(DirectSessionTest, PartialRunMissingFeed) {
1627 GraphDef def;
1628 Graph g(OpRegistry::Global());
1629
1630 Tensor first_value(DT_FLOAT, TensorShape({}));
1631 first_value.scalar<float>()() = 1.0;
1632 Node* first_const = test::graph::Constant(&g, first_value);
1633 Node* first_identity = test::graph::Identity(&g, first_const);
1634
1635 Tensor second_value(DT_FLOAT, TensorShape({}));
1636 second_value.scalar<float>()() = 2.0;
1637 Node* second_const = test::graph::Constant(&g, second_value);
1638 Node* second_identity = test::graph::Identity(&g, second_const);
1639
1640 Node* third = test::graph::Add(&g, first_identity, second_identity);
1641 Node* third_identity = test::graph::Identity(&g, third);
1642
1643 g.ToGraphDef(&def);
1644
1645 auto session = CreateSession();
1646 ASSERT_TRUE(session != nullptr);
1647 TF_ASSERT_OK(session->Create(def));
1648
1649 std::vector<Tensor> outputs;
1650
1651 string handle;
1652 Status s = session->PRunSetup({first_const->name(), second_const->name()},
1653 {third_identity->name() + ":0"}, {}, &handle);
1654 TF_ASSERT_OK(s);
1655
1656 // Feed first_const, fetch third_identity
1657 Tensor value_11(DT_FLOAT, TensorShape({}));
1658 value_11.scalar<float>()() = 11.0;
1659 s = session->PRun(handle, {{first_const->name(), value_11}},
1660 {third_identity->name() + ":0"}, &outputs);
1661 ASSERT_TRUE(errors::IsInvalidArgument(s));
1662 EXPECT_TRUE(
1663 absl::StrContains(s.error_message(), "can't be computed from the feeds"));
1664 }
1665
TEST(DirectSessionTest,PartialRunMultiOutputFeed)1666 TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
1667 GraphDef def;
1668 Graph g(OpRegistry::Global());
1669
1670 Tensor bool_value(DT_BOOL, TensorShape({}));
1671 bool_value.scalar<bool>()() = true;
1672 Node* bool_const = test::graph::Constant(&g, bool_value);
1673 Node* switch_node = test::graph::Switch(&g, bool_const, bool_const);
1674 Node* fourth_identity = test::graph::Identity(&g, switch_node, 1);
1675
1676 g.ToGraphDef(&def);
1677
1678 auto session = CreateSession();
1679 ASSERT_TRUE(session != nullptr);
1680 TF_ASSERT_OK(session->Create(def));
1681
1682 std::vector<Tensor> outputs;
1683
1684 string handle;
1685 Status s = session->PRunSetup({switch_node->name() + ":1"},
1686 {fourth_identity->name() + ":0"}, {}, &handle);
1687 TF_ASSERT_OK(s);
1688
1689 // Fetch fourth_identity without feeds.
1690 s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs);
1691 ASSERT_TRUE(errors::IsInvalidArgument(s));
1692 EXPECT_TRUE(
1693 absl::StrContains(s.error_message(), "can't be computed from the feeds"));
1694
1695 // Feed switch_node:1 and fetch fourth_identity.
1696 s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}},
1697 {fourth_identity->name() + ":0"}, &outputs);
1698 TF_ASSERT_OK(s);
1699 ASSERT_EQ(1, outputs.size());
1700 ASSERT_EQ(true, outputs[0].flat<bool>()(0));
1701 }
1702
TEST(DirectSessionTest,RunHandleTest)1703 TEST(DirectSessionTest, RunHandleTest) {
1704 GraphDef def;
1705 Graph g(OpRegistry::Global());
1706
1707 Tensor value0(DT_FLOAT, TensorShape({}));
1708 value0.scalar<float>()() = 1.0;
1709 Node* const0 = test::graph::Constant(&g, value0);
1710 Node* identity0 = test::graph::Identity(&g, const0);
1711
1712 Tensor value1(DT_FLOAT, TensorShape({}));
1713 value1.scalar<float>()() = 2.0;
1714 Node* const1 = test::graph::Constant(&g, value1);
1715 Node* node3 = test::graph::Add(&g, identity0, const1);
1716 Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
1717
1718 Tensor value2(DT_STRING, TensorShape({}));
1719 Node* const2 = test::graph::Constant(&g, value2);
1720 Node* node5 = test::graph::GetSessionTensor(&g, const2);
1721 Node* node6 = test::graph::Add(&g, node5, const1);
1722
1723 Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
1724
1725 g.ToGraphDef(&def);
1726
1727 auto session = CreateSession();
1728 ASSERT_TRUE(session != nullptr);
1729 TF_ASSERT_OK(session->Create(def));
1730
1731 // First run call: Create a handle.
1732 std::vector<Tensor> outputs;
1733 Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
1734 ASSERT_TRUE(s.ok());
1735 ASSERT_EQ(1, outputs.size());
1736
1737 const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
1738 Tensor string_handle(DT_STRING, {});
1739 string_handle.flat<tstring>().setConstant(resource_handle.name());
1740
1741 // Second run call: Use a handle.
1742 std::vector<Tensor> outputs1;
1743 s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
1744 {}, &outputs1);
1745 ASSERT_TRUE(s.ok());
1746 ASSERT_EQ(1, outputs1.size());
1747 ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
1748
1749 // Third run call: Delete a handle.
1750 std::vector<Tensor> outputs2;
1751 s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
1752 &outputs2);
1753 ASSERT_TRUE(s.ok());
1754 }
1755
TEST(DirectSessionTest,RunHandleTest_Callable)1756 TEST(DirectSessionTest, RunHandleTest_Callable) {
1757 GraphDef def;
1758 Graph g(OpRegistry::Global());
1759
1760 Tensor value0(DT_FLOAT, TensorShape({}));
1761 value0.scalar<float>()() = 1.0;
1762 Node* const0 = test::graph::Constant(&g, value0);
1763 Node* identity0 = test::graph::Identity(&g, const0);
1764
1765 Tensor value1(DT_FLOAT, TensorShape({}));
1766 value1.scalar<float>()() = 2.0;
1767 Node* const1 = test::graph::Constant(&g, value1);
1768 Node* node3 = test::graph::Add(&g, identity0, const1);
1769 Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
1770
1771 Tensor value2(DT_STRING, TensorShape({}));
1772 Node* const2 = test::graph::Constant(&g, value2);
1773 Node* node5 = test::graph::GetSessionTensor(&g, const2);
1774 Node* node6 = test::graph::Add(&g, node5, const1);
1775
1776 Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
1777
1778 g.ToGraphDef(&def);
1779
1780 auto session = CreateSession();
1781 ASSERT_TRUE(session != nullptr);
1782 TF_ASSERT_OK(session->Create(def));
1783
1784 // First run call: Create a handle.
1785 std::vector<Tensor> outputs;
1786 Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
1787 ASSERT_TRUE(s.ok());
1788 ASSERT_EQ(1, outputs.size());
1789
1790 const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
1791 Tensor string_handle(DT_STRING, {});
1792 string_handle.flat<tstring>().setConstant(resource_handle.name());
1793
1794 // Second run call: Use a handle.
1795 std::vector<Tensor> outputs1;
1796 s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
1797 {}, &outputs1);
1798 ASSERT_TRUE(s.ok());
1799 ASSERT_EQ(1, outputs1.size());
1800 ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
1801
1802 // Third run call: Delete a handle.
1803 std::vector<Tensor> outputs2;
1804 s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
1805 &outputs2);
1806 ASSERT_TRUE(s.ok());
1807 }
1808
TEST(DirectSessionTest,CreateGraphFailsWhenAssigningAFedVar)1809 TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) {
1810 Graph graph(OpRegistry::Global());
1811
1812 Node* a = test::graph::Var(&graph, DT_FLOAT, {});
1813 Node* b = test::graph::Constant(&graph, {});
1814
1815 Tensor zero(DT_FLOAT, {});
1816 test::FillValues<float>(&zero, {0});
1817
1818 // a = b
1819 Node* assign = test::graph::Assign(&graph, a, b);
1820
1821 auto session = CreateSession();
1822 ASSERT_TRUE(session != nullptr);
1823
1824 // The graph is invalid since a constant cannot be assigned to a constant.
1825 // The return Status of session->Run should flag this as an invalid argument.
1826 std::vector<Tensor> outputs;
1827 Status s = session->Run({{a->name(), zero}}, {assign->name()}, {}, &outputs);
1828 ASSERT_TRUE(errors::IsInvalidArgument(s));
1829 }
1830
TEST(DirectSessionTest,TimeoutSession)1831 TEST(DirectSessionTest, TimeoutSession) {
1832 GraphDef graph;
1833 // Creates a graph with one FIFOQueue and one dequeue op.
1834 protobuf::TextFormat::ParseFromString(R"pb(
1835 node {
1836 name: 'fifo_queue'
1837 op: 'FIFOQueue'
1838 device: '/device:CPU:0'
1839 attr {
1840 key: 'capacity'
1841 value { i: 10 }
1842 }
1843 attr {
1844 key: 'component_types'
1845 value { list { type: DT_FLOAT } }
1846 }
1847 attr {
1848 key: 'container'
1849 value { s: '' }
1850 }
1851 attr {
1852 key: 'shapes'
1853 value { list {} }
1854 }
1855 attr {
1856 key: 'shared_name'
1857 value { s: '' }
1858 }
1859 }
1860 node {
1861 name: 'fifo_queue_Dequeue'
1862 op: 'QueueDequeue'
1863 input: 'fifo_queue'
1864 device: '/device:CPU:0'
1865 attr {
1866 key: 'component_types'
1867 value { list { type: DT_FLOAT } }
1868 }
1869 attr {
1870 key: 'timeout_ms'
1871 value { i: -1 }
1872 }
1873 }
1874 versions { producer: 9 }
1875 )pb",
1876 &graph);
1877
1878 {
1879 // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
1880 SessionOptions options;
1881 (*options.config.mutable_device_count())["CPU"] = 2;
1882 options.config.set_operation_timeout_in_ms(100);
1883
1884 std::unique_ptr<Session> session(NewSession(options));
1885 ASSERT_TRUE(session != nullptr);
1886 TF_ASSERT_OK(session->Create(graph));
1887
1888 // Verifies that the error code is DEADLINE_EXCEEDED.
1889 Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr);
1890 ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
1891 TF_ASSERT_OK(session->Close());
1892 }
1893
1894 {
1895 // Creates a session with no operation_timeout_in_ms.
1896 auto session = CreateSession();
1897 ASSERT_TRUE(session != nullptr);
1898 TF_ASSERT_OK(session->Create(graph));
1899 RunOptions run_options;
1900 run_options.set_timeout_in_ms(20);
1901 // Verifies that the error code is DEADLINE_EXCEEDED.
1902 Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"},
1903 nullptr, nullptr);
1904 ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code());
1905 TF_ASSERT_OK(session->Close());
1906 }
1907 }
1908
1909 // Accesses the cancellation manager for the step after the step has been
1910 // cancelled.
1911 class CancellationMgrPollingOp : public OpKernel {
1912 public:
CancellationMgrPollingOp(OpKernelConstruction * ctx)1913 explicit CancellationMgrPollingOp(OpKernelConstruction* ctx)
1914 : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1915 void Compute(OpKernelContext* ctx) override {
1916 CancellationManager* cm = ctx->cancellation_manager();
1917 while (!cm->IsCancelled()) {
1918 ctx->env()->SleepForMicroseconds(1000);
1919 }
1920 notification.Notify();
1921 }
1922 static Notification notification;
1923 };
1924 Notification CancellationMgrPollingOp::notification;
1925
1926 REGISTER_KERNEL_BUILDER(Name("CancellationMgrPollingOp").Device(DEVICE_CPU),
1927 CancellationMgrPollingOp);
1928 REGISTER_OP("CancellationMgrPollingOp").Doc("");
1929
TEST(DirectSessionTest,TestTimeoutCleanShutdown)1930 TEST(DirectSessionTest, TestTimeoutCleanShutdown) {
1931 GraphDef graph;
1932 // Creates a graph with one FIFOQueue and one dequeue op.
1933 protobuf::TextFormat::ParseFromString(R"pb(
1934 node {
1935 name: 'cm_polling'
1936 op: 'CancellationMgrPollingOp'
1937 device: '/device:CPU:0'
1938 }
1939 versions { producer: 9 }
1940 )pb",
1941 &graph);
1942
1943 // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
1944 SessionOptions options;
1945 options.config.set_operation_timeout_in_ms(100);
1946 std::unique_ptr<Session> session(NewSession(options));
1947 ASSERT_TRUE(session != nullptr);
1948 TF_ASSERT_OK(session->Create(graph));
1949
1950 // Verifies that the error code is DEADLINE_EXCEEDED.
1951 Status s = session->Run({}, {}, {"cm_polling"}, nullptr);
1952 ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
1953
1954 // Verify that the op ran to completion.
1955 ASSERT_TRUE(CancellationMgrPollingOp::notification.HasBeenNotified());
1956
1957 TF_ASSERT_OK(session->Close());
1958 }
1959
TestSessionInterOpThreadsImpl(bool use_function_lib,bool use_global_pools)1960 static void TestSessionInterOpThreadsImpl(bool use_function_lib,
1961 bool use_global_pools) {
1962 using test::function::blocking_op_state;
1963 using test::function::BlockingOpState;
1964
1965 FunctionDefLibrary library_graph_def;
1966 if (use_function_lib) {
1967 *library_graph_def.add_function() = test::function::BlockingOpFn();
1968 }
1969
1970 FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1971 Graph g(&flib);
1972 Tensor t(DT_FLOAT, TensorShape({}));
1973 t.scalar<float>()() = {1.2f};
1974 Node* x = test::graph::Constant(&g, t);
1975 Node* y;
1976 if (use_function_lib) {
1977 y = test::graph::Unary(&g, "BlockingOpFn", x);
1978 } else {
1979 y = test::graph::Unary(&g, "BlockingOp", x);
1980 }
1981 GraphDef def;
1982 g.ToGraphDef(&def);
1983 *def.mutable_library() = library_graph_def;
1984
1985 // Create session with two inter-op thread pools.
1986 SessionOptions options;
1987 // Turn off optimizations so that the blocking op doesn't get invoked during
1988 // graph setup.
1989 options.config.mutable_graph_options()
1990 ->mutable_optimizer_options()
1991 ->set_opt_level(OptimizerOptions::L0);
1992 options.config.mutable_graph_options()
1993 ->mutable_rewrite_options()
1994 ->set_constant_folding(RewriterConfig::OFF);
1995 (*options.config.mutable_device_count())["CPU"] = 2;
1996 (*options.config.mutable_device_count())["GPU"] = 0;
1997
1998 auto* p = options.config.add_session_inter_op_thread_pool();
1999 if (use_global_pools) p->set_global_name("large pool");
2000 p = options.config.add_session_inter_op_thread_pool();
2001 if (use_global_pools) p->set_global_name("small pool");
2002 p->set_num_threads(1);
2003 const int kSyncPool = -1;
2004 const int kLargePool = 0;
2005 const int kSmallPool = 1;
2006
2007 std::vector<std::unique_ptr<Session>> sessions;
2008 if (!use_global_pools) {
2009 sessions.emplace_back(NewSession(options));
2010 TF_ASSERT_OK(sessions.back()->Create(def));
2011 }
2012 mutex sessions_mu;
2013
2014 std::atomic<int32> num_done(0);
2015 // Runs session to compute <node>:0 using inter_op thread pool <pool>.
2016 auto add_session_run_call =
2017 [use_global_pools, &def, &options, &sessions, &sessions_mu, &num_done](
2018 thread::ThreadPool* tp, Node* node, int inter_op_pool) {
2019 auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu,
2020 inter_op_pool, node, &num_done]() {
2021 RunOptions run_options;
2022 run_options.set_inter_op_thread_pool(inter_op_pool);
2023 std::vector<Tensor> outputs;
2024
2025 Session* session;
2026 if (use_global_pools) {
2027 std::unique_ptr<Session> s(NewSession(options));
2028 TF_ASSERT_OK(s->Create(def));
2029 session = s.get();
2030
2031 mutex_lock l(sessions_mu);
2032 sessions.emplace_back(std::move(s));
2033 } else {
2034 session = sessions[0].get();
2035 }
2036
2037 Status s = session->Run(run_options, {} /* inputs */,
2038 {node->name() + ":0"} /* output_names */, {},
2039 &outputs, nullptr /* run_metadata */);
2040 TF_CHECK_OK(s);
2041 ASSERT_EQ(1, outputs.size());
2042 auto flat = outputs[0].flat<float>();
2043 EXPECT_FLOAT_EQ(1.2, flat(0));
2044 num_done.fetch_add(1);
2045 };
2046 if (tp != nullptr) {
2047 tp->Schedule(fn);
2048 } else {
2049 fn();
2050 }
2051 };
2052
2053 // For blocking states:
2054 // - Starts at 0, BlockingOp::Compute will move to 1.
2055 // - This main thread will wait for 1, then move to 2 when other ops are done.
2056 // Moving to 2 unblocks the blocking op, which then moves to state 3.
2057
2058 // Run the graph once on the non-limited pool.
2059 thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 1);
2060 blocking_op_state = new BlockingOpState();
2061 add_session_run_call(tp1, y, kLargePool);
2062 blocking_op_state->AwaitState(1);
2063 blocking_op_state->MoveToState(1, 2);
2064 blocking_op_state->AwaitState(3);
2065 blocking_op_state->MoveToState(3, 0);
2066 delete tp1;
2067 num_done = 0;
2068
2069 tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
2070
2071 // Launch a session run call. It will not finish until the blocking op is
2072 // unblocked, because it is using all threads in the small pool.
2073 add_session_run_call(tp1, y, kSmallPool);
2074
2075 blocking_op_state->AwaitState(1); // Wait for the blocking op to Compute.
2076
2077 // These will block on <BlockingOpState>.
2078 const int kBlockedThreads = 3;
2079 for (int i = 0; i < kBlockedThreads; ++i) {
2080 add_session_run_call(tp1, x, kSmallPool);
2081 }
2082
2083 // Launch session calls using the other inter-op pool. These will finish
2084 // as they are in inter_op pool #2.
2085 thread::ThreadPool* tp2 = new thread::ThreadPool(Env::Default(), "tp2", 3);
2086 const int kUnblockedThreads = 4;
2087 for (int i = 0; i < kUnblockedThreads; ++i) {
2088 add_session_run_call(tp2, x, kLargePool);
2089 }
2090 delete tp2;
2091 EXPECT_EQ(kUnblockedThreads, num_done.load());
2092
2093 // Launch a session call using this thread. This will finish as it runs
2094 // synchronously in this thread.
2095 add_session_run_call(nullptr, x, kSyncPool);
2096
2097 // Unblock the blocked op and wait for the blocked functions to finish.
2098 blocking_op_state->MoveToState(1, 2);
2099 delete tp1;
2100
2101 EXPECT_EQ(kUnblockedThreads + kBlockedThreads + 1 + 1, num_done.load());
2102 delete blocking_op_state;
2103 blocking_op_state = nullptr;
2104 }
2105
TEST(DirectSessionTest,TestSessionInterOpThreads)2106 TEST(DirectSessionTest, TestSessionInterOpThreads) {
2107 TestSessionInterOpThreadsImpl(false /* use_function_lib */,
2108 false /*use_global_pools */);
2109 }
2110
TEST(DirectSessionTest,TestSessionInterOpThreadsWithFunctions)2111 TEST(DirectSessionTest, TestSessionInterOpThreadsWithFunctions) {
2112 TestSessionInterOpThreadsImpl(true /* use_function_lib */,
2113 false /*use_global_pools */);
2114 }
2115
TEST(DirectSessionTest,TestSessionInterOpGlobalPools)2116 TEST(DirectSessionTest, TestSessionInterOpGlobalPools) {
2117 TestSessionInterOpThreadsImpl(false /* use_function_lib */,
2118 true /*use_global_pools */);
2119 }
2120
TEST(DirectSessionTest,TestSessionInterOpGlobalPoolsWithFunctions)2121 TEST(DirectSessionTest, TestSessionInterOpGlobalPoolsWithFunctions) {
2122 TestSessionInterOpThreadsImpl(true /* use_function_lib */,
2123 true /*use_global_pools */);
2124 }
2125
TEST(DirectSessionTest,TestSessionInterOpThreadsInvalidOptions)2126 TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
2127 Graph g(OpRegistry::Global());
2128 Tensor t(DT_FLOAT, TensorShape({}));
2129 t.scalar<float>()() = {1.2f};
2130 Node* x = test::graph::Constant(&g, t);
2131 GraphDef def;
2132 g.ToGraphDef(&def);
2133
2134 SessionOptions options;
2135 options.config.mutable_graph_options()
2136 ->mutable_optimizer_options()
2137 ->set_opt_level(OptimizerOptions::L0);
2138 (*options.config.mutable_device_count())["CPU"] = 2;
2139
2140 options.config.add_session_inter_op_thread_pool();
2141
2142 // Wrong pool number on Run call.
2143 {
2144 std::unique_ptr<Session> session(NewSession(options));
2145 TF_ASSERT_OK(session->Create(def));
2146 for (int pool_num = -2; pool_num <= 1; pool_num += 3) {
2147 RunOptions run_options;
2148 run_options.set_inter_op_thread_pool(pool_num);
2149 std::vector<Tensor> outputs;
2150 Status s = session->Run(run_options, {} /* inputs */,
2151 {x->name() + ":0"} /* output_names */, {},
2152 &outputs, nullptr /* run_metadata */);
2153 EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
2154 EXPECT_TRUE(absl::StrContains(
2155 s.error_message(),
2156 strings::StrCat("Invalid inter_op_thread_pool: ", pool_num)));
2157 }
2158 }
2159
2160 // Global name changes thread count.
2161 std::vector<std::unique_ptr<Session>> sessions;
2162 auto* pool_config = options.config.mutable_session_inter_op_thread_pool(0);
2163 pool_config->set_num_threads(0);
2164 pool_config->set_global_name("foo");
2165 sessions.emplace_back(NewSession(options));
2166 TF_ASSERT_OK(sessions.back()->Create(def));
2167 sessions.emplace_back(NewSession(options)); // repeat creation, okay.
2168 TF_ASSERT_OK(sessions.back()->Create(def));
2169 for (int pass = 0; pass < 2; ++pass) {
2170 for (int i = 1; i < 128; ++i) {
2171 pool_config->set_num_threads(i);
2172 sessions.emplace_back(NewSession(options));
2173 auto status = sessions.back()->Create(def);
2174 ASSERT_FALSE(status.ok()) << status;
2175 }
2176
2177 // Clear existing sessions before second pass; error still happens.
2178 sessions.clear();
2179 }
2180 }
2181
TEST(DirectSessionTest,TestDirectSessionRunClose)2182 TEST(DirectSessionTest, TestDirectSessionRunClose) {
2183 // Construct a graph with a variable and a single assign.
2184 Graph g(OpRegistry::Global());
2185 Tensor t(DT_FLOAT, TensorShape({}));
2186 t.scalar<float>()() = {1.2f};
2187 Node* var_val = test::graph::Constant(&g, t);
2188 Node* var = test::graph::Var(&g, DT_FLOAT, {});
2189 Node* var_assign = test::graph::Assign(&g, var, var_val);
2190 GraphDef def;
2191 g.ToGraphDef(&def);
2192
2193 SessionOptions options;
2194 (*options.config.mutable_device_count())["CPU"] = 2;
2195 std::unique_ptr<Session> session(NewSession(options));
2196 ASSERT_TRUE(session != nullptr);
2197 TF_ASSERT_OK(session->Create(def));
2198
2199 // Assign a value to the var.
2200 TF_ASSERT_OK(session->Run({} /* inputs */, {},
2201 {var_assign->name()} /* target_nodes */, nullptr));
2202
2203 // Run a read on the variable to ensure that it works.
2204 std::vector<Tensor> outputs;
2205 TF_ASSERT_OK(session->Run(
2206 {} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
2207 EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
2208 outputs.clear();
2209
2210 // Make a callable handle before closing the session.
2211 Session::CallableHandle handle;
2212 TF_ASSERT_OK(session->MakeCallable(
2213 MakeCallableOptions({}, {}, {var_assign->name()}), &handle));
2214
2215 // Close the session.
2216 TF_ASSERT_OK(session->Close());
2217
2218 // Run the read on the variable to get an error.
2219 Status s = session->Run({} /* inputs */, {},
2220 {var_assign->name()} /* target_nodes */, nullptr);
2221 EXPECT_EQ(s.code(), error::CANCELLED);
2222 EXPECT_TRUE(absl::StrContains(s.error_message(), "Session has been closed."));
2223
2224 // Run the read as a callable to verify that we get the same error.
2225 s = session->RunCallable(handle, {}, {}, nullptr);
2226 EXPECT_EQ(s.code(), error::CANCELLED);
2227 EXPECT_TRUE(absl::StrContains(s.error_message(), "Session has been closed."));
2228 }
2229
TEST(DirectSessionTest,TestDirectSessionPRunClose)2230 TEST(DirectSessionTest, TestDirectSessionPRunClose) {
2231 GraphDef def;
2232 Graph g(OpRegistry::Global());
2233
2234 Tensor first_value(DT_FLOAT, TensorShape({}));
2235 first_value.scalar<float>()() = 1.0;
2236 Node* first_const = test::graph::Constant(&g, first_value);
2237 Node* first_identity = test::graph::Identity(&g, first_const);
2238
2239 Tensor second_value(DT_FLOAT, TensorShape({}));
2240 second_value.scalar<float>()() = 2.0;
2241 Node* second_const = test::graph::Constant(&g, second_value);
2242 Node* second_identity = test::graph::Identity(&g, second_const);
2243
2244 Node* third = test::graph::Add(&g, first_identity, second_identity);
2245 Node* third_identity = test::graph::Identity(&g, third);
2246
2247 g.ToGraphDef(&def);
2248
2249 auto session = CreateSession();
2250 ASSERT_TRUE(session != nullptr);
2251 TF_ASSERT_OK(session->Create(def));
2252
2253 std::vector<Tensor> outputs;
2254
2255 string handle;
2256 Status s = session->PRunSetup(
2257 {first_const->name(), second_const->name()},
2258 {first_identity->name() + ":0", second_identity->name() + ":0",
2259 third_identity->name() + ":0"},
2260 {}, &handle);
2261 TF_ASSERT_OK(s);
2262
2263 Tensor value_11(DT_FLOAT, TensorShape({}));
2264 value_11.scalar<float>()() = 11.0;
2265 Tensor value_22(DT_FLOAT, TensorShape({}));
2266 value_22.scalar<float>()() = 22.0;
2267
2268 // Close the session.
2269 TF_ASSERT_OK(session->Close());
2270
2271 // Feed first_const, fetch first_identity
2272 s = session->PRun(handle, {{first_const->name(), value_11}},
2273 {first_identity->name() + ":0"}, &outputs);
2274 EXPECT_EQ(s.code(), error::CANCELLED);
2275 EXPECT_TRUE(absl::StrContains(s.error_message(), "Session has been closed."));
2276 }
2277
TEST(DirectSessionTest,TestDirectSessionReset)2278 TEST(DirectSessionTest, TestDirectSessionReset) {
2279 // Construct a graph with a variable and a single assign.
2280 Graph g(OpRegistry::Global());
2281 Tensor t(DT_FLOAT, TensorShape({}));
2282 t.scalar<float>()() = {1.2f};
2283 Node* var_val = test::graph::Constant(&g, t);
2284 Node* var = test::graph::Var(&g, DT_FLOAT, {});
2285 Node* var_assign = test::graph::Assign(&g, var, var_val);
2286 GraphDef def;
2287 g.ToGraphDef(&def);
2288
2289 SessionOptions options;
2290 (*options.config.mutable_device_count())["CPU"] = 2;
2291 std::unique_ptr<Session> session(NewSession(options));
2292 ASSERT_TRUE(session != nullptr);
2293 TF_ASSERT_OK(session->Create(def));
2294
2295 // Assign a value to the var.
2296 TF_ASSERT_OK(session->Run({} /* inputs */, {},
2297 {var_assign->name()} /* target_nodes */, nullptr));
2298
2299 // Run a read on the variable to ensure that it works.
2300 std::vector<Tensor> outputs;
2301 TF_ASSERT_OK(session->Run(
2302 {} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
2303 EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
2304 outputs.clear();
2305
2306 // Reset the containers.
2307 TF_EXPECT_OK(Reset(options, {}));
2308
2309 // Run the read on the variable to get an error.
2310 // TODO(suharshs): This test only works because we close the Session in Reset.
2311 // If we change the behavior of Reset to not close the Session, this test will
2312 // fail, since the Variable buffer is cached by var.
2313 Status s = session->Run({} /* inputs */, {},
2314 {var_assign->name()} /* target_nodes */, nullptr);
2315 EXPECT_EQ(s.code(), error::CANCELLED);
2316 EXPECT_TRUE(absl::StrContains(s.error_message(), "Session has been closed."));
2317 }
2318
TEST(DirectSessionTest,LocalDeviceManager)2319 TEST(DirectSessionTest, LocalDeviceManager) {
2320 SessionOptions options;
2321 std::unique_ptr<Session> session(NewSession(options));
2322
2323 const DeviceMgr* mgr = nullptr;
2324 TF_ASSERT_OK(session->LocalDeviceManager(&mgr));
2325 ASSERT_TRUE(mgr != nullptr);
2326 EXPECT_GT(mgr->ListDevices().size(), 0);
2327 }
2328
2329 // y = tf.square(x)
CreateGraphForYEqualsXSquared()2330 GraphDef CreateGraphForYEqualsXSquared() {
2331 GraphDef graph_def;
2332 const char* text_proto = R"EOF(
2333 node {
2334 name: "x"
2335 op: "Placeholder"
2336 attr { key: "dtype" value { type: DT_FLOAT } }
2337 attr { key: "shape" value { shape { unknown_rank: true } } }
2338 }
2339 node {
2340 name: "y"
2341 op: "Square"
2342 input: "x"
2343 attr { key: "T" value { type: DT_FLOAT } }
2344 }
2345 versions {
2346 producer: 26
2347 }
2348 )EOF";
2349
2350 QCHECK(protobuf::TextFormat::ParseFromString(text_proto, &graph_def));
2351 return graph_def;
2352 }
2353
2354 // A graph that consumes and produces string tensors
2355 // (which are not GPU-compatible, i.e., there are no
2356 // GPU kernels for these operations).
IsCUDATensor(const Tensor & t)2357 bool IsCUDATensor(const Tensor& t) {
2358 #ifdef GOOGLE_CUDA
2359 cudaPointerAttributes attributes;
2360 cudaError_t err =
2361 cudaPointerGetAttributes(&attributes, t.tensor_data().data());
2362 if (err == cudaErrorInvalidValue) return false;
2363 CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err);
2364 return (attributes.type == cudaMemoryTypeDevice);
2365 #elif TENSORFLOW_USE_ROCM
2366 hipPointerAttribute_t attributes;
2367 hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data());
2368 if (err == hipErrorInvalidValue) return false;
2369 CHECK_EQ(hipSuccess, err) << hipGetErrorString(err);
2370 return (attributes.memoryType == hipMemoryTypeDevice);
2371 #else
2372 return false;
2373 #endif
2374 }
2375
GPUDeviceName(Session * session)2376 string GPUDeviceName(Session* session) {
2377 std::vector<DeviceAttributes> devices;
2378 TF_CHECK_OK(session->ListDevices(&devices));
2379 for (const DeviceAttributes& d : devices) {
2380 if (d.device_type() == "GPU" || d.device_type() == "gpu") {
2381 return d.name();
2382 }
2383 }
2384 return "";
2385 }
2386
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory)2387 TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory) {
2388 std::unique_ptr<Session> session(NewSession(SessionOptions()));
2389 const string gpu_device_name = GPUDeviceName(session.get());
2390 if (gpu_device_name.empty()) {
2391 LOG(INFO) << "Skipping test since no GPU is available";
2392 return;
2393 }
2394
2395 TF_ASSERT_OK(session->Create(CreateGraphForYEqualsXSquared()));
2396
2397 CallableOptions opts;
2398 opts.add_feed("x:0");
2399 opts.add_fetch("y:0");
2400
2401 Tensor gpu_tensor;
2402
2403 {
2404 Session::CallableHandle feed_cpu_fetch_gpu;
2405 opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2406 opts.set_fetch_skip_sync(true);
2407 TF_ASSERT_OK(session->MakeCallable(opts, &feed_cpu_fetch_gpu));
2408 Tensor input(DT_FLOAT, {});
2409 input.scalar<float>()() = 2.0f;
2410 std::vector<Tensor> outputs;
2411 TF_ASSERT_OK(
2412 session->RunCallable(feed_cpu_fetch_gpu, {input}, &outputs, nullptr));
2413 TF_ASSERT_OK(session->ReleaseCallable(feed_cpu_fetch_gpu));
2414 ASSERT_EQ(1, outputs.size());
2415 gpu_tensor = outputs[0];
2416 ASSERT_TRUE(IsCUDATensor(gpu_tensor));
2417 }
2418
2419 {
2420 Session::CallableHandle feed_gpu_fetch_cpu;
2421 opts.clear_fetch_devices();
2422 opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2423 TF_ASSERT_OK(session->MakeCallable(opts, &feed_gpu_fetch_cpu));
2424 std::vector<Tensor> outputs;
2425 TF_ASSERT_OK(session->RunCallable(feed_gpu_fetch_cpu, {gpu_tensor},
2426 &outputs, nullptr));
2427 TF_ASSERT_OK(session->ReleaseCallable(feed_gpu_fetch_cpu));
2428 ASSERT_EQ(1, outputs.size());
2429 // The output is in CPU/host memory, so it can be dereferenced.
2430 ASSERT_EQ(16.0, outputs[0].scalar<float>()());
2431 }
2432 }
2433
CreateIdentityGraphDef(DataType dtype)2434 GraphDef CreateIdentityGraphDef(DataType dtype) {
2435 GraphDef def;
2436
2437 AttrValue dtype_attr;
2438 dtype_attr.set_type(dtype);
2439
2440 AttrValue shape_attr;
2441 shape_attr.mutable_shape()->set_unknown_rank(true);
2442
2443 auto* placeholder = def.add_node();
2444 placeholder->set_name("x");
2445 placeholder->set_op("Placeholder");
2446 placeholder->mutable_attr()->insert({"dtype", dtype_attr});
2447 placeholder->mutable_attr()->insert({"shape", shape_attr});
2448
2449 auto* identity = def.add_node();
2450 identity->set_name("y");
2451 identity->set_op("Identity");
2452 identity->add_input("x");
2453 identity->mutable_attr()->insert({"T", dtype_attr});
2454
2455 return def;
2456 }
2457
TestFeedAndFetchTensorsInDeviceMemory(const SessionOptions & session_options,DataType dtype)2458 void TestFeedAndFetchTensorsInDeviceMemory(
2459 const SessionOptions& session_options, DataType dtype) {
2460 std::unique_ptr<Session> session(NewSession(session_options));
2461 const string gpu_device_name = GPUDeviceName(session.get());
2462 if (gpu_device_name.empty()) {
2463 LOG(INFO) << "Skipping test since no GPU is available";
2464 return;
2465 }
2466
2467 TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
2468 << DataType_Name(dtype);
2469
2470 CallableOptions opts;
2471 opts.add_feed("x:0");
2472 opts.add_fetch("y:0");
2473
2474 Tensor gpu_tensor;
2475 Tensor host_tensor(dtype, {3});
2476 {
2477 // Ask for the fetched tensor to be backed by device memory.
2478 // Even though the kernel that created the tensor produced it in host
2479 // memory.
2480 opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2481 opts.set_fetch_skip_sync(true);
2482 Session::CallableHandle handle;
2483 TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
2484 std::vector<Tensor> outputs;
2485 TF_ASSERT_OK(session->RunCallable(handle, {host_tensor}, &outputs, nullptr))
2486 << DataType_Name(dtype);
2487 TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
2488 ASSERT_EQ(1, outputs.size()) << DataType_Name(dtype);
2489 gpu_tensor = outputs[0];
2490 ASSERT_TRUE(IsCUDATensor(gpu_tensor)) << DataType_Name(dtype);
2491 }
2492
2493 {
2494 // Feed a tensor backed by device memory, even though the operations in the
2495 // graph expect it in host memory.
2496 opts.clear_fetch_devices();
2497 opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2498 Session::CallableHandle handle;
2499 TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
2500 std::vector<Tensor> outputs;
2501 TF_ASSERT_OK(session->RunCallable(handle, {gpu_tensor}, &outputs, nullptr))
2502 << DataType_Name(dtype);
2503 TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
2504 ASSERT_EQ(1, outputs.size());
2505 const StringPiece actual_data = outputs[0].tensor_data();
2506 const StringPiece expected_data = host_tensor.tensor_data();
2507 EXPECT_EQ(expected_data.size(), actual_data.size()) << DataType_Name(dtype);
2508 EXPECT_EQ(0, memcmp(expected_data.data(), actual_data.data(),
2509 std::min(expected_data.size(), actual_data.size())))
2510 << DataType_Name(dtype);
2511 }
2512 }
2513
TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(const SessionOptions & session_options,DataType dtype)2514 void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(
2515 const SessionOptions& session_options, DataType dtype) {
2516 std::unique_ptr<Session> session(NewSession(session_options));
2517 const string gpu_device_name = GPUDeviceName(session.get());
2518 if (gpu_device_name.empty()) {
2519 LOG(INFO) << "Skipping test since no GPU is available";
2520 return;
2521 }
2522
2523 TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
2524 << DataType_Name(dtype);
2525
2526 CallableOptions opts;
2527 opts.add_feed("x:0");
2528 opts.add_fetch("y:0");
2529
2530 // Fail when asking to fetch into GPU memory.
2531 {
2532 opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2533 opts.set_fetch_skip_sync(true);
2534 Session::CallableHandle handle;
2535 Status status = session->MakeCallable(opts, &handle);
2536 EXPECT_FALSE(status.ok()) << DataType_Name(dtype);
2537 EXPECT_TRUE(absl::StrContains(
2538 status.error_message(),
2539 strings::StrCat(
2540 "Cannot feed or fetch tensor 'y:0' from device ", gpu_device_name,
2541 " as feeding/fetching from GPU devices is not yet supported for ",
2542 DataTypeString(dtype), " tensors")))
2543 << DataType_Name(dtype) << ", Status: " << status;
2544 }
2545
2546 // Fail when feeding from GPU memory.
2547 {
2548 opts.clear_feed_devices();
2549 opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2550 Session::CallableHandle handle;
2551 Status status = session->MakeCallable(opts, &handle);
2552 EXPECT_FALSE(status.ok());
2553 EXPECT_TRUE(absl::StrContains(
2554 status.error_message(),
2555 strings::StrCat(
2556 "Cannot feed or fetch tensor 'x:0' from device ", gpu_device_name,
2557 " as feeding/fetching from GPU devices is not yet supported for ",
2558 DataTypeString(dtype), " tensors")))
2559 << DataType_Name(dtype) << ", Status: " << status;
2560 }
2561 }
2562
TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(const SessionOptions & opts)2563 void TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(
2564 const SessionOptions& opts) {
2565 // Feeding/fetching on device does not work for all DataTypes as it
2566 // relies on the implementation of the _Arg and _Retval kernels which
2567 // are not registered for some types or consume/produce inputs/outputs
2568 // in host memory for some types.
2569 //
2570 // Run through all datatypes to validate that either:
2571 // (a) MakeCallable fails (because the given type cannot be fed/fetched
2572 // in device memory),
2573 // OR
2574 // (b) Succeeds: RunCallable should gladly accept inputs in device memory
2575 // and produce output tensors in device memory.
2576 for (int i = DataType_MIN; i <= DataType_MAX; ++i) {
2577 if (!DataType_IsValid(i)) continue;
2578 const DataType dtype = static_cast<DataType>(i);
2579 switch (dtype) {
2580 case DT_INVALID:
2581 break;
2582 case DT_BFLOAT16:
2583 case DT_BOOL:
2584 case DT_COMPLEX128:
2585 case DT_COMPLEX64:
2586 case DT_DOUBLE:
2587 case DT_FLOAT:
2588 case DT_HALF:
2589 case DT_INT16:
2590 case DT_INT64:
2591 case DT_INT8:
2592 case DT_UINT16:
2593 case DT_UINT8:
2594 TestFeedAndFetchTensorsInDeviceMemory(opts, dtype);
2595 break;
2596 default:
2597 // Ignore all REF types since Tensors of this type aren't intended to
2598 // be fed (and attempting to create one via the Tensor constructor
2599 // will result in a LOG(FATAL)).
2600 if (!IsRefType(dtype)) {
2601 TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(opts, dtype);
2602 }
2603 break;
2604 }
2605 }
2606 }
2607
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory_AllDataTypes)2608 TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory_AllDataTypes) {
2609 SessionOptions opts;
2610 opts.config.set_allow_soft_placement(false);
2611 TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
2612 }
2613
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement)2614 TEST(DirectSessionTest,
2615 FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement) {
2616 SessionOptions opts;
2617 opts.config.set_allow_soft_placement(true);
2618 TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
2619 }
2620
2621 // A simple benchmark for the overhead of `DirectSession::Run()` calls
2622 // with varying numbers of feeds/fetches.
FeedFetchBenchmarkHelper(::testing::benchmark::State & state,int num_feeds,bool use_make_callable,int inter_op_threads,bool use_single_threaded_executor)2623 void FeedFetchBenchmarkHelper(::testing::benchmark::State& state, int num_feeds,
2624 bool use_make_callable, int inter_op_threads,
2625 bool use_single_threaded_executor) {
2626 Tensor value(DT_FLOAT, TensorShape());
2627 value.flat<float>()(0) = 37.0;
2628
2629 std::vector<std::pair<string, Tensor>> inputs;
2630 inputs.reserve(num_feeds);
2631 std::vector<string> outputs;
2632
2633 Graph g(OpRegistry::Global());
2634 for (int i = 0; i < num_feeds; ++i) {
2635 // NOTE(mrry): We pin nodes to the "/cpu:0" device, so as not to
2636 // measure CPU<->GPU copying overhead. We should also optimize and
2637 // monitor this overhead where possible, but that is not the
2638 // object of study in this benchmark.
2639 Node* placeholder;
2640 TF_CHECK_OK(NodeBuilder(g.NewName("Placeholder"), "Placeholder")
2641 .Attr("shape", TensorShape())
2642 .Attr("dtype", DT_FLOAT)
2643 .Device("/cpu:0")
2644 .Finalize(&g, &placeholder));
2645 Node* identity;
2646 TF_CHECK_OK(NodeBuilder(g.NewName("Identity"), "Identity")
2647 .Input(placeholder)
2648 .Attr("T", DT_FLOAT)
2649 .Device("/cpu:0")
2650 .Finalize(&g, &identity));
2651 inputs.push_back({placeholder->name() + ":0", value});
2652 outputs.push_back(identity->name() + ":0");
2653 }
2654 GraphDef gd;
2655 g.ToGraphDef(&gd);
2656 SessionOptions opts;
2657 opts.config.set_inter_op_parallelism_threads(inter_op_threads);
2658 if (use_single_threaded_executor) {
2659 opts.config.mutable_experimental()->set_executor_type(
2660 "SINGLE_THREADED_EXECUTOR");
2661 }
2662 std::unique_ptr<Session> session(NewSession(opts));
2663 TF_CHECK_OK(session->Create(gd));
2664 if (use_make_callable) {
2665 Session::CallableHandle handle;
2666 CallableOptions callable_options;
2667 std::vector<Tensor> input_tensors;
2668 for (const auto& input : inputs) {
2669 callable_options.add_feed(input.first);
2670 input_tensors.push_back(input.second);
2671 }
2672 for (const string& output : outputs) {
2673 callable_options.add_fetch(output);
2674 }
2675 TF_CHECK_OK(session->MakeCallable(callable_options, &handle));
2676
2677 for (auto s : state) {
2678 std::vector<Tensor> output_values;
2679 TF_CHECK_OK(
2680 session->RunCallable(handle, input_tensors, &output_values, nullptr));
2681 }
2682 } else {
2683 {
2684 // NOTE(mrry): Ignore the first run, which will incur the graph
2685 // partitioning/pruning overhead and skew the results.
2686 //
2687 // Note that we should also optimize and monitor the overhead on
2688 // the first run, which will impact application startup times, but
2689 // that is not the object of study in this benchmark.
2690 std::vector<Tensor> output_values;
2691 TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
2692 }
2693
2694 for (auto s : state) {
2695 std::vector<Tensor> output_values;
2696 TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
2697 }
2698 }
2699 }
2700
BM_FeedFetch(::testing::benchmark::State & state)2701 void BM_FeedFetch(::testing::benchmark::State& state) {
2702 const int num_feeds = state.range(0);
2703
2704 FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ false,
2705 /* inter_op_threads */ 0,
2706 /* use_single_threaded_executor */ false);
2707 }
BM_FeedFetchCallable(::testing::benchmark::State & state)2708 void BM_FeedFetchCallable(::testing::benchmark::State& state) {
2709 const int num_feeds = state.range(0);
2710
2711 FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2712 /* inter_op_threads */ 0,
2713 /* use_single_threaded_executor */ false);
2714 }
BM_FeedFetchCallableSingleThread(::testing::benchmark::State & state)2715 void BM_FeedFetchCallableSingleThread(::testing::benchmark::State& state) {
2716 const int num_feeds = state.range(0);
2717
2718 FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2719 /* inter_op_threads */ -1,
2720 /* use_single_threaded_executor */ false);
2721 }
BM_FeedFetchCallableSingleThreadExecutor(::testing::benchmark::State & state)2722 void BM_FeedFetchCallableSingleThreadExecutor(
2723 ::testing::benchmark::State& state) {
2724 const int num_feeds = state.range(0);
2725
2726 FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2727 /* inter_op_threads */ -1,
2728 /* use_single_threaded_executor */ true);
2729 }
2730
2731 BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2732 BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2733 BENCHMARK(BM_FeedFetchCallableSingleThread)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2734 BENCHMARK(BM_FeedFetchCallableSingleThreadExecutor)
2735 ->Arg(1)
2736 ->Arg(2)
2737 ->Arg(5)
2738 ->Arg(10);
2739
2740 } // namespace
2741
2742 class DirectSessionCollectiveTest : public ::testing::Test {
2743 public:
2744 // Creates a graph with CollectiveOps inside functions and runs it. Returns
2745 // the generated collective_graph_key.
RunGraphWithCollectiveFunctions(bool add_unused_function,int64_t * collective_graph_key)2746 Status RunGraphWithCollectiveFunctions(bool add_unused_function,
2747 int64_t* collective_graph_key) {
2748 GraphDef g = CreateGraph(add_unused_function);
2749 const Tensor t1 =
2750 test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
2751 const Tensor t2 =
2752 test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
2753 auto session = CreateSession();
2754 TF_RETURN_IF_ERROR(session->Create(g));
2755 std::vector<Tensor> outputs;
2756 TF_RETURN_IF_ERROR(
2757 session->Run({{"input0:0", t1}, {"input1:0", t2}}, {},
2758 {"collective_call0:0", "collective_call1:0"}, &outputs));
2759 DirectSession* direct_session = static_cast<DirectSession*>(session.get());
2760 {
2761 mutex_lock l(direct_session->collective_graph_key_lock_);
2762 *collective_graph_key = direct_session->collective_graph_key_;
2763 }
2764 return OkStatus();
2765 }
2766
2767 private:
2768 // Creates a function with name `function_name` and a single CollectiveReduce
2769 // node with instance key set as `instance_key`.
CollectiveFunction(const string & function_name,int instance_key)2770 FunctionDef CollectiveFunction(const string& function_name,
2771 int instance_key) {
2772 return FunctionDefHelper::Define(
2773 // Function name
2774 function_name,
2775 // In def
2776 {"arg:float"},
2777 // Out def
2778 {"reduce:float"},
2779 // Attr def
2780 {},
2781 // Node def
2782 {{
2783 {"reduce"},
2784 "CollectiveReduce",
2785 {"arg"},
2786 {{"group_size", 2},
2787 {"group_key", 1},
2788 {"instance_key", instance_key},
2789 {"subdiv_offsets", gtl::ArraySlice<int32>({0})},
2790 {"merge_op", "Add"},
2791 {"final_op", "Div"},
2792 {"T", DT_FLOAT}},
2793 }});
2794 }
2795
Input(int id)2796 NodeDef Input(int id) {
2797 AttrValue dtype_attr;
2798 SetAttrValue(DT_FLOAT, &dtype_attr);
2799 NodeDef input;
2800 input.set_name(strings::StrCat("input", id));
2801 input.set_op("Placeholder");
2802 input.mutable_attr()->insert({"dtype", dtype_attr});
2803 return input;
2804 }
2805
CollectiveCall(const string & op,const string & input,int cpu_id)2806 NodeDef CollectiveCall(const string& op, const string& input, int cpu_id) {
2807 NodeDef collective_call;
2808 collective_call.set_name(strings::StrCat("collective_call", cpu_id));
2809 collective_call.set_op(op);
2810 collective_call.add_input(input);
2811 collective_call.set_device(
2812 strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", cpu_id));
2813 return collective_call;
2814 }
2815
2816 // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
2817 // CPU1, with instance_key 1, and appropriate placeholder inputs. If
2818 // `add_unused_function` is true, adds another CollectiveFunction with
2819 // instance_key 2 that is not invoked in the graph.
CreateGraph(bool add_unused_function)2820 GraphDef CreateGraph(bool add_unused_function) {
2821 GraphDef g;
2822 FunctionDef collective_function =
2823 CollectiveFunction("CollectiveFunction1", 1);
2824 FunctionDefLibrary* lib = g.mutable_library();
2825 *lib->add_function() = collective_function;
2826 if (add_unused_function) {
2827 FunctionDef unused_function =
2828 CollectiveFunction("CollectiveFunction2", 2);
2829 *lib->add_function() = unused_function;
2830 }
2831
2832 *g.add_node() = Input(0);
2833 *g.add_node() = Input(1);
2834 // CollectiveReduce on CPU0 with instance_key 1.
2835 *g.add_node() = CollectiveCall("CollectiveFunction1", "input0", 0);
2836 // CollectiveReduce on CPU1 with instance_key 1.
2837 *g.add_node() = CollectiveCall("CollectiveFunction1", "input1", 1);
2838
2839 return g;
2840 }
2841 };
2842
TEST_F(DirectSessionCollectiveTest,TestCollectiveGraphKeyUsesOnlyCalledFunctions)2843 TEST_F(DirectSessionCollectiveTest,
2844 TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
2845 int64_t key1;
2846 TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
2847 int64_t key2;
2848 TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
2849 ASSERT_EQ(key1, key2);
2850 }
2851
2852 // Accesses the cancellation manager for the step after the step has been
2853 // cancelled.
2854 class StatefulOutputRequiredOp : public OpKernel {
2855 public:
StatefulOutputRequiredOp(OpKernelConstruction * ctx)2856 explicit StatefulOutputRequiredOp(OpKernelConstruction* ctx)
2857 : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)2858 void Compute(OpKernelContext* ctx) override {
2859 // The op counts the number of outputs required in the current subgraph,
2860 // and emits that number on each of its required outputs.
2861 Tensor count_outputs_required_t(int64_t{0});
2862 int64_t& count_outputs_required =
2863 count_outputs_required_t.scalar<int64_t>()();
2864 for (int i = 0; i < num_outputs(); ++i) {
2865 if (ctx->output_required(i)) ++count_outputs_required;
2866 }
2867 for (int i = 0; i < num_outputs(); ++i) {
2868 if (ctx->output_required(i)) ctx->set_output(i, count_outputs_required_t);
2869 }
2870 }
2871 };
2872
2873 REGISTER_KERNEL_BUILDER(Name("StatefulOutputRequired").Device(DEVICE_CPU),
2874 StatefulOutputRequiredOp);
2875 REGISTER_OP("StatefulOutputRequired")
2876 .Output("results : num_outs * int64")
2877 .Attr("num_outs : int = 5")
2878 .SetIsStateful();
2879
TEST(DirectSessionTest,TestStatefulOutputRequiredOp)2880 TEST(DirectSessionTest, TestStatefulOutputRequiredOp) {
2881 GraphDef graph;
2882 // Creates a graph with a StatefulOutputRequired op with 5 outputs.
2883 protobuf::TextFormat::ParseFromString(
2884 R"pb(
2885 node { name: 'n' op: 'StatefulOutputRequired' device: '/device:CPU:0' }
2886 versions { producer: 9 }
2887 )pb",
2888 &graph);
2889
2890 std::unique_ptr<Session> session(NewSession(SessionOptions()));
2891 ASSERT_TRUE(session != nullptr);
2892 TF_ASSERT_OK(session->Create(std::move(graph)));
2893
2894 // As a stateful op, a single StatefulOutputRequired kernel will be created
2895 // and shared across multiple subgraphs. We create 5 different subgraphs,
2896 // fetching different prefixes of the output of the op.
2897 for (int num_outputs_required = 1; num_outputs_required <= 5;
2898 ++num_outputs_required) {
2899 std::vector<string> fetch_tensor_names;
2900 fetch_tensor_names.reserve(num_outputs_required);
2901 for (int output_idx = 0; output_idx < num_outputs_required; ++output_idx) {
2902 fetch_tensor_names.push_back(strings::StrCat("n:", output_idx));
2903 }
2904 std::vector<Tensor> fetch_tensors;
2905 TF_ASSERT_OK(session->Run({}, fetch_tensor_names, {}, &fetch_tensors));
2906 ASSERT_EQ(num_outputs_required, fetch_tensors.size());
2907 for (const Tensor& t : fetch_tensors) {
2908 ASSERT_EQ(num_outputs_required, t.scalar<int64_t>()());
2909 }
2910 }
2911
2912 TF_ASSERT_OK(session->Close());
2913 }
2914
2915 } // namespace tensorflow
2916