• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/range_sampler.h"
17 
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/gtl/map_util.h"
23 #include "tensorflow/core/lib/io/inputbuffer.h"
24 #include "tensorflow/core/lib/strings/numbers.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
32 using gtl::ArraySlice;
33 using gtl::MutableArraySlice;
34 
~RangeSampler()35 RangeSampler::~RangeSampler() {}
36 
SampleBatch(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch) const37 void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
38                                gtl::MutableArraySlice<int64> batch) const {
39   SampleBatchGetExpectedCount(
40       rnd, unique, batch, gtl::MutableArraySlice<float>(),
41       gtl::ArraySlice<int64>(), gtl::MutableArraySlice<float>());
42 }
43 
SampleBatchGetExpectedCount(random::SimplePhilox * rnd,bool unique,gtl::MutableArraySlice<int64> batch,gtl::MutableArraySlice<float> batch_expected_count,gtl::ArraySlice<int64> extras,gtl::MutableArraySlice<float> extras_expected_count) const44 void RangeSampler::SampleBatchGetExpectedCount(
45     random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice<int64> batch,
46     gtl::MutableArraySlice<float> batch_expected_count,
47     gtl::ArraySlice<int64> extras,
48     gtl::MutableArraySlice<float> extras_expected_count) const {
49   SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
50                                    extras, extras_expected_count,
51                                    gtl::ArraySlice<int64>());
52 }
53 
54 namespace {
55 
56 // Approximates the expected count of a value in the output of SampleBatch.
57 //
58 // If unique=false, then this is (Probability(value) * batch_size)
59 //
60 // We use batch_size and num_tries, where num_tries is the observed number of
61 // tries it took to get batch_size unique values.
62 //
63 // Assuming (falsely) that the number of tries to get a batch of batch_size
64 // distinct values is _always_ num_tries, the probability that the value
65 // is in a batch is (1 - (1-p)^num_tries)
ExpectedCountHelper(float p,int batch_size,int num_tries)66 static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
67   if (num_tries == batch_size) {
68     // This shortcut will always be taken if unique=false
69     return p * batch_size;
70   }
71   // numerically stable version of (1 - (1-p)^num_tries)
72   return -expm1(num_tries * log1p(-p));
73 }
74 
75 }  // namespace
76 
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const77 void RangeSampler::SampleBatchGetExpectedCountAvoid(
78     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
79     MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
80     MutableArraySlice<float> extras_expected_count,
81     ArraySlice<int64> avoided_values) const {
82   const int batch_size = batch.size();
83   int num_tries;
84 
85   if (unique) {
86     CHECK_LE(batch_size + avoided_values.size(), range_);
87     std::unordered_set<int64> used(batch_size);
88     used.insert(avoided_values.begin(), avoided_values.end());
89     int num_picked = 0;
90     num_tries = 0;
91     while (num_picked < batch_size) {
92       num_tries++;
93       CHECK_LT(num_tries, kint32max);
94       int64 value = Sample(rnd);
95       if (gtl::InsertIfNotPresent(&used, value)) {
96         batch[num_picked++] = value;
97       }
98     }
99   } else {
100     CHECK_EQ(avoided_values.size(), size_t{0})
101         << "avoided_values only supported with unique=true";
102     for (int i = 0; i < batch_size; i++) {
103       batch[i] = Sample(rnd);
104     }
105     num_tries = batch_size;
106   }
107   // Compute the expected counts of the batch and the extra values
108   if (!batch_expected_count.empty()) {
109     CHECK_EQ(batch_size, batch_expected_count.size());
110     for (int i = 0; i < batch_size; i++) {
111       batch_expected_count[i] =
112           ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
113     }
114   }
115   CHECK_EQ(extras.size(), extras_expected_count.size());
116   for (size_t i = 0; i < extras.size(); i++) {
117     extras_expected_count[i] =
118         ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
119   }
120 }
121 
AllSampler(int64 range)122 AllSampler::AllSampler(int64 range) : RangeSampler(range) {}
123 
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const124 void AllSampler::SampleBatchGetExpectedCountAvoid(
125     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
126     MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
127     MutableArraySlice<float> extras_expected_count,
128     ArraySlice<int64> avoided_values) const {
129   const int batch_size = batch.size();
130   CHECK_EQ(range_, batch_size);
131   for (int i = 0; i < batch_size; i++) {
132     batch[i] = i;
133   }
134   if (!batch_expected_count.empty()) {
135     CHECK_EQ(batch_size, batch_expected_count.size());
136     for (int i = 0; i < batch_size; i++) {
137       batch_expected_count[i] = 1;
138     }
139   }
140   CHECK_EQ(size_t{0}, avoided_values.size());
141   CHECK_EQ(extras.size(), extras_expected_count.size());
142   for (size_t i = 0; i < extras.size(); i++) {
143     extras_expected_count[i] = 1;
144   }
145 }
146 
UniformSampler(int64 range)147 UniformSampler::UniformSampler(int64 range)
148     : RangeSampler(range), inv_range_(1.0 / range) {}
149 
Sample(random::SimplePhilox * rnd) const150 int64 UniformSampler::Sample(random::SimplePhilox* rnd) const {
151   return rnd->Uniform64(range_);
152 }
153 
Probability(int64 value) const154 float UniformSampler::Probability(int64 value) const { return inv_range_; }
155 
LogUniformSampler(int64 range)156 LogUniformSampler::LogUniformSampler(int64 range)
157     : RangeSampler(range), log_range_(log1p(range)) {}
158 
Sample(random::SimplePhilox * rnd) const159 int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
160   const int64 value =
161       static_cast<int64>(exp(rnd->RandDouble() * log_range_)) - 1;
162   CHECK_GE(value, 0);
163   // Mathematically, value should be <= range_, but might not be due to some
164   // floating point roundoff, so we mod by range_.
165   return value % range_;
166 }
167 
Probability(int64 value) const168 float LogUniformSampler::Probability(int64 value) const {
169   // value is returned iff the call to UniformDouble(log_range_) in the
170   // Sample() function returns a value between log(value + 1)
171   // and log(value + 2).   The probability of this is:
172   // (log(value + 2) - log(value + 1)) / log_range
173   // To avoid two calls to log(), we compute this as follows:
174   return (log((value + 2.0) / (value + 1.0))) / log_range_;
175 }
176 
ThreadUnsafeUnigramSampler(int64 range)177 ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64 range)
178     : RangeSampler(range), picker_(range) {
179   CHECK_LT(range, kint32max);
180 }
181 
Sample(random::SimplePhilox * rnd) const182 int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
183   return picker_.Pick(rnd);
184 }
185 
Probability(int64 value) const186 float ThreadUnsafeUnigramSampler::Probability(int64 value) const {
187   return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
188 }
189 
Update(ArraySlice<int64> values)190 void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64> values) {
191   int num_updates = std::min(static_cast<int>(values.size()),
192                              kint32max - picker_.total_weight());
193   for (int i = 0; i < num_updates; i++) {
194     const int64 value = values[i];
195     picker_.set_weight(value, picker_.get_weight(value) + 1);
196   }
197 }
198 
199 // Thread-safe unigram sampler
UnigramSampler(int64 range)200 UnigramSampler::UnigramSampler(int64 range)
201     : RangeSampler(range), unsafe_sampler_(range) {
202   CHECK_LT(range, kint32max);
203 }
204 
Sample(random::SimplePhilox * rnd) const205 int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const {
206   mutex_lock lock(mu_);  // could use reader lock
207   return unsafe_sampler_.Sample(rnd);
208 }
209 
Probability(int64 value) const210 float UnigramSampler::Probability(int64 value) const {
211   mutex_lock lock(mu_);  // could use reader lock
212   return unsafe_sampler_.Probability(value);
213 }
214 
215 // Overriding at a high level results in far fewer lock acquisitions.
SampleBatchGetExpectedCountAvoid(random::SimplePhilox * rnd,bool unique,MutableArraySlice<int64> batch,MutableArraySlice<float> batch_expected_count,ArraySlice<int64> extras,MutableArraySlice<float> extras_expected_count,ArraySlice<int64> avoided_values) const216 void UnigramSampler::SampleBatchGetExpectedCountAvoid(
217     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
218     MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
219     MutableArraySlice<float> extras_expected_count,
220     ArraySlice<int64> avoided_values) const {
221   mutex_lock lock(mu_);  // could use reader lock
222   unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
223       rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
224       avoided_values);
225 }
226 
Update(ArraySlice<int64> values)227 void UnigramSampler::Update(ArraySlice<int64> values) {
228   mutex_lock lock(mu_);
229   unsafe_sampler_.Update(values);
230 }
231 
FixedUnigramSampler(Env * env,int64 range,const string & vocab_file,float distortion,int32 num_reserved_ids,int32 num_shards,int32 shard)232 FixedUnigramSampler::FixedUnigramSampler(Env* env, int64 range,
233                                          const string& vocab_file,
234                                          float distortion,
235                                          int32 num_reserved_ids,
236                                          int32 num_shards, int32 shard)
237     : RangeSampler(range),
238       total_weight_(0.0),
239       num_shards_(num_shards),
240       shard_(shard) {
241   FillReservedIds(num_reserved_ids);
242   // TODO(vanhoucke): make this non-crashing.
243   TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
244   CHECK_EQ(range, weights_.size());
245   dist_sampler_.reset(new random::DistributionSampler(weights_));
246 }
247 
FixedUnigramSampler(int64 range,const std::vector<float> & unigrams,float distortion,int32 num_reserved_ids,int32 num_shards,int32 shard)248 FixedUnigramSampler::FixedUnigramSampler(int64 range,
249                                          const std::vector<float>& unigrams,
250                                          float distortion,
251                                          int32 num_reserved_ids,
252                                          int32 num_shards, int32 shard)
253     : RangeSampler(range),
254       total_weight_(0.0),
255       num_shards_(num_shards),
256       shard_(shard) {
257   FillReservedIds(num_reserved_ids);
258   LoadFromUnigrams(unigrams, distortion);
259   // TODO(vanhoucke): make this non-crashing.
260   CHECK_EQ(range, weights_.size());
261   dist_sampler_.reset(new random::DistributionSampler(weights_));
262 }
263 
Probability(int64 value) const264 float FixedUnigramSampler::Probability(int64 value) const {
265   if (value < 0 || static_cast<size_t>(value) >= weights_.size()) {
266     return 0.0;
267   }
268   return weights_.at(value) / total_weight_;
269 }
270 
Sample(random::SimplePhilox * rnd) const271 int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
272   return dist_sampler_->Sample(rnd);
273 }
274 
FillReservedIds(int32 num_reserved_ids)275 void FixedUnigramSampler::FillReservedIds(int32 num_reserved_ids) {
276   for (int32 word_id = 0; word_id < num_reserved_ids; ++word_id) {
277     if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
278   }
279 }
280 
LoadFromFile(Env * env,const string & vocab_file,float distortion)281 Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
282                                          float distortion) {
283   std::unique_ptr<RandomAccessFile> file;
284   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
285 
286   io::InputBuffer in(file.get(), 262144 /*bytes*/);
287   string line;
288   int32 word_id = weights_.size();
289   while (in.ReadLine(&line).ok()) {
290     // The vocabulary file should be in csv like format, with the last
291     // field the weight associated with the word.
292     std::vector<string> cols = str_util::Split(line, ',');
293     if (cols.empty()) continue;
294     // Skip entries that do not belong to this shard.
295     if (word_id % num_shards_ == shard_) {
296       float w = 0.0;
297       if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) {
298         return errors::InvalidArgument("Wrong vocabulary format at line: ",
299                                        line);
300       }
301       w = pow(w, distortion);
302       total_weight_ += w;
303       weights_.push_back(w);
304     }
305     ++word_id;
306   }
307   return Status::OK();
308 }
309 
LoadFromUnigrams(const std::vector<float> & unigrams,float distortion)310 void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
311                                            float distortion) {
312   int32 word_id = weights_.size();
313   for (float w : unigrams) {
314     // Skip entries that do not belong to this shard.
315     if (word_id % num_shards_ == shard_) {
316       w = pow(w, distortion);
317       total_weight_ += w;
318       weights_.push_back(w);
319     }
320     ++word_id;
321   }
322 }
323 
324 }  // namespace tensorflow
325