• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/optimizer_cse.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/graph/testlib.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/lib/random/simple_philox.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/stringprintf.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 
35 namespace tensorflow {
36 namespace {
37 
InitGraph(const string & s,Graph * graph)38 static void InitGraph(const string& s, Graph* graph) {
39   GraphDef graph_def;
40 
41   auto parser = protobuf::TextFormat::Parser();
42   //  parser.AllowRelaxedWhitespace(true);
43   CHECK(parser.MergeFromString(s, &graph_def)) << s;
44   GraphConstructorOptions opts;
45   TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
46 }
47 
48 class OptimizerCSETest : public ::testing::Test {
49  public:
OptimizerCSETest()50   OptimizerCSETest() : graph_(OpRegistry::Global()) {}
51 
InitGraph(const string & s)52   void InitGraph(const string& s) {
53     ::tensorflow::InitGraph(s, &graph_);
54     original_ = CanonicalGraphString(&graph_);
55   }
56 
IncludeNode(const Node * n)57   static bool IncludeNode(const Node* n) { return n->IsOp(); }
58 
EdgeId(const Node * n,int index)59   static string EdgeId(const Node* n, int index) {
60     if (index == 0) {
61       return n->name();
62     } else if (index == Graph::kControlSlot) {
63       return strings::StrCat(n->name(), ":control");
64     } else {
65       return strings::StrCat(n->name(), ":", index);
66     }
67   }
68 
CanonicalGraphString(Graph * g)69   string CanonicalGraphString(Graph* g) {
70     std::vector<string> nodes;
71     std::vector<string> edges;
72     for (const Node* n : g->nodes()) {
73       if (IncludeNode(n)) {
74         nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
75       }
76     }
77     for (const Edge* e : g->edges()) {
78       if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
79         edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
80                                         EdgeId(e->dst(), e->dst_input())));
81       }
82     }
83     // Canonicalize
84     std::sort(nodes.begin(), nodes.end());
85     std::sort(edges.begin(), edges.end());
86     return strings::StrCat(absl::StrJoin(nodes, ";"), "|",
87                            absl::StrJoin(edges, ";"));
88   }
89 
DoCSE(const std::function<bool (const Node *)> & consider_fn=nullptr)90   string DoCSE(const std::function<bool(const Node*)>& consider_fn = nullptr) {
91     string before = CanonicalGraphString(&graph_);
92     LOG(ERROR) << "Before rewrites: " << before;
93 
94     OptimizeCSE(&graph_, consider_fn);
95 
96     string result = CanonicalGraphString(&graph_);
97     LOG(ERROR) << "After rewrites:  " << result;
98     return result;
99   }
100 
OriginalGraph() const101   const string& OriginalGraph() const { return original_; }
102 
103   Graph graph_;
104   string original_;
105 };
106 
107 REGISTER_OP("Input").Output("o: float").SetIsStateful();
108 
109 // Note that the "rules" in these tests are not meant to be logically correct
TEST_F(OptimizerCSETest,Simple)110 TEST_F(OptimizerCSETest, Simple) {
111   InitGraph(
112       "node { name: 'A' op: 'Input'}"
113       "node { name: 'B' op: 'Input'}"
114       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
115       " input: ['A', 'B'] }"
116       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
117       " input: ['A', 'B'] }");
118   EXPECT_EQ(DoCSE(),
119             "A(Input);B(Input);C(Mul)|"
120             "A->C;B->C:1");
121 }
122 
TEST_F(OptimizerCSETest,Simple_ThreeEquivalent)123 TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) {
124   InitGraph(
125       "node { name: 'A' op: 'Input'}"
126       "node { name: 'B' op: 'Input'}"
127       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
128       " input: ['A', 'B'] }"
129       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
130       " input: ['A', 'B'] }"
131       "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
132       " input: ['A', 'B'] }");
133   EXPECT_EQ(DoCSE(),
134             "A(Input);B(Input);C(Mul)|"
135             "A->C;B->C:1");
136 }
137 
TEST_F(OptimizerCSETest,Simple_WithFixups)138 TEST_F(OptimizerCSETest, Simple_WithFixups) {
139   InitGraph(
140       "node { name: 'A' op: 'Input'}"
141       "node { name: 'B' op: 'Input'}"
142       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
143       " input: ['A', 'B'] }"
144       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
145       " input: ['A', 'B'] }"
146       "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
147       " input: ['C', 'D'] }");
148   EXPECT_EQ(DoCSE(),
149             "A(Input);B(Input);C(Mul);E(Mul)|"
150             "A->C;B->C:1;C->E;C->E:1");
151 }
152 
TEST_F(OptimizerCSETest,Simple_Commutative)153 TEST_F(OptimizerCSETest, Simple_Commutative) {
154   InitGraph(
155       "node { name: 'A' op: 'Input'}"
156       "node { name: 'B' op: 'Input'}"
157       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
158       " input: ['A', 'B'] }"
159       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
160       " input: ['B', 'A'] }");
161   EXPECT_EQ(DoCSE(),
162             "A(Input);B(Input);C(Mul)|"
163             "A->C;B->C:1");
164 }
165 
IsNotMultiply(const Node * n)166 static bool IsNotMultiply(const Node* n) { return n->type_string() != "Mul"; }
167 
168 // Like Simple_Commutative,
TEST_F(OptimizerCSETest,Simple_Filtered)169 TEST_F(OptimizerCSETest, Simple_Filtered) {
170   InitGraph(
171       "node { name: 'A' op: 'Input'}"
172       "node { name: 'B' op: 'Input'}"
173       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
174       " input: ['A', 'B'] }"
175       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
176       " input: ['B', 'A'] }");
177   EXPECT_EQ(DoCSE(IsNotMultiply), OriginalGraph());
178 }
179 
TEST_F(OptimizerCSETest,Simple_NotCommutative)180 TEST_F(OptimizerCSETest, Simple_NotCommutative) {
181   InitGraph(
182       "node { name: 'A' op: 'Input'}"
183       "node { name: 'B' op: 'Input'}"
184       "node { name: 'C' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
185       " input: ['A', 'B'] }"
186       "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
187       " input: ['B', 'A'] }");
188   EXPECT_EQ(DoCSE(), OriginalGraph());
189 }
190 
TEST_F(OptimizerCSETest,NotEquivalent_Ops)191 TEST_F(OptimizerCSETest, NotEquivalent_Ops) {
192   InitGraph(
193       "node { name: 'A' op: 'Input'}"
194       "node { name: 'B' op: 'Input'}"
195       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
196       " input: ['A', 'B'] }"
197       "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
198       " input: ['A', 'B'] }");
199   EXPECT_EQ(DoCSE(), OriginalGraph());
200 }
201 
TEST_F(OptimizerCSETest,Simple_SameOps_SameAttrs1)202 TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs1) {
203   // Should still do CSE for ops with attrs if they match.
204   InitGraph(
205       "node { name: 'A' op: 'Input'}"
206       "node { name: 'B' op: 'Input'}"
207       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
208       " input: ['A', 'B'] attr { key: 'shape'"
209       "    value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }"
210       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
211       " input: ['A', 'B'] attr { key: 'shape'"
212       "    value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }");
213   EXPECT_EQ(DoCSE(),
214             "A(Input);B(Input);C(Mul)|"
215             "A->C;B->C:1");
216 }
217 
TEST_F(OptimizerCSETest,Simple_SameOps_SameAttrs2)218 TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) {
219   // Should still do CSE for ops with attrs if they match, even if they
220   // are not in the same order.
221   InitGraph(
222       "node { name: 'A' op: 'Input'}"
223       "node { name: 'B' op: 'Input'}"
224       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
225       " input: ['A', 'B']"
226       "    attr { key: 'a' value { i: 3 } }"
227       "    attr { key: 't' value { type: DT_INT32 } } }"
228       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
229       " input: ['A', 'B']"
230       "    attr { key: 't' value { type: DT_INT32 } }"
231       "    attr { key: 'a' value { i: 3 } } }");
232   EXPECT_EQ(DoCSE(),
233             "A(Input);B(Input);C(Mul)|"
234             "A->C;B->C:1");
235 }
236 
TEST_F(OptimizerCSETest,SameConstants)237 TEST_F(OptimizerCSETest, SameConstants) {
238   // Should still do CSE for ops with constants if the values are identical
239   InitGraph(
240       "node { name: 'A' op: 'Const' "
241       "  attr { key: 'dtype' value { type: DT_INT32 } }"
242       "  attr { key: 'value' value {"
243       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
244       "    int_val: 0 } } } }"
245       "node { name: 'B' op: 'Const' "
246       "  attr { key: 'dtype' value { type: DT_INT32 } }"
247       "  attr { key: 'value' value {"
248       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
249       "    int_val: 0 } } } }"
250       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }"
251       " input: ['A', 'B'] }");
252   EXPECT_EQ(DoCSE(),
253             "A(Const);D(Mul)|"
254             "A->D;A->D:1");
255 }
256 
TEST_F(OptimizerCSETest,DifferentConstants)257 TEST_F(OptimizerCSETest, DifferentConstants) {
258   // Should still do CSE for ops with extensions if the extensions are identical
259   InitGraph(
260       "node { name: 'A' op: 'Const' "
261       "  attr { key: 'dtype' value { type: DT_INT32 } }"
262       "  attr { key: 'value' value {"
263       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
264       "    int_val: 0 } } } }"
265       "node { name: 'B' op: 'Const' "
266       "  attr { key: 'dtype' value { type: DT_INT32 } }"
267       "  attr { key: 'value' value {"
268       "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
269       "    int_val: 100000 } } } }"
270       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }"
271       " input: ['A', 'B'] }");
272   EXPECT_EQ(DoCSE(),
273             "A(Const);B(Const);D(Mul)|"
274             "A->D;B->D:1");
275 }
276 
TEST_F(OptimizerCSETest,SameOps_DifferentAttrs1)277 TEST_F(OptimizerCSETest, SameOps_DifferentAttrs1) {
278   InitGraph(
279       "node { name: 'A' op: 'Input'}"
280       "node { name: 'B' op: 'Input'}"
281       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
282       " input: ['A', 'B']"
283       "    attr { key: 'a' value { i: 3 } }"
284       "    attr { key: 't' value { type: DT_INT32 } } }"
285       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
286       " input: ['A', 'B']"
287       "    attr { key: 't' value { type: DT_INT32 } }"
288       "    attr { key: 'a' value { i: 4 } } }");
289   EXPECT_EQ(DoCSE(), OriginalGraph());
290 }
291 
TEST_F(OptimizerCSETest,SameOps_DifferentAttrs2)292 TEST_F(OptimizerCSETest, SameOps_DifferentAttrs2) {
293   InitGraph(
294       "node { name: 'A' op: 'Input'}"
295       "node { name: 'B' op: 'Input'}"
296       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
297       " input: ['A', 'B']"
298       "    attr { key: 'a' value { i: 3 } }"
299       "    attr { key: 't' value { type: DT_FLOAT } } }"
300       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
301       " input: ['A', 'B']"
302       "    attr { key: 't' value { type: DT_INT32 } }"
303       "    attr { key: 'a' value { i: 3 } } }");
304   EXPECT_EQ(DoCSE(), OriginalGraph());
305 }
306 
TEST_F(OptimizerCSETest,NotEquivalent_Inputs)307 TEST_F(OptimizerCSETest, NotEquivalent_Inputs) {
308   InitGraph(
309       "node { name: 'A' op: 'Input'}"
310       "node { name: 'B' op: 'Input'}"
311       "node { name: 'C' op: 'Input'}"
312       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
313       " input: ['A', 'B'] }"
314       "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
315       " input: ['A', 'C'] }");
316   EXPECT_EQ(DoCSE(), OriginalGraph());
317 }
318 
TEST_F(OptimizerCSETest,Constant_Dedup)319 TEST_F(OptimizerCSETest, Constant_Dedup) {
320   Tensor a(DT_FLOAT, TensorShape({1}));
321   a.flat<float>()(0) = 1.0;
322   Tensor b(DT_DOUBLE, TensorShape({1}));  // Different type
323   b.flat<double>()(0) = 1.0;
324   Tensor c(DT_FLOAT, TensorShape({1, 1}));  // Different shape
325   c.flat<float>()(0) = 1.0;
326   Tensor d(DT_FLOAT, TensorShape({1}));  // Different value
327   d.flat<float>()(0) = 2.0;
328 
329   // A graph contains a bunch of constants.
330   Graph g(OpRegistry::Global());
331   for (const auto& val : {a, b, c, d, d, c, b, a}) {
332     test::graph::Constant(&g, val);  // Node name is n/_0, n/_1, ...
333   }
334   GraphDef gdef;
335   test::graph::ToGraphDef(&g, &gdef);
336   InitGraph(gdef.DebugString());
337 
338   EXPECT_EQ(OriginalGraph(),
339             "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);"
340             "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|");
341   std::vector<string> nodes = str_util::Split(DoCSE(), ";|");
342   std::set<string> node_set(nodes.begin(), nodes.end());
343   // Expect exactly one of each type of node to be retained after CSE.
344   EXPECT_EQ(node_set.count("n/_0(Const)") + node_set.count("n/_7(Const)"), 1);
345   EXPECT_EQ(node_set.count("n/_1(Const)") + node_set.count("n/_6(Const)"), 1);
346   EXPECT_EQ(node_set.count("n/_2(Const)") + node_set.count("n/_5(Const)"), 1);
347   EXPECT_EQ(node_set.count("n/_3(Const)") + node_set.count("n/_4(Const)"), 1);
348 }
349 
BM_CSE(::testing::benchmark::State & state)350 void BM_CSE(::testing::benchmark::State& state) {
351   const int op_nodes = state.range(0);
352   string s;
353   for (int in = 0; in < 10; in++) {
354     s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
355   }
356   random::PhiloxRandom philox(301, 17);
357   random::SimplePhilox rnd(&philox);
358   for (int op = 0; op < op_nodes; op++) {
359     s += strings::Printf(
360         "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
361         "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
362         op, rnd.Uniform(10), rnd.Uniform(10));
363   }
364 
365   bool first = true;
366   for (auto i : state) {
367     state.PauseTiming();
368     Graph* graph = new Graph(OpRegistry::Global());
369     InitGraph(s, graph);
370     int N = graph->num_node_ids();
371     if (first) {
372       state.SetLabel(strings::StrCat("Per graph node.  Nodes: ", N));
373       first = false;
374     }
375     {
376       state.ResumeTiming();
377       OptimizeCSE(graph, nullptr);
378       state.PauseTiming();
379     }
380     delete graph;
381     state.ResumeTiming();
382   }
383 }
384 BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000);
385 
386 }  // namespace
387 }  // namespace tensorflow
388