1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
17
18 #include "absl/strings/str_format.h"
19 #include "absl/strings/str_join.h"
20
21 namespace tensorflow {
22 namespace grappler {
23 namespace graph_analyzer {
24 namespace test {
25
26 //=== Helper methods to construct the nodes.
27
MakeNodeConst(const string & name)28 NodeDef MakeNodeConst(const string& name) {
29 NodeDef n;
30 n.set_name(name);
31 n.set_op("Const");
32 return n;
33 }
34
MakeNode2Arg(const string & name,const string & opcode,const string & arg1,const string & arg2)35 NodeDef MakeNode2Arg(const string& name, const string& opcode,
36 const string& arg1, const string& arg2) {
37 NodeDef n;
38 n.set_name(name);
39 n.set_op(opcode);
40 n.add_input(arg1);
41 n.add_input(arg2);
42 return n;
43 }
44
MakeNode4Arg(const string & name,const string & opcode,const string & arg1,const string & arg2,const string & arg3,const string & arg4)45 NodeDef MakeNode4Arg(const string& name, const string& opcode,
46 const string& arg1, const string& arg2, const string& arg3,
47 const string& arg4) {
48 NodeDef n;
49 n.set_name(name);
50 n.set_op(opcode);
51 n.add_input(arg1);
52 n.add_input(arg2);
53 n.add_input(arg3);
54 n.add_input(arg4);
55 return n;
56 }
57
58 // Not really a 2-argument but convenient to construct.
MakeNodeShapeN(const string & name,const string & arg1,const string & arg2)59 NodeDef MakeNodeShapeN(const string& name, const string& arg1,
60 const string& arg2) {
61 // This opcode is multi-input but not commutative.
62 return MakeNode2Arg(name, "ShapeN", arg1, arg2);
63 }
64
65 // Not really a 2-argument but convenient to construct.
MakeNodeIdentityN(const string & name,const string & arg1,const string & arg2)66 NodeDef MakeNodeIdentityN(const string& name, const string& arg1,
67 const string& arg2) {
68 // The argument is of a list type.
69 return MakeNode2Arg(name, "IdentityN", arg1, arg2);
70 }
71
MakeNodeQuantizedConcat(const string & name,const string & arg1,const string & arg2,const string & arg3,const string & arg4)72 NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1,
73 const string& arg2, const string& arg3,
74 const string& arg4) {
75 // This opcode has multiple multi-inputs.
76 return MakeNode4Arg(name, "QuantizedConcat", arg1, arg2, arg3, arg4);
77 }
78
79 //=== Helper methods for analysing the structures.
80
DumpLinkMap(const GenNode::LinkMap & link_map)81 std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map) {
82 // This will order the entries first.
83 std::map<string, string> ordered;
84 for (const auto& link : link_map) {
85 string key = string(link.first);
86
87 // Order the other sides too. They may be repeating, so store them
88 // in a multiset.
89 std::multiset<string> others;
90 for (const auto& other : link.second) {
91 others.emplace(
92 absl::StrFormat("%s[%s]", other.node->name(), string(other.port)));
93 }
94 ordered[key] = absl::StrJoin(others, ", ");
95 }
96 // Now dump the result in a predictable order.
97 std::vector<string> result;
98 result.reserve(ordered.size());
99 for (const auto& link : ordered) {
100 result.emplace_back(link.first + ": " + link.second);
101 }
102 return result;
103 }
104
DumpLinkHashMap(const SigNode::LinkHashMap & link_hash_map)105 std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map) {
106 // The entries in this map are ordered by hash value which might change
107 // at any point. Re-order them by the link tag.
108 std::map<SigNode::LinkTag, size_t> tags;
109 for (const auto& entry : link_hash_map) {
110 tags[entry.second.tag] = entry.first;
111 }
112
113 std::vector<string> result;
114 for (const auto& id : tags) {
115 // For predictability, the nodes need to be sorted.
116 std::vector<string> nodes;
117 for (const auto& peer : link_hash_map.at(id.second).peers) {
118 nodes.emplace_back(peer->name());
119 }
120 std::sort(nodes.begin(), nodes.end());
121 result.emplace_back(string(id.first.local) + ":" + string(id.first.remote) +
122 ": " + absl::StrJoin(nodes, ", "));
123 }
124 return result;
125 }
126
DumpHashedPeerVector(const SigNode::HashedPeerVector & hashed_peers)127 std::vector<string> DumpHashedPeerVector(
128 const SigNode::HashedPeerVector& hashed_peers) {
129 std::vector<string> result;
130
131 // Each subset of nodes with the same hash has to be sorted by name.
132 // Other than that, the vector is already ordered by full tags.
133 size_t last_hash = 0;
134 // Index, since iterators may get invalidated on append.
135 size_t subset_start = 0;
136
137 for (const auto& entry : hashed_peers) {
138 if (entry.link_hash != last_hash) {
139 std::sort(result.begin() + subset_start, result.end());
140 subset_start = result.size();
141 }
142 result.emplace_back(entry.peer->name());
143 }
144 std::sort(result.begin() + subset_start, result.end());
145
146 return result;
147 }
148
TestGraphs()149 TestGraphs::TestGraphs() {
150 {
151 GraphDef& graph = graph_3n_self_control_;
152 // The topology includes a loop and a link to self.
153 (*graph.add_node()) = MakeNodeConst("node1");
154 (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
155 auto node3 = graph.add_node();
156 *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
157 node3->add_input("^node3"); // The control link goes back to self.
158 }
159 {
160 GraphDef& graph = graph_multi_input_;
161 // The topology includes a loop and a link to self.
162 (*graph.add_node()) = MakeNodeConst("const1_1");
163 (*graph.add_node()) = MakeNodeConst("const1_2");
164 (*graph.add_node()) = MakeNodeAddN("add1", "const1_1", "const1_2");
165
166 (*graph.add_node()) = MakeNodeConst("const2_1");
167 (*graph.add_node()) = MakeNodeConst("const2_2");
168 (*graph.add_node()) = MakeNodeConst("const2_3");
169
170 auto add2 = graph.add_node();
171 *add2 = MakeNodeAddN("add2", "const2_1", "const2_2");
172 // The 3rd node is connected twice, to 4 links total.
173 add2->add_input("const2_3");
174 add2->add_input("const2_3");
175
176 (*graph.add_node()) = MakeNodeSub("sub", "add1", "add2");
177 }
178 {
179 GraphDef& graph = graph_all_or_none_;
180 // The topology includes a loop and a link to self.
181 (*graph.add_node()) = MakeNodeConst("const1_1");
182 (*graph.add_node()) = MakeNodeConst("const1_2");
183 auto pass1 = graph.add_node();
184 *pass1 = MakeNodeIdentityN("pass1", "const1_1", "const1_2");
185
186 (*graph.add_node()) = MakeNodeConst("const2_1");
187 (*graph.add_node()) = MakeNodeConst("const2_2");
188 (*graph.add_node()) = MakeNodeConst("const2_3");
189
190 auto pass2 = graph.add_node();
191 *pass2 = MakeNodeIdentityN("pass2", "const2_1", "const2_2");
192 // The 3rd node is connected twice, to 4 links total.
193 pass2->add_input("const2_3");
194 pass2->add_input("const2_3");
195
196 // Add the control links, they get handled separately than the normal
197 // links.
198 pass1->add_input("^const2_1");
199 pass1->add_input("^const2_2");
200 pass1->add_input("^const2_3");
201
202 (*graph.add_node()) = MakeNodeSub("sub", "pass1", "pass2");
203 }
204 {
205 GraphDef& graph = graph_circular_onedir_;
206 (*graph.add_node()) = MakeNodeMul("node1", "node5", "node5");
207 (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
208 (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
209 (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
210 (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
211 }
212 {
213 GraphDef& graph = graph_circular_bidir_;
214 // The left and right links are intentionally mixed up.
215 (*graph.add_node()) = MakeNodeMul("node1", "node5", "node2");
216 (*graph.add_node()) = MakeNodeMul("node2", "node3", "node1");
217 (*graph.add_node()) = MakeNodeMul("node3", "node2", "node4");
218 (*graph.add_node()) = MakeNodeMul("node4", "node5", "node3");
219 (*graph.add_node()) = MakeNodeMul("node5", "node4", "node1");
220 }
221 {
222 GraphDef& graph = graph_linear_;
223 (*graph.add_node()) = MakeNodeConst("node1");
224 (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
225 (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
226 (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
227 (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
228 }
229 {
230 GraphDef& graph = graph_cross_;
231 (*graph.add_node()) = MakeNodeConst("node1");
232 (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
233 (*graph.add_node()) = MakeNodeConst("node3");
234 (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
235 (*graph.add_node()) = MakeNodeConst("node5");
236 (*graph.add_node()) = MakeNodeMul("node6", "node5", "node5");
237 (*graph.add_node()) = MakeNodeConst("node7");
238 (*graph.add_node()) = MakeNodeMul("node8", "node7", "node7");
239
240 auto center = graph.add_node();
241 *center = MakeNodeMul("node9", "node2", "node4");
242 center->add_input("node6");
243 center->add_input("node8");
244 }
245 {
246 GraphDef& graph = graph_small_cross_;
247 (*graph.add_node()) = MakeNodeConst("node1");
248 (*graph.add_node()) = MakeNodeConst("node2");
249 (*graph.add_node()) = MakeNodeConst("node3");
250 (*graph.add_node()) = MakeNodeConst("node4");
251
252 auto center = graph.add_node();
253 *center = MakeNodeMul("node5", "node1", "node2");
254 center->add_input("node3");
255 center->add_input("node4");
256 }
257 {
258 GraphDef& graph = graph_for_link_order_;
259 (*graph.add_node()) = MakeNodeConst("node1");
260 (*graph.add_node()) = MakeNodeConst("node2");
261 (*graph.add_node()) = MakeNodeConst("node3");
262 (*graph.add_node()) = MakeNodeConst("node4");
263
264 // One group of equivalent links.
265 auto center = graph.add_node();
266 *center = MakeNodeMul("node5", "node1", "node2");
267 center->add_input("node3");
268 center->add_input("node4");
269
270 // Multiple groups, separated by unique links.
271 auto center2 = graph.add_node();
272 *center2 = MakeNodeMul("node6", "node1", "node2");
273 center2->add_input("node2:1");
274 center2->add_input("node3:2");
275 center2->add_input("node4:2");
276 center2->add_input("node4:3");
277 }
278 {
279 GraphDef& graph = graph_sun_;
280 (*graph.add_node()) = MakeNodeConst("node1");
281 (*graph.add_node()) = MakeNodeConst("node2");
282 (*graph.add_node()) = MakeNodeConst("node3");
283 (*graph.add_node()) = MakeNodeConst("node4");
284 (*graph.add_node()) = MakeNodeConst("node5");
285 (*graph.add_node()) = MakeNodeSub("node6", "node1", "node10");
286 (*graph.add_node()) = MakeNodeSub("node7", "node2", "node6");
287 (*graph.add_node()) = MakeNodeSub("node8", "node3", "node7");
288 (*graph.add_node()) = MakeNodeSub("node9", "node4", "node8");
289 (*graph.add_node()) = MakeNodeSub("node10", "node5", "node9");
290 }
291 }
292
293 } // end namespace test
294 } // end namespace graph_analyzer
295 } // end namespace grappler
296 } // end namespace tensorflow
297