• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/dataset_utils.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/dataset.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op_def_builder.h"
34 #include "tensorflow/core/framework/op_def_util.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_util.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/graph/graph_def_builder.h"
40 #include "tensorflow/core/lib/core/blocking_counter.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/lib/strings/proto_serialization.h"
44 #include "tensorflow/core/platform/host_info.h"
45 #include "tensorflow/core/platform/regexp.h"
46 #include "tensorflow/core/util/determinism.h"
47 #include "tensorflow/core/util/work_sharder.h"
48 
49 namespace tensorflow {
50 namespace data {
51 namespace {
52 
53 constexpr char kOutputSize[] = "output_size";
54 constexpr char kCode[] = "code";
55 constexpr char kMessage[] = "msg";
56 constexpr char kOutput[] = "output";
57 
get_dataset_experiment_registry_lock()58 static mutex* get_dataset_experiment_registry_lock() {
59   static mutex dataset_experiment_registry_lock(LINKER_INITIALIZED);
60   return &dataset_experiment_registry_lock;
61 }
62 
get_dataset_experiments()63 static absl::flat_hash_map<string, int64_t>* get_dataset_experiments() {
64   static absl::flat_hash_map<string, int64_t>* experiments =
65       new absl::flat_hash_map<string, int64_t>;
66   return experiments;
67 }
68 
69 // Use "Opt" suffix so that they are not confused with the enums in Options
70 // proto.
71 constexpr char kMapAndBatchFusionOpt[] = "map_and_batch_fusion";
72 constexpr char kNoopEliminationOpt[] = "noop_elimination";
73 constexpr char kMapParallelizationOpt[] = "map_parallelization";
74 constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion";
75 constexpr char kFilterFusionOpt[] = "filter_fusion";
76 constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion";
77 constexpr char kMapFusionOpt[] = "map_fusion";
78 constexpr char kParallelBatchOpt[] = "parallel_batch";
79 constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes";
80 constexpr char kDisablePrefetchLegacyAutotuneOpt[] =
81     "disable_prefetch_legacy_autotune";
82 constexpr char kMakeSloppyOpt[] = "make_sloppy";
83 constexpr char kUseChooseFastestOpt[] = "use_choose_fastest";
84 constexpr char kBatchParallelizationOpt[] = "batch_parallelization";
85 constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent";
86 constexpr char kInjectPrefetchOpt[] = "inject_prefetch";
87 constexpr char kAutotuneOpt[] = "autotune";
88 constexpr char kSlackOpt[] = "slack";
89 constexpr char kSlackPeriodOpt[] = "slack_period";
90 constexpr char kMakeDeterministicOpt[] = "make_deterministic";
91 constexpr char kFilterParallelizationOpt[] = "filter_parallelization";
92 
DefaultOptimizationGraphRewrites(const Options & options,absl::flat_hash_set<tstring> * optimization_enabled,absl::flat_hash_set<tstring> * optimization_disabled,absl::flat_hash_set<tstring> * optimization_default)93 void DefaultOptimizationGraphRewrites(
94     const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
95     absl::flat_hash_set<tstring>* optimization_disabled,
96     absl::flat_hash_set<tstring>* optimization_default) {
97   const auto& optimization_options = options.optimization_options();
98   if (optimization_options.optional_apply_default_optimizations_case() !=
99           OptimizationOptions::kApplyDefaultOptimizations ||
100       optimization_options.apply_default_optimizations()) {
101     if (optimization_options.optional_map_and_batch_fusion_case() !=
102         OptimizationOptions::kMapAndBatchFusion) {
103       optimization_default->insert(kMapAndBatchFusionOpt);
104     }
105     if (optimization_options.optional_noop_elimination_case() !=
106         OptimizationOptions::kNoopElimination) {
107       optimization_default->insert(kNoopEliminationOpt);
108     }
109     if (optimization_options.optional_map_parallelization_case() !=
110         OptimizationOptions::kMapParallelization) {
111       optimization_default->insert(kMapParallelizationOpt);
112     }
113     if (optimization_options.optional_shuffle_and_repeat_fusion_case() !=
114         OptimizationOptions::kShuffleAndRepeatFusion) {
115       optimization_default->insert(kShuffleAndRepeatFusionOpt);
116     }
117     if (optimization_options.optional_parallel_batch_case() !=
118         OptimizationOptions::kParallelBatch) {
119       optimization_default->insert(kParallelBatchOpt);
120     }
121   }
122   if (OpDeterminismRequired()) {
123     optimization_enabled->insert(kMakeDeterministicOpt);
124   }
125   if (optimization_options.optional_filter_fusion_case() ==
126       OptimizationOptions::kFilterFusion) {
127     if (optimization_options.filter_fusion()) {
128       optimization_enabled->insert(kFilterFusionOpt);
129     } else {
130       optimization_disabled->insert(kFilterFusionOpt);
131     }
132   }
133   if (optimization_options.optional_map_and_batch_fusion_case() ==
134       OptimizationOptions::kMapAndBatchFusion) {
135     if (optimization_options.map_and_batch_fusion()) {
136       optimization_enabled->insert(kMapAndBatchFusionOpt);
137     } else {
138       optimization_disabled->insert(kMapAndBatchFusionOpt);
139     }
140   }
141   if (optimization_options.optional_map_and_filter_fusion_case() ==
142       OptimizationOptions::kMapAndFilterFusion) {
143     if (optimization_options.map_and_filter_fusion()) {
144       optimization_enabled->insert(kMapAndFilterFusionOpt);
145     } else {
146       optimization_disabled->insert(kMapAndFilterFusionOpt);
147     }
148   }
149   if (optimization_options.optional_map_parallelization_case() ==
150       OptimizationOptions::kMapParallelization) {
151     if (optimization_options.map_parallelization()) {
152       optimization_enabled->insert(kMapParallelizationOpt);
153     } else {
154       optimization_disabled->insert(kMapParallelizationOpt);
155     }
156   }
157   if (optimization_options.optional_filter_parallelization_case() ==
158       OptimizationOptions::kFilterParallelization) {
159     if (optimization_options.filter_parallelization()) {
160       optimization_enabled->insert(kFilterParallelizationOpt);
161     } else {
162       optimization_disabled->insert(kFilterParallelizationOpt);
163     }
164   }
165   if (optimization_options.optional_map_fusion_case() ==
166       OptimizationOptions::kMapFusion) {
167     if (optimization_options.map_fusion()) {
168       optimization_enabled->insert(kMapFusionOpt);
169     } else {
170       optimization_disabled->insert(kMapFusionOpt);
171     }
172   }
173   if (optimization_options.optional_noop_elimination_case() ==
174       OptimizationOptions::kNoopElimination) {
175     if (optimization_options.noop_elimination()) {
176       optimization_enabled->insert(kNoopEliminationOpt);
177     } else {
178       optimization_disabled->insert(kNoopEliminationOpt);
179     }
180   }
181   if (optimization_options.optional_parallel_batch_case() ==
182       OptimizationOptions::kParallelBatch) {
183     if (optimization_options.parallel_batch()) {
184       optimization_enabled->insert(kParallelBatchOpt);
185     } else {
186       optimization_disabled->insert(kParallelBatchOpt);
187     }
188   }
189   if (optimization_options.optional_shuffle_and_repeat_fusion_case() ==
190       OptimizationOptions::kShuffleAndRepeatFusion) {
191     if (optimization_options.shuffle_and_repeat_fusion()) {
192       optimization_enabled->insert(kShuffleAndRepeatFusionOpt);
193     } else {
194       optimization_disabled->insert(kShuffleAndRepeatFusionOpt);
195     }
196   }
197   if (optimization_options.optional_inject_prefetch_case() ==
198       OptimizationOptions::kInjectPrefetch) {
199     if (optimization_options.inject_prefetch()) {
200       optimization_enabled->insert(kInjectPrefetchOpt);
201     } else {
202       optimization_disabled->insert(kInjectPrefetchOpt);
203     }
204   }
205 }
206 
207 // Returns whether an op has been allowlisted as stateless. Uses a heuristic to
208 // allowlist source dataset ops which have been marked stateful due to
209 // b/65524810. Also looks up the `op_def->name` in the global
210 // `AllowlistedStatefulOpRegistry`.
IsOpAllowlisted(const OpDef * op_def)211 bool IsOpAllowlisted(const OpDef* op_def) {
212   return (op_def->output_arg_size() == 1 &&
213           op_def->output_arg(0).type() == DT_VARIANT &&
214           (absl::EndsWith(op_def->name(), "Dataset") ||
215            absl::EndsWith(op_def->name(), "DatasetV2"))) ||
216          AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
217 }
218 
219 }  // namespace
220 
MaybeOverrideSeeds(std::pair<int64_t,int64_t> seeds)221 std::pair<int64_t, int64_t> MaybeOverrideSeeds(
222     std::pair<int64_t, int64_t> seeds) {
223   if (seeds.first == 0 && seeds.second == 0) {
224     return {random::New64(), random::New64()};
225   }
226   return seeds;
227 }
228 
VerifyTypeMatch(const DataType & expected,const DataType & received,int index)229 Status VerifyTypeMatch(const DataType& expected, const DataType& received,
230                        int index) {
231   if (expected != received) {
232     return errors::InvalidArgument("Data type mismatch at component ", index,
233                                    ": expected ", DataTypeString(expected),
234                                    " but got ", DataTypeString(received), ".");
235   }
236   return OkStatus();
237 }
238 
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)239 Status VerifyTypesMatch(const DataTypeVector& expected,
240                         const DataTypeVector& received) {
241   if (expected.size() != received.size()) {
242     return errors::InvalidArgument(
243         "Number of components does not match: expected ", expected.size(),
244         " types but got ", received.size(), ".");
245   }
246   for (size_t i = 0; i < expected.size(); ++i) {
247     TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
248   }
249   return OkStatus();
250 }
251 
VerifyTypesMatch(const DataTypeVector & expected,const std::vector<Tensor> & received)252 Status VerifyTypesMatch(const DataTypeVector& expected,
253                         const std::vector<Tensor>& received) {
254   if (expected.size() != received.size()) {
255     return errors::InvalidArgument(
256         "Number of components does not match: expected ", expected.size(),
257         " types but got ", received.size(), ".");
258   }
259   for (size_t i = 0; i < expected.size(); ++i) {
260     TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
261   }
262   return OkStatus();
263 }
264 
VerifyShapeCompatible(const PartialTensorShape & expected,const PartialTensorShape & received,int index)265 Status VerifyShapeCompatible(const PartialTensorShape& expected,
266                              const PartialTensorShape& received, int index) {
267   if (!expected.IsCompatibleWith(received)) {
268     return errors::InvalidArgument("Incompatible shapes at component ", index,
269                                    ": expected ", expected.DebugString(),
270                                    " but got ", received.DebugString(), ".");
271   }
272   return OkStatus();
273 }
274 
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)275 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
276                               const std::vector<PartialTensorShape>& received) {
277   if (expected.size() != received.size()) {
278     return errors::InvalidArgument(
279         "Number of components does not match: expected ", expected.size(),
280         " shapes but got ", received.size(), ".");
281   }
282   for (size_t i = 0; i < expected.size(); ++i) {
283     TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
284   }
285 
286   return OkStatus();
287 }
288 
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<Tensor> & received)289 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
290                               const std::vector<Tensor>& received) {
291   if (expected.size() != received.size()) {
292     return errors::InvalidArgument(
293         "Number of components does not match: expected ", expected.size(),
294         " shapes but got ", received.size(), ".");
295   }
296   for (size_t i = 0; i < expected.size(); ++i) {
297     TF_RETURN_IF_ERROR(
298         VerifyShapeCompatible(expected[i], received[i].shape(), i));
299   }
300 
301   return OkStatus();
302 }
303 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)304 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
305                             const FunctionLibraryDefinition& to_add) {
306   for (const auto& fn : to_add.ListFunctionNames()) {
307     if (auto found = base->Find(fn)) {
308       if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
309         return errors::InvalidArgument("Cannot add function '", fn,
310                                        "' because a different function with "
311                                        "the same signature already exists.");
312       }
313       TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
314     }
315   }
316   return base->AddLibrary(to_add);
317 }
318 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)319 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
320                             const FunctionDefLibrary& to_add) {
321   for (const auto& fd : to_add.function()) {
322     if (auto found = base->Find(fd.signature().name())) {
323       if (!OpDefEqual(found->signature(), fd.signature())) {
324         return errors::InvalidArgument("Cannot add function '",
325                                        fd.signature().name(),
326                                        "' because a different function with "
327                                        "the same signature already exists.");
328       }
329       TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
330     }
331   }
332   return base->AddLibrary(to_add);
333 }
334 
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def)335 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
336                           const FunctionDef& function_def) {
337   if (!function_def.signature().is_stateful()) {
338     return OkStatus();
339   }
340 
341   for (const NodeDef& node_def : function_def.node_def()) {
342     TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
343   }
344   return OkStatus();
345 }
346 
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node)347 Status IsNodeStateful(const FunctionLibraryDefinition& library,
348                       const NodeDef& node) {
349   const OpDef* op_def;
350 
351   // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
352   // `LookUpOpDef` errors here.
353   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
354       IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
355       op_def->name() == "Assert") {
356     return OkStatus();
357   }
358 
359   if (op_def->name() == "If") {
360     const FunctionDef* then_func =
361         library.Find(node.attr().at("then_branch").func().name());
362     const FunctionDef* else_func =
363         library.Find(node.attr().at("else_branch").func().name());
364     if (then_func != nullptr) {
365       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
366     }
367     if (else_func != nullptr) {
368       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
369     }
370     return OkStatus();
371   }
372 
373   if (op_def->name() == "While") {
374     const FunctionDef* cond_func =
375         library.Find(node.attr().at("cond").func().name());
376     const FunctionDef* body_func =
377         library.Find(node.attr().at("body").func().name());
378     if (cond_func != nullptr) {
379       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
380     }
381     if (body_func != nullptr) {
382       TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
383     }
384     return OkStatus();
385   }
386 
387   return errors::FailedPrecondition(op_def->name(), " is stateful.");
388 }
389 
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)390 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
391     std::function<void(std::function<void()>)> runner, int max_parallelism) {
392   return std::bind(
393       [max_parallelism](
394           // Note: `runner` is a const reference to avoid copying it.
395           const std::function<void(std::function<void()>)>& runner,
396           std::function<void()> fn) {
397         std::function<void()> scoped_fn = std::bind(
398             [max_parallelism](const std::function<void()>& fn) {
399               ScopedPerThreadMaxParallelism scope(max_parallelism);
400               fn();
401             },
402             std::move(fn));
403         runner(std::move(scoped_fn));
404       },
405       std::move(runner), std::placeholders::_1);
406 }
407 
FromString(const std::string & s,DeterminismPolicy * out)408 Status DeterminismPolicy::FromString(const std::string& s,
409                                      DeterminismPolicy* out) {
410   DeterminismPolicy::Type type;
411   if (s == DeterminismPolicy::kDeterministic) {
412     type = DeterminismPolicy::Type::kDeterministic;
413   } else if (s == DeterminismPolicy::kNondeterministic) {
414     type = DeterminismPolicy::Type::kNondeterministic;
415   } else if (s == DeterminismPolicy::kDefault) {
416     type = DeterminismPolicy::Type::kDefault;
417   } else {
418     return errors::InvalidArgument("Unrecognized determinism policy: ", s);
419   }
420   *out = DeterminismPolicy(type);
421   return OkStatus();
422 }
423 
DeterminismPolicy(bool is_deterministic)424 DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
425   if (is_deterministic) {
426     determinism_ = DeterminismPolicy::Type::kDeterministic;
427   } else {
428     determinism_ = DeterminismPolicy::Type::kNondeterministic;
429   }
430 }
431 
String() const432 std::string DeterminismPolicy::String() const {
433   switch (determinism_) {
434     case DeterminismPolicy::Type::kDeterministic:
435       return DeterminismPolicy::kDeterministic;
436     case DeterminismPolicy::Type::kNondeterministic:
437       return DeterminismPolicy::kNondeterministic;
438     case DeterminismPolicy::Type::kDefault:
439       return DeterminismPolicy::kDefault;
440     default:
441       LOG(ERROR) << "Unrecognized determinism value";
442       return "Unrecognized";
443   }
444 }
445 
MatchesAnyVersion(StringPiece op_prefix,StringPiece op_to_match)446 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) {
447   if (!absl::StartsWith(op_to_match, op_prefix)) {
448     return false;
449   }
450   if (op_to_match.length() == op_prefix.length()) {
451     return true;
452   }
453   size_t index = op_to_match.length() - 1;
454   while (isdigit(op_to_match[index])) {
455     index--;
456   }
457   return (op_to_match[index] == 'V') && (op_prefix.length() == index);
458 }
459 
GetExperiments()460 absl::flat_hash_set<string> GetExperiments() {
461   return GetExperiments(port::JobName(),
462                         [](const tstring& str) { return Hash64(str); });
463 }
464 
GetExperiments(const string & job_name,std::function<uint64 (const string &)> hash_func)465 absl::flat_hash_set<string> GetExperiments(
466     const string& job_name, std::function<uint64(const string&)> hash_func) {
467   absl::flat_hash_set<string> experiments;
468   if (job_name.empty()) {
469     return experiments;
470   }
471 
472   // Parse the opt-in and opt-out settings.
473   const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
474   const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
475   string opt_ins_raw;
476   if (opt_ins_raw_cs != nullptr) {
477     opt_ins_raw = string(opt_ins_raw_cs);
478   }
479   string opt_outs_raw;
480   if (opt_outs_raw_cs != nullptr) {
481     opt_outs_raw = string(opt_outs_raw_cs);
482   }
483 
484   // Identify opted out experiments.
485   absl::flat_hash_map<string, int64_t> live_experiments =
486       DatasetExperimentRegistry::Experiments();
487   absl::flat_hash_set<string> opt_outs;
488   if (opt_outs_raw == "all") {
489     for (const auto& pair : live_experiments) {
490       opt_outs.insert(pair.first);
491     }
492   } else {
493     for (const auto& experiment :
494          str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty())) {
495       opt_outs.insert(experiment);
496     }
497   }
498 
499   // Include opted in experiments unless they are opted out.
500   if (opt_ins_raw == "all") {
501     for (const auto& pair : live_experiments) {
502       auto experiment = pair.first;
503       if (!opt_outs.contains(experiment)) {
504         experiments.insert(experiment);
505       }
506     }
507   } else {
508     for (const auto& experiment :
509          str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty())) {
510       if (!opt_outs.contains(experiment)) {
511         experiments.insert(experiment);
512       }
513     }
514   }
515 
516   if (opt_outs_raw == "all_except_opt_in") {
517     return experiments;
518   }
519   // Stochastically include live experiments unless they are opted out.
520   for (const auto& pair : live_experiments) {
521     auto& experiment = pair.first;
522     if ((hash_func(strings::StrCat(job_name, experiment)) % 100 <
523          pair.second) &&
524         !opt_outs.contains(experiment)) {
525       experiments.insert(experiment);
526     }
527   }
528 
529   return experiments;
530 }
531 
LogAndRecordExperiments(const absl::flat_hash_set<string> & experiments)532 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments) {
533   if (!experiments.empty()) {
534     constexpr float TEN_MINUTES = 60.0 * 10.0;
535     LOG_EVERY_N_SEC(INFO, TEN_MINUTES)
536         << "The input pipeline is subject to the following tf.data experiments:"
537         << " " << absl::StrJoin(experiments, ", ") << ". "
538         << "See `go/tf-data-experiments` for more details.";
539   }
540   for (auto& experiment : experiments) {
541     metrics::RecordTFDataExperiment(experiment);
542   }
543 }
544 
GetOptimizations(const Options & options,absl::flat_hash_set<tstring> * optimizations_enabled,absl::flat_hash_set<tstring> * optimizations_disabled,absl::flat_hash_set<tstring> * optimizations_default)545 void GetOptimizations(const Options& options,
546                       absl::flat_hash_set<tstring>* optimizations_enabled,
547                       absl::flat_hash_set<tstring>* optimizations_disabled,
548                       absl::flat_hash_set<tstring>* optimizations_default) {
549   DefaultOptimizationGraphRewrites(options, optimizations_enabled,
550                                    optimizations_disabled,
551                                    optimizations_default);
552   if (!OpDeterminismRequired() &&
553       options.optional_deterministic_case() == Options::kDeterministic &&
554       !options.deterministic()) {
555     optimizations_enabled->insert(kMakeSloppyOpt);
556   }
557   if (options.optional_slack_case() == Options::kSlack) {
558     if (options.slack()) {
559       optimizations_enabled->insert(kSlackOpt);
560     } else {
561       optimizations_disabled->insert(kSlackOpt);
562     }
563   }
564 }
565 
MaybeCopySubSlice(const Tensor & tensor,int64 index)566 Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index) {
567   Tensor slice = tensor.SubSlice(index);
568   if (slice.IsAligned()) {
569     return slice;
570   } else {
571     return tensorflow::tensor::DeepCopy(slice);
572   }
573 }
574 
StripDevicePlacement(FunctionDefLibrary * library)575 void StripDevicePlacement(FunctionDefLibrary* library) {
576   for (auto& function : (*library->mutable_function())) {
577     for (auto& node : (*function.mutable_node_def())) {
578       if (!node.device().empty()) {
579         *node.mutable_device() = "";
580       }
581     }
582   }
583 }
584 
CopyPartialBatch(int64_t num_elements,const Tensor & value,Tensor * output)585 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
586                         Tensor* output) {
587   switch (value.dtype()) {
588 #define HANDLE_TYPE(type)                                         \
589   case DataTypeToEnum<type>::value: {                             \
590     auto output_t = output->flat_outer_dims<type>();              \
591     auto value_t = value.flat_outer_dims<type>();                 \
592     for (size_t i = 0; i < num_elements; i++) {                   \
593       output_t.template chip<0>(i) = value_t.template chip<0>(i); \
594     }                                                             \
595     return OkStatus();                                            \
596   }
597     TF_CALL_DATASET_TYPES(HANDLE_TYPE);
598 #undef HANDLE_TYPE
599     default:
600       return errors::InvalidArgument("Unsupported data type: ",
601                                      DataTypeString(value.dtype()));
602   }
603   return OkStatus();
604 }
605 
ReadBatch(IteratorContext * ctx,IteratorStateReader * reader,int64_t batch_size,const string & iterator_prefix,const string & batch_prefix,std::vector<Tensor> * batch)606 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
607                  int64_t batch_size, const string& iterator_prefix,
608                  const string& batch_prefix, std::vector<Tensor>* batch) {
609   int64_t output_size;
610   TF_RETURN_IF_ERROR(reader->ReadScalar(
611       FullName(iterator_prefix,
612                strings::StrCat(batch_prefix, "_", kOutputSize)),
613       &output_size));
614   batch->reserve(output_size);
615   for (int i = 0; i < output_size; i++) {
616     Tensor t;
617     TF_RETURN_IF_ERROR(
618         reader->ReadTensor(ctx->flr(), FullName(iterator_prefix, batch_prefix),
619                            strings::StrCat(kOutput, "_", i), &t));
620     // If the batch was not full, we may have stored only the relevant slice.
621     // Since tensors in `BatchResult.output` are expected to have the leading
622     // dimension of size batch_size, we build a larger tensor and copy the slice
623     // read from the checkpoint into it.
624     if (t.dim_size(0) < batch_size) {
625       TensorShape component_shape(t.shape());
626       component_shape.set_dim(0, batch_size);
627       AllocatorAttributes attr;
628       attr.set_gpu_compatible(true);
629       Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
630       TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
631       batch->emplace_back(std::move(new_t));
632     } else {
633       batch->emplace_back(std::move(t));
634     }
635   }
636   return OkStatus();
637 }
638 
WriteBatch(int64_t batch_size,int64_t num_elements,const string & iterator_prefix,const string & batch_prefix,IteratorStateWriter * writer,std::vector<Tensor> * batch)639 Status WriteBatch(int64_t batch_size, int64_t num_elements,
640                   const string& iterator_prefix, const string& batch_prefix,
641                   IteratorStateWriter* writer, std::vector<Tensor>* batch) {
642   TF_RETURN_IF_ERROR(writer->WriteScalar(
643       FullName(iterator_prefix,
644                strings::StrCat(batch_prefix, "_", kOutputSize)),
645       batch->size()));
646   for (int i = 0; i < batch->size(); i++) {
647     // If the batch is not full, we only store the first `num_elements` values.
648     // The rest of the batch tensor is *uninitialized* and accessing that will
649     // raise msan errors.
650     if (num_elements < batch_size) {
651       TF_RETURN_IF_ERROR(
652           writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
653                               strings::StrCat(kOutput, "_", i),
654                               (*batch)[i].Slice(0, num_elements)));
655     } else {
656       TF_RETURN_IF_ERROR(
657           writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
658                               strings::StrCat(kOutput, "_", i), (*batch)[i]));
659     }
660   }
661   return OkStatus();
662 }
663 
ReadStatus(const string & iterator_prefix,const string & prefix,IteratorStateReader * reader,Status * status)664 Status ReadStatus(const string& iterator_prefix, const string& prefix,
665                   IteratorStateReader* reader, Status* status) {
666   int64_t code_int;
667   TF_RETURN_IF_ERROR(reader->ReadScalar(
668       FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
669       &code_int));
670   error::Code code = static_cast<error::Code>(code_int);
671 
672   if (code != error::Code::OK) {
673     tstring error_message;
674     TF_RETURN_IF_ERROR(reader->ReadScalar(
675         FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
676         &error_message));
677     *status = Status(code, error_message);
678   } else {
679     *status = OkStatus();
680   }
681   return OkStatus();
682 }
683 
WriteStatus(const string & iterator_prefix,const string & prefix,const Status & status,IteratorStateWriter * writer)684 Status WriteStatus(const string& iterator_prefix, const string& prefix,
685                    const Status& status, IteratorStateWriter* writer) {
686   TF_RETURN_IF_ERROR(writer->WriteScalar(
687       FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
688       static_cast<int64_t>(status.code())));
689   if (!status.ok()) {
690     TF_RETURN_IF_ERROR(writer->WriteScalar(
691         FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
692         status.error_message()));
693   }
694   return OkStatus();
695 }
696 
ProcessBatch(int64_t batch_size,int64_t num_elements,bool drop_remainder,const Status & status,IteratorContext * ctx,std::vector<Tensor> * output,bool * end_of_sequence,std::vector<Tensor> * batch)697 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
698                     bool drop_remainder, const Status& status,
699                     IteratorContext* ctx, std::vector<Tensor>* output,
700                     bool* end_of_sequence, std::vector<Tensor>* batch) {
701   if (num_elements == 0) {
702     if (status.ok() || errors::IsOutOfRange(status)) {
703       *end_of_sequence = true;
704       return OkStatus();
705     } else {
706       *end_of_sequence = false;
707       return status;
708     }
709   }
710   if (!status.ok() && !errors::IsOutOfRange(status)) {
711     *end_of_sequence = false;
712     return status;
713   }
714   if (num_elements < batch_size) {
715     if (drop_remainder) {
716       *end_of_sequence = true;
717       return OkStatus();
718     }
719     for (size_t i = 0; i < batch->size(); ++i) {
720       TensorShape component_shape((*batch)[i].shape());
721       component_shape.set_dim(0, num_elements);
722       AllocatorAttributes attr;
723       attr.set_gpu_compatible(true);
724       output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
725                            component_shape);
726       if (!output->back().IsInitialized()) {
727         return errors::ResourceExhausted(
728             "Failed to allocate memory for the batch of component ", i);
729       }
730       TF_RETURN_IF_ERROR(
731           CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
732     }
733   } else {
734     *output = std::move(*batch);
735   }
736   *end_of_sequence = false;
737   return OkStatus();
738 }
739 
CopyBatch(CopyBatchParams params,const std::vector<std::vector<Tensor>> & batch_elements,bool parallel_copy,std::function<Status ()> allocation_callback,std::vector<Tensor> * out_tensors)740 Status CopyBatch(CopyBatchParams params,
741                  const std::vector<std::vector<Tensor>>& batch_elements,
742                  bool parallel_copy,
743                  std::function<Status()> allocation_callback,
744                  std::vector<Tensor>* out_tensors) {
745   const size_t num_tuple_components = batch_elements.at(0).size();
746   out_tensors->reserve(num_tuple_components);
747   const int64_t num_batch_elements = batch_elements.size();
748   for (size_t component_index = 0; component_index < num_tuple_components;
749        ++component_index) {
750     const Tensor& first_element = batch_elements.at(0)[component_index];
751     TensorShape first_element_shape(first_element.shape());
752     TensorShape batch_component_shape({num_batch_elements});
753     batch_component_shape.AppendShape(first_element_shape);
754     out_tensors->emplace_back(params.allocator, first_element.dtype(),
755                               batch_component_shape);
756     if (!out_tensors->back().IsInitialized()) {
757       return errors::ResourceExhausted(
758           "Failed to allocate memory for the batch of component ",
759           component_index);
760     }
761   }
762   if (allocation_callback) {
763     TF_RETURN_IF_ERROR(allocation_callback());
764   }
765   for (size_t component_index = 0; component_index < num_tuple_components;
766        ++component_index) {
767     Tensor& batch_component = out_tensors->at(component_index);
768     const Tensor& first_element = batch_elements.at(0)[component_index];
769     TensorShape first_element_shape(first_element.shape());
770     // Build the output tuple component by copying one slice from each input
771     // element in the batch.
772     auto copy_element_fn = [component_index, &batch_elements, &batch_component,
773                             &first_element_shape](int index) {
774       if (batch_elements.at(index)[component_index].shape() !=
775           first_element_shape) {
776         return errors::InvalidArgument(
777             "Cannot batch tensors with different shapes in component ",
778             component_index, ". First element had shape ",
779             first_element_shape.DebugString(), " and element ", index,
780             " had shape ",
781             batch_elements.at(index)[component_index].shape().DebugString(),
782             ".");
783       }
784       return batch_util::CopyElementToSlice(
785           std::move(batch_elements.at(index)[component_index]),
786           &batch_component, index);
787     };
788     if (parallel_copy && first_element.AllocatedBytes() > (1 << 15)) {
789       Status status;
790       mutex status_mu;
791       BlockingCounter counter(num_batch_elements);
792       const auto num_threads = params.runner_threadpool_size;
793       const auto slice_size = num_batch_elements / num_threads;
794       int64_t offset = 0;
795       for (size_t i = 0; i < num_threads; ++i) {
796         int64_t length = slice_size;
797         // When the number of threads does not divide the number of elements
798         // evenly, the size of some slices is incremented to guarantee their
799         // sizes add up to the total number of elements.
800         if (i < num_batch_elements % num_threads) ++length;
801         (*params.runner)([offset, length, &status, &status_mu, &counter,
802                           &copy_element_fn]() {
803           for (size_t j = offset; j < offset + length; ++j) {
804             {
805               Status s = copy_element_fn(j);
806               mutex_lock l(status_mu);
807               status.Update(s);
808             }
809             counter.DecrementCount();
810           }
811         });
812         offset += length;
813       }
814       counter.Wait();
815       TF_RETURN_IF_ERROR(status);
816     } else {
817       for (size_t i = 0; i < num_batch_elements; ++i) {
818         TF_RETURN_IF_ERROR(copy_element_fn(i));
819       }
820     }
821   }
822   return OkStatus();
823 }
824 
CreateGraphRewriteConfigs(const Options & options)825 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) {
826   absl::flat_hash_set<tstring> configs;
827   const auto& autotune_options = options.autotune_options();
828   std::vector<tstring> autotune_only_optimizations = {
829       kAutotuneBufferSizesOpt,
830       kBatchParallelizationOpt,
831       kDisablePrefetchLegacyAutotuneOpt,
832       kEnableGradientDescentOpt,
833       kFilterParallelizationOpt,
834       kMapParallelizationOpt,
835       kInjectPrefetchOpt};
836 
837   if (autotune_options.optional_enabled_case() == AutotuneOptions::kEnabled &&
838       !autotune_options.enabled()) {
839     for (const auto& optimization : autotune_only_optimizations) {
840       configs.insert(
841           absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":false"));
842     }
843   } else {
844     for (const auto& optimization : autotune_only_optimizations) {
845       configs.insert(
846           absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":true"));
847     }
848   }
849   if (options.slack()) {
850     int num_devices = 1;
851     if (options.distribute_options().optional_num_devices_case() ==
852         DistributeOptions::kNumDevices) {
853       num_devices = options.distribute_options().num_devices();
854     }
855     configs.insert(
856         absl::StrCat(kSlackOpt, ":", kSlackPeriodOpt, ":", num_devices));
857   }
858   return configs;
859 }
860 
ShouldConfigureMaxIntraOpParallelism(const Options & options)861 bool ShouldConfigureMaxIntraOpParallelism(const Options& options) {
862   return options.threading_options().optional_max_intra_op_parallelism_case() ==
863          ThreadingOptions::kMaxIntraOpParallelism;
864 }
865 
ShouldUsePrivateThreadPool(const Options & options)866 bool ShouldUsePrivateThreadPool(const Options& options) {
867   return options.threading_options().optional_private_threadpool_size_case() ==
868          ThreadingOptions::kPrivateThreadpoolSize;
869 }
870 
ShouldUseAutotuning(const Options & options)871 bool ShouldUseAutotuning(const Options& options) {
872   return options.autotune_options().optional_enabled_case() !=
873              AutotuneOptions::kEnabled ||
874          options.autotune_options().enabled();
875 }
876 
ShouldApplyOptimizations(const Options & options,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_default)877 bool ShouldApplyOptimizations(
878     const Options& options,
879     const absl::flat_hash_set<tstring>& optimizations_enabled,
880     const absl::flat_hash_set<tstring>& optimizations_default) {
881   return (options.optimization_options()
882                   .optional_apply_default_optimizations_case() !=
883               OptimizationOptions::kApplyDefaultOptimizations ||
884           options.optimization_options().apply_default_optimizations() ||
885           !optimizations_enabled.empty() || !optimizations_default.empty());
886 }
887 
GetAutotuneDefaultParallelism(IteratorContext * ctx)888 int64 GetAutotuneDefaultParallelism(IteratorContext* ctx) {
889   return std::min(kAutotuneDefaultParallelism, ctx->runner_threadpool_size());
890 }
891 
892 // static
Register(const string & experiment,int64_t rollout_pct)893 void DatasetExperimentRegistry::Register(const string& experiment,
894                                          int64_t rollout_pct) {
895   mutex_lock l(*get_dataset_experiment_registry_lock());
896   get_dataset_experiments()->insert(std::make_pair(experiment, rollout_pct));
897 }
898 
899 // static
Experiments()900 absl::flat_hash_map<string, int64_t> DatasetExperimentRegistry::Experiments() {
901   mutex_lock l(*get_dataset_experiment_registry_lock());
902   return *get_dataset_experiments();
903 }
904 
905 namespace {
906 
907 REGISTER_DATASET_EXPERIMENT("allow_small_function_optimizations", 0);
908 REGISTER_DATASET_EXPERIMENT("autotune_buffer_optimization", 0);
909 REGISTER_DATASET_EXPERIMENT(kFilterParallelizationOpt, 0);
910 REGISTER_DATASET_EXPERIMENT("inject_prefetch", 100);
911 REGISTER_DATASET_EXPERIMENT("min_outer_interleave_parallelism", 0);
912 REGISTER_DATASET_EXPERIMENT("reduce_interleave_prefetch", 0);
913 REGISTER_DATASET_EXPERIMENT("serialize_input_cycle_length", 0);
914 REGISTER_DATASET_EXPERIMENT("stage_based_autotune", 0);
915 }  // namespace
916 }  // namespace data
917 }  // namespace tensorflow
918