1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/direct_session.h"
17
18 #include <map>
19 #include <memory>
20 #include <random>
21 #include <string>
22 #include <thread> // NOLINT
23 #include <unordered_map>
24 #include <vector>
25
26 #include "absl/memory/memory.h"
27 #include "absl/strings/match.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/function_testlib.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_testutil.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/graph/costmodel.h"
38 #include "tensorflow/core/graph/graph.h"
39 #include "tensorflow/core/graph/node_builder.h"
40 #include "tensorflow/core/graph/testlib.h"
41 #include "tensorflow/core/kernels/ops_util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/core/threadpool.h"
46 #include "tensorflow/core/lib/strings/str_util.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/stacktrace.h"
49 #include "tensorflow/core/platform/test.h"
50 #include "tensorflow/core/platform/test_benchmark.h"
51 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
52 #include "tensorflow/core/public/session.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/util/device_name_utils.h"
55
56 #if GOOGLE_CUDA
57 #include "third_party/gpus/cuda/include/cuda.h"
58 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
59 #elif TENSORFLOW_USE_ROCM
60 #include "rocm/include/hip/hip_runtime.h"
61 #endif // GOOGLE_CUDA
62
63 namespace tensorflow {
64 namespace {
65
MakeCallableOptions(gtl::ArraySlice<string> feeds,gtl::ArraySlice<string> fetches,gtl::ArraySlice<string> targets)66 CallableOptions MakeCallableOptions(gtl::ArraySlice<string> feeds,
67 gtl::ArraySlice<string> fetches,
68 gtl::ArraySlice<string> targets) {
69 CallableOptions ret;
70 for (const string& feed : feeds) {
71 ret.add_feed(feed);
72 }
73 for (const string& fetch : fetches) {
74 ret.add_fetch(fetch);
75 }
76 for (const string& target : targets) {
77 ret.add_target(target);
78 }
79 return ret;
80 }
81
DefaultSessionOptions()82 SessionOptions DefaultSessionOptions() {
83 SessionOptions options;
84 (*options.config.mutable_device_count())["CPU"] = 2;
85 return options;
86 }
87
CreateSession()88 std::unique_ptr<Session> CreateSession() {
89 return std::unique_ptr<Session>(NewSession(DefaultSessionOptions()));
90 }
91
92 class DirectSessionMinusAXTest : public ::testing::Test {
93 public:
Initialize(std::initializer_list<float> a_values)94 void Initialize(std::initializer_list<float> a_values) {
95 Graph graph(OpRegistry::Global());
96
97 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
98 test::FillValues<float>(&a_tensor, a_values);
99 Node* a = test::graph::Constant(&graph, a_tensor);
100 a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
101 a_ = a->name();
102
103 Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
104 test::FillValues<float>(&x_tensor, {1, 1});
105 Node* x = test::graph::Constant(&graph, x_tensor);
106 x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
107 x_ = x->name();
108
109 // y = A * x
110 Node* y = test::graph::Matmul(&graph, a, x, false, false);
111 y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
112 y_ = y->name();
113
114 Node* y_neg = test::graph::Unary(&graph, "Neg", y);
115 y_neg_ = y_neg->name();
116 y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
117
118 Node* z = test::graph::Unary(&graph, "Identity", y_neg);
119 z_ = z->name();
120 z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
121
122 graph.ToGraphDef(&def_);
123 }
124
125 string a_;
126 string x_;
127 string y_;
128 string y_neg_;
129 string z_;
130 GraphDef def_;
131 };
132
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork)133 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
134 Initialize({3, 2, -1, 0});
135 auto session = CreateSession();
136 ASSERT_TRUE(session != nullptr);
137 TF_ASSERT_OK(session->Create(def_));
138 std::vector<std::pair<string, Tensor>> inputs;
139
140 // Request two targets: one fetch output and one non-fetched output.
141 std::vector<string> output_names = {y_ + ":0"};
142 std::vector<string> target_nodes = {y_neg_};
143 std::vector<Tensor> outputs;
144 Status s = session->Run(inputs, output_names, target_nodes, &outputs);
145 TF_ASSERT_OK(s);
146
147 ASSERT_EQ(1, outputs.size());
148 // The first output should be initialized and have the correct
149 // output.
150 auto mat = outputs[0].matrix<float>();
151 ASSERT_TRUE(outputs[0].IsInitialized());
152 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
153 }
154
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_Callable)155 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
156 Initialize({3, 2, -1, 0});
157 auto session = CreateSession();
158 ASSERT_TRUE(session != nullptr);
159 TF_ASSERT_OK(session->Create(def_));
160
161 // Run the test twice to ensure that the Make/Run/Release cycle is hermetic.
162 for (int i = 0; i < 2; ++i) {
163 // Request two targets: one fetch output and one non-fetched output.
164 Session::CallableHandle handle;
165 TF_ASSERT_OK(session->MakeCallable(
166 MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
167
168 for (int i = 0; i < 2; ++i) {
169 std::vector<Tensor> outputs;
170 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
171
172 ASSERT_EQ(1, outputs.size());
173 // The first output should be initialized and have the correct
174 // output.
175 auto mat = outputs[0].matrix<float>();
176 ASSERT_TRUE(outputs[0].IsInitialized());
177 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
178 }
179
180 Status s = session->RunCallable(handle, {}, nullptr, nullptr);
181 EXPECT_TRUE(errors::IsInvalidArgument(s));
182 EXPECT_TRUE(absl::StrContains(s.error_message(),
183 "`fetch_tensors` must be provided"));
184
185 TF_ASSERT_OK(session->ReleaseCallable(handle));
186
187 std::vector<Tensor> outputs;
188 s = session->RunCallable(handle, {}, &outputs, nullptr);
189 EXPECT_TRUE(errors::IsInvalidArgument(s));
190 EXPECT_TRUE(absl::StrContains(
191 s.error_message(),
192 "Attempted to run callable after handle was released"));
193
194 s = session->RunCallable(handle + 1, {}, &outputs, nullptr);
195 EXPECT_TRUE(errors::IsInvalidArgument(s));
196 EXPECT_TRUE(
197 absl::StrContains(s.error_message(), "No such callable handle"));
198 }
199 }
200
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_OptimizeForStaticGraph)201 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_OptimizeForStaticGraph) {
202 Initialize({3, 2, -1, 0});
203 SessionOptions options(DefaultSessionOptions());
204 options.config.mutable_experimental()->set_optimize_for_static_graph(true);
205 auto session = absl::WrapUnique(NewSession(options));
206
207 ASSERT_TRUE(session != nullptr);
208 TF_ASSERT_OK(session->Create(def_));
209 std::vector<std::pair<string, Tensor>> inputs;
210
211 // Request two targets: one fetch output and one non-fetched output.
212 std::vector<string> output_names = {y_ + ":0"};
213 std::vector<string> target_nodes = {y_neg_};
214 std::vector<Tensor> outputs;
215 Status s = session->Run(inputs, output_names, target_nodes, &outputs);
216 TF_ASSERT_OK(s);
217
218 ASSERT_EQ(1, outputs.size());
219 // The first output should be initialized and have the correct
220 // output.
221 auto mat = outputs[0].matrix<float>();
222 ASSERT_TRUE(outputs[0].IsInitialized());
223 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
224
225 s = session->Extend({});
226 EXPECT_TRUE(errors::IsFailedPrecondition(s));
227 EXPECT_TRUE(
228 absl::StrContains(s.error_message(), "optimize_for_static_graph"));
229 }
230
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_DisableOutputPartitionGraphs)231 TEST_F(DirectSessionMinusAXTest,
232 RunSimpleNetwork_DisableOutputPartitionGraphs) {
233 Initialize({3, 2, -1, 0});
234 SessionOptions options(DefaultSessionOptions());
235 options.config.mutable_experimental()->set_disable_output_partition_graphs(
236 true);
237 auto session = absl::WrapUnique(NewSession(options));
238
239 ASSERT_TRUE(session != nullptr);
240 TF_ASSERT_OK(session->Create(def_));
241 std::vector<std::pair<string, Tensor>> inputs;
242
243 // Request two targets: one fetch output and one non-fetched output.
244 std::vector<string> output_names = {y_ + ":0"};
245 std::vector<string> target_nodes = {y_neg_};
246 std::vector<Tensor> outputs;
247 Status s = session->Run(inputs, output_names, target_nodes, &outputs);
248 TF_ASSERT_OK(s);
249
250 ASSERT_EQ(1, outputs.size());
251 // The first output should be initialized and have the correct
252 // output.
253 auto mat = outputs[0].matrix<float>();
254 ASSERT_TRUE(outputs[0].IsInitialized());
255 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
256
257 // The Run() call should fail when `output_partition_graphs` is set to true.
258 RunOptions run_options;
259 run_options.set_output_partition_graphs(true);
260 RunMetadata run_metadata;
261 s = session->Run(run_options, inputs, output_names, target_nodes, &outputs,
262 &run_metadata);
263
264 EXPECT_TRUE(errors::IsInvalidArgument(s));
265 EXPECT_TRUE(
266 absl::StrContains(s.error_message(), "disable_output_partition_graphs"));
267 }
268
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_FinalizeWithCallables)269 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithCallables) {
270 Initialize({3, 2, -1, 0});
271 auto session = CreateSession();
272 ASSERT_TRUE(session != nullptr);
273 TF_ASSERT_OK(session->Create(def_));
274
275 // Request two targets: one fetch output and one non-fetched output.
276 Session::CallableHandle handle;
277 TF_ASSERT_OK(session->MakeCallable(
278 MakeCallableOptions({}, {y_ + ":0"}, {y_neg_}), &handle));
279
280 // Finalize the session.
281 TF_ASSERT_OK(session->Finalize());
282
283 // The callable is usable after finalization.
284 for (int i = 0; i < 2; ++i) {
285 std::vector<Tensor> outputs;
286 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
287
288 ASSERT_EQ(1, outputs.size());
289 // The first output should be initialized and have the correct
290 // output.
291 auto mat = outputs[0].matrix<float>();
292 ASSERT_TRUE(outputs[0].IsInitialized());
293 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
294 }
295
296 TF_ASSERT_OK(session->ReleaseCallable(handle));
297
298 // Making a new callable fails because the session has been finalized.
299 Status s =
300 session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle);
301 EXPECT_TRUE(errors::IsFailedPrecondition(s));
302 EXPECT_TRUE(
303 absl::StrContains(s.error_message(), "Session has been finalized."));
304 }
305
TEST_F(DirectSessionMinusAXTest,RunSimpleNetwork_FinalizeWithRun)306 TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithRun) {
307 Initialize({3, 2, -1, 0});
308 auto session = CreateSession();
309 ASSERT_TRUE(session != nullptr);
310 TF_ASSERT_OK(session->Create(def_));
311
312 // Request two targets: one fetch output and one non-fetched output.
313 std::vector<Tensor> outputs;
314 TF_ASSERT_OK(session->Run({}, {y_ + ":0"}, {y_neg_}, &outputs));
315
316 ASSERT_EQ(1, outputs.size());
317 // The first output should be initialized and have the correct output.
318 auto mat = outputs[0].matrix<float>();
319 ASSERT_TRUE(outputs[0].IsInitialized());
320 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
321
322 // Finalize the session.
323 TF_ASSERT_OK(session->Finalize());
324
325 // Running the exact same subgraph succeeds after finalization.
326 TF_ASSERT_OK(session->Run({}, {y_ + ":0"}, {y_neg_}, &outputs));
327 ASSERT_EQ(1, outputs.size());
328 mat = outputs[0].matrix<float>();
329 ASSERT_TRUE(outputs[0].IsInitialized());
330 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
331
332 // Running a different subgraph fails because the session has been finalized.
333 Status s = session->Run({}, {y_ + ":0"}, {}, &outputs);
334 EXPECT_TRUE(errors::IsFailedPrecondition(s));
335 EXPECT_TRUE(
336 absl::StrContains(s.error_message(), "Session has been finalized."));
337 }
338
TEST_F(DirectSessionMinusAXTest,TestTensorConnection)339 TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
340 Initialize({3, 2, -1, 0});
341 auto session = CreateSession();
342 ASSERT_TRUE(session != nullptr);
343 TF_ASSERT_OK(session->Create(def_));
344
345 {
346 // Directly wire the output of node a to the output of node y, making the
347 // callable graph into "Neg(a);".
348 CallableOptions callable_options;
349 TensorConnection* c = callable_options.add_tensor_connection();
350 c->set_from_tensor(a_ + ":0");
351 c->set_to_tensor(y_ + ":0");
352 callable_options.add_fetch(y_neg_ + ":0");
353
354 Session::CallableHandle handle;
355 TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
356 std::vector<Tensor> outputs;
357 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
358 ASSERT_EQ(1, outputs.size());
359 auto mat = outputs[0].matrix<float>();
360 ASSERT_TRUE(outputs[0].IsInitialized());
361 EXPECT_FLOAT_EQ(-3.0, mat(0, 0));
362 EXPECT_FLOAT_EQ(-2.0, mat(0, 1));
363 EXPECT_FLOAT_EQ(1.0, mat(1, 0));
364 EXPECT_FLOAT_EQ(0.0, mat(1, 1));
365 TF_ASSERT_OK(session->ReleaseCallable(handle));
366 }
367
368 {
369 // Directly wire the output of node a to the output of node y, making the
370 // callable graph into "Neg(a);"; also fetch the result of a.
371 CallableOptions callable_options;
372 TensorConnection* c = callable_options.add_tensor_connection();
373 c->set_from_tensor(a_ + ":0");
374 c->set_to_tensor(y_ + ":0");
375 callable_options.add_fetch(a_ + ":0");
376 callable_options.add_fetch(y_neg_ + ":0");
377
378 Session::CallableHandle handle;
379 TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
380 std::vector<Tensor> outputs;
381 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
382 ASSERT_EQ(2, outputs.size());
383 auto mat_a = outputs[0].matrix<float>();
384 ASSERT_TRUE(outputs[0].IsInitialized());
385 EXPECT_FLOAT_EQ(3.0, mat_a(0, 0));
386 EXPECT_FLOAT_EQ(2.0, mat_a(0, 1));
387 EXPECT_FLOAT_EQ(-1.0, mat_a(1, 0));
388 EXPECT_FLOAT_EQ(0.0, mat_a(1, 1));
389
390 auto mat_y_neg = outputs[1].matrix<float>();
391 ASSERT_TRUE(outputs[1].IsInitialized());
392 EXPECT_FLOAT_EQ(-3.0, mat_y_neg(0, 0));
393 EXPECT_FLOAT_EQ(-2.0, mat_y_neg(0, 1));
394 EXPECT_FLOAT_EQ(1.0, mat_y_neg(1, 0));
395 EXPECT_FLOAT_EQ(0.0, mat_y_neg(1, 1));
396 TF_ASSERT_OK(session->ReleaseCallable(handle));
397 }
398
399 {
400 // Wire the output of "Neg(Matmul(a, x))" to the output of "a",
401 // creating an invalid cycle.
402 CallableOptions callable_options;
403 TensorConnection* c = callable_options.add_tensor_connection();
404 c->set_from_tensor(y_ + ":0");
405 c->set_to_tensor(a_ + ":0");
406 callable_options.add_fetch(y_ + ":0");
407
408 Session::CallableHandle handle;
409 Status s = session->MakeCallable(callable_options, &handle);
410 EXPECT_TRUE(errors::IsInvalidArgument(s));
411 EXPECT_TRUE(absl::StrContains(s.error_message(), "would create a cycle"));
412 }
413
414 {
415 // Attempt to wire a non-existent node to a node that does exist.
416 CallableOptions callable_options;
417 TensorConnection* c = callable_options.add_tensor_connection();
418 c->set_from_tensor("unknown_node:0");
419 c->set_to_tensor(y_ + ":0");
420 callable_options.add_fetch(y_ + ":0");
421
422 Session::CallableHandle handle;
423 Status s = session->MakeCallable(callable_options, &handle);
424 EXPECT_TRUE(errors::IsInvalidArgument(s));
425 EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown node"));
426 }
427
428 {
429 // Attempt to wire a non-existent output from a node that does
430 // exist to another node.
431 CallableOptions callable_options;
432 TensorConnection* c = callable_options.add_tensor_connection();
433 c->set_from_tensor(a_ + ":17");
434 c->set_to_tensor(y_ + ":0");
435 callable_options.add_fetch(y_ + ":0");
436
437 Session::CallableHandle handle;
438 Status s = session->MakeCallable(callable_options, &handle);
439 EXPECT_TRUE(errors::IsInvalidArgument(s));
440 EXPECT_TRUE(absl::StrContains(s.error_message(), "unknown edge"));
441 }
442
443 {
444 // Attempt to wire a tensor to a node that doesn't exist.
445 CallableOptions callable_options;
446 TensorConnection* c = callable_options.add_tensor_connection();
447 c->set_from_tensor(a_ + ":0");
448 c->set_to_tensor("unknown_node:0");
449 callable_options.add_fetch(y_ + ":0");
450
451 Session::CallableHandle handle;
452 Status s = session->MakeCallable(callable_options, &handle);
453 EXPECT_TRUE(errors::IsNotFound(s));
454 EXPECT_TRUE(
455 absl::StrContains(s.error_message(), "unable to find feed output"));
456 }
457
458 {
459 // Attempt to wire two tensors to the same tensor.
460 CallableOptions callable_options;
461 TensorConnection* c1 = callable_options.add_tensor_connection();
462 c1->set_from_tensor(a_ + ":0");
463 c1->set_to_tensor(y_neg_ + ":0");
464 TensorConnection* c2 = callable_options.add_tensor_connection();
465 c2->set_from_tensor(x_ + ":0");
466 c2->set_to_tensor(y_neg_ + ":0");
467 callable_options.add_fetch(z_ + ":0");
468
469 Session::CallableHandle handle;
470 Status s = session->MakeCallable(callable_options, &handle);
471 EXPECT_TRUE(errors::IsInvalidArgument(s));
472 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
473 }
474
475 {
476 // Attempt to wire a tensor to a tensor that is also being fed.
477 CallableOptions callable_options;
478 TensorConnection* c = callable_options.add_tensor_connection();
479 c->set_from_tensor(a_ + ":0");
480 c->set_to_tensor(y_ + ":0");
481 callable_options.add_feed(y_ + ":0");
482 callable_options.add_fetch(y_neg_ + ":0");
483
484 Session::CallableHandle handle;
485 Status s = session->MakeCallable(callable_options, &handle);
486 EXPECT_TRUE(errors::IsInvalidArgument(s));
487 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
488 }
489 }
490
TEST_F(DirectSessionMinusAXTest,TestFeed)491 TEST_F(DirectSessionMinusAXTest, TestFeed) {
492 Initialize({1, 2, 3, 4});
493 auto session = CreateSession();
494 ASSERT_TRUE(session != nullptr);
495
496 TF_ASSERT_OK(session->Create(def_));
497
498 // Fill in the input and ask for the output
499 //
500 // Note that the input being fed is on the second device.
501 Tensor t(DT_FLOAT, TensorShape({2, 1}));
502 t.matrix<float>()(0, 0) = 5;
503 t.matrix<float>()(1, 0) = 6;
504 std::vector<std::pair<string, Tensor>> inputs = {{x_, t}};
505 std::vector<string> output_names = {y_ + ":0"};
506 std::vector<Tensor> outputs;
507
508 // Run the graph
509 Status s = session->Run(inputs, output_names, {}, &outputs);
510 TF_ASSERT_OK(s);
511
512 ASSERT_EQ(1, outputs.size());
513 auto mat = outputs[0].matrix<float>();
514
515 // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
516 EXPECT_FLOAT_EQ(17.0, mat(0, 0));
517 EXPECT_FLOAT_EQ(39.0, mat(1, 0));
518 }
519
TEST_F(DirectSessionMinusAXTest,TestFeed_Callable)520 TEST_F(DirectSessionMinusAXTest, TestFeed_Callable) {
521 Initialize({1, 2, 3, 4});
522 auto session = CreateSession();
523 ASSERT_TRUE(session != nullptr);
524
525 TF_ASSERT_OK(session->Create(def_));
526
527 // Fill in the input and ask for the output
528 //
529 // Note that the input being fed is on the second device.
530 CallableOptions callable_options;
531 callable_options.add_feed(x_);
532 callable_options.add_fetch(y_ + ":0");
533 Session::CallableHandle handle;
534 TF_ASSERT_OK(session->MakeCallable(MakeCallableOptions({x_}, {y_ + ":0"}, {}),
535 &handle));
536 Tensor t(DT_FLOAT, TensorShape({2, 1}));
537 t.matrix<float>()(0, 0) = 5;
538 t.matrix<float>()(1, 0) = 6;
539 std::vector<Tensor> inputs = {t};
540 std::vector<Tensor> outputs;
541
542 // Run the callable
543 TF_ASSERT_OK(session->RunCallable(handle, inputs, &outputs, nullptr));
544
545 ASSERT_EQ(1, outputs.size());
546 auto mat = outputs[0].matrix<float>();
547
548 // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
549 EXPECT_FLOAT_EQ(17.0, mat(0, 0));
550 EXPECT_FLOAT_EQ(39.0, mat(1, 0));
551 }
552
TEST_F(DirectSessionMinusAXTest,TestConcurrency)553 TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
554 Initialize({1, 2, 3, 4});
555 auto session = CreateSession();
556 ASSERT_TRUE(session != nullptr);
557 TF_ASSERT_OK(session->Create(def_));
558
559 // Fill in the input and ask for the output
560 thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
561
562 // Run the graph 1000 times in 4 different threads concurrently.
563 std::vector<string> output_names = {y_ + ":0"};
564 auto fn = [&session, output_names]() {
565 for (int i = 0; i < 1000; ++i) {
566 std::vector<std::pair<string, Tensor>> inputs;
567 std::vector<Tensor> outputs;
568 // Run the graph
569 Status s = session->Run(inputs, output_names, {}, &outputs);
570 TF_ASSERT_OK(s);
571 ASSERT_EQ(1, outputs.size());
572 auto mat = outputs[0].matrix<float>();
573 EXPECT_FLOAT_EQ(3.0, mat(0, 0));
574 }
575 };
576
577 for (int i = 0; i < 4; ++i) {
578 tp->Schedule(fn);
579 }
580
581 // Wait for the functions to finish.
582 delete tp;
583 }
584
TEST_F(DirectSessionMinusAXTest,TestConcurrency_Callable)585 TEST_F(DirectSessionMinusAXTest, TestConcurrency_Callable) {
586 Initialize({1, 2, 3, 4});
587 auto session = CreateSession();
588 ASSERT_TRUE(session != nullptr);
589 TF_ASSERT_OK(session->Create(def_));
590
591 // Fill in the input and ask for the output
592 thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
593
594 Session::CallableHandle handle;
595 TF_ASSERT_OK(
596 session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle));
597
598 // Run the callable 1000 times in 4 different threads concurrently.
599 auto fn = [&session, handle]() {
600 for (int i = 0; i < 1000; ++i) {
601 std::vector<Tensor> outputs;
602 // Run the graph
603 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
604 ASSERT_EQ(1, outputs.size());
605 auto mat = outputs[0].matrix<float>();
606 EXPECT_FLOAT_EQ(3.0, mat(0, 0));
607 }
608 };
609
610 for (int i = 0; i < 4; ++i) {
611 tp->Schedule(fn);
612 }
613
614 // Wait for the functions to finish.
615 delete tp;
616 }
617
TEST_F(DirectSessionMinusAXTest,TestPerSessionThreads)618 TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) {
619 Initialize({1, 2, 3, 4});
620
621 SessionOptions options;
622 options.config.set_use_per_session_threads(true);
623 (*options.config.mutable_device_count())["CPU"] = 2;
624 std::unique_ptr<Session> session(NewSession(options));
625
626 ASSERT_TRUE(session != nullptr);
627 TF_ASSERT_OK(session->Create(def_));
628
629 // Fill in the input and ask for the output
630 thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
631
632 // Run the graph 1000 times in 4 different threads concurrently.
633 std::vector<string> output_names = {y_ + ":0"};
634 auto fn = [&session, output_names]() {
635 for (int i = 0; i < 1000; ++i) {
636 std::vector<std::pair<string, Tensor>> inputs;
637 std::vector<Tensor> outputs;
638 // Run the graph
639 Status s = session->Run(inputs, output_names, {}, &outputs);
640 TF_ASSERT_OK(s);
641 ASSERT_EQ(1, outputs.size());
642 auto mat = outputs[0].matrix<float>();
643 EXPECT_FLOAT_EQ(3.0, mat(0, 0));
644 }
645 };
646
647 for (int i = 0; i < 4; ++i) {
648 tp->Schedule(fn);
649 }
650
651 // Wait for the functions to finish.
652 delete tp;
653 }
654
TEST_F(DirectSessionMinusAXTest,TwoCreateCallsFails)655 TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
656 Initialize({1, 2, 3, 4});
657 auto session = CreateSession();
658 ASSERT_TRUE(session != nullptr);
659 TF_ASSERT_OK(session->Create(def_));
660
661 // Second is not.
662 ASSERT_FALSE(session->Create(def_).ok());
663 }
664
TEST_F(DirectSessionMinusAXTest,ForgetToCreate)665 TEST_F(DirectSessionMinusAXTest, ForgetToCreate) {
666 Initialize({1, 2, 3, 4});
667 auto session = CreateSession();
668 ASSERT_TRUE(session != nullptr);
669 std::vector<std::pair<string, Tensor>> inputs;
670 std::vector<Tensor> outputs;
671 ASSERT_FALSE(session->Run(inputs, {y_ + ":0"}, {y_neg_}, &outputs).ok());
672 }
673
TEST_F(DirectSessionMinusAXTest,InvalidDevice)674 TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
675 GraphDef def;
676 Graph graph(OpRegistry::Global());
677
678 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
679 a_tensor.flat<float>().setRandom();
680 Node* a = test::graph::Constant(&graph, a_tensor);
681 a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
682 Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
683 x_tensor.flat<float>().setRandom();
684 Node* x = test::graph::Constant(&graph, x_tensor);
685 x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
686 // Skip placing y.
687 Node* y = test::graph::Matmul(&graph, a, x, false, false);
688 y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:2");
689
690 graph.ToGraphDef(&def);
691
692 SessionOptions options;
693 (*options.config.mutable_device_count())["CPU"] = 2;
694 std::unique_ptr<Session> session(NewSession(options));
695 ASSERT_TRUE(session != nullptr);
696 // Should return an error.
697 ASSERT_FALSE(session->Create(def).ok());
698
699 // Fix placement and run again
700 def.Clear();
701 y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
702 graph.ToGraphDef(&def);
703 session.reset(NewSession(options));
704 TF_ASSERT_OK(session->Create(def));
705 std::vector<Tensor> outputs;
706 TF_ASSERT_OK(session->Run({}, {y->name() + ":0"}, {}, &outputs));
707 }
708
TEST_F(DirectSessionMinusAXTest,RunSimpleNetworkWithOpts)709 TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
710 Initialize({3, 2, -1, 0});
711 auto session = CreateSession();
712 ASSERT_TRUE(session != nullptr);
713 TF_ASSERT_OK(session->Create(def_));
714 std::vector<std::pair<string, Tensor>> inputs;
715
716 // Request two targets: one fetch output and one non-fetched output.
717 std::vector<string> output_names = {y_ + ":0"};
718 std::vector<string> target_nodes = {y_neg_};
719 std::vector<Tensor> outputs;
720
721 // Prepares RunOptions and RunMetadata
722 RunOptions run_options;
723 run_options.set_trace_level(RunOptions::SOFTWARE_TRACE);
724 RunMetadata run_metadata;
725 EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
726
727 Status s = session->Run(run_options, inputs, output_names, target_nodes,
728 &outputs, &run_metadata);
729 TF_ASSERT_OK(s);
730
731 ASSERT_EQ(1, outputs.size());
732 // The first output should be initialized and have the correct
733 // output.
734 auto mat = outputs[0].matrix<float>();
735 ASSERT_TRUE(outputs[0].IsInitialized());
736 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
737
738 // Checks RunMetadata is well-formed
739 ASSERT_TRUE(run_metadata.has_step_stats());
740 EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
741 }
742
TEST_F(DirectSessionMinusAXTest,RunSimpleNetworkWithOpts_Callable)743 TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
744 Initialize({3, 2, -1, 0});
745 auto session = CreateSession();
746 ASSERT_TRUE(session != nullptr);
747 TF_ASSERT_OK(session->Create(def_));
748
749 // Request two targets: one fetch output and one non-fetched output.
750 Session::CallableHandle handle;
751 CallableOptions callable_options =
752 MakeCallableOptions({}, {y_ + ":0"}, {y_neg_});
753 callable_options.mutable_run_options()->set_trace_level(
754 RunOptions::SOFTWARE_TRACE);
755 TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
756
757 RunMetadata run_metadata;
758 EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0);
759
760 std::vector<Tensor> outputs;
761 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, &run_metadata));
762
763 ASSERT_EQ(1, outputs.size());
764 // The first output should be initialized and have the correct
765 // output.
766 auto mat = outputs[0].matrix<float>();
767 ASSERT_TRUE(outputs[0].IsInitialized());
768 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
769
770 // Checks RunMetadata is well-formed
771 ASSERT_TRUE(run_metadata.has_step_stats());
772 EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
773 }
774
TEST_F(DirectSessionMinusAXTest,UseRunHandlerPool)775 TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
776 Initialize({3, 2, -1, 0});
777 auto session = CreateSession();
778 ASSERT_TRUE(session != nullptr);
779 TF_ASSERT_OK(session->Create(def_));
780 std::vector<std::pair<string, Tensor>> inputs;
781
782 // Request two targets: one fetch output and one non-fetched output.
783 std::vector<string> output_names = {y_ + ":0"};
784 std::vector<string> target_nodes = {y_neg_};
785 std::vector<Tensor> outputs;
786
787 // Prepares RunOptions and RunMetadata
788 RunOptions run_options;
789 run_options.mutable_experimental()->set_use_run_handler_pool(true);
790
791 Status s = session->Run(run_options, inputs, output_names, target_nodes,
792 &outputs, nullptr);
793 TF_ASSERT_OK(s);
794
795 ASSERT_EQ(1, outputs.size());
796 // The first output should be initialized and have the correct
797 // output.
798 auto mat = outputs[0].matrix<float>();
799 ASSERT_TRUE(outputs[0].IsInitialized());
800 EXPECT_FLOAT_EQ(5.0, mat(0, 0));
801 }
802
TEST(DirectSessionTest,KeepsStateAcrossRunsOfSession)803 TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
804 GraphDef def;
805 Graph g(OpRegistry::Global());
806 Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
807 var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
808
809 Tensor twenty(DT_FLOAT, TensorShape({10}));
810 for (int i = 0; i < 10; ++i) {
811 twenty.flat<float>()(i) = 20.0;
812 }
813
814 Node* twenty_node = test::graph::Constant(&g, twenty);
815 twenty_node->set_assigned_device_name(
816 "/job:localhost/replica:0/task:0/cpu:0");
817
818 Node* init = test::graph::Assign(&g, var, twenty_node);
819 init->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
820
821 g.ToGraphDef(&def);
822
823 auto session = CreateSession();
824 ASSERT_TRUE(session != nullptr);
825 TF_ASSERT_OK(session->Create(def));
826
827 std::vector<std::pair<string, Tensor>> inputs;
828 std::vector<Tensor> outputs;
829
830 // Initialize the variable
831 Status s = session->Run(inputs, {init->name()}, {}, &outputs);
832 TF_ASSERT_OK(s);
833
834 // Get the variable's data
835 s = session->Run(inputs, {var->name() + ":0"}, {}, &outputs);
836 TF_ASSERT_OK(s);
837 ASSERT_EQ(1, outputs.size());
838 ASSERT_TRUE(outputs[0].IsInitialized());
839 EXPECT_EQ(20.0, outputs[0].flat<float>()(0));
840 }
841
TEST(DirectSessionTest,MultipleFeedTest)842 TEST(DirectSessionTest, MultipleFeedTest) {
843 GraphDef def;
844 Graph g(OpRegistry::Global());
845
846 Tensor first_value(DT_FLOAT, TensorShape({}));
847 first_value.scalar<float>()() = 1.0;
848 Node* first_const = test::graph::Constant(&g, first_value);
849 Node* first_identity = test::graph::Identity(&g, first_const);
850
851 Tensor second_value(DT_FLOAT, TensorShape({}));
852 second_value.scalar<float>()() = 2.0;
853 Node* second_const = test::graph::Constant(&g, second_value);
854 Node* second_identity = test::graph::Identity(&g, second_const);
855
856 g.ToGraphDef(&def);
857
858 auto session = CreateSession();
859 ASSERT_TRUE(session != nullptr);
860 TF_ASSERT_OK(session->Create(def));
861
862 std::vector<Tensor> outputs;
863
864 // Fetch without feeding.
865 Status s = session->Run(
866 {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
867 &outputs);
868 TF_ASSERT_OK(s);
869 ASSERT_EQ(2, outputs.size());
870 ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
871 ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
872
873 s = session->Run(
874 {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
875 &outputs);
876 TF_ASSERT_OK(s);
877 ASSERT_EQ(2, outputs.size());
878 ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
879 ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
880
881 Tensor value_11(DT_FLOAT, TensorShape({}));
882 value_11.scalar<float>()() = 11.0;
883 Tensor value_22(DT_FLOAT, TensorShape({}));
884 value_22.scalar<float>()() = 22.0;
885
886 // Feed [first_const, second_const]
887 s = session->Run(
888 {{first_const->name(), value_11}, {second_const->name(), value_22}},
889 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
890 &outputs);
891 TF_ASSERT_OK(s);
892 ASSERT_EQ(2, outputs.size());
893 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
894 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
895
896 // Feed [second_const, first_const]
897 s = session->Run(
898 {{second_const->name(), value_22}, {first_const->name(), value_11}},
899 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
900 &outputs);
901 TF_ASSERT_OK(s);
902 ASSERT_EQ(2, outputs.size());
903 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
904 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
905
906 // Feed [first_const, first_const]
907 s = session->Run(
908 {{first_const->name(), value_11}, {first_const->name(), value_22}},
909 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
910 &outputs);
911 EXPECT_TRUE(errors::IsInvalidArgument(s));
912 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
913 }
914
TEST(DirectSessionTest,MultipleFeedTest_Callable)915 TEST(DirectSessionTest, MultipleFeedTest_Callable) {
916 GraphDef def;
917 Graph g(OpRegistry::Global());
918
919 Tensor first_value(DT_FLOAT, TensorShape({}));
920 first_value.scalar<float>()() = 1.0;
921 Node* first_const = test::graph::Constant(&g, first_value);
922 Node* first_identity = test::graph::Identity(&g, first_const);
923
924 Tensor second_value(DT_FLOAT, TensorShape({}));
925 second_value.scalar<float>()() = 2.0;
926 Node* second_const = test::graph::Constant(&g, second_value);
927 Node* second_identity = test::graph::Identity(&g, second_const);
928
929 g.ToGraphDef(&def);
930
931 auto session = CreateSession();
932 ASSERT_TRUE(session != nullptr);
933 TF_ASSERT_OK(session->Create(def));
934
935 Session::CallableHandle handle;
936 std::vector<Tensor> outputs;
937
938 // Fetch without feeding.
939 TF_ASSERT_OK(session->MakeCallable(
940 MakeCallableOptions(
941 {}, {first_identity->name() + ":0", second_identity->name() + ":0"},
942 {}),
943 &handle));
944 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
945 ASSERT_EQ(2, outputs.size());
946 ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
947 ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
948
949 TF_ASSERT_OK(session->MakeCallable(
950 MakeCallableOptions(
951 {}, {second_identity->name() + ":0", first_identity->name() + ":0"},
952 {}),
953 &handle));
954 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
955 ASSERT_EQ(2, outputs.size());
956 ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
957 ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
958
959 Tensor value_11(DT_FLOAT, TensorShape({}));
960 value_11.scalar<float>()() = 11.0;
961 Tensor value_22(DT_FLOAT, TensorShape({}));
962 value_22.scalar<float>()() = 22.0;
963
964 // Feed [first_const, second_const]
965 TF_ASSERT_OK(session->MakeCallable(
966 MakeCallableOptions(
967 {first_const->name(), second_const->name()},
968 {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
969 &handle));
970 TF_ASSERT_OK(
971 session->RunCallable(handle, {value_11, value_22}, &outputs, nullptr));
972 ASSERT_EQ(2, outputs.size());
973 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
974 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
975
976 // Feed [second_const, first_const]
977 TF_ASSERT_OK(session->MakeCallable(
978 MakeCallableOptions(
979 {second_const->name(), first_const->name()},
980 {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
981 &handle));
982 TF_ASSERT_OK(
983 session->RunCallable(handle, {value_22, value_11}, &outputs, nullptr));
984 ASSERT_EQ(2, outputs.size());
985 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
986 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
987
988 // Feed [first_const, first_const]
989 Status s = session->MakeCallable(
990 MakeCallableOptions(
991 {first_const->name(), first_const->name()},
992 {first_identity->name() + ":0", second_identity->name() + ":0"}, {}),
993 &handle);
994 EXPECT_TRUE(errors::IsInvalidArgument(s));
995 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
996 }
997
TEST(DirectSessionTest,TestTensorConnectionUseTwice)998 TEST(DirectSessionTest, TestTensorConnectionUseTwice) {
999 Graph graph(OpRegistry::Global());
1000
1001 Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
1002 test::FillValues<float>(&a_tensor, {1.0, 2.0, 3.0, 4.0});
1003 Node* a = test::graph::Constant(&graph, a_tensor);
1004
1005 Tensor dummy_tensor(DT_FLOAT, TensorShape({1}));
1006 test::FillValues<float>(&dummy_tensor, {-1.0});
1007
1008 Node* left = test::graph::Constant(&graph, dummy_tensor);
1009 Node* right = test::graph::Constant(&graph, dummy_tensor);
1010
1011 // y = A * x
1012 Node* y = test::graph::Add(&graph, left, right);
1013
1014 GraphDef def;
1015 graph.ToGraphDef(&def);
1016
1017 auto session = CreateSession();
1018 ASSERT_TRUE(session != nullptr);
1019 TF_ASSERT_OK(session->Create(def));
1020
1021 CallableOptions callable_options;
1022 // Directly wire the output of node a to the outputs of nodes left
1023 // and right, making the callable graph into "a + a;".
1024 TensorConnection* c_left = callable_options.add_tensor_connection();
1025 c_left->set_from_tensor(a->name() + ":0");
1026 c_left->set_to_tensor(left->name() + ":0");
1027 TensorConnection* c_right = callable_options.add_tensor_connection();
1028 c_right->set_from_tensor(a->name() + ":0");
1029 c_right->set_to_tensor(right->name() + ":0");
1030
1031 callable_options.add_fetch(y->name() + ":0");
1032
1033 Session::CallableHandle handle;
1034 TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
1035 std::vector<Tensor> outputs;
1036 TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
1037 ASSERT_EQ(1, outputs.size());
1038 auto mat = outputs[0].matrix<float>();
1039 ASSERT_TRUE(outputs[0].IsInitialized());
1040 EXPECT_FLOAT_EQ(2.0, mat(0, 0));
1041 EXPECT_FLOAT_EQ(4.0, mat(0, 1));
1042 EXPECT_FLOAT_EQ(6.0, mat(1, 0));
1043 EXPECT_FLOAT_EQ(8.0, mat(1, 1));
1044 TF_ASSERT_OK(session->ReleaseCallable(handle));
1045 }
1046
TEST(DirectSessionTest,FetchMultipleTimes)1047 TEST(DirectSessionTest, FetchMultipleTimes) {
1048 Graph g(OpRegistry::Global());
1049 Tensor seven_tensor(DT_INT32, TensorShape());
1050 seven_tensor.flat<int32>()(0) = 7;
1051 Node* seven_node = test::graph::Constant(&g, seven_tensor);
1052
1053 GraphDef def;
1054 g.ToGraphDef(&def);
1055
1056 auto session = CreateSession();
1057 ASSERT_TRUE(session != nullptr);
1058 TF_ASSERT_OK(session->Create(def));
1059
1060 const std::vector<std::pair<string, Tensor>> inputs;
1061 std::vector<Tensor> outputs;
1062
1063 auto seven = seven_node->name();
1064 Status s = session->Run(inputs, {seven, seven}, {}, &outputs);
1065 TF_ASSERT_OK(s);
1066
1067 EXPECT_EQ(2, outputs.size());
1068 for (int i = 0; i < outputs.size(); ++i) {
1069 const Tensor& t = outputs[i];
1070 ASSERT_TRUE(t.IsInitialized()) << i;
1071 EXPECT_EQ(7, t.flat<int32>()(0)) << i;
1072 }
1073 }
1074
TEST(DirectSessionTest,MultipleFeedTestSomeSyncRun)1075 TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) {
1076 GraphDef def;
1077 Graph g(OpRegistry::Global());
1078 RunOptions run_options;
1079 run_options.set_inter_op_thread_pool(-1);
1080
1081 Tensor first_value(DT_FLOAT, TensorShape({}));
1082 first_value.scalar<float>()() = 1.0;
1083 Node* first_const = test::graph::Constant(&g, first_value);
1084 Node* first_identity = test::graph::Identity(&g, first_const);
1085
1086 Tensor second_value(DT_FLOAT, TensorShape({}));
1087 second_value.scalar<float>()() = 2.0;
1088 Node* second_const = test::graph::Constant(&g, second_value);
1089 Node* second_identity = test::graph::Identity(&g, second_const);
1090
1091 g.ToGraphDef(&def);
1092
1093 auto session = CreateSession();
1094 ASSERT_TRUE(session != nullptr);
1095 TF_ASSERT_OK(session->Create(def));
1096
1097 std::vector<Tensor> outputs;
1098
1099 // Fetch without feeding.
1100 Status s = session->Run(
1101 run_options, {},
1102 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1103 &outputs, nullptr);
1104 TF_ASSERT_OK(s);
1105 ASSERT_EQ(2, outputs.size());
1106 ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
1107 ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
1108
1109 s = session->Run(
1110 {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
1111 &outputs);
1112 TF_ASSERT_OK(s);
1113 ASSERT_EQ(2, outputs.size());
1114 ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
1115 ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
1116
1117 Tensor value_11(DT_FLOAT, TensorShape({}));
1118 value_11.scalar<float>()() = 11.0;
1119 Tensor value_22(DT_FLOAT, TensorShape({}));
1120 value_22.scalar<float>()() = 22.0;
1121
1122 // Feed [first_const, second_const]
1123 s = session->Run(
1124 {{first_const->name(), value_11}, {second_const->name(), value_22}},
1125 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1126 &outputs);
1127 TF_ASSERT_OK(s);
1128 ASSERT_EQ(2, outputs.size());
1129 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1130 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
1131
1132 // Feed [second_const, first_const]
1133 s = session->Run(
1134 {{second_const->name(), value_22}, {first_const->name(), value_11}},
1135 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1136 &outputs);
1137 TF_ASSERT_OK(s);
1138 ASSERT_EQ(2, outputs.size());
1139 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1140 ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
1141
1142 // Feed [first_const, first_const]
1143 s = session->Run(
1144 run_options,
1145 {{first_const->name(), value_11}, {first_const->name(), value_22}},
1146 {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
1147 &outputs, nullptr);
1148 EXPECT_TRUE(errors::IsInvalidArgument(s));
1149 EXPECT_TRUE(absl::StrContains(s.error_message(), "fed more than once"));
1150 }
1151
1152 REGISTER_OP("SessionMetadataReader")
1153 .Input("x: int64")
1154 .Output("y: string")
1155 .SetIsStateful()
1156 .Doc(R"doc(SessionMetadataReader returns the session metadata.
1157
1158 x: int64
1159 y: string
1160 )doc");
1161
1162 class SessionMetadataReaderOp : public OpKernel {
1163 public:
SessionMetadataReaderOp(OpKernelConstruction * ctx)1164 explicit SessionMetadataReaderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1165 void Compute(OpKernelContext* ctx) override {
1166 Tensor* out_tensor = nullptr;
1167 OP_REQUIRES_OK(ctx,
1168 ctx->allocate_output("y", TensorShape({}), &out_tensor));
1169 if (ctx->session_metadata() != nullptr) {
1170 out_tensor->scalar<tstring>()() = ctx->session_metadata()->DebugString();
1171 } else {
1172 out_tensor->scalar<tstring>()() = "";
1173 }
1174 }
1175 };
1176 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_CPU),
1177 SessionMetadataReaderOp);
1178 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_GPU),
1179 SessionMetadataReaderOp);
1180
SessionMetadataReaderOpFn()1181 FunctionDef SessionMetadataReaderOpFn() {
1182 return FunctionDefHelper::Define(
1183 // Name
1184 "SessionMetadataReaderFn",
1185 // Args
1186 {"x: int64"},
1187 // Return values
1188 {"y: string"},
1189 // Attr def
1190 {},
1191 // Nodes
1192 {{{"y"}, "SessionMetadataReader", {"x"}, {}}});
1193 }
1194
TEST(DirectSessionTest,SessionMetadataAbsent)1195 TEST(DirectSessionTest, SessionMetadataAbsent) {
1196 Graph g(OpRegistry::Global());
1197 Tensor vx(DT_INT64, TensorShape({}));
1198 vx.scalar<int64>()() = 17;
1199 Node* x = test::graph::Constant(&g, vx);
1200 Node* y = test::graph::Unary(&g, "SessionMetadataReader", x);
1201 GraphDef def;
1202 g.ToGraphDef(&def);
1203 auto sess = CreateSession();
1204 TF_ASSERT_OK(sess->Create(def));
1205 std::vector<Tensor> outputs;
1206 RunOptions run_opts;
1207 run_opts.set_inter_op_thread_pool(-1);
1208 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1209
1210 EXPECT_EQ("", outputs[0].scalar<tstring>()());
1211 }
1212
TEST(DirectSessionTest,SessionMetadataAbsentViaFunction)1213 TEST(DirectSessionTest, SessionMetadataAbsentViaFunction) {
1214 FunctionDefLibrary library_graph_def;
1215 *library_graph_def.add_function() = SessionMetadataReaderOpFn();
1216 FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1217 Graph g(&flib);
1218 Tensor vx(DT_INT64, TensorShape({}));
1219 vx.scalar<int64>()() = 17;
1220 Node* x = test::graph::Constant(&g, vx);
1221 Node* y = test::graph::Unary(&g, "SessionMetadataReaderFn", x);
1222 GraphDef def;
1223 g.ToGraphDef(&def);
1224 *def.mutable_library() = library_graph_def;
1225 auto sess = CreateSession();
1226 TF_ASSERT_OK(sess->Create(def));
1227 std::vector<Tensor> outputs;
1228 RunOptions run_opts;
1229 run_opts.set_inter_op_thread_pool(-1);
1230 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1231
1232 EXPECT_EQ("", outputs[0].scalar<tstring>()());
1233 }
1234
TEST(DirectSessionTest,SessionMetadataPresent)1235 TEST(DirectSessionTest, SessionMetadataPresent) {
1236 Graph g(OpRegistry::Global());
1237 Tensor vx(DT_INT64, TensorShape({}));
1238 vx.scalar<int64>()() = 17;
1239 Node* x = test::graph::Constant(&g, vx);
1240 Node* y = test::graph::Unary(&g, "SessionMetadataReader", x);
1241 GraphDef def;
1242 g.ToGraphDef(&def);
1243 auto session_options = DefaultSessionOptions();
1244 auto* session_metadata =
1245 session_options.config.mutable_experimental()->mutable_session_metadata();
1246 session_metadata->set_name("name");
1247 session_metadata->set_version(1);
1248 auto sess = std::unique_ptr<Session>(NewSession(session_options));
1249 TF_ASSERT_OK(sess->Create(def));
1250 std::vector<Tensor> outputs;
1251 RunOptions run_opts;
1252 run_opts.set_inter_op_thread_pool(-1);
1253 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1254
1255 SessionMetadata read_metadata;
1256 ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1257 outputs[0].scalar<tstring>()(), &read_metadata));
1258 EXPECT_EQ("name", read_metadata.name());
1259 EXPECT_EQ(1, read_metadata.version());
1260 }
1261
TEST(DirectSessionTest,SessionMetadataPresentViaFunction)1262 TEST(DirectSessionTest, SessionMetadataPresentViaFunction) {
1263 FunctionDefLibrary library_graph_def;
1264 *library_graph_def.add_function() = SessionMetadataReaderOpFn();
1265 FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1266 Graph g(&flib);
1267 Tensor vx(DT_INT64, TensorShape({}));
1268 vx.scalar<int64>()() = 17;
1269 Node* x = test::graph::Constant(&g, vx);
1270 Node* y = test::graph::Unary(&g, "SessionMetadataReaderFn", x);
1271 GraphDef def;
1272 g.ToGraphDef(&def);
1273 *def.mutable_library() = library_graph_def;
1274 auto session_options = DefaultSessionOptions();
1275 auto* session_metadata =
1276 session_options.config.mutable_experimental()->mutable_session_metadata();
1277 session_metadata->set_name("name");
1278 session_metadata->set_version(1);
1279 auto sess = std::unique_ptr<Session>(NewSession(session_options));
1280 TF_ASSERT_OK(sess->Create(def));
1281 std::vector<Tensor> outputs;
1282 RunOptions run_opts;
1283 run_opts.set_inter_op_thread_pool(-1);
1284 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1285
1286 SessionMetadata read_metadata;
1287 ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
1288 outputs[0].scalar<tstring>()(), &read_metadata));
1289 EXPECT_EQ("name", read_metadata.name());
1290 EXPECT_EQ(1, read_metadata.version());
1291 }
1292
TEST(DirectSessionTest,SessionMetadataKey)1293 TEST(DirectSessionTest, SessionMetadataKey) {
1294 auto session_options0 = DefaultSessionOptions();
1295 auto* session_metadata0 = session_options0.config.mutable_experimental()
1296 ->mutable_session_metadata();
1297 session_metadata0->set_name("name");
1298 Session* sess0_ptr;
1299 ASSERT_TRUE(NewSession(session_options0, &sess0_ptr).ok());
1300 auto sess0 = absl::WrapUnique(sess0_ptr);
1301
1302 // Trying to use the same metadata (name, version) will cause an error.
1303 Session* dup_ptr;
1304 EXPECT_TRUE(
1305 errors::IsInvalidArgument(NewSession(session_options0, &dup_ptr)));
1306
1307 // A new (name, version) is fine.
1308 auto session_options1 = DefaultSessionOptions();
1309 auto* session_metadata1 = session_options1.config.mutable_experimental()
1310 ->mutable_session_metadata();
1311 session_metadata1->set_name("name");
1312 session_metadata1->set_version(1);
1313 Session* sess1_ptr;
1314 EXPECT_TRUE(NewSession(session_options1, &sess1_ptr).ok());
1315 auto sess1 = absl::WrapUnique(sess1_ptr);
1316
1317 // If the previous session, using the same (name, version) is gone, then it's
1318 // fine.
1319 sess0 = nullptr;
1320 EXPECT_TRUE(NewSession(session_options0, &dup_ptr).ok());
1321 auto dup = absl::WrapUnique(dup_ptr);
1322
1323 // Sessions without metadata options are always fine.
1324 auto sess_without_metadata0 = CreateSession();
1325 EXPECT_NE(sess_without_metadata0, nullptr);
1326 auto sess_without_metadata1 = CreateSession();
1327 EXPECT_NE(sess_without_metadata1, nullptr);
1328 }
1329
TEST(DirectSessionTest,SessionMetadataInvalid)1330 TEST(DirectSessionTest, SessionMetadataInvalid) {
1331 const auto valid_session_options = DefaultSessionOptions();
1332 Session* sess_ptr;
1333 ASSERT_TRUE(NewSession(valid_session_options, &sess_ptr).ok());
1334 auto sess = absl::WrapUnique(sess_ptr);
1335
1336 auto invalid_session_options = valid_session_options;
1337 auto* invalid_metadata =
1338 invalid_session_options.config.mutable_experimental()
1339 ->mutable_session_metadata();
1340 invalid_metadata->set_name("name");
1341 // Version should be >= 0.
1342 invalid_metadata->set_version(-1);
1343 Session* error_sess_ptr;
1344 EXPECT_TRUE(errors::IsInvalidArgument(
1345 NewSession(invalid_session_options, &error_sess_ptr)));
1346 }
1347
1348 REGISTER_OP("ThreadID").Input("x: int64").Output("y: int64").Doc(R"doc(
1349 ThreadID returns the thread ID that called compute.
1350
1351 x: int64
1352 y: int64
1353 )doc");
1354
1355 // The ThreadID kernel returns the thread ID that executed Compute.
1356 class ThreadIDOp : public OpKernel {
1357 public:
ThreadIDOp(OpKernelConstruction * ctx)1358 explicit ThreadIDOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1359 void Compute(OpKernelContext* ctx) override {
1360 Tensor* out_tensor = nullptr;
1361 OP_REQUIRES_OK(ctx,
1362 ctx->allocate_output("y", TensorShape({}), &out_tensor));
1363 std::hash<std::thread::id> hasher;
1364 out_tensor->scalar<int64>()() =
1365 static_cast<int64>(hasher(std::this_thread::get_id()));
1366 }
1367 };
1368 REGISTER_KERNEL_BUILDER(Name("ThreadID").Device(DEVICE_CPU), ThreadIDOp);
1369
TEST(DirectSessionTest,SessionSyncRun)1370 TEST(DirectSessionTest, SessionSyncRun) {
1371 Graph g(OpRegistry::Global());
1372 Tensor vx(DT_INT64, TensorShape({}));
1373 vx.scalar<int64>()() = 17;
1374 Node* x = test::graph::Constant(&g, vx);
1375 Node* y = test::graph::Unary(&g, "ThreadID", x);
1376 GraphDef def;
1377 g.ToGraphDef(&def);
1378 auto sess = CreateSession();
1379 TF_ASSERT_OK(sess->Create(def));
1380 std::vector<Tensor> outputs;
1381 RunOptions run_opts;
1382 run_opts.set_inter_op_thread_pool(-1);
1383 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1384
1385 std::hash<std::thread::id> hasher;
1386 EXPECT_EQ(static_cast<int64>(hasher(std::this_thread::get_id())),
1387 static_cast<int64>(outputs[0].scalar<int64>()()));
1388 }
1389
1390 REGISTER_OP("ExpensiveNoop").SetIsStateful();
1391
1392 class ExpensiveNoopOp : public OpKernel {
1393 public:
1394 using OpKernel::OpKernel;
IsExpensive()1395 bool IsExpensive() override { return true; }
Compute(OpKernelContext * ctx)1396 void Compute(OpKernelContext* ctx) override {
1397 const string& stack_trace = tensorflow::CurrentStackTrace();
1398 const string process_method = "ExecutorState::Process()";
1399 size_t pos = 0;
1400 int frame_count = 0;
1401 while ((pos = stack_trace.find("ExecutorState::Process()", pos)) !=
1402 string::npos) {
1403 ++frame_count;
1404 ++pos;
1405 }
1406 OP_REQUIRES(ctx, frame_count <= 1,
1407 errors::Internal(
1408 "Recursive call to ExecutorState::Process() detected."));
1409 }
1410 };
1411
1412 REGISTER_KERNEL_BUILDER(Name("ExpensiveNoop").Device(DEVICE_CPU),
1413 ExpensiveNoopOp);
1414
TEST(DirectSessionTest,SessionSyncRun_DeepGraph)1415 TEST(DirectSessionTest, SessionSyncRun_DeepGraph) {
1416 Graph g(OpRegistry::Global());
1417
1418 std::vector<Node*> nodes;
1419 nodes.reserve(1024);
1420
1421 auto make_expensive_noop = [&g](gtl::ArraySlice<Node*> control_deps) {
1422 Node* ret;
1423 auto builder = NodeBuilder(g.NewName("N"), "ExpensiveNoop");
1424 for (Node* control_dep : control_deps) {
1425 builder = builder.ControlInput(control_dep);
1426 }
1427 TF_CHECK_OK(builder.Finalize(&g, &ret));
1428 return ret;
1429 };
1430
1431 Node* base = make_expensive_noop({});
1432
1433 Node* child_1 = make_expensive_noop({base});
1434 Node* child_2 = make_expensive_noop({base});
1435
1436 GraphDef def;
1437 g.ToGraphDef(&def);
1438
1439 auto sess = CreateSession();
1440 TF_ASSERT_OK(sess->Create(def));
1441 std::vector<Tensor> outputs;
1442 RunOptions run_opts;
1443 run_opts.set_inter_op_thread_pool(-1);
1444
1445 EXPECT_TRUE(sess->Run(run_opts, {}, {}, {child_1->name(), child_2->name()},
1446 &outputs, nullptr)
1447 .ok());
1448 }
1449
TEST(DirectSessionTest,SyncSession)1450 TEST(DirectSessionTest, SyncSession) {
1451 Graph g(OpRegistry::Global());
1452 Tensor vx(DT_INT64, TensorShape({}));
1453 vx.scalar<int64>()() = 17;
1454 Node* x = test::graph::Constant(&g, vx);
1455 Node* y = test::graph::Unary(&g, "ThreadID", x);
1456 GraphDef def;
1457 g.ToGraphDef(&def);
1458 SessionOptions options;
1459 options.config.set_inter_op_parallelism_threads(-1);
1460 std::unique_ptr<Session> sess(NewSession(options));
1461 TF_ASSERT_OK(sess->Create(def));
1462 std::vector<Tensor> outputs;
1463 RunOptions run_opts;
1464 auto s = sess->Run(run_opts, {}, {y->name() + ":0"}, {}, &outputs, nullptr);
1465
1466 std::hash<std::thread::id> hasher;
1467 EXPECT_EQ(static_cast<int64>(hasher(std::this_thread::get_id())),
1468 static_cast<int64>(outputs[0].scalar<int64>()()));
1469 }
1470
1471 REGISTER_OP("Darth").Input("x: float").Output("y: float").Doc(R"doc(
1472 Darth promises one return value.
1473
1474 x: float
1475 y: float
1476 )doc");
1477
1478 // The DarthOp kernel violates its promise to return one-value.
1479 class DarthOp : public OpKernel {
1480 public:
DarthOp(OpKernelConstruction * ctx)1481 explicit DarthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1482 void Compute(OpKernelContext* ctx) override {}
1483 };
1484 REGISTER_KERNEL_BUILDER(Name("Darth").Device(DEVICE_CPU), DarthOp);
1485
TEST(DirectSessionTest,DarthKernel)1486 TEST(DirectSessionTest, DarthKernel) {
1487 Graph g(OpRegistry::Global());
1488 Tensor vx(DT_FLOAT, TensorShape({}));
1489 vx.scalar<float>()() = 1.0;
1490 Node* x = test::graph::Constant(&g, vx);
1491 Node* y = test::graph::Unary(&g, "Darth", x);
1492 GraphDef def;
1493 g.ToGraphDef(&def);
1494 auto sess = CreateSession();
1495 TF_ASSERT_OK(sess->Create(def));
1496 std::vector<Tensor> outputs;
1497 auto s = sess->Run({}, {y->name() + ":0"}, {}, &outputs);
1498 EXPECT_TRUE(errors::IsInternal(s));
1499 }
1500
1501 // Have the Darth op in the graph placed on GPU, but don't run it.
TEST(DirectSessionTest,PlacePrunedGraph)1502 TEST(DirectSessionTest, PlacePrunedGraph) {
1503 {
1504 Graph g(OpRegistry::Global());
1505 Tensor vx(DT_FLOAT, TensorShape({}));
1506 vx.scalar<float>()() = 1.0;
1507 Node* x = test::graph::Constant(&g, vx);
1508 Node* y = test::graph::Unary(&g, "Darth", x);
1509 y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
1510 GraphDef def;
1511 g.ToGraphDef(&def);
1512
1513 // By default, we place the entire graph, so we should fail the
1514 // call to Create.
1515 SessionOptions options;
1516 std::unique_ptr<Session> sess(NewSession(options));
1517 auto s = sess->Create(def);
1518 EXPECT_TRUE(errors::IsInvalidArgument(s));
1519 }
1520
1521 {
1522 Graph g(OpRegistry::Global());
1523 Tensor vx(DT_FLOAT, TensorShape({}));
1524 vx.scalar<float>()() = 1.0;
1525 Node* x = test::graph::Constant(&g, vx);
1526 Node* y = test::graph::Unary(&g, "Darth", x);
1527 y->set_assigned_device_name("/job:localhost/replica:0/task:0/device:GPU:0");
1528 GraphDef def;
1529 g.ToGraphDef(&def);
1530
1531 SessionOptions options;
1532 // Set the option to place pruned graphs, we should expect this
1533 // to run.
1534 options.config.mutable_graph_options()->set_place_pruned_graph(true);
1535 std::unique_ptr<Session> sess(NewSession(options));
1536 TF_ASSERT_OK(sess->Create(def));
1537 std::vector<Tensor> outputs;
1538 auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs);
1539 TF_EXPECT_OK(s);
1540 }
1541 }
1542
TEST(DirectSessionTest,PartialRunTest)1543 TEST(DirectSessionTest, PartialRunTest) {
1544 GraphDef def;
1545 Graph g(OpRegistry::Global());
1546
1547 Tensor first_value(DT_FLOAT, TensorShape({}));
1548 first_value.scalar<float>()() = 1.0;
1549 Node* first_const = test::graph::Constant(&g, first_value);
1550 Node* first_identity = test::graph::Identity(&g, first_const);
1551
1552 Tensor second_value(DT_FLOAT, TensorShape({}));
1553 second_value.scalar<float>()() = 2.0;
1554 Node* second_const = test::graph::Constant(&g, second_value);
1555 Node* second_identity = test::graph::Identity(&g, second_const);
1556
1557 Node* third = test::graph::Add(&g, first_identity, second_identity);
1558 Node* third_identity = test::graph::Identity(&g, third);
1559
1560 g.ToGraphDef(&def);
1561
1562 auto session = CreateSession();
1563 ASSERT_TRUE(session != nullptr);
1564 TF_ASSERT_OK(session->Create(def));
1565
1566 std::vector<Tensor> outputs;
1567
1568 string handle;
1569 Status s = session->PRunSetup(
1570 {first_const->name(), second_const->name()},
1571 {first_identity->name() + ":0", second_identity->name() + ":0",
1572 third_identity->name() + ":0"},
1573 {}, &handle);
1574 TF_ASSERT_OK(s);
1575
1576 Tensor value_11(DT_FLOAT, TensorShape({}));
1577 value_11.scalar<float>()() = 11.0;
1578 Tensor value_22(DT_FLOAT, TensorShape({}));
1579 value_22.scalar<float>()() = 22.0;
1580
1581 // Feed first_const, fetch first_identity
1582 s = session->PRun(handle, {{first_const->name(), value_11}},
1583 {first_identity->name() + ":0"}, &outputs);
1584 TF_ASSERT_OK(s);
1585 ASSERT_EQ(1, outputs.size());
1586 ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
1587
1588 // Feed second_const, fetch second_identity and third_identity
1589 s = session->PRun(
1590 handle, {{second_const->name(), value_22}},
1591 {second_identity->name() + ":0", third_identity->name() + ":0"},
1592 &outputs);
1593 TF_ASSERT_OK(s);
1594 ASSERT_EQ(2, outputs.size());
1595 ASSERT_EQ(22.0, outputs[0].flat<float>()(0));
1596 ASSERT_EQ(11.0 + 22.0, outputs[1].flat<float>()(0));
1597 }
1598
TEST(DirectSessionTest,PartialRunMissingFeed)1599 TEST(DirectSessionTest, PartialRunMissingFeed) {
1600 GraphDef def;
1601 Graph g(OpRegistry::Global());
1602
1603 Tensor first_value(DT_FLOAT, TensorShape({}));
1604 first_value.scalar<float>()() = 1.0;
1605 Node* first_const = test::graph::Constant(&g, first_value);
1606 Node* first_identity = test::graph::Identity(&g, first_const);
1607
1608 Tensor second_value(DT_FLOAT, TensorShape({}));
1609 second_value.scalar<float>()() = 2.0;
1610 Node* second_const = test::graph::Constant(&g, second_value);
1611 Node* second_identity = test::graph::Identity(&g, second_const);
1612
1613 Node* third = test::graph::Add(&g, first_identity, second_identity);
1614 Node* third_identity = test::graph::Identity(&g, third);
1615
1616 g.ToGraphDef(&def);
1617
1618 auto session = CreateSession();
1619 ASSERT_TRUE(session != nullptr);
1620 TF_ASSERT_OK(session->Create(def));
1621
1622 std::vector<Tensor> outputs;
1623
1624 string handle;
1625 Status s = session->PRunSetup({first_const->name(), second_const->name()},
1626 {third_identity->name() + ":0"}, {}, &handle);
1627 TF_ASSERT_OK(s);
1628
1629 // Feed first_const, fetch third_identity
1630 Tensor value_11(DT_FLOAT, TensorShape({}));
1631 value_11.scalar<float>()() = 11.0;
1632 s = session->PRun(handle, {{first_const->name(), value_11}},
1633 {third_identity->name() + ":0"}, &outputs);
1634 ASSERT_TRUE(errors::IsInvalidArgument(s));
1635 EXPECT_TRUE(
1636 absl::StrContains(s.error_message(), "can't be computed from the feeds"));
1637 }
1638
TEST(DirectSessionTest,PartialRunMultiOutputFeed)1639 TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
1640 GraphDef def;
1641 Graph g(OpRegistry::Global());
1642
1643 Tensor bool_value(DT_BOOL, TensorShape({}));
1644 bool_value.scalar<bool>()() = true;
1645 Node* bool_const = test::graph::Constant(&g, bool_value);
1646 Node* switch_node = test::graph::Switch(&g, bool_const, bool_const);
1647 Node* fourth_identity = test::graph::Identity(&g, switch_node, 1);
1648
1649 g.ToGraphDef(&def);
1650
1651 auto session = CreateSession();
1652 ASSERT_TRUE(session != nullptr);
1653 TF_ASSERT_OK(session->Create(def));
1654
1655 std::vector<Tensor> outputs;
1656
1657 string handle;
1658 Status s = session->PRunSetup({switch_node->name() + ":1"},
1659 {fourth_identity->name() + ":0"}, {}, &handle);
1660 TF_ASSERT_OK(s);
1661
1662 // Fetch fourth_identity without feeds.
1663 s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs);
1664 ASSERT_TRUE(errors::IsInvalidArgument(s));
1665 EXPECT_TRUE(
1666 absl::StrContains(s.error_message(), "can't be computed from the feeds"));
1667
1668 // Feed switch_node:1 and fetch fourth_identity.
1669 s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}},
1670 {fourth_identity->name() + ":0"}, &outputs);
1671 TF_ASSERT_OK(s);
1672 ASSERT_EQ(1, outputs.size());
1673 ASSERT_EQ(true, outputs[0].flat<bool>()(0));
1674 }
1675
TEST(DirectSessionTest,RunHandleTest)1676 TEST(DirectSessionTest, RunHandleTest) {
1677 GraphDef def;
1678 Graph g(OpRegistry::Global());
1679
1680 Tensor value0(DT_FLOAT, TensorShape({}));
1681 value0.scalar<float>()() = 1.0;
1682 Node* const0 = test::graph::Constant(&g, value0);
1683 Node* identity0 = test::graph::Identity(&g, const0);
1684
1685 Tensor value1(DT_FLOAT, TensorShape({}));
1686 value1.scalar<float>()() = 2.0;
1687 Node* const1 = test::graph::Constant(&g, value1);
1688 Node* node3 = test::graph::Add(&g, identity0, const1);
1689 Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
1690
1691 Tensor value2(DT_STRING, TensorShape({}));
1692 Node* const2 = test::graph::Constant(&g, value2);
1693 Node* node5 = test::graph::GetSessionTensor(&g, const2);
1694 Node* node6 = test::graph::Add(&g, node5, const1);
1695
1696 Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
1697
1698 g.ToGraphDef(&def);
1699
1700 auto session = CreateSession();
1701 ASSERT_TRUE(session != nullptr);
1702 TF_ASSERT_OK(session->Create(def));
1703
1704 // First run call: Create a handle.
1705 std::vector<Tensor> outputs;
1706 Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
1707 ASSERT_TRUE(s.ok());
1708 ASSERT_EQ(1, outputs.size());
1709
1710 const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
1711 Tensor string_handle(DT_STRING, {});
1712 string_handle.flat<tstring>().setConstant(resource_handle.name());
1713
1714 // Second run call: Use a handle.
1715 std::vector<Tensor> outputs1;
1716 s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
1717 {}, &outputs1);
1718 ASSERT_TRUE(s.ok());
1719 ASSERT_EQ(1, outputs1.size());
1720 ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
1721
1722 // Third run call: Delete a handle.
1723 std::vector<Tensor> outputs2;
1724 s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
1725 &outputs2);
1726 ASSERT_TRUE(s.ok());
1727 }
1728
TEST(DirectSessionTest,RunHandleTest_Callable)1729 TEST(DirectSessionTest, RunHandleTest_Callable) {
1730 GraphDef def;
1731 Graph g(OpRegistry::Global());
1732
1733 Tensor value0(DT_FLOAT, TensorShape({}));
1734 value0.scalar<float>()() = 1.0;
1735 Node* const0 = test::graph::Constant(&g, value0);
1736 Node* identity0 = test::graph::Identity(&g, const0);
1737
1738 Tensor value1(DT_FLOAT, TensorShape({}));
1739 value1.scalar<float>()() = 2.0;
1740 Node* const1 = test::graph::Constant(&g, value1);
1741 Node* node3 = test::graph::Add(&g, identity0, const1);
1742 Node* node4 = test::graph::Unary(&g, "GetSessionHandleV2", node3);
1743
1744 Tensor value2(DT_STRING, TensorShape({}));
1745 Node* const2 = test::graph::Constant(&g, value2);
1746 Node* node5 = test::graph::GetSessionTensor(&g, const2);
1747 Node* node6 = test::graph::Add(&g, node5, const1);
1748
1749 Node* node7 = test::graph::Unary(&g, "DeleteSessionTensor", const2);
1750
1751 g.ToGraphDef(&def);
1752
1753 auto session = CreateSession();
1754 ASSERT_TRUE(session != nullptr);
1755 TF_ASSERT_OK(session->Create(def));
1756
1757 // First run call: Create a handle.
1758 std::vector<Tensor> outputs;
1759 Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs);
1760 ASSERT_TRUE(s.ok());
1761 ASSERT_EQ(1, outputs.size());
1762
1763 const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
1764 Tensor string_handle(DT_STRING, {});
1765 string_handle.flat<tstring>().setConstant(resource_handle.name());
1766
1767 // Second run call: Use a handle.
1768 std::vector<Tensor> outputs1;
1769 s = session->Run({{const2->name(), string_handle}}, {node6->name() + ":0"},
1770 {}, &outputs1);
1771 ASSERT_TRUE(s.ok());
1772 ASSERT_EQ(1, outputs1.size());
1773 ASSERT_EQ(5.0, outputs1[0].flat<float>()(0));
1774
1775 // Third run call: Delete a handle.
1776 std::vector<Tensor> outputs2;
1777 s = session->Run({{const2->name(), string_handle}}, {}, {node7->name()},
1778 &outputs2);
1779 ASSERT_TRUE(s.ok());
1780 }
1781
TEST(DirectSessionTest,CreateGraphFailsWhenAssigningAFedVar)1782 TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) {
1783 Graph graph(OpRegistry::Global());
1784
1785 Node* a = test::graph::Var(&graph, DT_FLOAT, {});
1786 Node* b = test::graph::Constant(&graph, {});
1787
1788 Tensor zero(DT_FLOAT, {});
1789 test::FillValues<float>(&zero, {0});
1790
1791 // a = b
1792 Node* assign = test::graph::Assign(&graph, a, b);
1793
1794 auto session = CreateSession();
1795 ASSERT_TRUE(session != nullptr);
1796
1797 // The graph is invalid since a constant cannot be assigned to a constant.
1798 // The return Status of session->Run should flag this as an invalid argument.
1799 std::vector<Tensor> outputs;
1800 Status s = session->Run({{a->name(), zero}}, {assign->name()}, {}, &outputs);
1801 ASSERT_TRUE(errors::IsInvalidArgument(s));
1802 }
1803
TEST(DirectSessionTest,TimeoutSession)1804 TEST(DirectSessionTest, TimeoutSession) {
1805 GraphDef graph;
1806 // Creates a graph with one FIFOQueue and one dequeue op.
1807 protobuf::TextFormat::ParseFromString(R"proto(
1808 node {
1809 name: 'fifo_queue'
1810 op: 'FIFOQueue'
1811 device: '/device:CPU:0'
1812 attr {
1813 key: 'capacity'
1814 value { i: 10 }
1815 }
1816 attr {
1817 key: 'component_types'
1818 value { list { type: DT_FLOAT } }
1819 }
1820 attr {
1821 key: 'container'
1822 value { s: '' }
1823 }
1824 attr {
1825 key: 'shapes'
1826 value { list {} }
1827 }
1828 attr {
1829 key: 'shared_name'
1830 value { s: '' }
1831 }
1832 }
1833 node {
1834 name: 'fifo_queue_Dequeue'
1835 op: 'QueueDequeue'
1836 input: 'fifo_queue'
1837 device: '/device:CPU:0'
1838 attr {
1839 key: 'component_types'
1840 value { list { type: DT_FLOAT } }
1841 }
1842 attr {
1843 key: 'timeout_ms'
1844 value { i: -1 }
1845 }
1846 }
1847 versions { producer: 9 }
1848 )proto", &graph);
1849
1850 {
1851 // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
1852 SessionOptions options;
1853 (*options.config.mutable_device_count())["CPU"] = 2;
1854 options.config.set_operation_timeout_in_ms(100);
1855
1856 std::unique_ptr<Session> session(NewSession(options));
1857 ASSERT_TRUE(session != nullptr);
1858 TF_ASSERT_OK(session->Create(graph));
1859
1860 // Verifies that the error code is DEADLINE_EXCEEDED.
1861 Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr);
1862 ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
1863 TF_ASSERT_OK(session->Close());
1864 }
1865
1866 {
1867 // Creates a session with no operation_timeout_in_ms.
1868 auto session = CreateSession();
1869 ASSERT_TRUE(session != nullptr);
1870 TF_ASSERT_OK(session->Create(graph));
1871 RunOptions run_options;
1872 run_options.set_timeout_in_ms(20);
1873 // Verifies that the error code is DEADLINE_EXCEEDED.
1874 Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"},
1875 nullptr, nullptr);
1876 ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code());
1877 TF_ASSERT_OK(session->Close());
1878 }
1879 }
1880
1881 // Accesses the cancellation manager for the step after the step has been
1882 // cancelled.
1883 class CancellationMgrPollingOp : public OpKernel {
1884 public:
CancellationMgrPollingOp(OpKernelConstruction * ctx)1885 explicit CancellationMgrPollingOp(OpKernelConstruction* ctx)
1886 : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1887 void Compute(OpKernelContext* ctx) override {
1888 CancellationManager* cm = ctx->cancellation_manager();
1889 while (!cm->IsCancelled()) {
1890 ctx->env()->SleepForMicroseconds(1000);
1891 }
1892 notification.Notify();
1893 }
1894 static Notification notification;
1895 };
1896 Notification CancellationMgrPollingOp::notification;
1897
1898 REGISTER_KERNEL_BUILDER(Name("CancellationMgrPollingOp").Device(DEVICE_CPU),
1899 CancellationMgrPollingOp);
1900 REGISTER_OP("CancellationMgrPollingOp").Doc("");
1901
TEST(DirectSessionTest,TestTimeoutCleanShutdown)1902 TEST(DirectSessionTest, TestTimeoutCleanShutdown) {
1903 GraphDef graph;
1904 // Creates a graph with one FIFOQueue and one dequeue op.
1905 protobuf::TextFormat::ParseFromString(R"proto(
1906 node {
1907 name: 'cm_polling'
1908 op: 'CancellationMgrPollingOp'
1909 device: '/device:CPU:0'
1910 }
1911 versions { producer: 9 }
1912 )proto", &graph);
1913
1914 // Creates a session with operation_timeout_in_ms set to 100 milliseconds.
1915 SessionOptions options;
1916 options.config.set_operation_timeout_in_ms(100);
1917 std::unique_ptr<Session> session(NewSession(options));
1918 ASSERT_TRUE(session != nullptr);
1919 TF_ASSERT_OK(session->Create(graph));
1920
1921 // Verifies that the error code is DEADLINE_EXCEEDED.
1922 Status s = session->Run({}, {}, {"cm_polling"}, nullptr);
1923 ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
1924
1925 // Verify that the op ran to completion.
1926 ASSERT_TRUE(CancellationMgrPollingOp::notification.HasBeenNotified());
1927
1928 TF_ASSERT_OK(session->Close());
1929 }
1930
TestSessionInterOpThreadsImpl(bool use_function_lib,bool use_global_pools)1931 static void TestSessionInterOpThreadsImpl(bool use_function_lib,
1932 bool use_global_pools) {
1933 using test::function::blocking_op_state;
1934 using test::function::BlockingOpState;
1935
1936 FunctionDefLibrary library_graph_def;
1937 if (use_function_lib) {
1938 *library_graph_def.add_function() = test::function::BlockingOpFn();
1939 }
1940
1941 FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
1942 Graph g(&flib);
1943 Tensor t(DT_FLOAT, TensorShape({}));
1944 t.scalar<float>()() = {1.2f};
1945 Node* x = test::graph::Constant(&g, t);
1946 Node* y;
1947 if (use_function_lib) {
1948 y = test::graph::Unary(&g, "BlockingOpFn", x);
1949 } else {
1950 y = test::graph::Unary(&g, "BlockingOp", x);
1951 }
1952 GraphDef def;
1953 g.ToGraphDef(&def);
1954 *def.mutable_library() = library_graph_def;
1955
1956 // Create session with two inter-op thread pools.
1957 SessionOptions options;
1958 // Turn off optimizations so that the blocking op doesn't get invoked during
1959 // graph setup.
1960 options.config.mutable_graph_options()
1961 ->mutable_optimizer_options()
1962 ->set_opt_level(OptimizerOptions::L0);
1963 options.config.mutable_graph_options()
1964 ->mutable_rewrite_options()
1965 ->set_constant_folding(RewriterConfig::OFF);
1966 (*options.config.mutable_device_count())["CPU"] = 2;
1967 (*options.config.mutable_device_count())["GPU"] = 0;
1968
1969 auto* p = options.config.add_session_inter_op_thread_pool();
1970 if (use_global_pools) p->set_global_name("large pool");
1971 p = options.config.add_session_inter_op_thread_pool();
1972 if (use_global_pools) p->set_global_name("small pool");
1973 p->set_num_threads(1);
1974 const int kSyncPool = -1;
1975 const int kLargePool = 0;
1976 const int kSmallPool = 1;
1977
1978 std::vector<std::unique_ptr<Session>> sessions;
1979 if (!use_global_pools) {
1980 sessions.emplace_back(NewSession(options));
1981 TF_ASSERT_OK(sessions.back()->Create(def));
1982 }
1983 mutex sessions_mu;
1984
1985 std::atomic<int32> num_done(0);
1986 // Runs session to compute <node>:0 using inter_op thread pool <pool>.
1987 auto add_session_run_call =
1988 [use_global_pools, &def, &options, &sessions, &sessions_mu, &num_done](
1989 thread::ThreadPool* tp, Node* node, int inter_op_pool) {
1990 auto fn = [use_global_pools, &def, &options, &sessions, &sessions_mu,
1991 inter_op_pool, node, &num_done]() {
1992 RunOptions run_options;
1993 run_options.set_inter_op_thread_pool(inter_op_pool);
1994 std::vector<Tensor> outputs;
1995
1996 Session* session;
1997 if (use_global_pools) {
1998 std::unique_ptr<Session> s(NewSession(options));
1999 TF_ASSERT_OK(s->Create(def));
2000 session = s.get();
2001
2002 mutex_lock l(sessions_mu);
2003 sessions.emplace_back(std::move(s));
2004 } else {
2005 session = sessions[0].get();
2006 }
2007
2008 Status s = session->Run(run_options, {} /* inputs */,
2009 {node->name() + ":0"} /* output_names */, {},
2010 &outputs, nullptr /* run_metadata */);
2011 TF_CHECK_OK(s);
2012 ASSERT_EQ(1, outputs.size());
2013 auto flat = outputs[0].flat<float>();
2014 EXPECT_FLOAT_EQ(1.2, flat(0));
2015 num_done.fetch_add(1);
2016 };
2017 if (tp != nullptr) {
2018 tp->Schedule(fn);
2019 } else {
2020 fn();
2021 }
2022 };
2023
2024 // For blocking states:
2025 // - Starts at 0, BlockingOp::Compute will move to 1.
2026 // - This main thread will wait for 1, then move to 2 when other ops are done.
2027 // Moving to 2 unblocks the blocking op, which then moves to state 3.
2028
2029 // Run the graph once on the non-limited pool.
2030 thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 1);
2031 blocking_op_state = new BlockingOpState();
2032 add_session_run_call(tp1, y, kLargePool);
2033 blocking_op_state->AwaitState(1);
2034 blocking_op_state->MoveToState(1, 2);
2035 blocking_op_state->AwaitState(3);
2036 blocking_op_state->MoveToState(3, 0);
2037 delete tp1;
2038 num_done = 0;
2039
2040 tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
2041
2042 // Launch a session run call. It will not finish until the blocking op is
2043 // unblocked, because it is using all threads in the small pool.
2044 add_session_run_call(tp1, y, kSmallPool);
2045
2046 blocking_op_state->AwaitState(1); // Wait for the blocking op to Compute.
2047
2048 // These will block on <BlockingOpState>.
2049 const int kBlockedThreads = 3;
2050 for (int i = 0; i < kBlockedThreads; ++i) {
2051 add_session_run_call(tp1, x, kSmallPool);
2052 }
2053
2054 // Launch session calls using the other inter-op pool. These will finish
2055 // as they are in inter_op pool #2.
2056 thread::ThreadPool* tp2 = new thread::ThreadPool(Env::Default(), "tp2", 3);
2057 const int kUnblockedThreads = 4;
2058 for (int i = 0; i < kUnblockedThreads; ++i) {
2059 add_session_run_call(tp2, x, kLargePool);
2060 }
2061 delete tp2;
2062 EXPECT_EQ(kUnblockedThreads, num_done.load());
2063
2064 // Launch a session call using this thread. This will finish as it runs
2065 // synchronously in this thread.
2066 add_session_run_call(nullptr, x, kSyncPool);
2067
2068 // Unblock the blocked op and wait for the blocked functions to finish.
2069 blocking_op_state->MoveToState(1, 2);
2070 delete tp1;
2071
2072 EXPECT_EQ(kUnblockedThreads + kBlockedThreads + 1 + 1, num_done.load());
2073 delete blocking_op_state;
2074 blocking_op_state = nullptr;
2075 }
2076
TEST(DirectSessionTest,TestSessionInterOpThreads)2077 TEST(DirectSessionTest, TestSessionInterOpThreads) {
2078 TestSessionInterOpThreadsImpl(false /* use_function_lib */,
2079 false /*use_global_pools */);
2080 }
2081
TEST(DirectSessionTest,TestSessionInterOpThreadsWithFunctions)2082 TEST(DirectSessionTest, TestSessionInterOpThreadsWithFunctions) {
2083 TestSessionInterOpThreadsImpl(true /* use_function_lib */,
2084 false /*use_global_pools */);
2085 }
2086
TEST(DirectSessionTest,TestSessionInterOpGlobalPools)2087 TEST(DirectSessionTest, TestSessionInterOpGlobalPools) {
2088 TestSessionInterOpThreadsImpl(false /* use_function_lib */,
2089 true /*use_global_pools */);
2090 }
2091
TEST(DirectSessionTest,TestSessionInterOpGlobalPoolsWithFunctions)2092 TEST(DirectSessionTest, TestSessionInterOpGlobalPoolsWithFunctions) {
2093 TestSessionInterOpThreadsImpl(true /* use_function_lib */,
2094 true /*use_global_pools */);
2095 }
2096
TEST(DirectSessionTest,TestSessionInterOpThreadsInvalidOptions)2097 TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
2098 Graph g(OpRegistry::Global());
2099 Tensor t(DT_FLOAT, TensorShape({}));
2100 t.scalar<float>()() = {1.2f};
2101 Node* x = test::graph::Constant(&g, t);
2102 GraphDef def;
2103 g.ToGraphDef(&def);
2104
2105 SessionOptions options;
2106 options.config.mutable_graph_options()
2107 ->mutable_optimizer_options()
2108 ->set_opt_level(OptimizerOptions::L0);
2109 (*options.config.mutable_device_count())["CPU"] = 2;
2110
2111 options.config.add_session_inter_op_thread_pool();
2112
2113 // Wrong pool number on Run call.
2114 {
2115 std::unique_ptr<Session> session(NewSession(options));
2116 TF_ASSERT_OK(session->Create(def));
2117 for (int pool_num = -2; pool_num <= 1; pool_num += 3) {
2118 RunOptions run_options;
2119 run_options.set_inter_op_thread_pool(pool_num);
2120 std::vector<Tensor> outputs;
2121 Status s = session->Run(run_options, {} /* inputs */,
2122 {x->name() + ":0"} /* output_names */, {},
2123 &outputs, nullptr /* run_metadata */);
2124 EXPECT_EQ(
2125 strings::StrCat("Invalid argument: Invalid inter_op_thread_pool: ",
2126 pool_num),
2127 s.ToString());
2128 }
2129 }
2130
2131 // Global name changes thread count.
2132 std::vector<std::unique_ptr<Session>> sessions;
2133 auto* pool_config = options.config.mutable_session_inter_op_thread_pool(0);
2134 pool_config->set_num_threads(0);
2135 pool_config->set_global_name("foo");
2136 sessions.emplace_back(NewSession(options));
2137 TF_ASSERT_OK(sessions.back()->Create(def));
2138 sessions.emplace_back(NewSession(options)); // repeat creation, okay.
2139 TF_ASSERT_OK(sessions.back()->Create(def));
2140 for (int pass = 0; pass < 2; ++pass) {
2141 for (int i = 1; i < 128; ++i) {
2142 pool_config->set_num_threads(i);
2143 sessions.emplace_back(NewSession(options));
2144 auto status = sessions.back()->Create(def);
2145 ASSERT_FALSE(status.ok()) << status;
2146 }
2147
2148 // Clear existing sessions before second pass; error still happens.
2149 sessions.clear();
2150 }
2151 }
2152
TEST(DirectSessionTest,TestDirectSessionRunClose)2153 TEST(DirectSessionTest, TestDirectSessionRunClose) {
2154 // Construct a graph with a variable and a single assign.
2155 Graph g(OpRegistry::Global());
2156 Tensor t(DT_FLOAT, TensorShape({}));
2157 t.scalar<float>()() = {1.2f};
2158 Node* var_val = test::graph::Constant(&g, t);
2159 Node* var = test::graph::Var(&g, DT_FLOAT, {});
2160 Node* var_assign = test::graph::Assign(&g, var, var_val);
2161 GraphDef def;
2162 g.ToGraphDef(&def);
2163
2164 SessionOptions options;
2165 (*options.config.mutable_device_count())["CPU"] = 2;
2166 std::unique_ptr<Session> session(NewSession(options));
2167 ASSERT_TRUE(session != nullptr);
2168 TF_ASSERT_OK(session->Create(def));
2169
2170 // Assign a value to the var.
2171 TF_ASSERT_OK(session->Run({} /* inputs */, {},
2172 {var_assign->name()} /* target_nodes */, nullptr));
2173
2174 // Run a read on the variable to ensure that it works.
2175 std::vector<Tensor> outputs;
2176 TF_ASSERT_OK(session->Run(
2177 {} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
2178 EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
2179 outputs.clear();
2180
2181 // Make a callable handle before closing the session.
2182 Session::CallableHandle handle;
2183 TF_ASSERT_OK(session->MakeCallable(
2184 MakeCallableOptions({}, {}, {var_assign->name()}), &handle));
2185
2186 // Close the session.
2187 TF_ASSERT_OK(session->Close());
2188
2189 // Run the read on the variable to get an error.
2190 Status s = session->Run({} /* inputs */, {},
2191 {var_assign->name()} /* target_nodes */, nullptr);
2192 EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
2193
2194 // Run the read as a callable to verify that we get the same error.
2195 s = session->RunCallable(handle, {}, {}, nullptr);
2196 EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
2197 }
2198
TEST(DirectSessionTest,TestDirectSessionPRunClose)2199 TEST(DirectSessionTest, TestDirectSessionPRunClose) {
2200 GraphDef def;
2201 Graph g(OpRegistry::Global());
2202
2203 Tensor first_value(DT_FLOAT, TensorShape({}));
2204 first_value.scalar<float>()() = 1.0;
2205 Node* first_const = test::graph::Constant(&g, first_value);
2206 Node* first_identity = test::graph::Identity(&g, first_const);
2207
2208 Tensor second_value(DT_FLOAT, TensorShape({}));
2209 second_value.scalar<float>()() = 2.0;
2210 Node* second_const = test::graph::Constant(&g, second_value);
2211 Node* second_identity = test::graph::Identity(&g, second_const);
2212
2213 Node* third = test::graph::Add(&g, first_identity, second_identity);
2214 Node* third_identity = test::graph::Identity(&g, third);
2215
2216 g.ToGraphDef(&def);
2217
2218 auto session = CreateSession();
2219 ASSERT_TRUE(session != nullptr);
2220 TF_ASSERT_OK(session->Create(def));
2221
2222 std::vector<Tensor> outputs;
2223
2224 string handle;
2225 Status s = session->PRunSetup(
2226 {first_const->name(), second_const->name()},
2227 {first_identity->name() + ":0", second_identity->name() + ":0",
2228 third_identity->name() + ":0"},
2229 {}, &handle);
2230 TF_ASSERT_OK(s);
2231
2232 Tensor value_11(DT_FLOAT, TensorShape({}));
2233 value_11.scalar<float>()() = 11.0;
2234 Tensor value_22(DT_FLOAT, TensorShape({}));
2235 value_22.scalar<float>()() = 22.0;
2236
2237 // Close the session.
2238 TF_ASSERT_OK(session->Close());
2239
2240 // Feed first_const, fetch first_identity
2241 s = session->PRun(handle, {{first_const->name(), value_11}},
2242 {first_identity->name() + ":0"}, &outputs);
2243 EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
2244 }
2245
TEST(DirectSessionTest,TestDirectSessionReset)2246 TEST(DirectSessionTest, TestDirectSessionReset) {
2247 // Construct a graph with a variable and a single assign.
2248 Graph g(OpRegistry::Global());
2249 Tensor t(DT_FLOAT, TensorShape({}));
2250 t.scalar<float>()() = {1.2f};
2251 Node* var_val = test::graph::Constant(&g, t);
2252 Node* var = test::graph::Var(&g, DT_FLOAT, {});
2253 Node* var_assign = test::graph::Assign(&g, var, var_val);
2254 GraphDef def;
2255 g.ToGraphDef(&def);
2256
2257 SessionOptions options;
2258 (*options.config.mutable_device_count())["CPU"] = 2;
2259 std::unique_ptr<Session> session(NewSession(options));
2260 ASSERT_TRUE(session != nullptr);
2261 TF_ASSERT_OK(session->Create(def));
2262
2263 // Assign a value to the var.
2264 TF_ASSERT_OK(session->Run({} /* inputs */, {},
2265 {var_assign->name()} /* target_nodes */, nullptr));
2266
2267 // Run a read on the variable to ensure that it works.
2268 std::vector<Tensor> outputs;
2269 TF_ASSERT_OK(session->Run(
2270 {} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
2271 EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
2272 outputs.clear();
2273
2274 // Reset the containers.
2275 TF_EXPECT_OK(Reset(options, {}));
2276
2277 // Run the read on the variable to get an error.
2278 // TODO(suharshs): This test only works because we close the Session in Reset.
2279 // If we change the behavior of Reset to not close the Session, this test will
2280 // fail, since the Variable buffer is cached by var.
2281 Status s = session->Run({} /* inputs */, {},
2282 {var_assign->name()} /* target_nodes */, nullptr);
2283 EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
2284 }
2285
TEST(DirectSessionTest,LocalDeviceManager)2286 TEST(DirectSessionTest, LocalDeviceManager) {
2287 SessionOptions options;
2288 std::unique_ptr<Session> session(NewSession(options));
2289
2290 const DeviceMgr* mgr = nullptr;
2291 TF_ASSERT_OK(session->LocalDeviceManager(&mgr));
2292 ASSERT_TRUE(mgr != nullptr);
2293 EXPECT_GT(mgr->ListDevices().size(), 0);
2294 }
2295
2296 // y = tf.square(x)
CreateGraphForYEqualsXSquared()2297 GraphDef CreateGraphForYEqualsXSquared() {
2298 GraphDef graph_def;
2299 const char* text_proto = R"EOF(
2300 node {
2301 name: "x"
2302 op: "Placeholder"
2303 attr { key: "dtype" value { type: DT_FLOAT } }
2304 attr { key: "shape" value { shape { unknown_rank: true } } }
2305 }
2306 node {
2307 name: "y"
2308 op: "Square"
2309 input: "x"
2310 attr { key: "T" value { type: DT_FLOAT } }
2311 }
2312 versions {
2313 producer: 26
2314 }
2315 )EOF";
2316
2317 QCHECK(protobuf::TextFormat::ParseFromString(text_proto, &graph_def));
2318 return graph_def;
2319 }
2320
2321 // A graph that consumes and produces string tensors
2322 // (which are not GPU-compatible, i.e., there are no
2323 // GPU kernels for these operations).
IsCUDATensor(const Tensor & t)2324 bool IsCUDATensor(const Tensor& t) {
2325 #ifdef GOOGLE_CUDA
2326 cudaPointerAttributes attributes;
2327 cudaError_t err =
2328 cudaPointerGetAttributes(&attributes, t.tensor_data().data());
2329 if (err == cudaErrorInvalidValue) return false;
2330 CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err);
2331 return (attributes.type == cudaMemoryTypeDevice);
2332 #elif TENSORFLOW_USE_ROCM
2333 hipPointerAttribute_t attributes;
2334 hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data());
2335 if (err == hipErrorInvalidValue) return false;
2336 CHECK_EQ(hipSuccess, err) << hipGetErrorString(err);
2337 return (attributes.memoryType == hipMemoryTypeDevice);
2338 #else
2339 return false;
2340 #endif
2341 }
2342
GPUDeviceName(Session * session)2343 string GPUDeviceName(Session* session) {
2344 std::vector<DeviceAttributes> devices;
2345 TF_CHECK_OK(session->ListDevices(&devices));
2346 for (const DeviceAttributes& d : devices) {
2347 if (d.device_type() == "GPU" || d.device_type() == "gpu") {
2348 return d.name();
2349 }
2350 }
2351 return "";
2352 }
2353
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory)2354 TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory) {
2355 std::unique_ptr<Session> session(NewSession(SessionOptions()));
2356 const string gpu_device_name = GPUDeviceName(session.get());
2357 if (gpu_device_name.empty()) {
2358 LOG(INFO) << "Skipping test since no GPU is available";
2359 return;
2360 }
2361
2362 TF_ASSERT_OK(session->Create(CreateGraphForYEqualsXSquared()));
2363
2364 CallableOptions opts;
2365 opts.add_feed("x:0");
2366 opts.add_fetch("y:0");
2367
2368 Tensor gpu_tensor;
2369
2370 {
2371 Session::CallableHandle feed_cpu_fetch_gpu;
2372 opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2373 opts.set_fetch_skip_sync(true);
2374 TF_ASSERT_OK(session->MakeCallable(opts, &feed_cpu_fetch_gpu));
2375 Tensor input(DT_FLOAT, {});
2376 input.scalar<float>()() = 2.0f;
2377 std::vector<Tensor> outputs;
2378 TF_ASSERT_OK(
2379 session->RunCallable(feed_cpu_fetch_gpu, {input}, &outputs, nullptr));
2380 TF_ASSERT_OK(session->ReleaseCallable(feed_cpu_fetch_gpu));
2381 ASSERT_EQ(1, outputs.size());
2382 gpu_tensor = outputs[0];
2383 ASSERT_TRUE(IsCUDATensor(gpu_tensor));
2384 }
2385
2386 {
2387 Session::CallableHandle feed_gpu_fetch_cpu;
2388 opts.clear_fetch_devices();
2389 opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2390 TF_ASSERT_OK(session->MakeCallable(opts, &feed_gpu_fetch_cpu));
2391 std::vector<Tensor> outputs;
2392 TF_ASSERT_OK(session->RunCallable(feed_gpu_fetch_cpu, {gpu_tensor},
2393 &outputs, nullptr));
2394 TF_ASSERT_OK(session->ReleaseCallable(feed_gpu_fetch_cpu));
2395 ASSERT_EQ(1, outputs.size());
2396 // The output is in CPU/host memory, so it can be dereferenced.
2397 ASSERT_EQ(16.0, outputs[0].scalar<float>()());
2398 }
2399 }
2400
CreateIdentityGraphDef(DataType dtype)2401 GraphDef CreateIdentityGraphDef(DataType dtype) {
2402 GraphDef def;
2403
2404 AttrValue dtype_attr;
2405 dtype_attr.set_type(dtype);
2406
2407 AttrValue shape_attr;
2408 shape_attr.mutable_shape()->set_unknown_rank(true);
2409
2410 auto* placeholder = def.add_node();
2411 placeholder->set_name("x");
2412 placeholder->set_op("Placeholder");
2413 placeholder->mutable_attr()->insert({"dtype", dtype_attr});
2414 placeholder->mutable_attr()->insert({"shape", shape_attr});
2415
2416 auto* identity = def.add_node();
2417 identity->set_name("y");
2418 identity->set_op("Identity");
2419 identity->add_input("x");
2420 identity->mutable_attr()->insert({"T", dtype_attr});
2421
2422 return def;
2423 }
2424
TestFeedAndFetchTensorsInDeviceMemory(const SessionOptions & session_options,DataType dtype)2425 void TestFeedAndFetchTensorsInDeviceMemory(
2426 const SessionOptions& session_options, DataType dtype) {
2427 std::unique_ptr<Session> session(NewSession(session_options));
2428 const string gpu_device_name = GPUDeviceName(session.get());
2429 if (gpu_device_name.empty()) {
2430 LOG(INFO) << "Skipping test since no GPU is available";
2431 return;
2432 }
2433
2434 TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
2435 << DataType_Name(dtype);
2436
2437 CallableOptions opts;
2438 opts.add_feed("x:0");
2439 opts.add_fetch("y:0");
2440
2441 Tensor gpu_tensor;
2442 Tensor host_tensor(dtype, {3});
2443 {
2444 // Ask for the fetched tensor to be backed by device memory.
2445 // Even though the kernel that created the tensor produced it in host
2446 // memory.
2447 opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2448 opts.set_fetch_skip_sync(true);
2449 Session::CallableHandle handle;
2450 TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
2451 std::vector<Tensor> outputs;
2452 TF_ASSERT_OK(session->RunCallable(handle, {host_tensor}, &outputs, nullptr))
2453 << DataType_Name(dtype);
2454 TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
2455 ASSERT_EQ(1, outputs.size()) << DataType_Name(dtype);
2456 gpu_tensor = outputs[0];
2457 ASSERT_TRUE(IsCUDATensor(gpu_tensor)) << DataType_Name(dtype);
2458 }
2459
2460 {
2461 // Feed a tensor backed by device memory, even though the operations in the
2462 // graph expect it in host memory.
2463 opts.clear_fetch_devices();
2464 opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2465 Session::CallableHandle handle;
2466 TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
2467 std::vector<Tensor> outputs;
2468 TF_ASSERT_OK(session->RunCallable(handle, {gpu_tensor}, &outputs, nullptr))
2469 << DataType_Name(dtype);
2470 TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
2471 ASSERT_EQ(1, outputs.size());
2472 const StringPiece actual_data = outputs[0].tensor_data();
2473 const StringPiece expected_data = host_tensor.tensor_data();
2474 EXPECT_EQ(expected_data.size(), actual_data.size()) << DataType_Name(dtype);
2475 EXPECT_EQ(0, memcmp(expected_data.data(), actual_data.data(),
2476 std::min(expected_data.size(), actual_data.size())))
2477 << DataType_Name(dtype);
2478 }
2479 }
2480
TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(const SessionOptions & session_options,DataType dtype)2481 void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(
2482 const SessionOptions& session_options, DataType dtype) {
2483 std::unique_ptr<Session> session(NewSession(session_options));
2484 const string gpu_device_name = GPUDeviceName(session.get());
2485 if (gpu_device_name.empty()) {
2486 LOG(INFO) << "Skipping test since no GPU is available";
2487 return;
2488 }
2489
2490 TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
2491 << DataType_Name(dtype);
2492
2493 CallableOptions opts;
2494 opts.add_feed("x:0");
2495 opts.add_fetch("y:0");
2496
2497 // Fail when asking to fetch into GPU memory.
2498 {
2499 opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
2500 opts.set_fetch_skip_sync(true);
2501 Session::CallableHandle handle;
2502 Status status = session->MakeCallable(opts, &handle);
2503 EXPECT_FALSE(status.ok()) << DataType_Name(dtype);
2504 EXPECT_TRUE(absl::StrContains(
2505 status.error_message(),
2506 strings::StrCat(
2507 "Cannot feed or fetch tensor 'y:0' from device ", gpu_device_name,
2508 " as feeding/fetching from GPU devices is not yet supported for ",
2509 DataTypeString(dtype), " tensors")))
2510 << DataType_Name(dtype) << ", Status: " << status;
2511 }
2512
2513 // Fail when feeding from GPU memory.
2514 {
2515 opts.clear_feed_devices();
2516 opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
2517 Session::CallableHandle handle;
2518 Status status = session->MakeCallable(opts, &handle);
2519 EXPECT_FALSE(status.ok());
2520 EXPECT_TRUE(absl::StrContains(
2521 status.error_message(),
2522 strings::StrCat(
2523 "Cannot feed or fetch tensor 'x:0' from device ", gpu_device_name,
2524 " as feeding/fetching from GPU devices is not yet supported for ",
2525 DataTypeString(dtype), " tensors")))
2526 << DataType_Name(dtype) << ", Status: " << status;
2527 }
2528 }
2529
TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(const SessionOptions & opts)2530 void TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(
2531 const SessionOptions& opts) {
2532 // Feeding/fetching on device does not work for all DataTypes as it
2533 // relies on the implementation of the _Arg and _Retval kernels which
2534 // are not registered for some types or consume/produce inputs/outputs
2535 // in host memory for some types.
2536 //
2537 // Run through all datatypes to validate that either:
2538 // (a) MakeCallable fails (because the given type cannot be fed/fetched
2539 // in device memory),
2540 // OR
2541 // (b) Succeeds: RunCallable should gladly accept inputs in device memory
2542 // and produce output tensors in device memory.
2543 for (int i = DataType_MIN; i <= DataType_MAX; ++i) {
2544 if (!DataType_IsValid(i)) continue;
2545 const DataType dtype = static_cast<DataType>(i);
2546 switch (dtype) {
2547 case DT_INVALID:
2548 break;
2549 case DT_BFLOAT16:
2550 case DT_BOOL:
2551 case DT_COMPLEX128:
2552 case DT_COMPLEX64:
2553 case DT_DOUBLE:
2554 case DT_FLOAT:
2555 case DT_HALF:
2556 case DT_INT16:
2557 case DT_INT64:
2558 case DT_INT8:
2559 case DT_UINT16:
2560 case DT_UINT8:
2561 TestFeedAndFetchTensorsInDeviceMemory(opts, dtype);
2562 break;
2563 default:
2564 // Ignore all REF types since Tensors of this type aren't intended to
2565 // be fed (and attempting to create one via the Tensor constructor
2566 // will result in a LOG(FATAL)).
2567 if (!IsRefType(dtype)) {
2568 TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(opts, dtype);
2569 }
2570 break;
2571 }
2572 }
2573 }
2574
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory_AllDataTypes)2575 TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory_AllDataTypes) {
2576 SessionOptions opts;
2577 opts.config.set_allow_soft_placement(false);
2578 TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
2579 }
2580
TEST(DirectSessionTest,FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement)2581 TEST(DirectSessionTest,
2582 FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement) {
2583 SessionOptions opts;
2584 opts.config.set_allow_soft_placement(true);
2585 TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
2586 }
2587
2588 // A simple benchmark for the overhead of `DirectSession::Run()` calls
2589 // with varying numbers of feeds/fetches.
FeedFetchBenchmarkHelper(::testing::benchmark::State & state,int num_feeds,bool use_make_callable,int inter_op_threads,bool use_single_threaded_executor)2590 void FeedFetchBenchmarkHelper(::testing::benchmark::State& state, int num_feeds,
2591 bool use_make_callable, int inter_op_threads,
2592 bool use_single_threaded_executor) {
2593 Tensor value(DT_FLOAT, TensorShape());
2594 value.flat<float>()(0) = 37.0;
2595
2596 std::vector<std::pair<string, Tensor>> inputs;
2597 inputs.reserve(num_feeds);
2598 std::vector<string> outputs;
2599
2600 Graph g(OpRegistry::Global());
2601 for (int i = 0; i < num_feeds; ++i) {
2602 // NOTE(mrry): We pin nodes to the "/cpu:0" device, so as not to
2603 // measure CPU<->GPU copying overhead. We should also optimize and
2604 // monitor this overhead where possible, but that is not the
2605 // object of study in this benchmark.
2606 Node* placeholder;
2607 TF_CHECK_OK(NodeBuilder(g.NewName("Placeholder"), "Placeholder")
2608 .Attr("shape", TensorShape())
2609 .Attr("dtype", DT_FLOAT)
2610 .Device("/cpu:0")
2611 .Finalize(&g, &placeholder));
2612 Node* identity;
2613 TF_CHECK_OK(NodeBuilder(g.NewName("Identity"), "Identity")
2614 .Input(placeholder)
2615 .Attr("T", DT_FLOAT)
2616 .Device("/cpu:0")
2617 .Finalize(&g, &identity));
2618 inputs.push_back({placeholder->name() + ":0", value});
2619 outputs.push_back(identity->name() + ":0");
2620 }
2621 GraphDef gd;
2622 g.ToGraphDef(&gd);
2623 SessionOptions opts;
2624 opts.config.set_inter_op_parallelism_threads(inter_op_threads);
2625 if (use_single_threaded_executor) {
2626 opts.config.mutable_experimental()->set_executor_type(
2627 "SINGLE_THREADED_EXECUTOR");
2628 }
2629 std::unique_ptr<Session> session(NewSession(opts));
2630 TF_CHECK_OK(session->Create(gd));
2631 if (use_make_callable) {
2632 Session::CallableHandle handle;
2633 CallableOptions callable_options;
2634 std::vector<Tensor> input_tensors;
2635 for (const auto& input : inputs) {
2636 callable_options.add_feed(input.first);
2637 input_tensors.push_back(input.second);
2638 }
2639 for (const string& output : outputs) {
2640 callable_options.add_fetch(output);
2641 }
2642 TF_CHECK_OK(session->MakeCallable(callable_options, &handle));
2643
2644 for (auto s : state) {
2645 std::vector<Tensor> output_values;
2646 TF_CHECK_OK(
2647 session->RunCallable(handle, input_tensors, &output_values, nullptr));
2648 }
2649 } else {
2650 {
2651 // NOTE(mrry): Ignore the first run, which will incur the graph
2652 // partitioning/pruning overhead and skew the results.
2653 //
2654 // Note that we should also optimize and monitor the overhead on
2655 // the first run, which will impact application startup times, but
2656 // that is not the object of study in this benchmark.
2657 std::vector<Tensor> output_values;
2658 TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
2659 }
2660
2661 for (auto s : state) {
2662 std::vector<Tensor> output_values;
2663 TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
2664 }
2665 }
2666 }
2667
BM_FeedFetch(::testing::benchmark::State & state)2668 void BM_FeedFetch(::testing::benchmark::State& state) {
2669 const int num_feeds = state.range(0);
2670
2671 FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ false,
2672 /* inter_op_threads */ 0,
2673 /* use_single_threaded_executor */ false);
2674 }
BM_FeedFetchCallable(::testing::benchmark::State & state)2675 void BM_FeedFetchCallable(::testing::benchmark::State& state) {
2676 const int num_feeds = state.range(0);
2677
2678 FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2679 /* inter_op_threads */ 0,
2680 /* use_single_threaded_executor */ false);
2681 }
BM_FeedFetchCallableSingleThread(::testing::benchmark::State & state)2682 void BM_FeedFetchCallableSingleThread(::testing::benchmark::State& state) {
2683 const int num_feeds = state.range(0);
2684
2685 FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2686 /* inter_op_threads */ -1,
2687 /* use_single_threaded_executor */ false);
2688 }
BM_FeedFetchCallableSingleThreadExecutor(::testing::benchmark::State & state)2689 void BM_FeedFetchCallableSingleThreadExecutor(
2690 ::testing::benchmark::State& state) {
2691 const int num_feeds = state.range(0);
2692
2693 FeedFetchBenchmarkHelper(state, num_feeds, /* use_make_callable */ true,
2694 /* inter_op_threads */ -1,
2695 /* use_single_threaded_executor */ true);
2696 }
2697
2698 BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2699 BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2700 BENCHMARK(BM_FeedFetchCallableSingleThread)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
2701 BENCHMARK(BM_FeedFetchCallableSingleThreadExecutor)
2702 ->Arg(1)
2703 ->Arg(2)
2704 ->Arg(5)
2705 ->Arg(10);
2706
2707 } // namespace
2708
2709 class DirectSessionCollectiveTest : public ::testing::Test {
2710 public:
2711 // Creates a graph with CollectiveOps inside functions and runs it. Returns
2712 // the generated collective_graph_key.
RunGraphWithCollectiveFunctions(bool add_unused_function,int64 * collective_graph_key)2713 Status RunGraphWithCollectiveFunctions(bool add_unused_function,
2714 int64* collective_graph_key) {
2715 GraphDef g = CreateGraph(add_unused_function);
2716 const Tensor t1 =
2717 test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
2718 const Tensor t2 =
2719 test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
2720 auto session = CreateSession();
2721 TF_RETURN_IF_ERROR(session->Create(g));
2722 std::vector<Tensor> outputs;
2723 TF_RETURN_IF_ERROR(
2724 session->Run({{"input0:0", t1}, {"input1:0", t2}}, {},
2725 {"collective_call0:0", "collective_call1:0"}, &outputs));
2726 DirectSession* direct_session = static_cast<DirectSession*>(session.get());
2727 {
2728 mutex_lock l(direct_session->collective_graph_key_lock_);
2729 *collective_graph_key = direct_session->collective_graph_key_;
2730 }
2731 return Status::OK();
2732 }
2733
2734 private:
2735 // Creates a function with name `function_name` and a single CollectiveReduce
2736 // node with instance key set as `instance_key`.
CollectiveFunction(const string & function_name,int instance_key)2737 FunctionDef CollectiveFunction(const string& function_name,
2738 int instance_key) {
2739 return FunctionDefHelper::Define(
2740 // Function name
2741 function_name,
2742 // In def
2743 {"arg:float"},
2744 // Out def
2745 {"reduce:float"},
2746 // Attr def
2747 {},
2748 // Node def
2749 {{
2750 {"reduce"},
2751 "CollectiveReduce",
2752 {"arg"},
2753 {{"group_size", 2},
2754 {"group_key", 1},
2755 {"instance_key", instance_key},
2756 {"subdiv_offsets", gtl::ArraySlice<int32>({0})},
2757 {"merge_op", "Add"},
2758 {"final_op", "Div"},
2759 {"T", DT_FLOAT}},
2760 }});
2761 }
2762
Input(int id)2763 NodeDef Input(int id) {
2764 AttrValue dtype_attr;
2765 SetAttrValue(DT_FLOAT, &dtype_attr);
2766 NodeDef input;
2767 input.set_name(strings::StrCat("input", id));
2768 input.set_op("Placeholder");
2769 input.mutable_attr()->insert({"dtype", dtype_attr});
2770 return input;
2771 }
2772
CollectiveCall(const string & op,const string & input,int cpu_id)2773 NodeDef CollectiveCall(const string& op, const string& input, int cpu_id) {
2774 NodeDef collective_call;
2775 collective_call.set_name(strings::StrCat("collective_call", cpu_id));
2776 collective_call.set_op(op);
2777 collective_call.add_input(input);
2778 collective_call.set_device(
2779 strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", cpu_id));
2780 return collective_call;
2781 }
2782
2783 // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
2784 // CPU1, with instance_key 1, and appropriate placeholder inputs. If
2785 // `add_unused_function` is true, adds another CollectiveFunction with
2786 // instance_key 2 that is not invoked in the graph.
CreateGraph(bool add_unused_function)2787 GraphDef CreateGraph(bool add_unused_function) {
2788 GraphDef g;
2789 FunctionDef collective_function =
2790 CollectiveFunction("CollectiveFunction1", 1);
2791 FunctionDefLibrary* lib = g.mutable_library();
2792 *lib->add_function() = collective_function;
2793 if (add_unused_function) {
2794 FunctionDef unused_function =
2795 CollectiveFunction("CollectiveFunction2", 2);
2796 *lib->add_function() = unused_function;
2797 }
2798
2799 *g.add_node() = Input(0);
2800 *g.add_node() = Input(1);
2801 // CollectiveReduce on CPU0 with instance_key 1.
2802 *g.add_node() = CollectiveCall("CollectiveFunction1", "input0", 0);
2803 // CollectiveReduce on CPU1 with instance_key 1.
2804 *g.add_node() = CollectiveCall("CollectiveFunction1", "input1", 1);
2805
2806 return g;
2807 }
2808 };
2809
TEST_F(DirectSessionCollectiveTest,TestCollectiveGraphKeyUsesOnlyCalledFunctions)2810 TEST_F(DirectSessionCollectiveTest,
2811 TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
2812 int64 key1;
2813 TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
2814 int64 key2;
2815 TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
2816 ASSERT_EQ(key1, key2);
2817 }
2818
2819 // Accesses the cancellation manager for the step after the step has been
2820 // cancelled.
2821 class StatefulOutputRequiredOp : public OpKernel {
2822 public:
StatefulOutputRequiredOp(OpKernelConstruction * ctx)2823 explicit StatefulOutputRequiredOp(OpKernelConstruction* ctx)
2824 : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)2825 void Compute(OpKernelContext* ctx) override {
2826 // The op counts the number of outputs required in the current subgraph,
2827 // and emits that number on each of its required outputs.
2828 Tensor count_outputs_required_t(int64{0});
2829 int64& count_outputs_required = count_outputs_required_t.scalar<int64>()();
2830 for (int i = 0; i < num_outputs(); ++i) {
2831 if (ctx->output_required(i)) ++count_outputs_required;
2832 }
2833 for (int i = 0; i < num_outputs(); ++i) {
2834 if (ctx->output_required(i)) ctx->set_output(i, count_outputs_required_t);
2835 }
2836 }
2837 };
2838
2839 REGISTER_KERNEL_BUILDER(Name("StatefulOutputRequired").Device(DEVICE_CPU),
2840 StatefulOutputRequiredOp);
2841 REGISTER_OP("StatefulOutputRequired")
2842 .Output("results : num_outs * int64")
2843 .Attr("num_outs : int = 5")
2844 .SetIsStateful();
2845
TEST(DirectSessionTest,TestStatefulOutputRequiredOp)2846 TEST(DirectSessionTest, TestStatefulOutputRequiredOp) {
2847 GraphDef graph;
2848 // Creates a graph with a StatefulOutputRequired op with 5 outputs.
2849 protobuf::TextFormat::ParseFromString(
2850 R"proto(
2851 node { name: 'n' op: 'StatefulOutputRequired' device: '/device:CPU:0' }
2852 versions { producer: 9 }
2853 )proto",
2854 &graph);
2855
2856 std::unique_ptr<Session> session(NewSession(SessionOptions()));
2857 ASSERT_TRUE(session != nullptr);
2858 TF_ASSERT_OK(session->Create(std::move(graph)));
2859
2860 // As a stateful op, a single StatefulOutputRequired kernel will be created
2861 // and shared across multiple subgraphs. We create 5 different subgraphs,
2862 // fetching different prefixes of the output of the op.
2863 for (int num_outputs_required = 1; num_outputs_required <= 5;
2864 ++num_outputs_required) {
2865 std::vector<string> fetch_tensor_names;
2866 fetch_tensor_names.reserve(num_outputs_required);
2867 for (int output_idx = 0; output_idx < num_outputs_required; ++output_idx) {
2868 fetch_tensor_names.push_back(strings::StrCat("n:", output_idx));
2869 }
2870 std::vector<Tensor> fetch_tensors;
2871 TF_ASSERT_OK(session->Run({}, fetch_tensor_names, {}, &fetch_tensors));
2872 ASSERT_EQ(num_outputs_required, fetch_tensors.size());
2873 for (const Tensor& t : fetch_tensors) {
2874 ASSERT_EQ(num_outputs_required, t.scalar<int64>()());
2875 }
2876 }
2877
2878 TF_ASSERT_OK(session->Close());
2879 }
2880
2881 } // namespace tensorflow
2882