• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/grappler_test.h"
17 
18 #include <memory>
19 
20 #include "absl/algorithm/container.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/grappler/utils.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
25 #include "tensorflow/core/public/session.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 
30 namespace {
CompareGraphNodes(protobuf::RepeatedPtrField<NodeDef> * want,protobuf::RepeatedPtrField<NodeDef> * got)31 void CompareGraphNodes(protobuf::RepeatedPtrField<NodeDef>* want,
32                        protobuf::RepeatedPtrField<NodeDef>* got) {
33   auto comparator = [](const NodeDef& n1, const NodeDef& n2) -> bool {
34     return n1.name() < n2.name();
35   };
36 
37   std::sort(want->begin(), want->end(), comparator);
38   std::sort(got->begin(), got->end(), comparator);
39 
40   ASSERT_EQ(want->size(), got->size());
41 
42   for (int i = 0; i < want->size(); ++i) {
43     NodeDef& want_node = (*want)[i];
44     NodeDef& got_node = (*got)[i];
45 
46     EXPECT_EQ(want_node.op(), got_node.op());
47     EXPECT_EQ(want_node.name(), got_node.name());
48     EXPECT_EQ(want_node.device(), got_node.device());
49     ASSERT_EQ(want_node.input_size(), got_node.input_size())
50         << "want_node =\n"
51         << want_node.DebugString() << "\ngot_node =\n"
52         << got_node.DebugString();
53 
54     // Order of control dependencies doesn't matter, so we sort them first.
55     const auto is_control = [](const string& input) -> bool {
56       return ParseTensorName(input).index() < 0;
57     };
58 
59     auto want_inputs = want_node.mutable_input();
60     auto got_inputs = got_node.mutable_input();
61     std::sort(absl::c_find_if(*want_inputs, is_control), want_inputs->end());
62     std::sort(absl::c_find_if(*got_inputs, is_control), got_inputs->end());
63 
64     for (int j = 0; j < want_node.input_size(); ++j) {
65       const TensorId want_tensor = ParseTensorName(want_node.input(j));
66       const TensorId got_tensor = ParseTensorName(got_node.input(j));
67       EXPECT_EQ(want_tensor.ToString(), got_tensor.ToString());
68     }
69   }
70 }
71 
SetAllOptimizers(RewriterConfig * cfg,RewriterConfig::Toggle value)72 void SetAllOptimizers(RewriterConfig* cfg, RewriterConfig::Toggle value) {
73   cfg->set_arithmetic_optimization(value);
74   cfg->set_auto_mixed_precision(value);
75   cfg->set_auto_mixed_precision_mkl(value);
76   cfg->set_common_subgraph_elimination(value);
77   cfg->set_constant_folding(value);
78   cfg->set_debug_stripper(value);
79   cfg->set_dependency_optimization(value);
80   cfg->set_function_optimization(value);
81   cfg->set_implementation_selector(value);
82   cfg->set_layout_optimizer(value);
83   cfg->set_loop_optimization(value);
84   cfg->set_pin_to_host_optimization(value);
85   cfg->set_remapping(value);
86   cfg->set_scoped_allocator_optimization(value);
87   cfg->set_shape_optimization(value);
88 }
89 }  // namespace
90 
GrapplerTest()91 GrapplerTest::GrapplerTest() {
92   // Turn off all the automatic optimizations to ensure that we run the graph
93   // exactly as it is given to us. This ensures that we can compare the
94   // results before and after manual optimization, without any of the
95   // automatic optimizations interfering in the comparison.
96   DisableAllOptimizers();
97 }
98 
DisableAllOptimizers()99 void GrapplerTest::DisableAllOptimizers() {
100   SetAllOptimizers(
101       options_.config.mutable_graph_options()->mutable_rewrite_options(),
102       RewriterConfig::OFF);
103 }
104 
EnableAllOptimizers()105 void GrapplerTest::EnableAllOptimizers() {
106   SetAllOptimizers(
107       options_.config.mutable_graph_options()->mutable_rewrite_options(),
108       RewriterConfig::ON);
109 }
110 
EvaluateNodes(const GraphDef & graph,const std::vector<string> & node_names) const111 std::vector<Tensor> GrapplerTest::EvaluateNodes(
112     const GraphDef& graph, const std::vector<string>& node_names) const {
113   return EvaluateNodes(graph, node_names, {});
114 }
115 
EvaluateNodes(const GraphDef & graph,const std::vector<string> & node_names,const std::vector<std::pair<string,Tensor>> & inputs) const116 std::vector<Tensor> GrapplerTest::EvaluateNodes(
117     const GraphDef& graph, const std::vector<string>& node_names,
118     const std::vector<std::pair<string, Tensor>>& inputs) const {
119   std::unique_ptr<tensorflow::Session> session(NewSession(options_));
120   TF_CHECK_OK(session->Create(graph));
121   RunOptions run_options;
122   std::vector<Tensor> output_tensors;
123   TF_CHECK_OK(session->Run(run_options, inputs, node_names, node_names,
124                            &output_tensors, nullptr));
125   TF_CHECK_OK(session->Close());
126   return output_tensors;
127 }
128 
EvaluateFetchNodes(const GrapplerItem & item) const129 std::vector<Tensor> GrapplerTest::EvaluateFetchNodes(
130     const GrapplerItem& item) const {
131   std::unique_ptr<tensorflow::Session> session(NewSession(options_));
132   TF_CHECK_OK(session->Create(item.graph));
133   RunOptions run_options;
134   if (!item.init_ops.empty()) {
135     std::vector<Tensor> dummy;
136     TF_CHECK_OK(
137         session->Run(run_options, {}, {}, item.init_ops, &dummy, nullptr));
138   }
139   std::vector<Tensor> output_tensors;
140   TF_CHECK_OK(session->Run(run_options, item.feed, item.fetch, {},
141                            &output_tensors, nullptr));
142   TF_CHECK_OK(session->Close());
143   return output_tensors;
144 }
145 
AddNode(const string & name,const string & op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,GraphDef * graph) const146 NodeDef* GrapplerTest::AddNode(
147     const string& name, const string& op, const std::vector<string>& inputs,
148     const std::vector<std::pair<string, AttrValue>>& attributes,
149     GraphDef* graph) const {
150   NodeDef* node = graph->add_node();
151   node->set_name(name);
152   node->set_op(op);
153   for (const string& input : inputs) {
154     node->add_input(input);
155   }
156   for (auto attr : attributes) {
157     (*node->mutable_attr())[attr.first] = attr.second;
158   }
159   return node;
160 }
161 
CompareGraphs(GraphDef want,GraphDef got) const162 void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const {
163   CompareGraphNodes(want.mutable_node(), got.mutable_node());
164 }
165 
CompareFunctions(FunctionDef want,FunctionDef got) const166 void GrapplerTest::CompareFunctions(FunctionDef want, FunctionDef got) const {
167   CompareGraphNodes(want.mutable_node_def(), got.mutable_node_def());
168 }
169 
CompareNodes(const NodeDef & want,const NodeDef & got) const170 void GrapplerTest::CompareNodes(const NodeDef& want, const NodeDef& got) const {
171   EXPECT_EQ(want.name(), got.name());
172   EXPECT_EQ(want.op(), got.op());
173 
174   std::vector<string> want_inputs(want.input().begin(), want.input().end());
175   std::vector<string> got_inputs(got.input().begin(), got.input().end());
176   EXPECT_EQ(want_inputs, got_inputs);
177 
178   const auto attr_name = [](const std::pair<const string, AttrValue>& attr) {
179     return attr.first;
180   };
181 
182   std::vector<string> want_attrs;
183   std::vector<string> got_attrs;
184   absl::c_transform(want.attr(), std::back_inserter(want_attrs), attr_name);
185   absl::c_transform(got.attr(), std::back_inserter(got_attrs), attr_name);
186   absl::c_sort(want_attrs);
187   absl::c_sort(got_attrs);
188   EXPECT_EQ(want_attrs, got_attrs);
189 
190   for (const string& attr : want_attrs) {
191     EXPECT_TRUE(AreAttrValuesEqual(want.attr().at(attr), got.attr().at(attr)));
192   }
193 }
194 
IsNodesDirectlyConnected(const NodeMap & node_map,const string & src,const string & dst,int position)195 bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map,
196                                             const string& src,
197                                             const string& dst, int position) {
198   const NodeDef* src_node = node_map.GetNode(src);
199   const NodeDef* dst_node = node_map.GetNode(dst);
200   EXPECT_TRUE(src_node != nullptr) << src << " node not found";
201   EXPECT_TRUE(dst_node != nullptr) << dst << " node not found";
202   return src_node && dst_node && dst_node->input(position) == src_node->name();
203 }
204 
CountOpNodes(const GraphDef & graph,const string & op)205 int GrapplerTest::CountOpNodes(const GraphDef& graph, const string& op) {
206   return std::count_if(graph.node().begin(), graph.node().end(),
207                        [&op](const NodeDef& node) { return node.op() == op; });
208 }
209 
210 }  // namespace grappler
211 }  // namespace tensorflow
212