1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_ 17 #define TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_ 18 19 #include <atomic> 20 #include <random> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/core/notification.h" 27 #include "tensorflow/core/lib/core/threadpool.h" 28 #include "tensorflow/core/platform/macros.h" 29 #include "tensorflow/core/platform/thread_annotations.h" 30 31 namespace tensorflow { 32 33 // RecordYielder produces value records from a set of tfrecord files 34 // in a random order. 35 // 36 // It guarantees that: 37 // 1) all records in tfrecords are yielded within every epoch; 38 // 2) each record is yielded only once within every epoch; 39 // 3) the order in which records are yielded is highly randomized. 40 // 4) the peak memory usage is roughly avg record size * 41 // (opts.bufsize + opts.parellelism * 16). 42 // 43 // Usage example: 44 // RecordYielder::Options opts; 45 // opts.file_pattern = "input-*"; 46 // opts.seed = 301; 47 // opts.bufsize = 1000000; // A randomized buffer with 1M records. 48 // opts.parallelism = 8; // Uses 8 tfrecord iterators to iterate 49 // // through all files. 50 // RecordYielder yielder(opts); 51 // string val; 52 // while (true) { 53 // yielder.YieldOne(&val); 54 // // process val 55 // } 56 // 57 // RecordYielder can be accessed by multiple threads concurrently. 58 class RecordYielder { 59 public: 60 struct Options { 61 // Glob pattern for tfrecords. 62 string file_pattern; 63 64 // Random seed. It determines how data files are shuffled and how 65 // records are shuffled. 66 int64 seed = 0; 67 68 // Each epoch, all files are first shuffled according to the 69 // random seed and the epoch number, and then all files are 70 // left-shifted by file_shuffle_shift_ratio * num_files slots. If 71 // file_shuffle_shift_ratio is not within [0, 1), the 72 // implementation clip it to [0, 1). 73 float file_shuffle_shift_ratio = 0; 74 75 // Randomization buffer keeps these many records. 76 uint64 bufsize = 1; 77 78 // Uses these many concurrent tfrecord iterators to iterate through 79 // tfrecords. 80 int32 parallelism = 1; 81 82 string compression_type; 83 }; 84 85 explicit RecordYielder(OpKernelConstruction* context, 86 const RecordYielder::Options& opts); 87 ~RecordYielder(); 88 89 RecordYielder(const RecordYielder&) = delete; 90 RecordYielder& operator=(const RecordYielder&) = delete; 91 92 // Yields one 'value'. 93 Status YieldOne(string* value); 94 95 // Returns the current epoch number. current_epoch()96 int64 current_epoch() const { return epoch_; } 97 98 private: 99 typedef RecordYielder ME; 100 101 Options opts_; 102 103 // Backgrounds threads. Owned. 104 thread::ThreadPool* thread_; 105 106 // Epoch number. 107 std::atomic<int64> epoch_; 108 109 mutex mu_; 110 111 // Turned to true when this is deleted. 112 bool stop_ GUARDED_BY(mu_) = false; 113 Status status_ GUARDED_BY(mu_); 114 115 // PRG used for randomization. 116 std::mt19937_64 rnd_ GUARDED_BY(mu_); 117 118 // Randomization buffer. 119 std::vector<string> buf_ GUARDED_BY(mu_); 120 121 // True iff we are draining an epoch. 122 bool epoch_end_ = false; 123 124 int64 num_records_added_in_epoch_ = 0; 125 int64 num_records_yielded_in_epoch_ = 0; 126 127 // Trigger when the main loop has exited. 128 Notification main_loop_done_; 129 130 // condition_variables. 131 condition_variable buf_empty_; BufEmpty()132 bool BufEmpty() const SHARED_LOCKS_REQUIRED(mu_) { 133 return stop_ || buf_.empty(); 134 } 135 136 condition_variable buf_not_full_; BufNotFull()137 bool BufNotFull() const SHARED_LOCKS_REQUIRED(mu_) { 138 return stop_ || buf_.size() < opts_.bufsize; 139 } 140 141 condition_variable buf_enough_; BufEnough()142 bool BufEnough() const SHARED_LOCKS_REQUIRED(mu_) { 143 // NOTE: Unless we are finishing an epoch, we want to make sure 144 // the buf_ contains enough randomized elements before yielding 145 // any. 146 return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) || 147 (!epoch_end_ && 148 buf_.size() >= std::max<uint64>(1, opts_.bufsize / 2)); 149 } 150 151 void MainLoop(); 152 struct Shard; 153 void ShardLoop(Shard* shard); 154 bool ShouldFinish(const Status& s); 155 bool Add(std::vector<string>* values); 156 }; 157 158 } // namespace tensorflow 159 160 #endif // TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_ 161