1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/optimizer_cse.h"
17
18 #include <utility>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/graph/testlib.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/lib/random/simple_philox.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/stringprintf.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34
35 namespace tensorflow {
36 namespace {
37
InitGraph(const string & s,Graph * graph)38 static void InitGraph(const string& s, Graph* graph) {
39 GraphDef graph_def;
40
41 auto parser = protobuf::TextFormat::Parser();
42 // parser.AllowRelaxedWhitespace(true);
43 CHECK(parser.MergeFromString(s, &graph_def)) << s;
44 GraphConstructorOptions opts;
45 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
46 }
47
48 class OptimizerCSETest : public ::testing::Test {
49 public:
OptimizerCSETest()50 OptimizerCSETest() : graph_(OpRegistry::Global()) {}
51
InitGraph(const string & s)52 void InitGraph(const string& s) {
53 ::tensorflow::InitGraph(s, &graph_);
54 original_ = CanonicalGraphString(&graph_);
55 }
56
IncludeNode(const Node * n)57 static bool IncludeNode(const Node* n) { return n->IsOp(); }
58
EdgeId(const Node * n,int index)59 static string EdgeId(const Node* n, int index) {
60 if (index == 0) {
61 return n->name();
62 } else if (index == Graph::kControlSlot) {
63 return strings::StrCat(n->name(), ":control");
64 } else {
65 return strings::StrCat(n->name(), ":", index);
66 }
67 }
68
CanonicalGraphString(Graph * g)69 string CanonicalGraphString(Graph* g) {
70 std::vector<string> nodes;
71 std::vector<string> edges;
72 for (const Node* n : g->nodes()) {
73 if (IncludeNode(n)) {
74 nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
75 }
76 }
77 for (const Edge* e : g->edges()) {
78 if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
79 edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
80 EdgeId(e->dst(), e->dst_input())));
81 }
82 }
83 // Canonicalize
84 std::sort(nodes.begin(), nodes.end());
85 std::sort(edges.begin(), edges.end());
86 return strings::StrCat(absl::StrJoin(nodes, ";"), "|",
87 absl::StrJoin(edges, ";"));
88 }
89
DoCSE(const std::function<bool (const Node *)> & consider_fn=nullptr)90 string DoCSE(const std::function<bool(const Node*)>& consider_fn = nullptr) {
91 string before = CanonicalGraphString(&graph_);
92 LOG(ERROR) << "Before rewrites: " << before;
93
94 OptimizeCSE(&graph_, consider_fn);
95
96 string result = CanonicalGraphString(&graph_);
97 LOG(ERROR) << "After rewrites: " << result;
98 return result;
99 }
100
OriginalGraph() const101 const string& OriginalGraph() const { return original_; }
102
103 Graph graph_;
104 string original_;
105 };
106
107 REGISTER_OP("Input").Output("o: float").SetIsStateful();
108
109 // Note that the "rules" in these tests are not meant to be logically correct
TEST_F(OptimizerCSETest,Simple)110 TEST_F(OptimizerCSETest, Simple) {
111 InitGraph(
112 "node { name: 'A' op: 'Input'}"
113 "node { name: 'B' op: 'Input'}"
114 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
115 " input: ['A', 'B'] }"
116 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
117 " input: ['A', 'B'] }");
118 EXPECT_EQ(DoCSE(),
119 "A(Input);B(Input);C(Mul)|"
120 "A->C;B->C:1");
121 }
122
TEST_F(OptimizerCSETest,Simple_ThreeEquivalent)123 TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) {
124 InitGraph(
125 "node { name: 'A' op: 'Input'}"
126 "node { name: 'B' op: 'Input'}"
127 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
128 " input: ['A', 'B'] }"
129 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
130 " input: ['A', 'B'] }"
131 "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
132 " input: ['A', 'B'] }");
133 EXPECT_EQ(DoCSE(),
134 "A(Input);B(Input);C(Mul)|"
135 "A->C;B->C:1");
136 }
137
TEST_F(OptimizerCSETest,Simple_WithFixups)138 TEST_F(OptimizerCSETest, Simple_WithFixups) {
139 InitGraph(
140 "node { name: 'A' op: 'Input'}"
141 "node { name: 'B' op: 'Input'}"
142 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
143 " input: ['A', 'B'] }"
144 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
145 " input: ['A', 'B'] }"
146 "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
147 " input: ['C', 'D'] }");
148 EXPECT_EQ(DoCSE(),
149 "A(Input);B(Input);C(Mul);E(Mul)|"
150 "A->C;B->C:1;C->E;C->E:1");
151 }
152
TEST_F(OptimizerCSETest,Simple_Commutative)153 TEST_F(OptimizerCSETest, Simple_Commutative) {
154 InitGraph(
155 "node { name: 'A' op: 'Input'}"
156 "node { name: 'B' op: 'Input'}"
157 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
158 " input: ['A', 'B'] }"
159 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
160 " input: ['B', 'A'] }");
161 EXPECT_EQ(DoCSE(),
162 "A(Input);B(Input);C(Mul)|"
163 "A->C;B->C:1");
164 }
165
IsNotMultiply(const Node * n)166 static bool IsNotMultiply(const Node* n) { return n->type_string() != "Mul"; }
167
168 // Like Simple_Commutative,
TEST_F(OptimizerCSETest,Simple_Filtered)169 TEST_F(OptimizerCSETest, Simple_Filtered) {
170 InitGraph(
171 "node { name: 'A' op: 'Input'}"
172 "node { name: 'B' op: 'Input'}"
173 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
174 " input: ['A', 'B'] }"
175 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
176 " input: ['B', 'A'] }");
177 EXPECT_EQ(DoCSE(IsNotMultiply), OriginalGraph());
178 }
179
TEST_F(OptimizerCSETest,Simple_NotCommutative)180 TEST_F(OptimizerCSETest, Simple_NotCommutative) {
181 InitGraph(
182 "node { name: 'A' op: 'Input'}"
183 "node { name: 'B' op: 'Input'}"
184 "node { name: 'C' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
185 " input: ['A', 'B'] }"
186 "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
187 " input: ['B', 'A'] }");
188 EXPECT_EQ(DoCSE(), OriginalGraph());
189 }
190
TEST_F(OptimizerCSETest,NotEquivalent_Ops)191 TEST_F(OptimizerCSETest, NotEquivalent_Ops) {
192 InitGraph(
193 "node { name: 'A' op: 'Input'}"
194 "node { name: 'B' op: 'Input'}"
195 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
196 " input: ['A', 'B'] }"
197 "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
198 " input: ['A', 'B'] }");
199 EXPECT_EQ(DoCSE(), OriginalGraph());
200 }
201
TEST_F(OptimizerCSETest,Simple_SameOps_SameAttrs1)202 TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs1) {
203 // Should still do CSE for ops with attrs if they match.
204 InitGraph(
205 "node { name: 'A' op: 'Input'}"
206 "node { name: 'B' op: 'Input'}"
207 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
208 " input: ['A', 'B'] attr { key: 'shape'"
209 " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }"
210 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
211 " input: ['A', 'B'] attr { key: 'shape'"
212 " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }");
213 EXPECT_EQ(DoCSE(),
214 "A(Input);B(Input);C(Mul)|"
215 "A->C;B->C:1");
216 }
217
TEST_F(OptimizerCSETest,Simple_SameOps_SameAttrs2)218 TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) {
219 // Should still do CSE for ops with attrs if they match, even if they
220 // are not in the same order.
221 InitGraph(
222 "node { name: 'A' op: 'Input'}"
223 "node { name: 'B' op: 'Input'}"
224 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
225 " input: ['A', 'B']"
226 " attr { key: 'a' value { i: 3 } }"
227 " attr { key: 't' value { type: DT_INT32 } } }"
228 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
229 " input: ['A', 'B']"
230 " attr { key: 't' value { type: DT_INT32 } }"
231 " attr { key: 'a' value { i: 3 } } }");
232 EXPECT_EQ(DoCSE(),
233 "A(Input);B(Input);C(Mul)|"
234 "A->C;B->C:1");
235 }
236
TEST_F(OptimizerCSETest,SameConstants)237 TEST_F(OptimizerCSETest, SameConstants) {
238 // Should still do CSE for ops with constants if the values are identical
239 InitGraph(
240 "node { name: 'A' op: 'Const' "
241 " attr { key: 'dtype' value { type: DT_INT32 } }"
242 " attr { key: 'value' value {"
243 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
244 " int_val: 0 } } } }"
245 "node { name: 'B' op: 'Const' "
246 " attr { key: 'dtype' value { type: DT_INT32 } }"
247 " attr { key: 'value' value {"
248 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
249 " int_val: 0 } } } }"
250 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }"
251 " input: ['A', 'B'] }");
252 EXPECT_EQ(DoCSE(),
253 "A(Const);D(Mul)|"
254 "A->D;A->D:1");
255 }
256
TEST_F(OptimizerCSETest,DifferentConstants)257 TEST_F(OptimizerCSETest, DifferentConstants) {
258 // Should still do CSE for ops with extensions if the extensions are identical
259 InitGraph(
260 "node { name: 'A' op: 'Const' "
261 " attr { key: 'dtype' value { type: DT_INT32 } }"
262 " attr { key: 'value' value {"
263 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
264 " int_val: 0 } } } }"
265 "node { name: 'B' op: 'Const' "
266 " attr { key: 'dtype' value { type: DT_INT32 } }"
267 " attr { key: 'value' value {"
268 " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
269 " int_val: 100000 } } } }"
270 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }"
271 " input: ['A', 'B'] }");
272 EXPECT_EQ(DoCSE(),
273 "A(Const);B(Const);D(Mul)|"
274 "A->D;B->D:1");
275 }
276
TEST_F(OptimizerCSETest,SameOps_DifferentAttrs1)277 TEST_F(OptimizerCSETest, SameOps_DifferentAttrs1) {
278 InitGraph(
279 "node { name: 'A' op: 'Input'}"
280 "node { name: 'B' op: 'Input'}"
281 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
282 " input: ['A', 'B']"
283 " attr { key: 'a' value { i: 3 } }"
284 " attr { key: 't' value { type: DT_INT32 } } }"
285 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
286 " input: ['A', 'B']"
287 " attr { key: 't' value { type: DT_INT32 } }"
288 " attr { key: 'a' value { i: 4 } } }");
289 EXPECT_EQ(DoCSE(), OriginalGraph());
290 }
291
TEST_F(OptimizerCSETest,SameOps_DifferentAttrs2)292 TEST_F(OptimizerCSETest, SameOps_DifferentAttrs2) {
293 InitGraph(
294 "node { name: 'A' op: 'Input'}"
295 "node { name: 'B' op: 'Input'}"
296 "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
297 " input: ['A', 'B']"
298 " attr { key: 'a' value { i: 3 } }"
299 " attr { key: 't' value { type: DT_FLOAT } } }"
300 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
301 " input: ['A', 'B']"
302 " attr { key: 't' value { type: DT_INT32 } }"
303 " attr { key: 'a' value { i: 3 } } }");
304 EXPECT_EQ(DoCSE(), OriginalGraph());
305 }
306
TEST_F(OptimizerCSETest,NotEquivalent_Inputs)307 TEST_F(OptimizerCSETest, NotEquivalent_Inputs) {
308 InitGraph(
309 "node { name: 'A' op: 'Input'}"
310 "node { name: 'B' op: 'Input'}"
311 "node { name: 'C' op: 'Input'}"
312 "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
313 " input: ['A', 'B'] }"
314 "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
315 " input: ['A', 'C'] }");
316 EXPECT_EQ(DoCSE(), OriginalGraph());
317 }
318
TEST_F(OptimizerCSETest,Constant_Dedup)319 TEST_F(OptimizerCSETest, Constant_Dedup) {
320 Tensor a(DT_FLOAT, TensorShape({1}));
321 a.flat<float>()(0) = 1.0;
322 Tensor b(DT_DOUBLE, TensorShape({1})); // Different type
323 b.flat<double>()(0) = 1.0;
324 Tensor c(DT_FLOAT, TensorShape({1, 1})); // Different shape
325 c.flat<float>()(0) = 1.0;
326 Tensor d(DT_FLOAT, TensorShape({1})); // Different value
327 d.flat<float>()(0) = 2.0;
328
329 // A graph contains a bunch of constants.
330 Graph g(OpRegistry::Global());
331 for (const auto& val : {a, b, c, d, d, c, b, a}) {
332 test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ...
333 }
334 GraphDef gdef;
335 test::graph::ToGraphDef(&g, &gdef);
336 InitGraph(gdef.DebugString());
337
338 EXPECT_EQ(OriginalGraph(),
339 "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);"
340 "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|");
341 std::vector<string> nodes = str_util::Split(DoCSE(), ";|");
342 std::set<string> node_set(nodes.begin(), nodes.end());
343 // Expect exactly one of each type of node to be retained after CSE.
344 EXPECT_EQ(node_set.count("n/_0(Const)") + node_set.count("n/_7(Const)"), 1);
345 EXPECT_EQ(node_set.count("n/_1(Const)") + node_set.count("n/_6(Const)"), 1);
346 EXPECT_EQ(node_set.count("n/_2(Const)") + node_set.count("n/_5(Const)"), 1);
347 EXPECT_EQ(node_set.count("n/_3(Const)") + node_set.count("n/_4(Const)"), 1);
348 }
349
BM_CSE(::testing::benchmark::State & state)350 void BM_CSE(::testing::benchmark::State& state) {
351 const int op_nodes = state.range(0);
352 string s;
353 for (int in = 0; in < 10; in++) {
354 s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
355 }
356 random::PhiloxRandom philox(301, 17);
357 random::SimplePhilox rnd(&philox);
358 for (int op = 0; op < op_nodes; op++) {
359 s += strings::Printf(
360 "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
361 "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
362 op, rnd.Uniform(10), rnd.Uniform(10));
363 }
364
365 bool first = true;
366 for (auto i : state) {
367 state.PauseTiming();
368 Graph* graph = new Graph(OpRegistry::Global());
369 InitGraph(s, graph);
370 int N = graph->num_node_ids();
371 if (first) {
372 state.SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
373 first = false;
374 }
375 {
376 state.ResumeTiming();
377 OptimizeCSE(graph, nullptr);
378 state.PauseTiming();
379 }
380 delete graph;
381 state.ResumeTiming();
382 }
383 }
384 BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000);
385
386 } // namespace
387 } // namespace tensorflow
388