• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
17 #define TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/core/lib/random/distribution_sampler.h"
24 #include "tensorflow/core/lib/random/random_distributions.h"
25 #include "tensorflow/core/lib/random/weighted_picker.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/thread_annotations.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 class Env;
34 
35 // Abstract subclass for sampling from the set of non-negative integers
36 // [0, range)
37 class RangeSampler {
38  public:
RangeSampler(int64 range)39   explicit RangeSampler(int64 range) : range_(range) { CHECK_GT(range_, 0); }
40   virtual ~RangeSampler();
41 
42   // Sample a single value
43   virtual int64 Sample(random::SimplePhilox* rnd) const = 0;
44 
45   // The probability that a single call to Sample() returns the given value.
46   // Assumes that value is in [0, range).  No range checking is done.
47   virtual float Probability(int64 value) const = 0;
48 
49   // Fill "batch" with samples from the distribution.
50   // If unique=true, then we re-pick each element until we get a
51   // value distinct from all previously picked values in the batch.
52   void SampleBatch(random::SimplePhilox* rnd, bool unique,
53                    gtl::MutableArraySlice<int64> batch) const;
54 
55   // Fill "batch" with samples from the distribution, and report
56   // "expected counts".
57   //
58   // The "expected count" of a value is an estimate of the expected
59   // number of occurrences of the value in the batch returned by a
60   // call to this function with the given parameters.  If unique=true,
61   // the expected count is an inclusion probability.  For details on
62   // this estimation, see the comment to "ExpectedCountHelper" in the
63   // .cc file.
64   //
65   // Expected counts for the elements of the returned "batch" are reported
66   // in the aligned array "batch_expected_count".
67   //
68   // The user can optionally provide "extras", containing values in the range.
69   // The expected counts for the extras are reported in the aligned array
70   // "extras_expected_count".
71   //
72   // "batch_expected_count" must have size equal to 0 or to the size of "batch".
73   // "extras" and "extras_expected_count" must have equal size.
74   void SampleBatchGetExpectedCount(
75       random::SimplePhilox* rnd, bool unique,
76       gtl::MutableArraySlice<int64> batch,
77       gtl::MutableArraySlice<float> batch_expected_count,
78       gtl::ArraySlice<int64> extras,
79       gtl::MutableArraySlice<float> extras_expected_count) const;
80 
81   // Same as SampleBatchGetExpectedCount (see above), but with avoided values.
82   // We repick to avoid all of the values in "avoided_values".
83   // "avoided_values" is only supported with unique=true.  If
84   // unique=false, then avoided_values must be empty.
85   virtual void SampleBatchGetExpectedCountAvoid(
86       random::SimplePhilox* rnd, bool unique,
87       gtl::MutableArraySlice<int64> batch,
88       gtl::MutableArraySlice<float> batch_expected_count,
89       gtl::ArraySlice<int64> extras,
90       gtl::MutableArraySlice<float> extras_expected_count,
91       gtl::ArraySlice<int64> avoided_values) const;
92 
93   // Does this sampler need to be updated with values, e.g. UnigramSampler
NeedsUpdates()94   virtual bool NeedsUpdates() const { return false; }
95 
96   // Updates the underlying distribution
Update(gtl::ArraySlice<int64> values)97   virtual void Update(gtl::ArraySlice<int64> values) {
98     LOG(FATAL) << "Update not supported for this sampler type.";
99   }
100 
range()101   int64 range() { return range_; }
102 
103  protected:
104   const int64 range_;
105 };
106 
107 // An AllSampler only samples batches of size equal to range.
108 // It returns the entire range.
109 // It cannot sample single values.
110 class AllSampler : public RangeSampler {
111  public:
112   explicit AllSampler(int64 range);
113 
~AllSampler()114   ~AllSampler() override {}
115 
Sample(random::SimplePhilox * rnd)116   int64 Sample(random::SimplePhilox* rnd) const override {
117     LOG(FATAL) << "Should not be called";
118     return 0;
119   }
120 
Probability(int64 value)121   float Probability(int64 value) const override {
122     LOG(FATAL) << "Should not be called";
123     return 0;
124   }
125 
126   void SampleBatchGetExpectedCountAvoid(
127       random::SimplePhilox* rnd, bool unique,
128       gtl::MutableArraySlice<int64> batch,
129       gtl::MutableArraySlice<float> batch_expected_count,
130       gtl::ArraySlice<int64> extras,
131       gtl::MutableArraySlice<float> extras_expected_count,
132       gtl::ArraySlice<int64> avoided_values) const override;
133 };
134 
135 class UniformSampler : public RangeSampler {
136  public:
137   explicit UniformSampler(int64 range);
138 
~UniformSampler()139   ~UniformSampler() override {}
140 
141   int64 Sample(random::SimplePhilox* rnd) const override;
142 
143   float Probability(int64 value) const override;
144 
145  private:
146   const float inv_range_;
147 };
148 
149 class LogUniformSampler : public RangeSampler {
150  public:
151   explicit LogUniformSampler(int64 range);
152 
~LogUniformSampler()153   ~LogUniformSampler() override {}
154 
155   int64 Sample(random::SimplePhilox* rnd) const override;
156 
157   float Probability(int64 value) const override;
158 
159  private:
160   const double log_range_;
161 };
162 
163 // Thread-unsafe unigram sampler
164 class ThreadUnsafeUnigramSampler : public RangeSampler {
165  public:
166   explicit ThreadUnsafeUnigramSampler(int64 range);
~ThreadUnsafeUnigramSampler()167   ~ThreadUnsafeUnigramSampler() override {}
168 
169   int64 Sample(random::SimplePhilox* rnd) const override;
170 
171   float Probability(int64 value) const override;
172 
NeedsUpdates()173   bool NeedsUpdates() const override { return true; }
174   void Update(gtl::ArraySlice<int64> values) override;
175 
176  private:
177   random::WeightedPicker picker_;
178 };
179 
180 // Thread-safe unigram sampler
181 class UnigramSampler : public RangeSampler {
182  public:
183   explicit UnigramSampler(int64 range);
~UnigramSampler()184   ~UnigramSampler() override {}
185 
186   int64 Sample(random::SimplePhilox* rnd) const override;
187 
188   float Probability(int64 value) const override;
189 
190   // Overriding at a high level results in far fewer lock acquisitions.
191   void SampleBatchGetExpectedCountAvoid(
192       random::SimplePhilox* rnd, bool unique,
193       gtl::MutableArraySlice<int64> batch,
194       gtl::MutableArraySlice<float> batch_expected_count,
195       gtl::ArraySlice<int64> extras,
196       gtl::MutableArraySlice<float> extras_expected_count,
197       gtl::ArraySlice<int64> avoided_values) const override;
198 
NeedsUpdates()199   bool NeedsUpdates() const override { return true; }
200   void Update(gtl::ArraySlice<int64> values) override;
201 
202  private:
203   ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_);
204   mutable mutex mu_;
205 };
206 
207 // A unigram sampler that uses a fixed unigram distribution read from a
208 // file or passed in as an in-memory array instead of building up the
209 // distribution from data on the fly. There is also an option to skew the
210 // distribution by applying a distortion power to the weights.
211 class FixedUnigramSampler : public RangeSampler {
212  public:
213   // The vocab_file is assumed to be a CSV, with the last entry of each row a
214   // value representing the counts or probabilities for the corresponding ID.
215   FixedUnigramSampler(Env* env, int64 range, const string& vocab_file,
216                       float distortion, int32 num_reserved_ids,
217                       int32 num_shards, int32 shard);
218 
219   FixedUnigramSampler(int64 range, const std::vector<float>& unigrams,
220                       float distortion, int32 num_reserved_ids,
221                       int32 num_shards, int32 shard);
222 
223   float Probability(int64 value) const override;
224 
225   int64 Sample(random::SimplePhilox* rnd) const override;
226 
227  private:
228   // Underlying distribution sampler.
229   std::unique_ptr<random::DistributionSampler> dist_sampler_;
230   // Weights for individual samples. The probability of a sample i is defined
231   // as weights_.at(i) / total_weight_.
232   std::vector<float> weights_;
233   // The total weights of all samples.
234   float total_weight_;
235   // Sharding information of the sampler. The whole vocabulary is sharded
236   // into num_shards_ smaller ranges and each sampler is responsible for one
237   // such smaller range, identified by the shard number.
238   int32 num_shards_;
239   int32 shard_;
240 
241   // Fill the sampler with the appropriate number of reserved IDs.
242   void FillReservedIds(int32 num_reserved_ids);
243   // Load IDs to sample from a CSV file. It is assumed that the last item of
244   // each row contains a count or probability for the corresponding ID.
245   Status LoadFromFile(Env* env, const string& vocab_file, float distortion);
246   // Load from an in-memory array.
247   void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion);
248 };
249 
250 }  // namespace tensorflow
251 
252 #endif  // TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
253