1 // Copyright 2016 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SRC_WEIGHTED_RESERVOIR_SAMPLER_H_ 16 #define SRC_WEIGHTED_RESERVOIR_SAMPLER_H_ 17 18 #include <cassert> 19 #include <random> 20 21 namespace protobuf_mutator { 22 23 // Algorithm pick one item from the sequence of weighted items. 24 // https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao 25 // 26 // Example: 27 // WeightedReservoirSampler<int> sampler; 28 // for(int i = 0; i < size; ++i) 29 // sampler.Pick(weight[i], i); 30 // return sampler.GetSelected(); 31 template <class T, class RandomEngine = std::default_random_engine> 32 class WeightedReservoirSampler { 33 public: WeightedReservoirSampler(RandomEngine * random)34 explicit WeightedReservoirSampler(RandomEngine* random) : random_(random) {} 35 Try(uint64_t weight,const T & item)36 void Try(uint64_t weight, const T& item) { 37 if (Pick(weight)) selected_ = item; 38 } 39 selected()40 const T& selected() const { return selected_; } 41 IsEmpty()42 bool IsEmpty() const { return total_weight_ == 0; } 43 44 private: Pick(uint64_t weight)45 bool Pick(uint64_t weight) { 46 if (weight == 0) return false; 47 total_weight_ += weight; 48 return weight == total_weight_ || std::uniform_int_distribution<uint64_t>( 49 1, total_weight_)(*random_) <= weight; 50 } 51 52 T selected_ = {}; 53 uint64_t total_weight_ = 0; 54 RandomEngine* random_; 55 }; 56 57 } // namespace protobuf_mutator 58 59 #endif // SRC_WEIGHTED_RESERVOIR_SAMPLER_H_ 60