1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/match.h"
21 #include "tensorflow/core/data/dataset_utils.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/function.pb.h"
25 #include "tensorflow/core/framework/metrics.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/mutable_graph_view.h"
30 #include "tensorflow/core/grappler/op_types.h"
31 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
32 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
33 #include "tensorflow/core/grappler/utils/functions.h"
34 #include "tensorflow/core/kernels/data/shard_dataset_op.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/errors.h"
37
38 namespace tensorflow {
39 namespace grappler {
40 namespace {
41
42 using tensorflow::data::AutoShardPolicy;
43
44 constexpr char kAssertCardinalityDatasetOpName[] = "AssertCardinalityDataset";
45 constexpr char kShardDatasetOpName[] = "ShardDataset";
46 constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
47 constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
48 constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3";
49 constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset";
50 constexpr char kFinalizeDatasetOpName[] = "FinalizeDataset";
51 constexpr char kOptionsDatasetOpName[] = "OptionsDataset";
52 constexpr char kRebatchDatasetOpName[] = "RebatchDataset";
53 constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2";
54 constexpr char kTensorDatasetOpName[] = "TensorDataset";
55 constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset";
56 constexpr char kPlaceholderOpName[] = "Placeholder";
57 constexpr char kConstOpName[] = "Const";
58
59 constexpr char kNumWorkersAttrName[] = "num_workers";
60 constexpr char kNumReplicasAttrName[] = "num_replicas";
61 constexpr char kIndexAttrName[] = "index";
62 constexpr char kAutoShardPolicyAttrName[] = "auto_shard_policy";
63 constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
64 constexpr char kOutputShapes[] = "output_shapes";
65 constexpr char kOutputTypes[] = "output_types";
66
67 // clang-format off
68 constexpr std::array<const char*, 5> kReaderDatasetOps = {
69 "FixedLengthRecordDataset",
70 "RecordIODataset",
71 "SSTableDataset",
72 "TextLineDataset",
73 "TFRecordDataset"
74 };
75
76 constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
77 "ConcatenateDataset",
78 "ZipDataset"
79 };
80
81 constexpr std::array<const char*, 31> kPassThroughOps = {
82 "_Retval",
83 "AssertNextDataset",
84 "BatchDataset",
85 "CacheDataset",
86 "ExperimentalMapAndBatchDataset",
87 "ExperimentalParseExampleDataset",
88 "ExperimentalRebatchDataset",
89 "FilterDataset",
90 "FinalizeDataset",
91 "Identity",
92 "MapAndBatchDataset",
93 "MapDataset",
94 "MaxIntraOpParallelismDataset",
95 "ModelDataset",
96 "OptimizeDataset",
97 "OptionsDataset",
98 "PaddedBatchDataset",
99 "ParallelBatchDataset",
100 "ParallelMapDataset",
101 "ParseExampleDataset",
102 "PrefetchDataset",
103 "PrivateThreadPoolDataset",
104 "ReduceDataset",
105 "RebatchDataset",
106 "RepeatDataset",
107 "ShardDataset",
108 "ShuffleAndRepeatDataset",
109 "ShuffleDataset",
110 "SkipDataset",
111 "TakeDataset",
112 "WindowDataset",
113 };
114
115 // TODO(frankchn): Process functions within kFuncDatasetOps as well.
116 constexpr std::array<const char*, 5> kFuncDatasetOps = {
117 "ExperimentalParallelInterleaveDataset",
118 "FlatMapDataset",
119 "InterleaveDataset",
120 "LegacyParallelInterleaveDataset",
121 "ParallelInterleaveDataset",
122 };
123
124 constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
125 "GeneratorDataset",
126 "RangeDataset",
127 "SparseTensorsSliceDataset",
128 "TensorDataset",
129 "TensorSliceDataset",
130 };
131 // clang-format on
132
133 Status OptimizeGraph(const GrapplerItem& item, int64_t num_workers,
134 int64_t index, AutoShardPolicy policy,
135 int64_t num_replicas, GraphDef* output,
136 AutoShardPolicy* policy_applied);
137
138 template <std::size_t SIZE>
IsDatasetNodeOfType(const NodeDef & node,const std::array<const char *,SIZE> & arr)139 bool IsDatasetNodeOfType(const NodeDef& node,
140 const std::array<const char*, SIZE>& arr) {
141 for (const auto& dataset_op_name : arr) {
142 if (tensorflow::data::MatchesAnyVersion(/*op_prefix=*/dataset_op_name,
143 /*op_to_match=*/node.op())) {
144 return true;
145 }
146 }
147 return false;
148 }
149
150 // Adds a ShardDataset node before `add_before`.
AddShardNode(MutableGraphView * graph,const NodeDef & add_before,int64_t num_workers,int64_t index)151 Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
152 int64_t num_workers, int64_t index) {
153 NodeDef new_node;
154 new_node.set_op(kShardDatasetOpName);
155 graph_utils::SetUniqueGraphNodeName(kShardDatasetOpName, graph->graph(),
156 &new_node);
157
158 // Construct argument nodes
159 NodeDef* num_shards_node =
160 graph_utils::AddScalarConstNode<int64>(num_workers, graph);
161 NodeDef* index_node = graph_utils::AddScalarConstNode<int64>(index, graph);
162
163 // Add inputs to new node
164 new_node.add_input(add_before.input(0));
165 new_node.add_input(num_shards_node->name());
166 new_node.add_input(index_node->name());
167
168 // Ensure that each shard will have at least one element.
169 (*(new_node.mutable_attr()))[data::ShardDatasetOp::kRequireNonEmpty].set_b(
170 true);
171
172 // Add shapes and other attributes
173 NodeDef* add_after = graph->GetNode(add_before.input(0));
174
175 if (absl::StrContains(add_after->op(), "Dataset")) {
176 // We still may or may not have the right attributes because Datasets like
177 // TFRecordDataset doesn't have a output type or shape, and by default we
178 // set them to DT_STRING and an unknown shape.
179 if (add_after->attr().count(kOutputShapes) > 0) {
180 graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
181 } else {
182 tensorflow::TensorShapeProto* shape =
183 (*(new_node.mutable_attr()))[kOutputShapes]
184 .mutable_list()
185 ->add_shape();
186 shape->set_unknown_rank(true);
187 }
188
189 if (add_after->attr().count(kOutputTypes) > 0) {
190 graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
191 } else if (add_after->attr().count("Toutput_types") > 0) {
192 (*(new_node.mutable_attr()))[kOutputTypes] =
193 add_after->attr().at("Toutput_types");
194 } else {
195 (*(new_node.mutable_attr()))[kOutputTypes].mutable_list()->add_type(
196 tensorflow::DataType::DT_STRING);
197 }
198 } else {
199 // TODO(frankchn): Make this work for datasets where input(0) is a Const,
200 // and we need to shard the Const.
201 // This is probably not a dataset, so we bail because we can't infer the
202 // output types and shape.
203 return errors::NotFound(
204 "Unable to shard this input. You may need to wrap the inputs to your "
205 "reader dataset in a TensorSliceDataset. Input node is ",
206 add_after->DebugString());
207 }
208
209 // Add new node into graph and update edges
210 NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
211 TF_RETURN_IF_ERROR(
212 graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
213
214 return Status::OK();
215 }
216
AddShuffleDataset(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_node,const string & seed2_node,bool reshuffle_each_iteration)217 Status AddShuffleDataset(MutableGraphView* graph, const NodeDef& add_before,
218 const string& buffer_size_node,
219 const string& seed_node, const string& seed2_node,
220 bool reshuffle_each_iteration) {
221 NodeDef* add_after = graph->GetNode(add_before.input(0));
222 NodeDef new_node;
223 new_node.set_op(kShuffleDatasetOpName);
224 graph_utils::SetUniqueGraphNodeName(kShuffleDatasetOpName, graph->graph(),
225 &new_node);
226
227 new_node.add_input(add_before.input(0));
228 new_node.add_input(buffer_size_node);
229 new_node.add_input(seed_node);
230 new_node.add_input(seed2_node);
231
232 graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
233 graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
234
235 AttrValue reshuffle_attr;
236 reshuffle_attr.set_b(reshuffle_each_iteration);
237 (*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
238
239 NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
240
241 TF_RETURN_IF_ERROR(
242 graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
243 return Status::OK();
244 }
245
AddShuffleDatasetV2(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_generator_node)246 Status AddShuffleDatasetV2(MutableGraphView* graph, const NodeDef& add_before,
247 const string& buffer_size_node,
248 const string& seed_generator_node) {
249 NodeDef* add_after = graph->GetNode(add_before.input(0));
250 NodeDef new_node;
251 new_node.set_op(kShuffleDatasetV2OpName);
252 graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV2OpName, graph->graph(),
253 &new_node);
254
255 new_node.add_input(add_before.input(0));
256 new_node.add_input(buffer_size_node);
257 new_node.add_input(seed_generator_node);
258
259 graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
260 graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
261
262 NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
263
264 TF_RETURN_IF_ERROR(
265 graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
266 return Status::OK();
267 }
268
AddShuffleDatasetV3(MutableGraphView * graph,const NodeDef & add_before,const string & buffer_size_node,const string & seed_node,const string & seed2_node,const string & seed_generator_node,bool reshuffle_each_iteration)269 Status AddShuffleDatasetV3(MutableGraphView* graph, const NodeDef& add_before,
270 const string& buffer_size_node,
271 const string& seed_node, const string& seed2_node,
272 const string& seed_generator_node,
273 bool reshuffle_each_iteration) {
274 NodeDef* add_after = graph->GetNode(add_before.input(0));
275 NodeDef new_node;
276 new_node.set_op(kShuffleDatasetV3OpName);
277 graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV3OpName, graph->graph(),
278 &new_node);
279
280 new_node.add_input(add_before.input(0));
281 new_node.add_input(buffer_size_node);
282 new_node.add_input(seed_node);
283 new_node.add_input(seed2_node);
284 new_node.add_input(seed_generator_node);
285
286 graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
287 graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
288
289 AttrValue reshuffle_attr;
290 reshuffle_attr.set_b(reshuffle_each_iteration);
291 (*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
292
293 NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
294
295 TF_RETURN_IF_ERROR(
296 graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
297 return Status::OK();
298 }
299
ReaderOpInFunction(const NodeDef & node,const FunctionLibraryDefinition & flib)300 bool ReaderOpInFunction(const NodeDef& node,
301 const FunctionLibraryDefinition& flib) {
302 const FunctionDef* func = flib.Find(node.attr().at("f").func().name());
303 for (int i = 0; i < func->node_def_size(); i++) {
304 NodeDef node_in_func = func->node_def(i);
305 if (IsDatasetNodeOfType(node_in_func, kReaderDatasetOps) &&
306 node_in_func.input_size() > 0 &&
307 absl::StartsWith(node_in_func.input(0), "args_0")) {
308 return true;
309 }
310 if (IsDatasetNodeOfType(func->node_def(i), kFuncDatasetOps) &&
311 ReaderOpInFunction(func->node_def(i), flib)) {
312 return true;
313 }
314 }
315 return false;
316 }
317
RemoveShuffleDataset(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_node,string * seed2_node,bool * reshuffle_each_iteration)318 Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
319 absl::flat_hash_set<string>* nodes_to_delete,
320 string* op_name, string* buffer_size_node,
321 string* seed_node, string* seed2_node,
322 bool* reshuffle_each_iteration) {
323 if (node.op() == kShuffleDatasetOpName) {
324 *op_name = node.op();
325 *buffer_size_node = node.input(1);
326 *seed_node = node.input(2);
327 *seed2_node = node.input(3);
328 *reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
329 TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
330 nodes_to_delete->insert(node.name());
331 }
332
333 for (const auto& fanin : graph->GetFanins(node, true)) {
334 TF_RETURN_IF_ERROR(RemoveShuffleDataset(
335 graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
336 seed_node, seed2_node, reshuffle_each_iteration));
337 }
338
339 // TODO(frankchn): Traverse functions too.
340 return Status::OK();
341 }
342
RemoveShuffleDatasetV2(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_generator_node)343 Status RemoveShuffleDatasetV2(MutableGraphView* graph, const NodeDef& node,
344 absl::flat_hash_set<string>* nodes_to_delete,
345 string* op_name, string* buffer_size_node,
346 string* seed_generator_node) {
347 if (node.op() == kShuffleDatasetV2OpName) {
348 *op_name = node.op();
349 *buffer_size_node = node.input(1);
350 *seed_generator_node = node.input(2);
351 TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
352 nodes_to_delete->insert(node.name());
353 }
354
355 for (const auto& fanin : graph->GetFanins(node, true)) {
356 TF_RETURN_IF_ERROR(
357 RemoveShuffleDatasetV2(graph, *fanin.node, nodes_to_delete, op_name,
358 buffer_size_node, seed_generator_node));
359 }
360
361 // TODO(frankchn): Traverse functions too.
362 return Status::OK();
363 }
364
RemoveShuffleDatasetV3(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,string * op_name,string * buffer_size_node,string * seed_node,string * seed2_node,string * seed_generator_node,bool * reshuffle_each_iteration)365 Status RemoveShuffleDatasetV3(MutableGraphView* graph, const NodeDef& node,
366 absl::flat_hash_set<string>* nodes_to_delete,
367 string* op_name, string* buffer_size_node,
368 string* seed_node, string* seed2_node,
369 string* seed_generator_node,
370 bool* reshuffle_each_iteration) {
371 if (node.op() == kShuffleDatasetV3OpName) {
372 *op_name = node.op();
373 *buffer_size_node = node.input(1);
374 *seed_node = node.input(2);
375 *seed2_node = node.input(3);
376 *seed_generator_node = node.input(4);
377 *reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
378 TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
379 nodes_to_delete->insert(node.name());
380 }
381
382 for (const auto& fanin : graph->GetFanins(node, true)) {
383 TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
384 graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
385 seed_node, seed2_node, seed_generator_node, reshuffle_each_iteration));
386 }
387
388 // TODO(frankchn): Traverse functions too.
389 return Status::OK();
390 }
391
ProcessDatasetSourceNode(MutableGraphView * graph,const NodeDef & node,absl::flat_hash_set<string> * nodes_to_delete,int64_t num_workers,int64_t index)392 Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
393 absl::flat_hash_set<string>* nodes_to_delete,
394 int64_t num_workers, int64_t index) {
395 string shuffle_op_name = "";
396 string buffer_size_node = "";
397 string seed_node = "";
398 string seed2_node = "";
399 string seed_generator_node = "";
400 bool reshuffle_each_iteration;
401
402 TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
403 TF_RETURN_IF_ERROR(RemoveShuffleDataset(
404 graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
405 &seed_node, &seed2_node, &reshuffle_each_iteration));
406 if (shuffle_op_name.empty()) {
407 TF_RETURN_IF_ERROR(
408 RemoveShuffleDatasetV2(graph, node, nodes_to_delete, &shuffle_op_name,
409 &buffer_size_node, &seed_generator_node));
410 }
411 if (shuffle_op_name.empty()) {
412 TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
413 graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
414 &seed_node, &seed2_node, &seed_generator_node,
415 &reshuffle_each_iteration));
416 }
417
418 if (shuffle_op_name == kShuffleDatasetOpName) {
419 TF_RETURN_IF_ERROR(AddShuffleDataset(graph, node, buffer_size_node,
420 seed_node, seed2_node,
421 reshuffle_each_iteration));
422 } else if (shuffle_op_name == kShuffleDatasetV2OpName) {
423 TF_RETURN_IF_ERROR(AddShuffleDatasetV2(graph, node, buffer_size_node,
424 seed_generator_node));
425 } else if (shuffle_op_name == kShuffleDatasetV3OpName) {
426 TF_RETURN_IF_ERROR(AddShuffleDatasetV3(
427 graph, node, buffer_size_node, seed_node, seed2_node,
428 seed_generator_node, reshuffle_each_iteration));
429 }
430
431 return Status::OK();
432 }
433
FindFuncAndTensorSliceDataset(const NodeDef * node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph,absl::flat_hash_set<string> * nodes_to_delete)434 const NodeDef* FindFuncAndTensorSliceDataset(
435 const NodeDef* node, int64_t num_workers, int64_t index,
436 FunctionLibraryDefinition* flib, MutableGraphView* graph,
437 absl::flat_hash_set<string>* nodes_to_delete) {
438 if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) {
439 const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
440 if (input_node->op() == kTensorSliceDatasetOpName ||
441 input_node->op() == kTensorDatasetOpName) {
442 const NodeDef* next_input_node =
443 graph_utils::GetInputNode(*input_node, *graph, 0);
444 if (next_input_node->op() == kPlaceholderOpName) {
445 return node;
446 }
447 }
448 }
449
450 if (!IsDatasetNodeOfType(*node, kPassThroughOps)) {
451 return nullptr;
452 }
453
454 // Sometimes there are other nodes between the last InterleaveDataset and the
455 // second to last FlatMapDataset, so we need to skip over those.
456 const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
457 return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib,
458 graph, nodes_to_delete);
459 }
460
RecursivelyHandleOp(const NodeDef & node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph,absl::flat_hash_set<string> * nodes_to_delete)461 Status RecursivelyHandleOp(const NodeDef& node, int64_t num_workers,
462 int64_t index, FunctionLibraryDefinition* flib,
463 MutableGraphView* graph,
464 absl::flat_hash_set<string>* nodes_to_delete) {
465 if (node.op() == kAssertCardinalityDatasetOpName) {
466 LOG(WARNING) << "The `assert_cardinality` transformation is currently not "
467 "handled by the auto-shard rewrite and will be removed.";
468 nodes_to_delete->insert(node.name());
469 TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
470 const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
471 return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
472 nodes_to_delete);
473 }
474
475 if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
476 return errors::NotFound("Found an unshardable source dataset: ",
477 node.DebugString());
478 }
479
480 if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
481 for (int i = 0; i < node.input_size(); ++i) {
482 const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i);
483 TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, index,
484 flib, graph, nodes_to_delete));
485 }
486 return Status::OK();
487 }
488
489 // This handles the case for the following subgraph:
490 // Placeholder -> TensorSliceDataset -> FlatMapDataset -x->
491 // (other preprocessing datasets) -> InterleaveDataset
492 // and then inserting the shard node immediately after the FlatMapDataset.
493 //
494 // This is used for some training pipelines where a dataset is created with
495 // the following code:
496 //
497 // def make_dataset_pipeline():
498 // file_globs = [...]
499 // datasets = []
500 // for file_glob in file_globs:
501 // datasets.append(Dataset.list_files(file_glob).map(TFRecordReader))
502 // dataset = Dataset.from_tensor_slices(datasets)
503 // dataset = dataset.flat_map(lambda x: x)
504 // dataset = ... # additional preprocessing
505 // dataset = dataset.interleave(lambda x: x, cycle_length=...)
506 // return dataset
507 if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
508 const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
509 const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset(
510 input_node, num_workers, index, flib, graph, nodes_to_delete);
511
512 if (flat_map_node != nullptr) {
513 auto fanouts = graph->GetFanouts(*flat_map_node, false);
514 // FlatMapDataset should only be the input to one other dataset.
515 if (fanouts.size() == 1) {
516 return ProcessDatasetSourceNode(graph, *fanouts.begin()->node,
517 nodes_to_delete, num_workers, index);
518 }
519 }
520 }
521
522 // This handles the case where a reader Dataset is contained within a
523 // FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
524 //
525 // dataset = Dataset.list_files("/path/to/data")
526 // dataset = dataset.flat_map(core_readers.TFRecordDataset)
527 //
528 // where the list of files is passed in one-by-one as an argument to the
529 // function in flat_map.
530 if (IsDatasetNodeOfType(node, kFuncDatasetOps) &&
531 ReaderOpInFunction(node, *flib)) {
532 return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
533 index);
534 }
535
536 if (IsDatasetNodeOfType(node, kReaderDatasetOps)) {
537 // We reached a reader dataset directly and we try to shard input 0.
538 return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
539 index);
540 }
541
542 if (!IsDatasetNodeOfType(node, kPassThroughOps)) {
543 return errors::NotFound(
544 "Did not find a shardable source, walked to ",
545 "a node which is not a dataset: ", node.DebugString(),
546 ". Consider either turning off auto-sharding or switching the "
547 "auto_shard_policy to DATA to shard this dataset. You can do this by "
548 "creating a new `tf.data.Options()` object then setting "
549 "`options.experimental_distribute.auto_shard_policy = "
550 "AutoShardPolicy.DATA` before applying the options object to the "
551 "dataset via `dataset.with_options(options)`.");
552 }
553
554 const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
555 return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
556 nodes_to_delete);
557 }
558
559 // Recursively walk the dataset graph from sink to source, searching for
560 // the first (i.e. closest to the sink) occurence of a ReaderDataset, such as
561 // CSVDataset, TFRecordDataset, etc. We then insert a ShardDataset op before
562 // that nodes input, so that each worker only reads a subset of files.
563 // Additionally, we remove sources of randomness (e.g. ShuffleDataset) that
564 // occur upstream of the ShardDataset transformation to ensure that sharding
565 // returns a sensible result.
ShardByFile(const NodeDef & sink_node,int64_t num_workers,int64_t index,FunctionLibraryDefinition * flib,MutableGraphView * graph)566 Status ShardByFile(const NodeDef& sink_node, int64_t num_workers, int64_t index,
567 FunctionLibraryDefinition* flib, MutableGraphView* graph) {
568 absl::flat_hash_set<string> nodes_to_delete;
569 TF_RETURN_IF_ERROR(RecursivelyHandleOp(sink_node, num_workers, index, flib,
570 graph, &nodes_to_delete));
571 return graph->DeleteNodes(nodes_to_delete);
572 }
573
RewriteRebatchV2ToV1(const NodeDef & sink_node,int64_t num_replicas,MutableGraphView * graph)574 Status RewriteRebatchV2ToV1(const NodeDef& sink_node, int64_t num_replicas,
575 MutableGraphView* graph) {
576 // The final node before AutoShardDataset is RebatchDataset.
577 // This is always the case as RebatchDataset and AutoShardDataset are internal
578 // APIs used directly by tf.distribute's input_lib. As such, instead of
579 // walking the entire dataset graph, we can walk up directly from the
580 // sink_node to get the RebatchDataset.
581 NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
582 if (input_node->op() != kRebatchDatasetV2OpName) {
583 return Status::OK();
584 }
585
586 NodeDef* rebatch_node = input_node;
587 // Update RebatchDatasetV2 in place. Since Rebatch is an internal API, no
588 // other nodes should have it as an input.
589 rebatch_node->set_op(kRebatchDatasetOpName);
590 // Delete the `batch_sizes` and `drop_remainder` input.
591 rebatch_node->mutable_input()->DeleteSubrange(/*start=*/1, /*num=*/2);
592 // Add the `num_replicas` input.
593 if (num_replicas < 1) {
594 return errors::InvalidArgument(
595 "Cannot rewrite RebatchDatasetV2 to legacy RebatchDataset with invalid "
596 "num_replicas argument. `num_replicas` is ",
597 num_replicas, ", but expected to be >= 1.");
598 }
599 auto num_replicas_node = graph_utils::AddScalarConstNode(num_replicas, graph);
600 rebatch_node->add_input(num_replicas_node->name());
601
602 // Set `use_fallback` attr. This attr is not used anywhere, so its value
603 // does not matter
604 (*rebatch_node->mutable_attr())["use_fallback"].set_b(true);
605
606 // Update the output_shapes attr to set all its batch dimensions to -1
607 // (unknown).
608 auto* shapes_attr =
609 gtl::FindOrNull(*rebatch_node->mutable_attr(), "output_shapes");
610 if (shapes_attr == nullptr) {
611 return errors::InvalidArgument(
612 "Cannot rewrite RebatchDatasetV2 with missing `output_shapes` attr.");
613 }
614 for (int i = 0; i < shapes_attr->list().shape_size(); ++i) {
615 auto* shape = shapes_attr->mutable_list()->mutable_shape(i);
616 if (shape->unknown_rank()) continue;
617 shape->mutable_dim(0)->set_size(-1);
618 }
619
620 return Status::OK();
621 }
622
ShardByData(const NodeDef & sink_node,int64_t num_workers,int64_t index,int64_t num_replicas,MutableGraphView * graph)623 Status ShardByData(const NodeDef& sink_node, int64_t num_workers, int64_t index,
624 int64_t num_replicas, MutableGraphView* graph) {
625 const NodeDef* shard_before = &sink_node;
626 // We sometimes insert a PrefetchDataset, OptionsDataset, and FinalizeDataset
627 // at the end of the input pipeline before autosharding. When sharding by
628 // data, we should insert the shard before the these datasets so that the
629 // right number of elements is prefetched.
630 NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
631 while (input_node->op() == kPrefetchDatasetOpName ||
632 input_node->op() == kOptionsDatasetOpName ||
633 input_node->op() == kFinalizeDatasetOpName) {
634 shard_before = input_node;
635 input_node = graph_utils::GetInputNode(*input_node, *graph);
636 }
637 // Sharding by data only works with legacy RebatchDataset. As such, we rewrite
638 // all instances of RebatchDatasetV2 to RebatchDataset.
639 TF_RETURN_IF_ERROR(RewriteRebatchV2ToV1(*shard_before, num_replicas, graph));
640 return AddShardNode(graph, *shard_before, num_workers, index);
641 }
642
643 // Searches the dataset graph replacing any occurence of `shard(1, 0)` with
644 // `shard(num_workers, index)`.
ShardByHint(const NodeDef & sink_node,int64_t num_workers,int64_t index,int64_t num_replicas,MutableGraphView * graph)645 Status ShardByHint(const NodeDef& sink_node, int64_t num_workers, int64_t index,
646 int64_t num_replicas, MutableGraphView* graph) {
647 auto get_shard_node = [graph](const NodeDef& node) -> const NodeDef* {
648 if (node.op() != kShardDatasetOpName) return nullptr;
649 auto num_workers_node = graph->GetNode(node.input(1));
650 if (num_workers_node->op() != kConstOpName) return nullptr;
651 if (num_workers_node->attr().at("value").tensor().int64_val(0) !=
652 tensorflow::data::kShardHint)
653 return nullptr;
654 return &node;
655 };
656
657 auto* num_workers_node =
658 graph_utils::AddScalarConstNode(static_cast<int64>(num_workers), graph);
659 auto* worker_index_node =
660 graph_utils::AddScalarConstNode(static_cast<int64>(index), graph);
661
662 for (const NodeDef& node : graph->graph()->node()) {
663 const NodeDef* shard_node = get_shard_node(node);
664 if (!shard_node) continue;
665 auto mutable_node = graph->GetNode(shard_node->name());
666 *mutable_node->mutable_input(1) = num_workers_node->name();
667 *mutable_node->mutable_input(2) = worker_index_node->name();
668 // Ensure that each shard will have at least one element.
669 (*(mutable_node->mutable_attr()))[data::ShardDatasetOp::kRequireNonEmpty]
670 .set_b(true);
671 }
672 return Status::OK();
673 }
674
OptimizeGraph(const GrapplerItem & item,int64_t num_workers,int64_t index,AutoShardPolicy policy,int64_t num_replicas,GraphDef * output,AutoShardPolicy * policy_applied)675 Status OptimizeGraph(const GrapplerItem& item, int64_t num_workers,
676 int64_t index, AutoShardPolicy policy,
677 int64_t num_replicas, GraphDef* output,
678 AutoShardPolicy* policy_applied) {
679 *policy_applied = policy;
680 if (policy == AutoShardPolicy::OFF ||
681 (policy == AutoShardPolicy::FILE && num_workers == 1 && index == 0)) {
682 return Status::OK();
683 }
684
685 *output = item.graph;
686 MutableGraphView graph(output);
687 FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
688
689 NodeDef* sink_node;
690 TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
691
692 switch (policy) {
693 case AutoShardPolicy::OFF:
694 return Status::OK();
695 case AutoShardPolicy::FILE:
696 return ShardByFile(*sink_node, num_workers, index, &flib, &graph);
697 case AutoShardPolicy::DATA:
698 return ShardByData(*sink_node, num_workers, index, num_replicas, &graph);
699 case AutoShardPolicy::HINT:
700 return ShardByHint(*sink_node, num_workers, index, num_replicas, &graph);
701 case AutoShardPolicy::AUTO:
702 default:
703 Status s = ShardByFile(*sink_node, num_workers, index, &flib, &graph);
704 if (errors::IsNotFound(s)) {
705 LOG(WARNING) << "AUTO sharding policy will apply DATA sharding policy "
706 "as it failed to apply FILE sharding policy because of "
707 "the following reason: "
708 << s.error_message();
709 *policy_applied = AutoShardPolicy::DATA;
710 return ShardByData(*sink_node, num_workers, index, num_replicas,
711 &graph);
712 }
713 *policy_applied = AutoShardPolicy::FILE;
714 return s;
715 }
716 }
717
718 } // anonymous namespace
719
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)720 Status AutoShard::Init(
721 const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
722 if (!config) return errors::InvalidArgument("RewriterConfig not found.");
723
724 if ((config->parameter_map().find(kNumWorkersAttrName) ==
725 config->parameter_map().end())) {
726 return errors::InvalidArgument(kNumWorkersAttrName, " parameter missing.");
727 }
728
729 if ((config->parameter_map().find(kIndexAttrName) ==
730 config->parameter_map().end())) {
731 return errors::InvalidArgument(kIndexAttrName, " parameter missing.");
732 }
733
734 num_workers_ = config->parameter_map().at(kNumWorkersAttrName).i();
735 index_ = config->parameter_map().at(kIndexAttrName).i();
736 auto_shard_policy_ =
737 AutoShardPolicy(config->parameter_map().at(kAutoShardPolicyAttrName).i());
738 num_replicas_ = config->parameter_map().at(kNumReplicasAttrName).i();
739
740 if (auto_shard_policy_ != AutoShardPolicy::OFF &&
741 auto_shard_policy_ != AutoShardPolicy::AUTO &&
742 auto_shard_policy_ != AutoShardPolicy::DATA &&
743 auto_shard_policy_ != AutoShardPolicy::FILE &&
744 auto_shard_policy_ != AutoShardPolicy::HINT) {
745 return errors::InvalidArgument(kAutoShardPolicyAttrName, " is invalid.");
746 }
747
748 if (num_workers_ < 1) {
749 return errors::InvalidArgument(kNumWorkersAttrName,
750 " should be >= 1, currently ", num_workers_);
751 }
752
753 if (index_ < 0 || index_ >= num_workers_) {
754 return errors::InvalidArgument(kIndexAttrName, " should be >= 0 and < ",
755 num_workers_, ", currently ", index_);
756 }
757
758 if (num_replicas_ < 0) {
759 return errors::InvalidArgument(kNumReplicasAttrName, " should be >= 0");
760 }
761
762 return Status::OK();
763 }
764
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)765 Status AutoShard::OptimizeAndCollectStats(Cluster* cluster,
766 const GrapplerItem& item,
767 GraphDef* output,
768 OptimizationStats* stats) {
769 *output = item.graph;
770 AutoShardPolicy policy_applied;
771 TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_,
772 auto_shard_policy_, num_replicas_, output,
773 &policy_applied));
774
775 // Only record on the first shard to avoid duplication.
776 if (index_ == 0) {
777 // item.id is always the same so we use the address of the cluster as id.
778 string id = strings::StrCat(reinterpret_cast<uint64>(cluster));
779 metrics::RecordTFDataAutoShard(id, policy_applied, num_workers_,
780 num_replicas_);
781 }
782 stats->num_changes++;
783 return Status::OK();
784 }
785
786 REGISTER_GRAPH_OPTIMIZER_AS(AutoShard, "tf_auto_shard");
787
788 } // namespace grappler
789 } // namespace tensorflow
790