1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/grappler_test.h"
17
18 #include <memory>
19
20 #include "absl/algorithm/container.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/grappler/utils.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
25 #include "tensorflow/core/public/session.h"
26
27 namespace tensorflow {
28 namespace grappler {
29
30 namespace {
CompareGraphNodes(protobuf::RepeatedPtrField<NodeDef> * want,protobuf::RepeatedPtrField<NodeDef> * got)31 void CompareGraphNodes(protobuf::RepeatedPtrField<NodeDef>* want,
32 protobuf::RepeatedPtrField<NodeDef>* got) {
33 auto comparator = [](const NodeDef& n1, const NodeDef& n2) -> bool {
34 return n1.name() < n2.name();
35 };
36
37 std::sort(want->begin(), want->end(), comparator);
38 std::sort(got->begin(), got->end(), comparator);
39
40 ASSERT_EQ(want->size(), got->size());
41
42 for (int i = 0; i < want->size(); ++i) {
43 NodeDef& want_node = (*want)[i];
44 NodeDef& got_node = (*got)[i];
45
46 EXPECT_EQ(want_node.op(), got_node.op());
47 EXPECT_EQ(want_node.name(), got_node.name());
48 EXPECT_EQ(want_node.device(), got_node.device());
49 ASSERT_EQ(want_node.input_size(), got_node.input_size())
50 << "want_node =\n"
51 << want_node.DebugString() << "\ngot_node =\n"
52 << got_node.DebugString();
53
54 // Order of control dependencies doesn't matter, so we sort them first.
55 const auto is_control = [](const string& input) -> bool {
56 return ParseTensorName(input).index() < 0;
57 };
58
59 auto want_inputs = want_node.mutable_input();
60 auto got_inputs = got_node.mutable_input();
61 std::sort(absl::c_find_if(*want_inputs, is_control), want_inputs->end());
62 std::sort(absl::c_find_if(*got_inputs, is_control), got_inputs->end());
63
64 for (int j = 0; j < want_node.input_size(); ++j) {
65 const TensorId want_tensor = ParseTensorName(want_node.input(j));
66 const TensorId got_tensor = ParseTensorName(got_node.input(j));
67 EXPECT_EQ(want_tensor.ToString(), got_tensor.ToString());
68 }
69 }
70 }
71
SetAllOptimizers(RewriterConfig * cfg,RewriterConfig::Toggle value)72 void SetAllOptimizers(RewriterConfig* cfg, RewriterConfig::Toggle value) {
73 cfg->set_arithmetic_optimization(value);
74 cfg->set_auto_mixed_precision(value);
75 cfg->set_auto_mixed_precision_mkl(value);
76 cfg->set_common_subgraph_elimination(value);
77 cfg->set_constant_folding(value);
78 cfg->set_debug_stripper(value);
79 cfg->set_dependency_optimization(value);
80 cfg->set_function_optimization(value);
81 cfg->set_implementation_selector(value);
82 cfg->set_layout_optimizer(value);
83 cfg->set_loop_optimization(value);
84 cfg->set_pin_to_host_optimization(value);
85 cfg->set_remapping(value);
86 cfg->set_scoped_allocator_optimization(value);
87 cfg->set_shape_optimization(value);
88 }
89 } // namespace
90
GrapplerTest()91 GrapplerTest::GrapplerTest() {
92 // Turn off all the automatic optimizations to ensure that we run the graph
93 // exactly as it is given to us. This ensures that we can compare the
94 // results before and after manual optimization, without any of the
95 // automatic optimizations interfering in the comparison.
96 DisableAllOptimizers();
97 }
98
DisableAllOptimizers()99 void GrapplerTest::DisableAllOptimizers() {
100 SetAllOptimizers(
101 options_.config.mutable_graph_options()->mutable_rewrite_options(),
102 RewriterConfig::OFF);
103 }
104
EnableAllOptimizers()105 void GrapplerTest::EnableAllOptimizers() {
106 SetAllOptimizers(
107 options_.config.mutable_graph_options()->mutable_rewrite_options(),
108 RewriterConfig::ON);
109 }
110
EvaluateNodes(const GraphDef & graph,const std::vector<string> & node_names) const111 std::vector<Tensor> GrapplerTest::EvaluateNodes(
112 const GraphDef& graph, const std::vector<string>& node_names) const {
113 return EvaluateNodes(graph, node_names, {});
114 }
115
EvaluateNodes(const GraphDef & graph,const std::vector<string> & node_names,const std::vector<std::pair<string,Tensor>> & inputs) const116 std::vector<Tensor> GrapplerTest::EvaluateNodes(
117 const GraphDef& graph, const std::vector<string>& node_names,
118 const std::vector<std::pair<string, Tensor>>& inputs) const {
119 std::unique_ptr<tensorflow::Session> session(NewSession(options_));
120 TF_CHECK_OK(session->Create(graph));
121 RunOptions run_options;
122 std::vector<Tensor> output_tensors;
123 TF_CHECK_OK(session->Run(run_options, inputs, node_names, node_names,
124 &output_tensors, nullptr));
125 TF_CHECK_OK(session->Close());
126 return output_tensors;
127 }
128
EvaluateFetchNodes(const GrapplerItem & item) const129 std::vector<Tensor> GrapplerTest::EvaluateFetchNodes(
130 const GrapplerItem& item) const {
131 std::unique_ptr<tensorflow::Session> session(NewSession(options_));
132 TF_CHECK_OK(session->Create(item.graph));
133 RunOptions run_options;
134 if (!item.init_ops.empty()) {
135 std::vector<Tensor> dummy;
136 TF_CHECK_OK(
137 session->Run(run_options, {}, {}, item.init_ops, &dummy, nullptr));
138 }
139 std::vector<Tensor> output_tensors;
140 TF_CHECK_OK(session->Run(run_options, item.feed, item.fetch, {},
141 &output_tensors, nullptr));
142 TF_CHECK_OK(session->Close());
143 return output_tensors;
144 }
145
AddNode(const string & name,const string & op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,GraphDef * graph) const146 NodeDef* GrapplerTest::AddNode(
147 const string& name, const string& op, const std::vector<string>& inputs,
148 const std::vector<std::pair<string, AttrValue>>& attributes,
149 GraphDef* graph) const {
150 NodeDef* node = graph->add_node();
151 node->set_name(name);
152 node->set_op(op);
153 for (const string& input : inputs) {
154 node->add_input(input);
155 }
156 for (auto attr : attributes) {
157 (*node->mutable_attr())[attr.first] = attr.second;
158 }
159 return node;
160 }
161
CompareGraphs(GraphDef want,GraphDef got) const162 void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const {
163 CompareGraphNodes(want.mutable_node(), got.mutable_node());
164 }
165
CompareFunctions(FunctionDef want,FunctionDef got) const166 void GrapplerTest::CompareFunctions(FunctionDef want, FunctionDef got) const {
167 CompareGraphNodes(want.mutable_node_def(), got.mutable_node_def());
168 }
169
CompareNodes(const NodeDef & want,const NodeDef & got) const170 void GrapplerTest::CompareNodes(const NodeDef& want, const NodeDef& got) const {
171 EXPECT_EQ(want.name(), got.name());
172 EXPECT_EQ(want.op(), got.op());
173
174 std::vector<string> want_inputs(want.input().begin(), want.input().end());
175 std::vector<string> got_inputs(got.input().begin(), got.input().end());
176 EXPECT_EQ(want_inputs, got_inputs);
177
178 const auto attr_name = [](const std::pair<const string, AttrValue>& attr) {
179 return attr.first;
180 };
181
182 std::vector<string> want_attrs;
183 std::vector<string> got_attrs;
184 absl::c_transform(want.attr(), std::back_inserter(want_attrs), attr_name);
185 absl::c_transform(got.attr(), std::back_inserter(got_attrs), attr_name);
186 absl::c_sort(want_attrs);
187 absl::c_sort(got_attrs);
188 EXPECT_EQ(want_attrs, got_attrs);
189
190 for (const string& attr : want_attrs) {
191 EXPECT_TRUE(AreAttrValuesEqual(want.attr().at(attr), got.attr().at(attr)));
192 }
193 }
194
IsNodesDirectlyConnected(const NodeMap & node_map,const string & src,const string & dst,int position)195 bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map,
196 const string& src,
197 const string& dst, int position) {
198 const NodeDef* src_node = node_map.GetNode(src);
199 const NodeDef* dst_node = node_map.GetNode(dst);
200 EXPECT_TRUE(src_node != nullptr) << src << " node not found";
201 EXPECT_TRUE(dst_node != nullptr) << dst << " node not found";
202 return src_node && dst_node && dst_node->input(position) == src_node->name();
203 }
204
CountOpNodes(const GraphDef & graph,const string & op)205 int GrapplerTest::CountOpNodes(const GraphDef& graph, const string& op) {
206 return std::count_if(graph.node().begin(), graph.node().end(),
207 [&op](const NodeDef& node) { return node.op() == op; });
208 }
209
210 } // namespace grappler
211 } // namespace tensorflow
212