1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // DistributionSampler allows generating a discrete random variable with a given 17 // distribution. 18 // The values taken by the variable are [0, N) and relative weights for each 19 // value are specified using a vector of size N. 20 // 21 // The Algorithm takes O(N) time to precompute data at construction time and 22 // takes O(1) time (2 random number generation, 2 lookups) for each sample. 23 // The data structure takes O(N) memory. 24 // 25 // In contrast, util/random/weighted-picker.h provides O(lg N) sampling. 26 // The advantage of that implementation is that weights can be adjusted 27 // dynamically, while DistributionSampler doesn't allow weight adjustment. 28 // 29 // The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2. 30 31 #ifndef TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ 32 #define TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ 33 34 #include <memory> 35 #include <utility> 36 37 #include "tensorflow/core/lib/gtl/array_slice.h" 38 #include "tensorflow/core/lib/random/simple_philox.h" 39 #include "tensorflow/core/platform/logging.h" 40 #include "tensorflow/core/platform/macros.h" 41 #include "tensorflow/core/platform/types.h" 42 43 namespace tensorflow { 44 namespace random { 45 46 class DistributionSampler { 47 public: 48 explicit DistributionSampler(const gtl::ArraySlice<float>& weights); 49 ~DistributionSampler()50 ~DistributionSampler() {} 51 Sample(SimplePhilox * rand)52 int Sample(SimplePhilox* rand) const { 53 float r = rand->RandFloat(); 54 // Since n is typically low, we don't bother with UnbiasedUniform. 55 int idx = rand->Uniform(num_); 56 if (r < prob(idx)) return idx; 57 // else pick alt from that bucket. 58 DCHECK_NE(-1, alt(idx)); 59 return alt(idx); 60 } 61 num()62 int num() const { return num_; } 63 64 private: prob(int idx)65 float prob(int idx) const { 66 DCHECK_LT(idx, num_); 67 return data_[idx].first; 68 } 69 alt(int idx)70 int alt(int idx) const { 71 DCHECK_LT(idx, num_); 72 return data_[idx].second; 73 } 74 set_prob(int idx,float f)75 void set_prob(int idx, float f) { 76 DCHECK_LT(idx, num_); 77 data_[idx].first = f; 78 } 79 set_alt(int idx,int val)80 void set_alt(int idx, int val) { 81 DCHECK_LT(idx, num_); 82 data_[idx].second = val; 83 } 84 85 int num_; 86 std::unique_ptr<std::pair<float, int>[]> data_; 87 88 TF_DISALLOW_COPY_AND_ASSIGN(DistributionSampler); 89 }; 90 91 } // namespace random 92 } // namespace tensorflow 93 94 #endif // TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ 95