1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/data/dataset_utils.h"
17
18 #include <memory>
19 #include <queue>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op_def_builder.h"
29 #include "tensorflow/core/framework/op_def_util.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/tensor.pb.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/lib/core/blocking_counter.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/lib/strings/proto_serialization.h"
38 #include "tensorflow/core/platform/regexp.h"
39 #include "tensorflow/core/util/work_sharder.h"
40
41 namespace tensorflow {
42 namespace data {
43 namespace {
44
45 constexpr char kDelimiter[] = "@@";
46 constexpr char kComponent[] = "component";
47 constexpr char kNumElements[] = "num_elements";
48 constexpr char kNumComponents[] = "num_components";
49 constexpr char kOutputSize[] = "output_size";
50 constexpr char kCode[] = "code";
51 constexpr char kMessage[] = "msg";
52 constexpr char kOutput[] = "output";
53
54 } // namespace
55
WriteElementsToCheckpoint(IteratorStateWriter * writer,StringPiece key_prefix,const std::vector<std::vector<Tensor>> & elements)56 Status WriteElementsToCheckpoint(
57 IteratorStateWriter* writer, StringPiece key_prefix,
58 const std::vector<std::vector<Tensor>>& elements) {
59 TF_RETURN_IF_ERROR(
60 writer->WriteScalar(key_prefix, kNumElements, elements.size()));
61 for (int i = 0; i < elements.size(); ++i) {
62 const std::vector<Tensor>& element = elements[i];
63 std::string element_prefix = absl::StrCat(key_prefix, "::", i);
64 TF_RETURN_IF_ERROR(
65 writer->WriteScalar(element_prefix, kNumComponents, element.size()));
66 for (int j = 0; j < elements[i].size(); ++j) {
67 TF_RETURN_IF_ERROR(writer->WriteTensor(
68 element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j]));
69 }
70 }
71 return Status::OK();
72 }
73
ReadElementsFromCheckpoint(IteratorStateReader * reader,StringPiece key_prefix,std::vector<std::vector<Tensor>> * elements)74 Status ReadElementsFromCheckpoint(IteratorStateReader* reader,
75 StringPiece key_prefix,
76 std::vector<std::vector<Tensor>>* elements) {
77 int64 num_elements;
78 TF_RETURN_IF_ERROR(
79 reader->ReadScalar(key_prefix, kNumElements, &num_elements));
80 elements->reserve(num_elements);
81 for (int i = 0; i < num_elements; ++i) {
82 std::string element_prefix = absl::StrCat(key_prefix, "::", i);
83 int64 num_components;
84 TF_RETURN_IF_ERROR(
85 reader->ReadScalar(element_prefix, kNumComponents, &num_components));
86 elements->emplace_back();
87 std::vector<Tensor>& element = elements->at(i);
88 element.reserve(num_components);
89 for (int j = 0; j < num_components; ++j) {
90 element.emplace_back();
91 TF_RETURN_IF_ERROR(reader->ReadTensor(
92 element_prefix, absl::StrCat(kComponent, "[", j, "]"),
93 &element.back()));
94 }
95 }
96 return Status::OK();
97 }
98
MaybeOverrideSeeds(std::pair<int64,int64> seeds)99 std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds) {
100 if (seeds.first == 0 && seeds.second == 0) {
101 return {random::New64(), random::New64()};
102 }
103 return seeds;
104 }
105
VerifyTypeMatch(const DataType & expected,const DataType & received,int index)106 Status VerifyTypeMatch(const DataType& expected, const DataType& received,
107 int index) {
108 if (expected != received) {
109 return errors::InvalidArgument("Data type mismatch at component ", index,
110 ": expected ", DataTypeString(expected),
111 " but got ", DataTypeString(received), ".");
112 }
113 return Status::OK();
114 }
115
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)116 Status VerifyTypesMatch(const DataTypeVector& expected,
117 const DataTypeVector& received) {
118 if (expected.size() != received.size()) {
119 return errors::InvalidArgument(
120 "Number of components does not match: expected ", expected.size(),
121 " types but got ", received.size(), ".");
122 }
123 for (size_t i = 0; i < expected.size(); ++i) {
124 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
125 }
126 return Status::OK();
127 }
128
VerifyTypesMatch(const DataTypeVector & expected,const std::vector<Tensor> & received)129 Status VerifyTypesMatch(const DataTypeVector& expected,
130 const std::vector<Tensor>& received) {
131 if (expected.size() != received.size()) {
132 return errors::InvalidArgument(
133 "Number of components does not match: expected ", expected.size(),
134 " types but got ", received.size(), ".");
135 }
136 for (size_t i = 0; i < expected.size(); ++i) {
137 TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
138 }
139 return Status::OK();
140 }
141
VerifyShapeCompatible(const PartialTensorShape & expected,const PartialTensorShape & received,int index)142 Status VerifyShapeCompatible(const PartialTensorShape& expected,
143 const PartialTensorShape& received, int index) {
144 if (!expected.IsCompatibleWith(received)) {
145 return errors::InvalidArgument("Incompatible shapes at component ", index,
146 ": expected ", expected.DebugString(),
147 " but got ", received.DebugString(), ".");
148 }
149 return Status::OK();
150 }
151
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)152 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
153 const std::vector<PartialTensorShape>& received) {
154 if (expected.size() != received.size()) {
155 return errors::InvalidArgument(
156 "Number of components does not match: expected ", expected.size(),
157 " shapes but got ", received.size(), ".");
158 }
159 for (size_t i = 0; i < expected.size(); ++i) {
160 TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
161 }
162
163 return Status::OK();
164 }
165
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<Tensor> & received)166 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
167 const std::vector<Tensor>& received) {
168 if (expected.size() != received.size()) {
169 return errors::InvalidArgument(
170 "Number of components does not match: expected ", expected.size(),
171 " shapes but got ", received.size(), ".");
172 }
173 for (size_t i = 0; i < expected.size(); ++i) {
174 TF_RETURN_IF_ERROR(
175 VerifyShapeCompatible(expected[i], received[i].shape(), i));
176 }
177
178 return Status::OK();
179 }
180
181 namespace {
182
183 // We assume that all keys are of the form <iterator_prefix>:<name>. We extract
184 // the iterator name by getting rid of everything post the final colon.
GetIteratorName(StringPiece key,string * name)185 Status GetIteratorName(StringPiece key, string* name) {
186 if (!str_util::StartsWith(key, data::kFullNameRandomHex)) {
187 return errors::InvalidArgument("Save key: ", key,
188 " not generated using full_name.");
189 }
190 std::vector<string> split_keys = str_util::Split(key, data::kPipe);
191 if (split_keys.size() != 2) {
192 return errors::InvalidArgument("Save key: ", key,
193 " not generated using full_name.");
194 }
195 string real_key = split_keys[1];
196 const int pos = real_key.rfind(kColon);
197 *name = real_key.substr(0, pos);
198 return Status::OK();
199 }
200
201 } // namespace
202
VariantTensorDataReader(const std::vector<const tensorflow::VariantTensorData * > & data)203 VariantTensorDataReader::VariantTensorDataReader(
204 const std::vector<const tensorflow::VariantTensorData*>& data) {
205 for (const auto& d : data) {
206 string metadata;
207 d->get_metadata(&metadata);
208 auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
209 const string name = keys[0];
210 data_[name] = d;
211 map_[name] = std::map<string, size_t>();
212 for (size_t i = 1; i < keys.size(); ++i) {
213 map_[name][keys[i]] = i - 1;
214 }
215 }
216 }
217
ReadScalar(StringPiece key,int64 * val) const218 Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) const {
219 return ReadScalarInternal(key, val);
220 }
221
ReadScalar(StringPiece key,tstring * val) const222 Status VariantTensorDataReader::ReadScalar(StringPiece key,
223 tstring* val) const {
224 return ReadScalarInternal(key, val);
225 }
226
ReadTensor(StringPiece key,Tensor * val) const227 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) const {
228 return ReadTensorInternal(key, val);
229 }
230
ReadScalar(StringPiece name,StringPiece key,int64 * val) const231 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
232 int64* val) const {
233 return ReadScalarInternal(name, key, val);
234 }
235
ReadScalar(StringPiece name,StringPiece key,tstring * val) const236 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
237 tstring* val) const {
238 return ReadScalarInternal(name, key, val);
239 }
240
ReadTensor(StringPiece name,StringPiece key,Tensor * val) const241 Status VariantTensorDataReader::ReadTensor(StringPiece name, StringPiece key,
242 Tensor* val) const {
243 return ReadTensorInternal(name, key, val);
244 }
245
Contains(StringPiece key) const246 bool VariantTensorDataReader::Contains(StringPiece key) const {
247 string name;
248 if (!GetIteratorName(key, &name).ok()) {
249 return false;
250 }
251 return Contains(name, key);
252 }
253
Contains(StringPiece n,StringPiece key) const254 bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const {
255 string name(n);
256 auto it = map_.find(name);
257 if (it == map_.end()) {
258 return false;
259 }
260 const auto& bucket = it->second;
261 return bucket.find(string(key)) != bucket.end();
262 }
263
264 template <typename T>
ReadScalarInternal(StringPiece key,T * val) const265 Status VariantTensorDataReader::ReadScalarInternal(StringPiece key,
266 T* val) const {
267 string name;
268 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
269 return ReadScalarInternal(name, key, val);
270 }
271
ReadTensorInternal(StringPiece key,Tensor * val) const272 Status VariantTensorDataReader::ReadTensorInternal(StringPiece key,
273 Tensor* val) const {
274 string name;
275 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
276 return ReadTensorInternal(name, key, val);
277 }
278
279 template <typename T>
ReadScalarInternal(StringPiece n,StringPiece key,T * val) const280 Status VariantTensorDataReader::ReadScalarInternal(StringPiece n,
281 StringPiece key,
282 T* val) const {
283 string name(n);
284 auto it = map_.find(name);
285 if (it == map_.end()) {
286 return errors::NotFound(name);
287 }
288 const auto& bucket = it->second;
289 auto key_it = bucket.find(string(key));
290 if (key_it == bucket.end()) {
291 return errors::NotFound(key);
292 }
293 *val = data_.at(name)->tensors(key_it->second).scalar<T>()();
294 return Status::OK();
295 }
296
ReadTensorInternal(StringPiece n,StringPiece key,Tensor * val) const297 Status VariantTensorDataReader::ReadTensorInternal(StringPiece n,
298 StringPiece key,
299 Tensor* val) const {
300 string name(n);
301 auto it = map_.find(name);
302 if (it == map_.end()) {
303 return errors::NotFound(name);
304 }
305 const auto& bucket = it->second;
306 auto key_it = bucket.find(string(key));
307 if (key_it == bucket.end()) {
308 return errors::NotFound(key);
309 }
310 *val = data_.at(name)->tensors(key_it->second);
311 return Status::OK();
312 }
313
WriteScalar(StringPiece key,const int64 val)314 Status VariantTensorDataWriter::WriteScalar(StringPiece key, const int64 val) {
315 return WriteScalarInternal(key, val);
316 }
317
WriteScalar(StringPiece key,const tstring & val)318 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
319 const tstring& val) {
320 return WriteScalarInternal(key, val);
321 }
322
WriteTensor(StringPiece key,const Tensor & val)323 Status VariantTensorDataWriter::WriteTensor(StringPiece key,
324 const Tensor& val) {
325 return WriteTensorInternal(key, val);
326 }
327
WriteScalar(StringPiece name,StringPiece key,const int64 val)328 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
329 const int64 val) {
330 return WriteScalarInternal(name, key, val);
331 }
332
WriteScalar(StringPiece name,StringPiece key,const tstring & val)333 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
334 const tstring& val) {
335 return WriteScalarInternal(name, key, val);
336 }
337
WriteTensor(StringPiece name,StringPiece key,const Tensor & val)338 Status VariantTensorDataWriter::WriteTensor(StringPiece name, StringPiece key,
339 const Tensor& val) {
340 return WriteTensorInternal(name, key, val);
341 }
342
MaybeFlush()343 void VariantTensorDataWriter::MaybeFlush() {
344 if (is_flushed_) return;
345 for (auto& keys : keys_) {
346 const string name = keys.first;
347 string metadata = name;
348 for (size_t i = 0; i < keys_[name].size(); ++i) {
349 strings::StrAppend(&metadata, kDelimiter, keys_[name][i]);
350 }
351 data_[name]->set_metadata(metadata);
352 }
353 is_flushed_ = true;
354 }
355
Reset()356 void VariantTensorDataWriter::Reset() {
357 is_flushed_ = false;
358 data_.clear();
359 keys_.clear();
360 }
361
ReleaseData(std::vector<std::unique_ptr<VariantTensorData>> * variants)362 void VariantTensorDataWriter::ReleaseData(
363 std::vector<std::unique_ptr<VariantTensorData>>* variants) {
364 MaybeFlush();
365 for (auto& it : data_) {
366 variants->push_back(std::move(it.second));
367 }
368 Reset();
369 }
370
GetData(std::vector<const VariantTensorData * > * variants)371 void VariantTensorDataWriter::GetData(
372 std::vector<const VariantTensorData*>* variants) {
373 MaybeFlush();
374 for (auto& it : data_) {
375 variants->push_back(it.second.get());
376 }
377 }
378
379 template <typename T>
WriteScalarInternal(StringPiece key,const T & val)380 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece key,
381 const T& val) {
382 if (is_flushed_) {
383 return errors::FailedPrecondition(
384 "Cannot call WriteScalar after GetData or ReleaseData is called");
385 }
386 string name;
387 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
388 return WriteScalarInternal(name, key, val);
389 }
390
WriteTensorInternal(StringPiece key,const Tensor & val)391 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece key,
392 const Tensor& val) {
393 if (is_flushed_) {
394 return errors::FailedPrecondition(
395 "Cannot call WriteTensor after GetData or ReleaseData is called");
396 }
397 string name;
398 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
399 return WriteTensorInternal(name, key, val);
400 }
401
402 template <typename T>
WriteScalarInternal(StringPiece name,StringPiece key,const T & val)403 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name,
404 StringPiece key,
405 const T& val) {
406 if (is_flushed_) {
407 return errors::FailedPrecondition(
408 "Cannot call WriteScalar after GetData or ReleaseData is called");
409 }
410 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
411 val_t.scalar<T>()() = val;
412 return WriteTensorInternal(name, key, val_t);
413 }
414
WriteTensorInternal(StringPiece n,StringPiece key,const Tensor & val)415 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n,
416 StringPiece key,
417 const Tensor& val) {
418 if (is_flushed_) {
419 return errors::FailedPrecondition(
420 "Cannot call WriteTensor after GetData or ReleaseData is called");
421 }
422 DCHECK_EQ(key.find(kDelimiter), string::npos);
423 string name(n);
424 if (keys_.count(name) == 0) {
425 keys_[name] = std::vector<string>();
426 }
427 keys_[name].push_back(string(key));
428 if (data_.count(name) == 0) {
429 data_[name] = absl::make_unique<VariantTensorData>();
430 data_[name]->set_type_name("tensorflow::Iterator");
431 }
432 *(data_[name]->add_tensors()) = val;
433 return Status::OK();
434 }
435
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)436 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
437 const FunctionLibraryDefinition& to_add) {
438 for (const auto& fn : to_add.ListFunctionNames()) {
439 if (auto found = base->Find(fn)) {
440 if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
441 return errors::InvalidArgument("Cannot add function '", fn,
442 "' because a different function with "
443 "the same signature already exists.");
444 }
445 TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
446 }
447 }
448 return base->AddLibrary(to_add);
449 }
450
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)451 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
452 const FunctionDefLibrary& to_add) {
453 for (const auto& fd : to_add.function()) {
454 if (auto found = base->Find(fd.signature().name())) {
455 if (!OpDefEqual(found->signature(), fd.signature())) {
456 return errors::InvalidArgument("Cannot add function '",
457 fd.signature().name(),
458 "' because a different function with "
459 "the same signature already exists.");
460 }
461 TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
462 }
463 }
464 return base->AddLibrary(to_add);
465 }
466
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)467 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
468 std::function<void(std::function<void()>)> runner, int max_parallelism) {
469 return std::bind(
470 [max_parallelism](
471 // Note: `runner` is a const reference to avoid copying it.
472 const std::function<void(std::function<void()>)>& runner,
473 std::function<void()> fn) {
474 std::function<void()> scoped_fn = std::bind(
475 [max_parallelism](const std::function<void()>& fn) {
476 ScopedPerThreadMaxParallelism scope(max_parallelism);
477 fn();
478 },
479 std::move(fn));
480 runner(std::move(scoped_fn));
481 },
482 std::move(runner), std::placeholders::_1);
483 }
484
FromString(const std::string & s,DeterminismPolicy * out)485 Status DeterminismPolicy::FromString(const std::string& s,
486 DeterminismPolicy* out) {
487 DeterminismPolicy::Type type;
488 if (s == DeterminismPolicy::kDeterministic) {
489 type = DeterminismPolicy::Type::kDeterministic;
490 } else if (s == DeterminismPolicy::kNondeterministic) {
491 type = DeterminismPolicy::Type::kNondeterministic;
492 } else if (s == DeterminismPolicy::kDefault) {
493 type = DeterminismPolicy::Type::kDefault;
494 } else {
495 return errors::InvalidArgument("Unrecognized determinism policy: ", s);
496 }
497 *out = DeterminismPolicy(type);
498 return Status::OK();
499 }
500
DeterminismPolicy(bool is_deterministic)501 DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
502 if (is_deterministic) {
503 determinism_ = DeterminismPolicy::Type::kDeterministic;
504 } else {
505 determinism_ = DeterminismPolicy::Type::kNondeterministic;
506 }
507 }
508
String() const509 std::string DeterminismPolicy::String() const {
510 switch (determinism_) {
511 case DeterminismPolicy::Type::kDeterministic:
512 return DeterminismPolicy::kDeterministic;
513 case DeterminismPolicy::Type::kNondeterministic:
514 return DeterminismPolicy::kNondeterministic;
515 case DeterminismPolicy::Type::kDefault:
516 return DeterminismPolicy::kDefault;
517 default:
518 LOG(ERROR) << "Unrecognized determinism value";
519 return "Unrecognized";
520 }
521 }
522
MatchesAnyVersion(StringPiece op_prefix,StringPiece op_to_match)523 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) {
524 if (!absl::StartsWith(op_to_match, op_prefix)) {
525 return false;
526 }
527 if (op_to_match.length() == op_prefix.length()) {
528 return true;
529 }
530 size_t index = op_to_match.length() - 1;
531 while (isdigit(op_to_match[index])) {
532 index--;
533 }
534 return (op_to_match[index] == 'V') && (op_prefix.length() == index);
535 }
536
SelectOptimizations(const string & job_name,const absl::flat_hash_map<string,uint64> & live_experiments,const std::vector<tstring> & optimizations_enabled,const std::vector<tstring> & optimizations_disabled,const std::vector<tstring> & optimizations_default,std::function<uint64 (const string &)> hash_func)537 std::vector<tstring> SelectOptimizations(
538 const string& job_name,
539 const absl::flat_hash_map<string, uint64>& live_experiments,
540 const std::vector<tstring>& optimizations_enabled,
541 const std::vector<tstring>& optimizations_disabled,
542 const std::vector<tstring>& optimizations_default,
543 std::function<uint64(const string&)> hash_func) {
544 std::vector<tstring> optimizations;
545 if (job_name.empty()) {
546 // If `job_name` is empty, apply the enabled and default optimizations
547 // directly.
548 optimizations.insert(optimizations.end(), optimizations_enabled.begin(),
549 optimizations_enabled.end());
550 optimizations.insert(optimizations.end(), optimizations_default.begin(),
551 optimizations_default.end());
552 return optimizations;
553 }
554
555 // If `job_name` is non-empty, we determine which optimizations to apply to
556 // this job based on the enable/disable settings from tf.data.Options, the
557 // opt in/out settings from environment variables, and rollout condition from
558 // `live_experiments`.
559 const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
560 const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
561 string opt_ins_raw;
562 if (opt_ins_raw_cs != nullptr) {
563 opt_ins_raw = string(opt_ins_raw_cs);
564 }
565 string opt_outs_raw;
566 if (opt_outs_raw_cs != nullptr) {
567 opt_outs_raw = string(opt_outs_raw_cs);
568 }
569
570 // Creates a set of optimizations.
571 absl::flat_hash_set<tstring> optimizations_set;
572
573 // Creates the opt in and opt out settings.
574 std::vector<string> opt_ins, opt_outs;
575 if (opt_ins_raw == "all") {
576 for (auto& pair : live_experiments) {
577 opt_ins.push_back(pair.first);
578 }
579 } else {
580 opt_ins = str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty());
581 }
582 if (opt_outs_raw == "all") {
583 for (auto& pair : live_experiments) {
584 opt_outs.push_back(pair.first);
585 }
586 } else {
587 opt_outs = str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty());
588 }
589
590 // Checks if the opt in and opt out experiments are live experiments.
591 for (auto& optimization : opt_ins) {
592 if (live_experiments.find(optimization) == live_experiments.end()) {
593 LOG(WARNING) << "The experiment \"" << optimization
594 << "\" is opted in but it is not a live experiment.";
595 }
596 }
597 for (auto& optimization : opt_outs) {
598 if (live_experiments.find(optimization) == live_experiments.end()) {
599 LOG(WARNING) << "The experiment \"" << optimization
600 << "\" is opted out but it is not a live experiment.";
601 }
602 }
603
604 // Checks if the opt in settings conflict with opt out settings.
605 for (auto& optimization : opt_ins) {
606 if (std::find(opt_outs.begin(), opt_outs.end(), optimization) !=
607 opt_outs.end()) {
608 LOG(WARNING) << "The experiment \"" << optimization
609 << "\" is set in both \"TF_DATA_EXPERIMENT_OPT_IN\" and "
610 "\"TF_DATA_EXPERIMENT_OPT_OUT\". Unless the experiment "
611 "corresponds to an explicitly enabled optimization, it "
612 "is not applied.";
613 }
614 }
615
616 // Checks if the enable/disable settings from tf.data.Options conflict with
617 // user opt in/out settings. In which case we assume tf.data.Options settings
618 // have higher priority to overwrite.
619 for (auto& optimization : optimizations_enabled) {
620 if (std::find(opt_outs.begin(), opt_outs.end(), optimization) !=
621 opt_outs.end()) {
622 LOG(WARNING) << "The optimization \"" << optimization
623 << "\" is opt out, but is still applied since"
624 " it is enabled through tf.data.Options.";
625 }
626 }
627 for (auto& optimization : optimizations_disabled) {
628 if (std::find(opt_ins.begin(), opt_ins.end(), optimization) !=
629 opt_ins.end()) {
630 LOG(WARNING) << "The optimization \"" << optimization
631 << "\" is opt in, but is not applied since"
632 " it is disabled through tf.data.Options.";
633 }
634 }
635
636 // Add the enabled optimizations.
637 optimizations_set.insert(optimizations_enabled.begin(),
638 optimizations_enabled.end());
639
640 // Add the default optimizations that are not explicitly opted out.
641 for (auto& optimization : optimizations_default) {
642 if (std::find(opt_outs.begin(), opt_outs.end(), optimization) ==
643 opt_outs.end()) {
644 optimizations_set.insert(optimization);
645 }
646 }
647
648 // Add the live experiments stochastically if they are neither opted in nor
649 // opted out.
650 for (auto& pair : live_experiments) {
651 string experiment = pair.first;
652 // Skip experiments that are explicitly opted out.
653 if (std::find(opt_outs.begin(), opt_outs.end(), experiment) !=
654 opt_outs.end()) {
655 continue;
656 }
657 // Skip experiments whose transformations are explicitly disabled.
658 if (std::find(optimizations_disabled.begin(), optimizations_disabled.end(),
659 experiment) != optimizations_disabled.end()) {
660 continue;
661 }
662 // Apply experiments that are explicitly opted in.
663 if (std::find(opt_ins.begin(), opt_ins.end(), experiment) !=
664 opt_ins.end()) {
665 optimizations_set.insert(experiment);
666 continue;
667 }
668 // Otherwise, apply experiment stochastically based on job name and
669 // experiment roll out percentage.
670 if (hash_func(strings::StrCat(job_name, experiment)) % 100 < pair.second) {
671 optimizations_set.insert(experiment);
672 }
673 }
674
675 optimizations.insert(optimizations.end(), optimizations_set.begin(),
676 optimizations_set.end());
677 return optimizations;
678 }
679
StripDevicePlacement(FunctionDefLibrary * library)680 void StripDevicePlacement(FunctionDefLibrary* library) {
681 for (auto& function : (*library->mutable_function())) {
682 for (auto& node : (*function.mutable_node_def())) {
683 if (!node.device().empty()) {
684 *node.mutable_device() = "";
685 }
686 }
687 }
688 }
689
CopyPartialBatch(int64 num_elements,const Tensor & value,Tensor * output)690 Status CopyPartialBatch(int64 num_elements, const Tensor& value,
691 Tensor* output) {
692 switch (value.dtype()) {
693 #define HANDLE_TYPE(type) \
694 case DataTypeToEnum<type>::value: { \
695 auto output_t = output->flat_outer_dims<type>(); \
696 auto value_t = value.flat_outer_dims<type>(); \
697 for (size_t i = 0; i < num_elements; i++) { \
698 output_t.template chip<0>(i) = value_t.template chip<0>(i); \
699 } \
700 return Status::OK(); \
701 }
702 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
703 #undef HANDLE_TYPE
704 default:
705 return errors::InvalidArgument("Unsupported data type: ",
706 DataTypeString(value.dtype()));
707 }
708 return Status::OK();
709 }
710
ReadBatch(int64 batch_size,const string & iterator_prefix,const string & batch_prefix,IteratorContext * ctx,IteratorStateReader * reader,std::vector<Tensor> * batch)711 Status ReadBatch(int64 batch_size, const string& iterator_prefix,
712 const string& batch_prefix, IteratorContext* ctx,
713 IteratorStateReader* reader, std::vector<Tensor>* batch) {
714 int64 output_size;
715 TF_RETURN_IF_ERROR(reader->ReadScalar(
716 FullName(iterator_prefix,
717 strings::StrCat(batch_prefix, "_", kOutputSize)),
718 &output_size));
719 batch->reserve(output_size);
720 for (int i = 0; i < output_size; i++) {
721 Tensor t;
722 TF_RETURN_IF_ERROR(reader->ReadTensor(
723 FullName(iterator_prefix,
724 strings::StrCat(batch_prefix, "_", kOutput, "_", i)),
725 &t));
726 // If the batch was not full, we may have stored only the relevant slice.
727 // Since tensors in `BatchResult.output` are expected to have the leading
728 // dimension of size batch_size, we build a larger tensor and copy the slice
729 // read from the checkpoint into it.
730 if (t.dim_size(0) < batch_size) {
731 TensorShape component_shape(t.shape());
732 component_shape.set_dim(0, batch_size);
733 AllocatorAttributes attr;
734 attr.set_gpu_compatible(true);
735 Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
736 TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
737 batch->emplace_back(std::move(new_t));
738 } else {
739 batch->emplace_back(std::move(t));
740 }
741 }
742 return Status::OK();
743 }
744
WriteBatch(int64 batch_size,int64 num_elements,const string & iterator_prefix,const string & batch_prefix,IteratorStateWriter * writer,std::vector<Tensor> * batch)745 Status WriteBatch(int64 batch_size, int64 num_elements,
746 const string& iterator_prefix, const string& batch_prefix,
747 IteratorStateWriter* writer, std::vector<Tensor>* batch) {
748 TF_RETURN_IF_ERROR(writer->WriteScalar(
749 FullName(iterator_prefix,
750 strings::StrCat(batch_prefix, "_", kOutputSize)),
751 batch->size()));
752 for (int i = 0; i < batch->size(); i++) {
753 // If the batch is not full, we only store the first `num_elements` values.
754 // The rest of the batch tensor is *uninitialized* and accessing that will
755 // raise msan errors.
756 if (num_elements < batch_size) {
757 TF_RETURN_IF_ERROR(writer->WriteTensor(
758 FullName(iterator_prefix,
759 strings::StrCat(batch_prefix, "_", kOutput, "_", i)),
760 (*batch)[i].Slice(0, num_elements)));
761 } else {
762 TF_RETURN_IF_ERROR(writer->WriteTensor(
763 FullName(iterator_prefix,
764 strings::StrCat(batch_prefix, "_", kOutput, "_", i)),
765 (*batch)[i]));
766 }
767 }
768 return Status::OK();
769 }
770
ReadStatus(const string & iterator_prefix,const string & prefix,IteratorStateReader * reader,Status * status)771 Status ReadStatus(const string& iterator_prefix, const string& prefix,
772 IteratorStateReader* reader, Status* status) {
773 int64 code_int;
774 TF_RETURN_IF_ERROR(reader->ReadScalar(
775 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
776 &code_int));
777 error::Code code = static_cast<error::Code>(code_int);
778
779 if (code != error::Code::OK) {
780 tstring error_message;
781 TF_RETURN_IF_ERROR(reader->ReadScalar(
782 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
783 &error_message));
784 *status = Status(code, error_message);
785 } else {
786 *status = Status::OK();
787 }
788 return Status::OK();
789 }
790
WriteStatus(const string & iterator_prefix,const string & prefix,const Status & status,IteratorStateWriter * writer)791 Status WriteStatus(const string& iterator_prefix, const string& prefix,
792 const Status& status, IteratorStateWriter* writer) {
793 TF_RETURN_IF_ERROR(writer->WriteScalar(
794 FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
795 static_cast<int64>(status.code())));
796 if (!status.ok()) {
797 TF_RETURN_IF_ERROR(writer->WriteScalar(
798 FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
799 status.error_message()));
800 }
801 return Status::OK();
802 }
803
ProcessBatch(int64 batch_size,int64 num_elements,bool drop_remainder,const Status & status,IteratorContext * ctx,std::vector<Tensor> * output,bool * end_of_sequence,std::vector<Tensor> * batch)804 Status ProcessBatch(int64 batch_size, int64 num_elements, bool drop_remainder,
805 const Status& status, IteratorContext* ctx,
806 std::vector<Tensor>* output, bool* end_of_sequence,
807 std::vector<Tensor>* batch) {
808 if (num_elements == 0) {
809 if (status.ok() || errors::IsOutOfRange(status)) {
810 *end_of_sequence = true;
811 return Status::OK();
812 } else {
813 *end_of_sequence = false;
814 return status;
815 }
816 }
817 if (!status.ok() && !errors::IsOutOfRange(status)) {
818 *end_of_sequence = false;
819 return status;
820 }
821 if (num_elements < batch_size) {
822 if (drop_remainder) {
823 *end_of_sequence = true;
824 return Status::OK();
825 }
826 for (size_t i = 0; i < batch->size(); ++i) {
827 TensorShape component_shape((*batch)[i].shape());
828 component_shape.set_dim(0, num_elements);
829 AllocatorAttributes attr;
830 attr.set_gpu_compatible(true);
831 output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
832 component_shape);
833 if (!output->back().IsInitialized()) {
834 return errors::ResourceExhausted(
835 "Failed to allocate memory for the batch of component ", i);
836 }
837 TF_RETURN_IF_ERROR(
838 CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
839 }
840 } else {
841 *output = std::move(*batch);
842 }
843 *end_of_sequence = false;
844 return Status::OK();
845 }
846
CopyBatch(bool parallel_copy,IteratorContext * ctx,std::vector<Tensor> * out_tensors,std::vector<std::vector<Tensor>> * batch_elements)847 Status CopyBatch(bool parallel_copy, IteratorContext* ctx,
848 std::vector<Tensor>* out_tensors,
849 std::vector<std::vector<Tensor>>* batch_elements) {
850 const size_t num_tuple_components = (*batch_elements)[0].size();
851 out_tensors->reserve(num_tuple_components);
852 const int64 num_batch_elements = batch_elements->size();
853 for (size_t component_index = 0; component_index < num_tuple_components;
854 ++component_index) {
855 const Tensor& first_element = (*batch_elements)[0][component_index];
856 TensorShape batch_component_shape({num_batch_elements});
857 // NOTE(mrry): Copy the shape of the first element here, because
858 // `first_element.shape()` will become undefined after the 0th batch element
859 // is moved into the output batch.
860 TensorShape first_element_shape(first_element.shape());
861 batch_component_shape.AppendShape(first_element_shape);
862 out_tensors->emplace_back(ctx->allocator({}), first_element.dtype(),
863 batch_component_shape);
864 if (!out_tensors->back().IsInitialized()) {
865 return errors::ResourceExhausted(
866 "Failed to allocate memory for the batch of component ",
867 component_index);
868 }
869 Tensor& batch_component = out_tensors->back();
870 // Build the output tuple component by copying one slice from each input
871 // element in the batch.
872 auto copy_element_fn = [component_index, &batch_elements,
873 &batch_component](int index) {
874 TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
875 std::move((*batch_elements)[index][component_index]),
876 &batch_component, index));
877 return Status::OK();
878 };
879 Status status;
880 std::unique_ptr<BlockingCounter> counter;
881 std::unique_ptr<mutex> status_mu;
882 if (TF_PREDICT_FALSE(parallel_copy)) {
883 counter = std::make_unique<BlockingCounter>(num_batch_elements);
884 status_mu = std::make_unique<mutex>();
885 }
886 for (size_t i = 0; i < num_batch_elements; ++i) {
887 if ((*batch_elements)[i][component_index].shape() !=
888 first_element_shape) {
889 return errors::InvalidArgument(
890 "Cannot batch tensors with different shapes in component ",
891 component_index, ". First element had shape ",
892 first_element_shape.DebugString(), " and element ", i,
893 " had shape ",
894 (*batch_elements)[i][component_index].shape().DebugString(), ".");
895 }
896 if (TF_PREDICT_FALSE(parallel_copy)) {
897 (*ctx->runner())(
898 [i, &status, &status_mu, &counter, ©_element_fn]() {
899 Status s = copy_element_fn(i);
900 {
901 mutex_lock l(*status_mu);
902 status.Update(s);
903 }
904 counter->DecrementCount();
905 });
906 } else {
907 status.Update(copy_element_fn(i));
908 }
909 }
910 if (TF_PREDICT_FALSE(parallel_copy)) {
911 counter->Wait();
912 }
913 TF_RETURN_IF_ERROR(status);
914 }
915 return Status::OK();
916 }
917
918 } // namespace data
919 } // namespace tensorflow
920