1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
17
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/grappler/costs/graph_properties.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/platform/test.h"
23 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
24
25 namespace tensorflow {
26 namespace grappler {
27 namespace {
28
29 using ::tensorflow::test::function::GDef;
30 using ::tensorflow::test::function::NDef;
31
32 class GraphOptimizerStageTest : public ::testing::Test {};
33
34 struct FakeResult {};
35
36 // NoOp optimizer stage that supports all the node types and does nothing
37 class FakeOptimizerStage : public GraphOptimizerStage<FakeResult> {
38 public:
FakeOptimizerStage(const string & optimizer_name,const string & stage_name,const GraphOptimizerContext & ctx)39 explicit FakeOptimizerStage(const string& optimizer_name,
40 const string& stage_name,
41 const GraphOptimizerContext& ctx)
42 : GraphOptimizerStage(optimizer_name, stage_name, ctx) {}
43 ~FakeOptimizerStage() override = default;
44
IsSupported(const NodeDef * node) const45 bool IsSupported(const NodeDef* node) const override { return true; }
TrySimplify(NodeDef * node,FakeResult * result)46 Status TrySimplify(NodeDef* node, FakeResult* result) override {
47 return Status::OK();
48 }
49 };
50
TEST_F(GraphOptimizerStageTest,ParseNodeNameAndScopeInRoot)51 TEST_F(GraphOptimizerStageTest, ParseNodeNameAndScopeInRoot) {
52 const auto scope_and_name = ParseNodeScopeAndName("Add");
53 EXPECT_EQ(scope_and_name.scope, "");
54 EXPECT_EQ(scope_and_name.name, "Add");
55 }
56
TEST_F(GraphOptimizerStageTest,ParseNodeNameAndScopeInScope)57 TEST_F(GraphOptimizerStageTest, ParseNodeNameAndScopeInScope) {
58 const auto scope_and_name = ParseNodeScopeAndName("a/b/c/Add");
59 EXPECT_EQ(scope_and_name.scope, "a/b/c");
60 EXPECT_EQ(scope_and_name.name, "Add");
61 }
62
TEST_F(GraphOptimizerStageTest,OptimizedNodeName)63 TEST_F(GraphOptimizerStageTest, OptimizedNodeName) {
64 GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
65 /*optimized_graph*/ nullptr,
66 /*graph_properties*/ nullptr,
67 /*node_map*/ nullptr,
68 /*feed_nodes*/ nullptr,
69 /*opt_level*/ RewriterConfig::ON);
70 FakeOptimizerStage stage("my_opt", "my_stg", ctx);
71
72 const auto node = ParseNodeScopeAndName("a/b/c/Add");
73
74 // Without rewrite rule
75 EXPECT_EQ(stage.OptimizedNodeName(node), "a/b/c/my_opt/my_stg_Add");
76 EXPECT_EQ(stage.OptimizedNodeName(node, std::vector<string>({"Mul", "Sqrt"})),
77 "a/b/c/my_opt/my_stg_Add_Mul_Sqrt");
78
79 // With rewrite rule
80 const string rewrite = "my_rewrite";
81 EXPECT_EQ(stage.OptimizedNodeName(node, rewrite),
82 "a/b/c/my_opt/my_stg_my_rewrite_Add");
83 }
84
TEST_F(GraphOptimizerStageTest,UniqueOptimizedNodeName)85 TEST_F(GraphOptimizerStageTest, UniqueOptimizedNodeName) {
86 GraphDef graph =
87 GDef({NDef("a/b/c/A", "NotImportant", {}),
88 NDef("a/b/c/my_opt/my_stg_A", "NotImportant", {}),
89 NDef("a/b/c/my_opt/my_stg_my_rewrite_A", "NotImportant", {})},
90 /*funcs=*/{});
91
92 NodeMap node_map(&graph);
93 GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
94 /*optimized_graph*/ nullptr,
95 /*graph_properties*/ nullptr,
96 /*node_map*/ &node_map,
97 /*feed_nodes*/ nullptr,
98 /*opt_level*/ RewriterConfig::ON);
99 FakeOptimizerStage stage("my_opt", "my_stg", ctx);
100
101 const auto node = ParseNodeScopeAndName("a/b/c/A");
102
103 EXPECT_EQ(stage.UniqueOptimizedNodeName(node),
104 "a/b/c/my_opt/my_stg_A_unique0");
105
106 // With rewrite rule
107 const string rewrite = "my_rewrite";
108 EXPECT_EQ(stage.UniqueOptimizedNodeName(node, rewrite),
109 "a/b/c/my_opt/my_stg_my_rewrite_A_unique1");
110 }
111
TEST_F(GraphOptimizerStageTest,UniqueOptimizedNodeNameWithUsedNodeNames)112 TEST_F(GraphOptimizerStageTest, UniqueOptimizedNodeNameWithUsedNodeNames) {
113 GraphDef graph = GDef(
114 {NDef("a/b/c/A", "NotImportant", {}),
115 NDef("a/b/c/my_opt/my_stg_A", "NotImportant", {}),
116 NDef("a/b/c/my_opt/my_stg_A_unique0", "NotImportant", {}),
117 NDef("a/b/c/my_opt/my_stg_my_rewrite_A", "NotImportant", {}),
118 NDef("a/b/c/my_opt/my_stg_my_rewrite_A_unique1", "NotImportant", {})},
119 /*funcs=*/{});
120
121 NodeMap node_map(&graph);
122 GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
123 /*optimized_graph*/ nullptr,
124 /*graph_properties*/ nullptr,
125 /*node_map*/ &node_map,
126 /*feed_nodes*/ nullptr,
127 /*opt_level*/ RewriterConfig::ON);
128 FakeOptimizerStage stage("my_opt", "my_stg", ctx);
129
130 const auto node = ParseNodeScopeAndName("a/b/c/A");
131
132 EXPECT_EQ(stage.UniqueOptimizedNodeName(node),
133 "a/b/c/my_opt/my_stg_A_unique1");
134
135 // With rewrite rule
136 const string rewrite = "my_rewrite";
137 EXPECT_EQ(stage.UniqueOptimizedNodeName(node, rewrite),
138 "a/b/c/my_opt/my_stg_my_rewrite_A_unique2");
139 }
140
TEST_F(GraphOptimizerStageTest,GetInputNodeAndProperties)141 TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) {
142 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
143
144 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
145 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
146 auto add = ops::Add(s.WithOpName("Add"), a, b);
147
148 GrapplerItem item;
149 TF_CHECK_OK(s.ToGraphDef(&item.graph));
150
151 GraphProperties properties(item);
152 TF_CHECK_OK(properties.InferStatically(/*assume_valid_feeds*/ false));
153
154 NodeMap node_map(&item.graph);
155
156 GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
157 /*optimized_graph*/ &item.graph,
158 /*graph_properties*/ &properties,
159 /*node_map*/ &node_map,
160 /*feed_nodes*/ nullptr,
161 /*opt_level*/ RewriterConfig::ON);
162 FakeOptimizerStage stage("my_opt", "my_stg", ctx);
163
164 NodeDef* add_node;
165 TF_CHECK_OK(stage.GetInputNode("Add", &add_node));
166 ASSERT_EQ(add_node->input_size(), 2);
167 EXPECT_EQ(add_node->input(0), "a");
168 EXPECT_EQ(add_node->input(1), "b");
169
170 const OpInfo::TensorProperties* add_properties;
171 TF_CHECK_OK(stage.GetTensorProperties("Add", &add_properties));
172 EXPECT_EQ(add_properties->dtype(), DT_FLOAT);
173
174 const OpInfo::TensorProperties* a_properties;
175 TF_CHECK_OK(stage.GetTensorProperties("a:0", &a_properties));
176 EXPECT_EQ(a_properties->dtype(), DT_FLOAT_REF);
177
178 const OpInfo::TensorProperties* b_properties;
179 TF_CHECK_OK(stage.GetTensorProperties("b:0", &b_properties));
180 EXPECT_EQ(b_properties->dtype(), DT_FLOAT_REF);
181 }
182
TEST_F(GraphOptimizerStageTest,AddNodes)183 TEST_F(GraphOptimizerStageTest, AddNodes) {
184 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
185
186 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
187 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
188 auto add = ops::Add(s.WithOpName("Add"), a, b);
189
190 GrapplerItem item;
191 TF_CHECK_OK(s.ToGraphDef(&item.graph));
192
193 GraphProperties properties(item);
194 TF_CHECK_OK(properties.InferStatically(/*assume_valid_feeds*/ false));
195
196 NodeMap node_map(&item.graph);
197
198 GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
199 /*optimized_graph*/ &item.graph,
200 /*graph_properties*/ &properties,
201 /*node_map*/ &node_map,
202 /*feed_nodes*/ nullptr,
203 /*opt_level*/ RewriterConfig::ON);
204 FakeOptimizerStage stage("my_opt", "my_stg", ctx);
205
206 NodeDef* add_node;
207 TF_CHECK_OK(stage.GetInputNode("Add", &add_node));
208
209 // Add a new copy node
210 NodeDef* add_node_copy = stage.AddCopyNode("Add_1", add_node);
211 EXPECT_EQ(add_node_copy->name(), "Add_1");
212 EXPECT_EQ(add_node_copy->op(), "Add");
213 ASSERT_EQ(add_node->input_size(), 2);
214 EXPECT_EQ(add_node_copy->input(0), "a");
215 EXPECT_EQ(add_node_copy->input(1), "b");
216
217 // It must be available for by-name lookup
218 NodeDef* add_node_copy_by_name;
219 TF_CHECK_OK(stage.GetInputNode("Add_1", &add_node_copy_by_name));
220 EXPECT_EQ(add_node_copy, add_node_copy_by_name);
221
222 // Add new empty node
223 NodeDef* empty_node = stage.AddEmptyNode("Add_2");
224 EXPECT_EQ(empty_node->name(), "Add_2");
225 EXPECT_EQ(empty_node->input_size(), 0);
226
227 // It must be available for by-name lookup
228 NodeDef* empty_node_by_name;
229 TF_CHECK_OK(stage.GetInputNode("Add_2", &empty_node_by_name));
230 EXPECT_EQ(empty_node, empty_node_by_name);
231
232 // Check that AddEmptyNode adds a unique suffix if the node already exists.
233 NodeDef* unique_empty_node = stage.AddEmptyNode("Add_2");
234 EXPECT_EQ(unique_empty_node->name(), "Add_2_0");
235 }
236
237 } // namespace
238 } // end namespace grappler
239 } // end namespace tensorflow
240