1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/range_sampler.h"
17
18 #include <cmath>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/io/inputbuffer.h"
25 #include "tensorflow/core/lib/strings/numbers.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30
31 namespace tensorflow {
32
33 using gtl::ArraySlice;
34 using gtl::MutableArraySlice;
35
~RangeSampler()36 RangeSampler::~RangeSampler() {}
37
SampleBatch(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch) const38 void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
39 gtl::MutableArraySlice<int64> batch) const {
40 SampleBatchGetExpectedCount(
41 rnd, unique, batch, gtl::MutableArraySlice<float>(),
42 gtl::ArraySlice<int64>(), gtl::MutableArraySlice<float>());
43 }
44
SampleBatchGetExpectedCount(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch,gtl::MutableArraySlice<float> batch_expected_count,gtl::ArraySlice<int64> extras,gtl::MutableArraySlice<float> extras_expected_count) const45 void RangeSampler::SampleBatchGetExpectedCount(
46 random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice<int64> batch,
47 gtl::MutableArraySlice<float> batch_expected_count,
48 gtl::ArraySlice<int64> extras,
49 gtl::MutableArraySlice<float> extras_expected_count) const {
50 SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
51 extras, extras_expected_count,
52 gtl::ArraySlice<int64>());
53 }
54
55 namespace {
56
57 // Approximates the expected count of a value in the output of SampleBatch.
58 //
59 // If unique=false, then this is (Probability(value) * batch_size)
60 //
61 // We use batch_size and num_tries, where num_tries is the observed number of
62 // tries it took to get batch_size unique values.
63 //
64 // Assuming (falsely) that the number of tries to get a batch of batch_size
65 // distinct values is _always_ num_tries, the probability that the value
66 // is in a batch is (1 - (1-p)^num_tries)
ExpectedCountHelper(float p,int batch_size,int num_tries)67 static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
68 if (num_tries == batch_size) {
69 // This shortcut will always be taken if unique=false
70 return p * batch_size;
71 }
72 // numerically stable version of (1 - (1-p)^num_tries)
73 return -std::expm1(num_tries * std::log1p(-p));
74 }
75
76 } // namespace
77
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const78 void RangeSampler::SampleBatchGetExpectedCountAvoid(
79 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
80 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
81 MutableArraySlice<float> extras_expected_count,
82 ArraySlice<int64> avoided_values) const {
83 const int batch_size = batch.size();
84 int num_tries;
85
86 if (unique) {
87 CHECK_LE(static_cast<int64>(batch_size + avoided_values.size()), range_);
88 std::unordered_set<int64> used(batch_size);
89 used.insert(avoided_values.begin(), avoided_values.end());
90 int num_picked = 0;
91 num_tries = 0;
92 while (num_picked < batch_size) {
93 num_tries++;
94 CHECK_LT(num_tries, kint32max);
95 int64_t value = Sample(rnd);
96 if (gtl::InsertIfNotPresent(&used, value)) {
97 batch[num_picked++] = value;
98 }
99 }
100 } else {
101 CHECK_EQ(avoided_values.size(), size_t{0})
102 << "avoided_values only supported with unique=true";
103 for (int i = 0; i < batch_size; i++) {
104 batch[i] = Sample(rnd);
105 }
106 num_tries = batch_size;
107 }
108 // Compute the expected counts of the batch and the extra values
109 if (!batch_expected_count.empty()) {
110 CHECK_EQ(batch_size, batch_expected_count.size());
111 for (int i = 0; i < batch_size; i++) {
112 batch_expected_count[i] =
113 ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
114 }
115 }
116 CHECK_EQ(extras.size(), extras_expected_count.size());
117 for (size_t i = 0; i < extras.size(); i++) {
118 extras_expected_count[i] =
119 ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
120 }
121 }
122
AllSampler(int64_t range)123 AllSampler::AllSampler(int64_t range) : RangeSampler(range) {}
124
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const125 void AllSampler::SampleBatchGetExpectedCountAvoid(
126 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
127 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
128 MutableArraySlice<float> extras_expected_count,
129 ArraySlice<int64> avoided_values) const {
130 const int batch_size = batch.size();
131 CHECK_EQ(range_, batch_size);
132 for (int i = 0; i < batch_size; i++) {
133 batch[i] = i;
134 }
135 if (!batch_expected_count.empty()) {
136 CHECK_EQ(batch_size, batch_expected_count.size());
137 for (int i = 0; i < batch_size; i++) {
138 batch_expected_count[i] = 1;
139 }
140 }
141 CHECK_EQ(size_t{0}, avoided_values.size());
142 CHECK_EQ(extras.size(), extras_expected_count.size());
143 for (size_t i = 0; i < extras.size(); i++) {
144 extras_expected_count[i] = 1;
145 }
146 }
147
UniformSampler(int64_t range)148 UniformSampler::UniformSampler(int64_t range)
149 : RangeSampler(range), inv_range_(1.0 / range) {}
150
Sample(random::SimplePhilox * rnd) const151 int64 UniformSampler::Sample(random::SimplePhilox* rnd) const {
152 return rnd->Uniform64(range_);
153 }
154
Probability(int64_t value) const155 float UniformSampler::Probability(int64_t value) const { return inv_range_; }
156
LogUniformSampler(int64_t range)157 LogUniformSampler::LogUniformSampler(int64_t range)
158 : RangeSampler(range), log_range_(log1p(range)) {}
159
Sample(random::SimplePhilox * rnd) const160 int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
161 const int64_t value =
162 static_cast<int64>(exp(rnd->RandDouble() * log_range_)) - 1;
163 DCHECK_GE(value, 0);
164 // Mathematically, value should be <= range_, but might not be due to some
165 // floating point roundoff, so we mod by range_. In practice this case
166 // happens never regardless of the value of range_, including and up to
167 // DBL_MAX. But we include it as a guarantee of the function's output.
168 return value % range_;
169 }
170
Probability(int64_t value) const171 float LogUniformSampler::Probability(int64_t value) const {
172 // value is returned iff the call to UniformDouble(log_range_) in the
173 // Sample() function returns a value between log(value + 1)
174 // and log(value + 2). The probability of this is:
175 // (log(value + 2) - log(value + 1)) / log_range
176 // To avoid two calls to log(), we compute this as follows:
177 return (log((value + 2.0) / (value + 1.0))) / log_range_;
178 }
179
ThreadUnsafeUnigramSampler(int64_t range)180 ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64_t range)
181 : RangeSampler(range), picker_(range) {
182 CHECK_LT(range, kint32max);
183 }
184
Sample(random::SimplePhilox * rnd) const185 int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
186 return picker_.Pick(rnd);
187 }
188
Probability(int64_t value) const189 float ThreadUnsafeUnigramSampler::Probability(int64_t value) const {
190 return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
191 }
192
Update(ArraySlice<int64> values)193 void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64> values) {
194 int num_updates = std::min(static_cast<int>(values.size()),
195 kint32max - picker_.total_weight());
196 for (int i = 0; i < num_updates; i++) {
197 const int64_t value = values[i];
198 picker_.set_weight(value, picker_.get_weight(value) + 1);
199 }
200 }
201
202 // Thread-safe unigram sampler
UnigramSampler(int64_t range)203 UnigramSampler::UnigramSampler(int64_t range)
204 : RangeSampler(range), unsafe_sampler_(range) {
205 CHECK_LT(range, kint32max);
206 }
207
Sample(random::SimplePhilox * rnd) const208 int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const {
209 tf_shared_lock lock(mu_);
210 return unsafe_sampler_.Sample(rnd);
211 }
212
Probability(int64_t value) const213 float UnigramSampler::Probability(int64_t value) const {
214 tf_shared_lock lock(mu_);
215 return unsafe_sampler_.Probability(value);
216 }
217
218 // Overriding at a high level results in far fewer lock acquisitions.
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const219 void UnigramSampler::SampleBatchGetExpectedCountAvoid(
220 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
221 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
222 MutableArraySlice<float> extras_expected_count,
223 ArraySlice<int64> avoided_values) const {
224 tf_shared_lock lock(mu_);
225 unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
226 rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
227 avoided_values);
228 }
229
Update(ArraySlice<int64> values)230 void UnigramSampler::Update(ArraySlice<int64> values) {
231 mutex_lock lock(mu_);
232 unsafe_sampler_.Update(values);
233 }
234
FixedUnigramSampler(Env * env,int64_t range,const string & vocab_file,float distortion,int32_t num_reserved_ids,int32_t num_shards,int32_t shard)235 FixedUnigramSampler::FixedUnigramSampler(Env* env, int64_t range,
236 const string& vocab_file,
237 float distortion,
238 int32_t num_reserved_ids,
239 int32_t num_shards, int32_t shard)
240 : RangeSampler(range),
241 total_weight_(0.0),
242 num_shards_(num_shards),
243 shard_(shard) {
244 FillReservedIds(num_reserved_ids);
245 // TODO(vanhoucke): make this non-crashing.
246 TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
247 CHECK_EQ(range, weights_.size());
248 dist_sampler_.reset(new random::DistributionSampler(weights_));
249 }
250
FixedUnigramSampler(int64_t range,const std::vector<float> & unigrams,float distortion,int32_t num_reserved_ids,int32_t num_shards,int32_t shard)251 FixedUnigramSampler::FixedUnigramSampler(int64_t range,
252 const std::vector<float>& unigrams,
253 float distortion,
254 int32_t num_reserved_ids,
255 int32_t num_shards, int32_t shard)
256 : RangeSampler(range),
257 total_weight_(0.0),
258 num_shards_(num_shards),
259 shard_(shard) {
260 FillReservedIds(num_reserved_ids);
261 LoadFromUnigrams(unigrams, distortion);
262 // TODO(vanhoucke): make this non-crashing.
263 CHECK_EQ(range, weights_.size());
264 dist_sampler_.reset(new random::DistributionSampler(weights_));
265 }
266
Probability(int64_t value) const267 float FixedUnigramSampler::Probability(int64_t value) const {
268 if (value < 0 || static_cast<size_t>(value) >= weights_.size()) {
269 return 0.0;
270 }
271 return weights_.at(value) / total_weight_;
272 }
273
Sample(random::SimplePhilox * rnd) const274 int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
275 return dist_sampler_->Sample(rnd);
276 }
277
FillReservedIds(int32_t num_reserved_ids)278 void FixedUnigramSampler::FillReservedIds(int32_t num_reserved_ids) {
279 for (int32_t word_id = 0; word_id < num_reserved_ids; ++word_id) {
280 if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
281 }
282 }
283
LoadFromFile(Env * env,const string & vocab_file,float distortion)284 Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
285 float distortion) {
286 std::unique_ptr<RandomAccessFile> file;
287 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
288
289 io::InputBuffer in(file.get(), 262144 /*bytes*/);
290 string line;
291 int32_t word_id = weights_.size();
292 while (in.ReadLine(&line).ok()) {
293 // The vocabulary file should be in csv like format, with the last
294 // field the weight associated with the word.
295 std::vector<string> cols = str_util::Split(line, ',');
296 if (cols.empty()) continue;
297 // Skip entries that do not belong to this shard.
298 if (word_id % num_shards_ == shard_) {
299 float w = 0.0;
300 if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) {
301 return errors::InvalidArgument("Wrong vocabulary format at line: ",
302 line);
303 }
304 w = std::pow(w, distortion);
305 total_weight_ += w;
306 weights_.push_back(w);
307 }
308 ++word_id;
309 }
310 return Status::OK();
311 }
312
LoadFromUnigrams(const std::vector<float> & unigrams,float distortion)313 void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
314 float distortion) {
315 int32_t word_id = weights_.size();
316 for (float w : unigrams) {
317 // Skip entries that do not belong to this shard.
318 if (word_id % num_shards_ == shard_) {
319 w = std::pow(w, distortion);
320 total_weight_ += w;
321 weights_.push_back(w);
322 }
323 ++word_id;
324 }
325 }
326
327 } // namespace tensorflow
328