• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/match.h"
21 #include "tensorflow/core/data/dataset_utils.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/function.pb.h"
25 #include "tensorflow/core/framework/metrics.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
32 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
33 #include "tensorflow/core/grappler/utils/functions.h"
34 #include "tensorflow/core/kernels/data/shard_dataset_op.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/errors.h"
37 
38 namespace tensorflow {
39 namespace grappler {
40 namespace {
41 
42 using tensorflow::data::AutoShardPolicy;
43 
44 constexpr char kAssertCardinalityDatasetOpName[] = "AssertCardinalityDataset";
45 constexpr char kShardDatasetOpName[] = "ShardDataset";
46 constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
47 constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
48 constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3";
49 constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset";
50 constexpr char kFinalizeDatasetOpName[] = "FinalizeDataset";
51 constexpr char kOptionsDatasetOpName[] = "OptionsDataset";
52 constexpr char kRebatchDatasetOpName[] = "RebatchDataset";
53 constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2";
54 constexpr char kTensorDatasetOpName[] = "TensorDataset";
55 constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset";
56 constexpr char kPlaceholderOpName[] = "Placeholder";
57 constexpr char kConstOpName[] = "Const";
58 
59 constexpr char kNumWorkersAttrName[] = "num_workers";
60 constexpr char kNumReplicasAttrName[] = "num_replicas";
61 constexpr char kIndexAttrName[] = "index";
62 constexpr char kAutoShardPolicyAttrName[] = "auto_shard_policy";
63 constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
64 constexpr char kOutputShapes[] = "output_shapes";
65 constexpr char kOutputTypes[] = "output_types";
66 
67 // clang-format off
68 constexpr std::array<const char*, 5> kReaderDatasetOps = {
69     "FixedLengthRecordDataset",
70     "RecordIODataset",
71     "SSTableDataset",
72     "TextLineDataset",
73     "TFRecordDataset"
74 };
75 
76 constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
77     "ConcatenateDataset",
78     "ZipDataset"
79 };
80 
81 constexpr std::array<const char*, 31> kPassThroughOps = {
82     "_Retval",
83     "AssertNextDataset",
84     "BatchDataset",
85     "CacheDataset",
86     "ExperimentalMapAndBatchDataset",
87     "ExperimentalParseExampleDataset",
88     "ExperimentalRebatchDataset",
89     "FilterDataset",
90     "FinalizeDataset",
91     "Identity",
92     "MapAndBatchDataset",
93     "MapDataset",
94     "MaxIntraOpParallelismDataset",
95     "ModelDataset",
96     "OptimizeDataset",
97     "OptionsDataset",
98     "PaddedBatchDataset",
99     "ParallelBatchDataset",
100     "ParallelMapDataset",
101     "ParseExampleDataset",
102     "PrefetchDataset",
103     "PrivateThreadPoolDataset",
104     "ReduceDataset",
105     "RebatchDataset",
106     "RepeatDataset",
107     "ShardDataset",
108     "ShuffleAndRepeatDataset",
109     "ShuffleDataset",
110     "SkipDataset",
111     "TakeDataset",
112     "WindowDataset",
113 };
114 
115 // TODO(frankchn): Process functions within kFuncDatasetOps as well.
116 constexpr std::array<const char*, 5> kFuncDatasetOps = {
117     "ExperimentalParallelInterleaveDataset",
118     "FlatMapDataset",
119     "InterleaveDataset",
120     "LegacyParallelInterleaveDataset",
121     "ParallelInterleaveDataset",
122 };
123 
124 constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
125     "GeneratorDataset",
126     "RangeDataset",
127     "SparseTensorsSliceDataset",
128     "TensorDataset",
129     "TensorSliceDataset",
130 };
131 // clang-format on
132 
133 Status OptimizeGraph(const GrapplerItem& item, int64_t num_workers,
134                      int64_t index, AutoShardPolicy policy,
135                      int64_t num_replicas, GraphDef* output,
136                      AutoShardPolicy* policy_applied);
137 
138 template <std::size_t SIZE>
IsDatasetNodeOfType(const NodeDef & node,const std::array<const char *,SIZE> & arr)139 bool IsDatasetNodeOfType(const NodeDef& node,
140                          const std::array<const char*, SIZE>& arr) {
141   for (const auto& dataset_op_name : arr) {
142     if (tensorflow::data::MatchesAnyVersion(/*op_prefix=*/dataset_op_name,
143                                             /*op_to_match=*/node.op())) {
144       return true;
145     }
146   }
147   return false;
148 }
149 
150 // Adds a ShardDataset node before `add_before`.
AddShardNode(MutableGraphView * graph,const NodeDef & add_before,int64_t num_workers,int64_t index)151 Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
152                     int64_t num_workers, int64_t index) {
153   NodeDef new_node;
154   new_node.set_op(kShardDatasetOpName);
155   graph_utils::SetUniqueGraphNodeName(kShardDatasetOpName, graph->graph(),
156                                       &new_node);
157 
158   // Construct argument nodes
159   NodeDef* num_shards_node =
160       graph_utils::AddScalarConstNode<int64>(num_workers, graph);
161   NodeDef* index_node = graph_utils::AddScalarConstNode<int64>(index, graph);
162 
163   // Add inputs to new node
164   new_node.add_input(add_before.input(0));
165   new_node.add_input(num_shards_node->name());
166   new_node.add_input(index_node->name());
167 
168   // Ensure that each shard will have at least one element.
169   (*(new_node.mutable_attr()))[data::ShardDatasetOp::kRequireNonEmpty].set_b(
170       true);
171 
172   // Add shapes and other attributes
173   NodeDef* add_after = graph->GetNode(add_before.input(0));
174 
175   if (absl::StrContains(add_after->op(), "Dataset")) {
176     // We still may or may not have the right attributes because Datasets like
177     // TFRecordDataset doesn't have a output type or shape, and by default we
178     // set them to DT_STRING and an unknown shape.
179     if (add_after->attr().count(kOutputShapes) > 0) {
180       graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
181     } else {
182       tensorflow::TensorShapeProto* shape =
183           (*(new_node.mutable_attr()))[kOutputShapes]
184               .mutable_list()
185               ->add_shape();
186       shape->set_unknown_rank(true);
187     }
188 
189     if (add_after->attr().count(kOutputTypes) > 0) {
190       graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
191     } else if (add_after->attr().count("Toutput_types") > 0) {
192       (*(new_node.mutable_attr()))[kOutputTypes] =
193           add_after->attr().at("Toutput_types");
194     } else {
195       (*(new_node.mutable_attr()))[kOutputTypes].mutable_list()->add_type(
196           tensorflow::DataType::DT_STRING);
197     }
198   } else {
199     // TODO(frankchn): Make this work for datasets where input(0) is a Const,
200     // and we need to shard the Const.
201     // This is probably not a dataset, so we bail because we can't infer the
202     // output types and shape.
203     return errors::NotFound(
204         "Unable to shard this input. You may need to wrap the inputs to your "
205         "reader dataset in a TensorSliceDataset. Input node is ",
206         add_after->DebugString());
207   }
208 
209   // Add new node into graph and update edges
210   NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
211   TF_RETURN_IF_ERROR(
212       graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
213 
214   return Status::OK();
215 }
216 
AddShuffleDataset(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_node,const string & seed2_node,bool reshuffle_each_iteration)217 Status AddShuffleDataset(MutableGraphView* graph, const NodeDef& add_before,
218                          const string& buffer_size_node,
219                          const string& seed_node, const string& seed2_node,
220                          bool reshuffle_each_iteration) {
221   NodeDef* add_after = graph->GetNode(add_before.input(0));
222   NodeDef new_node;
223   new_node.set_op(kShuffleDatasetOpName);
224   graph_utils::SetUniqueGraphNodeName(kShuffleDatasetOpName, graph->graph(),
225                                       &new_node);
226 
227   new_node.add_input(add_before.input(0));
228   new_node.add_input(buffer_size_node);
229   new_node.add_input(seed_node);
230   new_node.add_input(seed2_node);
231 
232   graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
233   graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
234 
235   AttrValue reshuffle_attr;
236   reshuffle_attr.set_b(reshuffle_each_iteration);
237   (*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
238 
239   NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
240 
241   TF_RETURN_IF_ERROR(
242       graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
243   return Status::OK();
244 }
245 
AddShuffleDatasetV2(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_generator_node)246 Status AddShuffleDatasetV2(MutableGraphView* graph, const NodeDef& add_before,
247                            const string& buffer_size_node,
248                            const string& seed_generator_node) {
249   NodeDef* add_after = graph->GetNode(add_before.input(0));
250   NodeDef new_node;
251   new_node.set_op(kShuffleDatasetV2OpName);
252   graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV2OpName, graph->graph(),
253                                       &new_node);
254 
255   new_node.add_input(add_before.input(0));
256   new_node.add_input(buffer_size_node);
257   new_node.add_input(seed_generator_node);
258 
259   graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
260   graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
261 
262   NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
263 
264   TF_RETURN_IF_ERROR(
265       graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
266   return Status::OK();
267 }
268 
AddShuffleDatasetV3(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_node,const string & seed2_node,const string & seed_generator_node,bool reshuffle_each_iteration)269 Status AddShuffleDatasetV3(MutableGraphView* graph, const NodeDef& add_before,
270                            const string& buffer_size_node,
271                            const string& seed_node, const string& seed2_node,
272                            const string& seed_generator_node,
273                            bool reshuffle_each_iteration) {
274   NodeDef* add_after = graph->GetNode(add_before.input(0));
275   NodeDef new_node;
276   new_node.set_op(kShuffleDatasetV3OpName);
277   graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV3OpName, graph->graph(),
278                                       &new_node);
279 
280   new_node.add_input(add_before.input(0));
281   new_node.add_input(buffer_size_node);
282   new_node.add_input(seed_node);
283   new_node.add_input(seed2_node);
284   new_node.add_input(seed_generator_node);
285 
286   graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
287   graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
288 
289   AttrValue reshuffle_attr;
290   reshuffle_attr.set_b(reshuffle_each_iteration);
291   (*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
292 
293   NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
294 
295   TF_RETURN_IF_ERROR(
296       graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
297   return Status::OK();
298 }
299 
ReaderOpInFunction(const NodeDef & node,const FunctionLibraryDefinition & flib)300 bool ReaderOpInFunction(const NodeDef& node,
301                         const FunctionLibraryDefinition& flib) {
302   const FunctionDef* func = flib.Find(node.attr().at("f").func().name());
303   for (int i = 0; i < func->node_def_size(); i++) {
304     NodeDef node_in_func = func->node_def(i);
305     if (IsDatasetNodeOfType(node_in_func, kReaderDatasetOps) &&
306         node_in_func.input_size() > 0 &&
307         absl::StartsWith(node_in_func.input(0), "args_0")) {
308       return true;
309     }
310     if (IsDatasetNodeOfType(func->node_def(i), kFuncDatasetOps) &&
311         ReaderOpInFunction(func->node_def(i), flib)) {
312       return true;
313     }
314   }
315   return false;
316 }
317 
RemoveShuffleDataset(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_node,string * seed2_node,bool * reshuffle_each_iteration)318 Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
319                             absl::flat_hash_set<string>* nodes_to_delete,
320                             string* op_name, string* buffer_size_node,
321                             string* seed_node, string* seed2_node,
322                             bool* reshuffle_each_iteration) {
323   if (node.op() == kShuffleDatasetOpName) {
324     *op_name = node.op();
325     *buffer_size_node = node.input(1);
326     *seed_node = node.input(2);
327     *seed2_node = node.input(3);
328     *reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
329     TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
330     nodes_to_delete->insert(node.name());
331   }
332 
333   for (const auto& fanin : graph->GetFanins(node, true)) {
334     TF_RETURN_IF_ERROR(RemoveShuffleDataset(
335         graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
336         seed_node, seed2_node, reshuffle_each_iteration));
337   }
338 
339   // TODO(frankchn): Traverse functions too.
340   return Status::OK();
341 }
342 
RemoveShuffleDatasetV2(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_generator_node)343 Status RemoveShuffleDatasetV2(MutableGraphView* graph, const NodeDef& node,
344                               absl::flat_hash_set<string>* nodes_to_delete,
345                               string* op_name, string* buffer_size_node,
346                               string* seed_generator_node) {
347   if (node.op() == kShuffleDatasetV2OpName) {
348     *op_name = node.op();
349     *buffer_size_node = node.input(1);
350     *seed_generator_node = node.input(2);
351     TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
352     nodes_to_delete->insert(node.name());
353   }
354 
355   for (const auto& fanin : graph->GetFanins(node, true)) {
356     TF_RETURN_IF_ERROR(
357         RemoveShuffleDatasetV2(graph, *fanin.node, nodes_to_delete, op_name,
358                                buffer_size_node, seed_generator_node));
359   }
360 
361   // TODO(frankchn): Traverse functions too.
362   return Status::OK();
363 }
364 
RemoveShuffleDatasetV3(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_node,string * seed2_node,string * seed_generator_node,bool * reshuffle_each_iteration)365 Status RemoveShuffleDatasetV3(MutableGraphView* graph, const NodeDef& node,
366                               absl::flat_hash_set<string>* nodes_to_delete,
367                               string* op_name, string* buffer_size_node,
368                               string* seed_node, string* seed2_node,
369                               string* seed_generator_node,
370                               bool* reshuffle_each_iteration) {
371   if (node.op() == kShuffleDatasetV3OpName) {
372     *op_name = node.op();
373     *buffer_size_node = node.input(1);
374     *seed_node = node.input(2);
375     *seed2_node = node.input(3);
376     *seed_generator_node = node.input(4);
377     *reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
378     TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
379     nodes_to_delete->insert(node.name());
380   }
381 
382   for (const auto& fanin : graph->GetFanins(node, true)) {
383     TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
384         graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
385         seed_node, seed2_node, seed_generator_node, reshuffle_each_iteration));
386   }
387 
388   // TODO(frankchn): Traverse functions too.
389   return Status::OK();
390 }
391 
ProcessDatasetSourceNode(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,int64_t num_workers,int64_t index)392 Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
393                                 absl::flat_hash_set<string>* nodes_to_delete,
394                                 int64_t num_workers, int64_t index) {
395   string shuffle_op_name = "";
396   string buffer_size_node = "";
397   string seed_node = "";
398   string seed2_node = "";
399   string seed_generator_node = "";
400   bool reshuffle_each_iteration;
401 
402   TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
403   TF_RETURN_IF_ERROR(RemoveShuffleDataset(
404       graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
405       &seed_node, &seed2_node, &reshuffle_each_iteration));
406   if (shuffle_op_name.empty()) {
407     TF_RETURN_IF_ERROR(
408         RemoveShuffleDatasetV2(graph, node, nodes_to_delete, &shuffle_op_name,
409                                &buffer_size_node, &seed_generator_node));
410   }
411   if (shuffle_op_name.empty()) {
412     TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
413         graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
414         &seed_node, &seed2_node, &seed_generator_node,
415         &reshuffle_each_iteration));
416   }
417 
418   if (shuffle_op_name == kShuffleDatasetOpName) {
419     TF_RETURN_IF_ERROR(AddShuffleDataset(graph, node, buffer_size_node,
420                                          seed_node, seed2_node,
421                                          reshuffle_each_iteration));
422   } else if (shuffle_op_name == kShuffleDatasetV2OpName) {
423     TF_RETURN_IF_ERROR(AddShuffleDatasetV2(graph, node, buffer_size_node,
424                                            seed_generator_node));
425   } else if (shuffle_op_name == kShuffleDatasetV3OpName) {
426     TF_RETURN_IF_ERROR(AddShuffleDatasetV3(
427         graph, node, buffer_size_node, seed_node, seed2_node,
428         seed_generator_node, reshuffle_each_iteration));
429   }
430 
431   return Status::OK();
432 }
433 
FindFuncAndTensorSliceDataset(const NodeDef * node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph,absl::flat_hash_set<string> * nodes_to_delete)434 const NodeDef* FindFuncAndTensorSliceDataset(
435     const NodeDef* node, int64_t num_workers, int64_t index,
436     FunctionLibraryDefinition* flib, MutableGraphView* graph,
437     absl::flat_hash_set<string>* nodes_to_delete) {
438   if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) {
439     const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
440     if (input_node->op() == kTensorSliceDatasetOpName ||
441         input_node->op() == kTensorDatasetOpName) {
442       const NodeDef* next_input_node =
443           graph_utils::GetInputNode(*input_node, *graph, 0);
444       if (next_input_node->op() == kPlaceholderOpName) {
445         return node;
446       }
447     }
448   }
449 
450   if (!IsDatasetNodeOfType(*node, kPassThroughOps)) {
451     return nullptr;
452   }
453 
454   // Sometimes there are other nodes between the last InterleaveDataset and the
455   // second to last FlatMapDataset, so we need to skip over those.
456   const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
457   return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib,
458                                        graph, nodes_to_delete);
459 }
460 
RecursivelyHandleOp(const NodeDef & node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph,absl::flat_hash_set<string> * nodes_to_delete)461 Status RecursivelyHandleOp(const NodeDef& node, int64_t num_workers,
462                            int64_t index, FunctionLibraryDefinition* flib,
463                            MutableGraphView* graph,
464                            absl::flat_hash_set<string>* nodes_to_delete) {
465   if (node.op() == kAssertCardinalityDatasetOpName) {
466     LOG(WARNING) << "The `assert_cardinality` transformation is currently not "
467                     "handled by the auto-shard rewrite and will be removed.";
468     nodes_to_delete->insert(node.name());
469     TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
470     const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
471     return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
472                                nodes_to_delete);
473   }
474 
475   if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
476     return errors::NotFound("Found an unshardable source dataset: ",
477                             node.DebugString());
478   }
479 
480   if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
481     for (int i = 0; i < node.input_size(); ++i) {
482       const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i);
483       TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, index,
484                                              flib, graph, nodes_to_delete));
485     }
486     return Status::OK();
487   }
488 
489   // This handles the case for the following subgraph:
490   //   Placeholder -> TensorSliceDataset -> FlatMapDataset -x->
491   //   (other preprocessing datasets) -> InterleaveDataset
492   // and then inserting the shard node immediately after the FlatMapDataset.
493   //
494   // This is used for some training pipelines where a dataset is created with
495   // the following code:
496   //
497   // def make_dataset_pipeline():
498   //   file_globs = [...]
499   //   datasets = []
500   //   for file_glob in file_globs:
501   //     datasets.append(Dataset.list_files(file_glob).map(TFRecordReader))
502   //   dataset = Dataset.from_tensor_slices(datasets)
503   //   dataset = dataset.flat_map(lambda x: x)
504   //   dataset = ...  # additional preprocessing
505   //   dataset = dataset.interleave(lambda x: x, cycle_length=...)
506   //   return dataset
507   if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
508     const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
509     const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset(
510         input_node, num_workers, index, flib, graph, nodes_to_delete);
511 
512     if (flat_map_node != nullptr) {
513       auto fanouts = graph->GetFanouts(*flat_map_node, false);
514       // FlatMapDataset should only be the input to one other dataset.
515       if (fanouts.size() == 1) {
516         return ProcessDatasetSourceNode(graph, *fanouts.begin()->node,
517                                         nodes_to_delete, num_workers, index);
518       }
519     }
520   }
521 
522   // This handles the case where a reader Dataset is contained within a
523   // FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
524   //
525   // dataset = Dataset.list_files("/path/to/data")
526   // dataset = dataset.flat_map(core_readers.TFRecordDataset)
527   //
528   // where the list of files is passed in one-by-one as an argument to the
529   // function in flat_map.
530   if (IsDatasetNodeOfType(node, kFuncDatasetOps) &&
531       ReaderOpInFunction(node, *flib)) {
532     return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
533                                     index);
534   }
535 
536   if (IsDatasetNodeOfType(node, kReaderDatasetOps)) {
537     // We reached a reader dataset directly and we try to shard input 0.
538     return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
539                                     index);
540   }
541 
542   if (!IsDatasetNodeOfType(node, kPassThroughOps)) {
543     return errors::NotFound(
544         "Did not find a shardable source, walked to ",
545         "a node which is not a dataset: ", node.DebugString(),
546         ". Consider either turning off auto-sharding or switching the "
547         "auto_shard_policy to DATA to shard this dataset. You can do this by "
548         "creating a new `tf.data.Options()` object then setting "
549         "`options.experimental_distribute.auto_shard_policy = "
550         "AutoShardPolicy.DATA` before applying the options object to the "
551         "dataset via `dataset.with_options(options)`.");
552   }
553 
554   const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
555   return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
556                              nodes_to_delete);
557 }
558 
559 // Recursively walk the dataset graph from sink to source, searching for
560 // the first (i.e. closest to the sink) occurence of a ReaderDataset, such as
561 // CSVDataset, TFRecordDataset, etc. We then insert a ShardDataset op before
562 // that nodes input, so that each worker only reads a subset of files.
563 // Additionally, we remove sources of randomness (e.g. ShuffleDataset) that
564 // occur upstream of the ShardDataset transformation to ensure that sharding
565 // returns a sensible result.
ShardByFile(const NodeDef & sink_node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph)566 Status ShardByFile(const NodeDef& sink_node, int64_t num_workers, int64_t index,
567                    FunctionLibraryDefinition* flib, MutableGraphView* graph) {
568   absl::flat_hash_set<string> nodes_to_delete;
569   TF_RETURN_IF_ERROR(RecursivelyHandleOp(sink_node, num_workers, index, flib,
570                                          graph, &nodes_to_delete));
571   return graph->DeleteNodes(nodes_to_delete);
572 }
573 
RewriteRebatchV2ToV1(const NodeDef & sink_node,int64_t num_replicas,MutableGraphView * graph)574 Status RewriteRebatchV2ToV1(const NodeDef& sink_node, int64_t num_replicas,
575                             MutableGraphView* graph) {
576   // The final node before AutoShardDataset is RebatchDataset.
577   // This is always the case as RebatchDataset and AutoShardDataset are internal
578   // APIs used directly by tf.distribute's input_lib. As such, instead of
579   // walking the entire dataset graph, we can walk up directly from the
580   // sink_node to get the RebatchDataset.
581   NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
582   if (input_node->op() != kRebatchDatasetV2OpName) {
583     return Status::OK();
584   }
585 
586   NodeDef* rebatch_node = input_node;
587   // Update RebatchDatasetV2 in place. Since Rebatch is an internal API, no
588   // other nodes should have it as an input.
589   rebatch_node->set_op(kRebatchDatasetOpName);
590   // Delete the `batch_sizes` and `drop_remainder` input.
591   rebatch_node->mutable_input()->DeleteSubrange(/*start=*/1, /*num=*/2);
592   // Add the `num_replicas` input.
593   if (num_replicas < 1) {
594     return errors::InvalidArgument(
595         "Cannot rewrite RebatchDatasetV2 to legacy RebatchDataset with invalid "
596         "num_replicas argument. `num_replicas` is ",
597         num_replicas, ", but expected to be >= 1.");
598   }
599   auto num_replicas_node = graph_utils::AddScalarConstNode(num_replicas, graph);
600   rebatch_node->add_input(num_replicas_node->name());
601 
602   // Set `use_fallback` attr. This attr is not used anywhere, so its value
603   // does not matter
604   (*rebatch_node->mutable_attr())["use_fallback"].set_b(true);
605 
606   // Update the output_shapes attr to set all its batch dimensions to -1
607   // (unknown).
608   auto* shapes_attr =
609       gtl::FindOrNull(*rebatch_node->mutable_attr(), "output_shapes");
610   if (shapes_attr == nullptr) {
611     return errors::InvalidArgument(
612         "Cannot rewrite RebatchDatasetV2 with missing `output_shapes` attr.");
613   }
614   for (int i = 0; i < shapes_attr->list().shape_size(); ++i) {
615     auto* shape = shapes_attr->mutable_list()->mutable_shape(i);
616     if (shape->unknown_rank()) continue;
617     shape->mutable_dim(0)->set_size(-1);
618   }
619 
620   return Status::OK();
621 }
622 
ShardByData(const NodeDef & sink_node,int64_t num_workers,int64_t index,int64_t num_replicas,MutableGraphView * graph)623 Status ShardByData(const NodeDef& sink_node, int64_t num_workers, int64_t index,
624                    int64_t num_replicas, MutableGraphView* graph) {
625   const NodeDef* shard_before = &sink_node;
626   // We sometimes insert a PrefetchDataset, OptionsDataset, and FinalizeDataset
627   // at the end of the input pipeline before autosharding. When sharding by
628   // data, we should insert the shard before the these datasets so that the
629   // right number of elements is prefetched.
630   NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
631   while (input_node->op() == kPrefetchDatasetOpName ||
632          input_node->op() == kOptionsDatasetOpName ||
633          input_node->op() == kFinalizeDatasetOpName) {
634     shard_before = input_node;
635     input_node = graph_utils::GetInputNode(*input_node, *graph);
636   }
637   // Sharding by data only works with legacy RebatchDataset. As such, we rewrite
638   // all instances of RebatchDatasetV2 to RebatchDataset.
639   TF_RETURN_IF_ERROR(RewriteRebatchV2ToV1(*shard_before, num_replicas, graph));
640   return AddShardNode(graph, *shard_before, num_workers, index);
641 }
642 
643 // Searches the dataset graph replacing any occurence of `shard(1, 0)` with
644 // `shard(num_workers, index)`.
ShardByHint(const NodeDef & sink_node,int64_t num_workers,int64_t index,int64_t num_replicas,MutableGraphView * graph)645 Status ShardByHint(const NodeDef& sink_node, int64_t num_workers, int64_t index,
646                    int64_t num_replicas, MutableGraphView* graph) {
647   auto get_shard_node = [graph](const NodeDef& node) -> const NodeDef* {
648     if (node.op() != kShardDatasetOpName) return nullptr;
649     auto num_workers_node = graph->GetNode(node.input(1));
650     if (num_workers_node->op() != kConstOpName) return nullptr;
651     if (num_workers_node->attr().at("value").tensor().int64_val(0) !=
652         tensorflow::data::kShardHint)
653       return nullptr;
654     return &node;
655   };
656 
657   auto* num_workers_node =
658       graph_utils::AddScalarConstNode(static_cast<int64>(num_workers), graph);
659   auto* worker_index_node =
660       graph_utils::AddScalarConstNode(static_cast<int64>(index), graph);
661 
662   for (const NodeDef& node : graph->graph()->node()) {
663     const NodeDef* shard_node = get_shard_node(node);
664     if (!shard_node) continue;
665     auto mutable_node = graph->GetNode(shard_node->name());
666     *mutable_node->mutable_input(1) = num_workers_node->name();
667     *mutable_node->mutable_input(2) = worker_index_node->name();
668     // Ensure that each shard will have at least one element.
669     (*(mutable_node->mutable_attr()))[data::ShardDatasetOp::kRequireNonEmpty]
670         .set_b(true);
671   }
672   return Status::OK();
673 }
674 
OptimizeGraph(const GrapplerItem & item,int64_t num_workers,int64_t index,AutoShardPolicy policy,int64_t num_replicas,GraphDef * output,AutoShardPolicy * policy_applied)675 Status OptimizeGraph(const GrapplerItem& item, int64_t num_workers,
676                      int64_t index, AutoShardPolicy policy,
677                      int64_t num_replicas, GraphDef* output,
678                      AutoShardPolicy* policy_applied) {
679   *policy_applied = policy;
680   if (policy == AutoShardPolicy::OFF ||
681       (policy == AutoShardPolicy::FILE && num_workers == 1 && index == 0)) {
682     return Status::OK();
683   }
684 
685   *output = item.graph;
686   MutableGraphView graph(output);
687   FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
688 
689   NodeDef* sink_node;
690   TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
691 
692   switch (policy) {
693     case AutoShardPolicy::OFF:
694       return Status::OK();
695     case AutoShardPolicy::FILE:
696       return ShardByFile(*sink_node, num_workers, index, &flib, &graph);
697     case AutoShardPolicy::DATA:
698       return ShardByData(*sink_node, num_workers, index, num_replicas, &graph);
699     case AutoShardPolicy::HINT:
700       return ShardByHint(*sink_node, num_workers, index, num_replicas, &graph);
701     case AutoShardPolicy::AUTO:
702     default:
703       Status s = ShardByFile(*sink_node, num_workers, index, &flib, &graph);
704       if (errors::IsNotFound(s)) {
705         LOG(WARNING) << "AUTO sharding policy will apply DATA sharding policy "
706                         "as it failed to apply FILE sharding policy because of "
707                         "the following reason: "
708                      << s.error_message();
709         *policy_applied = AutoShardPolicy::DATA;
710         return ShardByData(*sink_node, num_workers, index, num_replicas,
711                            &graph);
712       }
713       *policy_applied = AutoShardPolicy::FILE;
714       return s;
715   }
716 }
717 
718 }  // anonymous namespace
719 
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)720 Status AutoShard::Init(
721     const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
722   if (!config) return errors::InvalidArgument("RewriterConfig not found.");
723 
724   if ((config->parameter_map().find(kNumWorkersAttrName) ==
725        config->parameter_map().end())) {
726     return errors::InvalidArgument(kNumWorkersAttrName, " parameter missing.");
727   }
728 
729   if ((config->parameter_map().find(kIndexAttrName) ==
730        config->parameter_map().end())) {
731     return errors::InvalidArgument(kIndexAttrName, " parameter missing.");
732   }
733 
734   num_workers_ = config->parameter_map().at(kNumWorkersAttrName).i();
735   index_ = config->parameter_map().at(kIndexAttrName).i();
736   auto_shard_policy_ =
737       AutoShardPolicy(config->parameter_map().at(kAutoShardPolicyAttrName).i());
738   num_replicas_ = config->parameter_map().at(kNumReplicasAttrName).i();
739 
740   if (auto_shard_policy_ != AutoShardPolicy::OFF &&
741       auto_shard_policy_ != AutoShardPolicy::AUTO &&
742       auto_shard_policy_ != AutoShardPolicy::DATA &&
743       auto_shard_policy_ != AutoShardPolicy::FILE &&
744       auto_shard_policy_ != AutoShardPolicy::HINT) {
745     return errors::InvalidArgument(kAutoShardPolicyAttrName, " is invalid.");
746   }
747 
748   if (num_workers_ < 1) {
749     return errors::InvalidArgument(kNumWorkersAttrName,
750                                    " should be >= 1, currently ", num_workers_);
751   }
752 
753   if (index_ < 0 || index_ >= num_workers_) {
754     return errors::InvalidArgument(kIndexAttrName, " should be >= 0 and < ",
755                                    num_workers_, ", currently ", index_);
756   }
757 
758   if (num_replicas_ < 0) {
759     return errors::InvalidArgument(kNumReplicasAttrName, " should be >= 0");
760   }
761 
762   return Status::OK();
763 }
764 
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)765 Status AutoShard::OptimizeAndCollectStats(Cluster* cluster,
766                                           const GrapplerItem& item,
767                                           GraphDef* output,
768                                           OptimizationStats* stats) {
769   *output = item.graph;
770   AutoShardPolicy policy_applied;
771   TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_,
772                                    auto_shard_policy_, num_replicas_, output,
773                                    &policy_applied));
774 
775   // Only record on the first shard to avoid duplication.
776   if (index_ == 0) {
777     // item.id is always the same so we use the address of the cluster as id.
778     string id = strings::StrCat(reinterpret_cast<uint64>(cluster));
779     metrics::RecordTFDataAutoShard(id, policy_applied, num_workers_,
780                                    num_replicas_);
781   }
782   stats->num_changes++;
783   return Status::OK();
784 }
785 
786 REGISTER_GRAPH_OPTIMIZER_AS(AutoShard, "tf_auto_shard");
787 
788 }  // namespace grappler
789 }  // namespace tensorflow
790