1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
17
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
24 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
25 #include "tensorflow/core/grappler/utils/topological_sort.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/public/session.h"
30
31 namespace tensorflow {
32 namespace grappler {
33 namespace {
34
35 constexpr char kConstOp[] = "Const";
36 constexpr char kRangeOp[] = "RangeDataset";
37 constexpr char kBatchOp[] = "BatchDataset";
38 constexpr char kBatchV2Op[] = "BatchDatasetV2";
39 constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
40 constexpr char kMapOp[] = "MapDataset";
41 constexpr char kParallelMapOp[] = "ParallelMapDatasetV2";
42 constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset";
43 constexpr char kPrefetchOp[] = "PrefetchDataset";
44 constexpr char kAttrNameF[] = "f";
45 constexpr char kAttrNameTarguments[] = "Targuments";
46 constexpr char kAttrNameOutputTypes[] = "output_types";
47 constexpr char kAttrNameOutputShapes[] = "output_shapes";
48 constexpr char kAttrNameInterOpParallelism[] = "use_inter_op_parallelism";
49 constexpr char kAttrNamePreserveCardinality[] = "preserve_cardinality";
50 constexpr char kAttrNameSloppy[] = "sloppy";
51 constexpr char kAttrNameValue[] = "value";
52 constexpr char kAttrNameDtype[] = "dtype";
53
54 using test::function::NDef;
55
OptimizeWithMapVectorization(const GrapplerItem & item,GraphDef * output,bool use_choose_fastest)56 Status OptimizeWithMapVectorization(const GrapplerItem& item, GraphDef* output,
57 bool use_choose_fastest) {
58 MapVectorization optimizer;
59 RewriterConfig_CustomGraphOptimizer config;
60 if (use_choose_fastest) {
61 (*config.mutable_parameter_map())["use_choose_fastest"].set_s("true");
62 } else {
63 (*config.mutable_parameter_map())["use_choose_fastest"].set_s("false");
64 }
65 TF_RETURN_IF_ERROR(optimizer.Init(&config));
66 return optimizer.Optimize(nullptr, item, output);
67 }
68
69 // Adds a simple vectorizable map function that is akin to
70 // dataset.map(lambda x: tf.identity(x))
AddMapFn(MutableGraphView * graph)71 FunctionDef* AddMapFn(MutableGraphView* graph) {
72 FunctionDef* map_fn = graph->graph()->mutable_library()->add_function();
73 *map_fn = FunctionDefHelper::Create(
74 /*function_name=*/"map_fn",
75 /*in_def=*/{"x: int64"},
76 /*out_def=*/{"res: int64"},
77 /*attr_def=*/{},
78 /*node_def=*/{{{"node"}, "Identity", {"x"}, {{"T", DT_INT64}}}},
79 /*ret_def=*/{{"res", "node:output"}});
80
81 return map_fn;
82 }
83
AddMapNode(MutableGraphView * graph,const string & input_dataset,const string & map_fn,int num_parallel_calls=0)84 NodeDef* AddMapNode(MutableGraphView* graph, const string& input_dataset,
85 const string& map_fn, int num_parallel_calls = 0) {
86 NodeDef result;
87 if (num_parallel_calls) {
88 auto num_parallel_calls_node =
89 graph_utils::AddScalarConstNode(num_parallel_calls, graph);
90 result =
91 NDef(/*name=*/"map", /*op=*/kParallelMapOp,
92 /*inputs=*/{input_dataset, num_parallel_calls_node->name()},
93 /*attrs=*/
94 {{kAttrNameF, FunctionDefHelper::FunctionRef(map_fn)},
95 {kAttrNameTarguments, gtl::ArraySlice<DataType>({})},
96 {kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
97 {kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{}})},
98 {kAttrNameInterOpParallelism, false},
99 {kAttrNameSloppy, true},
100 {kAttrNamePreserveCardinality, true}});
101 } else {
102 result =
103 NDef(/*name=*/"map", /*op=*/kMapOp,
104 /*inputs=*/{input_dataset},
105 /*attrs=*/
106 {{kAttrNameF, FunctionDefHelper::FunctionRef(map_fn)},
107 {kAttrNameTarguments, gtl::ArraySlice<DataType>({})},
108 {kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
109 {kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{}})},
110 {kAttrNameInterOpParallelism, false},
111 {kAttrNamePreserveCardinality, true}});
112 }
113
114 graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
115 return graph->AddNode(std::move(result));
116 }
117
AddPrefetchNode(MutableGraphView * graph,const string & input_dataset,int64 buffer_size)118 NodeDef* AddPrefetchNode(MutableGraphView* graph, const string& input_dataset,
119 int64 buffer_size) {
120 auto buffer_size_node = graph_utils::AddScalarConstNode(buffer_size, graph);
121 NodeDef result =
122 NDef(/*name=*/"prefetch", /*op=*/kPrefetchOp,
123 /*inputs=*/{input_dataset, buffer_size_node->name()},
124 /*attrs=*/
125 {{kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
126 {kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{}})}});
127
128 return graph->AddNode(std::move(result));
129 }
130
AddBatchNode(MutableGraphView * graph,const string & input_dataset,bool v2=false,int64 batch_size=10)131 NodeDef* AddBatchNode(MutableGraphView* graph, const string& input_dataset,
132 bool v2 = false, int64 batch_size = 10) {
133 NodeDef result;
134 auto batch_size_node = graph_utils::AddScalarConstNode(batch_size, graph);
135
136 if (v2) {
137 // BatchDatasetV2
138 auto drop_remainder = graph_utils::AddScalarConstNode(true, graph);
139 result = NDef(
140 /*name=*/"batch", /*op=*/kBatchV2Op,
141 /*inputs=*/
142 {input_dataset, batch_size_node->name(), drop_remainder->name()},
143 /*attrs=*/
144 {{kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
145 {kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{10, 1}})}});
146 } else {
147 result =
148 NDef(/*name=*/"batch", /*op=*/kBatchOp,
149 /*inputs=*/{input_dataset, batch_size_node->name()},
150 /*attrs=*/
151 {{kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
152 {kAttrNameOutputShapes,
153 gtl::ArraySlice<PartialTensorShape>({{v2 ? 10 : -1, 1}})}});
154 }
155
156 graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
157 return graph->AddNode(std::move(result));
158 }
159
AddRangeNode(MutableGraphView * graph)160 NodeDef* AddRangeNode(MutableGraphView* graph) {
161 auto start = graph_utils::AddScalarConstNode(static_cast<int64>(0), graph);
162 auto stop = graph_utils::AddScalarConstNode(static_cast<int64>(10), graph);
163 auto step = graph_utils::AddScalarConstNode(static_cast<int64>(1), graph);
164
165 NodeDef result =
166 NDef(/*name=*/"range", /*op=*/kRangeOp,
167 /*inputs=*/{start->name(), stop->name(), step->name()},
168 /*attrs=*/
169 {{kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{}})},
170 {kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})}});
171
172 graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
173 return graph->AddNode(std::move(result));
174 }
175
CheckNotVectorized(const GraphDef & output,const string & map_op,const string & batch_op,const string & map_input_name)176 void CheckNotVectorized(const GraphDef& output, const string& map_op,
177 const string& batch_op, const string& map_input_name) {
178 ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(map_op, output).size(), 1);
179 ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(batch_op, output).size(), 1);
180 const NodeDef& map_node =
181 output.node(graph_utils::FindGraphNodeWithOp(map_op, output));
182 const NodeDef& batch_node =
183 output.node(graph_utils::FindGraphNodeWithOp(batch_op, output));
184 EXPECT_EQ(map_node.input(0), map_input_name);
185 EXPECT_EQ(batch_node.input(0), map_node.name());
186 }
187
CheckBranch(const FunctionDef & function,gtl::ArraySlice<string> ops)188 void CheckBranch(const FunctionDef& function, gtl::ArraySlice<string> ops) {
189 for (int i = 0, size = ops.size(); i < size; ++i) {
190 EXPECT_EQ(function.node_def(i).op(), ops[i]);
191 }
192 }
193
GetFunction(const GraphDef & graph,const string & function_name)194 const FunctionDef* GetFunction(const GraphDef& graph,
195 const string& function_name) {
196 int found =
197 graph_utils::FindGraphFunctionWithName(function_name, graph.library());
198 if (found == -1) {
199 return nullptr;
200 }
201 return &graph.library().function(found);
202 }
203
CheckVectorizedWithoutChooseFastest(const GraphDef & output,gtl::ArraySlice<string> expected_vectorized_branch,const string & input_name)204 void CheckVectorizedWithoutChooseFastest(
205 const GraphDef& output, gtl::ArraySlice<string> expected_vectorized_branch,
206 const string& input_name) {
207 std::vector<const NodeDef*> vectorized_branch;
208 for (const auto& op : expected_vectorized_branch) {
209 // This assumes that vectorized op is the only one that exists in the graph.
210 // For our test cases, this is true (we don't have superfluous map/batch
211 // nodes in other parts of the pipeline).
212 ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(op, output).size(), 1);
213 vectorized_branch.push_back(
214 &output.node(graph_utils::FindGraphNodeWithOp(op, output)));
215 }
216
217 for (int i = 1; i < vectorized_branch.size() - 1; ++i) {
218 const NodeDef* node = vectorized_branch[i];
219 const NodeDef* next_node = vectorized_branch[i + 1];
220 ASSERT_EQ(next_node->input(0), node->name());
221 }
222 ASSERT_EQ(vectorized_branch[0]->input(0), input_name);
223
224 const NodeDef* vectorized_map_node = vectorized_branch[1];
225 string function_name =
226 vectorized_map_node->attr().at(kAttrNameF).func().name();
227
228 const FunctionDef* function = GetFunction(output, function_name);
229 ASSERT_NE(function, nullptr);
230 EXPECT_EQ(function->node_def(0).op(), "Identity");
231 }
232
233 // Checks that a graph has undergone the map_vectorization transformation
234 // successfully, whereby the new graph has the shape:
235 //
236 // input_node -------------> choose_fastest --> ...
237 // |f0 |f1
238 // | |
239 // | +---> new batch --> new map
240 // |
241 // +--> old map --> old batch
242 //
CheckVectorizedWithChooseFastest(const GraphDef & output,gtl::ArraySlice<string> expected_vectorized_branch,gtl::ArraySlice<string> expected_original_branch,const string & input_name)243 void CheckVectorizedWithChooseFastest(
244 const GraphDef& output, gtl::ArraySlice<string> expected_vectorized_branch,
245 gtl::ArraySlice<string> expected_original_branch,
246 const string& input_name) {
247 for (const auto& op :
248 {kBatchOp, kBatchV2Op, kMapOp, kParallelMapOp, kMapAndBatchOp}) {
249 // Check that the dataset nodes have been removed from the main graph.
250 ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(op, output).size(), 0);
251 }
252 ASSERT_EQ(
253 graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
254 const NodeDef& choose_fastest_node =
255 output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output));
256 ASSERT_EQ(choose_fastest_node.input(0), input_name);
257
258 const auto& functions_list = choose_fastest_node.attr().at("branches").list();
259
260 // Branch 0: vectorized
261 const FunctionDef* branch_0 =
262 GetFunction(output, functions_list.func(0).name());
263 ASSERT_NE(branch_0, nullptr);
264 CheckBranch(*branch_0, expected_vectorized_branch);
265
266 // Branch 1: original
267 const FunctionDef* branch_1 =
268 GetFunction(output, functions_list.func(1).name());
269 ASSERT_NE(branch_1, nullptr);
270 CheckBranch(*branch_1, expected_original_branch);
271
272 const NodeDef& vectorized_map_node =
273 branch_0->node_def(function_utils::FindFunctionNodeWithOp(
274 expected_vectorized_branch[1], *branch_0));
275 string function_name =
276 vectorized_map_node.attr().at(kAttrNameF).func().name();
277
278 const FunctionDef* function = GetFunction(output, function_name);
279 ASSERT_NE(function, nullptr);
280 EXPECT_EQ(function->node_def(0).op(), "Identity");
281 }
282
283 class MapThenBatchTest
284 : public ::testing::TestWithParam<std::tuple<int, bool, int, bool>> {};
285
TEST_P(MapThenBatchTest,IsVectorized)286 TEST_P(MapThenBatchTest, IsVectorized) {
287 int num_parallel_calls = std::get<0>(GetParam());
288 bool use_batch_v2 = std::get<1>(GetParam());
289 int prefetch = std::get<2>(GetParam());
290 bool use_choose_fastest = std::get<3>(GetParam());
291 GrapplerItem item;
292 MutableGraphView graph(&item.graph);
293 auto range_dataset = AddRangeNode(&graph);
294 auto map_fn = AddMapFn(&graph);
295 auto dataset = AddMapNode(&graph, range_dataset->name(),
296 map_fn->signature().name(), num_parallel_calls);
297
298 if (prefetch) {
299 dataset = AddPrefetchNode(&graph, dataset->name(), prefetch);
300 }
301 dataset = AddBatchNode(&graph, dataset->name(), use_batch_v2);
302 GraphDef output;
303 TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
304
305 std::vector<string> expected_original_branch;
306 expected_original_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
307 : kMapOp);
308 if (prefetch) {
309 expected_original_branch.push_back(kPrefetchOp);
310 }
311 expected_original_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op : kBatchOp);
312
313 std::vector<string> expected_vectorized_branch;
314 expected_vectorized_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op
315 : kBatchOp);
316 expected_vectorized_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
317 : kMapOp);
318 if (prefetch) {
319 expected_vectorized_branch.push_back(kPrefetchOp);
320 }
321
322 if (use_choose_fastest) {
323 CheckVectorizedWithChooseFastest(output, expected_vectorized_branch,
324 expected_original_branch,
325 range_dataset->name());
326
327 } else {
328 CheckVectorizedWithoutChooseFastest(output, expected_vectorized_branch,
329 range_dataset->name());
330 }
331 }
332
333 INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest,
334 ::testing::Combine(::testing::Values(0, 12),
335 ::testing::Bool(),
336 ::testing::Values(0, 20),
337 ::testing::Bool()));
338
339 class MapAndBatchTest : public ::testing::TestWithParam<bool> {};
340
AddMapAndBatchNode(MutableGraphView * graph,const string & input_dataset,const string & map_fn,int64 batch_size=10,int64 num_parallel_calls=12)341 NodeDef* AddMapAndBatchNode(MutableGraphView* graph,
342 const string& input_dataset, const string& map_fn,
343 int64 batch_size = 10,
344 int64 num_parallel_calls = 12) {
345 auto batch_size_node = graph_utils::AddScalarConstNode(batch_size, graph);
346 auto num_parallel_calls_node =
347 graph_utils::AddScalarConstNode(num_parallel_calls, graph);
348 auto drop_remainder = graph_utils::AddScalarConstNode(true, graph);
349
350 NodeDef result =
351 NDef(/*name=*/"map_and_batch",
352 /*op=*/kMapAndBatchOp,
353 /*inputs=*/
354 {input_dataset, batch_size_node->name(),
355 num_parallel_calls_node->name(), drop_remainder->name()},
356 /*attrs=*/
357 {{kAttrNameF, FunctionDefHelper::FunctionRef(map_fn)},
358 {kAttrNameTarguments, gtl::ArraySlice<DataType>({})},
359 {kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
360 {kAttrNameOutputShapes,
361 gtl::ArraySlice<PartialTensorShape>({{10, 1}})}});
362
363 graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
364 return graph->AddNode(std::move(result));
365 }
366
TEST_P(MapAndBatchTest,VectorizeMapAndBatch)367 TEST_P(MapAndBatchTest, VectorizeMapAndBatch) {
368 GrapplerItem item;
369 MutableGraphView graph(&item.graph);
370 auto range_node = AddRangeNode(&graph);
371 auto map_fn = AddMapFn(&graph);
372 auto map_and_batch_node = AddMapAndBatchNode(&graph, range_node->name(),
373 map_fn->signature().name());
374 ASSERT_NE(map_and_batch_node, nullptr);
375
376 GraphDef output;
377 bool use_choose_fastest = GetParam();
378
379 TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
380 if (use_choose_fastest) {
381 CheckVectorizedWithChooseFastest(output, {kBatchV2Op, kParallelMapOp},
382 {kMapAndBatchOp}, range_node->name());
383 } else {
384 CheckVectorizedWithoutChooseFastest(output, {kBatchV2Op, kParallelMapOp},
385 range_node->name());
386 }
387 }
388
389 INSTANTIATE_TEST_SUITE_P(MapAndBatchTest, MapAndBatchTest, ::testing::Bool());
390
391 class ChainedMapAndBatchTest
392 : public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {};
393
394 // Tests:
395 // 1) map.batch.map.batch
396 // 2) map.batch.map_and_batch
397 // 3) map_and_batch.map.batch
398 // 4) map_and_batch.map_and_batch
TEST_P(ChainedMapAndBatchTest,IsVectorized)399 TEST_P(ChainedMapAndBatchTest, IsVectorized) {
400 GrapplerItem item;
401 MutableGraphView graph(&item.graph);
402 auto input_node = AddRangeNode(&graph);
403
404 auto map_fn = AddMapFn(&graph);
405
406 auto make_map_and_batch = [&graph, map_fn](NodeDef* input, bool fuse) {
407 if (fuse) {
408 return AddMapAndBatchNode(&graph, input->name(),
409 map_fn->signature().name());
410 }
411 auto map_node =
412 AddMapNode(&graph, input->name(), map_fn->signature().name(), true);
413 auto batch_node = AddBatchNode(&graph, map_node->name(), true);
414 return batch_node;
415 };
416
417 bool fuse_0 = std::get<0>(GetParam());
418 bool fuse_1 = std::get<1>(GetParam());
419 bool use_choose_fastest = std::get<2>(GetParam());
420 auto map_and_batch_0 = make_map_and_batch(input_node, fuse_0);
421 auto map_and_batch_1 = make_map_and_batch(map_and_batch_0, fuse_1);
422 ASSERT_NE(map_and_batch_1, nullptr);
423
424 GraphDef output;
425 TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
426 TF_ASSERT_OK(TopologicalSort(&output));
427
428 if (use_choose_fastest) {
429 std::vector<int> choose_fastest_nodes =
430 graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output);
431 ASSERT_EQ(choose_fastest_nodes.size(), 2);
432
433 std::vector<string> fused_sequence({kMapAndBatchOp});
434 std::vector<string> unfused_sequence({kParallelMapOp, kBatchV2Op});
435 const NodeDef& range_node =
436 output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
437 const NodeDef& choose_fastest_0 = output.node(choose_fastest_nodes[0]);
438 ASSERT_EQ(choose_fastest_0.input(0), range_node.name());
439 const NodeDef& choose_fastest_1 = output.node(choose_fastest_nodes[1]);
440 ASSERT_EQ(choose_fastest_1.input(0), choose_fastest_0.name());
441
442 auto check_branches = [&output](const NodeDef& choose_fastest_node,
443 gtl::ArraySlice<string> original_ops) {
444 const auto& functions_list =
445 choose_fastest_node.attr().at("branches").list();
446
447 // Branch 0: vectorized
448 const FunctionDef* branch_0 =
449 GetFunction(output, functions_list.func(0).name());
450 ASSERT_NE(branch_0, nullptr);
451 CheckBranch(*branch_0, {kBatchV2Op, kParallelMapOp});
452
453 // Branch 1: original
454 const FunctionDef* branch_1 =
455 GetFunction(output, functions_list.func(1).name());
456 ASSERT_NE(branch_1, nullptr);
457 CheckBranch(*branch_1, original_ops);
458 };
459
460 check_branches(choose_fastest_0,
461 fuse_0 ? fused_sequence : unfused_sequence);
462 check_branches(choose_fastest_1,
463 fuse_1 ? fused_sequence : unfused_sequence);
464 } else {
465 std::vector<int> map_nodes =
466 graph_utils::FindAllGraphNodesWithOp(kParallelMapOp, output);
467 std::vector<int> batch_nodes =
468 graph_utils::FindAllGraphNodesWithOp(kBatchV2Op, output);
469 ASSERT_EQ(map_nodes.size(), 2);
470 ASSERT_EQ(batch_nodes.size(), 2);
471
472 const NodeDef& range_node =
473 output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
474
475 const NodeDef& batch_node_0 = output.node(batch_nodes[0]);
476 EXPECT_EQ(batch_node_0.input(0), range_node.name());
477 const NodeDef& map_node_0 = output.node(map_nodes[0]);
478 EXPECT_EQ(map_node_0.input(0), batch_node_0.name());
479 const NodeDef& batch_node_1 = output.node(batch_nodes[1]);
480 EXPECT_EQ(batch_node_1.input(0), map_node_0.name());
481 const NodeDef& map_node_1 = output.node(map_nodes[1]);
482 EXPECT_EQ(map_node_1.input(0), batch_node_1.name());
483 }
484 }
485
486 INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest,
487 ::testing::Combine(::testing::Bool(),
488 ::testing::Bool(),
489 ::testing::Bool()));
490
491 // Not all dataset types have "output_shapes" and "output_types"
492 // attrs defined. Add a generic input node which may not have these attrs
493 // defined.
AddArbitraryInputNode(MutableGraphView * graph,std::vector<PartialTensorShape> * output_shapes,std::vector<DataType> * output_types)494 NodeDef* AddArbitraryInputNode(MutableGraphView* graph,
495 std::vector<PartialTensorShape>* output_shapes,
496 std::vector<DataType>* output_types) {
497 std::vector<std::pair<string, FunctionDefHelper::AttrValueWrapper>> attrs;
498 if (output_shapes) {
499 attrs.push_back({kAttrNameOutputShapes, *output_shapes});
500 }
501 if (output_types) {
502 attrs.push_back({kAttrNameOutputTypes, *output_types});
503 }
504
505 NodeDef result = NDef(/*name=*/"input", /*op=*/"InputDataset",
506 /*inputs=*/{},
507 /*attrs=*/attrs);
508
509 graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
510 return graph->AddNode(std::move(result));
511 }
512
TEST(MapVectorizationTest,VectorizeWithUndefinedOutputShapes)513 TEST(MapVectorizationTest, VectorizeWithUndefinedOutputShapes) {
514 // Tests that the optimization doesn't break when the input to MapDataset
515 // doesn't have an output_shapes attr defined. In this case, the map and
516 // batch swap does not occur.
517 GrapplerItem item;
518 MutableGraphView graph(&item.graph);
519 std::vector<DataType> input_types({DT_INT64});
520 auto input_node = AddArbitraryInputNode(&graph, nullptr, &input_types);
521 auto map_fn = AddMapFn(&graph);
522 auto map_node =
523 AddMapNode(&graph, input_node->name(), map_fn->signature().name());
524 auto batch_node = AddBatchNode(&graph, map_node->name());
525 GraphDef output;
526 TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
527 CheckNotVectorized(output, map_node->op(), batch_node->op(),
528 input_node->name());
529 }
530
TEST(MapVectorizationTest,VectorizeWithUnknownRank)531 TEST(MapVectorizationTest, VectorizeWithUnknownRank) {
532 // Tests that the optimization doesn't break when the input to MapDataset
533 // has components with unknown rank. In this case, the optimization does not
534 // occur.
535 GrapplerItem item;
536 MutableGraphView graph(&item.graph);
537 std::vector<PartialTensorShape> input_shapes({{}});
538 std::vector<DataType> input_types({DT_INT64});
539 auto input_node = AddArbitraryInputNode(&graph, &input_shapes, &input_types);
540 auto map_fn = AddMapFn(&graph);
541 auto map_node =
542 AddMapNode(&graph, input_node->name(), map_fn->signature().name());
543 auto batch_node = AddBatchNode(&graph, map_node->name());
544 GraphDef output;
545 TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
546 CheckNotVectorized(output, map_node->op(), batch_node->op(),
547 input_node->name());
548 }
549
TEST(MapVectorizationTest,VectorizeWithUnknownDim)550 TEST(MapVectorizationTest, VectorizeWithUnknownDim) {
551 // Tests that the optimization doesn't break when the input to MapDataset
552 // has components with unknown dimensions. In this case, the optimization does
553 // not occur.
554 GrapplerItem item;
555 MutableGraphView graph(&item.graph);
556 std::vector<PartialTensorShape> input_shapes({{-1, 2}});
557 std::vector<DataType> input_types({DT_INT64});
558 auto input_node = AddArbitraryInputNode(&graph, &input_shapes, &input_types);
559 auto map_fn = AddMapFn(&graph);
560 auto map_node =
561 AddMapNode(&graph, input_node->name(), map_fn->signature().name());
562 auto batch_node = AddBatchNode(&graph, map_node->name());
563 GraphDef output;
564 TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
565 CheckNotVectorized(output, map_node->op(), batch_node->op(),
566 input_node->name());
567 }
568
TEST(MapVectorizationTest,VectorizeWithUndefinedOutputTypes)569 TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
570 // Tests that the optimization doesn't break when the input doesn't have
571 // an output_types attr defined. The output_types of the input node, even
572 // if not present, can be inferred from the map function input signature.
573 GrapplerItem item;
574 MutableGraphView graph(&item.graph);
575 std::vector<PartialTensorShape> input_shapes({{1}});
576 auto input_node = AddArbitraryInputNode(&graph, &input_shapes, nullptr);
577 auto map_fn = AddMapFn(&graph);
578 auto map_node =
579 AddMapNode(&graph, input_node->name(), map_fn->signature().name());
580 auto batch_node = AddBatchNode(&graph, map_node->name());
581 GraphDef output;
582 TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
583 CheckVectorizedWithChooseFastest(
584 output, /*expected_vectorized_branch=*/{batch_node->op(), map_node->op()},
585 /*expected_original_branch=*/{map_node->op(), batch_node->op()},
586 input_node->name());
587 }
588
589 // TODO(rachelim): Add test that has a polymorphic function.
590
591 } // namespace
592 } // namespace grappler
593 } // namespace tensorflow
594