• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17 
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/graph/testlib.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/core/public/session.h"
28 
29 #if GOOGLE_CUDA && GOOGLE_TENSORRT
30 
31 namespace tensorflow {
32 namespace tensorrt {
33 namespace segment {
34 namespace test {
35 
36 class SegmentTest : public ::testing::Test {
37  protected:
MakeCandidateFn(const std::set<string> & node_names)38   std::function<Status(const Node*)> MakeCandidateFn(
39       const std::set<string>& node_names) {
40     return [node_names](const Node* node) -> Status {
41       if (node_names.find(node->name()) != node_names.end()) {
42         return Status::OK();
43       }
44       return errors::NotFound("Not a user specified candidate");
45     };
46   }
47 
MakeInputEdgeCandidateFn(const std::set<string> & node_names)48   std::function<bool(const Edge*)> MakeInputEdgeCandidateFn(
49       const std::set<string>& node_names) {
50     return [node_names](const Edge* in_edge) -> bool {
51       return node_names.find(in_edge->dst()->name()) != node_names.end();
52     };
53   }
54 
MakeOutputEdgeCandidateFn(const std::set<string> & node_names)55   std::function<bool(const Edge*)> MakeOutputEdgeCandidateFn(
56       const std::set<string>& node_names) {
57     return [node_names](const Edge* out_edge) -> bool {
58       return node_names.find(out_edge->src()->name()) != node_names.end();
59     };
60   }
61 
RunTest(const Graph * graph,const grappler::GraphProperties * graph_properties,const std::set<string> & candidates,const std::set<string> & input_candidates,const std::set<string> & output_candidates,const std::vector<std::set<string>> & expected_segments)62   void RunTest(const Graph* graph,
63                const grappler::GraphProperties* graph_properties,
64                const std::set<string>& candidates,
65                const std::set<string>& input_candidates,
66                const std::set<string>& output_candidates,
67                const std::vector<std::set<string>>& expected_segments) {
68     SegmentVector segments;
69     TF_EXPECT_OK(SegmentGraph(graph, graph_properties,
70                               MakeCandidateFn(candidates),
71                               MakeInputEdgeCandidateFn(input_candidates),
72                               MakeOutputEdgeCandidateFn(output_candidates),
73                               segment_options_, &segments));
74     ValidateSegment(segments, expected_segments);
75   }
76 
RunTest(const Graph * graph,const std::set<string> & candidates,const std::set<string> & input_candidates,const std::set<string> & output_candidates,const std::vector<std::set<string>> & expected_segments)77   void RunTest(const Graph* graph, const std::set<string>& candidates,
78                const std::set<string>& input_candidates,
79                const std::set<string>& output_candidates,
80                const std::vector<std::set<string>>& expected_segments) {
81     RunTest(graph, nullptr, candidates, input_candidates, output_candidates,
82             expected_segments);
83   }
84 
ValidateSegment(const SegmentVector & segments,const std::vector<std::set<string>> & expected_segments)85   void ValidateSegment(const SegmentVector& segments,
86                        const std::vector<std::set<string>>& expected_segments) {
87     EXPECT_EQ(expected_segments.size(), segments.size());
88     for (int i = 0; i < segments.size(); ++i) {
89       std::set<string> segment_node_names;
90       for (const Node* node : segments[i].nodes) {
91         segment_node_names.insert(node->name());
92       }
93       const auto& expected = expected_segments[i];
94       for (const auto& name : expected) {
95         EXPECT_TRUE(segment_node_names.count(name))
96             << "Segment " << i << " is missing expected node: " << name;
97       }
98       if (segment_node_names.size() == expected.size()) continue;
99       for (const auto& name : segment_node_names) {
100         EXPECT_TRUE(expected.count(name))
101             << "Unexpected node found in segment " << i << ": " << name;
102       }
103     }
104   }
105 
DisableImplicitBatchMode()106   void DisableImplicitBatchMode() {
107     segment_options_.use_implicit_batch = false;
108     segment_options_.allow_dynamic_non_batch_dim = true;
109   }
110 
EnableImplicitBatchModeForStaticEngine(int maximum_batch_size=1000)111   void EnableImplicitBatchModeForStaticEngine(int maximum_batch_size = 1000) {
112     segment_options_.use_implicit_batch = true;
113     segment_options_.maximum_batch_size = maximum_batch_size;
114     segment_options_.allow_dynamic_non_batch_dim = false;
115   }
116 
117   SegmentOptions segment_options_;
118 };
119 
operator -(const std::set<string> & lhs,const string & rhs)120 std::set<string> operator-(const std::set<string>& lhs, const string& rhs) {
121   std::set<string> result = lhs;
122   CHECK(result.erase(rhs));
123   return result;
124 }
125 
TEST_F(SegmentTest,Empty)126 TEST_F(SegmentTest, Empty) {
127   Scope s = Scope::NewRootScope();
128   Graph g(OpRegistry::Global());
129   TF_EXPECT_OK(s.ToGraph(&g));
130   // Expect no segments/subgraphs.
131   DisableImplicitBatchMode();
132   RunTest(&g, {}, {}, {}, {});
133 }
134 
TEST_F(SegmentTest,Simple)135 TEST_F(SegmentTest, Simple) {
136   //           feed
137   //          //  \\
138   //       add0    add1
139   //        | \    /
140   //        |  add2
141   //        | /   \\
142   //       add3    add4
143   //          \    /
144   //          <sink>
145   Scope s = Scope::NewRootScope();
146   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
147   auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
148   auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
149   auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
150   auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
151   auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
152   Graph g(OpRegistry::Global());
153   TF_EXPECT_OK(s.ToGraph(&g));
154 
155   // All Add operations are candidates, and we expect all of them to be
156   // collapsed into a single segment
157   const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4"};
158   DisableImplicitBatchMode();
159   RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
160 
161   // Make add1 not a candidate, and we expect all other Add operations to be
162   // collapsed into a single segment
163   auto without_add1 = all_adds - "add1";
164   RunTest(&g, without_add1, without_add1, without_add1, {without_add1});
165 
166   // Make add1 not a candidate and add2 not an input candidate, and we expect
167   // add0 and add2 are removed from the segment.
168   auto without_add2 = all_adds - "add2";
169   RunTest(&g, without_add1, without_add2, without_add1, {{"add3", "add4"}});
170 
171   // Making add2 not an input candidate itself won't affect anything.
172   RunTest(&g, all_adds, without_add2, all_adds, {all_adds});
173 
174   // Making add1 not an input candidate.
175   RunTest(&g, all_adds, without_add1, all_adds, {without_add1});
176 
177   // Making add3 not an output candidate doesn't affect anything, since it's
178   // output is sink.
179   auto without_add3 = all_adds - "add3";
180   RunTest(&g, all_adds, all_adds, without_add3, {all_adds});
181 }
182 
TEST_F(SegmentTest,WithDeviceAssignments)183 TEST_F(SegmentTest, WithDeviceAssignments) {
184   //           feed
185   //          //  \\
186   //       add0    add1
187   //        | \    /
188   //        |  add2
189   //        | /   \\
190   //       add3    add4
191   //          \    /
192   //          <sink>
193   Scope s = Scope::NewRootScope();
194   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
195   auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
196   auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
197   auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
198   auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
199   auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
200 
201   const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4"};
202   DisableImplicitBatchMode();
203 
204   {
205     Graph g(OpRegistry::Global());
206     TF_EXPECT_OK(s.ToGraph(&g));
207     RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
208   }
209 
210   {
211     // Assigning add1 to CPU to exclude it from the cluster.
212     add1.node()->set_assigned_device_name("/device:CPU:0");
213     Graph g(OpRegistry::Global());
214     TF_EXPECT_OK(s.ToGraph(&g));
215     RunTest(&g, all_adds, all_adds, all_adds, {all_adds - "add1"});
216     add1.node()->set_assigned_device_name("");
217   }
218 
219   {
220     // Assigning operations add3 and add4 to another GPU to exclude the
221     // operation from the cluster.
222     constexpr char kGpu0[] = "/device:GPU:0";
223     add0.node()->set_assigned_device_name(kGpu0);
224     add1.node()->set_assigned_device_name(kGpu0);
225     add2.node()->set_assigned_device_name(kGpu0);
226     constexpr char kGpu1[] = "/device:GPU:1";
227     add3.node()->set_assigned_device_name(kGpu1);
228     add4.node()->set_assigned_device_name(kGpu1);
229     Graph g(OpRegistry::Global());
230     TF_EXPECT_OK(s.ToGraph(&g));
231     RunTest(&g, all_adds, all_adds, all_adds, {{"add0", "add1", "add2"}});
232   }
233 
234   {
235     // Assigning the operations to two compatibile GPU devices resulting in
236     // one cluster with all operations.
237     constexpr char kGpuAny[] = "/device:GPU:*";
238     add3.node()->set_assigned_device_name(kGpuAny);
239     add4.node()->set_assigned_device_name(kGpuAny);
240     Graph g(OpRegistry::Global());
241     TF_EXPECT_OK(s.ToGraph(&g));
242     RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
243   }
244 }
245 
TEST_F(SegmentTest,AvoidCycle)246 TEST_F(SegmentTest, AvoidCycle) {
247   //           feed
248   //          //  \\
249   //       add0    add1
250   //        | \    /
251   //        |  add2
252   //        |  /  \\
253   //       add3    add4
254   //          \    /
255   //          <sink>
256   Scope s = Scope::NewRootScope();
257   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
258   auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
259   auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
260   auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
261   auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
262   auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
263   Graph g(OpRegistry::Global());
264   TF_EXPECT_OK(s.ToGraph(&g));
265 
266   // add2 is not a TRT candidate so there should be no segments generated.
267   const std::set<string> without_add2 = {"add0", "add1", "add3", "add4"};
268   DisableImplicitBatchMode();
269   RunTest(&g, without_add2, without_add2, without_add2, {});
270 }
271 
TEST_F(SegmentTest,Multiple)272 TEST_F(SegmentTest, Multiple) {
273   //              feed
274   //           //  ||  \\
275   //        add0  add1  add7
276   //        |  \  /     / \\
277   //        |  add2    /   \\
278   //        |   || \   |   ||
279   //        |   ||  add5  add8
280   //        |  /  \ /  \   /
281   //        add3  add4  add6
282   //           \   |   /
283   //             <sink>
284   Scope s = Scope::NewRootScope();
285   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
286   auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
287   auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
288   auto add7 = ops::Add(s.WithOpName("add7"), feed, feed);
289   auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
290   auto add5 = ops::Add(s.WithOpName("add5"), add2, add7);
291   auto add8 = ops::Add(s.WithOpName("add8"), add7, add7);
292   auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
293   auto add4 = ops::Add(s.WithOpName("add4"), add2, add5);
294   auto add6 = ops::Add(s.WithOpName("add6"), add5, add8);
295   Graph g(OpRegistry::Global());
296   TF_EXPECT_OK(s.ToGraph(&g));
297 
298   const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4",
299                                      "add5", "add6", "add7", "add8"};
300   // Make add5 not a TRT candidate, and we expect two segments.
301   auto without_add5 = all_adds - "add5";
302   DisableImplicitBatchMode();
303   RunTest(&g, without_add5, without_add5, without_add5,
304           {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}});
305 
306   // Make add8 not a candidate and add6 not an input candidate, then all direct
307   // and indirect inputs of add6 will be removed from the segment.
308   auto without_add8 = all_adds - "add8";
309   auto without_add6 = all_adds - "add6";
310   RunTest(&g, without_add8, without_add6, all_adds, {{"add3", "add4"}});
311 
312   // Make add3 not a candidate and add0 not an output candidate, then all
313   // direct and indirect outputs of add0 will be removed from the segment.
314   auto without_add3 = all_adds - "add3";
315   auto without_add0 = all_adds - "add0";
316   RunTest(&g, without_add3, all_adds, without_add0, {{"add1", "add7", "add8"}});
317 }
318 
TEST_F(SegmentTest,BigIfElse)319 TEST_F(SegmentTest, BigIfElse) {
320   //           feed
321   //            ||
322   //           add0
323   //         //    \\
324   //       add1    add4
325   //        ||      ||
326   //       add2    add5
327   //        ||      ||
328   //       add3    add6
329   //         \\    //
330   //           add7
331   //            ||
332   //          <sink>
333   Scope s = Scope::NewRootScope();
334   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
335   auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
336   auto add1 = ops::Add(s.WithOpName("add1"), add0, add0);
337   auto add2 = ops::Add(s.WithOpName("add2"), add1, add1);
338   auto add3 = ops::Add(s.WithOpName("add3"), add2, add2);
339   auto add4 = ops::Add(s.WithOpName("add4"), add0, add0);
340   auto add5 = ops::Add(s.WithOpName("add5"), add4, add4);
341   auto add6 = ops::Add(s.WithOpName("add6"), add5, add5);
342   auto add7 = ops::Add(s.WithOpName("add7"), add3, add6);
343   Graph g(OpRegistry::Global());
344   TF_EXPECT_OK(s.ToGraph(&g));
345 
346   // Make add2 not a TRT candidate, and we expect 2 segments.
347   const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
348                                      "add4", "add5", "add6", "add7"};
349   DisableImplicitBatchMode();
350   RunTest(&g, all_adds - "add2", all_adds, all_adds,
351           {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
352 }
353 
TEST_F(SegmentTest,IdentityOps)354 TEST_F(SegmentTest, IdentityOps) {
355   Scope s = Scope::NewRootScope();
356   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
357   auto identity0 = ops::Identity(s.WithOpName("identity0"), feed);
358   auto identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
359   auto identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
360   auto identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
361   Graph g(OpRegistry::Global());
362   TF_EXPECT_OK(s.ToGraph(&g));
363 
364   const std::set<string> all_identities = {"identity0", "identity1",
365                                            "identity2", "identity3"};
366   // Identity ops are not counted as effective ops in the segment, so no segment
367   // will be formed in this case.
368   DisableImplicitBatchMode();
369   RunTest(&g, all_identities, all_identities, all_identities, {});
370 }
371 
372 // Testing implicit batch mode segmentation: it excludes the add-2 operation
373 // with a dynamic non-batch dimension.
TEST_F(SegmentTest,ExcludeAddWithDynamicNonBatchDimension)374 TEST_F(SegmentTest, ExcludeAddWithDynamicNonBatchDimension) {
375   Scope s = Scope::NewRootScope();
376   auto feed_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2, 3}));
377   auto feed_1_shape = ops::Placeholder::Shape(PartialTensorShape({-1, -1, 3}));
378   auto const_val = ops::Const<float>(s, {1.0}, {});
379   auto feed_0 =
380       ops::Placeholder(s.WithOpName("feed-1"), DT_FLOAT, feed_0_shape);
381   auto feed_1 =
382       ops::Placeholder(s.WithOpName("feed-2"), DT_FLOAT, feed_1_shape);
383   auto add_0 = ops::Add(s.WithOpName("add-0"), feed_0, const_val);
384   auto add_1 = ops::Add(s.WithOpName("add-1"), add_0, feed_0);
385   auto add_2 = ops::Add(s.WithOpName("add-2"), const_val, feed_1);
386 
387   grappler::GrapplerItem item;
388   item.fetch.push_back("add-2");
389   TF_EXPECT_OK(s.ToGraphDef(&item.graph));
390 
391   grappler::GraphProperties static_graph_properties(item);
392   TF_EXPECT_OK(static_graph_properties.InferStatically(true));
393 
394   Graph g(OpRegistry::Global());
395   TF_CHECK_OK(
396       ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
397 
398   const std::set<string> all_nodes = {"add-0", "add-1", "add-2"};
399   EnableImplicitBatchModeForStaticEngine();
400   RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
401           {all_nodes - "add-2"});
402 }
403 
404 // Testing implicit batch mode segmentation: It excludes the reshape operation
405 // with a dynamic non-batch output dimension.
406 // TODO(bixia): hoist the check for reshape should not change batch size from
407 // the converter to the segmenter and add another test case for excluding
408 // a reshape without dynamic dimensions involved.
TEST_F(SegmentTest,ExcludeReshapeWithDynamicNonBatchDimensionInOutput)409 TEST_F(SegmentTest, ExcludeReshapeWithDynamicNonBatchDimensionInOutput) {
410   Scope s = Scope::NewRootScope();
411   auto feed_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2, 3}));
412   auto const_val = ops::Const<float>(s, {1.0}, {});
413   auto feed_0 =
414       ops::Placeholder(s.WithOpName("feed-1"), DT_FLOAT, feed_0_shape);
415   auto add_0 = ops::Add(s.WithOpName("add-0"), feed_0, const_val);
416   auto reshape = ops::Reshape(s.WithOpName("reshape"), add_0, Input({6, -1}));
417   auto add_1 = ops::Add(s.WithOpName("add-1"), reshape, const_val);
418 
419   grappler::GrapplerItem item;
420   item.fetch.push_back("add-1");
421   TF_EXPECT_OK(s.ToGraphDef(&item.graph));
422 
423   grappler::GraphProperties static_graph_properties(item);
424   TF_EXPECT_OK(static_graph_properties.InferStatically(true));
425 
426   Graph g(OpRegistry::Global());
427   TF_CHECK_OK(
428       ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
429 
430   const std::set<string> all_nodes = {"add-0", "reshape", "add-1"};
431   EnableImplicitBatchModeForStaticEngine();
432   RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {});
433 }
434 
TEST_F(SegmentTest,RankOneCannotUseImplicitBatch)435 TEST_F(SegmentTest, RankOneCannotUseImplicitBatch) {
436   Scope s = Scope::NewRootScope();
437   auto input_0_shape = ops::Placeholder::Shape(TensorShape({3}));
438   auto input_1_shape = ops::Placeholder::Shape(TensorShape({3}));
439   auto input_0 =
440       ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
441   auto input_1 =
442       ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
443   auto const_val = ops::Const(s.WithOpName("const-scalar"), 1.0f, {});
444   auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, const_val);
445   auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, const_val);
446 
447   grappler::GrapplerItem item;
448   item.fetch.push_back("output-0");
449   item.fetch.push_back("output-1");
450   TF_EXPECT_OK(s.ToGraphDef(&item.graph));
451 
452   grappler::GraphProperties static_graph_properties(item);
453   TF_EXPECT_OK(static_graph_properties.InferStatically(true));
454 
455   Graph g(OpRegistry::Global());
456   TF_CHECK_OK(
457       ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
458 
459   const std::set<string> all_nodes = {"const-scalar", "output-0", "output-1"};
460   EnableImplicitBatchModeForStaticEngine();
461   RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {});
462 }
463 
TEST_F(SegmentTest,TwoChainsDiffBatchSizes)464 TEST_F(SegmentTest, TwoChainsDiffBatchSizes) {
465   Scope s = Scope::NewRootScope();
466   auto input_0_shape = ops::Placeholder::Shape(TensorShape({2, 3}));
467   auto input_1_shape = ops::Placeholder::Shape(TensorShape({5, 3}));
468   auto input_0 =
469       ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
470   auto input_1 =
471       ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
472   auto const_val = ops::Const(s.WithOpName("const-scalar"), 1.0f, {});
473   auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, const_val);
474   auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, const_val);
475 
476   grappler::GrapplerItem item;
477   item.fetch.push_back("output-0");
478   item.fetch.push_back("output-1");
479   TF_EXPECT_OK(s.ToGraphDef(&item.graph));
480 
481   grappler::GraphProperties static_graph_properties(item);
482   TF_EXPECT_OK(static_graph_properties.InferStatically(true));
483 
484   Graph g(OpRegistry::Global());
485   TF_CHECK_OK(
486       ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
487 
488   const std::set<string> all_nodes = {"const-scalar", "output-0", "output-1"};
489   EnableImplicitBatchModeForStaticEngine();
490   RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
491           /*expected_segments=*/{{"output-0", "const-scalar"}});
492 
493   // Converter will create engines based on the static batch size
494   EnableImplicitBatchModeForStaticEngine(1);
495   RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
496           /*expected_segments=*/{{"output-0", "const-scalar"}});
497 }
498 
TEST_F(SegmentTest,SameRankImplicitBroadcastingStaticBatchSize)499 TEST_F(SegmentTest, SameRankImplicitBroadcastingStaticBatchSize) {
500   Scope s = Scope::NewRootScope();
501   auto input_0_shape = ops::Placeholder::Shape(TensorShape({2, 3, 1}));
502   auto input_1_shape = ops::Placeholder::Shape(TensorShape({1, 3, 4}));
503   auto input_2_shape = ops::Placeholder::Shape(TensorShape({2, 3, 4}));
504   auto input_0 =
505       ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
506   auto input_1 =
507       ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
508   auto input_2 =
509       ops::Placeholder(s.WithOpName("input-2"), DT_FLOAT, input_2_shape);
510   auto multiple = ops::Mul(s.WithOpName("multiple"), input_2, input_2);
511   auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, multiple);
512   auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, multiple);
513 
514   grappler::GrapplerItem item;
515   item.fetch.push_back("output-0");
516   item.fetch.push_back("output-1");
517   TF_EXPECT_OK(s.ToGraphDef(&item.graph));
518 
519   grappler::GraphProperties static_graph_properties(item);
520   TF_EXPECT_OK(static_graph_properties.InferStatically(true));
521 
522   Graph g(OpRegistry::Global());
523   TF_CHECK_OK(
524       ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
525 
526   const std::set<string> all_nodes = {"multiple", "output-0", "output-1"};
527   EnableImplicitBatchModeForStaticEngine();
528   RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
529           {all_nodes});
530 }
531 
TEST_F(SegmentTest,SameRankImplicitBroadcastingDynamicBatchSize)532 TEST_F(SegmentTest, SameRankImplicitBroadcastingDynamicBatchSize) {
533   Scope s = Scope::NewRootScope();
534   auto input_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2}));
535   auto input_1_shape = ops::Placeholder::Shape(TensorShape({1, 2}));
536   auto input_0 =
537       ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
538   auto input_1 =
539       ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
540   auto const_val = ops::Const(s.WithOpName("const-val"), 1.0f, {1, 1});
541   auto add_0 = ops::Add(s.WithOpName("add-0"), input_0, const_val);
542   auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, add_0);
543 
544   grappler::GrapplerItem item;
545   item.fetch.push_back("output-0");
546   TF_EXPECT_OK(s.ToGraphDef(&item.graph));
547 
548   grappler::GraphProperties static_graph_properties(item);
549   TF_EXPECT_OK(static_graph_properties.InferStatically(true));
550 
551   Graph g(OpRegistry::Global());
552   TF_CHECK_OK(
553       ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
554 
555   const std::set<string> all_nodes = {"const-val", "add-0", "output-0"};
556   EnableImplicitBatchModeForStaticEngine();
557   RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes,
558           {{"const-val", "add-0", "output-0"}});
559 }
560 
TEST_F(SegmentTest,IncompatibleBatchSizes)561 TEST_F(SegmentTest, IncompatibleBatchSizes) {
562   Scope s = Scope::NewRootScope();
563   auto input_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2}));
564   auto input_1_shape = ops::Placeholder::Shape(TensorShape({2, 2}));
565   auto input_0 =
566       ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape);
567   auto input_1 =
568       ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape);
569   auto const_val = ops::Const(s.WithOpName("const-val"), 1.0f, {2, 2});
570   auto add_0 = ops::Add(s.WithOpName("add-0"), input_0, const_val);
571   auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, add_0);
572 
573   grappler::GrapplerItem item;
574   item.fetch.push_back("output-0");
575   TF_EXPECT_OK(s.ToGraphDef(&item.graph));
576 
577   grappler::GraphProperties static_graph_properties(item);
578   TF_EXPECT_OK(static_graph_properties.InferStatically(true));
579 
580   Graph g(OpRegistry::Global());
581   TF_CHECK_OK(
582       ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g));
583 
584   const std::set<string> all_nodes = {"const-val", "add-0", "output-0"};
585   EnableImplicitBatchModeForStaticEngine();
586   RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {});
587 }
588 }  // namespace test
589 }  // namespace segment
590 }  // namespace tensorrt
591 }  // namespace tensorflow
592 
593 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
594