• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/dataset_utils.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <queue>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/dataset.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/op_def_builder.h"
32 #include "tensorflow/core/framework/op_def_util.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/graph/graph_def_builder.h"
37 #include "tensorflow/core/lib/core/blocking_counter.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/lib/strings/proto_serialization.h"
41 #include "tensorflow/core/platform/host_info.h"
42 #include "tensorflow/core/platform/regexp.h"
43 #include "tensorflow/core/util/work_sharder.h"
44 
45 namespace tensorflow {
46 namespace data {
47 namespace {
48 
49 constexpr char kOutputSize[] = "output_size";
50 constexpr char kCode[] = "code";
51 constexpr char kMessage[] = "msg";
52 constexpr char kOutput[] = "output";
53 
get_dataset_experiment_registry_lock()54 static mutex* get_dataset_experiment_registry_lock() {
55   static mutex dataset_experiment_registry_lock(LINKER_INITIALIZED);
56   return &dataset_experiment_registry_lock;
57 }
58 
get_dataset_experiments()59 static absl::flat_hash_map<string, int64>* get_dataset_experiments() {
60   static absl::flat_hash_map<string, int64>* experiments =
61       new absl::flat_hash_map<string, int64>;
62   return experiments;
63 }
64 
65 // Use "Opt" suffix so that they are not confused with the enums in Options
66 // proto.
67 constexpr char kMapAndBatchFusionOpt[] = "map_and_batch_fusion";
68 constexpr char kNoopEliminationOpt[] = "noop_elimination";
69 constexpr char kMapParallelizationOpt[] = "map_parallelization";
70 constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion";
71 constexpr char kFilterFusionOpt[] = "filter_fusion";
72 constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion";
73 constexpr char kMapFusionOpt[] = "map_fusion";
74 constexpr char kParallelBatchOpt[] = "parallel_batch";
75 constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes";
76 constexpr char kDisablePrefetchLegacyAutotuneOpt[] =
77     "disable_prefetch_legacy_autotune";
78 constexpr char kMakeSloppyOpt[] = "make_sloppy";
79 constexpr char kUseChooseFastestOpt[] = "use_choose_fastest";
80 constexpr char kBatchParallelizationOpt[] = "batch_parallelization";
81 constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent";
82 constexpr char kAutotuneOpt[] = "autotune";
83 constexpr char kSlackOpt[] = "slack";
84 constexpr char kSlackPeriodOpt[] = "slack_period";
85 
DefaultOptimizationGraphRewrites(const Options & options,absl::flat_hash_set<tstring> * optimization_enabled,absl::flat_hash_set<tstring> * optimization_disabled,absl::flat_hash_set<tstring> * optimization_default)86 void DefaultOptimizationGraphRewrites(
87     const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
88     absl::flat_hash_set<tstring>* optimization_disabled,
89     absl::flat_hash_set<tstring>* optimization_default) {
90   const auto& optimization_options = options.optimization_options();
91   if (optimization_options.optional_apply_default_optimizations_case() !=
92           OptimizationOptions::kApplyDefaultOptimizations ||
93       optimization_options.apply_default_optimizations()) {
94     if (optimization_options.optional_map_and_batch_fusion_case() !=
95         OptimizationOptions::kMapAndBatchFusion) {
96       optimization_default->insert(kMapAndBatchFusionOpt);
97     }
98     if (optimization_options.optional_noop_elimination_case() !=
99         OptimizationOptions::kNoopElimination) {
100       optimization_default->insert(kNoopEliminationOpt);
101     }
102     if (optimization_options.optional_map_parallelization_case() !=
103         OptimizationOptions::kMapParallelization) {
104       optimization_default->insert(kMapParallelizationOpt);
105     }
106     if (optimization_options.optional_shuffle_and_repeat_fusion_case() !=
107         OptimizationOptions::kShuffleAndRepeatFusion) {
108       optimization_default->insert(kShuffleAndRepeatFusionOpt);
109     }
110   }
111   if (optimization_options.optional_filter_fusion_case() ==
112       OptimizationOptions::kFilterFusion) {
113     if (optimization_options.filter_fusion()) {
114       optimization_enabled->insert(kFilterFusionOpt);
115     } else {
116       optimization_disabled->insert(kFilterFusionOpt);
117     }
118   }
119   if (optimization_options.optional_map_and_batch_fusion_case() ==
120       OptimizationOptions::kMapAndBatchFusion) {
121     if (optimization_options.map_and_batch_fusion()) {
122       optimization_enabled->insert(kMapAndBatchFusionOpt);
123     } else {
124       optimization_disabled->insert(kMapAndBatchFusionOpt);
125     }
126   }
127   if (optimization_options.optional_map_and_filter_fusion_case() ==
128       OptimizationOptions::kMapAndFilterFusion) {
129     if (optimization_options.map_and_filter_fusion()) {
130       optimization_enabled->insert(kMapAndFilterFusionOpt);
131     } else {
132       optimization_disabled->insert(kMapAndFilterFusionOpt);
133     }
134   }
135   if (optimization_options.optional_map_parallelization_case() ==
136       OptimizationOptions::kMapParallelization) {
137     if (optimization_options.map_parallelization()) {
138       optimization_enabled->insert(kMapParallelizationOpt);
139     } else {
140       optimization_disabled->insert(kMapParallelizationOpt);
141     }
142   }
143   if (optimization_options.optional_map_fusion_case() ==
144       OptimizationOptions::kMapFusion) {
145     if (optimization_options.map_fusion()) {
146       optimization_enabled->insert(kMapFusionOpt);
147     } else {
148       optimization_disabled->insert(kMapFusionOpt);
149     }
150   }
151   if (optimization_options.optional_noop_elimination_case() ==
152       OptimizationOptions::kNoopElimination) {
153     if (optimization_options.noop_elimination()) {
154       optimization_enabled->insert(kNoopEliminationOpt);
155     } else {
156       optimization_disabled->insert(kNoopEliminationOpt);
157     }
158   }
159   if (optimization_options.optional_parallel_batch_case() ==
160       OptimizationOptions::kParallelBatch) {
161     if (optimization_options.parallel_batch()) {
162       optimization_enabled->insert(kParallelBatchOpt);
163     } else {
164       optimization_disabled->insert(kParallelBatchOpt);
165     }
166   }
167   if (optimization_options.optional_shuffle_and_repeat_fusion_case() ==
168       OptimizationOptions::kShuffleAndRepeatFusion) {
169     if (optimization_options.shuffle_and_repeat_fusion()) {
170       optimization_enabled->insert(kShuffleAndRepeatFusionOpt);
171     } else {
172       optimization_disabled->insert(kShuffleAndRepeatFusionOpt);
173     }
174   }
175 }
176 
177 // Returns whether an op has been allowlisted as stateless. Uses a heuristic to
178 // allowlist source dataset ops which have been marked stateful due to
179 // b/65524810. Also looks up the `op_def->name` in the global
180 // `AllowlistedStatefulOpRegistry`.
IsOpAllowlisted(const OpDef * op_def)181 bool IsOpAllowlisted(const OpDef* op_def) {
182   return (op_def->output_arg_size() == 1 &&
183           op_def->output_arg(0).type() == DT_VARIANT &&
184           (absl::EndsWith(op_def->name(), "Dataset") ||
185            absl::EndsWith(op_def->name(), "DatasetV2"))) ||
186          AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
187 }
188 
189 }  // namespace
190 
MaybeOverrideSeeds(std::pair<int64,int64> seeds)191 std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds) {
192   if (seeds.first == 0 && seeds.second == 0) {
193     return {random::New64(), random::New64()};
194   }
195   return seeds;
196 }
197 
VerifyTypeMatch(const DataType & expected,const DataType & received,int index)198 Status VerifyTypeMatch(const DataType& expected, const DataType& received,
199                        int index) {
200   if (expected != received) {
201     return errors::InvalidArgument("Data type mismatch at component ", index,
202                                    ": expected ", DataTypeString(expected),
203                                    " but got ", DataTypeString(received), ".");
204   }
205   return Status::OK();
206 }
207 
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)208 Status VerifyTypesMatch(const DataTypeVector& expected,
209                         const DataTypeVector& received) {
210   if (expected.size() != received.size()) {
211     return errors::InvalidArgument(
212         "Number of components does not match: expected ", expected.size(),
213         " types but got ", received.size(), ".");
214   }
215   for (size_t i = 0; i < expected.size(); ++i) {
216     TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
217   }
218   return Status::OK();
219 }
220 
VerifyTypesMatch(const DataTypeVector & expected,const std::vector<Tensor> & received)221 Status VerifyTypesMatch(const DataTypeVector& expected,
222                         const std::vector<Tensor>& received) {
223   if (expected.size() != received.size()) {
224     return errors::InvalidArgument(
225         "Number of components does not match: expected ", expected.size(),
226         " types but got ", received.size(), ".");
227   }
228   for (size_t i = 0; i < expected.size(); ++i) {
229     TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
230   }
231   return Status::OK();
232 }
233 
VerifyShapeCompatible(const PartialTensorShape & expected,const PartialTensorShape & received,int index)234 Status VerifyShapeCompatible(const PartialTensorShape& expected,
235                              const PartialTensorShape& received, int index) {
236   if (!expected.IsCompatibleWith(received)) {
237     return errors::InvalidArgument("Incompatible shapes at component ", index,
238                                    ": expected ", expected.DebugString(),
239                                    " but got ", received.DebugString(), ".");
240   }
241   return Status::OK();
242 }
243 
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)244 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
245                               const std::vector<PartialTensorShape>& received) {
246   if (expected.size() != received.size()) {
247     return errors::InvalidArgument(
248         "Number of components does not match: expected ", expected.size(),
249         " shapes but got ", received.size(), ".");
250   }
251   for (size_t i = 0; i < expected.size(); ++i) {
252     TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
253   }
254 
255   return Status::OK();
256 }
257 
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<Tensor> & received)258 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
259                               const std::vector<Tensor>& received) {
260   if (expected.size() != received.size()) {
261     return errors::InvalidArgument(
262         "Number of components does not match: expected ", expected.size(),
263         " shapes but got ", received.size(), ".");
264   }
265   for (size_t i = 0; i < expected.size(); ++i) {
266     TF_RETURN_IF_ERROR(
267         VerifyShapeCompatible(expected[i], received[i].shape(), i));
268   }
269 
270   return Status::OK();
271 }
272 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)273 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
274                             const FunctionLibraryDefinition& to_add) {
275   for (const auto& fn : to_add.ListFunctionNames()) {
276     if (auto found = base->Find(fn)) {
277       if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
278         return errors::InvalidArgument("Cannot add function '", fn,
279                                        "' because a different function with "
280                                        "the same signature already exists.");
281       }
282       TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
283     }
284   }
285   return base->AddLibrary(to_add);
286 }
287 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)288 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
289                             const FunctionDefLibrary& to_add) {
290   for (const auto& fd : to_add.function()) {
291     if (auto found = base->Find(fd.signature().name())) {
292       if (!OpDefEqual(found->signature(), fd.signature())) {
293         return errors::InvalidArgument("Cannot add function '",
294                                        fd.signature().name(),
295                                        "' because a different function with "
296                                        "the same signature already exists.");
297       }
298       TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
299     }
300   }
301   return base->AddLibrary(to_add);
302 }
303 
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def)304 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
305                           const FunctionDef& function_def) {
306   if (!function_def.signature().is_stateful()) {
307     return Status::OK();
308   }
309 
310   for (const NodeDef& node_def : function_def.node_def()) {
311     TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
312   }
313   return Status::OK();
314 }
315 
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node)316 Status IsNodeStateful(const FunctionLibraryDefinition& library,
317                       const NodeDef& node) {
318   const OpDef* op_def;
319 
320   // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
321   // `LookUpOpDef` errors here.
322   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
323       IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
324       op_def->name() == "Assert") {
325     return Status::OK();
326   }
327 
328   if (op_def->name() == "If") {
329     const FunctionDef* then_func =
330         library.Find(node.attr().at("then_branch").func().name());
331     const FunctionDef* else_func =
332         library.Find(node.attr().at("else_branch").func().name());
333     if (then_func != nullptr) {
334       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
335     }
336     if (else_func != nullptr) {
337       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
338     }
339     return Status::OK();
340   }
341 
342   if (op_def->name() == "While") {
343     const FunctionDef* cond_func =
344         library.Find(node.attr().at("cond").func().name());
345     const FunctionDef* body_func =
346         library.Find(node.attr().at("body").func().name());
347     if (cond_func != nullptr) {
348       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
349     }
350     if (body_func != nullptr) {
351       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
352     }
353     return Status::OK();
354   }
355 
356   return errors::FailedPrecondition(op_def->name(), " is stateful.");
357 }
358 
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)359 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
360     std::function<void(std::function<void()>)> runner, int max_parallelism) {
361   return std::bind(
362       [max_parallelism](
363           // Note: `runner` is a const reference to avoid copying it.
364           const std::function<void(std::function<void()>)>& runner,
365           std::function<void()> fn) {
366         std::function<void()> scoped_fn = std::bind(
367             [max_parallelism](const std::function<void()>& fn) {
368               ScopedPerThreadMaxParallelism scope(max_parallelism);
369               fn();
370             },
371             std::move(fn));
372         runner(std::move(scoped_fn));
373       },
374       std::move(runner), std::placeholders::_1);
375 }
376 
FromString(const std::string & s,DeterminismPolicy * out)377 Status DeterminismPolicy::FromString(const std::string& s,
378                                      DeterminismPolicy* out) {
379   DeterminismPolicy::Type type;
380   if (s == DeterminismPolicy::kDeterministic) {
381     type = DeterminismPolicy::Type::kDeterministic;
382   } else if (s == DeterminismPolicy::kNondeterministic) {
383     type = DeterminismPolicy::Type::kNondeterministic;
384   } else if (s == DeterminismPolicy::kDefault) {
385     type = DeterminismPolicy::Type::kDefault;
386   } else {
387     return errors::InvalidArgument("Unrecognized determinism policy: ", s);
388   }
389   *out = DeterminismPolicy(type);
390   return Status::OK();
391 }
392 
DeterminismPolicy(bool is_deterministic)393 DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
394   if (is_deterministic) {
395     determinism_ = DeterminismPolicy::Type::kDeterministic;
396   } else {
397     determinism_ = DeterminismPolicy::Type::kNondeterministic;
398   }
399 }
400 
String() const401 std::string DeterminismPolicy::String() const {
402   switch (determinism_) {
403     case DeterminismPolicy::Type::kDeterministic:
404       return DeterminismPolicy::kDeterministic;
405     case DeterminismPolicy::Type::kNondeterministic:
406       return DeterminismPolicy::kNondeterministic;
407     case DeterminismPolicy::Type::kDefault:
408       return DeterminismPolicy::kDefault;
409     default:
410       LOG(ERROR) << "Unrecognized determinism value";
411       return "Unrecognized";
412   }
413 }
414 
MatchesAnyVersion(StringPiece op_prefix,StringPiece op_to_match)415 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) {
416   if (!absl::StartsWith(op_to_match, op_prefix)) {
417     return false;
418   }
419   if (op_to_match.length() == op_prefix.length()) {
420     return true;
421   }
422   size_t index = op_to_match.length() - 1;
423   while (isdigit(op_to_match[index])) {
424     index--;
425   }
426   return (op_to_match[index] == 'V') && (op_prefix.length() == index);
427 }
428 
GetExperiments()429 absl::flat_hash_set<string> GetExperiments() {
430   return GetExperiments(port::JobName(),
431                         [](const tstring& str) { return Hash64(str); });
432 }
433 
GetExperiments(const string & job_name,std::function<uint64 (const string &)> hash_func)434 absl::flat_hash_set<string> GetExperiments(
435     const string& job_name, std::function<uint64(const string&)> hash_func) {
436   absl::flat_hash_set<string> experiments;
437 
438   if (job_name.empty()) {
439     return experiments;
440   }
441 
442   // Parse the opt-in and opt-out settings.
443   const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
444   const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
445   string opt_ins_raw;
446   if (opt_ins_raw_cs != nullptr) {
447     opt_ins_raw = string(opt_ins_raw_cs);
448   }
449   string opt_outs_raw;
450   if (opt_outs_raw_cs != nullptr) {
451     opt_outs_raw = string(opt_outs_raw_cs);
452   }
453 
454   // Identify opted out experiments.
455   absl::flat_hash_map<string, int64> live_experiments =
456       DatasetExperimentRegistry::Experiments();
457   absl::flat_hash_set<string> opt_outs;
458   if (opt_outs_raw == "all") {
459     for (const auto& pair : live_experiments) {
460       opt_outs.insert(pair.first);
461     }
462   } else {
463     for (const auto& experiment :
464          str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty())) {
465       opt_outs.insert(experiment);
466     }
467   }
468 
469   // Include opted in experiments unless they are opted out.
470   if (opt_ins_raw == "all") {
471     for (const auto& pair : live_experiments) {
472       auto experiment = pair.first;
473       if (!opt_outs.contains(experiment)) {
474         experiments.insert(experiment);
475       }
476     }
477   } else {
478     for (const auto& experiment :
479          str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty())) {
480       if (!opt_outs.contains(experiment)) {
481         experiments.insert(experiment);
482       }
483     }
484   }
485 
486   // Stochastically include live experiments unless they are opted out.
487   for (const auto& pair : live_experiments) {
488     auto& experiment = pair.first;
489     if ((hash_func(strings::StrCat(job_name, experiment)) % 100 <
490          pair.second) &&
491         !opt_outs.contains(experiment)) {
492       experiments.insert(experiment);
493     }
494   }
495 
496   return experiments;
497 }
498 
LogAndRecordExperiments(const absl::flat_hash_set<string> & experiments)499 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments) {
500   if (!experiments.empty()) {
501     VLOG(1) << "The input pipeline is subject to tf.data experiments. "
502                "Please see `go/tf-data-experiments` for more details.";
503   }
504   for (auto& experiment : experiments) {
505     VLOG(1) << "The experiment \"" << experiment << "\" is applied.";
506     metrics::RecordTFDataExperiment(experiment);
507   }
508 }
509 
GetOptimizations(const Options & options,absl::flat_hash_set<tstring> * optimizations_enabled,absl::flat_hash_set<tstring> * optimizations_disabled,absl::flat_hash_set<tstring> * optimizations_default)510 void GetOptimizations(const Options& options,
511                       absl::flat_hash_set<tstring>* optimizations_enabled,
512                       absl::flat_hash_set<tstring>* optimizations_disabled,
513                       absl::flat_hash_set<tstring>* optimizations_default) {
514   DefaultOptimizationGraphRewrites(options, optimizations_enabled,
515                                    optimizations_disabled,
516                                    optimizations_default);
517   if (options.optional_deterministic_case() == Options::kDeterministic) {
518     if (options.deterministic()) {
519       optimizations_disabled->insert(kMakeSloppyOpt);
520     } else {
521       optimizations_enabled->insert(kMakeSloppyOpt);
522     }
523   }
524   if (options.optional_slack_case() == Options::kSlack) {
525     if (options.slack()) {
526       optimizations_enabled->insert(kSlackOpt);
527     } else {
528       optimizations_disabled->insert(kSlackOpt);
529     }
530   }
531 }
532 
SelectOptimizations(const absl::flat_hash_set<string> & experiments,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_disabled,const absl::flat_hash_set<tstring> & optimizations_default)533 absl::flat_hash_set<tstring> SelectOptimizations(
534     const absl::flat_hash_set<string>& experiments,
535     const absl::flat_hash_set<tstring>& optimizations_enabled,
536     const absl::flat_hash_set<tstring>& optimizations_disabled,
537     const absl::flat_hash_set<tstring>& optimizations_default) {
538   absl::flat_hash_set<tstring> optimizations;
539 
540   // Add the enabled and default optimizations.
541   optimizations.insert(optimizations_enabled.begin(),
542                        optimizations_enabled.end());
543   optimizations.insert(optimizations_default.begin(),
544                        optimizations_default.end());
545 
546   // Add experiments unless they correspond to a disabled optimization.
547   for (auto& experiment : experiments) {
548     if (!optimizations_disabled.contains(experiment)) {
549       optimizations.insert(experiment);
550     }
551   }
552 
553   return optimizations;
554 }
555 
StripDevicePlacement(FunctionDefLibrary * library)556 void StripDevicePlacement(FunctionDefLibrary* library) {
557   for (auto& function : (*library->mutable_function())) {
558     for (auto& node : (*function.mutable_node_def())) {
559       if (!node.device().empty()) {
560         *node.mutable_device() = "";
561       }
562     }
563   }
564 }
565 
CopyPartialBatch(int64_t num_elements,const Tensor & value,Tensor * output)566 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
567                         Tensor* output) {
568   switch (value.dtype()) {
569 #define HANDLE_TYPE(type)                                         \
570   case DataTypeToEnum<type>::value: {                             \
571     auto output_t = output->flat_outer_dims<type>();              \
572     auto value_t = value.flat_outer_dims<type>();                 \
573     for (size_t i = 0; i < num_elements; i++) {                   \
574       output_t.template chip<0>(i) = value_t.template chip<0>(i); \
575     }                                                             \
576     return Status::OK();                                          \
577   }
578     TF_CALL_DATASET_TYPES(HANDLE_TYPE);
579 #undef HANDLE_TYPE
580     default:
581       return errors::InvalidArgument("Unsupported data type: ",
582                                      DataTypeString(value.dtype()));
583   }
584   return Status::OK();
585 }
586 
ReadBatch(IteratorContext * ctx,IteratorStateReader * reader,int64_t batch_size,const string & iterator_prefix,const string & batch_prefix,std::vector<Tensor> * batch)587 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
588                  int64_t batch_size, const string& iterator_prefix,
589                  const string& batch_prefix, std::vector<Tensor>* batch) {
590   int64_t output_size;
591   TF_RETURN_IF_ERROR(reader->ReadScalar(
592       FullName(iterator_prefix,
593                strings::StrCat(batch_prefix, "_", kOutputSize)),
594       &output_size));
595   batch->reserve(output_size);
596   for (int i = 0; i < output_size; i++) {
597     Tensor t;
598     TF_RETURN_IF_ERROR(
599         reader->ReadTensor(ctx->flr(), FullName(iterator_prefix, batch_prefix),
600                            strings::StrCat(kOutput, "_", i), &t));
601     // If the batch was not full, we may have stored only the relevant slice.
602     // Since tensors in `BatchResult.output` are expected to have the leading
603     // dimension of size batch_size, we build a larger tensor and copy the slice
604     // read from the checkpoint into it.
605     if (t.dim_size(0) < batch_size) {
606       TensorShape component_shape(t.shape());
607       component_shape.set_dim(0, batch_size);
608       AllocatorAttributes attr;
609       attr.set_gpu_compatible(true);
610       Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
611       TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
612       batch->emplace_back(std::move(new_t));
613     } else {
614       batch->emplace_back(std::move(t));
615     }
616   }
617   return Status::OK();
618 }
619 
WriteBatch(int64_t batch_size,int64_t num_elements,const string & iterator_prefix,const string & batch_prefix,IteratorStateWriter * writer,std::vector<Tensor> * batch)620 Status WriteBatch(int64_t batch_size, int64_t num_elements,
621                   const string& iterator_prefix, const string& batch_prefix,
622                   IteratorStateWriter* writer, std::vector<Tensor>* batch) {
623   TF_RETURN_IF_ERROR(writer->WriteScalar(
624       FullName(iterator_prefix,
625                strings::StrCat(batch_prefix, "_", kOutputSize)),
626       batch->size()));
627   for (int i = 0; i < batch->size(); i++) {
628     // If the batch is not full, we only store the first `num_elements` values.
629     // The rest of the batch tensor is *uninitialized* and accessing that will
630     // raise msan errors.
631     if (num_elements < batch_size) {
632       TF_RETURN_IF_ERROR(
633           writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
634                               strings::StrCat(kOutput, "_", i),
635                               (*batch)[i].Slice(0, num_elements)));
636     } else {
637       TF_RETURN_IF_ERROR(
638           writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
639                               strings::StrCat(kOutput, "_", i), (*batch)[i]));
640     }
641   }
642   return Status::OK();
643 }
644 
ReadStatus(const string & iterator_prefix,const string & prefix,IteratorStateReader * reader,Status * status)645 Status ReadStatus(const string& iterator_prefix, const string& prefix,
646                   IteratorStateReader* reader, Status* status) {
647   int64_t code_int;
648   TF_RETURN_IF_ERROR(reader->ReadScalar(
649       FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
650       &code_int));
651   error::Code code = static_cast<error::Code>(code_int);
652 
653   if (code != error::Code::OK) {
654     tstring error_message;
655     TF_RETURN_IF_ERROR(reader->ReadScalar(
656         FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
657         &error_message));
658     *status = Status(code, error_message);
659   } else {
660     *status = Status::OK();
661   }
662   return Status::OK();
663 }
664 
WriteStatus(const string & iterator_prefix,const string & prefix,const Status & status,IteratorStateWriter * writer)665 Status WriteStatus(const string& iterator_prefix, const string& prefix,
666                    const Status& status, IteratorStateWriter* writer) {
667   TF_RETURN_IF_ERROR(writer->WriteScalar(
668       FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
669       static_cast<int64>(status.code())));
670   if (!status.ok()) {
671     TF_RETURN_IF_ERROR(writer->WriteScalar(
672         FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
673         status.error_message()));
674   }
675   return Status::OK();
676 }
677 
ProcessBatch(int64_t batch_size,int64_t num_elements,bool drop_remainder,const Status & status,IteratorContext * ctx,std::vector<Tensor> * output,bool * end_of_sequence,std::vector<Tensor> * batch)678 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
679                     bool drop_remainder, const Status& status,
680                     IteratorContext* ctx, std::vector<Tensor>* output,
681                     bool* end_of_sequence, std::vector<Tensor>* batch) {
682   if (num_elements == 0) {
683     if (status.ok() || errors::IsOutOfRange(status)) {
684       *end_of_sequence = true;
685       return Status::OK();
686     } else {
687       *end_of_sequence = false;
688       return status;
689     }
690   }
691   if (!status.ok() && !errors::IsOutOfRange(status)) {
692     *end_of_sequence = false;
693     return status;
694   }
695   if (num_elements < batch_size) {
696     if (drop_remainder) {
697       *end_of_sequence = true;
698       return Status::OK();
699     }
700     for (size_t i = 0; i < batch->size(); ++i) {
701       TensorShape component_shape((*batch)[i].shape());
702       component_shape.set_dim(0, num_elements);
703       AllocatorAttributes attr;
704       attr.set_gpu_compatible(true);
705       output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
706                            component_shape);
707       if (!output->back().IsInitialized()) {
708         return errors::ResourceExhausted(
709             "Failed to allocate memory for the batch of component ", i);
710       }
711       TF_RETURN_IF_ERROR(
712           CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
713     }
714   } else {
715     *output = std::move(*batch);
716   }
717   *end_of_sequence = false;
718   return Status::OK();
719 }
720 
CopyBatch(IteratorContext * ctx,const std::vector<std::vector<Tensor>> & batch_elements,bool parallel_copy,std::function<Status ()> allocation_callback,std::vector<Tensor> * out_tensors)721 Status CopyBatch(IteratorContext* ctx,
722                  const std::vector<std::vector<Tensor>>& batch_elements,
723                  bool parallel_copy,
724                  std::function<Status()> allocation_callback,
725                  std::vector<Tensor>* out_tensors) {
726   static bool in_experiment =
727       GetExperiments().contains("parallelize_batch_copy");
728   const size_t num_tuple_components = batch_elements.at(0).size();
729   out_tensors->reserve(num_tuple_components);
730   const int64_t num_batch_elements = batch_elements.size();
731   for (size_t component_index = 0; component_index < num_tuple_components;
732        ++component_index) {
733     const Tensor& first_element = batch_elements.at(0)[component_index];
734     TensorShape first_element_shape(first_element.shape());
735     TensorShape batch_component_shape({num_batch_elements});
736     batch_component_shape.AppendShape(first_element_shape);
737     out_tensors->emplace_back(ctx->allocator({}), first_element.dtype(),
738                               batch_component_shape);
739     if (!out_tensors->back().IsInitialized()) {
740       return errors::ResourceExhausted(
741           "Failed to allocate memory for the batch of component ",
742           component_index);
743     }
744   }
745   if (allocation_callback) {
746     TF_RETURN_IF_ERROR(allocation_callback());
747   }
748   for (size_t component_index = 0; component_index < num_tuple_components;
749        ++component_index) {
750     Tensor& batch_component = out_tensors->at(component_index);
751     const Tensor& first_element = batch_elements.at(0)[component_index];
752     TensorShape first_element_shape(first_element.shape());
753     // Build the output tuple component by copying one slice from each input
754     // element in the batch.
755     auto copy_element_fn = [component_index, &batch_elements, &batch_component,
756                             &first_element_shape](int index) {
757       if (batch_elements.at(index)[component_index].shape() !=
758           first_element_shape) {
759         return errors::InvalidArgument(
760             "Cannot batch tensors with different shapes in component ",
761             component_index, ". First element had shape ",
762             first_element_shape.DebugString(), " and element ", index,
763             " had shape ",
764             batch_elements.at(index)[component_index].shape().DebugString(),
765             ".");
766       }
767       return batch_util::CopyElementToSlice(
768           std::move(batch_elements.at(index)[component_index]),
769           &batch_component, index);
770     };
771     if (parallel_copy ||
772         (in_experiment && first_element.AllocatedBytes() > (1 << 15))) {
773       Status status;
774       mutex status_mu;
775       BlockingCounter counter(num_batch_elements);
776       const auto num_threads = ctx->runner_threadpool_size();
777       const auto slice_size = num_batch_elements / num_threads;
778       int64_t offset = 0;
779       for (size_t i = 0; i < num_threads; ++i) {
780         int64_t length = slice_size;
781         // When the number of threads does not divide the number of elements
782         // evenly, the size of some slices is incremented to guarantee their
783         // sizes add up to the total number of elements.
784         if (i < num_batch_elements % num_threads) ++length;
785         (*ctx->runner())([offset, length, &status, &status_mu, &counter,
786                           &copy_element_fn]() {
787           for (size_t j = offset; j < offset + length; ++j) {
788             {
789               Status s = copy_element_fn(j);
790               mutex_lock l(status_mu);
791               status.Update(s);
792             }
793             counter.DecrementCount();
794           }
795         });
796         offset += length;
797       }
798       counter.Wait();
799       TF_RETURN_IF_ERROR(status);
800     } else {
801       for (size_t i = 0; i < num_batch_elements; ++i) {
802         TF_RETURN_IF_ERROR(copy_element_fn(i));
803       }
804     }
805   }
806   return Status::OK();
807 }
808 
CreateGraphRewriteConfigs(const Options & options)809 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) {
810   absl::flat_hash_set<tstring> configs;
811   const auto& autotune_options = options.autotune_options();
812   std::vector<tstring> autotune_only_optimizations = {
813       kAutotuneBufferSizesOpt, kBatchParallelizationOpt,
814       kDisablePrefetchLegacyAutotuneOpt, kEnableGradientDescentOpt,
815       kMapParallelizationOpt};
816 
817   if (autotune_options.optional_enabled_case() == AutotuneOptions::kEnabled &&
818       !autotune_options.enabled()) {
819     for (const auto& optimization : autotune_only_optimizations) {
820       configs.insert(
821           absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":false"));
822     }
823   } else {
824     for (const auto& optimization : autotune_only_optimizations) {
825       configs.insert(
826           absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":true"));
827     }
828   }
829   if (options.slack()) {
830     int num_devices = 1;
831     if (options.distribute_options().optional_num_devices_case() ==
832         DistributeOptions::kNumDevices) {
833       num_devices = options.distribute_options().num_devices();
834     }
835     configs.insert(
836         absl::StrCat(kSlackOpt, ":", kSlackPeriodOpt, ":", num_devices));
837   }
838   return configs;
839 }
840 
ShouldConfigureMaxIntraOpParallelism(const Options & options)841 bool ShouldConfigureMaxIntraOpParallelism(const Options& options) {
842   return options.threading_options().optional_max_intra_op_parallelism_case() ==
843          ThreadingOptions::kMaxIntraOpParallelism;
844 }
845 
ShouldUsePrivateThreadPool(const Options & options)846 bool ShouldUsePrivateThreadPool(const Options& options) {
847   return options.threading_options().optional_private_threadpool_size_case() ==
848          ThreadingOptions::kPrivateThreadpoolSize;
849 }
850 
ShouldUseAutotuning(const Options & options)851 bool ShouldUseAutotuning(const Options& options) {
852   return options.autotune_options().optional_enabled_case() !=
853              AutotuneOptions::kEnabled ||
854          options.autotune_options().enabled();
855 }
856 
ShouldApplyOptimizations(const Options & options,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_default)857 bool ShouldApplyOptimizations(
858     const Options& options,
859     const absl::flat_hash_set<tstring>& optimizations_enabled,
860     const absl::flat_hash_set<tstring>& optimizations_default) {
861   return (options.optimization_options()
862                   .optional_apply_default_optimizations_case() !=
863               OptimizationOptions::kApplyDefaultOptimizations ||
864           options.optimization_options().apply_default_optimizations() ||
865           !optimizations_enabled.empty() || !optimizations_default.empty());
866 }
867 
868 // static
Register(const string & experiment,int64_t rollout_pct)869 void DatasetExperimentRegistry::Register(const string& experiment,
870                                          int64_t rollout_pct) {
871   mutex_lock l(*get_dataset_experiment_registry_lock());
872   get_dataset_experiments()->insert(std::make_pair(experiment, rollout_pct));
873 }
874 
875 // static
Experiments()876 absl::flat_hash_map<string, int64> DatasetExperimentRegistry::Experiments() {
877   mutex_lock l(*get_dataset_experiment_registry_lock());
878   return *get_dataset_experiments();
879 }
880 
881 namespace {
882 
883 REGISTER_DATASET_EXPERIMENT("enable_gradient_descent", 0);
884 REGISTER_DATASET_EXPERIMENT("parallelize_batch_copy", 100);
885 REGISTER_DATASET_EXPERIMENT("max_parallelism", 50);
886 }  // namespace
887 }  // namespace data
888 }  // namespace tensorflow
889