1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/dataset_utils.h"
17
18 #include <functional>
19 #include <memory>
20 #include <queue>
21 #include <string>
22 #include <utility>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/dataset.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/op_def_builder.h"
32 #include "tensorflow/core/framework/op_def_util.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/graph/graph_def_builder.h"
37 #include "tensorflow/core/lib/core/blocking_counter.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/lib/strings/proto_serialization.h"
41 #include "tensorflow/core/platform/host_info.h"
42 #include "tensorflow/core/platform/regexp.h"
43 #include "tensorflow/core/util/work_sharder.h"
44
45 namespace tensorflow {
46 namespace data {
47 namespace {
48
49 constexpr char kOutputSize[] = "output_size";
50 constexpr char kCode[] = "code";
51 constexpr char kMessage[] = "msg";
52 constexpr char kOutput[] = "output";
53
get_dataset_experiment_registry_lock()54 static mutex* get_dataset_experiment_registry_lock() {
55 static mutex dataset_experiment_registry_lock(LINKER_INITIALIZED);
56 return &dataset_experiment_registry_lock;
57 }
58
get_dataset_experiments()59 static absl::flat_hash_map<string, int64>* get_dataset_experiments() {
60 static absl::flat_hash_map<string, int64>* experiments =
61 new absl::flat_hash_map<string, int64>;
62 return experiments;
63 }
64
65 // Use "Opt" suffix so that they are not confused with the enums in Options
66 // proto.
67 constexpr char kMapAndBatchFusionOpt[] = "map_and_batch_fusion";
68 constexpr char kNoopEliminationOpt[] = "noop_elimination";
69 constexpr char kMapParallelizationOpt[] = "map_parallelization";
70 constexpr char kShuffleAndRepeatFusionOpt[] = "shuffle_and_repeat_fusion";
71 constexpr char kFilterFusionOpt[] = "filter_fusion";
72 constexpr char kMapAndFilterFusionOpt[] = "map_and_filter_fusion";
73 constexpr char kMapFusionOpt[] = "map_fusion";
74 constexpr char kParallelBatchOpt[] = "parallel_batch";
75 constexpr char kAutotuneBufferSizesOpt[] = "autotune_buffer_sizes";
76 constexpr char kDisablePrefetchLegacyAutotuneOpt[] =
77 "disable_prefetch_legacy_autotune";
78 constexpr char kMakeSloppyOpt[] = "make_sloppy";
79 constexpr char kUseChooseFastestOpt[] = "use_choose_fastest";
80 constexpr char kBatchParallelizationOpt[] = "batch_parallelization";
81 constexpr char kEnableGradientDescentOpt[] = "enable_gradient_descent";
82 constexpr char kAutotuneOpt[] = "autotune";
83 constexpr char kSlackOpt[] = "slack";
84 constexpr char kSlackPeriodOpt[] = "slack_period";
85
DefaultOptimizationGraphRewrites(const Options & options,absl::flat_hash_set<tstring> * optimization_enabled,absl::flat_hash_set<tstring> * optimization_disabled,absl::flat_hash_set<tstring> * optimization_default)86 void DefaultOptimizationGraphRewrites(
87 const Options& options, absl::flat_hash_set<tstring>* optimization_enabled,
88 absl::flat_hash_set<tstring>* optimization_disabled,
89 absl::flat_hash_set<tstring>* optimization_default) {
90 const auto& optimization_options = options.optimization_options();
91 if (optimization_options.optional_apply_default_optimizations_case() !=
92 OptimizationOptions::kApplyDefaultOptimizations ||
93 optimization_options.apply_default_optimizations()) {
94 if (optimization_options.optional_map_and_batch_fusion_case() !=
95 OptimizationOptions::kMapAndBatchFusion) {
96 optimization_default->insert(kMapAndBatchFusionOpt);
97 }
98 if (optimization_options.optional_noop_elimination_case() !=
99 OptimizationOptions::kNoopElimination) {
100 optimization_default->insert(kNoopEliminationOpt);
101 }
102 if (optimization_options.optional_map_parallelization_case() !=
103 OptimizationOptions::kMapParallelization) {
104 optimization_default->insert(kMapParallelizationOpt);
105 }
106 if (optimization_options.optional_shuffle_and_repeat_fusion_case() !=
107 OptimizationOptions::kShuffleAndRepeatFusion) {
108 optimization_default->insert(kShuffleAndRepeatFusionOpt);
109 }
110 }
111 if (optimization_options.optional_filter_fusion_case() ==
112 OptimizationOptions::kFilterFusion) {
113 if (optimization_options.filter_fusion()) {
114 optimization_enabled->insert(kFilterFusionOpt);
115 } else {
116 optimization_disabled->insert(kFilterFusionOpt);
117 }
118 }
119 if (optimization_options.optional_map_and_batch_fusion_case() ==
120 OptimizationOptions::kMapAndBatchFusion) {
121 if (optimization_options.map_and_batch_fusion()) {
122 optimization_enabled->insert(kMapAndBatchFusionOpt);
123 } else {
124 optimization_disabled->insert(kMapAndBatchFusionOpt);
125 }
126 }
127 if (optimization_options.optional_map_and_filter_fusion_case() ==
128 OptimizationOptions::kMapAndFilterFusion) {
129 if (optimization_options.map_and_filter_fusion()) {
130 optimization_enabled->insert(kMapAndFilterFusionOpt);
131 } else {
132 optimization_disabled->insert(kMapAndFilterFusionOpt);
133 }
134 }
135 if (optimization_options.optional_map_parallelization_case() ==
136 OptimizationOptions::kMapParallelization) {
137 if (optimization_options.map_parallelization()) {
138 optimization_enabled->insert(kMapParallelizationOpt);
139 } else {
140 optimization_disabled->insert(kMapParallelizationOpt);
141 }
142 }
143 if (optimization_options.optional_map_fusion_case() ==
144 OptimizationOptions::kMapFusion) {
145 if (optimization_options.map_fusion()) {
146 optimization_enabled->insert(kMapFusionOpt);
147 } else {
148 optimization_disabled->insert(kMapFusionOpt);
149 }
150 }
151 if (optimization_options.optional_noop_elimination_case() ==
152 OptimizationOptions::kNoopElimination) {
153 if (optimization_options.noop_elimination()) {
154 optimization_enabled->insert(kNoopEliminationOpt);
155 } else {
156 optimization_disabled->insert(kNoopEliminationOpt);
157 }
158 }
159 if (optimization_options.optional_parallel_batch_case() ==
160 OptimizationOptions::kParallelBatch) {
161 if (optimization_options.parallel_batch()) {
162 optimization_enabled->insert(kParallelBatchOpt);
163 } else {
164 optimization_disabled->insert(kParallelBatchOpt);
165 }
166 }
167 if (optimization_options.optional_shuffle_and_repeat_fusion_case() ==
168 OptimizationOptions::kShuffleAndRepeatFusion) {
169 if (optimization_options.shuffle_and_repeat_fusion()) {
170 optimization_enabled->insert(kShuffleAndRepeatFusionOpt);
171 } else {
172 optimization_disabled->insert(kShuffleAndRepeatFusionOpt);
173 }
174 }
175 }
176
177 // Returns whether an op has been allowlisted as stateless. Uses a heuristic to
178 // allowlist source dataset ops which have been marked stateful due to
179 // b/65524810. Also looks up the `op_def->name` in the global
180 // `AllowlistedStatefulOpRegistry`.
IsOpAllowlisted(const OpDef * op_def)181 bool IsOpAllowlisted(const OpDef* op_def) {
182 return (op_def->output_arg_size() == 1 &&
183 op_def->output_arg(0).type() == DT_VARIANT &&
184 (absl::EndsWith(op_def->name(), "Dataset") ||
185 absl::EndsWith(op_def->name(), "DatasetV2"))) ||
186 AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
187 }
188
189 } // namespace
190
MaybeOverrideSeeds(std::pair<int64,int64> seeds)191 std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds) {
192 if (seeds.first == 0 && seeds.second == 0) {
193 return {random::New64(), random::New64()};
194 }
195 return seeds;
196 }
197
VerifyTypeMatch(const DataType & expected,const DataType & received,int index)198 Status VerifyTypeMatch(const DataType& expected, const DataType& received,
199 int index) {
200 if (expected != received) {
201 return errors::InvalidArgument("Data type mismatch at component ", index,
202 ": expected ", DataTypeString(expected),
203 " but got ", DataTypeString(received), ".");
204 }
205 return Status::OK();
206 }
207
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)208 Status VerifyTypesMatch(const DataTypeVector& expected,
209 const DataTypeVector& received) {
210 if (expected.size() != received.size()) {
211 return errors::InvalidArgument(
212 "Number of components does not match: expected ", expected.size(),
213 " types but got ", received.size(), ".");
214 }
215 for (size_t i = 0; i < expected.size(); ++i) {
216 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
217 }
218 return Status::OK();
219 }
220
VerifyTypesMatch(const DataTypeVector & expected,const std::vector<Tensor> & received)221 Status VerifyTypesMatch(const DataTypeVector& expected,
222 const std::vector<Tensor>& received) {
223 if (expected.size() != received.size()) {
224 return errors::InvalidArgument(
225 "Number of components does not match: expected ", expected.size(),
226 " types but got ", received.size(), ".");
227 }
228 for (size_t i = 0; i < expected.size(); ++i) {
229 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
230 }
231 return Status::OK();
232 }
233
VerifyShapeCompatible(const PartialTensorShape & expected,const PartialTensorShape & received,int index)234 Status VerifyShapeCompatible(const PartialTensorShape& expected,
235 const PartialTensorShape& received, int index) {
236 if (!expected.IsCompatibleWith(received)) {
237 return errors::InvalidArgument("Incompatible shapes at component ", index,
238 ": expected ", expected.DebugString(),
239 " but got ", received.DebugString(), ".");
240 }
241 return Status::OK();
242 }
243
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)244 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
245 const std::vector<PartialTensorShape>& received) {
246 if (expected.size() != received.size()) {
247 return errors::InvalidArgument(
248 "Number of components does not match: expected ", expected.size(),
249 " shapes but got ", received.size(), ".");
250 }
251 for (size_t i = 0; i < expected.size(); ++i) {
252 TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
253 }
254
255 return Status::OK();
256 }
257
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<Tensor> & received)258 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
259 const std::vector<Tensor>& received) {
260 if (expected.size() != received.size()) {
261 return errors::InvalidArgument(
262 "Number of components does not match: expected ", expected.size(),
263 " shapes but got ", received.size(), ".");
264 }
265 for (size_t i = 0; i < expected.size(); ++i) {
266 TF_RETURN_IF_ERROR(
267 VerifyShapeCompatible(expected[i], received[i].shape(), i));
268 }
269
270 return Status::OK();
271 }
272
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)273 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
274 const FunctionLibraryDefinition& to_add) {
275 for (const auto& fn : to_add.ListFunctionNames()) {
276 if (auto found = base->Find(fn)) {
277 if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
278 return errors::InvalidArgument("Cannot add function '", fn,
279 "' because a different function with "
280 "the same signature already exists.");
281 }
282 TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
283 }
284 }
285 return base->AddLibrary(to_add);
286 }
287
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)288 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
289 const FunctionDefLibrary& to_add) {
290 for (const auto& fd : to_add.function()) {
291 if (auto found = base->Find(fd.signature().name())) {
292 if (!OpDefEqual(found->signature(), fd.signature())) {
293 return errors::InvalidArgument("Cannot add function '",
294 fd.signature().name(),
295 "' because a different function with "
296 "the same signature already exists.");
297 }
298 TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
299 }
300 }
301 return base->AddLibrary(to_add);
302 }
303
IsFunctionStateful(const FunctionLibraryDefinition & library,const FunctionDef & function_def)304 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
305 const FunctionDef& function_def) {
306 if (!function_def.signature().is_stateful()) {
307 return Status::OK();
308 }
309
310 for (const NodeDef& node_def : function_def.node_def()) {
311 TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
312 }
313 return Status::OK();
314 }
315
IsNodeStateful(const FunctionLibraryDefinition & library,const NodeDef & node)316 Status IsNodeStateful(const FunctionLibraryDefinition& library,
317 const NodeDef& node) {
318 const OpDef* op_def;
319
320 // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
321 // `LookUpOpDef` errors here.
322 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
323 IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
324 op_def->name() == "Assert") {
325 return Status::OK();
326 }
327
328 if (op_def->name() == "If") {
329 const FunctionDef* then_func =
330 library.Find(node.attr().at("then_branch").func().name());
331 const FunctionDef* else_func =
332 library.Find(node.attr().at("else_branch").func().name());
333 if (then_func != nullptr) {
334 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
335 }
336 if (else_func != nullptr) {
337 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
338 }
339 return Status::OK();
340 }
341
342 if (op_def->name() == "While") {
343 const FunctionDef* cond_func =
344 library.Find(node.attr().at("cond").func().name());
345 const FunctionDef* body_func =
346 library.Find(node.attr().at("body").func().name());
347 if (cond_func != nullptr) {
348 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
349 }
350 if (body_func != nullptr) {
351 TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
352 }
353 return Status::OK();
354 }
355
356 return errors::FailedPrecondition(op_def->name(), " is stateful.");
357 }
358
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)359 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
360 std::function<void(std::function<void()>)> runner, int max_parallelism) {
361 return std::bind(
362 [max_parallelism](
363 // Note: `runner` is a const reference to avoid copying it.
364 const std::function<void(std::function<void()>)>& runner,
365 std::function<void()> fn) {
366 std::function<void()> scoped_fn = std::bind(
367 [max_parallelism](const std::function<void()>& fn) {
368 ScopedPerThreadMaxParallelism scope(max_parallelism);
369 fn();
370 },
371 std::move(fn));
372 runner(std::move(scoped_fn));
373 },
374 std::move(runner), std::placeholders::_1);
375 }
376
FromString(const std::string & s,DeterminismPolicy * out)377 Status DeterminismPolicy::FromString(const std::string& s,
378 DeterminismPolicy* out) {
379 DeterminismPolicy::Type type;
380 if (s == DeterminismPolicy::kDeterministic) {
381 type = DeterminismPolicy::Type::kDeterministic;
382 } else if (s == DeterminismPolicy::kNondeterministic) {
383 type = DeterminismPolicy::Type::kNondeterministic;
384 } else if (s == DeterminismPolicy::kDefault) {
385 type = DeterminismPolicy::Type::kDefault;
386 } else {
387 return errors::InvalidArgument("Unrecognized determinism policy: ", s);
388 }
389 *out = DeterminismPolicy(type);
390 return Status::OK();
391 }
392
DeterminismPolicy(bool is_deterministic)393 DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
394 if (is_deterministic) {
395 determinism_ = DeterminismPolicy::Type::kDeterministic;
396 } else {
397 determinism_ = DeterminismPolicy::Type::kNondeterministic;
398 }
399 }
400
String() const401 std::string DeterminismPolicy::String() const {
402 switch (determinism_) {
403 case DeterminismPolicy::Type::kDeterministic:
404 return DeterminismPolicy::kDeterministic;
405 case DeterminismPolicy::Type::kNondeterministic:
406 return DeterminismPolicy::kNondeterministic;
407 case DeterminismPolicy::Type::kDefault:
408 return DeterminismPolicy::kDefault;
409 default:
410 LOG(ERROR) << "Unrecognized determinism value";
411 return "Unrecognized";
412 }
413 }
414
MatchesAnyVersion(StringPiece op_prefix,StringPiece op_to_match)415 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) {
416 if (!absl::StartsWith(op_to_match, op_prefix)) {
417 return false;
418 }
419 if (op_to_match.length() == op_prefix.length()) {
420 return true;
421 }
422 size_t index = op_to_match.length() - 1;
423 while (isdigit(op_to_match[index])) {
424 index--;
425 }
426 return (op_to_match[index] == 'V') && (op_prefix.length() == index);
427 }
428
GetExperiments()429 absl::flat_hash_set<string> GetExperiments() {
430 return GetExperiments(port::JobName(),
431 [](const tstring& str) { return Hash64(str); });
432 }
433
GetExperiments(const string & job_name,std::function<uint64 (const string &)> hash_func)434 absl::flat_hash_set<string> GetExperiments(
435 const string& job_name, std::function<uint64(const string&)> hash_func) {
436 absl::flat_hash_set<string> experiments;
437
438 if (job_name.empty()) {
439 return experiments;
440 }
441
442 // Parse the opt-in and opt-out settings.
443 const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
444 const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
445 string opt_ins_raw;
446 if (opt_ins_raw_cs != nullptr) {
447 opt_ins_raw = string(opt_ins_raw_cs);
448 }
449 string opt_outs_raw;
450 if (opt_outs_raw_cs != nullptr) {
451 opt_outs_raw = string(opt_outs_raw_cs);
452 }
453
454 // Identify opted out experiments.
455 absl::flat_hash_map<string, int64> live_experiments =
456 DatasetExperimentRegistry::Experiments();
457 absl::flat_hash_set<string> opt_outs;
458 if (opt_outs_raw == "all") {
459 for (const auto& pair : live_experiments) {
460 opt_outs.insert(pair.first);
461 }
462 } else {
463 for (const auto& experiment :
464 str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty())) {
465 opt_outs.insert(experiment);
466 }
467 }
468
469 // Include opted in experiments unless they are opted out.
470 if (opt_ins_raw == "all") {
471 for (const auto& pair : live_experiments) {
472 auto experiment = pair.first;
473 if (!opt_outs.contains(experiment)) {
474 experiments.insert(experiment);
475 }
476 }
477 } else {
478 for (const auto& experiment :
479 str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty())) {
480 if (!opt_outs.contains(experiment)) {
481 experiments.insert(experiment);
482 }
483 }
484 }
485
486 // Stochastically include live experiments unless they are opted out.
487 for (const auto& pair : live_experiments) {
488 auto& experiment = pair.first;
489 if ((hash_func(strings::StrCat(job_name, experiment)) % 100 <
490 pair.second) &&
491 !opt_outs.contains(experiment)) {
492 experiments.insert(experiment);
493 }
494 }
495
496 return experiments;
497 }
498
LogAndRecordExperiments(const absl::flat_hash_set<string> & experiments)499 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments) {
500 if (!experiments.empty()) {
501 VLOG(1) << "The input pipeline is subject to tf.data experiments. "
502 "Please see `go/tf-data-experiments` for more details.";
503 }
504 for (auto& experiment : experiments) {
505 VLOG(1) << "The experiment \"" << experiment << "\" is applied.";
506 metrics::RecordTFDataExperiment(experiment);
507 }
508 }
509
GetOptimizations(const Options & options,absl::flat_hash_set<tstring> * optimizations_enabled,absl::flat_hash_set<tstring> * optimizations_disabled,absl::flat_hash_set<tstring> * optimizations_default)510 void GetOptimizations(const Options& options,
511 absl::flat_hash_set<tstring>* optimizations_enabled,
512 absl::flat_hash_set<tstring>* optimizations_disabled,
513 absl::flat_hash_set<tstring>* optimizations_default) {
514 DefaultOptimizationGraphRewrites(options, optimizations_enabled,
515 optimizations_disabled,
516 optimizations_default);
517 if (options.optional_deterministic_case() == Options::kDeterministic) {
518 if (options.deterministic()) {
519 optimizations_disabled->insert(kMakeSloppyOpt);
520 } else {
521 optimizations_enabled->insert(kMakeSloppyOpt);
522 }
523 }
524 if (options.optional_slack_case() == Options::kSlack) {
525 if (options.slack()) {
526 optimizations_enabled->insert(kSlackOpt);
527 } else {
528 optimizations_disabled->insert(kSlackOpt);
529 }
530 }
531 }
532
SelectOptimizations(const absl::flat_hash_set<string> & experiments,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_disabled,const absl::flat_hash_set<tstring> & optimizations_default)533 absl::flat_hash_set<tstring> SelectOptimizations(
534 const absl::flat_hash_set<string>& experiments,
535 const absl::flat_hash_set<tstring>& optimizations_enabled,
536 const absl::flat_hash_set<tstring>& optimizations_disabled,
537 const absl::flat_hash_set<tstring>& optimizations_default) {
538 absl::flat_hash_set<tstring> optimizations;
539
540 // Add the enabled and default optimizations.
541 optimizations.insert(optimizations_enabled.begin(),
542 optimizations_enabled.end());
543 optimizations.insert(optimizations_default.begin(),
544 optimizations_default.end());
545
546 // Add experiments unless they correspond to a disabled optimization.
547 for (auto& experiment : experiments) {
548 if (!optimizations_disabled.contains(experiment)) {
549 optimizations.insert(experiment);
550 }
551 }
552
553 return optimizations;
554 }
555
StripDevicePlacement(FunctionDefLibrary * library)556 void StripDevicePlacement(FunctionDefLibrary* library) {
557 for (auto& function : (*library->mutable_function())) {
558 for (auto& node : (*function.mutable_node_def())) {
559 if (!node.device().empty()) {
560 *node.mutable_device() = "";
561 }
562 }
563 }
564 }
565
CopyPartialBatch(int64_t num_elements,const Tensor & value,Tensor * output)566 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
567 Tensor* output) {
568 switch (value.dtype()) {
569 #define HANDLE_TYPE(type) \
570 case DataTypeToEnum<type>::value: { \
571 auto output_t = output->flat_outer_dims<type>(); \
572 auto value_t = value.flat_outer_dims<type>(); \
573 for (size_t i = 0; i < num_elements; i++) { \
574 output_t.template chip<0>(i) = value_t.template chip<0>(i); \
575 } \
576 return Status::OK(); \
577 }
578 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
579 #undef HANDLE_TYPE
580 default:
581 return errors::InvalidArgument("Unsupported data type: ",
582 DataTypeString(value.dtype()));
583 }
584 return Status::OK();
585 }
586
ReadBatch(IteratorContext * ctx,IteratorStateReader * reader,int64_t batch_size,const string & iterator_prefix,const string & batch_prefix,std::vector<Tensor> * batch)587 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
588 int64_t batch_size, const string& iterator_prefix,
589 const string& batch_prefix, std::vector<Tensor>* batch) {
590 int64_t output_size;
591 TF_RETURN_IF_ERROR(reader->ReadScalar(
592 FullName(iterator_prefix,
593 strings::StrCat(batch_prefix, "_", kOutputSize)),
594 &output_size));
595 batch->reserve(output_size);
596 for (int i = 0; i < output_size; i++) {
597 Tensor t;
598 TF_RETURN_IF_ERROR(
599 reader->ReadTensor(ctx->flr(), FullName(iterator_prefix, batch_prefix),
600 strings::StrCat(kOutput, "_", i), &t));
601 // If the batch was not full, we may have stored only the relevant slice.
602 // Since tensors in `BatchResult.output` are expected to have the leading
603 // dimension of size batch_size, we build a larger tensor and copy the slice
604 // read from the checkpoint into it.
605 if (t.dim_size(0) < batch_size) {
606 TensorShape component_shape(t.shape());
607 component_shape.set_dim(0, batch_size);
608 AllocatorAttributes attr;
609 attr.set_gpu_compatible(true);
610 Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
611 TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
612 batch->emplace_back(std::move(new_t));
613 } else {
614 batch->emplace_back(std::move(t));
615 }
616 }
617 return Status::OK();
618 }
619
WriteBatch(int64_t batch_size,int64_t num_elements,const string & iterator_prefix,const string & batch_prefix,IteratorStateWriter * writer,std::vector<Tensor> * batch)620 Status WriteBatch(int64_t batch_size, int64_t num_elements,
621 const string& iterator_prefix, const string& batch_prefix,
622 IteratorStateWriter* writer, std::vector<Tensor>* batch) {
623 TF_RETURN_IF_ERROR(writer->WriteScalar(
624 FullName(iterator_prefix,
625 strings::StrCat(batch_prefix, "_", kOutputSize)),
626 batch->size()));
627 for (int i = 0; i < batch->size(); i++) {
628 // If the batch is not full, we only store the first `num_elements` values.
629 // The rest of the batch tensor is *uninitialized* and accessing that will
630 // raise msan errors.
631 if (num_elements < batch_size) {
632 TF_RETURN_IF_ERROR(
633 writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
634 strings::StrCat(kOutput, "_", i),
635 (*batch)[i].Slice(0, num_elements)));
636 } else {
637 TF_RETURN_IF_ERROR(
638 writer->WriteTensor(FullName(iterator_prefix, batch_prefix),
639 strings::StrCat(kOutput, "_", i), (*batch)[i]));
640 }
641 }
642 return Status::OK();
643 }
644
ReadStatus(const string & iterator_prefix,const string & prefix,IteratorStateReader * reader,Status * status)645 Status ReadStatus(const string& iterator_prefix, const string& prefix,
646 IteratorStateReader* reader, Status* status) {
647 int64_t code_int;
648 TF_RETURN_IF_ERROR(reader->ReadScalar(
649 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
650 &code_int));
651 error::Code code = static_cast<error::Code>(code_int);
652
653 if (code != error::Code::OK) {
654 tstring error_message;
655 TF_RETURN_IF_ERROR(reader->ReadScalar(
656 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
657 &error_message));
658 *status = Status(code, error_message);
659 } else {
660 *status = Status::OK();
661 }
662 return Status::OK();
663 }
664
WriteStatus(const string & iterator_prefix,const string & prefix,const Status & status,IteratorStateWriter * writer)665 Status WriteStatus(const string& iterator_prefix, const string& prefix,
666 const Status& status, IteratorStateWriter* writer) {
667 TF_RETURN_IF_ERROR(writer->WriteScalar(
668 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
669 static_cast<int64>(status.code())));
670 if (!status.ok()) {
671 TF_RETURN_IF_ERROR(writer->WriteScalar(
672 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
673 status.error_message()));
674 }
675 return Status::OK();
676 }
677
ProcessBatch(int64_t batch_size,int64_t num_elements,bool drop_remainder,const Status & status,IteratorContext * ctx,std::vector<Tensor> * output,bool * end_of_sequence,std::vector<Tensor> * batch)678 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
679 bool drop_remainder, const Status& status,
680 IteratorContext* ctx, std::vector<Tensor>* output,
681 bool* end_of_sequence, std::vector<Tensor>* batch) {
682 if (num_elements == 0) {
683 if (status.ok() || errors::IsOutOfRange(status)) {
684 *end_of_sequence = true;
685 return Status::OK();
686 } else {
687 *end_of_sequence = false;
688 return status;
689 }
690 }
691 if (!status.ok() && !errors::IsOutOfRange(status)) {
692 *end_of_sequence = false;
693 return status;
694 }
695 if (num_elements < batch_size) {
696 if (drop_remainder) {
697 *end_of_sequence = true;
698 return Status::OK();
699 }
700 for (size_t i = 0; i < batch->size(); ++i) {
701 TensorShape component_shape((*batch)[i].shape());
702 component_shape.set_dim(0, num_elements);
703 AllocatorAttributes attr;
704 attr.set_gpu_compatible(true);
705 output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
706 component_shape);
707 if (!output->back().IsInitialized()) {
708 return errors::ResourceExhausted(
709 "Failed to allocate memory for the batch of component ", i);
710 }
711 TF_RETURN_IF_ERROR(
712 CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
713 }
714 } else {
715 *output = std::move(*batch);
716 }
717 *end_of_sequence = false;
718 return Status::OK();
719 }
720
CopyBatch(IteratorContext * ctx,const std::vector<std::vector<Tensor>> & batch_elements,bool parallel_copy,std::function<Status ()> allocation_callback,std::vector<Tensor> * out_tensors)721 Status CopyBatch(IteratorContext* ctx,
722 const std::vector<std::vector<Tensor>>& batch_elements,
723 bool parallel_copy,
724 std::function<Status()> allocation_callback,
725 std::vector<Tensor>* out_tensors) {
726 static bool in_experiment =
727 GetExperiments().contains("parallelize_batch_copy");
728 const size_t num_tuple_components = batch_elements.at(0).size();
729 out_tensors->reserve(num_tuple_components);
730 const int64_t num_batch_elements = batch_elements.size();
731 for (size_t component_index = 0; component_index < num_tuple_components;
732 ++component_index) {
733 const Tensor& first_element = batch_elements.at(0)[component_index];
734 TensorShape first_element_shape(first_element.shape());
735 TensorShape batch_component_shape({num_batch_elements});
736 batch_component_shape.AppendShape(first_element_shape);
737 out_tensors->emplace_back(ctx->allocator({}), first_element.dtype(),
738 batch_component_shape);
739 if (!out_tensors->back().IsInitialized()) {
740 return errors::ResourceExhausted(
741 "Failed to allocate memory for the batch of component ",
742 component_index);
743 }
744 }
745 if (allocation_callback) {
746 TF_RETURN_IF_ERROR(allocation_callback());
747 }
748 for (size_t component_index = 0; component_index < num_tuple_components;
749 ++component_index) {
750 Tensor& batch_component = out_tensors->at(component_index);
751 const Tensor& first_element = batch_elements.at(0)[component_index];
752 TensorShape first_element_shape(first_element.shape());
753 // Build the output tuple component by copying one slice from each input
754 // element in the batch.
755 auto copy_element_fn = [component_index, &batch_elements, &batch_component,
756 &first_element_shape](int index) {
757 if (batch_elements.at(index)[component_index].shape() !=
758 first_element_shape) {
759 return errors::InvalidArgument(
760 "Cannot batch tensors with different shapes in component ",
761 component_index, ". First element had shape ",
762 first_element_shape.DebugString(), " and element ", index,
763 " had shape ",
764 batch_elements.at(index)[component_index].shape().DebugString(),
765 ".");
766 }
767 return batch_util::CopyElementToSlice(
768 std::move(batch_elements.at(index)[component_index]),
769 &batch_component, index);
770 };
771 if (parallel_copy ||
772 (in_experiment && first_element.AllocatedBytes() > (1 << 15))) {
773 Status status;
774 mutex status_mu;
775 BlockingCounter counter(num_batch_elements);
776 const auto num_threads = ctx->runner_threadpool_size();
777 const auto slice_size = num_batch_elements / num_threads;
778 int64_t offset = 0;
779 for (size_t i = 0; i < num_threads; ++i) {
780 int64_t length = slice_size;
781 // When the number of threads does not divide the number of elements
782 // evenly, the size of some slices is incremented to guarantee their
783 // sizes add up to the total number of elements.
784 if (i < num_batch_elements % num_threads) ++length;
785 (*ctx->runner())([offset, length, &status, &status_mu, &counter,
786 ©_element_fn]() {
787 for (size_t j = offset; j < offset + length; ++j) {
788 {
789 Status s = copy_element_fn(j);
790 mutex_lock l(status_mu);
791 status.Update(s);
792 }
793 counter.DecrementCount();
794 }
795 });
796 offset += length;
797 }
798 counter.Wait();
799 TF_RETURN_IF_ERROR(status);
800 } else {
801 for (size_t i = 0; i < num_batch_elements; ++i) {
802 TF_RETURN_IF_ERROR(copy_element_fn(i));
803 }
804 }
805 }
806 return Status::OK();
807 }
808
CreateGraphRewriteConfigs(const Options & options)809 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options) {
810 absl::flat_hash_set<tstring> configs;
811 const auto& autotune_options = options.autotune_options();
812 std::vector<tstring> autotune_only_optimizations = {
813 kAutotuneBufferSizesOpt, kBatchParallelizationOpt,
814 kDisablePrefetchLegacyAutotuneOpt, kEnableGradientDescentOpt,
815 kMapParallelizationOpt};
816
817 if (autotune_options.optional_enabled_case() == AutotuneOptions::kEnabled &&
818 !autotune_options.enabled()) {
819 for (const auto& optimization : autotune_only_optimizations) {
820 configs.insert(
821 absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":false"));
822 }
823 } else {
824 for (const auto& optimization : autotune_only_optimizations) {
825 configs.insert(
826 absl::StrCat(optimization.data(), ":", kAutotuneOpt, ":true"));
827 }
828 }
829 if (options.slack()) {
830 int num_devices = 1;
831 if (options.distribute_options().optional_num_devices_case() ==
832 DistributeOptions::kNumDevices) {
833 num_devices = options.distribute_options().num_devices();
834 }
835 configs.insert(
836 absl::StrCat(kSlackOpt, ":", kSlackPeriodOpt, ":", num_devices));
837 }
838 return configs;
839 }
840
ShouldConfigureMaxIntraOpParallelism(const Options & options)841 bool ShouldConfigureMaxIntraOpParallelism(const Options& options) {
842 return options.threading_options().optional_max_intra_op_parallelism_case() ==
843 ThreadingOptions::kMaxIntraOpParallelism;
844 }
845
ShouldUsePrivateThreadPool(const Options & options)846 bool ShouldUsePrivateThreadPool(const Options& options) {
847 return options.threading_options().optional_private_threadpool_size_case() ==
848 ThreadingOptions::kPrivateThreadpoolSize;
849 }
850
ShouldUseAutotuning(const Options & options)851 bool ShouldUseAutotuning(const Options& options) {
852 return options.autotune_options().optional_enabled_case() !=
853 AutotuneOptions::kEnabled ||
854 options.autotune_options().enabled();
855 }
856
ShouldApplyOptimizations(const Options & options,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_default)857 bool ShouldApplyOptimizations(
858 const Options& options,
859 const absl::flat_hash_set<tstring>& optimizations_enabled,
860 const absl::flat_hash_set<tstring>& optimizations_default) {
861 return (options.optimization_options()
862 .optional_apply_default_optimizations_case() !=
863 OptimizationOptions::kApplyDefaultOptimizations ||
864 options.optimization_options().apply_default_optimizations() ||
865 !optimizations_enabled.empty() || !optimizations_default.empty());
866 }
867
868 // static
Register(const string & experiment,int64_t rollout_pct)869 void DatasetExperimentRegistry::Register(const string& experiment,
870 int64_t rollout_pct) {
871 mutex_lock l(*get_dataset_experiment_registry_lock());
872 get_dataset_experiments()->insert(std::make_pair(experiment, rollout_pct));
873 }
874
875 // static
Experiments()876 absl::flat_hash_map<string, int64> DatasetExperimentRegistry::Experiments() {
877 mutex_lock l(*get_dataset_experiment_registry_lock());
878 return *get_dataset_experiments();
879 }
880
881 namespace {
882
883 REGISTER_DATASET_EXPERIMENT("enable_gradient_descent", 0);
884 REGISTER_DATASET_EXPERIMENT("parallelize_batch_copy", 100);
885 REGISTER_DATASET_EXPERIMENT("max_parallelism", 50);
886 } // namespace
887 } // namespace data
888 } // namespace tensorflow
889