1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/noop_elimination.h"
17 #include <tuple>
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/grappler/grappler_item.h"
20 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace tensorflow {
25 namespace grappler {
26 namespace {
27
GetCommonAttributes()28 std::vector<std::pair<string, AttrValue>> GetCommonAttributes() {
29 AttrValue shapes_attr, types_attr;
30 SetAttrValue("output_shapes", &shapes_attr);
31 SetAttrValue("output_types", &types_attr);
32 std::vector<std::pair<string, AttrValue>> commonAttributes = {
33 {"output_shapes", shapes_attr}, {"output_types", types_attr}};
34
35 return commonAttributes;
36 }
37
MakeUnaryNode(StringPiece node_type,int count,string input_node,MutableGraphView * graph)38 NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node,
39 MutableGraphView *graph) {
40 NodeDef *node_count = graph_utils::AddScalarConstNode<int64>(count, graph);
41 return graph_utils::AddNode("", node_type,
42 {std::move(input_node), node_count->name()},
43 GetCommonAttributes(), graph);
44 }
45
MakeUnaryNonConstNode(StringPiece node_type,string input_node,MutableGraphView * graph)46 NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node,
47 MutableGraphView *graph) {
48 NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph);
49 return graph_utils::AddNode("", node_type,
50 {std::move(input_node), node_count->name()},
51 GetCommonAttributes(), graph);
52 }
53
MakeCacheNode(string input_node,MutableGraphView * graph)54 NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) {
55 NodeDef *node_filename =
56 graph_utils::AddScalarConstNode<StringPiece>("", graph);
57 return graph_utils::AddNode("", "CacheDataset",
58 {std::move(input_node), node_filename->name()},
59 GetCommonAttributes(), graph);
60 }
61
MakeRangeNode(MutableGraphView * graph)62 NodeDef *MakeRangeNode(MutableGraphView *graph) {
63 auto *start_node = graph_utils::AddScalarConstNode<int64>(0, graph);
64 auto *stop_node = graph_utils::AddScalarConstNode<int64>(10, graph);
65 auto *step_node = graph_utils::AddScalarConstNode<int64>(1, graph);
66
67 std::vector<string> range_inputs = {start_node->name(), stop_node->name(),
68 step_node->name()};
69
70 return graph_utils::AddNode("", "RangeDataset", range_inputs,
71 GetCommonAttributes(), graph);
72 }
73
74 struct NoOpLastEliminationTest
75 : ::testing::TestWithParam<std::tuple<string, int, bool>> {};
76
77 // This test checks whether the no-op elimination correctly handles
78 // transformations at the end of the pipeline.
TEST_P(NoOpLastEliminationTest,EliminateLastNoOpNode)79 TEST_P(NoOpLastEliminationTest, EliminateLastNoOpNode) {
80 GrapplerItem item;
81 MutableGraphView graph(&item.graph);
82
83 const string &node_type = std::get<0>(GetParam());
84 const int node_count = std::get<1>(GetParam());
85 const bool should_keep_node = std::get<2>(GetParam());
86
87 NodeDef *range_node = MakeRangeNode(&graph);
88
89 NodeDef *node =
90 MakeUnaryNode(node_type, node_count, range_node->name(), &graph);
91
92 NoOpElimination optimizer;
93 GraphDef output;
94 TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
95
96 EXPECT_EQ(graph_utils::ContainsGraphNodeWithName(node->name(), output),
97 should_keep_node);
98 }
99
100 INSTANTIATE_TEST_CASE_P(
101 BasicRemovalTest, NoOpLastEliminationTest,
102 ::testing::Values(std::make_tuple("TakeDataset", -3, false),
103 std::make_tuple("TakeDataset", -1, false),
104 std::make_tuple("TakeDataset", 0, true),
105 std::make_tuple("TakeDataset", 3, true),
106 std::make_tuple("SkipDataset", -1, true),
107 std::make_tuple("SkipDataset", 0, false),
108 std::make_tuple("SkipDataset", 3, true),
109 std::make_tuple("PrefetchDataset", 0, false),
110 std::make_tuple("PrefetchDataset", 1, true),
111 std::make_tuple("RepeatDataset", 1, false),
112 std::make_tuple("RepeatDataset", 2, true)));
113
114 struct NoOpMiddleEliminationTest
115 : ::testing::TestWithParam<std::tuple<string, int, bool>> {};
116
117 // This test checks whether the no-op elimination correctly handles
118 // transformations int the middle of the pipeline.
TEST_P(NoOpMiddleEliminationTest,EliminateMiddleNoOpNode)119 TEST_P(NoOpMiddleEliminationTest, EliminateMiddleNoOpNode) {
120 GrapplerItem item;
121 MutableGraphView graph(&item.graph);
122
123 const string &node_type = std::get<0>(GetParam());
124 const int node_count = std::get<1>(GetParam());
125 const bool should_keep_node = std::get<2>(GetParam());
126
127 NodeDef *range_node = MakeRangeNode(&graph);
128
129 NodeDef *node =
130 MakeUnaryNode(node_type, node_count, range_node->name(), &graph);
131
132 NodeDef *cache_node = MakeCacheNode(node->name(), &graph);
133 NoOpElimination optimizer;
134 GraphDef output;
135 TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
136
137 EXPECT_EQ(graph_utils::ContainsGraphNodeWithName(node->name(), output),
138 should_keep_node);
139 EXPECT_TRUE(
140 graph_utils::ContainsGraphNodeWithName(cache_node->name(), output));
141
142 NodeDef cache_node_out = output.node(
143 graph_utils::FindGraphNodeWithName(cache_node->name(), output));
144
145 EXPECT_EQ(cache_node_out.input_size(), 2);
146 auto last_node_input = (should_keep_node ? node : range_node)->name();
147 EXPECT_EQ(cache_node_out.input(0), last_node_input);
148 }
149
150 INSTANTIATE_TEST_CASE_P(
151 BasicRemovalTest, NoOpMiddleEliminationTest,
152 ::testing::Values(std::make_tuple("TakeDataset", -1, false),
153 std::make_tuple("TakeDataset", -3, false),
154 std::make_tuple("TakeDataset", 0, true),
155 std::make_tuple("TakeDataset", 3, true),
156 std::make_tuple("SkipDataset", -1, true),
157 std::make_tuple("SkipDataset", 0, false),
158 std::make_tuple("SkipDataset", 3, true),
159 std::make_tuple("PrefetchDataset", 0, false),
160 std::make_tuple("PrefetchDataset", 1, true),
161 std::make_tuple("RepeatDataset", 1, false),
162 std::make_tuple("RepeatDataset", 2, true)));
163
164 using NodesTypes = std::tuple<std::pair<string, int>, std::pair<string, int>>;
165 struct NoOpMultipleEliminationTest : ::testing::TestWithParam<NodesTypes> {};
166
167 // This test checks whether the no-op elimination correctly removes
168 // multiple noop nodes.
TEST_P(NoOpMultipleEliminationTest,EliminateMultipleNoOpNode)169 TEST_P(NoOpMultipleEliminationTest, EliminateMultipleNoOpNode) {
170 GrapplerItem item;
171 MutableGraphView graph(&item.graph);
172
173 static_assert(std::tuple_size<NodesTypes>::value == 2,
174 "Make sure to include everything in the test");
175 const std::vector<std::pair<string, int>> noop_nodes = {
176 std::get<0>(GetParam()), std::get<1>(GetParam())};
177
178 NodeDef *range_node = MakeRangeNode(&graph);
179
180 NodeDef *previous = range_node;
181 std::vector<string> nodes_to_remove;
182 nodes_to_remove.reserve(noop_nodes.size());
183
184 for (const auto &noop_node : noop_nodes) {
185 NodeDef *node = MakeUnaryNode(noop_node.first, noop_node.second,
186 previous->name(), &graph);
187 nodes_to_remove.push_back(node->name());
188 previous = node;
189 }
190
191 NodeDef *cache_node = MakeCacheNode(previous->name(), &graph);
192 NoOpElimination optimizer;
193 GraphDef output;
194 TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
195
196 for (const auto &noop_node_name : nodes_to_remove)
197 EXPECT_FALSE(
198 graph_utils::ContainsGraphNodeWithName(noop_node_name, output));
199
200 EXPECT_TRUE(
201 graph_utils::ContainsGraphNodeWithName(cache_node->name(), output));
202
203 NodeDef cache_node_out = output.node(
204 graph_utils::FindGraphNodeWithName(cache_node->name(), output));
205
206 EXPECT_EQ(cache_node_out.input_size(), 2);
207 EXPECT_EQ(cache_node_out.input(0), range_node->name());
208 }
209
210 const auto *const kTakeNode = new std::pair<string, int>{"TakeDataset", -1};
211 const auto *const kSkipNode = new std::pair<string, int>{"SkipDataset", 0};
212 const auto *const kRepeatNode = new std::pair<string, int>{"RepeatDataset", 1};
213 const auto *const kPrefetchNode =
214 new std::pair<string, int>{"PrefetchDataset", 0};
215
216 INSTANTIATE_TEST_CASE_P(
217 BasicRemovalTest, NoOpMultipleEliminationTest,
218 ::testing::Combine(::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode,
219 *kPrefetchNode),
220 ::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode,
221 *kPrefetchNode)));
222
223 struct NoOpPlaceholdersTest
224 : ::testing::TestWithParam<std::tuple<string, string>> {};
225
TEST_P(NoOpPlaceholdersTest,NonConstNoOpNode)226 TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) {
227 GrapplerItem item;
228 MutableGraphView graph(&item.graph);
229
230 static_assert(std::tuple_size<NodesTypes>::value == 2,
231 "Make sure to include everything in the test");
232 const std::vector<string> noop_nodes = {std::get<0>(GetParam()),
233 std::get<1>(GetParam())};
234 NodeDef *range_node = MakeRangeNode(&graph);
235 std::vector<string> nodes_to_keep;
236 nodes_to_keep.reserve(noop_nodes.size());
237 NodeDef *previous = range_node;
238
239 for (const auto &noop_node : noop_nodes) {
240 NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph);
241 nodes_to_keep.push_back(node->name());
242 previous = node;
243 }
244
245 NoOpElimination optimizer;
246 GraphDef output;
247 TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
248 for (const auto &noop_node_name : nodes_to_keep)
249 EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output));
250 }
251
252 INSTANTIATE_TEST_CASE_P(
253 DoNotRemovePlaceholders, NoOpPlaceholdersTest,
254 ::testing::Combine(::testing::Values("TakeDataset", "SkipDataset",
255 "RepeatDataset", "PrefetchDataset"),
256 ::testing::Values("TakeDataset", "SkipDataset",
257 "RepeatDataset", "PrefetchDataset")));
258
259 } // namespace
260 } // namespace grappler
261 } // namespace tensorflow
262