• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
17 
18 #include "tensorflow/core/framework/attr_value_util.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
24 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
25 #include "tensorflow/core/grappler/utils/topological_sort.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/public/session.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 namespace {
34 
35 constexpr char kConstOp[] = "Const";
36 constexpr char kRangeOp[] = "RangeDataset";
37 constexpr char kBatchOp[] = "BatchDataset";
38 constexpr char kBatchV2Op[] = "BatchDatasetV2";
39 constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
40 constexpr char kMapOp[] = "MapDataset";
41 constexpr char kParallelMapOp[] = "ParallelMapDatasetV2";
42 constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset";
43 constexpr char kPrefetchOp[] = "PrefetchDataset";
44 constexpr char kAttrNameF[] = "f";
45 constexpr char kAttrNameTarguments[] = "Targuments";
46 constexpr char kAttrNameOutputTypes[] = "output_types";
47 constexpr char kAttrNameOutputShapes[] = "output_shapes";
48 constexpr char kAttrNameInterOpParallelism[] = "use_inter_op_parallelism";
49 constexpr char kAttrNamePreserveCardinality[] = "preserve_cardinality";
50 constexpr char kAttrNameSloppy[] = "sloppy";
51 constexpr char kAttrNameValue[] = "value";
52 constexpr char kAttrNameDtype[] = "dtype";
53 
54 using test::function::NDef;
55 
OptimizeWithMapVectorization(const GrapplerItem & item,GraphDef * output,bool use_choose_fastest)56 Status OptimizeWithMapVectorization(const GrapplerItem& item, GraphDef* output,
57                                     bool use_choose_fastest) {
58   MapVectorization optimizer;
59   RewriterConfig_CustomGraphOptimizer config;
60   if (use_choose_fastest) {
61     (*config.mutable_parameter_map())["use_choose_fastest"].set_s("true");
62   } else {
63     (*config.mutable_parameter_map())["use_choose_fastest"].set_s("false");
64   }
65   TF_RETURN_IF_ERROR(optimizer.Init(&config));
66   return optimizer.Optimize(nullptr, item, output);
67 }
68 
69 // Adds a simple vectorizable map function that is akin to
70 // dataset.map(lambda x: tf.identity(x))
AddMapFn(MutableGraphView * graph)71 FunctionDef* AddMapFn(MutableGraphView* graph) {
72   FunctionDef* map_fn = graph->graph()->mutable_library()->add_function();
73   *map_fn = FunctionDefHelper::Create(
74       /*function_name=*/"map_fn",
75       /*in_def=*/{"x: int64"},
76       /*out_def=*/{"res: int64"},
77       /*attr_def=*/{},
78       /*node_def=*/{{{"node"}, "Identity", {"x"}, {{"T", DT_INT64}}}},
79       /*ret_def=*/{{"res", "node:output"}});
80 
81   return map_fn;
82 }
83 
AddMapNode(MutableGraphView * graph,const string & input_dataset,const string & map_fn,int num_parallel_calls=0)84 NodeDef* AddMapNode(MutableGraphView* graph, const string& input_dataset,
85                     const string& map_fn, int num_parallel_calls = 0) {
86   NodeDef result;
87   if (num_parallel_calls) {
88     auto num_parallel_calls_node =
89         graph_utils::AddScalarConstNode(num_parallel_calls, graph);
90     result =
91         NDef(/*name=*/"map", /*op=*/kParallelMapOp,
92              /*inputs=*/{input_dataset, num_parallel_calls_node->name()},
93              /*attrs=*/
94              {{kAttrNameF, FunctionDefHelper::FunctionRef(map_fn)},
95               {kAttrNameTarguments, gtl::ArraySlice<DataType>({})},
96               {kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
97               {kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{}})},
98               {kAttrNameInterOpParallelism, false},
99               {kAttrNameSloppy, true},
100               {kAttrNamePreserveCardinality, true}});
101   } else {
102     result =
103         NDef(/*name=*/"map", /*op=*/kMapOp,
104              /*inputs=*/{input_dataset},
105              /*attrs=*/
106              {{kAttrNameF, FunctionDefHelper::FunctionRef(map_fn)},
107               {kAttrNameTarguments, gtl::ArraySlice<DataType>({})},
108               {kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
109               {kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{}})},
110               {kAttrNameInterOpParallelism, false},
111               {kAttrNamePreserveCardinality, true}});
112   }
113 
114   graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
115   return graph->AddNode(std::move(result));
116 }
117 
AddPrefetchNode(MutableGraphView * graph,const string & input_dataset,int64 buffer_size)118 NodeDef* AddPrefetchNode(MutableGraphView* graph, const string& input_dataset,
119                          int64 buffer_size) {
120   auto buffer_size_node = graph_utils::AddScalarConstNode(buffer_size, graph);
121   NodeDef result =
122       NDef(/*name=*/"prefetch", /*op=*/kPrefetchOp,
123            /*inputs=*/{input_dataset, buffer_size_node->name()},
124            /*attrs=*/
125            {{kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
126             {kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{}})}});
127 
128   return graph->AddNode(std::move(result));
129 }
130 
AddBatchNode(MutableGraphView * graph,const string & input_dataset,bool v2=false,int64 batch_size=10)131 NodeDef* AddBatchNode(MutableGraphView* graph, const string& input_dataset,
132                       bool v2 = false, int64 batch_size = 10) {
133   NodeDef result;
134   auto batch_size_node = graph_utils::AddScalarConstNode(batch_size, graph);
135 
136   if (v2) {
137     // BatchDatasetV2
138     auto drop_remainder = graph_utils::AddScalarConstNode(true, graph);
139     result = NDef(
140         /*name=*/"batch", /*op=*/kBatchV2Op,
141         /*inputs=*/
142         {input_dataset, batch_size_node->name(), drop_remainder->name()},
143         /*attrs=*/
144         {{kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
145          {kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{10, 1}})}});
146   } else {
147     result =
148         NDef(/*name=*/"batch", /*op=*/kBatchOp,
149              /*inputs=*/{input_dataset, batch_size_node->name()},
150              /*attrs=*/
151              {{kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
152               {kAttrNameOutputShapes,
153                gtl::ArraySlice<PartialTensorShape>({{v2 ? 10 : -1, 1}})}});
154   }
155 
156   graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
157   return graph->AddNode(std::move(result));
158 }
159 
AddRangeNode(MutableGraphView * graph)160 NodeDef* AddRangeNode(MutableGraphView* graph) {
161   auto start = graph_utils::AddScalarConstNode(static_cast<int64>(0), graph);
162   auto stop = graph_utils::AddScalarConstNode(static_cast<int64>(10), graph);
163   auto step = graph_utils::AddScalarConstNode(static_cast<int64>(1), graph);
164 
165   NodeDef result =
166       NDef(/*name=*/"range", /*op=*/kRangeOp,
167            /*inputs=*/{start->name(), stop->name(), step->name()},
168            /*attrs=*/
169            {{kAttrNameOutputShapes, gtl::ArraySlice<TensorShape>({{}})},
170             {kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})}});
171 
172   graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
173   return graph->AddNode(std::move(result));
174 }
175 
CheckNotVectorized(const GraphDef & output,const string & map_op,const string & batch_op,const string & map_input_name)176 void CheckNotVectorized(const GraphDef& output, const string& map_op,
177                         const string& batch_op, const string& map_input_name) {
178   ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(map_op, output).size(), 1);
179   ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(batch_op, output).size(), 1);
180   const NodeDef& map_node =
181       output.node(graph_utils::FindGraphNodeWithOp(map_op, output));
182   const NodeDef& batch_node =
183       output.node(graph_utils::FindGraphNodeWithOp(batch_op, output));
184   EXPECT_EQ(map_node.input(0), map_input_name);
185   EXPECT_EQ(batch_node.input(0), map_node.name());
186 }
187 
CheckBranch(const FunctionDef & function,gtl::ArraySlice<string> ops)188 void CheckBranch(const FunctionDef& function, gtl::ArraySlice<string> ops) {
189   for (int i = 0, size = ops.size(); i < size; ++i) {
190     EXPECT_EQ(function.node_def(i).op(), ops[i]);
191   }
192 }
193 
GetFunction(const GraphDef & graph,const string & function_name)194 const FunctionDef* GetFunction(const GraphDef& graph,
195                                const string& function_name) {
196   int found =
197       graph_utils::FindGraphFunctionWithName(function_name, graph.library());
198   if (found == -1) {
199     return nullptr;
200   }
201   return &graph.library().function(found);
202 }
203 
CheckVectorizedWithoutChooseFastest(const GraphDef & output,gtl::ArraySlice<string> expected_vectorized_branch,const string & input_name)204 void CheckVectorizedWithoutChooseFastest(
205     const GraphDef& output, gtl::ArraySlice<string> expected_vectorized_branch,
206     const string& input_name) {
207   std::vector<const NodeDef*> vectorized_branch;
208   for (const auto& op : expected_vectorized_branch) {
209     // This assumes that vectorized op is the only one that exists in the graph.
210     // For our test cases, this is true (we don't have superfluous map/batch
211     // nodes in other parts of the pipeline).
212     ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(op, output).size(), 1);
213     vectorized_branch.push_back(
214         &output.node(graph_utils::FindGraphNodeWithOp(op, output)));
215   }
216 
217   for (int i = 1; i < vectorized_branch.size() - 1; ++i) {
218     const NodeDef* node = vectorized_branch[i];
219     const NodeDef* next_node = vectorized_branch[i + 1];
220     ASSERT_EQ(next_node->input(0), node->name());
221   }
222   ASSERT_EQ(vectorized_branch[0]->input(0), input_name);
223 
224   const NodeDef* vectorized_map_node = vectorized_branch[1];
225   string function_name =
226       vectorized_map_node->attr().at(kAttrNameF).func().name();
227 
228   const FunctionDef* function = GetFunction(output, function_name);
229   ASSERT_NE(function, nullptr);
230   EXPECT_EQ(function->node_def(0).op(), "Identity");
231 }
232 
233 // Checks that a graph has undergone the map_vectorization transformation
234 // successfully, whereby the new graph has the shape:
235 //
236 //    input_node -------------> choose_fastest --> ...
237 //                               |f0    |f1
238 //                               |      |
239 //                               |      +---> new batch --> new map
240 //                               |
241 //                               +--> old map --> old batch
242 //
CheckVectorizedWithChooseFastest(const GraphDef & output,gtl::ArraySlice<string> expected_vectorized_branch,gtl::ArraySlice<string> expected_original_branch,const string & input_name)243 void CheckVectorizedWithChooseFastest(
244     const GraphDef& output, gtl::ArraySlice<string> expected_vectorized_branch,
245     gtl::ArraySlice<string> expected_original_branch,
246     const string& input_name) {
247   for (const auto& op :
248        {kBatchOp, kBatchV2Op, kMapOp, kParallelMapOp, kMapAndBatchOp}) {
249     // Check that the dataset nodes have been removed from the main graph.
250     ASSERT_EQ(graph_utils::FindAllGraphNodesWithOp(op, output).size(), 0);
251   }
252   ASSERT_EQ(
253       graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output).size(), 1);
254   const NodeDef& choose_fastest_node =
255       output.node(graph_utils::FindGraphNodeWithOp(kChooseFastestOp, output));
256   ASSERT_EQ(choose_fastest_node.input(0), input_name);
257 
258   const auto& functions_list = choose_fastest_node.attr().at("branches").list();
259 
260   // Branch 0: vectorized
261   const FunctionDef* branch_0 =
262       GetFunction(output, functions_list.func(0).name());
263   ASSERT_NE(branch_0, nullptr);
264   CheckBranch(*branch_0, expected_vectorized_branch);
265 
266   // Branch 1: original
267   const FunctionDef* branch_1 =
268       GetFunction(output, functions_list.func(1).name());
269   ASSERT_NE(branch_1, nullptr);
270   CheckBranch(*branch_1, expected_original_branch);
271 
272   const NodeDef& vectorized_map_node =
273       branch_0->node_def(function_utils::FindFunctionNodeWithOp(
274           expected_vectorized_branch[1], *branch_0));
275   string function_name =
276       vectorized_map_node.attr().at(kAttrNameF).func().name();
277 
278   const FunctionDef* function = GetFunction(output, function_name);
279   ASSERT_NE(function, nullptr);
280   EXPECT_EQ(function->node_def(0).op(), "Identity");
281 }
282 
283 class MapThenBatchTest
284     : public ::testing::TestWithParam<std::tuple<int, bool, int, bool>> {};
285 
TEST_P(MapThenBatchTest,IsVectorized)286 TEST_P(MapThenBatchTest, IsVectorized) {
287   int num_parallel_calls = std::get<0>(GetParam());
288   bool use_batch_v2 = std::get<1>(GetParam());
289   int prefetch = std::get<2>(GetParam());
290   bool use_choose_fastest = std::get<3>(GetParam());
291   GrapplerItem item;
292   MutableGraphView graph(&item.graph);
293   auto range_dataset = AddRangeNode(&graph);
294   auto map_fn = AddMapFn(&graph);
295   auto dataset = AddMapNode(&graph, range_dataset->name(),
296                             map_fn->signature().name(), num_parallel_calls);
297 
298   if (prefetch) {
299     dataset = AddPrefetchNode(&graph, dataset->name(), prefetch);
300   }
301   dataset = AddBatchNode(&graph, dataset->name(), use_batch_v2);
302   GraphDef output;
303   TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
304 
305   std::vector<string> expected_original_branch;
306   expected_original_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
307                                                             : kMapOp);
308   if (prefetch) {
309     expected_original_branch.push_back(kPrefetchOp);
310   }
311   expected_original_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op : kBatchOp);
312 
313   std::vector<string> expected_vectorized_branch;
314   expected_vectorized_branch.push_back(use_batch_v2 > 0 ? kBatchV2Op
315                                                         : kBatchOp);
316   expected_vectorized_branch.push_back(num_parallel_calls > 0 ? kParallelMapOp
317                                                               : kMapOp);
318   if (prefetch) {
319     expected_vectorized_branch.push_back(kPrefetchOp);
320   }
321 
322   if (use_choose_fastest) {
323     CheckVectorizedWithChooseFastest(output, expected_vectorized_branch,
324                                      expected_original_branch,
325                                      range_dataset->name());
326 
327   } else {
328     CheckVectorizedWithoutChooseFastest(output, expected_vectorized_branch,
329                                         range_dataset->name());
330   }
331 }
332 
333 INSTANTIATE_TEST_SUITE_P(MapThenBatchTest, MapThenBatchTest,
334                          ::testing::Combine(::testing::Values(0, 12),
335                                             ::testing::Bool(),
336                                             ::testing::Values(0, 20),
337                                             ::testing::Bool()));
338 
339 class MapAndBatchTest : public ::testing::TestWithParam<bool> {};
340 
AddMapAndBatchNode(MutableGraphView * graph,const string & input_dataset,const string & map_fn,int64 batch_size=10,int64 num_parallel_calls=12)341 NodeDef* AddMapAndBatchNode(MutableGraphView* graph,
342                             const string& input_dataset, const string& map_fn,
343                             int64 batch_size = 10,
344                             int64 num_parallel_calls = 12) {
345   auto batch_size_node = graph_utils::AddScalarConstNode(batch_size, graph);
346   auto num_parallel_calls_node =
347       graph_utils::AddScalarConstNode(num_parallel_calls, graph);
348   auto drop_remainder = graph_utils::AddScalarConstNode(true, graph);
349 
350   NodeDef result =
351       NDef(/*name=*/"map_and_batch",
352            /*op=*/kMapAndBatchOp,
353            /*inputs=*/
354            {input_dataset, batch_size_node->name(),
355             num_parallel_calls_node->name(), drop_remainder->name()},
356            /*attrs=*/
357            {{kAttrNameF, FunctionDefHelper::FunctionRef(map_fn)},
358             {kAttrNameTarguments, gtl::ArraySlice<DataType>({})},
359             {kAttrNameOutputTypes, gtl::ArraySlice<DataType>({DT_INT64})},
360             {kAttrNameOutputShapes,
361              gtl::ArraySlice<PartialTensorShape>({{10, 1}})}});
362 
363   graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
364   return graph->AddNode(std::move(result));
365 }
366 
TEST_P(MapAndBatchTest,VectorizeMapAndBatch)367 TEST_P(MapAndBatchTest, VectorizeMapAndBatch) {
368   GrapplerItem item;
369   MutableGraphView graph(&item.graph);
370   auto range_node = AddRangeNode(&graph);
371   auto map_fn = AddMapFn(&graph);
372   auto map_and_batch_node = AddMapAndBatchNode(&graph, range_node->name(),
373                                                map_fn->signature().name());
374   ASSERT_NE(map_and_batch_node, nullptr);
375 
376   GraphDef output;
377   bool use_choose_fastest = GetParam();
378 
379   TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
380   if (use_choose_fastest) {
381     CheckVectorizedWithChooseFastest(output, {kBatchV2Op, kParallelMapOp},
382                                      {kMapAndBatchOp}, range_node->name());
383   } else {
384     CheckVectorizedWithoutChooseFastest(output, {kBatchV2Op, kParallelMapOp},
385                                         range_node->name());
386   }
387 }
388 
389 INSTANTIATE_TEST_SUITE_P(MapAndBatchTest, MapAndBatchTest, ::testing::Bool());
390 
391 class ChainedMapAndBatchTest
392     : public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {};
393 
394 // Tests:
395 // 1) map.batch.map.batch
396 // 2) map.batch.map_and_batch
397 // 3) map_and_batch.map.batch
398 // 4) map_and_batch.map_and_batch
TEST_P(ChainedMapAndBatchTest,IsVectorized)399 TEST_P(ChainedMapAndBatchTest, IsVectorized) {
400   GrapplerItem item;
401   MutableGraphView graph(&item.graph);
402   auto input_node = AddRangeNode(&graph);
403 
404   auto map_fn = AddMapFn(&graph);
405 
406   auto make_map_and_batch = [&graph, map_fn](NodeDef* input, bool fuse) {
407     if (fuse) {
408       return AddMapAndBatchNode(&graph, input->name(),
409                                 map_fn->signature().name());
410     }
411     auto map_node =
412         AddMapNode(&graph, input->name(), map_fn->signature().name(), true);
413     auto batch_node = AddBatchNode(&graph, map_node->name(), true);
414     return batch_node;
415   };
416 
417   bool fuse_0 = std::get<0>(GetParam());
418   bool fuse_1 = std::get<1>(GetParam());
419   bool use_choose_fastest = std::get<2>(GetParam());
420   auto map_and_batch_0 = make_map_and_batch(input_node, fuse_0);
421   auto map_and_batch_1 = make_map_and_batch(map_and_batch_0, fuse_1);
422   ASSERT_NE(map_and_batch_1, nullptr);
423 
424   GraphDef output;
425   TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, use_choose_fastest));
426   TF_ASSERT_OK(TopologicalSort(&output));
427 
428   if (use_choose_fastest) {
429     std::vector<int> choose_fastest_nodes =
430         graph_utils::FindAllGraphNodesWithOp(kChooseFastestOp, output);
431     ASSERT_EQ(choose_fastest_nodes.size(), 2);
432 
433     std::vector<string> fused_sequence({kMapAndBatchOp});
434     std::vector<string> unfused_sequence({kParallelMapOp, kBatchV2Op});
435     const NodeDef& range_node =
436         output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
437     const NodeDef& choose_fastest_0 = output.node(choose_fastest_nodes[0]);
438     ASSERT_EQ(choose_fastest_0.input(0), range_node.name());
439     const NodeDef& choose_fastest_1 = output.node(choose_fastest_nodes[1]);
440     ASSERT_EQ(choose_fastest_1.input(0), choose_fastest_0.name());
441 
442     auto check_branches = [&output](const NodeDef& choose_fastest_node,
443                                     gtl::ArraySlice<string> original_ops) {
444       const auto& functions_list =
445           choose_fastest_node.attr().at("branches").list();
446 
447       // Branch 0: vectorized
448       const FunctionDef* branch_0 =
449           GetFunction(output, functions_list.func(0).name());
450       ASSERT_NE(branch_0, nullptr);
451       CheckBranch(*branch_0, {kBatchV2Op, kParallelMapOp});
452 
453       // Branch 1: original
454       const FunctionDef* branch_1 =
455           GetFunction(output, functions_list.func(1).name());
456       ASSERT_NE(branch_1, nullptr);
457       CheckBranch(*branch_1, original_ops);
458     };
459 
460     check_branches(choose_fastest_0,
461                    fuse_0 ? fused_sequence : unfused_sequence);
462     check_branches(choose_fastest_1,
463                    fuse_1 ? fused_sequence : unfused_sequence);
464   } else {
465     std::vector<int> map_nodes =
466         graph_utils::FindAllGraphNodesWithOp(kParallelMapOp, output);
467     std::vector<int> batch_nodes =
468         graph_utils::FindAllGraphNodesWithOp(kBatchV2Op, output);
469     ASSERT_EQ(map_nodes.size(), 2);
470     ASSERT_EQ(batch_nodes.size(), 2);
471 
472     const NodeDef& range_node =
473         output.node(graph_utils::FindGraphNodeWithOp(kRangeOp, output));
474 
475     const NodeDef& batch_node_0 = output.node(batch_nodes[0]);
476     EXPECT_EQ(batch_node_0.input(0), range_node.name());
477     const NodeDef& map_node_0 = output.node(map_nodes[0]);
478     EXPECT_EQ(map_node_0.input(0), batch_node_0.name());
479     const NodeDef& batch_node_1 = output.node(batch_nodes[1]);
480     EXPECT_EQ(batch_node_1.input(0), map_node_0.name());
481     const NodeDef& map_node_1 = output.node(map_nodes[1]);
482     EXPECT_EQ(map_node_1.input(0), batch_node_1.name());
483   }
484 }
485 
486 INSTANTIATE_TEST_SUITE_P(ChainedMapAndBatchTest, ChainedMapAndBatchTest,
487                          ::testing::Combine(::testing::Bool(),
488                                             ::testing::Bool(),
489                                             ::testing::Bool()));
490 
491 // Not all dataset types have "output_shapes" and "output_types"
492 // attrs defined. Add a generic input node which may not have these attrs
493 // defined.
AddArbitraryInputNode(MutableGraphView * graph,std::vector<PartialTensorShape> * output_shapes,std::vector<DataType> * output_types)494 NodeDef* AddArbitraryInputNode(MutableGraphView* graph,
495                                std::vector<PartialTensorShape>* output_shapes,
496                                std::vector<DataType>* output_types) {
497   std::vector<std::pair<string, FunctionDefHelper::AttrValueWrapper>> attrs;
498   if (output_shapes) {
499     attrs.push_back({kAttrNameOutputShapes, *output_shapes});
500   }
501   if (output_types) {
502     attrs.push_back({kAttrNameOutputTypes, *output_types});
503   }
504 
505   NodeDef result = NDef(/*name=*/"input", /*op=*/"InputDataset",
506                         /*inputs=*/{},
507                         /*attrs=*/attrs);
508 
509   graph_utils::SetUniqueGraphNodeName(result.name(), graph->graph(), &result);
510   return graph->AddNode(std::move(result));
511 }
512 
TEST(MapVectorizationTest,VectorizeWithUndefinedOutputShapes)513 TEST(MapVectorizationTest, VectorizeWithUndefinedOutputShapes) {
514   // Tests that the optimization doesn't break when the input to MapDataset
515   // doesn't have an output_shapes attr defined. In this case, the map and
516   // batch swap does not occur.
517   GrapplerItem item;
518   MutableGraphView graph(&item.graph);
519   std::vector<DataType> input_types({DT_INT64});
520   auto input_node = AddArbitraryInputNode(&graph, nullptr, &input_types);
521   auto map_fn = AddMapFn(&graph);
522   auto map_node =
523       AddMapNode(&graph, input_node->name(), map_fn->signature().name());
524   auto batch_node = AddBatchNode(&graph, map_node->name());
525   GraphDef output;
526   TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
527   CheckNotVectorized(output, map_node->op(), batch_node->op(),
528                      input_node->name());
529 }
530 
TEST(MapVectorizationTest,VectorizeWithUnknownRank)531 TEST(MapVectorizationTest, VectorizeWithUnknownRank) {
532   // Tests that the optimization doesn't break when the input to MapDataset
533   // has components with unknown rank. In this case, the optimization does not
534   // occur.
535   GrapplerItem item;
536   MutableGraphView graph(&item.graph);
537   std::vector<PartialTensorShape> input_shapes({{}});
538   std::vector<DataType> input_types({DT_INT64});
539   auto input_node = AddArbitraryInputNode(&graph, &input_shapes, &input_types);
540   auto map_fn = AddMapFn(&graph);
541   auto map_node =
542       AddMapNode(&graph, input_node->name(), map_fn->signature().name());
543   auto batch_node = AddBatchNode(&graph, map_node->name());
544   GraphDef output;
545   TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
546   CheckNotVectorized(output, map_node->op(), batch_node->op(),
547                      input_node->name());
548 }
549 
TEST(MapVectorizationTest,VectorizeWithUnknownDim)550 TEST(MapVectorizationTest, VectorizeWithUnknownDim) {
551   // Tests that the optimization doesn't break when the input to MapDataset
552   // has components with unknown dimensions. In this case, the optimization does
553   // not occur.
554   GrapplerItem item;
555   MutableGraphView graph(&item.graph);
556   std::vector<PartialTensorShape> input_shapes({{-1, 2}});
557   std::vector<DataType> input_types({DT_INT64});
558   auto input_node = AddArbitraryInputNode(&graph, &input_shapes, &input_types);
559   auto map_fn = AddMapFn(&graph);
560   auto map_node =
561       AddMapNode(&graph, input_node->name(), map_fn->signature().name());
562   auto batch_node = AddBatchNode(&graph, map_node->name());
563   GraphDef output;
564   TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
565   CheckNotVectorized(output, map_node->op(), batch_node->op(),
566                      input_node->name());
567 }
568 
TEST(MapVectorizationTest,VectorizeWithUndefinedOutputTypes)569 TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
570   // Tests that the optimization doesn't break when the input doesn't have
571   // an output_types attr defined. The output_types of the input node, even
572   // if not present, can be inferred from the map function input signature.
573   GrapplerItem item;
574   MutableGraphView graph(&item.graph);
575   std::vector<PartialTensorShape> input_shapes({{1}});
576   auto input_node = AddArbitraryInputNode(&graph, &input_shapes, nullptr);
577   auto map_fn = AddMapFn(&graph);
578   auto map_node =
579       AddMapNode(&graph, input_node->name(), map_fn->signature().name());
580   auto batch_node = AddBatchNode(&graph, map_node->name());
581   GraphDef output;
582   TF_ASSERT_OK(OptimizeWithMapVectorization(item, &output, true));
583   CheckVectorizedWithChooseFastest(
584       output, /*expected_vectorized_branch=*/{batch_node->op(), map_node->op()},
585       /*expected_original_branch=*/{map_node->op(), batch_node->op()},
586       input_node->name());
587 }
588 
589 // TODO(rachelim): Add test that has a polymorphic function.
590 
591 }  // namespace
592 }  // namespace grappler
593 }  // namespace tensorflow
594