1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/graph/testlib.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/core/public/session.h"
28
29 #if GOOGLE_CUDA
30 #if GOOGLE_TENSORRT
31
32 namespace tensorflow {
33 namespace tensorrt {
34 namespace segment {
35 namespace test {
36
37 class SegmentTest : public ::testing::Test {
38 protected:
MakeCandidateFn(const std::set<string> & node_names)39 std::function<Status(const Node*)> MakeCandidateFn(
40 const std::set<string>& node_names) {
41 return [node_names](const Node* node) -> Status {
42 if (node_names.find(node->name()) != node_names.end()) {
43 return Status::OK();
44 }
45 return errors::NotFound("");
46 };
47 }
48
MakeInputEdgeCandidateFn(const std::set<string> & node_names)49 std::function<bool(const Edge*)> MakeInputEdgeCandidateFn(
50 const std::set<string>& node_names) {
51 return [node_names](const Edge* in_edge) -> bool {
52 return node_names.find(in_edge->dst()->name()) != node_names.end();
53 };
54 }
55
MakeOutputEdgeCandidateFn(const std::set<string> & node_names)56 std::function<bool(const Edge*)> MakeOutputEdgeCandidateFn(
57 const std::set<string>& node_names) {
58 return [node_names](const Edge* out_edge) -> bool {
59 return node_names.find(out_edge->src()->name()) != node_names.end();
60 };
61 }
62
RunTest(const Graph * graph,const std::set<string> & candidates,const std::set<string> & input_candidates,const std::set<string> & output_candidates,const std::vector<std::set<string>> & expected_segments)63 void RunTest(const Graph* graph, const std::set<string>& candidates,
64 const std::set<string>& input_candidates,
65 const std::set<string>& output_candidates,
66 const std::vector<std::set<string>>& expected_segments) {
67 SegmentNodesVector segments;
68 TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates),
69 MakeInputEdgeCandidateFn(input_candidates),
70 MakeOutputEdgeCandidateFn(output_candidates),
71 default_options_, &segments));
72 ValidateSegment(segments, expected_segments);
73 }
74
ValidateSegment(const SegmentNodesVector & segments,const std::vector<std::set<string>> & expected_segments)75 void ValidateSegment(const SegmentNodesVector& segments,
76 const std::vector<std::set<string>>& expected_segments) {
77 EXPECT_EQ(expected_segments.size(), segments.size());
78 for (int i = 0; i < segments.size(); ++i) {
79 std::set<string> segment_node_names;
80 for (const Node* node : segments[i]) {
81 segment_node_names.insert(node->name());
82 }
83 const auto& expected = expected_segments[i];
84 for (const auto& name : expected) {
85 EXPECT_TRUE(segment_node_names.count(name))
86 << "Segment " << i << " is missing expected node: " << name;
87 }
88 if (segment_node_names.size() == expected.size()) continue;
89 for (const auto& name : segment_node_names) {
90 EXPECT_TRUE(expected.count(name))
91 << "Unexpected node found in segment " << i << ": " << name;
92 }
93 }
94 }
95
96 SegmentOptions default_options_;
97 };
98
operator -(const std::set<string> & lhs,const string & rhs)99 std::set<string> operator-(const std::set<string>& lhs, const string& rhs) {
100 std::set<string> result = lhs;
101 CHECK(result.erase(rhs));
102 return result;
103 }
104
TEST_F(SegmentTest,Empty)105 TEST_F(SegmentTest, Empty) {
106 Scope s = Scope::NewRootScope();
107 Graph g(OpRegistry::Global());
108 TF_EXPECT_OK(s.ToGraph(&g));
109 // Expect no segments/subgraphs.
110 RunTest(&g, {}, {}, {}, {});
111 }
112
TEST_F(SegmentTest,Simple)113 TEST_F(SegmentTest, Simple) {
114 // feed
115 // // \\
116 // add0 add1
117 // | \ /
118 // | add2
119 // | / \\
120 // add3 add4
121 // \ /
122 // <sink>
123 Scope s = Scope::NewRootScope();
124 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
125 auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
126 auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
127 auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
128 auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
129 auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
130 Graph g(OpRegistry::Global());
131 TF_EXPECT_OK(s.ToGraph(&g));
132
133 // All Add operations are candidates, and we expect all of them to be
134 // collapsed into a single segment
135 const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4"};
136 RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
137
138 // Make add1 not a candidate, and we expect all other Add operations to be
139 // collapsed into a single segment
140 auto without_add1 = all_adds - "add1";
141 RunTest(&g, without_add1, without_add1, without_add1, {without_add1});
142
143 // Make add1 not a candidate and add2 not an input candidate, and we expect
144 // add0 and add2 are removed from the segment.
145 auto without_add2 = all_adds - "add2";
146 RunTest(&g, without_add1, without_add2, without_add1, {{"add3", "add4"}});
147
148 // Making add2 not an input candidate itself won't affect anything.
149 RunTest(&g, all_adds, without_add2, all_adds, {all_adds});
150
151 // Making add1 not an input candidate.
152 RunTest(&g, all_adds, without_add1, all_adds, {without_add1});
153
154 // Making add3 not an output candidate doesn't affect anything, since it's
155 // output is sink.
156 auto without_add3 = all_adds - "add3";
157 RunTest(&g, all_adds, all_adds, without_add3, {all_adds});
158 }
159
TEST_F(SegmentTest,AvoidCycle)160 TEST_F(SegmentTest, AvoidCycle) {
161 // feed
162 // // \\
163 // add0 add1
164 // | \ /
165 // | add2
166 // | / \\
167 // add3 add4
168 // \ /
169 // <sink>
170 Scope s = Scope::NewRootScope();
171 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
172 auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
173 auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
174 auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
175 auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
176 auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
177 Graph g(OpRegistry::Global());
178 TF_EXPECT_OK(s.ToGraph(&g));
179
180 // add2 is not a TRT candidate so there should be no segments generated.
181 const std::set<string> without_add2 = {"add0", "add1", "add3", "add4"};
182 RunTest(&g, without_add2, without_add2, without_add2, {});
183 }
184
TEST_F(SegmentTest,Multiple)185 TEST_F(SegmentTest, Multiple) {
186 // feed
187 // // || \\
188 // add0 add1 add7
189 // | \ / / \\
190 // | add2 / \\
191 // | || \ | ||
192 // | || add5 add8
193 // | / \ / \ /
194 // add3 add4 add6
195 // \ | /
196 // <sink>
197 Scope s = Scope::NewRootScope();
198 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
199 auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
200 auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
201 auto add7 = ops::Add(s.WithOpName("add7"), feed, feed);
202 auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
203 auto add5 = ops::Add(s.WithOpName("add5"), add2, add7);
204 auto add8 = ops::Add(s.WithOpName("add8"), add7, add7);
205 auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
206 auto add4 = ops::Add(s.WithOpName("add4"), add2, add5);
207 auto add6 = ops::Add(s.WithOpName("add6"), add5, add8);
208 Graph g(OpRegistry::Global());
209 TF_EXPECT_OK(s.ToGraph(&g));
210
211 const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4",
212 "add5", "add6", "add7", "add8"};
213 // Make add5 not a TRT candidate, and we expect two segments.
214 auto without_add5 = all_adds - "add5";
215 RunTest(&g, without_add5, without_add5, without_add5,
216 {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}});
217
218 // Make add8 not a candidate and add6 not an input candidate, then all direct
219 // and indirect inputs of add6 will be removed from the segment.
220 auto without_add8 = all_adds - "add8";
221 auto without_add6 = all_adds - "add6";
222 RunTest(&g, without_add8, without_add6, all_adds, {{"add3", "add4"}});
223
224 // Make add3 not a candidate and add0 not an output candidate, then all
225 // direct and indirect outputs of add0 will be removed from the segment.
226 auto without_add3 = all_adds - "add3";
227 auto without_add0 = all_adds - "add0";
228 RunTest(&g, without_add3, all_adds, without_add0, {{"add1", "add7", "add8"}});
229 }
230
TEST_F(SegmentTest,BigIfElse)231 TEST_F(SegmentTest, BigIfElse) {
232 // feed
233 // ||
234 // add0
235 // // \\
236 // add1 add4
237 // || ||
238 // add2 add5
239 // || ||
240 // add3 add6
241 // \\ //
242 // add7
243 // ||
244 // <sink>
245 Scope s = Scope::NewRootScope();
246 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
247 auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
248 auto add1 = ops::Add(s.WithOpName("add1"), add0, add0);
249 auto add2 = ops::Add(s.WithOpName("add2"), add1, add1);
250 auto add3 = ops::Add(s.WithOpName("add3"), add2, add2);
251 auto add4 = ops::Add(s.WithOpName("add4"), add0, add0);
252 auto add5 = ops::Add(s.WithOpName("add5"), add4, add4);
253 auto add6 = ops::Add(s.WithOpName("add6"), add5, add5);
254 auto add7 = ops::Add(s.WithOpName("add7"), add3, add6);
255 Graph g(OpRegistry::Global());
256 TF_EXPECT_OK(s.ToGraph(&g));
257
258 // Make add2 not a TRT candidate, and we expect 2 segments.
259 const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
260 "add4", "add5", "add6", "add7"};
261 RunTest(&g, all_adds - "add2", all_adds, all_adds,
262 {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
263 }
264
TEST_F(SegmentTest,IdentityOps)265 TEST_F(SegmentTest, IdentityOps) {
266 Scope s = Scope::NewRootScope();
267 auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
268 auto identity0 = ops::Identity(s.WithOpName("identity0"), feed);
269 auto identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
270 auto identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
271 auto identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
272 Graph g(OpRegistry::Global());
273 TF_EXPECT_OK(s.ToGraph(&g));
274
275 const std::set<string> all_identities = {"identity0", "identity1",
276 "identity2", "identity3"};
277 // Identity ops are not counted as effective ops in the segment, so no segment
278 // will be formed in this case.
279 RunTest(&g, all_identities, all_identities, all_identities, {});
280 }
281
282 } // namespace test
283 } // namespace segment
284 } // namespace tensorrt
285 } // namespace tensorflow
286
287 #endif // GOOGLE_TENSORRT
288 #endif // GOOGLE_CUDA
289