• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/range_sampler.h"
17 
18 #include <cmath>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/io/inputbuffer.h"
25 #include "tensorflow/core/lib/strings/numbers.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 using gtl::ArraySlice;
34 using gtl::MutableArraySlice;
35 
~RangeSampler()36 RangeSampler::~RangeSampler() {}
37 
SampleBatch(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch) const38 void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
39                                gtl::MutableArraySlice<int64> batch) const {
40   SampleBatchGetExpectedCount(
41       rnd, unique, batch, gtl::MutableArraySlice<float>(),
42       gtl::ArraySlice<int64>(), gtl::MutableArraySlice<float>());
43 }
44 
SampleBatchGetExpectedCount(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch,gtl::MutableArraySlice<float> batch_expected_count,gtl::ArraySlice<int64> extras,gtl::MutableArraySlice<float> extras_expected_count) const45 void RangeSampler::SampleBatchGetExpectedCount(
46     random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice<int64> batch,
47     gtl::MutableArraySlice<float> batch_expected_count,
48     gtl::ArraySlice<int64> extras,
49     gtl::MutableArraySlice<float> extras_expected_count) const {
50   SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
51                                    extras, extras_expected_count,
52                                    gtl::ArraySlice<int64>());
53 }
54 
55 namespace {
56 
57 // Approximates the expected count of a value in the output of SampleBatch.
58 //
59 // If unique=false, then this is (Probability(value) * batch_size)
60 //
61 // We use batch_size and num_tries, where num_tries is the observed number of
62 // tries it took to get batch_size unique values.
63 //
64 // Assuming (falsely) that the number of tries to get a batch of batch_size
65 // distinct values is _always_ num_tries, the probability that the value
66 // is in a batch is (1 - (1-p)^num_tries)
ExpectedCountHelper(float p,int batch_size,int num_tries)67 static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
68   if (num_tries == batch_size) {
69     // This shortcut will always be taken if unique=false
70     return p * batch_size;
71   }
72   // numerically stable version of (1 - (1-p)^num_tries)
73   return -std::expm1(num_tries * std::log1p(-p));
74 }
75 
76 }  // namespace
77 
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const78 void RangeSampler::SampleBatchGetExpectedCountAvoid(
79     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
80     MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
81     MutableArraySlice<float> extras_expected_count,
82     ArraySlice<int64> avoided_values) const {
83   const int batch_size = batch.size();
84   int num_tries;
85 
86   if (unique) {
87     CHECK_LE(static_cast<int64>(batch_size + avoided_values.size()), range_);
88     std::unordered_set<int64> used(batch_size);
89     used.insert(avoided_values.begin(), avoided_values.end());
90     int num_picked = 0;
91     num_tries = 0;
92     while (num_picked < batch_size) {
93       num_tries++;
94       CHECK_LT(num_tries, kint32max);
95       int64 value = Sample(rnd);
96       if (gtl::InsertIfNotPresent(&used, value)) {
97         batch[num_picked++] = value;
98       }
99     }
100   } else {
101     CHECK_EQ(avoided_values.size(), size_t{0})
102         << "avoided_values only supported with unique=true";
103     for (int i = 0; i < batch_size; i++) {
104       batch[i] = Sample(rnd);
105     }
106     num_tries = batch_size;
107   }
108   // Compute the expected counts of the batch and the extra values
109   if (!batch_expected_count.empty()) {
110     CHECK_EQ(batch_size, batch_expected_count.size());
111     for (int i = 0; i < batch_size; i++) {
112       batch_expected_count[i] =
113           ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
114     }
115   }
116   CHECK_EQ(extras.size(), extras_expected_count.size());
117   for (size_t i = 0; i < extras.size(); i++) {
118     extras_expected_count[i] =
119         ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
120   }
121 }
122 
AllSampler(int64 range)123 AllSampler::AllSampler(int64 range) : RangeSampler(range) {}
124 
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const125 void AllSampler::SampleBatchGetExpectedCountAvoid(
126     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
127     MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
128     MutableArraySlice<float> extras_expected_count,
129     ArraySlice<int64> avoided_values) const {
130   const int batch_size = batch.size();
131   CHECK_EQ(range_, batch_size);
132   for (int i = 0; i < batch_size; i++) {
133     batch[i] = i;
134   }
135   if (!batch_expected_count.empty()) {
136     CHECK_EQ(batch_size, batch_expected_count.size());
137     for (int i = 0; i < batch_size; i++) {
138       batch_expected_count[i] = 1;
139     }
140   }
141   CHECK_EQ(size_t{0}, avoided_values.size());
142   CHECK_EQ(extras.size(), extras_expected_count.size());
143   for (size_t i = 0; i < extras.size(); i++) {
144     extras_expected_count[i] = 1;
145   }
146 }
147 
UniformSampler(int64 range)148 UniformSampler::UniformSampler(int64 range)
149     : RangeSampler(range), inv_range_(1.0 / range) {}
150 
Sample(random::SimplePhilox * rnd) const151 int64 UniformSampler::Sample(random::SimplePhilox* rnd) const {
152   return rnd->Uniform64(range_);
153 }
154 
Probability(int64 value) const155 float UniformSampler::Probability(int64 value) const { return inv_range_; }
156 
LogUniformSampler(int64 range)157 LogUniformSampler::LogUniformSampler(int64 range)
158     : RangeSampler(range), log_range_(log1p(range)) {}
159 
Sample(random::SimplePhilox * rnd) const160 int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
161   const int64 value =
162       static_cast<int64>(exp(rnd->RandDouble() * log_range_)) - 1;
163   DCHECK_GE(value, 0);
164   // Mathematically, value should be <= range_, but might not be due to some
165   // floating point roundoff, so we mod by range_.  In practice this case
166   // happens never regardless of the value of range_, including and up to
167   // DBL_MAX.  But we include it as a guarantee of the function's output.
168   return value % range_;
169 }
170 
Probability(int64 value) const171 float LogUniformSampler::Probability(int64 value) const {
172   // value is returned iff the call to UniformDouble(log_range_) in the
173   // Sample() function returns a value between log(value + 1)
174   // and log(value + 2).   The probability of this is:
175   // (log(value + 2) - log(value + 1)) / log_range
176   // To avoid two calls to log(), we compute this as follows:
177   return (log((value + 2.0) / (value + 1.0))) / log_range_;
178 }
179 
ThreadUnsafeUnigramSampler(int64 range)180 ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64 range)
181     : RangeSampler(range), picker_(range) {
182   CHECK_LT(range, kint32max);
183 }
184 
Sample(random::SimplePhilox * rnd) const185 int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
186   return picker_.Pick(rnd);
187 }
188 
Probability(int64 value) const189 float ThreadUnsafeUnigramSampler::Probability(int64 value) const {
190   return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
191 }
192 
Update(ArraySlice<int64> values)193 void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64> values) {
194   int num_updates = std::min(static_cast<int>(values.size()),
195                              kint32max - picker_.total_weight());
196   for (int i = 0; i < num_updates; i++) {
197     const int64 value = values[i];
198     picker_.set_weight(value, picker_.get_weight(value) + 1);
199   }
200 }
201 
202 // Thread-safe unigram sampler
UnigramSampler(int64 range)203 UnigramSampler::UnigramSampler(int64 range)
204     : RangeSampler(range), unsafe_sampler_(range) {
205   CHECK_LT(range, kint32max);
206 }
207 
Sample(random::SimplePhilox * rnd) const208 int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const {
209   mutex_lock lock(mu_);  // could use reader lock
210   return unsafe_sampler_.Sample(rnd);
211 }
212 
Probability(int64 value) const213 float UnigramSampler::Probability(int64 value) const {
214   mutex_lock lock(mu_);  // could use reader lock
215   return unsafe_sampler_.Probability(value);
216 }
217 
218 // Overriding at a high level results in far fewer lock acquisitions.
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const219 void UnigramSampler::SampleBatchGetExpectedCountAvoid(
220     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
221     MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
222     MutableArraySlice<float> extras_expected_count,
223     ArraySlice<int64> avoided_values) const {
224   mutex_lock lock(mu_);  // could use reader lock
225   unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
226       rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
227       avoided_values);
228 }
229 
Update(ArraySlice<int64> values)230 void UnigramSampler::Update(ArraySlice<int64> values) {
231   mutex_lock lock(mu_);
232   unsafe_sampler_.Update(values);
233 }
234 
FixedUnigramSampler(Env * env,int64 range,const string & vocab_file,float distortion,int32 num_reserved_ids,int32 num_shards,int32 shard)235 FixedUnigramSampler::FixedUnigramSampler(Env* env, int64 range,
236                                          const string& vocab_file,
237                                          float distortion,
238                                          int32 num_reserved_ids,
239                                          int32 num_shards, int32 shard)
240     : RangeSampler(range),
241       total_weight_(0.0),
242       num_shards_(num_shards),
243       shard_(shard) {
244   FillReservedIds(num_reserved_ids);
245   // TODO(vanhoucke): make this non-crashing.
246   TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
247   CHECK_EQ(range, weights_.size());
248   dist_sampler_.reset(new random::DistributionSampler(weights_));
249 }
250 
FixedUnigramSampler(int64 range,const std::vector<float> & unigrams,float distortion,int32 num_reserved_ids,int32 num_shards,int32 shard)251 FixedUnigramSampler::FixedUnigramSampler(int64 range,
252                                          const std::vector<float>& unigrams,
253                                          float distortion,
254                                          int32 num_reserved_ids,
255                                          int32 num_shards, int32 shard)
256     : RangeSampler(range),
257       total_weight_(0.0),
258       num_shards_(num_shards),
259       shard_(shard) {
260   FillReservedIds(num_reserved_ids);
261   LoadFromUnigrams(unigrams, distortion);
262   // TODO(vanhoucke): make this non-crashing.
263   CHECK_EQ(range, weights_.size());
264   dist_sampler_.reset(new random::DistributionSampler(weights_));
265 }
266 
Probability(int64 value) const267 float FixedUnigramSampler::Probability(int64 value) const {
268   if (value < 0 || static_cast<size_t>(value) >= weights_.size()) {
269     return 0.0;
270   }
271   return weights_.at(value) / total_weight_;
272 }
273 
Sample(random::SimplePhilox * rnd) const274 int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
275   return dist_sampler_->Sample(rnd);
276 }
277 
FillReservedIds(int32 num_reserved_ids)278 void FixedUnigramSampler::FillReservedIds(int32 num_reserved_ids) {
279   for (int32 word_id = 0; word_id < num_reserved_ids; ++word_id) {
280     if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
281   }
282 }
283 
LoadFromFile(Env * env,const string & vocab_file,float distortion)284 Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
285                                          float distortion) {
286   std::unique_ptr<RandomAccessFile> file;
287   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
288 
289   io::InputBuffer in(file.get(), 262144 /*bytes*/);
290   string line;
291   int32 word_id = weights_.size();
292   while (in.ReadLine(&line).ok()) {
293     // The vocabulary file should be in csv like format, with the last
294     // field the weight associated with the word.
295     std::vector<string> cols = str_util::Split(line, ',');
296     if (cols.empty()) continue;
297     // Skip entries that do not belong to this shard.
298     if (word_id % num_shards_ == shard_) {
299       float w = 0.0;
300       if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) {
301         return errors::InvalidArgument("Wrong vocabulary format at line: ",
302                                        line);
303       }
304       w = std::pow(w, distortion);
305       total_weight_ += w;
306       weights_.push_back(w);
307     }
308     ++word_id;
309   }
310   return Status::OK();
311 }
312 
LoadFromUnigrams(const std::vector<float> & unigrams,float distortion)313 void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
314                                            float distortion) {
315   int32 word_id = weights_.size();
316   for (float w : unigrams) {
317     // Skip entries that do not belong to this shard.
318     if (word_id % num_shards_ == shard_) {
319       w = std::pow(w, distortion);
320       total_weight_ += w;
321       weights_.push_back(w);
322     }
323     ++word_id;
324   }
325 }
326 
327 }  // namespace tensorflow
328