• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
17 #define TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
18 
19 #include <atomic>
20 #include <random>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/notification.h"
27 #include "tensorflow/core/lib/core/threadpool.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/platform/thread_annotations.h"
30 
31 namespace tensorflow {
32 
33 // RecordYielder produces value records from a set of tfrecord files
34 // in a random order.
35 //
36 // It guarantees that:
37 //   1) all records in tfrecords are yielded within every epoch;
38 //   2) each record is yielded only once within every epoch;
39 //   3) the order in which records are yielded is highly randomized.
40 //   4) the peak memory usage is roughly avg record size *
41 //      (opts.bufsize + opts.parellelism * 16).
42 //
43 // Usage example:
44 //   RecordYielder::Options opts;
45 //   opts.file_pattern = "input-*";
46 //   opts.seed = 301;
47 //   opts.bufsize = 1000000;    // A randomized buffer with 1M records.
48 //   opts.parallelism = 8;      // Uses 8 tfrecord iterators to iterate
49 //                              // through all files.
50 //   RecordYielder yielder(opts);
51 //   string val;
52 //   while (true) {
53 //     yielder.YieldOne(&val);
54 //     // process val
55 //   }
56 //
57 // RecordYielder can be accessed by multiple threads concurrently.
58 class RecordYielder {
59  public:
60   struct Options {
61     // Glob pattern for tfrecords.
62     string file_pattern;
63 
64     // Random seed. It determines how data files are shuffled and how
65     // records are shuffled.
66     int64 seed = 0;
67 
68     // Each epoch, all files are first shuffled according to the
69     // random seed and the epoch number, and then all files are
70     // left-shifted by file_shuffle_shift_ratio * num_files slots.  If
71     // file_shuffle_shift_ratio is not within [0, 1), the
72     // implementation clip it to [0, 1).
73     float file_shuffle_shift_ratio = 0;
74 
75     // Randomization buffer keeps these many records.
76     uint64 bufsize = 1;
77 
78     // Uses these many concurrent tfrecord iterators to iterate through
79     // tfrecords.
80     int32 parallelism = 1;
81 
82     string compression_type;
83   };
84 
85   explicit RecordYielder(OpKernelConstruction* context,
86                          const RecordYielder::Options& opts);
87   ~RecordYielder();
88 
89   RecordYielder(const RecordYielder&) = delete;
90   RecordYielder& operator=(const RecordYielder&) = delete;
91 
92   // Yields one 'value'.
93   Status YieldOne(string* value);
94 
95   // Returns the current epoch number.
current_epoch()96   int64 current_epoch() const { return epoch_; }
97 
98  private:
99   typedef RecordYielder ME;
100 
101   Options opts_;
102 
103   // Backgrounds threads. Owned.
104   thread::ThreadPool* thread_;
105 
106   // Epoch number.
107   std::atomic<int64> epoch_;
108 
109   mutex mu_;
110 
111   // Turned to true when this is deleted.
112   bool stop_ GUARDED_BY(mu_) = false;
113   Status status_ GUARDED_BY(mu_);
114 
115   // PRG used for randomization.
116   std::mt19937_64 rnd_ GUARDED_BY(mu_);
117 
118   // Randomization buffer.
119   std::vector<string> buf_ GUARDED_BY(mu_);
120 
121   // True iff we are draining an epoch.
122   bool epoch_end_ = false;
123 
124   int64 num_records_added_in_epoch_ = 0;
125   int64 num_records_yielded_in_epoch_ = 0;
126 
127   // Trigger when the main loop has exited.
128   Notification main_loop_done_;
129 
130   // condition_variables.
131   condition_variable buf_empty_;
BufEmpty()132   bool BufEmpty() const SHARED_LOCKS_REQUIRED(mu_) {
133     return stop_ || buf_.empty();
134   }
135 
136   condition_variable buf_not_full_;
BufNotFull()137   bool BufNotFull() const SHARED_LOCKS_REQUIRED(mu_) {
138     return stop_ || buf_.size() < opts_.bufsize;
139   }
140 
141   condition_variable buf_enough_;
BufEnough()142   bool BufEnough() const SHARED_LOCKS_REQUIRED(mu_) {
143     // NOTE: Unless we are finishing an epoch, we want to make sure
144     // the buf_ contains enough randomized elements before yielding
145     // any.
146     return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) ||
147            (!epoch_end_ &&
148             buf_.size() >= std::max<uint64>(1, opts_.bufsize / 2));
149   }
150 
151   void MainLoop();
152   struct Shard;
153   void ShardLoop(Shard* shard);
154   bool ShouldFinish(const Status& s);
155   bool Add(std::vector<string>* values);
156 };
157 
158 }  // namespace tensorflow
159 
160 #endif  // TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
161