1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/range_sampler.h"
17
18 #include <cmath>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/io/inputbuffer.h"
25 #include "tensorflow/core/lib/strings/numbers.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30
31 namespace tensorflow {
32
33 using gtl::ArraySlice;
34 using gtl::MutableArraySlice;
35
~RangeSampler()36 RangeSampler::~RangeSampler() {}
37
SampleBatch(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch) const38 void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
39 gtl::MutableArraySlice<int64> batch) const {
40 SampleBatchGetExpectedCount(
41 rnd, unique, batch, gtl::MutableArraySlice<float>(),
42 gtl::ArraySlice<int64>(), gtl::MutableArraySlice<float>());
43 }
44
SampleBatchGetExpectedCount(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch,gtl::MutableArraySlice<float> batch_expected_count,gtl::ArraySlice<int64> extras,gtl::MutableArraySlice<float> extras_expected_count) const45 void RangeSampler::SampleBatchGetExpectedCount(
46 random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice<int64> batch,
47 gtl::MutableArraySlice<float> batch_expected_count,
48 gtl::ArraySlice<int64> extras,
49 gtl::MutableArraySlice<float> extras_expected_count) const {
50 SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
51 extras, extras_expected_count,
52 gtl::ArraySlice<int64>());
53 }
54
55 namespace {
56
57 // Approximates the expected count of a value in the output of SampleBatch.
58 //
59 // If unique=false, then this is (Probability(value) * batch_size)
60 //
61 // We use batch_size and num_tries, where num_tries is the observed number of
62 // tries it took to get batch_size unique values.
63 //
64 // Assuming (falsely) that the number of tries to get a batch of batch_size
65 // distinct values is _always_ num_tries, the probability that the value
66 // is in a batch is (1 - (1-p)^num_tries)
ExpectedCountHelper(float p,int batch_size,int num_tries)67 static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
68 if (num_tries == batch_size) {
69 // This shortcut will always be taken if unique=false
70 return p * batch_size;
71 }
72 // numerically stable version of (1 - (1-p)^num_tries)
73 return -std::expm1(num_tries * std::log1p(-p));
74 }
75
76 } // namespace
77
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const78 void RangeSampler::SampleBatchGetExpectedCountAvoid(
79 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
80 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
81 MutableArraySlice<float> extras_expected_count,
82 ArraySlice<int64> avoided_values) const {
83 const int batch_size = batch.size();
84 int num_tries;
85
86 if (unique) {
87 CHECK_LE(static_cast<int64>(batch_size + avoided_values.size()), range_);
88 std::unordered_set<int64> used(batch_size);
89 used.insert(avoided_values.begin(), avoided_values.end());
90 int num_picked = 0;
91 num_tries = 0;
92 while (num_picked < batch_size) {
93 num_tries++;
94 CHECK_LT(num_tries, kint32max);
95 int64 value = Sample(rnd);
96 if (gtl::InsertIfNotPresent(&used, value)) {
97 batch[num_picked++] = value;
98 }
99 }
100 } else {
101 CHECK_EQ(avoided_values.size(), size_t{0})
102 << "avoided_values only supported with unique=true";
103 for (int i = 0; i < batch_size; i++) {
104 batch[i] = Sample(rnd);
105 }
106 num_tries = batch_size;
107 }
108 // Compute the expected counts of the batch and the extra values
109 if (!batch_expected_count.empty()) {
110 CHECK_EQ(batch_size, batch_expected_count.size());
111 for (int i = 0; i < batch_size; i++) {
112 batch_expected_count[i] =
113 ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
114 }
115 }
116 CHECK_EQ(extras.size(), extras_expected_count.size());
117 for (size_t i = 0; i < extras.size(); i++) {
118 extras_expected_count[i] =
119 ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
120 }
121 }
122
AllSampler(int64 range)123 AllSampler::AllSampler(int64 range) : RangeSampler(range) {}
124
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const125 void AllSampler::SampleBatchGetExpectedCountAvoid(
126 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
127 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
128 MutableArraySlice<float> extras_expected_count,
129 ArraySlice<int64> avoided_values) const {
130 const int batch_size = batch.size();
131 CHECK_EQ(range_, batch_size);
132 for (int i = 0; i < batch_size; i++) {
133 batch[i] = i;
134 }
135 if (!batch_expected_count.empty()) {
136 CHECK_EQ(batch_size, batch_expected_count.size());
137 for (int i = 0; i < batch_size; i++) {
138 batch_expected_count[i] = 1;
139 }
140 }
141 CHECK_EQ(size_t{0}, avoided_values.size());
142 CHECK_EQ(extras.size(), extras_expected_count.size());
143 for (size_t i = 0; i < extras.size(); i++) {
144 extras_expected_count[i] = 1;
145 }
146 }
147
UniformSampler(int64 range)148 UniformSampler::UniformSampler(int64 range)
149 : RangeSampler(range), inv_range_(1.0 / range) {}
150
Sample(random::SimplePhilox * rnd) const151 int64 UniformSampler::Sample(random::SimplePhilox* rnd) const {
152 return rnd->Uniform64(range_);
153 }
154
Probability(int64 value) const155 float UniformSampler::Probability(int64 value) const { return inv_range_; }
156
LogUniformSampler(int64 range)157 LogUniformSampler::LogUniformSampler(int64 range)
158 : RangeSampler(range), log_range_(log1p(range)) {}
159
Sample(random::SimplePhilox * rnd) const160 int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
161 const int64 value =
162 static_cast<int64>(exp(rnd->RandDouble() * log_range_)) - 1;
163 DCHECK_GE(value, 0);
164 // Mathematically, value should be <= range_, but might not be due to some
165 // floating point roundoff, so we mod by range_. In practice this case
166 // happens never regardless of the value of range_, including and up to
167 // DBL_MAX. But we include it as a guarantee of the function's output.
168 return value % range_;
169 }
170
Probability(int64 value) const171 float LogUniformSampler::Probability(int64 value) const {
172 // value is returned iff the call to UniformDouble(log_range_) in the
173 // Sample() function returns a value between log(value + 1)
174 // and log(value + 2). The probability of this is:
175 // (log(value + 2) - log(value + 1)) / log_range
176 // To avoid two calls to log(), we compute this as follows:
177 return (log((value + 2.0) / (value + 1.0))) / log_range_;
178 }
179
ThreadUnsafeUnigramSampler(int64 range)180 ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64 range)
181 : RangeSampler(range), picker_(range) {
182 CHECK_LT(range, kint32max);
183 }
184
Sample(random::SimplePhilox * rnd) const185 int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
186 return picker_.Pick(rnd);
187 }
188
Probability(int64 value) const189 float ThreadUnsafeUnigramSampler::Probability(int64 value) const {
190 return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
191 }
192
Update(ArraySlice<int64> values)193 void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64> values) {
194 int num_updates = std::min(static_cast<int>(values.size()),
195 kint32max - picker_.total_weight());
196 for (int i = 0; i < num_updates; i++) {
197 const int64 value = values[i];
198 picker_.set_weight(value, picker_.get_weight(value) + 1);
199 }
200 }
201
202 // Thread-safe unigram sampler
UnigramSampler(int64 range)203 UnigramSampler::UnigramSampler(int64 range)
204 : RangeSampler(range), unsafe_sampler_(range) {
205 CHECK_LT(range, kint32max);
206 }
207
Sample(random::SimplePhilox * rnd) const208 int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const {
209 mutex_lock lock(mu_); // could use reader lock
210 return unsafe_sampler_.Sample(rnd);
211 }
212
Probability(int64 value) const213 float UnigramSampler::Probability(int64 value) const {
214 mutex_lock lock(mu_); // could use reader lock
215 return unsafe_sampler_.Probability(value);
216 }
217
218 // Overriding at a high level results in far fewer lock acquisitions.
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const219 void UnigramSampler::SampleBatchGetExpectedCountAvoid(
220 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
221 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
222 MutableArraySlice<float> extras_expected_count,
223 ArraySlice<int64> avoided_values) const {
224 mutex_lock lock(mu_); // could use reader lock
225 unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
226 rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
227 avoided_values);
228 }
229
Update(ArraySlice<int64> values)230 void UnigramSampler::Update(ArraySlice<int64> values) {
231 mutex_lock lock(mu_);
232 unsafe_sampler_.Update(values);
233 }
234
FixedUnigramSampler(Env * env,int64 range,const string & vocab_file,float distortion,int32 num_reserved_ids,int32 num_shards,int32 shard)235 FixedUnigramSampler::FixedUnigramSampler(Env* env, int64 range,
236 const string& vocab_file,
237 float distortion,
238 int32 num_reserved_ids,
239 int32 num_shards, int32 shard)
240 : RangeSampler(range),
241 total_weight_(0.0),
242 num_shards_(num_shards),
243 shard_(shard) {
244 FillReservedIds(num_reserved_ids);
245 // TODO(vanhoucke): make this non-crashing.
246 TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
247 CHECK_EQ(range, weights_.size());
248 dist_sampler_.reset(new random::DistributionSampler(weights_));
249 }
250
FixedUnigramSampler(int64 range,const std::vector<float> & unigrams,float distortion,int32 num_reserved_ids,int32 num_shards,int32 shard)251 FixedUnigramSampler::FixedUnigramSampler(int64 range,
252 const std::vector<float>& unigrams,
253 float distortion,
254 int32 num_reserved_ids,
255 int32 num_shards, int32 shard)
256 : RangeSampler(range),
257 total_weight_(0.0),
258 num_shards_(num_shards),
259 shard_(shard) {
260 FillReservedIds(num_reserved_ids);
261 LoadFromUnigrams(unigrams, distortion);
262 // TODO(vanhoucke): make this non-crashing.
263 CHECK_EQ(range, weights_.size());
264 dist_sampler_.reset(new random::DistributionSampler(weights_));
265 }
266
Probability(int64 value) const267 float FixedUnigramSampler::Probability(int64 value) const {
268 if (value < 0 || static_cast<size_t>(value) >= weights_.size()) {
269 return 0.0;
270 }
271 return weights_.at(value) / total_weight_;
272 }
273
Sample(random::SimplePhilox * rnd) const274 int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
275 return dist_sampler_->Sample(rnd);
276 }
277
FillReservedIds(int32 num_reserved_ids)278 void FixedUnigramSampler::FillReservedIds(int32 num_reserved_ids) {
279 for (int32 word_id = 0; word_id < num_reserved_ids; ++word_id) {
280 if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
281 }
282 }
283
LoadFromFile(Env * env,const string & vocab_file,float distortion)284 Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
285 float distortion) {
286 std::unique_ptr<RandomAccessFile> file;
287 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
288
289 io::InputBuffer in(file.get(), 262144 /*bytes*/);
290 string line;
291 int32 word_id = weights_.size();
292 while (in.ReadLine(&line).ok()) {
293 // The vocabulary file should be in csv like format, with the last
294 // field the weight associated with the word.
295 std::vector<string> cols = str_util::Split(line, ',');
296 if (cols.empty()) continue;
297 // Skip entries that do not belong to this shard.
298 if (word_id % num_shards_ == shard_) {
299 float w = 0.0;
300 if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) {
301 return errors::InvalidArgument("Wrong vocabulary format at line: ",
302 line);
303 }
304 w = std::pow(w, distortion);
305 total_weight_ += w;
306 weights_.push_back(w);
307 }
308 ++word_id;
309 }
310 return Status::OK();
311 }
312
LoadFromUnigrams(const std::vector<float> & unigrams,float distortion)313 void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
314 float distortion) {
315 int32 word_id = weights_.size();
316 for (float w : unigrams) {
317 // Skip entries that do not belong to this shard.
318 if (word_id % num_shards_ == shard_) {
319 w = std::pow(w, distortion);
320 total_weight_ += w;
321 weights_.push_back(w);
322 }
323 ++word_id;
324 }
325 }
326
327 } // namespace tensorflow
328