• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/algorithm.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
23 #include "tensorflow/core/graph/benchmark_testlib.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/graph/graph_def_builder.h"
26 #include "tensorflow/core/graph/subgraph.h"
27 #include "tensorflow/core/kernels/ops_util.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/platform/test_benchmark.h"
32 
33 // TODO(josh11b): Test setting the "device" field of a NodeDef.
34 // TODO(josh11b): Test that feeding won't prune targets.
35 
36 namespace tensorflow {
37 namespace {
38 
39 REGISTER_OP("TestParams").Output("o: float");
40 REGISTER_OP("TestInput").Output("a: float").Output("b: float");
41 REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
42 REGISTER_OP("TestUnary").Input("a: float").Output("o: float");
43 REGISTER_OP("TestBinary")
44     .Input("a: float")
45     .Input("b: float")
46     .Output("o: float");
47 
48 // Compares that the order of nodes in 'inputs' respects the
49 // pair orders described in 'ordered_pairs'.
ExpectBefore(const std::vector<std::pair<string,string>> & ordered_pairs,const std::vector<Node * > & inputs,string * error)50 bool ExpectBefore(const std::vector<std::pair<string, string>>& ordered_pairs,
51                   const std::vector<Node*>& inputs, string* error) {
52   for (const std::pair<string, string>& pair : ordered_pairs) {
53     const string& before_node = pair.first;
54     const string& after_node = pair.second;
55     bool seen_before = false;
56     bool seen_both = false;
57     for (const Node* node : inputs) {
58       if (!seen_before && after_node == node->name()) {
59         *error = strings::StrCat("Saw ", after_node, " before ", before_node);
60         return false;
61       }
62 
63       if (before_node == node->name()) {
64         seen_before = true;
65       } else if (after_node == node->name()) {
66         seen_both = seen_before;
67         break;
68       }
69     }
70     if (!seen_both) {
71       *error = strings::StrCat("didn't see either ", before_node, " or ",
72                                after_node);
73       return false;
74     }
75   }
76 
77   return true;
78 }
79 
TEST(AlgorithmTest,ReversePostOrder)80 TEST(AlgorithmTest, ReversePostOrder) {
81   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
82   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
83   Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
84   Node* w2 = SourceOp("TestParams", b.opts().WithName("W2"));
85   Node* input =
86       SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
87   Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1"));
88   BinaryOp("TestMul", w1, {input, 1},
89            b.opts().WithName("t2").WithControlInput(t1));
90   BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3"));
91 
92   Graph g(OpRegistry::Global());
93   TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
94   std::vector<Node*> order;
95 
96   // Test reverse post order:
97   GetReversePostOrder(g, &order);
98 
99   // Check that the order respects the dependencies correctly.
100   std::vector<std::pair<string, string>> reverse_orders = {
101       {"W1", "input"}, {"W1", "t1"},    {"W1", "t2"}, {"W1", "t3"},
102       {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}};
103   string error;
104   EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error;
105 
106   // A false ordering should fail the check.
107   reverse_orders = {{"input", "W1"}};
108   EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error));
109 
110   // Test post order:
111   GetPostOrder(g, &order);
112 
113   // Check that the order respects the dependencies correctly.
114   std::vector<std::pair<string, string>> orders = {
115       {"input", "W1"}, {"t1", "W1"},    {"t2", "W1"}, {"t3", "W1"},
116       {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}};
117   EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error;
118 
119   // A false ordering should fail the check.
120   orders = {{"W1", "t3"}};
121   EXPECT_FALSE(ExpectBefore(orders, order, &error));
122 }
123 
TEST(AlgorithmTest,ReversePostOrderStable)124 TEST(AlgorithmTest, ReversePostOrderStable) {
125   int64 run_count = 100;
126   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
127 
128   for (int64 i = 0; i < run_count; ++i) {
129     // One source of nondeterminism comes from unordered set with key of a
130     // pointer type, for example the order of FlatSet<Node*> depends on the
131     // raw pointer value of Node. Stable post order suppose to remove this
132     // nondeterminism by enforcing an ordering based on node ids.
133     GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
134     string error;
135     Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
136     Node* input =
137         SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
138     BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t2"));
139     // Insert different number of nodes between the allocation of t2 and t3,
140     // this creates enough entropy in the memory distance between t2 and t3 thus
141     // forces them to have randomized ordering had stable DFS was not
142     // implemented correctly.
143     for (int64 j = 0; j < i; ++j) {
144       BinaryOp("TestMul", w1, {input, 1},
145                b.opts().WithName(strings::StrCat("internal", j)));
146     }
147 
148     BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3"));
149 
150     Graph g(OpRegistry::Global());
151     TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
152     std::vector<Node*> order;
153 
154     // Test reverse post order generates expected ordering.
155     GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorName());
156     EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error));
157   }
158 }
159 
TEST(AlgorithmTest,PostOrderWithEdgeFilter)160 TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
161   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
162   Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0"));
163   Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1"));
164   Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2"));
165   Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3"));
166 
167   Graph g(OpRegistry::Global());
168   TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
169 
170   g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1);
171 
172   std::vector<Node*> post_order;
173   auto edge_filter = [&](const Edge& e) {
174     return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id());
175   };
176 
177   std::vector<Node*> expected_post_order = {
178       g.sink_node(),          g.FindNodeId(n3->id()), g.FindNodeId(n2->id()),
179       g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()};
180 
181   std::vector<Node*> expected_reverse_post_order = expected_post_order;
182   std::reverse(expected_reverse_post_order.begin(),
183                expected_reverse_post_order.end());
184 
185   GetPostOrder(g, &post_order, /*stable_comparator=*/{},
186                /*edge_filter=*/edge_filter);
187 
188   ASSERT_EQ(expected_post_order.size(), post_order.size());
189   for (int i = 0; i < post_order.size(); i++) {
190     CHECK_EQ(post_order[i], expected_post_order[i])
191         << post_order[i]->name() << " vs. " << expected_post_order[i]->name();
192   }
193 
194   std::vector<Node*> reverse_post_order;
195   GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{},
196                       /*edge_filter=*/edge_filter);
197 
198   ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size());
199   for (int i = 0; i < reverse_post_order.size(); i++) {
200     CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i])
201         << reverse_post_order[i]->name() << " vs. "
202         << expected_reverse_post_order[i]->name();
203   }
204 }
205 
BM_PruneForReverseReachability(::testing::benchmark::State & state)206 void BM_PruneForReverseReachability(::testing::benchmark::State& state) {
207   const int num_nodes = state.range(0);
208   const int num_edges_per_node = state.range(1);
209   const GraphDef graph_def =
210       test::CreateGraphDef(num_nodes, num_edges_per_node);
211   const auto registry = OpRegistry::Global();
212   GraphConstructorOptions opts;
213   for (auto s : state) {
214     state.PauseTiming();
215     Graph graph(registry);
216     TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
217     std::unordered_set<const Node*> visited;
218     visited.insert(graph.FindNodeId(graph.num_nodes() - 1));
219     state.ResumeTiming();
220     PruneForReverseReachability(&graph, std::move(visited));
221   }
222 }
223 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 2);
224 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 6, 2);
225 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 9, 2);
226 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 12, 2);
227 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 15, 2);
228 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 4);
229 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 6, 4);
230 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 9, 4);
231 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 12, 4);
232 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 15, 4);
233 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 8);
234 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 6, 8);
235 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 9, 8);
236 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 12, 8);
237 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 15, 8);
238 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 16);
239 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 6, 16);
240 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 9, 16);
241 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 12, 16);
242 BENCHMARK(BM_PruneForReverseReachability)->ArgPair(1 << 15, 16);
243 
244 }  // namespace
245 }  // namespace tensorflow
246