• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
17 
18 #include "absl/strings/str_format.h"
19 #include "absl/strings/str_join.h"
20 
21 namespace tensorflow {
22 namespace grappler {
23 namespace graph_analyzer {
24 namespace test {
25 
26 //=== Helper methods to construct the nodes.
27 
MakeNodeConst(const string & name)28 NodeDef MakeNodeConst(const string& name) {
29   NodeDef n;
30   n.set_name(name);
31   n.set_op("Const");
32   return n;
33 }
34 
MakeNode2Arg(const string & name,const string & opcode,const string & arg1,const string & arg2)35 NodeDef MakeNode2Arg(const string& name, const string& opcode,
36                      const string& arg1, const string& arg2) {
37   NodeDef n;
38   n.set_name(name);
39   n.set_op(opcode);
40   n.add_input(arg1);
41   n.add_input(arg2);
42   return n;
43 }
44 
MakeNode4Arg(const string & name,const string & opcode,const string & arg1,const string & arg2,const string & arg3,const string & arg4)45 NodeDef MakeNode4Arg(const string& name, const string& opcode,
46                      const string& arg1, const string& arg2, const string& arg3,
47                      const string& arg4) {
48   NodeDef n;
49   n.set_name(name);
50   n.set_op(opcode);
51   n.add_input(arg1);
52   n.add_input(arg2);
53   n.add_input(arg3);
54   n.add_input(arg4);
55   return n;
56 }
57 
58 // Not really a 2-argument but convenient to construct.
MakeNodeShapeN(const string & name,const string & arg1,const string & arg2)59 NodeDef MakeNodeShapeN(const string& name, const string& arg1,
60                        const string& arg2) {
61   // This opcode is multi-input but not commutative.
62   return MakeNode2Arg(name, "ShapeN", arg1, arg2);
63 }
64 
65 // Not really a 2-argument but convenient to construct.
MakeNodeIdentityN(const string & name,const string & arg1,const string & arg2)66 NodeDef MakeNodeIdentityN(const string& name, const string& arg1,
67                           const string& arg2) {
68   // The argument is of a list type.
69   return MakeNode2Arg(name, "IdentityN", arg1, arg2);
70 }
71 
MakeNodeQuantizedConcat(const string & name,const string & arg1,const string & arg2,const string & arg3,const string & arg4)72 NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1,
73                                 const string& arg2, const string& arg3,
74                                 const string& arg4) {
75   // This opcode has multiple multi-inputs.
76   return MakeNode4Arg(name, "QuantizedConcat", arg1, arg2, arg3, arg4);
77 }
78 
79 //=== Helper methods for analysing the structures.
80 
DumpLinkMap(const GenNode::LinkMap & link_map)81 std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map) {
82   // This will order the entries first.
83   std::map<string, string> ordered;
84   for (const auto& link : link_map) {
85     string key = string(link.first);
86 
87     // Order the other sides too. They may be repeating, so store them
88     // in a multiset.
89     std::multiset<string> others;
90     for (const auto& other : link.second) {
91       others.emplace(
92           absl::StrFormat("%s[%s]", other.node->name(), string(other.port)));
93     }
94     ordered[key] = absl::StrJoin(others, ", ");
95   }
96   // Now dump the result in a predictable order.
97   std::vector<string> result;
98   result.reserve(ordered.size());
99   for (const auto& link : ordered) {
100     result.emplace_back(link.first + ": " + link.second);
101   }
102   return result;
103 }
104 
DumpLinkHashMap(const SigNode::LinkHashMap & link_hash_map)105 std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map) {
106   // The entries in this map are ordered by hash value which might change
107   // at any point. Re-order them by the link tag.
108   std::map<SigNode::LinkTag, size_t> tags;
109   for (const auto& entry : link_hash_map) {
110     tags[entry.second.tag] = entry.first;
111   }
112 
113   std::vector<string> result;
114   for (const auto& id : tags) {
115     // For predictability, the nodes need to be sorted.
116     std::vector<string> nodes;
117     for (const auto& peer : link_hash_map.at(id.second).peers) {
118       nodes.emplace_back(peer->name());
119     }
120     std::sort(nodes.begin(), nodes.end());
121     result.emplace_back(string(id.first.local) + ":" + string(id.first.remote) +
122                         ": " + absl::StrJoin(nodes, ", "));
123   }
124   return result;
125 }
126 
DumpHashedPeerVector(const SigNode::HashedPeerVector & hashed_peers)127 std::vector<string> DumpHashedPeerVector(
128     const SigNode::HashedPeerVector& hashed_peers) {
129   std::vector<string> result;
130 
131   // Each subset of nodes with the same hash has to be sorted by name.
132   // Other than that, the vector is already ordered by full tags.
133   size_t last_hash = 0;
134   // Index, since iterators may get invalidated on append.
135   size_t subset_start = 0;
136 
137   for (const auto& entry : hashed_peers) {
138     if (entry.link_hash != last_hash) {
139       std::sort(result.begin() + subset_start, result.end());
140       subset_start = result.size();
141     }
142     result.emplace_back(entry.peer->name());
143   }
144   std::sort(result.begin() + subset_start, result.end());
145 
146   return result;
147 }
148 
TestGraphs()149 TestGraphs::TestGraphs() {
150   {
151     GraphDef& graph = graph_3n_self_control_;
152     // The topology includes a loop and a link to self.
153     (*graph.add_node()) = MakeNodeConst("node1");
154     (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
155     auto node3 = graph.add_node();
156     *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
157     node3->add_input("^node3");  // The control link goes back to self.
158   }
159   {
160     GraphDef& graph = graph_multi_input_;
161     // The topology includes a loop and a link to self.
162     (*graph.add_node()) = MakeNodeConst("const1_1");
163     (*graph.add_node()) = MakeNodeConst("const1_2");
164     (*graph.add_node()) = MakeNodeAddN("add1", "const1_1", "const1_2");
165 
166     (*graph.add_node()) = MakeNodeConst("const2_1");
167     (*graph.add_node()) = MakeNodeConst("const2_2");
168     (*graph.add_node()) = MakeNodeConst("const2_3");
169 
170     auto add2 = graph.add_node();
171     *add2 = MakeNodeAddN("add2", "const2_1", "const2_2");
172     // The 3rd node is connected twice, to 4 links total.
173     add2->add_input("const2_3");
174     add2->add_input("const2_3");
175 
176     (*graph.add_node()) = MakeNodeSub("sub", "add1", "add2");
177   }
178   {
179     GraphDef& graph = graph_all_or_none_;
180     // The topology includes a loop and a link to self.
181     (*graph.add_node()) = MakeNodeConst("const1_1");
182     (*graph.add_node()) = MakeNodeConst("const1_2");
183     auto pass1 = graph.add_node();
184     *pass1 = MakeNodeIdentityN("pass1", "const1_1", "const1_2");
185 
186     (*graph.add_node()) = MakeNodeConst("const2_1");
187     (*graph.add_node()) = MakeNodeConst("const2_2");
188     (*graph.add_node()) = MakeNodeConst("const2_3");
189 
190     auto pass2 = graph.add_node();
191     *pass2 = MakeNodeIdentityN("pass2", "const2_1", "const2_2");
192     // The 3rd node is connected twice, to 4 links total.
193     pass2->add_input("const2_3");
194     pass2->add_input("const2_3");
195 
196     // Add the control links, they get handled separately than the normal
197     // links.
198     pass1->add_input("^const2_1");
199     pass1->add_input("^const2_2");
200     pass1->add_input("^const2_3");
201 
202     (*graph.add_node()) = MakeNodeSub("sub", "pass1", "pass2");
203   }
204   {
205     GraphDef& graph = graph_circular_onedir_;
206     (*graph.add_node()) = MakeNodeMul("node1", "node5", "node5");
207     (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
208     (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
209     (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
210     (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
211   }
212   {
213     GraphDef& graph = graph_circular_bidir_;
214     // The left and right links are intentionally mixed up.
215     (*graph.add_node()) = MakeNodeMul("node1", "node5", "node2");
216     (*graph.add_node()) = MakeNodeMul("node2", "node3", "node1");
217     (*graph.add_node()) = MakeNodeMul("node3", "node2", "node4");
218     (*graph.add_node()) = MakeNodeMul("node4", "node5", "node3");
219     (*graph.add_node()) = MakeNodeMul("node5", "node4", "node1");
220   }
221   {
222     GraphDef& graph = graph_linear_;
223     (*graph.add_node()) = MakeNodeConst("node1");
224     (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
225     (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
226     (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
227     (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
228   }
229   {
230     GraphDef& graph = graph_cross_;
231     (*graph.add_node()) = MakeNodeConst("node1");
232     (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
233     (*graph.add_node()) = MakeNodeConst("node3");
234     (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
235     (*graph.add_node()) = MakeNodeConst("node5");
236     (*graph.add_node()) = MakeNodeMul("node6", "node5", "node5");
237     (*graph.add_node()) = MakeNodeConst("node7");
238     (*graph.add_node()) = MakeNodeMul("node8", "node7", "node7");
239 
240     auto center = graph.add_node();
241     *center = MakeNodeMul("node9", "node2", "node4");
242     center->add_input("node6");
243     center->add_input("node8");
244   }
245   {
246     GraphDef& graph = graph_small_cross_;
247     (*graph.add_node()) = MakeNodeConst("node1");
248     (*graph.add_node()) = MakeNodeConst("node2");
249     (*graph.add_node()) = MakeNodeConst("node3");
250     (*graph.add_node()) = MakeNodeConst("node4");
251 
252     auto center = graph.add_node();
253     *center = MakeNodeMul("node5", "node1", "node2");
254     center->add_input("node3");
255     center->add_input("node4");
256   }
257   {
258     GraphDef& graph = graph_for_link_order_;
259     (*graph.add_node()) = MakeNodeConst("node1");
260     (*graph.add_node()) = MakeNodeConst("node2");
261     (*graph.add_node()) = MakeNodeConst("node3");
262     (*graph.add_node()) = MakeNodeConst("node4");
263 
264     // One group of equivalent links.
265     auto center = graph.add_node();
266     *center = MakeNodeMul("node5", "node1", "node2");
267     center->add_input("node3");
268     center->add_input("node4");
269 
270     // Multiple groups, separated by unique links.
271     auto center2 = graph.add_node();
272     *center2 = MakeNodeMul("node6", "node1", "node2");
273     center2->add_input("node2:1");
274     center2->add_input("node3:2");
275     center2->add_input("node4:2");
276     center2->add_input("node4:3");
277   }
278   {
279     GraphDef& graph = graph_sun_;
280     (*graph.add_node()) = MakeNodeConst("node1");
281     (*graph.add_node()) = MakeNodeConst("node2");
282     (*graph.add_node()) = MakeNodeConst("node3");
283     (*graph.add_node()) = MakeNodeConst("node4");
284     (*graph.add_node()) = MakeNodeConst("node5");
285     (*graph.add_node()) = MakeNodeSub("node6", "node1", "node10");
286     (*graph.add_node()) = MakeNodeSub("node7", "node2", "node6");
287     (*graph.add_node()) = MakeNodeSub("node8", "node3", "node7");
288     (*graph.add_node()) = MakeNodeSub("node9", "node4", "node8");
289     (*graph.add_node()) = MakeNodeSub("node10", "node5", "node9");
290   }
291 }
292 
293 }  // end namespace test
294 }  // end namespace graph_analyzer
295 }  // end namespace grappler
296 }  // end namespace tensorflow
297