• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/data/dataset_utils.h"
17 
18 #include <memory>
19 #include <queue>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op_def_builder.h"
29 #include "tensorflow/core/framework/op_def_util.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/tensor.pb.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/lib/core/blocking_counter.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/lib/strings/proto_serialization.h"
38 #include "tensorflow/core/platform/regexp.h"
39 #include "tensorflow/core/util/work_sharder.h"
40 
41 namespace tensorflow {
42 namespace data {
43 namespace {
44 
45 constexpr char kDelimiter[] = "@@";
46 constexpr char kComponent[] = "component";
47 constexpr char kNumElements[] = "num_elements";
48 constexpr char kNumComponents[] = "num_components";
49 constexpr char kOutputSize[] = "output_size";
50 constexpr char kCode[] = "code";
51 constexpr char kMessage[] = "msg";
52 constexpr char kOutput[] = "output";
53 
54 }  // namespace
55 
WriteElementsToCheckpoint(IteratorStateWriter * writer,StringPiece key_prefix,const std::vector<std::vector<Tensor>> & elements)56 Status WriteElementsToCheckpoint(
57     IteratorStateWriter* writer, StringPiece key_prefix,
58     const std::vector<std::vector<Tensor>>& elements) {
59   TF_RETURN_IF_ERROR(
60       writer->WriteScalar(key_prefix, kNumElements, elements.size()));
61   for (int i = 0; i < elements.size(); ++i) {
62     const std::vector<Tensor>& element = elements[i];
63     std::string element_prefix = absl::StrCat(key_prefix, "::", i);
64     TF_RETURN_IF_ERROR(
65         writer->WriteScalar(element_prefix, kNumComponents, element.size()));
66     for (int j = 0; j < elements[i].size(); ++j) {
67       TF_RETURN_IF_ERROR(writer->WriteTensor(
68           element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j]));
69     }
70   }
71   return Status::OK();
72 }
73 
ReadElementsFromCheckpoint(IteratorStateReader * reader,StringPiece key_prefix,std::vector<std::vector<Tensor>> * elements)74 Status ReadElementsFromCheckpoint(IteratorStateReader* reader,
75                                   StringPiece key_prefix,
76                                   std::vector<std::vector<Tensor>>* elements) {
77   int64 num_elements;
78   TF_RETURN_IF_ERROR(
79       reader->ReadScalar(key_prefix, kNumElements, &num_elements));
80   elements->reserve(num_elements);
81   for (int i = 0; i < num_elements; ++i) {
82     std::string element_prefix = absl::StrCat(key_prefix, "::", i);
83     int64 num_components;
84     TF_RETURN_IF_ERROR(
85         reader->ReadScalar(element_prefix, kNumComponents, &num_components));
86     elements->emplace_back();
87     std::vector<Tensor>& element = elements->at(i);
88     element.reserve(num_components);
89     for (int j = 0; j < num_components; ++j) {
90       element.emplace_back();
91       TF_RETURN_IF_ERROR(reader->ReadTensor(
92           element_prefix, absl::StrCat(kComponent, "[", j, "]"),
93           &element.back()));
94     }
95   }
96   return Status::OK();
97 }
98 
MaybeOverrideSeeds(std::pair<int64,int64> seeds)99 std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds) {
100   if (seeds.first == 0 && seeds.second == 0) {
101     return {random::New64(), random::New64()};
102   }
103   return seeds;
104 }
105 
VerifyTypeMatch(const DataType & expected,const DataType & received,int index)106 Status VerifyTypeMatch(const DataType& expected, const DataType& received,
107                        int index) {
108   if (expected != received) {
109     return errors::InvalidArgument("Data type mismatch at component ", index,
110                                    ": expected ", DataTypeString(expected),
111                                    " but got ", DataTypeString(received), ".");
112   }
113   return Status::OK();
114 }
115 
VerifyTypesMatch(const DataTypeVector & expected,const DataTypeVector & received)116 Status VerifyTypesMatch(const DataTypeVector& expected,
117                         const DataTypeVector& received) {
118   if (expected.size() != received.size()) {
119     return errors::InvalidArgument(
120         "Number of components does not match: expected ", expected.size(),
121         " types but got ", received.size(), ".");
122   }
123   for (size_t i = 0; i < expected.size(); ++i) {
124     TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
125   }
126   return Status::OK();
127 }
128 
VerifyTypesMatch(const DataTypeVector & expected,const std::vector<Tensor> & received)129 Status VerifyTypesMatch(const DataTypeVector& expected,
130                         const std::vector<Tensor>& received) {
131   if (expected.size() != received.size()) {
132     return errors::InvalidArgument(
133         "Number of components does not match: expected ", expected.size(),
134         " types but got ", received.size(), ".");
135   }
136   for (size_t i = 0; i < expected.size(); ++i) {
137     TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
138   }
139   return Status::OK();
140 }
141 
VerifyShapeCompatible(const PartialTensorShape & expected,const PartialTensorShape & received,int index)142 Status VerifyShapeCompatible(const PartialTensorShape& expected,
143                              const PartialTensorShape& received, int index) {
144   if (!expected.IsCompatibleWith(received)) {
145     return errors::InvalidArgument("Incompatible shapes at component ", index,
146                                    ": expected ", expected.DebugString(),
147                                    " but got ", received.DebugString(), ".");
148   }
149   return Status::OK();
150 }
151 
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<PartialTensorShape> & received)152 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
153                               const std::vector<PartialTensorShape>& received) {
154   if (expected.size() != received.size()) {
155     return errors::InvalidArgument(
156         "Number of components does not match: expected ", expected.size(),
157         " shapes but got ", received.size(), ".");
158   }
159   for (size_t i = 0; i < expected.size(); ++i) {
160     TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
161   }
162 
163   return Status::OK();
164 }
165 
VerifyShapesCompatible(const std::vector<PartialTensorShape> & expected,const std::vector<Tensor> & received)166 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
167                               const std::vector<Tensor>& received) {
168   if (expected.size() != received.size()) {
169     return errors::InvalidArgument(
170         "Number of components does not match: expected ", expected.size(),
171         " shapes but got ", received.size(), ".");
172   }
173   for (size_t i = 0; i < expected.size(); ++i) {
174     TF_RETURN_IF_ERROR(
175         VerifyShapeCompatible(expected[i], received[i].shape(), i));
176   }
177 
178   return Status::OK();
179 }
180 
181 namespace {
182 
183 // We assume that all keys are of the form <iterator_prefix>:<name>. We extract
184 // the iterator name by getting rid of everything post the final colon.
GetIteratorName(StringPiece key,string * name)185 Status GetIteratorName(StringPiece key, string* name) {
186   if (!str_util::StartsWith(key, data::kFullNameRandomHex)) {
187     return errors::InvalidArgument("Save key: ", key,
188                                    " not generated using full_name.");
189   }
190   std::vector<string> split_keys = str_util::Split(key, data::kPipe);
191   if (split_keys.size() != 2) {
192     return errors::InvalidArgument("Save key: ", key,
193                                    " not generated using full_name.");
194   }
195   string real_key = split_keys[1];
196   const int pos = real_key.rfind(kColon);
197   *name = real_key.substr(0, pos);
198   return Status::OK();
199 }
200 
201 }  // namespace
202 
VariantTensorDataReader(const std::vector<const tensorflow::VariantTensorData * > & data)203 VariantTensorDataReader::VariantTensorDataReader(
204     const std::vector<const tensorflow::VariantTensorData*>& data) {
205   for (const auto& d : data) {
206     string metadata;
207     d->get_metadata(&metadata);
208     auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
209     const string name = keys[0];
210     data_[name] = d;
211     map_[name] = std::map<string, size_t>();
212     for (size_t i = 1; i < keys.size(); ++i) {
213       map_[name][keys[i]] = i - 1;
214     }
215   }
216 }
217 
ReadScalar(StringPiece key,int64 * val) const218 Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) const {
219   return ReadScalarInternal(key, val);
220 }
221 
ReadScalar(StringPiece key,tstring * val) const222 Status VariantTensorDataReader::ReadScalar(StringPiece key,
223                                            tstring* val) const {
224   return ReadScalarInternal(key, val);
225 }
226 
ReadTensor(StringPiece key,Tensor * val) const227 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) const {
228   return ReadTensorInternal(key, val);
229 }
230 
ReadScalar(StringPiece name,StringPiece key,int64 * val) const231 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
232                                            int64* val) const {
233   return ReadScalarInternal(name, key, val);
234 }
235 
ReadScalar(StringPiece name,StringPiece key,tstring * val) const236 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
237                                            tstring* val) const {
238   return ReadScalarInternal(name, key, val);
239 }
240 
ReadTensor(StringPiece name,StringPiece key,Tensor * val) const241 Status VariantTensorDataReader::ReadTensor(StringPiece name, StringPiece key,
242                                            Tensor* val) const {
243   return ReadTensorInternal(name, key, val);
244 }
245 
Contains(StringPiece key) const246 bool VariantTensorDataReader::Contains(StringPiece key) const {
247   string name;
248   if (!GetIteratorName(key, &name).ok()) {
249     return false;
250   }
251   return Contains(name, key);
252 }
253 
Contains(StringPiece n,StringPiece key) const254 bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const {
255   string name(n);
256   auto it = map_.find(name);
257   if (it == map_.end()) {
258     return false;
259   }
260   const auto& bucket = it->second;
261   return bucket.find(string(key)) != bucket.end();
262 }
263 
264 template <typename T>
ReadScalarInternal(StringPiece key,T * val) const265 Status VariantTensorDataReader::ReadScalarInternal(StringPiece key,
266                                                    T* val) const {
267   string name;
268   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
269   return ReadScalarInternal(name, key, val);
270 }
271 
ReadTensorInternal(StringPiece key,Tensor * val) const272 Status VariantTensorDataReader::ReadTensorInternal(StringPiece key,
273                                                    Tensor* val) const {
274   string name;
275   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
276   return ReadTensorInternal(name, key, val);
277 }
278 
279 template <typename T>
ReadScalarInternal(StringPiece n,StringPiece key,T * val) const280 Status VariantTensorDataReader::ReadScalarInternal(StringPiece n,
281                                                    StringPiece key,
282                                                    T* val) const {
283   string name(n);
284   auto it = map_.find(name);
285   if (it == map_.end()) {
286     return errors::NotFound(name);
287   }
288   const auto& bucket = it->second;
289   auto key_it = bucket.find(string(key));
290   if (key_it == bucket.end()) {
291     return errors::NotFound(key);
292   }
293   *val = data_.at(name)->tensors(key_it->second).scalar<T>()();
294   return Status::OK();
295 }
296 
ReadTensorInternal(StringPiece n,StringPiece key,Tensor * val) const297 Status VariantTensorDataReader::ReadTensorInternal(StringPiece n,
298                                                    StringPiece key,
299                                                    Tensor* val) const {
300   string name(n);
301   auto it = map_.find(name);
302   if (it == map_.end()) {
303     return errors::NotFound(name);
304   }
305   const auto& bucket = it->second;
306   auto key_it = bucket.find(string(key));
307   if (key_it == bucket.end()) {
308     return errors::NotFound(key);
309   }
310   *val = data_.at(name)->tensors(key_it->second);
311   return Status::OK();
312 }
313 
WriteScalar(StringPiece key,const int64 val)314 Status VariantTensorDataWriter::WriteScalar(StringPiece key, const int64 val) {
315   return WriteScalarInternal(key, val);
316 }
317 
WriteScalar(StringPiece key,const tstring & val)318 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
319                                             const tstring& val) {
320   return WriteScalarInternal(key, val);
321 }
322 
WriteTensor(StringPiece key,const Tensor & val)323 Status VariantTensorDataWriter::WriteTensor(StringPiece key,
324                                             const Tensor& val) {
325   return WriteTensorInternal(key, val);
326 }
327 
WriteScalar(StringPiece name,StringPiece key,const int64 val)328 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
329                                             const int64 val) {
330   return WriteScalarInternal(name, key, val);
331 }
332 
WriteScalar(StringPiece name,StringPiece key,const tstring & val)333 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
334                                             const tstring& val) {
335   return WriteScalarInternal(name, key, val);
336 }
337 
WriteTensor(StringPiece name,StringPiece key,const Tensor & val)338 Status VariantTensorDataWriter::WriteTensor(StringPiece name, StringPiece key,
339                                             const Tensor& val) {
340   return WriteTensorInternal(name, key, val);
341 }
342 
MaybeFlush()343 void VariantTensorDataWriter::MaybeFlush() {
344   if (is_flushed_) return;
345   for (auto& keys : keys_) {
346     const string name = keys.first;
347     string metadata = name;
348     for (size_t i = 0; i < keys_[name].size(); ++i) {
349       strings::StrAppend(&metadata, kDelimiter, keys_[name][i]);
350     }
351     data_[name]->set_metadata(metadata);
352   }
353   is_flushed_ = true;
354 }
355 
Reset()356 void VariantTensorDataWriter::Reset() {
357   is_flushed_ = false;
358   data_.clear();
359   keys_.clear();
360 }
361 
ReleaseData(std::vector<std::unique_ptr<VariantTensorData>> * variants)362 void VariantTensorDataWriter::ReleaseData(
363     std::vector<std::unique_ptr<VariantTensorData>>* variants) {
364   MaybeFlush();
365   for (auto& it : data_) {
366     variants->push_back(std::move(it.second));
367   }
368   Reset();
369 }
370 
GetData(std::vector<const VariantTensorData * > * variants)371 void VariantTensorDataWriter::GetData(
372     std::vector<const VariantTensorData*>* variants) {
373   MaybeFlush();
374   for (auto& it : data_) {
375     variants->push_back(it.second.get());
376   }
377 }
378 
379 template <typename T>
WriteScalarInternal(StringPiece key,const T & val)380 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece key,
381                                                     const T& val) {
382   if (is_flushed_) {
383     return errors::FailedPrecondition(
384         "Cannot call WriteScalar after GetData or ReleaseData is called");
385   }
386   string name;
387   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
388   return WriteScalarInternal(name, key, val);
389 }
390 
WriteTensorInternal(StringPiece key,const Tensor & val)391 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece key,
392                                                     const Tensor& val) {
393   if (is_flushed_) {
394     return errors::FailedPrecondition(
395         "Cannot call WriteTensor after GetData or ReleaseData is called");
396   }
397   string name;
398   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
399   return WriteTensorInternal(name, key, val);
400 }
401 
402 template <typename T>
WriteScalarInternal(StringPiece name,StringPiece key,const T & val)403 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name,
404                                                     StringPiece key,
405                                                     const T& val) {
406   if (is_flushed_) {
407     return errors::FailedPrecondition(
408         "Cannot call WriteScalar after GetData or ReleaseData is called");
409   }
410   Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
411   val_t.scalar<T>()() = val;
412   return WriteTensorInternal(name, key, val_t);
413 }
414 
WriteTensorInternal(StringPiece n,StringPiece key,const Tensor & val)415 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n,
416                                                     StringPiece key,
417                                                     const Tensor& val) {
418   if (is_flushed_) {
419     return errors::FailedPrecondition(
420         "Cannot call WriteTensor after GetData or ReleaseData is called");
421   }
422   DCHECK_EQ(key.find(kDelimiter), string::npos);
423   string name(n);
424   if (keys_.count(name) == 0) {
425     keys_[name] = std::vector<string>();
426   }
427   keys_[name].push_back(string(key));
428   if (data_.count(name) == 0) {
429     data_[name] = absl::make_unique<VariantTensorData>();
430     data_[name]->set_type_name("tensorflow::Iterator");
431   }
432   *(data_[name]->add_tensors()) = val;
433   return Status::OK();
434 }
435 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionLibraryDefinition & to_add)436 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
437                             const FunctionLibraryDefinition& to_add) {
438   for (const auto& fn : to_add.ListFunctionNames()) {
439     if (auto found = base->Find(fn)) {
440       if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
441         return errors::InvalidArgument("Cannot add function '", fn,
442                                        "' because a different function with "
443                                        "the same signature already exists.");
444       }
445       TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
446     }
447   }
448   return base->AddLibrary(to_add);
449 }
450 
AddToFunctionLibrary(FunctionLibraryDefinition * base,const FunctionDefLibrary & to_add)451 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
452                             const FunctionDefLibrary& to_add) {
453   for (const auto& fd : to_add.function()) {
454     if (auto found = base->Find(fd.signature().name())) {
455       if (!OpDefEqual(found->signature(), fd.signature())) {
456         return errors::InvalidArgument("Cannot add function '",
457                                        fd.signature().name(),
458                                        "' because a different function with "
459                                        "the same signature already exists.");
460       }
461       TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
462     }
463   }
464   return base->AddLibrary(to_add);
465 }
466 
RunnerWithMaxParallelism(std::function<void (std::function<void ()>)> runner,int max_parallelism)467 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
468     std::function<void(std::function<void()>)> runner, int max_parallelism) {
469   return std::bind(
470       [max_parallelism](
471           // Note: `runner` is a const reference to avoid copying it.
472           const std::function<void(std::function<void()>)>& runner,
473           std::function<void()> fn) {
474         std::function<void()> scoped_fn = std::bind(
475             [max_parallelism](const std::function<void()>& fn) {
476               ScopedPerThreadMaxParallelism scope(max_parallelism);
477               fn();
478             },
479             std::move(fn));
480         runner(std::move(scoped_fn));
481       },
482       std::move(runner), std::placeholders::_1);
483 }
484 
FromString(const std::string & s,DeterminismPolicy * out)485 Status DeterminismPolicy::FromString(const std::string& s,
486                                      DeterminismPolicy* out) {
487   DeterminismPolicy::Type type;
488   if (s == DeterminismPolicy::kDeterministic) {
489     type = DeterminismPolicy::Type::kDeterministic;
490   } else if (s == DeterminismPolicy::kNondeterministic) {
491     type = DeterminismPolicy::Type::kNondeterministic;
492   } else if (s == DeterminismPolicy::kDefault) {
493     type = DeterminismPolicy::Type::kDefault;
494   } else {
495     return errors::InvalidArgument("Unrecognized determinism policy: ", s);
496   }
497   *out = DeterminismPolicy(type);
498   return Status::OK();
499 }
500 
DeterminismPolicy(bool is_deterministic)501 DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
502   if (is_deterministic) {
503     determinism_ = DeterminismPolicy::Type::kDeterministic;
504   } else {
505     determinism_ = DeterminismPolicy::Type::kNondeterministic;
506   }
507 }
508 
String() const509 std::string DeterminismPolicy::String() const {
510   switch (determinism_) {
511     case DeterminismPolicy::Type::kDeterministic:
512       return DeterminismPolicy::kDeterministic;
513     case DeterminismPolicy::Type::kNondeterministic:
514       return DeterminismPolicy::kNondeterministic;
515     case DeterminismPolicy::Type::kDefault:
516       return DeterminismPolicy::kDefault;
517     default:
518       LOG(ERROR) << "Unrecognized determinism value";
519       return "Unrecognized";
520   }
521 }
522 
MatchesAnyVersion(StringPiece op_prefix,StringPiece op_to_match)523 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match) {
524   if (!absl::StartsWith(op_to_match, op_prefix)) {
525     return false;
526   }
527   if (op_to_match.length() == op_prefix.length()) {
528     return true;
529   }
530   size_t index = op_to_match.length() - 1;
531   while (isdigit(op_to_match[index])) {
532     index--;
533   }
534   return (op_to_match[index] == 'V') && (op_prefix.length() == index);
535 }
536 
SelectOptimizations(const string & job_name,const absl::flat_hash_map<string,uint64> & live_experiments,const std::vector<tstring> & optimizations_enabled,const std::vector<tstring> & optimizations_disabled,const std::vector<tstring> & optimizations_default,std::function<uint64 (const string &)> hash_func)537 std::vector<tstring> SelectOptimizations(
538     const string& job_name,
539     const absl::flat_hash_map<string, uint64>& live_experiments,
540     const std::vector<tstring>& optimizations_enabled,
541     const std::vector<tstring>& optimizations_disabled,
542     const std::vector<tstring>& optimizations_default,
543     std::function<uint64(const string&)> hash_func) {
544   std::vector<tstring> optimizations;
545   if (job_name.empty()) {
546     // If `job_name` is empty, apply the enabled and default optimizations
547     // directly.
548     optimizations.insert(optimizations.end(), optimizations_enabled.begin(),
549                          optimizations_enabled.end());
550     optimizations.insert(optimizations.end(), optimizations_default.begin(),
551                          optimizations_default.end());
552     return optimizations;
553   }
554 
555   // If `job_name` is non-empty, we determine which optimizations to apply to
556   // this job based on the enable/disable settings from tf.data.Options, the
557   // opt in/out settings from environment variables, and rollout condition from
558   // `live_experiments`.
559   const char* opt_ins_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
560   const char* opt_outs_raw_cs = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
561   string opt_ins_raw;
562   if (opt_ins_raw_cs != nullptr) {
563     opt_ins_raw = string(opt_ins_raw_cs);
564   }
565   string opt_outs_raw;
566   if (opt_outs_raw_cs != nullptr) {
567     opt_outs_raw = string(opt_outs_raw_cs);
568   }
569 
570   // Creates a set of optimizations.
571   absl::flat_hash_set<tstring> optimizations_set;
572 
573   // Creates the opt in and opt out settings.
574   std::vector<string> opt_ins, opt_outs;
575   if (opt_ins_raw == "all") {
576     for (auto& pair : live_experiments) {
577       opt_ins.push_back(pair.first);
578     }
579   } else {
580     opt_ins = str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty());
581   }
582   if (opt_outs_raw == "all") {
583     for (auto& pair : live_experiments) {
584       opt_outs.push_back(pair.first);
585     }
586   } else {
587     opt_outs = str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty());
588   }
589 
590   // Checks if the opt in and opt out experiments are live experiments.
591   for (auto& optimization : opt_ins) {
592     if (live_experiments.find(optimization) == live_experiments.end()) {
593       LOG(WARNING) << "The experiment \"" << optimization
594                    << "\" is opted in but it is not a live experiment.";
595     }
596   }
597   for (auto& optimization : opt_outs) {
598     if (live_experiments.find(optimization) == live_experiments.end()) {
599       LOG(WARNING) << "The experiment \"" << optimization
600                    << "\" is opted out but it is not a live experiment.";
601     }
602   }
603 
604   // Checks if the opt in settings conflict with opt out settings.
605   for (auto& optimization : opt_ins) {
606     if (std::find(opt_outs.begin(), opt_outs.end(), optimization) !=
607         opt_outs.end()) {
608       LOG(WARNING) << "The experiment \"" << optimization
609                    << "\" is set in both \"TF_DATA_EXPERIMENT_OPT_IN\" and "
610                       "\"TF_DATA_EXPERIMENT_OPT_OUT\". Unless the experiment "
611                       "corresponds to an explicitly enabled optimization, it "
612                       "is not applied.";
613     }
614   }
615 
616   // Checks if the enable/disable settings from tf.data.Options conflict with
617   // user opt in/out settings. In which case we assume tf.data.Options settings
618   // have higher priority to overwrite.
619   for (auto& optimization : optimizations_enabled) {
620     if (std::find(opt_outs.begin(), opt_outs.end(), optimization) !=
621         opt_outs.end()) {
622       LOG(WARNING) << "The optimization \"" << optimization
623                    << "\" is opt out, but is still applied since"
624                       " it is enabled through tf.data.Options.";
625     }
626   }
627   for (auto& optimization : optimizations_disabled) {
628     if (std::find(opt_ins.begin(), opt_ins.end(), optimization) !=
629         opt_ins.end()) {
630       LOG(WARNING) << "The optimization \"" << optimization
631                    << "\" is opt in, but is not applied since"
632                       " it is disabled through tf.data.Options.";
633     }
634   }
635 
636   // Add the enabled optimizations.
637   optimizations_set.insert(optimizations_enabled.begin(),
638                            optimizations_enabled.end());
639 
640   // Add the default optimizations that are not explicitly opted out.
641   for (auto& optimization : optimizations_default) {
642     if (std::find(opt_outs.begin(), opt_outs.end(), optimization) ==
643         opt_outs.end()) {
644       optimizations_set.insert(optimization);
645     }
646   }
647 
648   // Add the live experiments stochastically if they are neither opted in nor
649   // opted out.
650   for (auto& pair : live_experiments) {
651     string experiment = pair.first;
652     // Skip experiments that are explicitly opted out.
653     if (std::find(opt_outs.begin(), opt_outs.end(), experiment) !=
654         opt_outs.end()) {
655       continue;
656     }
657     // Skip experiments whose transformations are explicitly disabled.
658     if (std::find(optimizations_disabled.begin(), optimizations_disabled.end(),
659                   experiment) != optimizations_disabled.end()) {
660       continue;
661     }
662     // Apply experiments that are explicitly opted in.
663     if (std::find(opt_ins.begin(), opt_ins.end(), experiment) !=
664         opt_ins.end()) {
665       optimizations_set.insert(experiment);
666       continue;
667     }
668     // Otherwise, apply experiment stochastically based on job name and
669     // experiment roll out percentage.
670     if (hash_func(strings::StrCat(job_name, experiment)) % 100 < pair.second) {
671       optimizations_set.insert(experiment);
672     }
673   }
674 
675   optimizations.insert(optimizations.end(), optimizations_set.begin(),
676                        optimizations_set.end());
677   return optimizations;
678 }
679 
StripDevicePlacement(FunctionDefLibrary * library)680 void StripDevicePlacement(FunctionDefLibrary* library) {
681   for (auto& function : (*library->mutable_function())) {
682     for (auto& node : (*function.mutable_node_def())) {
683       if (!node.device().empty()) {
684         *node.mutable_device() = "";
685       }
686     }
687   }
688 }
689 
CopyPartialBatch(int64 num_elements,const Tensor & value,Tensor * output)690 Status CopyPartialBatch(int64 num_elements, const Tensor& value,
691                         Tensor* output) {
692   switch (value.dtype()) {
693 #define HANDLE_TYPE(type)                                         \
694   case DataTypeToEnum<type>::value: {                             \
695     auto output_t = output->flat_outer_dims<type>();              \
696     auto value_t = value.flat_outer_dims<type>();                 \
697     for (size_t i = 0; i < num_elements; i++) {                   \
698       output_t.template chip<0>(i) = value_t.template chip<0>(i); \
699     }                                                             \
700     return Status::OK();                                          \
701   }
702     TF_CALL_DATASET_TYPES(HANDLE_TYPE);
703 #undef HANDLE_TYPE
704     default:
705       return errors::InvalidArgument("Unsupported data type: ",
706                                      DataTypeString(value.dtype()));
707   }
708   return Status::OK();
709 }
710 
ReadBatch(int64 batch_size,const string & iterator_prefix,const string & batch_prefix,IteratorContext * ctx,IteratorStateReader * reader,std::vector<Tensor> * batch)711 Status ReadBatch(int64 batch_size, const string& iterator_prefix,
712                  const string& batch_prefix, IteratorContext* ctx,
713                  IteratorStateReader* reader, std::vector<Tensor>* batch) {
714   int64 output_size;
715   TF_RETURN_IF_ERROR(reader->ReadScalar(
716       FullName(iterator_prefix,
717                strings::StrCat(batch_prefix, "_", kOutputSize)),
718       &output_size));
719   batch->reserve(output_size);
720   for (int i = 0; i < output_size; i++) {
721     Tensor t;
722     TF_RETURN_IF_ERROR(reader->ReadTensor(
723         FullName(iterator_prefix,
724                  strings::StrCat(batch_prefix, "_", kOutput, "_", i)),
725         &t));
726     // If the batch was not full, we may have stored only the relevant slice.
727     // Since tensors in `BatchResult.output` are expected to have the leading
728     // dimension of size batch_size, we build a larger tensor and copy the slice
729     // read from the checkpoint into it.
730     if (t.dim_size(0) < batch_size) {
731       TensorShape component_shape(t.shape());
732       component_shape.set_dim(0, batch_size);
733       AllocatorAttributes attr;
734       attr.set_gpu_compatible(true);
735       Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape);
736       TF_RETURN_IF_ERROR(CopyPartialBatch(t.dim_size(0), t, &new_t));
737       batch->emplace_back(std::move(new_t));
738     } else {
739       batch->emplace_back(std::move(t));
740     }
741   }
742   return Status::OK();
743 }
744 
WriteBatch(int64 batch_size,int64 num_elements,const string & iterator_prefix,const string & batch_prefix,IteratorStateWriter * writer,std::vector<Tensor> * batch)745 Status WriteBatch(int64 batch_size, int64 num_elements,
746                   const string& iterator_prefix, const string& batch_prefix,
747                   IteratorStateWriter* writer, std::vector<Tensor>* batch) {
748   TF_RETURN_IF_ERROR(writer->WriteScalar(
749       FullName(iterator_prefix,
750                strings::StrCat(batch_prefix, "_", kOutputSize)),
751       batch->size()));
752   for (int i = 0; i < batch->size(); i++) {
753     // If the batch is not full, we only store the first `num_elements` values.
754     // The rest of the batch tensor is *uninitialized* and accessing that will
755     // raise msan errors.
756     if (num_elements < batch_size) {
757       TF_RETURN_IF_ERROR(writer->WriteTensor(
758           FullName(iterator_prefix,
759                    strings::StrCat(batch_prefix, "_", kOutput, "_", i)),
760           (*batch)[i].Slice(0, num_elements)));
761     } else {
762       TF_RETURN_IF_ERROR(writer->WriteTensor(
763           FullName(iterator_prefix,
764                    strings::StrCat(batch_prefix, "_", kOutput, "_", i)),
765           (*batch)[i]));
766     }
767   }
768   return Status::OK();
769 }
770 
ReadStatus(const string & iterator_prefix,const string & prefix,IteratorStateReader * reader,Status * status)771 Status ReadStatus(const string& iterator_prefix, const string& prefix,
772                   IteratorStateReader* reader, Status* status) {
773   int64 code_int;
774   TF_RETURN_IF_ERROR(reader->ReadScalar(
775       FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
776       &code_int));
777   error::Code code = static_cast<error::Code>(code_int);
778 
779   if (code != error::Code::OK) {
780     tstring error_message;
781     TF_RETURN_IF_ERROR(reader->ReadScalar(
782         FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
783         &error_message));
784     *status = Status(code, error_message);
785   } else {
786     *status = Status::OK();
787   }
788   return Status::OK();
789 }
790 
WriteStatus(const string & iterator_prefix,const string & prefix,const Status & status,IteratorStateWriter * writer)791 Status WriteStatus(const string& iterator_prefix, const string& prefix,
792                    const Status& status, IteratorStateWriter* writer) {
793   TF_RETURN_IF_ERROR(writer->WriteScalar(
794       FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)),
795       static_cast<int64>(status.code())));
796   if (!status.ok()) {
797     TF_RETURN_IF_ERROR(writer->WriteScalar(
798         FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)),
799         status.error_message()));
800   }
801   return Status::OK();
802 }
803 
ProcessBatch(int64 batch_size,int64 num_elements,bool drop_remainder,const Status & status,IteratorContext * ctx,std::vector<Tensor> * output,bool * end_of_sequence,std::vector<Tensor> * batch)804 Status ProcessBatch(int64 batch_size, int64 num_elements, bool drop_remainder,
805                     const Status& status, IteratorContext* ctx,
806                     std::vector<Tensor>* output, bool* end_of_sequence,
807                     std::vector<Tensor>* batch) {
808   if (num_elements == 0) {
809     if (status.ok() || errors::IsOutOfRange(status)) {
810       *end_of_sequence = true;
811       return Status::OK();
812     } else {
813       *end_of_sequence = false;
814       return status;
815     }
816   }
817   if (!status.ok() && !errors::IsOutOfRange(status)) {
818     *end_of_sequence = false;
819     return status;
820   }
821   if (num_elements < batch_size) {
822     if (drop_remainder) {
823       *end_of_sequence = true;
824       return Status::OK();
825     }
826     for (size_t i = 0; i < batch->size(); ++i) {
827       TensorShape component_shape((*batch)[i].shape());
828       component_shape.set_dim(0, num_elements);
829       AllocatorAttributes attr;
830       attr.set_gpu_compatible(true);
831       output->emplace_back(ctx->allocator(attr), (*batch)[i].dtype(),
832                            component_shape);
833       if (!output->back().IsInitialized()) {
834         return errors::ResourceExhausted(
835             "Failed to allocate memory for the batch of component ", i);
836       }
837       TF_RETURN_IF_ERROR(
838           CopyPartialBatch(num_elements, (*batch)[i], &output->back()));
839     }
840   } else {
841     *output = std::move(*batch);
842   }
843   *end_of_sequence = false;
844   return Status::OK();
845 }
846 
CopyBatch(bool parallel_copy,IteratorContext * ctx,std::vector<Tensor> * out_tensors,std::vector<std::vector<Tensor>> * batch_elements)847 Status CopyBatch(bool parallel_copy, IteratorContext* ctx,
848                  std::vector<Tensor>* out_tensors,
849                  std::vector<std::vector<Tensor>>* batch_elements) {
850   const size_t num_tuple_components = (*batch_elements)[0].size();
851   out_tensors->reserve(num_tuple_components);
852   const int64 num_batch_elements = batch_elements->size();
853   for (size_t component_index = 0; component_index < num_tuple_components;
854        ++component_index) {
855     const Tensor& first_element = (*batch_elements)[0][component_index];
856     TensorShape batch_component_shape({num_batch_elements});
857     // NOTE(mrry): Copy the shape of the first element here, because
858     // `first_element.shape()` will become undefined after the 0th batch element
859     // is moved into the output batch.
860     TensorShape first_element_shape(first_element.shape());
861     batch_component_shape.AppendShape(first_element_shape);
862     out_tensors->emplace_back(ctx->allocator({}), first_element.dtype(),
863                               batch_component_shape);
864     if (!out_tensors->back().IsInitialized()) {
865       return errors::ResourceExhausted(
866           "Failed to allocate memory for the batch of component ",
867           component_index);
868     }
869     Tensor& batch_component = out_tensors->back();
870     // Build the output tuple component by copying one slice from each input
871     // element in the batch.
872     auto copy_element_fn = [component_index, &batch_elements,
873                             &batch_component](int index) {
874       TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
875           std::move((*batch_elements)[index][component_index]),
876           &batch_component, index));
877       return Status::OK();
878     };
879     Status status;
880     std::unique_ptr<BlockingCounter> counter;
881     std::unique_ptr<mutex> status_mu;
882     if (TF_PREDICT_FALSE(parallel_copy)) {
883       counter = std::make_unique<BlockingCounter>(num_batch_elements);
884       status_mu = std::make_unique<mutex>();
885     }
886     for (size_t i = 0; i < num_batch_elements; ++i) {
887       if ((*batch_elements)[i][component_index].shape() !=
888           first_element_shape) {
889         return errors::InvalidArgument(
890             "Cannot batch tensors with different shapes in component ",
891             component_index, ". First element had shape ",
892             first_element_shape.DebugString(), " and element ", i,
893             " had shape ",
894             (*batch_elements)[i][component_index].shape().DebugString(), ".");
895       }
896       if (TF_PREDICT_FALSE(parallel_copy)) {
897         (*ctx->runner())(
898             [i, &status, &status_mu, &counter, &copy_element_fn]() {
899               Status s = copy_element_fn(i);
900               {
901                 mutex_lock l(*status_mu);
902                 status.Update(s);
903               }
904               counter->DecrementCount();
905             });
906       } else {
907         status.Update(copy_element_fn(i));
908       }
909     }
910     if (TF_PREDICT_FALSE(parallel_copy)) {
911       counter->Wait();
912     }
913     TF_RETURN_IF_ERROR(status);
914   }
915   return Status::OK();
916 }
917 
918 }  // namespace data
919 }  // namespace tensorflow
920