1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/range_sampler.h"
17
18 #include <unordered_set>
19 #include <vector>
20
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/gtl/map_util.h"
23 #include "tensorflow/core/lib/io/inputbuffer.h"
24 #include "tensorflow/core/lib/strings/numbers.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/types.h"
29
30 namespace tensorflow {
31
32 using gtl::ArraySlice;
33 using gtl::MutableArraySlice;
34
~RangeSampler()35 RangeSampler::~RangeSampler() {}
36
SampleBatch(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch) const37 void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
38 gtl::MutableArraySlice<int64> batch) const {
39 SampleBatchGetExpectedCount(
40 rnd, unique, batch, gtl::MutableArraySlice<float>(),
41 gtl::ArraySlice<int64>(), gtl::MutableArraySlice<float>());
42 }
43
SampleBatchGetExpectedCount(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch,gtl::MutableArraySlice<float> batch_expected_count,gtl::ArraySlice<int64> extras,gtl::MutableArraySlice<float> extras_expected_count) const44 void RangeSampler::SampleBatchGetExpectedCount(
45 random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice<int64> batch,
46 gtl::MutableArraySlice<float> batch_expected_count,
47 gtl::ArraySlice<int64> extras,
48 gtl::MutableArraySlice<float> extras_expected_count) const {
49 SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
50 extras, extras_expected_count,
51 gtl::ArraySlice<int64>());
52 }
53
54 namespace {
55
56 // Approximates the expected count of a value in the output of SampleBatch.
57 //
58 // If unique=false, then this is (Probability(value) * batch_size)
59 //
60 // We use batch_size and num_tries, where num_tries is the observed number of
61 // tries it took to get batch_size unique values.
62 //
63 // Assuming (falsely) that the number of tries to get a batch of batch_size
64 // distinct values is _always_ num_tries, the probability that the value
65 // is in a batch is (1 - (1-p)^num_tries)
ExpectedCountHelper(float p,int batch_size,int num_tries)66 static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
67 if (num_tries == batch_size) {
68 // This shortcut will always be taken if unique=false
69 return p * batch_size;
70 }
71 // numerically stable version of (1 - (1-p)^num_tries)
72 return -expm1(num_tries * log1p(-p));
73 }
74
75 } // namespace
76
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const77 void RangeSampler::SampleBatchGetExpectedCountAvoid(
78 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
79 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
80 MutableArraySlice<float> extras_expected_count,
81 ArraySlice<int64> avoided_values) const {
82 const int batch_size = batch.size();
83 int num_tries;
84
85 if (unique) {
86 CHECK_LE(batch_size + avoided_values.size(), range_);
87 std::unordered_set<int64> used(batch_size);
88 used.insert(avoided_values.begin(), avoided_values.end());
89 int num_picked = 0;
90 num_tries = 0;
91 while (num_picked < batch_size) {
92 num_tries++;
93 CHECK_LT(num_tries, kint32max);
94 int64 value = Sample(rnd);
95 if (gtl::InsertIfNotPresent(&used, value)) {
96 batch[num_picked++] = value;
97 }
98 }
99 } else {
100 CHECK_EQ(avoided_values.size(), size_t{0})
101 << "avoided_values only supported with unique=true";
102 for (int i = 0; i < batch_size; i++) {
103 batch[i] = Sample(rnd);
104 }
105 num_tries = batch_size;
106 }
107 // Compute the expected counts of the batch and the extra values
108 if (!batch_expected_count.empty()) {
109 CHECK_EQ(batch_size, batch_expected_count.size());
110 for (int i = 0; i < batch_size; i++) {
111 batch_expected_count[i] =
112 ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
113 }
114 }
115 CHECK_EQ(extras.size(), extras_expected_count.size());
116 for (size_t i = 0; i < extras.size(); i++) {
117 extras_expected_count[i] =
118 ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
119 }
120 }
121
AllSampler(int64 range)122 AllSampler::AllSampler(int64 range) : RangeSampler(range) {}
123
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const124 void AllSampler::SampleBatchGetExpectedCountAvoid(
125 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
126 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
127 MutableArraySlice<float> extras_expected_count,
128 ArraySlice<int64> avoided_values) const {
129 const int batch_size = batch.size();
130 CHECK_EQ(range_, batch_size);
131 for (int i = 0; i < batch_size; i++) {
132 batch[i] = i;
133 }
134 if (!batch_expected_count.empty()) {
135 CHECK_EQ(batch_size, batch_expected_count.size());
136 for (int i = 0; i < batch_size; i++) {
137 batch_expected_count[i] = 1;
138 }
139 }
140 CHECK_EQ(size_t{0}, avoided_values.size());
141 CHECK_EQ(extras.size(), extras_expected_count.size());
142 for (size_t i = 0; i < extras.size(); i++) {
143 extras_expected_count[i] = 1;
144 }
145 }
146
UniformSampler(int64 range)147 UniformSampler::UniformSampler(int64 range)
148 : RangeSampler(range), inv_range_(1.0 / range) {}
149
Sample(random::SimplePhilox * rnd) const150 int64 UniformSampler::Sample(random::SimplePhilox* rnd) const {
151 return rnd->Uniform64(range_);
152 }
153
Probability(int64 value) const154 float UniformSampler::Probability(int64 value) const { return inv_range_; }
155
LogUniformSampler(int64 range)156 LogUniformSampler::LogUniformSampler(int64 range)
157 : RangeSampler(range), log_range_(log1p(range)) {}
158
Sample(random::SimplePhilox * rnd) const159 int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
160 const int64 value =
161 static_cast<int64>(exp(rnd->RandDouble() * log_range_)) - 1;
162 CHECK_GE(value, 0);
163 // Mathematically, value should be <= range_, but might not be due to some
164 // floating point roundoff, so we mod by range_.
165 return value % range_;
166 }
167
Probability(int64 value) const168 float LogUniformSampler::Probability(int64 value) const {
169 // value is returned iff the call to UniformDouble(log_range_) in the
170 // Sample() function returns a value between log(value + 1)
171 // and log(value + 2). The probability of this is:
172 // (log(value + 2) - log(value + 1)) / log_range
173 // To avoid two calls to log(), we compute this as follows:
174 return (log((value + 2.0) / (value + 1.0))) / log_range_;
175 }
176
ThreadUnsafeUnigramSampler(int64 range)177 ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64 range)
178 : RangeSampler(range), picker_(range) {
179 CHECK_LT(range, kint32max);
180 }
181
Sample(random::SimplePhilox * rnd) const182 int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
183 return picker_.Pick(rnd);
184 }
185
Probability(int64 value) const186 float ThreadUnsafeUnigramSampler::Probability(int64 value) const {
187 return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
188 }
189
Update(ArraySlice<int64> values)190 void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64> values) {
191 int num_updates = std::min(static_cast<int>(values.size()),
192 kint32max - picker_.total_weight());
193 for (int i = 0; i < num_updates; i++) {
194 const int64 value = values[i];
195 picker_.set_weight(value, picker_.get_weight(value) + 1);
196 }
197 }
198
199 // Thread-safe unigram sampler
UnigramSampler(int64 range)200 UnigramSampler::UnigramSampler(int64 range)
201 : RangeSampler(range), unsafe_sampler_(range) {
202 CHECK_LT(range, kint32max);
203 }
204
Sample(random::SimplePhilox * rnd) const205 int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const {
206 mutex_lock lock(mu_); // could use reader lock
207 return unsafe_sampler_.Sample(rnd);
208 }
209
Probability(int64 value) const210 float UnigramSampler::Probability(int64 value) const {
211 mutex_lock lock(mu_); // could use reader lock
212 return unsafe_sampler_.Probability(value);
213 }
214
215 // Overriding at a high level results in far fewer lock acquisitions.
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const216 void UnigramSampler::SampleBatchGetExpectedCountAvoid(
217 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
218 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
219 MutableArraySlice<float> extras_expected_count,
220 ArraySlice<int64> avoided_values) const {
221 mutex_lock lock(mu_); // could use reader lock
222 unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
223 rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
224 avoided_values);
225 }
226
Update(ArraySlice<int64> values)227 void UnigramSampler::Update(ArraySlice<int64> values) {
228 mutex_lock lock(mu_);
229 unsafe_sampler_.Update(values);
230 }
231
FixedUnigramSampler(Env * env,int64 range,const string & vocab_file,float distortion,int32 num_reserved_ids,int32 num_shards,int32 shard)232 FixedUnigramSampler::FixedUnigramSampler(Env* env, int64 range,
233 const string& vocab_file,
234 float distortion,
235 int32 num_reserved_ids,
236 int32 num_shards, int32 shard)
237 : RangeSampler(range),
238 total_weight_(0.0),
239 num_shards_(num_shards),
240 shard_(shard) {
241 FillReservedIds(num_reserved_ids);
242 // TODO(vanhoucke): make this non-crashing.
243 TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
244 CHECK_EQ(range, weights_.size());
245 dist_sampler_.reset(new random::DistributionSampler(weights_));
246 }
247
FixedUnigramSampler(int64 range,const std::vector<float> & unigrams,float distortion,int32 num_reserved_ids,int32 num_shards,int32 shard)248 FixedUnigramSampler::FixedUnigramSampler(int64 range,
249 const std::vector<float>& unigrams,
250 float distortion,
251 int32 num_reserved_ids,
252 int32 num_shards, int32 shard)
253 : RangeSampler(range),
254 total_weight_(0.0),
255 num_shards_(num_shards),
256 shard_(shard) {
257 FillReservedIds(num_reserved_ids);
258 LoadFromUnigrams(unigrams, distortion);
259 // TODO(vanhoucke): make this non-crashing.
260 CHECK_EQ(range, weights_.size());
261 dist_sampler_.reset(new random::DistributionSampler(weights_));
262 }
263
Probability(int64 value) const264 float FixedUnigramSampler::Probability(int64 value) const {
265 if (value < 0 || static_cast<size_t>(value) >= weights_.size()) {
266 return 0.0;
267 }
268 return weights_.at(value) / total_weight_;
269 }
270
Sample(random::SimplePhilox * rnd) const271 int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
272 return dist_sampler_->Sample(rnd);
273 }
274
FillReservedIds(int32 num_reserved_ids)275 void FixedUnigramSampler::FillReservedIds(int32 num_reserved_ids) {
276 for (int32 word_id = 0; word_id < num_reserved_ids; ++word_id) {
277 if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
278 }
279 }
280
LoadFromFile(Env * env,const string & vocab_file,float distortion)281 Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
282 float distortion) {
283 std::unique_ptr<RandomAccessFile> file;
284 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
285
286 io::InputBuffer in(file.get(), 262144 /*bytes*/);
287 string line;
288 int32 word_id = weights_.size();
289 while (in.ReadLine(&line).ok()) {
290 // The vocabulary file should be in csv like format, with the last
291 // field the weight associated with the word.
292 std::vector<string> cols = str_util::Split(line, ',');
293 if (cols.empty()) continue;
294 // Skip entries that do not belong to this shard.
295 if (word_id % num_shards_ == shard_) {
296 float w = 0.0;
297 if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) {
298 return errors::InvalidArgument("Wrong vocabulary format at line: ",
299 line);
300 }
301 w = pow(w, distortion);
302 total_weight_ += w;
303 weights_.push_back(w);
304 }
305 ++word_id;
306 }
307 return Status::OK();
308 }
309
LoadFromUnigrams(const std::vector<float> & unigrams,float distortion)310 void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
311 float distortion) {
312 int32 word_id = weights_.size();
313 for (float w : unigrams) {
314 // Skip entries that do not belong to this shard.
315 if (word_id % num_shards_ == shard_) {
316 w = pow(w, distortion);
317 total_weight_ += w;
318 weights_.push_back(w);
319 }
320 ++word_id;
321 }
322 }
323
324 } // namespace tensorflow
325