1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/algorithm.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
23 #include "tensorflow/core/graph/benchmark_testlib.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/graph/graph_def_builder.h"
26 #include "tensorflow/core/graph/subgraph.h"
27 #include "tensorflow/core/kernels/ops_util.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/platform/test_benchmark.h"
32
33 // TODO(josh11b): Test setting the "device" field of a NodeDef.
34 // TODO(josh11b): Test that feeding won't prune targets.
35
36 namespace tensorflow {
37 namespace {
38
39 REGISTER_OP("TestParams").Output("o: float");
40 REGISTER_OP("TestInput").Output("a: float").Output("b: float");
41 REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
42 REGISTER_OP("TestUnary").Input("a: float").Output("o: float");
43 REGISTER_OP("TestBinary")
44 .Input("a: float")
45 .Input("b: float")
46 .Output("o: float");
47
48 // Compares that the order of nodes in 'inputs' respects the
49 // pair orders described in 'ordered_pairs'.
ExpectBefore(const std::vector<std::pair<string,string>> & ordered_pairs,const std::vector<Node * > & inputs,string * error)50 bool ExpectBefore(const std::vector<std::pair<string, string>>& ordered_pairs,
51 const std::vector<Node*>& inputs, string* error) {
52 for (const std::pair<string, string>& pair : ordered_pairs) {
53 const string& before_node = pair.first;
54 const string& after_node = pair.second;
55 bool seen_before = false;
56 bool seen_both = false;
57 for (const Node* node : inputs) {
58 if (!seen_before && after_node == node->name()) {
59 *error = strings::StrCat("Saw ", after_node, " before ", before_node);
60 return false;
61 }
62
63 if (before_node == node->name()) {
64 seen_before = true;
65 } else if (after_node == node->name()) {
66 seen_both = seen_before;
67 break;
68 }
69 }
70 if (!seen_both) {
71 *error = strings::StrCat("didn't see either ", before_node, " or ",
72 after_node);
73 return false;
74 }
75 }
76
77 return true;
78 }
79
TEST(AlgorithmTest,ReversePostOrder)80 TEST(AlgorithmTest, ReversePostOrder) {
81 GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
82 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
83 Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
84 Node* w2 = SourceOp("TestParams", b.opts().WithName("W2"));
85 Node* input =
86 SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
87 Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1"));
88 BinaryOp("TestMul", w1, {input, 1},
89 b.opts().WithName("t2").WithControlInput(t1));
90 BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3"));
91
92 Graph g(OpRegistry::Global());
93 TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
94 std::vector<Node*> order;
95
96 // Test reverse post order:
97 GetReversePostOrder(g, &order);
98
99 // Check that the order respects the dependencies correctly.
100 std::vector<std::pair<string, string>> reverse_orders = {
101 {"W1", "input"}, {"W1", "t1"}, {"W1", "t2"}, {"W1", "t3"},
102 {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}};
103 string error;
104 EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error;
105
106 // A false ordering should fail the check.
107 reverse_orders = {{"input", "W1"}};
108 EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error));
109
110 // Test post order:
111 GetPostOrder(g, &order);
112
113 // Check that the order respects the dependencies correctly.
114 std::vector<std::pair<string, string>> orders = {
115 {"input", "W1"}, {"t1", "W1"}, {"t2", "W1"}, {"t3", "W1"},
116 {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}};
117 EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error;
118
119 // A false ordering should fail the check.
120 orders = {{"W1", "t3"}};
121 EXPECT_FALSE(ExpectBefore(orders, order, &error));
122 }
123
TEST(AlgorithmTest,ReversePostOrderStable)124 TEST(AlgorithmTest, ReversePostOrderStable) {
125 int64 run_count = 100;
126 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
127
128 for (int64 i = 0; i < run_count; ++i) {
129 // One source of nondeterminism comes from unordered set with key of a
130 // pointer type, for example the order of FlatSet<Node*> depends on the
131 // raw pointer value of Node. Stable post order suppose to remove this
132 // nondeterminism by enforcing an ordering based on node ids.
133 GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
134 string error;
135 Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
136 Node* input =
137 SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
138 BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t2"));
139 // Insert different number of nodes between the allocation of t2 and t3,
140 // this creates enough entropy in the memory distance between t2 and t3 thus
141 // forces them to have randomized ordering had stable DFS was not
142 // implemented correctly.
143 for (int64 j = 0; j < i; ++j) {
144 BinaryOp("TestMul", w1, {input, 1},
145 b.opts().WithName(strings::StrCat("internal", j)));
146 }
147
148 BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3"));
149
150 Graph g(OpRegistry::Global());
151 TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
152 std::vector<Node*> order;
153
154 // Test reverse post order generates expected ordering.
155 GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorName());
156 EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error));
157 }
158 }
159
TEST(AlgorithmTest,PostOrderWithEdgeFilter)160 TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
161 GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
162 Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0"));
163 Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1"));
164 Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2"));
165 Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3"));
166
167 Graph g(OpRegistry::Global());
168 TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
169
170 g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1);
171
172 std::vector<Node*> post_order;
173 auto edge_filter = [&](const Edge& e) {
174 return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id());
175 };
176
177 std::vector<Node*> expected_post_order = {
178 g.sink_node(), g.FindNodeId(n3->id()), g.FindNodeId(n2->id()),
179 g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()};
180
181 std::vector<Node*> expected_reverse_post_order = expected_post_order;
182 std::reverse(expected_reverse_post_order.begin(),
183 expected_reverse_post_order.end());
184
185 GetPostOrder(g, &post_order, /*stable_comparator=*/{},
186 /*edge_filter=*/edge_filter);
187
188 ASSERT_EQ(expected_post_order.size(), post_order.size());
189 for (int i = 0; i < post_order.size(); i++) {
190 CHECK_EQ(post_order[i], expected_post_order[i])
191 << post_order[i]->name() << " vs. " << expected_post_order[i]->name();
192 }
193
194 std::vector<Node*> reverse_post_order;
195 GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{},
196 /*edge_filter=*/edge_filter);
197
198 ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size());
199 for (int i = 0; i < reverse_post_order.size(); i++) {
200 CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i])
201 << reverse_post_order[i]->name() << " vs. "
202 << expected_reverse_post_order[i]->name();
203 }
204 }
205
BM_PruneForReverseReachability(::testing::benchmark::State & state)206 void BM_PruneForReverseReachability(::testing::benchmark::State& state) {
207 const int num_nodes = state.range(0);
208 const int num_edges_per_node = state.range(1);
209 const GraphDef graph_def =
210 test::CreateGraphDef(num_nodes, num_edges_per_node);
211 const auto registry = OpRegistry::Global();
212 GraphConstructorOptions opts;
213 for (auto s : state) {
214 state.PauseTiming();
215 Graph graph(registry);
216 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
217 std::unordered_set<const Node*> visited;
218 visited.insert(graph.FindNodeId(graph.num_nodes() - 1));
219 state.ResumeTiming();
220 PruneForReverseReachability(&graph, std::move(visited));
221 }
222 }
223 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 2);
224 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 6, 2);
225 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 9, 2);
226 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 12, 2);
227 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 15, 2);
228 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 4);
229 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 6, 4);
230 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 9, 4);
231 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 12, 4);
232 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 15, 4);
233 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 8);
234 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 6, 8);
235 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 9, 8);
236 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 12, 8);
237 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 15, 8);
238 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 16);
239 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 6, 16);
240 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 9, 16);
241 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 12, 16);
242 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 15, 16);
243
244 } // namespace
245 } // namespace tensorflow
246