• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/grappler_test.h"
17 
18 #include <memory>
19 
20 #include "absl/algorithm/container.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/grappler/utils.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
25 #include "tensorflow/core/public/session.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 
30 namespace {
CompareGraphNodes(protobuf::RepeatedPtrField<NodeDef> * want,protobuf::RepeatedPtrField<NodeDef> * got)31 void CompareGraphNodes(protobuf::RepeatedPtrField<NodeDef>* want,
32                        protobuf::RepeatedPtrField<NodeDef>* got) {
33   auto comparator = [](const NodeDef& n1, const NodeDef& n2) -> bool {
34     return n1.name() < n2.name();
35   };
36 
37   std::sort(want->begin(), want->end(), comparator);
38   std::sort(got->begin(), got->end(), comparator);
39 
40   ASSERT_EQ(want->size(), got->size());
41 
42   for (int i = 0; i < want->size(); ++i) {
43     NodeDef& want_node = (*want)[i];
44     NodeDef& got_node = (*got)[i];
45 
46     EXPECT_EQ(want_node.op(), got_node.op());
47     EXPECT_EQ(want_node.name(), got_node.name());
48     EXPECT_EQ(want_node.device(), got_node.device());
49     ASSERT_EQ(want_node.input_size(), got_node.input_size());
50 
51     // Order of control dependencies doesn't matter, so we sort them first.
52     const auto is_control = [](const string& input) -> bool {
53       return ParseTensorName(input).index() < 0;
54     };
55 
56     auto want_inputs = want_node.mutable_input();
57     auto got_inputs = got_node.mutable_input();
58     std::sort(absl::c_find_if(*want_inputs, is_control), want_inputs->end());
59     std::sort(absl::c_find_if(*got_inputs, is_control), got_inputs->end());
60 
61     for (int j = 0; j < want_node.input_size(); ++j) {
62       const TensorId want_tensor = ParseTensorName(want_node.input(j));
63       const TensorId got_tensor = ParseTensorName(got_node.input(j));
64       EXPECT_EQ(want_tensor.ToString(), got_tensor.ToString());
65     }
66   }
67 }
68 }  // namespace
69 
GrapplerTest()70 GrapplerTest::GrapplerTest() {
71   // Turn off all the automatic optimizations to ensure that we run the graph
72   // exactly as it is given to us. This ensures that we can compare the results
73   // before and after manual optimization, without any of the automatic
74   // optimizations interfering in the comparison.
75   RewriterConfig* cfg =
76       options_.config.mutable_graph_options()->mutable_rewrite_options();
77   // TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
78   // off.
79   cfg->set_arithmetic_optimization(RewriterConfig::OFF);
80   cfg->set_constant_folding(RewriterConfig::OFF);
81   cfg->set_debug_stripper(RewriterConfig::OFF);
82   cfg->set_dependency_optimization(RewriterConfig::OFF);
83   cfg->set_function_optimization(RewriterConfig::OFF);
84   cfg->set_implementation_selector(RewriterConfig::OFF);
85   cfg->set_layout_optimizer(RewriterConfig::OFF);
86   cfg->set_loop_optimization(RewriterConfig::OFF);
87   cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
88 }
89 
EvaluateNodes(const GraphDef & graph,const std::vector<string> & node_names) const90 std::vector<Tensor> GrapplerTest::EvaluateNodes(
91     const GraphDef& graph, const std::vector<string>& node_names) const {
92   return EvaluateNodes(graph, node_names, {});
93 }
94 
EvaluateNodes(const GraphDef & graph,const std::vector<string> & node_names,const std::vector<std::pair<string,Tensor>> & inputs) const95 std::vector<Tensor> GrapplerTest::EvaluateNodes(
96     const GraphDef& graph, const std::vector<string>& node_names,
97     const std::vector<std::pair<string, Tensor>>& inputs) const {
98   std::unique_ptr<tensorflow::Session> session(NewSession(options_));
99   TF_CHECK_OK(session->Create(graph));
100   RunOptions run_options;
101   std::vector<Tensor> output_tensors;
102   TF_CHECK_OK(session->Run(run_options, inputs, node_names, node_names,
103                            &output_tensors, nullptr));
104   TF_CHECK_OK(session->Close());
105   return output_tensors;
106 }
107 
EvaluateFetchNodes(const GrapplerItem & item) const108 std::vector<Tensor> GrapplerTest::EvaluateFetchNodes(
109     const GrapplerItem& item) const {
110   std::unique_ptr<tensorflow::Session> session(NewSession(options_));
111   TF_CHECK_OK(session->Create(item.graph));
112   RunOptions run_options;
113   if (!item.init_ops.empty()) {
114     std::vector<Tensor> dummy;
115     TF_CHECK_OK(
116         session->Run(run_options, {}, {}, item.init_ops, &dummy, nullptr));
117   }
118   std::vector<Tensor> output_tensors;
119   TF_CHECK_OK(session->Run(run_options, item.feed, item.fetch, {},
120                            &output_tensors, nullptr));
121   TF_CHECK_OK(session->Close());
122   return output_tensors;
123 }
124 
AddNode(const string & name,const string & op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,GraphDef * graph) const125 NodeDef* GrapplerTest::AddNode(
126     const string& name, const string& op, const std::vector<string>& inputs,
127     const std::vector<std::pair<string, AttrValue>>& attributes,
128     GraphDef* graph) const {
129   NodeDef* node = graph->add_node();
130   node->set_name(name);
131   node->set_op(op);
132   for (const string& input : inputs) {
133     node->add_input(input);
134   }
135   for (auto attr : attributes) {
136     (*node->mutable_attr())[attr.first] = attr.second;
137   }
138   return node;
139 }
140 
CompareGraphs(GraphDef want,GraphDef got) const141 void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const {
142   CompareGraphNodes(want.mutable_node(), got.mutable_node());
143 }
144 
CompareFunctions(FunctionDef want,FunctionDef got) const145 void GrapplerTest::CompareFunctions(FunctionDef want, FunctionDef got) const {
146   CompareGraphNodes(want.mutable_node_def(), got.mutable_node_def());
147 }
148 
CompareNodes(const NodeDef & want,const NodeDef & got) const149 void GrapplerTest::CompareNodes(const NodeDef& want, const NodeDef& got) const {
150   EXPECT_EQ(want.name(), got.name());
151   EXPECT_EQ(want.op(), got.op());
152 
153   std::vector<string> want_inputs(want.input().begin(), want.input().end());
154   std::vector<string> got_inputs(got.input().begin(), got.input().end());
155   EXPECT_EQ(want_inputs, got_inputs);
156 
157   const auto attr_name = [](const std::pair<const string, AttrValue>& attr) {
158     return attr.first;
159   };
160 
161   std::vector<string> want_attrs;
162   std::vector<string> got_attrs;
163   absl::c_transform(want.attr(), std::back_inserter(want_attrs), attr_name);
164   absl::c_transform(got.attr(), std::back_inserter(got_attrs), attr_name);
165   absl::c_sort(want_attrs);
166   absl::c_sort(got_attrs);
167   EXPECT_EQ(want_attrs, got_attrs);
168 
169   for (const string& attr : want_attrs) {
170     EXPECT_TRUE(AreAttrValuesEqual(want.attr().at(attr), got.attr().at(attr)));
171   }
172 }
173 
IsNodesDirectlyConnected(const NodeMap & node_map,const string & src,const string & dst,int position)174 bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map,
175                                             const string& src,
176                                             const string& dst, int position) {
177   const NodeDef* src_node = node_map.GetNode(src);
178   const NodeDef* dst_node = node_map.GetNode(dst);
179   EXPECT_TRUE(src_node != nullptr) << src << " node not found";
180   EXPECT_TRUE(dst_node != nullptr) << dst << " node not found";
181   return src_node && dst_node && dst_node->input(position) == src_node->name();
182 }
183 
CountOpNodes(const GraphDef & graph,const string & op)184 int GrapplerTest::CountOpNodes(const GraphDef& graph, const string& op) {
185   return std::count_if(graph.node().begin(), graph.node().end(),
186                        [&op](const NodeDef& node) { return node.op() == op; });
187 }
188 
189 }  // namespace grappler
190 }  // namespace tensorflow
191