• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/noop_elimination.h"
17 #include <tuple>
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/grappler/grappler_item.h"
20 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 namespace {
27 
GetCommonAttributes()28 std::vector<std::pair<string, AttrValue>> GetCommonAttributes() {
29   AttrValue shapes_attr, types_attr;
30   SetAttrValue("output_shapes", &shapes_attr);
31   SetAttrValue("output_types", &types_attr);
32   std::vector<std::pair<string, AttrValue>> commonAttributes = {
33       {"output_shapes", shapes_attr}, {"output_types", types_attr}};
34 
35   return commonAttributes;
36 }
37 
MakeUnaryNode(StringPiece node_type,int count,string input_node,MutableGraphView * graph)38 NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node,
39                        MutableGraphView *graph) {
40   NodeDef *node_count = graph_utils::AddScalarConstNode<int64>(count, graph);
41   return graph_utils::AddNode("", node_type,
42                               {std::move(input_node), node_count->name()},
43                               GetCommonAttributes(), graph);
44 }
45 
MakeUnaryNonConstNode(StringPiece node_type,string input_node,MutableGraphView * graph)46 NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node,
47                                MutableGraphView *graph) {
48   NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph);
49   return graph_utils::AddNode("", node_type,
50                               {std::move(input_node), node_count->name()},
51                               GetCommonAttributes(), graph);
52 }
53 
MakeCacheNode(string input_node,MutableGraphView * graph)54 NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) {
55   NodeDef *node_filename =
56       graph_utils::AddScalarConstNode<StringPiece>("", graph);
57   return graph_utils::AddNode("", "CacheDataset",
58                               {std::move(input_node), node_filename->name()},
59                               GetCommonAttributes(), graph);
60 }
61 
MakeRangeNode(MutableGraphView * graph)62 NodeDef *MakeRangeNode(MutableGraphView *graph) {
63   auto *start_node = graph_utils::AddScalarConstNode<int64>(0, graph);
64   auto *stop_node = graph_utils::AddScalarConstNode<int64>(10, graph);
65   auto *step_node = graph_utils::AddScalarConstNode<int64>(1, graph);
66 
67   std::vector<string> range_inputs = {start_node->name(), stop_node->name(),
68                                       step_node->name()};
69 
70   return graph_utils::AddNode("", "RangeDataset", range_inputs,
71                               GetCommonAttributes(), graph);
72 }
73 
74 struct NoOpLastEliminationTest
75     : ::testing::TestWithParam<std::tuple<string, int, bool>> {};
76 
77 // This test checks whether the no-op elimination correctly handles
78 // transformations at the end of the pipeline.
TEST_P(NoOpLastEliminationTest,EliminateLastNoOpNode)79 TEST_P(NoOpLastEliminationTest, EliminateLastNoOpNode) {
80   GrapplerItem item;
81   MutableGraphView graph(&item.graph);
82 
83   const string &node_type = std::get<0>(GetParam());
84   const int node_count = std::get<1>(GetParam());
85   const bool should_keep_node = std::get<2>(GetParam());
86 
87   NodeDef *range_node = MakeRangeNode(&graph);
88 
89   NodeDef *node =
90       MakeUnaryNode(node_type, node_count, range_node->name(), &graph);
91 
92   NoOpElimination optimizer;
93   GraphDef output;
94   TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
95 
96   EXPECT_EQ(graph_utils::ContainsGraphNodeWithName(node->name(), output),
97             should_keep_node);
98 }
99 
100 INSTANTIATE_TEST_CASE_P(
101     BasicRemovalTest, NoOpLastEliminationTest,
102     ::testing::Values(std::make_tuple("TakeDataset", -3, false),
103                       std::make_tuple("TakeDataset", -1, false),
104                       std::make_tuple("TakeDataset", 0, true),
105                       std::make_tuple("TakeDataset", 3, true),
106                       std::make_tuple("SkipDataset", -1, true),
107                       std::make_tuple("SkipDataset", 0, false),
108                       std::make_tuple("SkipDataset", 3, true),
109                       std::make_tuple("PrefetchDataset", 0, false),
110                       std::make_tuple("PrefetchDataset", 1, true),
111                       std::make_tuple("RepeatDataset", 1, false),
112                       std::make_tuple("RepeatDataset", 2, true)));
113 
114 struct NoOpMiddleEliminationTest
115     : ::testing::TestWithParam<std::tuple<string, int, bool>> {};
116 
117 // This test checks whether the no-op elimination correctly handles
118 // transformations int the middle of the pipeline.
TEST_P(NoOpMiddleEliminationTest,EliminateMiddleNoOpNode)119 TEST_P(NoOpMiddleEliminationTest, EliminateMiddleNoOpNode) {
120   GrapplerItem item;
121   MutableGraphView graph(&item.graph);
122 
123   const string &node_type = std::get<0>(GetParam());
124   const int node_count = std::get<1>(GetParam());
125   const bool should_keep_node = std::get<2>(GetParam());
126 
127   NodeDef *range_node = MakeRangeNode(&graph);
128 
129   NodeDef *node =
130       MakeUnaryNode(node_type, node_count, range_node->name(), &graph);
131 
132   NodeDef *cache_node = MakeCacheNode(node->name(), &graph);
133   NoOpElimination optimizer;
134   GraphDef output;
135   TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
136 
137   EXPECT_EQ(graph_utils::ContainsGraphNodeWithName(node->name(), output),
138             should_keep_node);
139   EXPECT_TRUE(
140       graph_utils::ContainsGraphNodeWithName(cache_node->name(), output));
141 
142   NodeDef cache_node_out = output.node(
143       graph_utils::FindGraphNodeWithName(cache_node->name(), output));
144 
145   EXPECT_EQ(cache_node_out.input_size(), 2);
146   auto last_node_input = (should_keep_node ? node : range_node)->name();
147   EXPECT_EQ(cache_node_out.input(0), last_node_input);
148 }
149 
150 INSTANTIATE_TEST_CASE_P(
151     BasicRemovalTest, NoOpMiddleEliminationTest,
152     ::testing::Values(std::make_tuple("TakeDataset", -1, false),
153                       std::make_tuple("TakeDataset", -3, false),
154                       std::make_tuple("TakeDataset", 0, true),
155                       std::make_tuple("TakeDataset", 3, true),
156                       std::make_tuple("SkipDataset", -1, true),
157                       std::make_tuple("SkipDataset", 0, false),
158                       std::make_tuple("SkipDataset", 3, true),
159                       std::make_tuple("PrefetchDataset", 0, false),
160                       std::make_tuple("PrefetchDataset", 1, true),
161                       std::make_tuple("RepeatDataset", 1, false),
162                       std::make_tuple("RepeatDataset", 2, true)));
163 
164 using NodesTypes = std::tuple<std::pair<string, int>, std::pair<string, int>>;
165 struct NoOpMultipleEliminationTest : ::testing::TestWithParam<NodesTypes> {};
166 
167 // This test checks whether the no-op elimination correctly removes
168 // multiple noop nodes.
TEST_P(NoOpMultipleEliminationTest,EliminateMultipleNoOpNode)169 TEST_P(NoOpMultipleEliminationTest, EliminateMultipleNoOpNode) {
170   GrapplerItem item;
171   MutableGraphView graph(&item.graph);
172 
173   static_assert(std::tuple_size<NodesTypes>::value == 2,
174                 "Make sure to include everything in the test");
175   const std::vector<std::pair<string, int>> noop_nodes = {
176       std::get<0>(GetParam()), std::get<1>(GetParam())};
177 
178   NodeDef *range_node = MakeRangeNode(&graph);
179 
180   NodeDef *previous = range_node;
181   std::vector<string> nodes_to_remove;
182   nodes_to_remove.reserve(noop_nodes.size());
183 
184   for (const auto &noop_node : noop_nodes) {
185     NodeDef *node = MakeUnaryNode(noop_node.first, noop_node.second,
186                                   previous->name(), &graph);
187     nodes_to_remove.push_back(node->name());
188     previous = node;
189   }
190 
191   NodeDef *cache_node = MakeCacheNode(previous->name(), &graph);
192   NoOpElimination optimizer;
193   GraphDef output;
194   TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
195 
196   for (const auto &noop_node_name : nodes_to_remove)
197     EXPECT_FALSE(
198         graph_utils::ContainsGraphNodeWithName(noop_node_name, output));
199 
200   EXPECT_TRUE(
201       graph_utils::ContainsGraphNodeWithName(cache_node->name(), output));
202 
203   NodeDef cache_node_out = output.node(
204       graph_utils::FindGraphNodeWithName(cache_node->name(), output));
205 
206   EXPECT_EQ(cache_node_out.input_size(), 2);
207   EXPECT_EQ(cache_node_out.input(0), range_node->name());
208 }
209 
210 const auto *const kTakeNode = new std::pair<string, int>{"TakeDataset", -1};
211 const auto *const kSkipNode = new std::pair<string, int>{"SkipDataset", 0};
212 const auto *const kRepeatNode = new std::pair<string, int>{"RepeatDataset", 1};
213 const auto *const kPrefetchNode =
214     new std::pair<string, int>{"PrefetchDataset", 0};
215 
216 INSTANTIATE_TEST_CASE_P(
217     BasicRemovalTest, NoOpMultipleEliminationTest,
218     ::testing::Combine(::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode,
219                                          *kPrefetchNode),
220                        ::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode,
221                                          *kPrefetchNode)));
222 
223 struct NoOpPlaceholdersTest
224     : ::testing::TestWithParam<std::tuple<string, string>> {};
225 
TEST_P(NoOpPlaceholdersTest,NonConstNoOpNode)226 TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) {
227   GrapplerItem item;
228   MutableGraphView graph(&item.graph);
229 
230   static_assert(std::tuple_size<NodesTypes>::value == 2,
231                 "Make sure to include everything in the test");
232   const std::vector<string> noop_nodes = {std::get<0>(GetParam()),
233                                           std::get<1>(GetParam())};
234   NodeDef *range_node = MakeRangeNode(&graph);
235   std::vector<string> nodes_to_keep;
236   nodes_to_keep.reserve(noop_nodes.size());
237   NodeDef *previous = range_node;
238 
239   for (const auto &noop_node : noop_nodes) {
240     NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph);
241     nodes_to_keep.push_back(node->name());
242     previous = node;
243   }
244 
245   NoOpElimination optimizer;
246   GraphDef output;
247   TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
248   for (const auto &noop_node_name : nodes_to_keep)
249     EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output));
250 }
251 
252 INSTANTIATE_TEST_CASE_P(
253     DoNotRemovePlaceholders, NoOpPlaceholdersTest,
254     ::testing::Combine(::testing::Values("TakeDataset", "SkipDataset",
255                                          "RepeatDataset", "PrefetchDataset"),
256                        ::testing::Values("TakeDataset", "SkipDataset",
257                                          "RepeatDataset", "PrefetchDataset")));
258 
259 }  // namespace
260 }  // namespace grappler
261 }  // namespace tensorflow
262