• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17 
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/graph/testlib.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/core/public/session.h"
28 
29 #if GOOGLE_CUDA
30 #if GOOGLE_TENSORRT
31 
32 namespace tensorflow {
33 namespace tensorrt {
34 namespace segment {
35 namespace test {
36 
37 class SegmentTest : public ::testing::Test {
38  protected:
MakeCandidateFn(const std::set<string> & node_names)39   std::function<Status(const Node*)> MakeCandidateFn(
40       const std::set<string>& node_names) {
41     return [node_names](const Node* node) -> Status {
42       if (node_names.find(node->name()) != node_names.end()) {
43         return Status::OK();
44       }
45       return errors::NotFound("");
46     };
47   }
48 
MakeInputEdgeCandidateFn(const std::set<string> & node_names)49   std::function<bool(const Edge*)> MakeInputEdgeCandidateFn(
50       const std::set<string>& node_names) {
51     return [node_names](const Edge* in_edge) -> bool {
52       return node_names.find(in_edge->dst()->name()) != node_names.end();
53     };
54   }
55 
MakeOutputEdgeCandidateFn(const std::set<string> & node_names)56   std::function<bool(const Edge*)> MakeOutputEdgeCandidateFn(
57       const std::set<string>& node_names) {
58     return [node_names](const Edge* out_edge) -> bool {
59       return node_names.find(out_edge->src()->name()) != node_names.end();
60     };
61   }
62 
RunTest(const Graph * graph,const std::set<string> & candidates,const std::set<string> & input_candidates,const std::set<string> & output_candidates,const std::vector<std::set<string>> & expected_segments)63   void RunTest(const Graph* graph, const std::set<string>& candidates,
64                const std::set<string>& input_candidates,
65                const std::set<string>& output_candidates,
66                const std::vector<std::set<string>>& expected_segments) {
67     SegmentNodesVector segments;
68     TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates),
69                               MakeInputEdgeCandidateFn(input_candidates),
70                               MakeOutputEdgeCandidateFn(output_candidates),
71                               default_options_, &segments));
72     ValidateSegment(segments, expected_segments);
73   }
74 
ValidateSegment(const SegmentNodesVector & segments,const std::vector<std::set<string>> & expected_segments)75   void ValidateSegment(const SegmentNodesVector& segments,
76                        const std::vector<std::set<string>>& expected_segments) {
77     EXPECT_EQ(expected_segments.size(), segments.size());
78     for (int i = 0; i < segments.size(); ++i) {
79       std::set<string> segment_node_names;
80       for (const Node* node : segments[i]) {
81         segment_node_names.insert(node->name());
82       }
83       const auto& expected = expected_segments[i];
84       for (const auto& name : expected) {
85         EXPECT_TRUE(segment_node_names.count(name))
86             << "Segment " << i << " is missing expected node: " << name;
87       }
88       if (segment_node_names.size() == expected.size()) continue;
89       for (const auto& name : segment_node_names) {
90         EXPECT_TRUE(expected.count(name))
91             << "Unexpected node found in segment " << i << ": " << name;
92       }
93     }
94   }
95 
96   SegmentOptions default_options_;
97 };
98 
operator -(const std::set<string> & lhs,const string & rhs)99 std::set<string> operator-(const std::set<string>& lhs, const string& rhs) {
100   std::set<string> result = lhs;
101   CHECK(result.erase(rhs));
102   return result;
103 }
104 
TEST_F(SegmentTest,Empty)105 TEST_F(SegmentTest, Empty) {
106   Scope s = Scope::NewRootScope();
107   Graph g(OpRegistry::Global());
108   TF_EXPECT_OK(s.ToGraph(&g));
109   // Expect no segments/subgraphs.
110   RunTest(&g, {}, {}, {}, {});
111 }
112 
TEST_F(SegmentTest,Simple)113 TEST_F(SegmentTest, Simple) {
114   //           feed
115   //          //  \\
116   //       add0    add1
117   //        | \    /
118   //        |  add2
119   //        | /   \\
120   //       add3    add4
121   //          \    /
122   //          <sink>
123   Scope s = Scope::NewRootScope();
124   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
125   auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
126   auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
127   auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
128   auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
129   auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
130   Graph g(OpRegistry::Global());
131   TF_EXPECT_OK(s.ToGraph(&g));
132 
133   // All Add operations are candidates, and we expect all of them to be
134   // collapsed into a single segment
135   const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4"};
136   RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
137 
138   // Make add1 not a candidate, and we expect all other Add operations to be
139   // collapsed into a single segment
140   auto without_add1 = all_adds - "add1";
141   RunTest(&g, without_add1, without_add1, without_add1, {without_add1});
142 
143   // Make add1 not a candidate and add2 not an input candidate, and we expect
144   // add0 and add2 are removed from the segment.
145   auto without_add2 = all_adds - "add2";
146   RunTest(&g, without_add1, without_add2, without_add1, {{"add3", "add4"}});
147 
148   // Making add2 not an input candidate itself won't affect anything.
149   RunTest(&g, all_adds, without_add2, all_adds, {all_adds});
150 
151   // Making add1 not an input candidate.
152   RunTest(&g, all_adds, without_add1, all_adds, {without_add1});
153 
154   // Making add3 not an output candidate doesn't affect anything, since it's
155   // output is sink.
156   auto without_add3 = all_adds - "add3";
157   RunTest(&g, all_adds, all_adds, without_add3, {all_adds});
158 }
159 
TEST_F(SegmentTest,AvoidCycle)160 TEST_F(SegmentTest, AvoidCycle) {
161   //           feed
162   //          //  \\
163   //       add0    add1
164   //        | \    /
165   //        |  add2
166   //        |  /  \\
167   //       add3    add4
168   //          \    /
169   //          <sink>
170   Scope s = Scope::NewRootScope();
171   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
172   auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
173   auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
174   auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
175   auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
176   auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
177   Graph g(OpRegistry::Global());
178   TF_EXPECT_OK(s.ToGraph(&g));
179 
180   // add2 is not a TRT candidate so there should be no segments generated.
181   const std::set<string> without_add2 = {"add0", "add1", "add3", "add4"};
182   RunTest(&g, without_add2, without_add2, without_add2, {});
183 }
184 
TEST_F(SegmentTest,Multiple)185 TEST_F(SegmentTest, Multiple) {
186   //              feed
187   //           //  ||  \\
188   //        add0  add1  add7
189   //        |  \  /     / \\
190   //        |  add2    /   \\
191   //        |   || \   |   ||
192   //        |   ||  add5  add8
193   //        |  /  \ /  \   /
194   //        add3  add4  add6
195   //           \   |   /
196   //             <sink>
197   Scope s = Scope::NewRootScope();
198   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
199   auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
200   auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
201   auto add7 = ops::Add(s.WithOpName("add7"), feed, feed);
202   auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
203   auto add5 = ops::Add(s.WithOpName("add5"), add2, add7);
204   auto add8 = ops::Add(s.WithOpName("add8"), add7, add7);
205   auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
206   auto add4 = ops::Add(s.WithOpName("add4"), add2, add5);
207   auto add6 = ops::Add(s.WithOpName("add6"), add5, add8);
208   Graph g(OpRegistry::Global());
209   TF_EXPECT_OK(s.ToGraph(&g));
210 
211   const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4",
212                                      "add5", "add6", "add7", "add8"};
213   // Make add5 not a TRT candidate, and we expect two segments.
214   auto without_add5 = all_adds - "add5";
215   RunTest(&g, without_add5, without_add5, without_add5,
216           {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}});
217 
218   // Make add8 not a candidate and add6 not an input candidate, then all direct
219   // and indirect inputs of add6 will be removed from the segment.
220   auto without_add8 = all_adds - "add8";
221   auto without_add6 = all_adds - "add6";
222   RunTest(&g, without_add8, without_add6, all_adds, {{"add3", "add4"}});
223 
224   // Make add3 not a candidate and add0 not an output candidate, then all
225   // direct and indirect outputs of add0 will be removed from the segment.
226   auto without_add3 = all_adds - "add3";
227   auto without_add0 = all_adds - "add0";
228   RunTest(&g, without_add3, all_adds, without_add0, {{"add1", "add7", "add8"}});
229 }
230 
TEST_F(SegmentTest,BigIfElse)231 TEST_F(SegmentTest, BigIfElse) {
232   //           feed
233   //            ||
234   //           add0
235   //         //    \\
236   //       add1    add4
237   //        ||      ||
238   //       add2    add5
239   //        ||      ||
240   //       add3    add6
241   //         \\    //
242   //           add7
243   //            ||
244   //          <sink>
245   Scope s = Scope::NewRootScope();
246   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
247   auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
248   auto add1 = ops::Add(s.WithOpName("add1"), add0, add0);
249   auto add2 = ops::Add(s.WithOpName("add2"), add1, add1);
250   auto add3 = ops::Add(s.WithOpName("add3"), add2, add2);
251   auto add4 = ops::Add(s.WithOpName("add4"), add0, add0);
252   auto add5 = ops::Add(s.WithOpName("add5"), add4, add4);
253   auto add6 = ops::Add(s.WithOpName("add6"), add5, add5);
254   auto add7 = ops::Add(s.WithOpName("add7"), add3, add6);
255   Graph g(OpRegistry::Global());
256   TF_EXPECT_OK(s.ToGraph(&g));
257 
258   // Make add2 not a TRT candidate, and we expect 2 segments.
259   const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
260                                      "add4", "add5", "add6", "add7"};
261   RunTest(&g, all_adds - "add2", all_adds, all_adds,
262           {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
263 }
264 
TEST_F(SegmentTest,IdentityOps)265 TEST_F(SegmentTest, IdentityOps) {
266   Scope s = Scope::NewRootScope();
267   auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
268   auto identity0 = ops::Identity(s.WithOpName("identity0"), feed);
269   auto identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
270   auto identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
271   auto identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
272   Graph g(OpRegistry::Global());
273   TF_EXPECT_OK(s.ToGraph(&g));
274 
275   const std::set<string> all_identities = {"identity0", "identity1",
276                                            "identity2", "identity3"};
277   // Identity ops are not counted as effective ops in the segment, so no segment
278   // will be formed in this case.
279   RunTest(&g, all_identities, all_identities, all_identities, {});
280 }
281 
282 }  // namespace test
283 }  // namespace segment
284 }  // namespace tensorrt
285 }  // namespace tensorflow
286 
287 #endif  // GOOGLE_TENSORRT
288 #endif  // GOOGLE_CUDA
289