1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
18
19 #include <string>
20 #include <vector>
21
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
24 #include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
25 #include "tensorflow/core/grappler/op_types.h"
26
27 namespace tensorflow {
28 namespace grappler {
29 namespace graph_analyzer {
30 namespace test {
31
32 //=== Helper methods to construct the nodes.
33
34 NodeDef MakeNodeConst(const string& name);
35
36 NodeDef MakeNode2Arg(const string& name, const string& opcode,
37 const string& arg1, const string& arg2);
38
39 NodeDef MakeNode4Arg(const string& name, const string& opcode,
40 const string& arg1, const string& arg2, const string& arg3,
41 const string& arg4);
42
MakeNodeMul(const string & name,const string & arg1,const string & arg2)43 inline NodeDef MakeNodeMul(const string& name, const string& arg1,
44 const string& arg2) {
45 return MakeNode2Arg(name, "Mul", arg1, arg2);
46 }
47
48 // Not really a 2-argument but convenient to construct.
MakeNodeAddN(const string & name,const string & arg1,const string & arg2)49 inline NodeDef MakeNodeAddN(const string& name, const string& arg1,
50 const string& arg2) {
51 return MakeNode2Arg(name, "AddN", arg1, arg2);
52 }
53
MakeNodeSub(const string & name,const string & arg1,const string & arg2)54 inline NodeDef MakeNodeSub(const string& name, const string& arg1,
55 const string& arg2) {
56 return MakeNode2Arg(name, "Sub", arg1, arg2);
57 }
58
59 // Has 2 honest outputs.
MakeNodeBroadcastGradientArgs(const string & name,const string & arg1,const string & arg2)60 inline NodeDef MakeNodeBroadcastGradientArgs(const string& name,
61 const string& arg1,
62 const string& arg2) {
63 return MakeNode2Arg(name, "BroadcastGradientArgs", arg1, arg2);
64 }
65
66 NodeDef MakeNodeShapeN(const string& name, const string& arg1,
67 const string& arg2);
68
69 NodeDef MakeNodeIdentityN(const string& name, const string& arg1,
70 const string& arg2);
71
72 NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1,
73 const string& arg2, const string& arg3,
74 const string& arg4);
75
76 //=== A container of pre-constructed graphs.
77
78 class TestGraphs {
79 public:
80 TestGraphs();
81
82 // Graph with 3 nodes and a control link to self (which is not valid in
83 // reality but adds excitement to the tests).
84 GraphDef graph_3n_self_control_;
85 // Graph that has the multi-input links.
86 GraphDef graph_multi_input_;
87 // Graph that has the all-or-none nodes.
88 GraphDef graph_all_or_none_;
89 // All the nodes are connected in a circle that goes in one direction.
90 GraphDef graph_circular_onedir_;
91 // All the nodes are connected in a circle that goes in both directions.
92 GraphDef graph_circular_bidir_;
93 // The nodes are connected in a line.
94 GraphDef graph_linear_;
95 // The nodes are connected in a cross shape.
96 GraphDef graph_cross_;
97 GraphDef graph_small_cross_;
98 // For testing the ordering of links at the end of signature generation,
99 // a variation of a cross.
100 GraphDef graph_for_link_order_;
101 // Sun-shaped, a ring with "rays".
102 GraphDef graph_sun_;
103 };
104
105 //=== Helper methods for analysing the structures.
106
107 std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map);
108
109 // Also checks for the consistency of hash values.
110 std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map);
111
112 std::vector<string> DumpHashedPeerVector(
113 const SigNode::HashedPeerVector& hashed_peers);
114
115 } // end namespace test
116 } // end namespace graph_analyzer
117 } // end namespace grappler
118 } // end namespace tensorflow
119
120 #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
121