1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/dataset_utils.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <utility>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/dataset.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op_def_builder.h"
34 #include "tensorflow/core/framework/op_def_util.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_util.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/graph/graph_def_builder.h"
40 #include "tensorflow/core/lib/core/blocking_counter.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/hash/hash.h"
43 #include "tensorflow/core/lib/strings/proto_serialization.h"
44 #include "tensorflow/core/platform/host_info.h"
45 #include "tensorflow/core/platform/regexp.h"
46 #include "tensorflow/core/util/determinism.h"
47 #include "tensorflow/core/util/work_sharder.h"
48
49 namespace tensorflow {
50 namespace data {
51 namespace {
52
53 constexpr char kOutputSize[] = "output_size";
54 constexpr char kCode[] = "code";
55 constexpr char kMessage[] = "msg";
56 constexpr char kOutput[] = "output";
57
get_dataset_experiment_registry_lock()58 static mutex* get_dataset_experiment_registry_lock() {
59 static mutex dataset_experiment_registry_lock(LINKER_INITIALIZED);
60 return &dataset_experiment_registry_lock;
61 }
62
get_dataset_experiments()63 static absl::flat_hash_map<string, int64_t>* get_dataset_experiments() {
64 static absl::flat_hash_map<string, int64_t>* experiments =
65 new absl::flat_hash_map<string, int64_t>;
66 return experiments;
67 }
68
69 // Use "Opt" suffix so that they are not confused with the enums in Options
70 // proto.
71 constexpr char kMapAndBatchFusionOpt[] = "map_and_batch_fusion";
72 constexpr char kNoopEliminationOpt[] = "noop_elimination";
73 constexpr char kMapParallelizationOpt[] = "map_parallelization";
74 constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion";
75 constexpr char kFilterFusionOpt[] = "filter_fusion";
76 constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion";
77 constexpr char kMapFusionOpt[] = "map_fusion";
78 constexpr char kParallelBatchOpt[] = "parallel_batch";
79 constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes";
80 constexpr char kDisablePrefetchLegacyAutotuneOpt[] =
81 "disable_prefetch_legacy_autotune";
82 constexpr char kMakeSloppyOpt[] = "make_sloppy";
83 constexpr char kUseChooseFastestOpt[] = "use_choose_fastest";
84 constexpr char kBatchParallelizationOpt[] = "batch_parallelization";
85 constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent";
86 constexpr char kInjectPrefetchOpt[] = "inject_prefetch";
87 constexpr char kAutotuneOpt[] = "autotune";
88 constexpr char kSlackOpt[] = "slack";
89 constexpr char kSlackPeriodOpt[] = "slack_period";
90 constexpr char kMakeDeterministicOpt[] = "make_deterministic";
91 constexpr char kFilterParallelizationOpt[] = "filter_parallelization";
92
DefaultOptimizationGraphRewrites(const Options & options,absl::flat_hash_set<tstring> * optimization_enabled,absl::flat_hash_set<tstring> * optimization_disabled,absl::flat_hash_set<tstring> * optimization_default)93 void DefaultOptimizationGraphRewrites(
94 const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
95 absl::flat_hash_set<tstring>* optimization_disabled,
96 absl::flat_hash_set<tstring>* optimization_default) {
97 const auto& optimization_options = options.optimization_options();
98 if (optimization_options.optional_apply_default_optimizations_case() !=
99 OptimizationOptions::kApplyDefaultOptimizations ||
100 optimization_options.apply_default_optimizations()) {
101 if (optimization_options.optional_map_and_batch_fusion_case() !=
102 OptimizationOptions::kMapAndBatchFusion) {
103 optimization_default->insert(kMapAndBatchFusionOpt);
104 }
105 if (optimization_options.optional_noop_elimination_case() !=
106 OptimizationOptions::kNoopElimination) {
107 optimization_default->insert(kNoopEliminationOpt);
108 }
109 if (optimization_options.optional_map_parallelization_case() !=
110 OptimizationOptions::kMapParallelization) {
111 optimization_default->insert(kMapParallelizationOpt);
112 }
113 if (optimization_options.optional_shuffle_and_repeat_fusion_case() !=
114 OptimizationOptions::kShuffleAndRepeatFusion) {
115 optimization_default->insert(kShuffleAndRepeatFusionOpt);
116 }
117 if (optimization_options.optional_parallel_batch_case() !=
118 OptimizationOptions::kParallelBatch) {
119 optimization_default->insert(kParallelBatchOpt);
120 }
121 }
122 if (OpDeterminismRequired()) {
123 optimization_enabled->insert(kMakeDeterministicOpt);
124 }
125 if (optimization_options.optional_filter_fusion_case() ==
126 OptimizationOptions::kFilterFusion) {
127 if (optimization_options.filter_fusion()) {
128 optimization_enabled->insert(kFilterFusionOpt);
129 } else {
130 optimization_disabled->insert(kFilterFusionOpt);
131 }
132 }
133 if (optimization_options.optional_map_and_batch_fusion_case() ==
134 OptimizationOptions::kMapAndBatchFusion) {
135 if (optimization_options.map_and_batch_fusion()) {
136 optimization_enabled->insert(kMapAndBatchFusionOpt);
137 } else {
138 optimization_disabled->insert(kMapAndBatchFusionOpt);
139 }
140 }
141 if (optimization_options.optional_map_and_filter_fusion_case() ==
142 OptimizationOptions::kMapAndFilterFusion) {
143 if (optimization_options.map_and_filter_fusion()) {
144 optimization_enabled->insert(kMapAndFilterFusionOpt);
145 } else {
146 optimization_disabled->insert(kMapAndFilterFusionOpt);
147 }
148 }
149 if (optimization_options.optional_map_parallelization_case() ==
150 OptimizationOptions::kMapParallelization) {
151 if (optimization_options.map_parallelization()) {
152 optimization_enabled->insert(kMapParallelizationOpt);
153 } else {
154 optimization_disabled->insert(kMapParallelizationOpt);
155 }
156 }
157 if (optimization_options.optional_filter_parallelization_case() ==
158 OptimizationOptions::kFilterParallelization) {
159 if (optimization_options.filter_parallelization()) {
160 optimization_enabled->insert(kFilterParallelizationOpt);
161 } else {
162 optimization_disabled->insert(kFilterParallelizationOpt);
163 }
164 }
165 if (optimization_options.optional_map_fusion_case() ==
166 OptimizationOptions::kMapFusion) {
167 if (optimization_options.map_fusion()) {
168 optimization_enabled->insert(kMapFusionOpt);
169 } else {
170 optimization_disabled->insert(kMapFusionOpt);
171 }
172 }
173 if (optimization_options.optional_noop_elimination_case() ==
174 OptimizationOptions::kNoopElimination) {
175 if (optimization_options.noop_elimination()) {
176 optimization_enabled->insert(kNoopEliminationOpt);
177 } else {
178 optimization_disabled->insert(kNoopEliminationOpt);
179 }
180 }
181 if (optimization_options.optional_parallel_batch_case() ==
182 OptimizationOptions::kParallelBatch) {
183 if (optimization_options.parallel_batch()) {
184 optimization_enabled->insert(kParallelBatchOpt);
185 } else {
186 optimization_disabled->insert(kParallelBatchOpt);
187 }
188 }
189 if (optimization_options.optional_shuffle_and_repeat_fusion_case() ==
190 OptimizationOptions::kShuffleAndRepeatFusion) {
191 if (optimization_options.shuffle_and_repeat_fusion()) {
192 optimization_enabled->insert(kShuffleAndRepeatFusionOpt);
193 } else {
194 optimization_disabled->insert(kShuffleAndRepeatFusionOpt);
195 }
196 }
197 if (optimization_options.optional_inject_prefetch_case() ==
198 OptimizationOptions::kInjectPrefetch) {
199 if (optimization_options.inject_prefetch()) {
200 optimization_enabled->insert(kInjectPrefetchOpt);
201 } else {
202 optimization_disabled->insert(kInjectPrefetchOpt);
203 }
204 }
205 }
206
207 // Returns whether an op has been allowlisted as stateless. Uses a heuristic to
208 // allowlist source dataset ops which have been marked stateful due to
209 // b/65524810. Also looks up the `op_def->name` in the global
210 // `AllowlistedStatefulOpRegistry`.
IsOpAllowlisted(const OpDef * op_def)211 bool IsOpAllowlisted(const OpDef* op_def) {
212 return (op_def->output_arg_size() == 1 &&
213 op_def->output_arg(0).type() == DT_VARIANT &&
214 (absl::EndsWith(op_def->name(), "Dataset") ||
215 absl::EndsWith(op_def->name(), "DatasetV2"))) ||
216 AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
217 }
218
219 } // namespace
220
MaybeOverrideSeeds(std::pair<int64_t,int64_t> seeds)221 std::pair<int64_t, int64_t> MaybeOverrideSeeds(
222 std::pair<int64_t, int64_t> seeds) {
223 if (seeds.first == 0 && seeds.second == 0) {
224 return {random::New64(), random::New64()};
225 }
226 return seeds;
227 }
228
VerifyTypeMatch(const DataType & expected,const DataType & received,int index)229 Status VerifyTypeMatch(const DataType& expected, const DataType& received,
230 int index) {
231 if (expected != received) {
232 return errors::InvalidArgument("Data type mismatch at component ", index,
233 ": expected ", DataTypeString(expected),
234 " but got ", DataTypeString(received), ".");
235 }
236 return OkStatus();
237 }
238
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)239 Status VerifyTypesMatch(const DataTypeVector& expected,
240 const DataTypeVector& received) {
241 if (expected.size() != received.size()) {
242 return errors::InvalidArgument(
243 "Number of components does not match: expected ", expected.size(),
244 " types but got ", received.size(), ".");
245 }
246 for (size_t i = 0; i < expected.size(); ++i) {
247 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
248 }
249 return OkStatus();
250 }
251
VerifyTypesMatch(const DataTypeVector & expected,const std::vector<Tensor> & received)252 Status VerifyTypesMatch(const DataTypeVector& expected,
253 const std::vector<Tensor>& received) {
254 if (expected.size() != received.size()) {
255 return errors::InvalidArgument(
256 "Number of components does not match: expected ", expected.size(),
257 " types but got ", received.size(), ".");
258 }
259 for (size_t i = 0; i < expected.size(); ++i) {
260 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
261 }
262 return OkStatus();
263 }
264
VerifyShapeCompatible(const PartialTensorShape & expected,const PartialTensorShape & received,int index)265 Status VerifyShapeCompatible(const PartialTensorShape& expected,
266 const PartialTensorShape& received, int index) {
267 if (!expected.IsCompatibleWith(received)) {
268 return errors::InvalidArgument("Incompatible shapes at component ", index,
269 ": expected ", expected.DebugString(),
270 " but got ", received.DebugString(), ".");
271 }
272 return OkStatus();
273 }
274
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)275 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
276 const std::vector<PartialTensorShape>& received) {
277 if (expected.size() != received.size()) {
278 return errors::InvalidArgument(
279 "Number of components does not match: expected ", expected.size(),
280 " shapes but got ", received.size(), ".");
281 }
282 for (size_t i = 0; i < expected.size(); ++i) {
283 TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
284 }
285
286 return OkStatus();
287 }
288
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<Tensor> & received)289 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
290 const std::vector<Tensor>& received) {
291 if (expected.size() != received.size()) {
292 return errors::InvalidArgument(
293 "Number of components does not match: expected ", expected.size(),
294 " shapes but got ", received.size(), ".");
295 }
296 for (size_t i = 0; i < expected.size(); ++i) {
297 TF_RETURN_IF_ERROR(
298 VerifyShapeCompatible(expected[i], received[i].shape(), i));
299 }
300
301 return OkStatus();
302 }
303
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)304 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
305 const FunctionLibraryDefinition& to_add) {
306 for (const auto& fn : to_add.ListFunctionNames()) {
307 if (auto found = base->Find(fn)) {
308 if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
309 return errors::InvalidArgument("Cannot add function '", fn,
310 "' because a different function with "
311 "the same signature already exists.");
312 }
313 TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
314 }
315 }
316 return base->AddLibrary(to_add);
317 }
318
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)319 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
320 const FunctionDefLibrary& to_add) {
321 for (const auto& fd : to_add.function()) {
322 if (auto found = base->Find(fd.signature().name())) {
323 if (!OpDefEqual(found->signature(), fd.signature())) {
324 return errors::InvalidArgument("Cannot add function '",
325 fd.signature().name(),
326 "' because a different function with "
327 "the same signature already exists.");
328 }
329 TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
330 }
331 }
332 return base->AddLibrary(to_add);
333 }
334
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def)335 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
336 const FunctionDef& function_def) {
337 if (!function_def.signature().is_stateful()) {
338 return OkStatus();
339 }
340
341 for (const NodeDef& node_def : function_def.node_def()) {
342 TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
343 }
344 return OkStatus();
345 }
346
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node)347 Status IsNodeStateful(const FunctionLibraryDefinition& library,
348 const NodeDef& node) {
349 const OpDef* op_def;
350
351 // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
352 // `LookUpOpDef` errors here.
353 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
354 IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
355 op_def->name() == "Assert") {
356 return OkStatus();
357 }
358
359 if (op_def->name() == "If") {
360 const FunctionDef* then_func =
361 library.Find(node.attr().at("then_branch").func().name());
362 const FunctionDef* else_func =
363 library.Find(node.attr().at("else_branch").func().name());
364 if (then_func != nullptr) {
365 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
366 }
367 if (else_func != nullptr) {
368 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
369 }
370 return OkStatus();
371 }
372
373 if (op_def->name() == "While") {
374 const FunctionDef* cond_func =
375 library.Find(node.attr().at("cond").func().name());
376 const FunctionDef* body_func =
377 library.Find(node.attr().at("body").func().name());
378 if (cond_func != nullptr) {
379 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
380 }
381 if (body_func != nullptr) {
382 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
383 }
384 return OkStatus();
385 }
386
387 return errors::FailedPrecondition(op_def->name(), " is stateful.");
388 }
389
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)390 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
391 std::function<void(std::function<void()>)> runner, int max_parallelism) {
392 return std::bind(
393 [max_parallelism](
394 // Note: `runner` is a const reference to avoid copying it.
395 const std::function<void(std::function<void()>)>& runner,
396 std::function<void()> fn) {
397 std::function<void()> scoped_fn = std::bind(
398 [max_parallelism](const std::function<void()>& fn) {
399 ScopedPerThreadMaxParallelism scope(max_parallelism);
400 fn();
401 },
402 std::move(fn));
403 runner(std::move(scoped_fn));
404 },
405 std::move(runner), std::placeholders::_1);
406 }
407
FromString(const std::string & s,DeterminismPolicy * out)408 Status DeterminismPolicy::FromString(const std::string& s,
409 DeterminismPolicy* out) {
410 DeterminismPolicy::Type type;
411 if (s == DeterminismPolicy::kDeterministic) {
412 type = DeterminismPolicy::Type::kDeterministic;
413 } else if (s == DeterminismPolicy::kNondeterministic) {
414 type = DeterminismPolicy::Type::kNondeterministic;
415 } else if (s == DeterminismPolicy::kDefault) {
416 type = DeterminismPolicy::Type::kDefault;
417 } else {
418 return errors::InvalidArgument("Unrecognized determinism policy: ", s);
419 }
420 *out = DeterminismPolicy(type);
421 return OkStatus();
422 }
423
DeterminismPolicy(bool is_deterministic)424 DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
425 if (is_deterministic) {
426 determinism_ = DeterminismPolicy::Type::kDeterministic;
427 } else {
428 determinism_ = DeterminismPolicy::Type::kNondeterministic;
429 }
430 }
431
String() const432 std::string DeterminismPolicy::String() const {
433 switch (determinism_) {
434 case DeterminismPolicy::Type::kDeterministic:
435 return DeterminismPolicy::kDeterministic;
436 case DeterminismPolicy::Type::kNondeterministic:
437 return DeterminismPolicy::kNondeterministic;
438 case DeterminismPolicy::Type::kDefault:
439 return DeterminismPolicy::kDefault;
440 default:
441 LOG(ERROR) << "Unrecognized determinism value";
442 return "Unrecognized";
443 }
444 }
445
MatchesAnyVersion(StringPiece op_prefix,StringPiece op_to_match)446 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) {
447 if (!absl::StartsWith(op_to_match, op_prefix)) {
448 return false;
449 }
450 if (op_to_match.length() == op_prefix.length()) {
451 return true;
452 }
453 size_t index = op_to_match.length() - 1;
454 while (isdigit(op_to_match[index])) {
455 index--;
456 }
457 return (op_to_match[index] == 'V') && (op_prefix.length() == index);
458 }
459
GetExperiments()460 absl::flat_hash_set<string> GetExperiments() {
461 return GetExperiments(port::JobName(),
462 [](const tstring& str) { return Hash64(str); });
463 }
464
GetExperiments(const string & job_name,std::function<uint64 (const string &)> hash_func)465 absl::flat_hash_set<string> GetExperiments(
466 const string& job_name, std::function<uint64(const string&)> hash_func) {
467 absl::flat_hash_set<string> experiments;
468 if (job_name.empty()) {
469 return experiments;
470 }
471
472 // Parse the opt-in and opt-out settings.
473 const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
474 const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
475 string opt_ins_raw;
476 if (opt_ins_raw_cs != nullptr) {
477 opt_ins_raw = string(opt_ins_raw_cs);
478 }
479 string opt_outs_raw;
480 if (opt_outs_raw_cs != nullptr) {
481 opt_outs_raw = string(opt_outs_raw_cs);
482 }
483
484 // Identify opted out experiments.
485 absl::flat_hash_map<string, int64_t> live_experiments =
486 DatasetExperimentRegistry::Experiments();
487 absl::flat_hash_set<string> opt_outs;
488 if (opt_outs_raw == "all") {
489 for (const auto& pair : live_experiments) {
490 opt_outs.insert(pair.first);
491 }
492 } else {
493 for (const auto& experiment :
494 str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty())) {
495 opt_outs.insert(experiment);
496 }
497 }
498
499 // Include opted in experiments unless they are opted out.
500 if (opt_ins_raw == "all") {
501 for (const auto& pair : live_experiments) {
502 auto experiment = pair.first;
503 if (!opt_outs.contains(experiment)) {
504 experiments.insert(experiment);
505 }
506 }
507 } else {
508 for (const auto& experiment :
509 str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty())) {
510 if (!opt_outs.contains(experiment)) {
511 experiments.insert(experiment);
512 }
513 }
514 }
515
516 if (opt_outs_raw == "all_except_opt_in") {
517 return experiments;
518 }
519 // Stochastically include live experiments unless they are opted out.
520 for (const auto& pair : live_experiments) {
521 auto& experiment = pair.first;
522 if ((hash_func(strings::StrCat(job_name, experiment)) % 100 <
523 pair.second) &&
524 !opt_outs.contains(experiment)) {
525 experiments.insert(experiment);
526 }
527 }
528
529 return experiments;
530 }
531
LogAndRecordExperiments(const absl::flat_hash_set<string> & experiments)532 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments) {
533 if (!experiments.empty()) {
534 constexpr float TEN_MINUTES = 60.0 * 10.0;
535 LOG_EVERY_N_SEC(INFO, TEN_MINUTES)
536 << "The input pipeline is subject to the following tf.data experiments:"
537 << " " << absl::StrJoin(experiments, ", ") << ". "
538 << "See `go/tf-data-experiments` for more details.";
539 }
540 for (auto& experiment : experiments) {
541 metrics::RecordTFDataExperiment(experiment);
542 }
543 }
544
GetOptimizations(const Options & options,absl::flat_hash_set<tstring> * optimizations_enabled,absl::flat_hash_set<tstring> * optimizations_disabled,absl::flat_hash_set<tstring> * optimizations_default)545 void GetOptimizations(const Options& options,
546 absl::flat_hash_set<tstring>* optimizations_enabled,
547 absl::flat_hash_set<tstring>* optimizations_disabled,
548 absl::flat_hash_set<tstring>* optimizations_default) {
549 DefaultOptimizationGraphRewrites(options, optimizations_enabled,
550 optimizations_disabled,
551 optimizations_default);
552 if (!OpDeterminismRequired() &&
553 options.optional_deterministic_case() == Options::kDeterministic &&
554 !options.deterministic()) {
555 optimizations_enabled->insert(kMakeSloppyOpt);
556 }
557 if (options.optional_slack_case() == Options::kSlack) {
558 if (options.slack()) {
559 optimizations_enabled->insert(kSlackOpt);
560 } else {
561 optimizations_disabled->insert(kSlackOpt);
562 }
563 }
564 }
565
MaybeCopySubSlice(const Tensor & tensor,int64 index)566 Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index) {
567 Tensor slice = tensor.SubSlice(index);
568 if (slice.IsAligned()) {
569 return slice;
570 } else {
571 return tensorflow::tensor::DeepCopy(slice);
572 }
573 }
574
StripDevicePlacement(FunctionDefLibrary * library)575 void StripDevicePlacement(FunctionDefLibrary* library) {
576 for (auto& function : (*library->mutable_function())) {
577 for (auto& node : (*function.mutable_node_def())) {
578 if (!node.device().empty()) {
579 *node.mutable_device() = "";
580 }
581 }
582 }
583 }
584
CopyPartialBatch(int64_t num_elements,const Tensor & value,Tensor * output)585 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
586 Tensor* output) {
587 switch (value.dtype()) {
588 #define HANDLE_TYPE(type) \
589 case DataTypeToEnum<type>::value: { \
590 auto output_t = output->flat_outer_dims<type>(); \
591 auto value_t = value.flat_outer_dims<type>(); \
592 for (size_t i = 0; i < num_elements; i++) { \
593 output_t.template chip<0>(i) = value_t.template chip<0>(i); \
594 } \
595 return OkStatus(); \
596 }
597 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
598 #undef HANDLE_TYPE
599 default:
600 return errors::InvalidArgument("Unsupported data type: ",
601 DataTypeString(value.dtype()));
602 }
603 return OkStatus();
604 }
605
ReadBatch(IteratorContext * ctx,IteratorStateReader * reader,int64_t batch_size,const string & iterator_prefix,const string & batch_prefix,std::vector<Tensor> * batch)606 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
607 int64_t batch_size, const string& iterator_prefix,
608 const string& batch_prefix, std::vector<Tensor>* batch) {
609 int64_t output_size;
610 TF_RETURN_IF_ERROR(reader->ReadScalar(
611 FullName(iterator_prefix,
612 strings::StrCat(batch_prefix, "_", kOutputSize)),
613 &output_size));
614 batch->reserve(output_size);
615 for (int i = 0; i < output_size; i++) {
616 Tensor t;
617 TF_RETURN_IF_ERROR(
618 reader->ReadTensor(ctx->flr(), FullName(iterator_prefix, batch_prefix),
619 strings::StrCat(kOutput, "_", i), &t));
620 // If the batch was not full, we may have stored only the relevant slice.
621 // Since tensors in `BatchResult.output` are expected to have the leading
622 // dimension of size batch_size, we build a larger tensor and copy the slice
623 // read from the checkpoint into it.
624 if (t.dim_size(0) < batch_size) {
625 TensorShape component_shape(t.shape());
626 component_shape.set_dim(0, batch_size);
627 AllocatorAttributes attr;
628 attr.set_gpu_compatible(true);
629 Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
630 TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
631 batch->emplace_back(std::move(new_t));
632 } else {
633 batch->emplace_back(std::move(t));
634 }
635 }
636 return OkStatus();
637 }
638
WriteBatch(int64_t batch_size,int64_t num_elements,const string & iterator_prefix,const string & batch_prefix,IteratorStateWriter * writer,std::vector<Tensor> * batch)639 Status WriteBatch(int64_t batch_size, int64_t num_elements,
640 const string& iterator_prefix, const string& batch_prefix,
641 IteratorStateWriter* writer, std::vector<Tensor>* batch) {
642 TF_RETURN_IF_ERROR(writer->WriteScalar(
643 FullName(iterator_prefix,
644 strings::StrCat(batch_prefix, "_", kOutputSize)),
645 batch->size()));
646 for (int i = 0; i < batch->size(); i++) {
647 // If the batch is not full, we only store the first `num_elements` values.
648 // The rest of the batch tensor is *uninitialized* and accessing that will
649 // raise msan errors.
650 if (num_elements < batch_size) {
651 TF_RETURN_IF_ERROR(
652 writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
653 strings::StrCat(kOutput, "_", i),
654 (*batch)[i].Slice(0, num_elements)));
655 } else {
656 TF_RETURN_IF_ERROR(
657 writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
658 strings::StrCat(kOutput, "_", i), (*batch)[i]));
659 }
660 }
661 return OkStatus();
662 }
663
ReadStatus(const string & iterator_prefix,const string & prefix,IteratorStateReader * reader,Status * status)664 Status ReadStatus(const string& iterator_prefix, const string& prefix,
665 IteratorStateReader* reader, Status* status) {
666 int64_t code_int;
667 TF_RETURN_IF_ERROR(reader->ReadScalar(
668 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
669 &code_int));
670 error::Code code = static_cast<error::Code>(code_int);
671
672 if (code != error::Code::OK) {
673 tstring error_message;
674 TF_RETURN_IF_ERROR(reader->ReadScalar(
675 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
676 &error_message));
677 *status = Status(code, error_message);
678 } else {
679 *status = OkStatus();
680 }
681 return OkStatus();
682 }
683
WriteStatus(const string & iterator_prefix,const string & prefix,const Status & status,IteratorStateWriter * writer)684 Status WriteStatus(const string& iterator_prefix, const string& prefix,
685 const Status& status, IteratorStateWriter* writer) {
686 TF_RETURN_IF_ERROR(writer->WriteScalar(
687 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
688 static_cast<int64_t>(status.code())));
689 if (!status.ok()) {
690 TF_RETURN_IF_ERROR(writer->WriteScalar(
691 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
692 status.error_message()));
693 }
694 return OkStatus();
695 }
696
ProcessBatch(int64_t batch_size,int64_t num_elements,bool drop_remainder,const Status & status,IteratorContext * ctx,std::vector<Tensor> * output,bool * end_of_sequence,std::vector<Tensor> * batch)697 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
698 bool drop_remainder, const Status& status,
699 IteratorContext* ctx, std::vector<Tensor>* output,
700 bool* end_of_sequence, std::vector<Tensor>* batch) {
701 if (num_elements == 0) {
702 if (status.ok() || errors::IsOutOfRange(status)) {
703 *end_of_sequence = true;
704 return OkStatus();
705 } else {
706 *end_of_sequence = false;
707 return status;
708 }
709 }
710 if (!status.ok() && !errors::IsOutOfRange(status)) {
711 *end_of_sequence = false;
712 return status;
713 }
714 if (num_elements < batch_size) {
715 if (drop_remainder) {
716 *end_of_sequence = true;
717 return OkStatus();
718 }
719 for (size_t i = 0; i < batch->size(); ++i) {
720 TensorShape component_shape((*batch)[i].shape());
721 component_shape.set_dim(0, num_elements);
722 AllocatorAttributes attr;
723 attr.set_gpu_compatible(true);
724 output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
725 component_shape);
726 if (!output->back().IsInitialized()) {
727 return errors::ResourceExhausted(
728 "Failed to allocate memory for the batch of component ", i);
729 }
730 TF_RETURN_IF_ERROR(
731 CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
732 }
733 } else {
734 *output = std::move(*batch);
735 }
736 *end_of_sequence = false;
737 return OkStatus();
738 }
739
CopyBatch(CopyBatchParams params,const std::vector<std::vector<Tensor>> & batch_elements,bool parallel_copy,std::function<Status ()> allocation_callback,std::vector<Tensor> * out_tensors)740 Status CopyBatch(CopyBatchParams params,
741 const std::vector<std::vector<Tensor>>& batch_elements,
742 bool parallel_copy,
743 std::function<Status()> allocation_callback,
744 std::vector<Tensor>* out_tensors) {
745 const size_t num_tuple_components = batch_elements.at(0).size();
746 out_tensors->reserve(num_tuple_components);
747 const int64_t num_batch_elements = batch_elements.size();
748 for (size_t component_index = 0; component_index < num_tuple_components;
749 ++component_index) {
750 const Tensor& first_element = batch_elements.at(0)[component_index];
751 TensorShape first_element_shape(first_element.shape());
752 TensorShape batch_component_shape({num_batch_elements});
753 batch_component_shape.AppendShape(first_element_shape);
754 out_tensors->emplace_back(params.allocator, first_element.dtype(),
755 batch_component_shape);
756 if (!out_tensors->back().IsInitialized()) {
757 return errors::ResourceExhausted(
758 "Failed to allocate memory for the batch of component ",
759 component_index);
760 }
761 }
762 if (allocation_callback) {
763 TF_RETURN_IF_ERROR(allocation_callback());
764 }
765 for (size_t component_index = 0; component_index < num_tuple_components;
766 ++component_index) {
767 Tensor& batch_component = out_tensors->at(component_index);
768 const Tensor& first_element = batch_elements.at(0)[component_index];
769 TensorShape first_element_shape(first_element.shape());
770 // Build the output tuple component by copying one slice from each input
771 // element in the batch.
772 auto copy_element_fn = [component_index, &batch_elements, &batch_component,
773 &first_element_shape](int index) {
774 if (batch_elements.at(index)[component_index].shape() !=
775 first_element_shape) {
776 return errors::InvalidArgument(
777 "Cannot batch tensors with different shapes in component ",
778 component_index, ". First element had shape ",
779 first_element_shape.DebugString(), " and element ", index,
780 " had shape ",
781 batch_elements.at(index)[component_index].shape().DebugString(),
782 ".");
783 }
784 return batch_util::CopyElementToSlice(
785 std::move(batch_elements.at(index)[component_index]),
786 &batch_component, index);
787 };
788 if (parallel_copy && first_element.AllocatedBytes() > (1 << 15)) {
789 Status status;
790 mutex status_mu;
791 BlockingCounter counter(num_batch_elements);
792 const auto num_threads = params.runner_threadpool_size;
793 const auto slice_size = num_batch_elements / num_threads;
794 int64_t offset = 0;
795 for (size_t i = 0; i < num_threads; ++i) {
796 int64_t length = slice_size;
797 // When the number of threads does not divide the number of elements
798 // evenly, the size of some slices is incremented to guarantee their
799 // sizes add up to the total number of elements.
800 if (i < num_batch_elements % num_threads) ++length;
801 (*params.runner)([offset, length, &status, &status_mu, &counter,
802 ©_element_fn]() {
803 for (size_t j = offset; j < offset + length; ++j) {
804 {
805 Status s = copy_element_fn(j);
806 mutex_lock l(status_mu);
807 status.Update(s);
808 }
809 counter.DecrementCount();
810 }
811 });
812 offset += length;
813 }
814 counter.Wait();
815 TF_RETURN_IF_ERROR(status);
816 } else {
817 for (size_t i = 0; i < num_batch_elements; ++i) {
818 TF_RETURN_IF_ERROR(copy_element_fn(i));
819 }
820 }
821 }
822 return OkStatus();
823 }
824
CreateGraphRewriteConfigs(const Options & options)825 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) {
826 absl::flat_hash_set<tstring> configs;
827 const auto& autotune_options = options.autotune_options();
828 std::vector<tstring> autotune_only_optimizations = {
829 kAutotuneBufferSizesOpt,
830 kBatchParallelizationOpt,
831 kDisablePrefetchLegacyAutotuneOpt,
832 kEnableGradientDescentOpt,
833 kFilterParallelizationOpt,
834 kMapParallelizationOpt,
835 kInjectPrefetchOpt};
836
837 if (autotune_options.optional_enabled_case() == AutotuneOptions::kEnabled &&
838 !autotune_options.enabled()) {
839 for (const auto& optimization : autotune_only_optimizations) {
840 configs.insert(
841 absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":false"));
842 }
843 } else {
844 for (const auto& optimization : autotune_only_optimizations) {
845 configs.insert(
846 absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":true"));
847 }
848 }
849 if (options.slack()) {
850 int num_devices = 1;
851 if (options.distribute_options().optional_num_devices_case() ==
852 DistributeOptions::kNumDevices) {
853 num_devices = options.distribute_options().num_devices();
854 }
855 configs.insert(
856 absl::StrCat(kSlackOpt, ":", kSlackPeriodOpt, ":", num_devices));
857 }
858 return configs;
859 }
860
ShouldConfigureMaxIntraOpParallelism(const Options & options)861 bool ShouldConfigureMaxIntraOpParallelism(const Options& options) {
862 return options.threading_options().optional_max_intra_op_parallelism_case() ==
863 ThreadingOptions::kMaxIntraOpParallelism;
864 }
865
ShouldUsePrivateThreadPool(const Options & options)866 bool ShouldUsePrivateThreadPool(const Options& options) {
867 return options.threading_options().optional_private_threadpool_size_case() ==
868 ThreadingOptions::kPrivateThreadpoolSize;
869 }
870
ShouldUseAutotuning(const Options & options)871 bool ShouldUseAutotuning(const Options& options) {
872 return options.autotune_options().optional_enabled_case() !=
873 AutotuneOptions::kEnabled ||
874 options.autotune_options().enabled();
875 }
876
ShouldApplyOptimizations(const Options & options,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_default)877 bool ShouldApplyOptimizations(
878 const Options& options,
879 const absl::flat_hash_set<tstring>& optimizations_enabled,
880 const absl::flat_hash_set<tstring>& optimizations_default) {
881 return (options.optimization_options()
882 .optional_apply_default_optimizations_case() !=
883 OptimizationOptions::kApplyDefaultOptimizations ||
884 options.optimization_options().apply_default_optimizations() ||
885 !optimizations_enabled.empty() || !optimizations_default.empty());
886 }
887
GetAutotuneDefaultParallelism(IteratorContext * ctx)888 int64 GetAutotuneDefaultParallelism(IteratorContext* ctx) {
889 return std::min(kAutotuneDefaultParallelism, ctx->runner_threadpool_size());
890 }
891
892 // static
Register(const string & experiment,int64_t rollout_pct)893 void DatasetExperimentRegistry::Register(const string& experiment,
894 int64_t rollout_pct) {
895 mutex_lock l(*get_dataset_experiment_registry_lock());
896 get_dataset_experiments()->insert(std::make_pair(experiment, rollout_pct));
897 }
898
899 // static
Experiments()900 absl::flat_hash_map<string, int64_t> DatasetExperimentRegistry::Experiments() {
901 mutex_lock l(*get_dataset_experiment_registry_lock());
902 return *get_dataset_experiments();
903 }
904
905 namespace {
906
907 REGISTER_DATASET_EXPERIMENT("allow_small_function_optimizations", 0);
908 REGISTER_DATASET_EXPERIMENT("autotune_buffer_optimization", 0);
909 REGISTER_DATASET_EXPERIMENT(kFilterParallelizationOpt, 0);
910 REGISTER_DATASET_EXPERIMENT("inject_prefetch", 100);
911 REGISTER_DATASET_EXPERIMENT("min_outer_interleave_parallelism", 0);
912 REGISTER_DATASET_EXPERIMENT("reduce_interleave_prefetch", 0);
913 REGISTER_DATASET_EXPERIMENT("serialize_input_cycle_length", 0);
914 REGISTER_DATASET_EXPERIMENT("stage_based_autotune", 0);
915 } // namespace
916 } // namespace data
917 } // namespace tensorflow
918