• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
17 
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/grappler/costs/graph_properties.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/platform/test.h"
23 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 namespace {
28 
29 using ::tensorflow::test::function::GDef;
30 using ::tensorflow::test::function::NDef;
31 
32 class GraphOptimizerStageTest : public ::testing::Test {};
33 
34 struct FakeResult {};
35 
36 // NoOp optimizer stage that supports all the node types and does nothing
37 class FakeOptimizerStage : public GraphOptimizerStage<FakeResult> {
38  public:
FakeOptimizerStage(const string & optimizer_name,const string & stage_name,const GraphOptimizerContext & ctx)39   explicit FakeOptimizerStage(const string& optimizer_name,
40                               const string& stage_name,
41                               const GraphOptimizerContext& ctx)
42       : GraphOptimizerStage(optimizer_name, stage_name, ctx) {}
43   ~FakeOptimizerStage() override = default;
44 
IsSupported(const NodeDef * node) const45   bool IsSupported(const NodeDef* node) const override { return true; }
TrySimplify(NodeDef * node,FakeResult * result)46   Status TrySimplify(NodeDef* node, FakeResult* result) override {
47     return Status::OK();
48   }
49 };
50 
TEST_F(GraphOptimizerStageTest,ParseNodeNameAndScopeInRoot)51 TEST_F(GraphOptimizerStageTest, ParseNodeNameAndScopeInRoot) {
52   const auto scope_and_name = ParseNodeScopeAndName("Add");
53   EXPECT_EQ(scope_and_name.scope, "");
54   EXPECT_EQ(scope_and_name.name, "Add");
55 }
56 
TEST_F(GraphOptimizerStageTest,ParseNodeNameAndScopeInScope)57 TEST_F(GraphOptimizerStageTest, ParseNodeNameAndScopeInScope) {
58   const auto scope_and_name = ParseNodeScopeAndName("a/b/c/Add");
59   EXPECT_EQ(scope_and_name.scope, "a/b/c");
60   EXPECT_EQ(scope_and_name.name, "Add");
61 }
62 
TEST_F(GraphOptimizerStageTest,OptimizedNodeName)63 TEST_F(GraphOptimizerStageTest, OptimizedNodeName) {
64   GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
65                             /*optimized_graph*/ nullptr,
66                             /*graph_properties*/ nullptr,
67                             /*node_map*/ nullptr,
68                             /*feed_nodes*/ nullptr,
69                             /*opt_level*/ RewriterConfig::ON);
70   FakeOptimizerStage stage("my_opt", "my_stg", ctx);
71 
72   const auto node = ParseNodeScopeAndName("a/b/c/Add");
73 
74   // Without rewrite rule
75   EXPECT_EQ(stage.OptimizedNodeName(node), "a/b/c/my_opt/my_stg_Add");
76   EXPECT_EQ(stage.OptimizedNodeName(node, std::vector<string>({"Mul", "Sqrt"})),
77             "a/b/c/my_opt/my_stg_Add_Mul_Sqrt");
78 
79   // With rewrite rule
80   const string rewrite = "my_rewrite";
81   EXPECT_EQ(stage.OptimizedNodeName(node, rewrite),
82             "a/b/c/my_opt/my_stg_my_rewrite_Add");
83 }
84 
TEST_F(GraphOptimizerStageTest,UniqueOptimizedNodeName)85 TEST_F(GraphOptimizerStageTest, UniqueOptimizedNodeName) {
86   GraphDef graph =
87       GDef({NDef("a/b/c/A", "NotImportant", {}),
88             NDef("a/b/c/my_opt/my_stg_A", "NotImportant", {}),
89             NDef("a/b/c/my_opt/my_stg_my_rewrite_A", "NotImportant", {})},
90            /*funcs=*/{});
91 
92   NodeMap node_map(&graph);
93   GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
94                             /*optimized_graph*/ nullptr,
95                             /*graph_properties*/ nullptr,
96                             /*node_map*/ &node_map,
97                             /*feed_nodes*/ nullptr,
98                             /*opt_level*/ RewriterConfig::ON);
99   FakeOptimizerStage stage("my_opt", "my_stg", ctx);
100 
101   const auto node = ParseNodeScopeAndName("a/b/c/A");
102 
103   EXPECT_EQ(stage.UniqueOptimizedNodeName(node),
104             "a/b/c/my_opt/my_stg_A_unique0");
105 
106   // With rewrite rule
107   const string rewrite = "my_rewrite";
108   EXPECT_EQ(stage.UniqueOptimizedNodeName(node, rewrite),
109             "a/b/c/my_opt/my_stg_my_rewrite_A_unique1");
110 }
111 
TEST_F(GraphOptimizerStageTest,UniqueOptimizedNodeNameWithUsedNodeNames)112 TEST_F(GraphOptimizerStageTest, UniqueOptimizedNodeNameWithUsedNodeNames) {
113   GraphDef graph = GDef(
114       {NDef("a/b/c/A", "NotImportant", {}),
115        NDef("a/b/c/my_opt/my_stg_A", "NotImportant", {}),
116        NDef("a/b/c/my_opt/my_stg_A_unique0", "NotImportant", {}),
117        NDef("a/b/c/my_opt/my_stg_my_rewrite_A", "NotImportant", {}),
118        NDef("a/b/c/my_opt/my_stg_my_rewrite_A_unique1", "NotImportant", {})},
119       /*funcs=*/{});
120 
121   NodeMap node_map(&graph);
122   GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
123                             /*optimized_graph*/ nullptr,
124                             /*graph_properties*/ nullptr,
125                             /*node_map*/ &node_map,
126                             /*feed_nodes*/ nullptr,
127                             /*opt_level*/ RewriterConfig::ON);
128   FakeOptimizerStage stage("my_opt", "my_stg", ctx);
129 
130   const auto node = ParseNodeScopeAndName("a/b/c/A");
131 
132   EXPECT_EQ(stage.UniqueOptimizedNodeName(node),
133             "a/b/c/my_opt/my_stg_A_unique1");
134 
135   // With rewrite rule
136   const string rewrite = "my_rewrite";
137   EXPECT_EQ(stage.UniqueOptimizedNodeName(node, rewrite),
138             "a/b/c/my_opt/my_stg_my_rewrite_A_unique2");
139 }
140 
TEST_F(GraphOptimizerStageTest,GetInputNodeAndProperties)141 TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) {
142   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
143 
144   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
145   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
146   auto add = ops::Add(s.WithOpName("Add"), a, b);
147 
148   GrapplerItem item;
149   TF_CHECK_OK(s.ToGraphDef(&item.graph));
150 
151   GraphProperties properties(item);
152   TF_CHECK_OK(properties.InferStatically(/*assume_valid_feeds*/ false));
153 
154   NodeMap node_map(&item.graph);
155 
156   GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
157                             /*optimized_graph*/ &item.graph,
158                             /*graph_properties*/ &properties,
159                             /*node_map*/ &node_map,
160                             /*feed_nodes*/ nullptr,
161                             /*opt_level*/ RewriterConfig::ON);
162   FakeOptimizerStage stage("my_opt", "my_stg", ctx);
163 
164   NodeDef* add_node;
165   TF_CHECK_OK(stage.GetInputNode("Add", &add_node));
166   ASSERT_EQ(add_node->input_size(), 2);
167   EXPECT_EQ(add_node->input(0), "a");
168   EXPECT_EQ(add_node->input(1), "b");
169 
170   const OpInfo::TensorProperties* add_properties;
171   TF_CHECK_OK(stage.GetTensorProperties("Add", &add_properties));
172   EXPECT_EQ(add_properties->dtype(), DT_FLOAT);
173 
174   const OpInfo::TensorProperties* a_properties;
175   TF_CHECK_OK(stage.GetTensorProperties("a:0", &a_properties));
176   EXPECT_EQ(a_properties->dtype(), DT_FLOAT_REF);
177 
178   const OpInfo::TensorProperties* b_properties;
179   TF_CHECK_OK(stage.GetTensorProperties("b:0", &b_properties));
180   EXPECT_EQ(b_properties->dtype(), DT_FLOAT_REF);
181 }
182 
TEST_F(GraphOptimizerStageTest,AddNodes)183 TEST_F(GraphOptimizerStageTest, AddNodes) {
184   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
185 
186   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
187   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
188   auto add = ops::Add(s.WithOpName("Add"), a, b);
189 
190   GrapplerItem item;
191   TF_CHECK_OK(s.ToGraphDef(&item.graph));
192 
193   GraphProperties properties(item);
194   TF_CHECK_OK(properties.InferStatically(/*assume_valid_feeds*/ false));
195 
196   NodeMap node_map(&item.graph);
197 
198   GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
199                             /*optimized_graph*/ &item.graph,
200                             /*graph_properties*/ &properties,
201                             /*node_map*/ &node_map,
202                             /*feed_nodes*/ nullptr,
203                             /*opt_level*/ RewriterConfig::ON);
204   FakeOptimizerStage stage("my_opt", "my_stg", ctx);
205 
206   NodeDef* add_node;
207   TF_CHECK_OK(stage.GetInputNode("Add", &add_node));
208 
209   // Add a new copy node
210   NodeDef* add_node_copy = stage.AddCopyNode("Add_1", add_node);
211   EXPECT_EQ(add_node_copy->name(), "Add_1");
212   EXPECT_EQ(add_node_copy->op(), "Add");
213   ASSERT_EQ(add_node->input_size(), 2);
214   EXPECT_EQ(add_node_copy->input(0), "a");
215   EXPECT_EQ(add_node_copy->input(1), "b");
216 
217   // It must be available for by-name lookup
218   NodeDef* add_node_copy_by_name;
219   TF_CHECK_OK(stage.GetInputNode("Add_1", &add_node_copy_by_name));
220   EXPECT_EQ(add_node_copy, add_node_copy_by_name);
221 
222   // Add new empty node
223   NodeDef* empty_node = stage.AddEmptyNode("Add_2");
224   EXPECT_EQ(empty_node->name(), "Add_2");
225   EXPECT_EQ(empty_node->input_size(), 0);
226 
227   // It must be available for by-name lookup
228   NodeDef* empty_node_by_name;
229   TF_CHECK_OK(stage.GetInputNode("Add_2", &empty_node_by_name));
230   EXPECT_EQ(empty_node, empty_node_by_name);
231 
232   // Check that AddEmptyNode adds a unique suffix if the node already exists.
233   NodeDef* unique_empty_node = stage.AddEmptyNode("Add_2");
234   EXPECT_EQ(unique_empty_node->name(), "Add_2_0");
235 }
236 
237 }  // namespace
238 }  // end namespace grappler
239 }  // end namespace tensorflow
240