1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_ 17 #define TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_ 18 19 #include <vector> 20 21 #include "tensorflow/core/lib/core/status.h" 22 #include "tensorflow/core/lib/gtl/array_slice.h" 23 #include "tensorflow/core/lib/random/distribution_sampler.h" 24 #include "tensorflow/core/lib/random/random_distributions.h" 25 #include "tensorflow/core/lib/random/weighted_picker.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/mutex.h" 28 #include "tensorflow/core/platform/thread_annotations.h" 29 #include "tensorflow/core/platform/types.h" 30 31 namespace tensorflow { 32 33 class Env; 34 35 // Abstract subclass for sampling from the set of non-negative integers 36 // [0, range) 37 class RangeSampler { 38 public: RangeSampler(int64 range)39 explicit RangeSampler(int64 range) : range_(range) { CHECK_GT(range_, 0); } 40 virtual ~RangeSampler(); 41 42 // Sample a single value 43 virtual int64 Sample(random::SimplePhilox* rnd) const = 0; 44 45 // The probability that a single call to Sample() returns the given value. 46 // Assumes that value is in [0, range). No range checking is done. 47 virtual float Probability(int64 value) const = 0; 48 49 // Fill "batch" with samples from the distribution. 50 // If unique=true, then we re-pick each element until we get a 51 // value distinct from all previously picked values in the batch. 52 void SampleBatch(random::SimplePhilox* rnd, bool unique, 53 gtl::MutableArraySlice<int64> batch) const; 54 55 // Fill "batch" with samples from the distribution, and report 56 // "expected counts". 57 // 58 // The "expected count" of a value is an estimate of the expected 59 // number of occurrences of the value in the batch returned by a 60 // call to this function with the given parameters. If unique=true, 61 // the expected count is an inclusion probability. For details on 62 // this estimation, see the comment to "ExpectedCountHelper" in the 63 // .cc file. 64 // 65 // Expected counts for the elements of the returned "batch" are reported 66 // in the aligned array "batch_expected_count". 67 // 68 // The user can optionally provide "extras", containing values in the range. 69 // The expected counts for the extras are reported in the aligned array 70 // "extras_expected_count". 71 // 72 // "batch_expected_count" must have size equal to 0 or to the size of "batch". 73 // "extras" and "extras_expected_count" must have equal size. 74 void SampleBatchGetExpectedCount( 75 random::SimplePhilox* rnd, bool unique, 76 gtl::MutableArraySlice<int64> batch, 77 gtl::MutableArraySlice<float> batch_expected_count, 78 gtl::ArraySlice<int64> extras, 79 gtl::MutableArraySlice<float> extras_expected_count) const; 80 81 // Same as SampleBatchGetExpectedCount (see above), but with avoided values. 82 // We repick to avoid all of the values in "avoided_values". 83 // "avoided_values" is only supported with unique=true. If 84 // unique=false, then avoided_values must be empty. 85 virtual void SampleBatchGetExpectedCountAvoid( 86 random::SimplePhilox* rnd, bool unique, 87 gtl::MutableArraySlice<int64> batch, 88 gtl::MutableArraySlice<float> batch_expected_count, 89 gtl::ArraySlice<int64> extras, 90 gtl::MutableArraySlice<float> extras_expected_count, 91 gtl::ArraySlice<int64> avoided_values) const; 92 93 // Does this sampler need to be updated with values, e.g. UnigramSampler NeedsUpdates()94 virtual bool NeedsUpdates() const { return false; } 95 96 // Updates the underlying distribution Update(gtl::ArraySlice<int64> values)97 virtual void Update(gtl::ArraySlice<int64> values) { 98 LOG(FATAL) << "Update not supported for this sampler type."; 99 } 100 range()101 int64 range() { return range_; } 102 103 protected: 104 const int64 range_; 105 }; 106 107 // An AllSampler only samples batches of size equal to range. 108 // It returns the entire range. 109 // It cannot sample single values. 110 class AllSampler : public RangeSampler { 111 public: 112 explicit AllSampler(int64 range); 113 ~AllSampler()114 ~AllSampler() override {} 115 Sample(random::SimplePhilox * rnd)116 int64 Sample(random::SimplePhilox* rnd) const override { 117 LOG(FATAL) << "Should not be called"; 118 return 0; 119 } 120 Probability(int64 value)121 float Probability(int64 value) const override { 122 LOG(FATAL) << "Should not be called"; 123 return 0; 124 } 125 126 void SampleBatchGetExpectedCountAvoid( 127 random::SimplePhilox* rnd, bool unique, 128 gtl::MutableArraySlice<int64> batch, 129 gtl::MutableArraySlice<float> batch_expected_count, 130 gtl::ArraySlice<int64> extras, 131 gtl::MutableArraySlice<float> extras_expected_count, 132 gtl::ArraySlice<int64> avoided_values) const override; 133 }; 134 135 class UniformSampler : public RangeSampler { 136 public: 137 explicit UniformSampler(int64 range); 138 ~UniformSampler()139 ~UniformSampler() override {} 140 141 int64 Sample(random::SimplePhilox* rnd) const override; 142 143 float Probability(int64 value) const override; 144 145 private: 146 const float inv_range_; 147 }; 148 149 class LogUniformSampler : public RangeSampler { 150 public: 151 explicit LogUniformSampler(int64 range); 152 ~LogUniformSampler()153 ~LogUniformSampler() override {} 154 155 int64 Sample(random::SimplePhilox* rnd) const override; 156 157 float Probability(int64 value) const override; 158 159 private: 160 const double log_range_; 161 }; 162 163 // Thread-unsafe unigram sampler 164 class ThreadUnsafeUnigramSampler : public RangeSampler { 165 public: 166 explicit ThreadUnsafeUnigramSampler(int64 range); ~ThreadUnsafeUnigramSampler()167 ~ThreadUnsafeUnigramSampler() override {} 168 169 int64 Sample(random::SimplePhilox* rnd) const override; 170 171 float Probability(int64 value) const override; 172 NeedsUpdates()173 bool NeedsUpdates() const override { return true; } 174 void Update(gtl::ArraySlice<int64> values) override; 175 176 private: 177 random::WeightedPicker picker_; 178 }; 179 180 // Thread-safe unigram sampler 181 class UnigramSampler : public RangeSampler { 182 public: 183 explicit UnigramSampler(int64 range); ~UnigramSampler()184 ~UnigramSampler() override {} 185 186 int64 Sample(random::SimplePhilox* rnd) const override; 187 188 float Probability(int64 value) const override; 189 190 // Overriding at a high level results in far fewer lock acquisitions. 191 void SampleBatchGetExpectedCountAvoid( 192 random::SimplePhilox* rnd, bool unique, 193 gtl::MutableArraySlice<int64> batch, 194 gtl::MutableArraySlice<float> batch_expected_count, 195 gtl::ArraySlice<int64> extras, 196 gtl::MutableArraySlice<float> extras_expected_count, 197 gtl::ArraySlice<int64> avoided_values) const override; 198 NeedsUpdates()199 bool NeedsUpdates() const override { return true; } 200 void Update(gtl::ArraySlice<int64> values) override; 201 202 private: 203 ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_); 204 mutable mutex mu_; 205 }; 206 207 // A unigram sampler that uses a fixed unigram distribution read from a 208 // file or passed in as an in-memory array instead of building up the 209 // distribution from data on the fly. There is also an option to skew the 210 // distribution by applying a distortion power to the weights. 211 class FixedUnigramSampler : public RangeSampler { 212 public: 213 // The vocab_file is assumed to be a CSV, with the last entry of each row a 214 // value representing the counts or probabilities for the corresponding ID. 215 FixedUnigramSampler(Env* env, int64 range, const string& vocab_file, 216 float distortion, int32 num_reserved_ids, 217 int32 num_shards, int32 shard); 218 219 FixedUnigramSampler(int64 range, const std::vector<float>& unigrams, 220 float distortion, int32 num_reserved_ids, 221 int32 num_shards, int32 shard); 222 223 float Probability(int64 value) const override; 224 225 int64 Sample(random::SimplePhilox* rnd) const override; 226 227 private: 228 // Underlying distribution sampler. 229 std::unique_ptr<random::DistributionSampler> dist_sampler_; 230 // Weights for individual samples. The probability of a sample i is defined 231 // as weights_.at(i) / total_weight_. 232 std::vector<float> weights_; 233 // The total weights of all samples. 234 float total_weight_; 235 // Sharding information of the sampler. The whole vocabulary is sharded 236 // into num_shards_ smaller ranges and each sampler is responsible for one 237 // such smaller range, identified by the shard number. 238 int32 num_shards_; 239 int32 shard_; 240 241 // Fill the sampler with the appropriate number of reserved IDs. 242 void FillReservedIds(int32 num_reserved_ids); 243 // Load IDs to sample from a CSV file. It is assumed that the last item of 244 // each row contains a count or probability for the corresponding ID. 245 Status LoadFromFile(Env* env, const string& vocab_file, float distortion); 246 // Load from an in-memory array. 247 void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion); 248 }; 249 250 } // namespace tensorflow 251 252 #endif // TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_ 253