1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/graph/testlib.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/core/public/session.h"
28
29 #if GOOGLE_CUDA && GOOGLE_TENSORRT
30
31 namespace tensorflow {
32 namespace tensorrt {
33 namespace segment {
34 namespace test {
35
36 class SegmentTest : public ::testing::Test {
37 protected:
MakeCandidateFn(const std::set<string> & node_names)38 std::function<Status(const Node*)> MakeCandidateFn(
39 const std::set<string>& node_names) {
40 return [node_names](const Node* node) -> Status {
41 if (node_names.find(node->name()) != node_names.end()) {
42 return Status::OK();
43 }
44 return errors::NotFound("Not a user specified candidate");
45 };
46 }
47
MakeInputEdgeCandidateFn(const std::set<string> & node_names)48 std::function<bool(const Edge*)> MakeInputEdgeCandidateFn(
49 const std::set<string>& node_names) {
50 return [node_names](const Edge* in_edge) -> bool {
51 return node_names.find(in_edge->dst()->name()) != node_names.end();
52 };
53 }
54
MakeOutputEdgeCandidateFn(const std::set<string> & node_names)55 std::function<bool(const Edge*)> MakeOutputEdgeCandidateFn(
56 const std::set<string>& node_names) {
57 return [node_names](const Edge* out_edge) -> bool {
58 return node_names.find(out_edge->src()->name()) != node_names.end();
59 };
60 }
61
RunTest(const Graph * graph,const grappler::GraphProperties * graph_properties,const std::set<string> & candidates,const std::set<string> & input_candidates,const std::set<string> & output_candidates,const std::vector<std::set<string>> & expected_segments)62 void RunTest(const Graph* graph,
63 const grappler::GraphProperties* graph_properties,
64 const std::set<string>& candidates,
65 const std::set<string>& input_candidates,
66 const std::set<string>& output_candidates,
67 const std::vector<std::set<string>>& expected_segments) {
68 SegmentVector segments;
69 TF_EXPECT_OK(SegmentGraph(graph, graph_properties,
70 MakeCandidateFn(candidates),
71 MakeInputEdgeCandidateFn(input_candidates),
72 MakeOutputEdgeCandidateFn(output_candidates),
73 segment_options_, &segments));
74 ValidateSegment(segments, expected_segments);
75 }
76
RunTest(const Graph * graph,const std::set<string> & candidates,const std::set<string> & input_candidates,const std::set<string> & output_candidates,const std::vector<std::set<string>> & expected_segments)77 void RunTest(const Graph* graph, const std::set<string>& candidates,
78 const std::set<string>& input_candidates,
79 const std::set<string>& output_candidates,
80 const std::vector<std::set<string>>& expected_segments) {
81 RunTest(graph, nullptr, candidates, input_candidates, output_candidates,
82 expected_segments);
83 }
84
ValidateSegment(const SegmentVector & segments,const std::vector<std::set<string>> & expected_segments)85 void ValidateSegment(const SegmentVector& segments,
86 const std::vector<std::set<string>>& expected_segments) {
87 EXPECT_EQ(expected_segments.size(), segments.size());
88 for (int i = 0; i < segments.size(); ++i) {
89 std::set<string> segment_node_names;
90 for (const Node* node : segments[i].nodes) {
91 segment_node_names.insert(node->name());
92 }
93 const auto& expected = expected_segments[i];
94 for (const auto& name : expected) {
95 EXPECT_TRUE(segment_node_names.count(name))
96 << "Segment " << i << " is missing expected node: " << name;
97 }
98 if (segment_node_names.size() == expected.size()) continue;
99 for (const auto& name : segment_node_names) {
100 EXPECT_TRUE(expected.count(name))
101 << "Unexpected node found in segment " << i << ": " << name;
102 }
103 }
104 }
105
DisableImplicitBatchMode()106 void DisableImplicitBatchMode() {
107 segment_options_.use_implicit_batch = false;
108 segment_options_.allow_dynamic_non_batch_dim = true;
109 }
110
EnableImplicitBatchModeForStaticEngine(int maximum_batch_size=1000)111 void EnableImplicitBatchModeForStaticEngine(int maximum_batch_size = 1000) {
112 segment_options_.use_implicit_batch = true;
113 segment_options_.maximum_batch_size = maximum_batch_size;
114 segment_options_.allow_dynamic_non_batch_dim = false;
115 }
116
117 SegmentOptions segment_options_;
118 };
119
operator -(const std::set<string> & lhs,const string & rhs)120 std::set<string> operator-(const std::set<string>& lhs, const string& rhs) {
121 std::set<string> result = lhs;
122 CHECK(result.erase(rhs));
123 return result;
124 }
125
TEST_F(SegmentTest,Empty)126 TEST_F(SegmentTest, Empty) {
127 Scope s = Scope::NewRootScope();
128 Graph g(OpRegistry::Global());
129 TF_EXPECT_OK(s.ToGraph(&g));
130 // Expect no segments/subgraphs.
131 DisableImplicitBatchMode();
132 RunTest(&g, {}, {}, {}, {});
133 }
134
TEST_F(SegmentTest,Simple)135 TEST_F(SegmentTest, Simple) {
136 // feed
137 // // \\
138 // add0 add1
139 // | \ /
140 // | add2
141 // | / \\
142 // add3 add4
143 // \ /
144 // <sink>
145 Scope s = Scope::NewRootScope();
146 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
147 auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
148 auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
149 auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
150 auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
151 auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
152 Graph g(OpRegistry::Global());
153 TF_EXPECT_OK(s.ToGraph(&g));
154
155 // All Add operations are candidates, and we expect all of them to be
156 // collapsed into a single segment
157 const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4"};
158 DisableImplicitBatchMode();
159 RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
160
161 // Make add1 not a candidate, and we expect all other Add operations to be
162 // collapsed into a single segment
163 auto without_add1 = all_adds - "add1";
164 RunTest(&g, without_add1, without_add1, without_add1, {without_add1});
165
166 // Make add1 not a candidate and add2 not an input candidate, and we expect
167 // add0 and add2 are removed from the segment.
168 auto without_add2 = all_adds - "add2";
169 RunTest(&g, without_add1, without_add2, without_add1, {{"add3", "add4"}});
170
171 // Making add2 not an input candidate itself won't affect anything.
172 RunTest(&g, all_adds, without_add2, all_adds, {all_adds});
173
174 // Making add1 not an input candidate.
175 RunTest(&g, all_adds, without_add1, all_adds, {without_add1});
176
177 // Making add3 not an output candidate doesn't affect anything, since it's
178 // output is sink.
179 auto without_add3 = all_adds - "add3";
180 RunTest(&g, all_adds, all_adds, without_add3, {all_adds});
181 }
182
TEST_F(SegmentTest,WithDeviceAssignments)183 TEST_F(SegmentTest, WithDeviceAssignments) {
184 // feed
185 // // \\
186 // add0 add1
187 // | \ /
188 // | add2
189 // | / \\
190 // add3 add4
191 // \ /
192 // <sink>
193 Scope s = Scope::NewRootScope();
194 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
195 auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
196 auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
197 auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
198 auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
199 auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
200
201 const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4"};
202 DisableImplicitBatchMode();
203
204 {
205 Graph g(OpRegistry::Global());
206 TF_EXPECT_OK(s.ToGraph(&g));
207 RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
208 }
209
210 {
211 // Assigning add1 to CPU to exclude it from the cluster.
212 add1.node()->set_assigned_device_name("/device:CPU:0");
213 Graph g(OpRegistry::Global());
214 TF_EXPECT_OK(s.ToGraph(&g));
215 RunTest(&g, all_adds, all_adds, all_adds, {all_adds - "add1"});
216 add1.node()->set_assigned_device_name("");
217 }
218
219 {
220 // Assigning operations add3 and add4 to another GPU to exclude the
221 // operation from the cluster.
222 constexpr char kGpu0[] = "/device:GPU:0";
223 add0.node()->set_assigned_device_name(kGpu0);
224 add1.node()->set_assigned_device_name(kGpu0);
225 add2.node()->set_assigned_device_name(kGpu0);
226 constexpr char kGpu1[] = "/device:GPU:1";
227 add3.node()->set_assigned_device_name(kGpu1);
228 add4.node()->set_assigned_device_name(kGpu1);
229 Graph g(OpRegistry::Global());
230 TF_EXPECT_OK(s.ToGraph(&g));
231 RunTest(&g, all_adds, all_adds, all_adds, {{"add0", "add1", "add2"}});
232 }
233
234 {
235 // Assigning the operations to two compatibile GPU devices resulting in
236 // one cluster with all operations.
237 constexpr char kGpuAny[] = "/device:GPU:*";
238 add3.node()->set_assigned_device_name(kGpuAny);
239 add4.node()->set_assigned_device_name(kGpuAny);
240 Graph g(OpRegistry::Global());
241 TF_EXPECT_OK(s.ToGraph(&g));
242 RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
243 }
244 }
245
TEST_F(SegmentTest,AvoidCycle)246 TEST_F(SegmentTest, AvoidCycle) {
247 // feed
248 // // \\
249 // add0 add1
250 // | \ /
251 // | add2
252 // | / \\
253 // add3 add4
254 // \ /
255 // <sink>
256 Scope s = Scope::NewRootScope();
257 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
258 auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
259 auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
260 auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
261 auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
262 auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
263 Graph g(OpRegistry::Global());
264 TF_EXPECT_OK(s.ToGraph(&g));
265
266 // add2 is not a TRT candidate so there should be no segments generated.
267 const std::set<string> without_add2 = {"add0", "add1", "add3", "add4"};
268 DisableImplicitBatchMode();
269 RunTest(&g, without_add2, without_add2, without_add2, {});
270 }
271
TEST_F(SegmentTest,Multiple)272 TEST_F(SegmentTest, Multiple) {
273 // feed
274 // // || \\
275 // add0 add1 add7
276 // | \ / / \\
277 // | add2 / \\
278 // | || \ | ||
279 // | || add5 add8
280 // | / \ / \ /
281 // add3 add4 add6
282 // \ | /
283 // <sink>
284 Scope s = Scope::NewRootScope();
285 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
286 auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
287 auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
288 auto add7 = ops::Add(s.WithOpName("add7"), feed, feed);
289 auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
290 auto add5 = ops::Add(s.WithOpName("add5"), add2, add7);
291 auto add8 = ops::Add(s.WithOpName("add8"), add7, add7);
292 auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
293 auto add4 = ops::Add(s.WithOpName("add4"), add2, add5);
294 auto add6 = ops::Add(s.WithOpName("add6"), add5, add8);
295 Graph g(OpRegistry::Global());
296 TF_EXPECT_OK(s.ToGraph(&g));
297
298 const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4",
299 "add5", "add6", "add7", "add8"};
300 // Make add5 not a TRT candidate, and we expect two segments.
301 auto without_add5 = all_adds - "add5";
302 DisableImplicitBatchMode();
303 RunTest(&g, without_add5, without_add5, without_add5,
304 {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}});
305
306 // Make add8 not a candidate and add6 not an input candidate, then all direct
307 // and indirect inputs of add6 will be removed from the segment.
308 auto without_add8 = all_adds - "add8";
309 auto without_add6 = all_adds - "add6";
310 RunTest(&g, without_add8, without_add6, all_adds, {{"add3", "add4"}});
311
312 // Make add3 not a candidate and add0 not an output candidate, then all
313 // direct and indirect outputs of add0 will be removed from the segment.
314 auto without_add3 = all_adds - "add3";
315 auto without_add0 = all_adds - "add0";
316 RunTest(&g, without_add3, all_adds, without_add0, {{"add1", "add7", "add8"}});
317 }
318
TEST_F(SegmentTest,BigIfElse)319 TEST_F(SegmentTest, BigIfElse) {
320 // feed
321 // ||
322 // add0
323 // // \\
324 // add1 add4
325 // || ||
326 // add2 add5
327 // || ||
328 // add3 add6
329 // \\ //
330 // add7
331 // ||
332 // <sink>
333 Scope s = Scope::NewRootScope();
334 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
335 auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
336 auto add1 = ops::Add(s.WithOpName("add1"), add0, add0);
337 auto add2 = ops::Add(s.WithOpName("add2"), add1, add1);
338 auto add3 = ops::Add(s.WithOpName("add3"), add2, add2);
339 auto add4 = ops::Add(s.WithOpName("add4"), add0, add0);
340 auto add5 = ops::Add(s.WithOpName("add5"), add4, add4);
341 auto add6 = ops::Add(s.WithOpName("add6"), add5, add5);
342 auto add7 = ops::Add(s.WithOpName("add7"), add3, add6);
343 Graph g(OpRegistry::Global());
344 TF_EXPECT_OK(s.ToGraph(&g));
345
346 // Make add2 not a TRT candidate, and we expect 2 segments.
347 const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
348 "add4", "add5", "add6", "add7"};
349 DisableImplicitBatchMode();
350 RunTest(&g, all_adds - "add2", all_adds, all_adds,
351 {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
352 }
353
TEST_F(SegmentTest,IdentityOps)354 TEST_F(SegmentTest, IdentityOps) {
355 Scope s = Scope::NewRootScope();
356 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
357 auto identity0 = ops::Identity(s.WithOpName("identity0"), feed);
358 auto identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
359 auto identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
360 auto identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
361 Graph g(OpRegistry::Global());
362 TF_EXPECT_OK(s.ToGraph(&g));
363
364 const std::set<string> all_identities = {"identity0", "identity1",
365 "identity2", "identity3"};
366 // Identity ops are not counted as effective ops in the segment, so no segment
367 // will be formed in this case.
368 DisableImplicitBatchMode();
369 RunTest(&g, all_identities, all_identities, all_identities, {});
370 }
371
372 // Testing implicit batch mode segmentation: it excludes the add-2 operation
373 // with a dynamic non-batch dimension.
TEST_F(SegmentTest,ExcludeAddWithDynamicNonBatchDimension)374 TEST_F(SegmentTest, ExcludeAddWithDynamicNonBatchDimension) {
375 Scope s = Scope::NewRootScope();
376 auto feed_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2, 3}));
377 auto feed_1_shape = ops::Placeholder::Shape(PartialTensorShape({-1, -1, 3}));
378 auto const_val = ops::Const<float>(s, {1.0}, {});
379 auto feed_0 =
380 ops::Placeholder(s.WithOpName("feed-1"), DT_FLOAT, feed_0_shape);
381 auto feed_1 =
382 ops::Placeholder(s.WithOpName("feed-2"), DT_FLOAT, feed_1_shape);
383 auto add_0 = ops::Add(s.WithOpName("add-0"), feed_0, const_val);
384 auto add_1 = ops::Add(s.WithOpName("add-1"), add_0, feed_0);
385 auto add_2 = ops::Add(s.WithOpName("add-2"), const_val, feed_1);
386
387 grappler::GrapplerItem item;
388 item.fetch.push_back("add-2");
389 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
390
391 grappler::GraphProperties static_graph_properties(item);
392 TF_EXPECT_OK(static_graph_properties.InferStatically(true));
393
394 Graph g(OpRegistry::Global());
395 TF_CHECK_OK(
396 ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
397
398 const std::set<string> all_nodes = {"add-0", "add-1", "add-2"};
399 EnableImplicitBatchModeForStaticEngine();
400 RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
401 {all_nodes - "add-2"});
402 }
403
404 // Testing implicit batch mode segmentation: It excludes the reshape operation
405 // with a dynamic non-batch output dimension.
406 // TODO(bixia): hoist the check for reshape should not change batch size from
407 // the converter to the segmenter and add another test case for excluding
408 // a reshape without dynamic dimensions involved.
TEST_F(SegmentTest,ExcludeReshapeWithDynamicNonBatchDimensionInOutput)409 TEST_F(SegmentTest, ExcludeReshapeWithDynamicNonBatchDimensionInOutput) {
410 Scope s = Scope::NewRootScope();
411 auto feed_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2, 3}));
412 auto const_val = ops::Const<float>(s, {1.0}, {});
413 auto feed_0 =
414 ops::Placeholder(s.WithOpName("feed-1"), DT_FLOAT, feed_0_shape);
415 auto add_0 = ops::Add(s.WithOpName("add-0"), feed_0, const_val);
416 auto reshape = ops::Reshape(s.WithOpName("reshape"), add_0, Input({6, -1}));
417 auto add_1 = ops::Add(s.WithOpName("add-1"), reshape, const_val);
418
419 grappler::GrapplerItem item;
420 item.fetch.push_back("add-1");
421 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
422
423 grappler::GraphProperties static_graph_properties(item);
424 TF_EXPECT_OK(static_graph_properties.InferStatically(true));
425
426 Graph g(OpRegistry::Global());
427 TF_CHECK_OK(
428 ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
429
430 const std::set<string> all_nodes = {"add-0", "reshape", "add-1"};
431 EnableImplicitBatchModeForStaticEngine();
432 RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {});
433 }
434
TEST_F(SegmentTest,RankOneCannotUseImplicitBatch)435 TEST_F(SegmentTest, RankOneCannotUseImplicitBatch) {
436 Scope s = Scope::NewRootScope();
437 auto input_0_shape = ops::Placeholder::Shape(TensorShape({3}));
438 auto input_1_shape = ops::Placeholder::Shape(TensorShape({3}));
439 auto input_0 =
440 ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
441 auto input_1 =
442 ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
443 auto const_val = ops::Const(s.WithOpName("const-scalar"), 1.0f, {});
444 auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, const_val);
445 auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, const_val);
446
447 grappler::GrapplerItem item;
448 item.fetch.push_back("output-0");
449 item.fetch.push_back("output-1");
450 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
451
452 grappler::GraphProperties static_graph_properties(item);
453 TF_EXPECT_OK(static_graph_properties.InferStatically(true));
454
455 Graph g(OpRegistry::Global());
456 TF_CHECK_OK(
457 ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
458
459 const std::set<string> all_nodes = {"const-scalar", "output-0", "output-1"};
460 EnableImplicitBatchModeForStaticEngine();
461 RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {});
462 }
463
TEST_F(SegmentTest,TwoChainsDiffBatchSizes)464 TEST_F(SegmentTest, TwoChainsDiffBatchSizes) {
465 Scope s = Scope::NewRootScope();
466 auto input_0_shape = ops::Placeholder::Shape(TensorShape({2, 3}));
467 auto input_1_shape = ops::Placeholder::Shape(TensorShape({5, 3}));
468 auto input_0 =
469 ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
470 auto input_1 =
471 ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
472 auto const_val = ops::Const(s.WithOpName("const-scalar"), 1.0f, {});
473 auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, const_val);
474 auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, const_val);
475
476 grappler::GrapplerItem item;
477 item.fetch.push_back("output-0");
478 item.fetch.push_back("output-1");
479 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
480
481 grappler::GraphProperties static_graph_properties(item);
482 TF_EXPECT_OK(static_graph_properties.InferStatically(true));
483
484 Graph g(OpRegistry::Global());
485 TF_CHECK_OK(
486 ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
487
488 const std::set<string> all_nodes = {"const-scalar", "output-0", "output-1"};
489 EnableImplicitBatchModeForStaticEngine();
490 RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
491 /*expected_segments=*/{{"output-0", "const-scalar"}});
492
493 // Converter will create engines based on the static batch size
494 EnableImplicitBatchModeForStaticEngine(1);
495 RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
496 /*expected_segments=*/{{"output-0", "const-scalar"}});
497 }
498
TEST_F(SegmentTest,SameRankImplicitBroadcastingStaticBatchSize)499 TEST_F(SegmentTest, SameRankImplicitBroadcastingStaticBatchSize) {
500 Scope s = Scope::NewRootScope();
501 auto input_0_shape = ops::Placeholder::Shape(TensorShape({2, 3, 1}));
502 auto input_1_shape = ops::Placeholder::Shape(TensorShape({1, 3, 4}));
503 auto input_2_shape = ops::Placeholder::Shape(TensorShape({2, 3, 4}));
504 auto input_0 =
505 ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
506 auto input_1 =
507 ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
508 auto input_2 =
509 ops::Placeholder(s.WithOpName("input-2"), DT_FLOAT, input_2_shape);
510 auto multiple = ops::Mul(s.WithOpName("multiple"), input_2, input_2);
511 auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, multiple);
512 auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, multiple);
513
514 grappler::GrapplerItem item;
515 item.fetch.push_back("output-0");
516 item.fetch.push_back("output-1");
517 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
518
519 grappler::GraphProperties static_graph_properties(item);
520 TF_EXPECT_OK(static_graph_properties.InferStatically(true));
521
522 Graph g(OpRegistry::Global());
523 TF_CHECK_OK(
524 ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
525
526 const std::set<string> all_nodes = {"multiple", "output-0", "output-1"};
527 EnableImplicitBatchModeForStaticEngine();
528 RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
529 {all_nodes});
530 }
531
TEST_F(SegmentTest,SameRankImplicitBroadcastingDynamicBatchSize)532 TEST_F(SegmentTest, SameRankImplicitBroadcastingDynamicBatchSize) {
533 Scope s = Scope::NewRootScope();
534 auto input_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2}));
535 auto input_1_shape = ops::Placeholder::Shape(TensorShape({1, 2}));
536 auto input_0 =
537 ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
538 auto input_1 =
539 ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
540 auto const_val = ops::Const(s.WithOpName("const-val"), 1.0f, {1, 1});
541 auto add_0 = ops::Add(s.WithOpName("add-0"), input_0, const_val);
542 auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, add_0);
543
544 grappler::GrapplerItem item;
545 item.fetch.push_back("output-0");
546 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
547
548 grappler::GraphProperties static_graph_properties(item);
549 TF_EXPECT_OK(static_graph_properties.InferStatically(true));
550
551 Graph g(OpRegistry::Global());
552 TF_CHECK_OK(
553 ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
554
555 const std::set<string> all_nodes = {"const-val", "add-0", "output-0"};
556 EnableImplicitBatchModeForStaticEngine();
557 RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
558 {{"const-val", "add-0", "output-0"}});
559 }
560
TEST_F(SegmentTest,IncompatibleBatchSizes)561 TEST_F(SegmentTest, IncompatibleBatchSizes) {
562 Scope s = Scope::NewRootScope();
563 auto input_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2}));
564 auto input_1_shape = ops::Placeholder::Shape(TensorShape({2, 2}));
565 auto input_0 =
566 ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
567 auto input_1 =
568 ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
569 auto const_val = ops::Const(s.WithOpName("const-val"), 1.0f, {2, 2});
570 auto add_0 = ops::Add(s.WithOpName("add-0"), input_0, const_val);
571 auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, add_0);
572
573 grappler::GrapplerItem item;
574 item.fetch.push_back("output-0");
575 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
576
577 grappler::GraphProperties static_graph_properties(item);
578 TF_EXPECT_OK(static_graph_properties.InferStatically(true));
579
580 Graph g(OpRegistry::Global());
581 TF_CHECK_OK(
582 ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
583
584 const std::set<string> all_nodes = {"const-val", "add-0", "output-0"};
585 EnableImplicitBatchModeForStaticEngine();
586 RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {});
587 }
588 } // namespace test
589 } // namespace segment
590 } // namespace tensorrt
591 } // namespace tensorflow
592
593 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
594