1 // Copyright 2018 Developers of the Rand project. 2 // 3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license 5 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your 6 // option. This file may not be copied, modified, or distributed 7 // except according to those terms. 8 9 //! Weighted index sampling 10 11 use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; 12 use crate::distributions::Distribution; 13 use crate::Rng; 14 use core::cmp::PartialOrd; 15 use core::fmt; 16 17 // Note that this whole module is only imported if feature="alloc" is enabled. 18 use alloc::vec::Vec; 19 20 #[cfg(feature = "serde1")] 21 use serde::{Serialize, Deserialize}; 22 23 /// A distribution using weighted sampling of discrete items 24 /// 25 /// Sampling a `WeightedIndex` distribution returns the index of a randomly 26 /// selected element from the iterator used when the `WeightedIndex` was 27 /// created. The chance of a given element being picked is proportional to the 28 /// value of the element. The weights can use any type `X` for which an 29 /// implementation of [`Uniform<X>`] exists. 30 /// 31 /// # Performance 32 /// 33 /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where 34 /// `N` is the number of weights. As an alternative, 35 /// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) 36 /// supports `O(1)` sampling, but with much higher initialisation cost. 37 /// 38 /// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its 39 /// size is the sum of the size of those objects, possibly plus some alignment. 40 /// 41 /// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1` 42 /// weights of type `X`, where `N` is the number of weights. However, since 43 /// `Vec` doesn't guarantee a particular growth strategy, additional memory 44 /// might be allocated but not used. Since the `WeightedIndex` object also 45 /// contains, this might cause additional allocations, though for primitive 46 /// types, [`Uniform<X>`] doesn't allocate any memory. 47 /// 48 /// Sampling from `WeightedIndex` will result in a single call to 49 /// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically 50 /// will request a single value from the underlying [`RngCore`], though the 51 /// exact number depends on the implementation of `Uniform<X>::sample`. 52 /// 53 /// # Example 54 /// 55 /// ``` 56 /// use rand::prelude::*; 57 /// use rand::distributions::WeightedIndex; 58 /// 59 /// let choices = ['a', 'b', 'c']; 60 /// let weights = [2, 1, 1]; 61 /// let dist = WeightedIndex::new(&weights).unwrap(); 62 /// let mut rng = thread_rng(); 63 /// for _ in 0..100 { 64 /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' 65 /// println!("{}", choices[dist.sample(&mut rng)]); 66 /// } 67 /// 68 /// let items = [('a', 0), ('b', 3), ('c', 7)]; 69 /// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); 70 /// for _ in 0..100 { 71 /// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' 72 /// println!("{}", items[dist2.sample(&mut rng)].0); 73 /// } 74 /// ``` 75 /// 76 /// [`Uniform<X>`]: crate::distributions::Uniform 77 /// [`RngCore`]: crate::RngCore 78 #[derive(Debug, Clone)] 79 #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] 80 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] 81 pub struct WeightedIndex<X: SampleUniform + PartialOrd> { 82 cumulative_weights: Vec<X>, 83 total_weight: X, 84 weight_distribution: X::Sampler, 85 } 86 87 impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { 88 /// Creates a new a `WeightedIndex` [`Distribution`] using the values 89 /// in `weights`. The weights can use any type `X` for which an 90 /// implementation of [`Uniform<X>`] exists. 91 /// 92 /// Returns an error if the iterator is empty, if any weight is `< 0`, or 93 /// if its total value is 0. 94 /// 95 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> where I: IntoIterator, I::Item: SampleBorrow<X>, X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,96 pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> 97 where 98 I: IntoIterator, 99 I::Item: SampleBorrow<X>, 100 X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, 101 { 102 let mut iter = weights.into_iter(); 103 let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); 104 105 let zero = <X as Default>::default(); 106 if !(total_weight >= zero) { 107 return Err(WeightedError::InvalidWeight); 108 } 109 110 let mut weights = Vec::<X>::with_capacity(iter.size_hint().0); 111 for w in iter { 112 // Note that `!(w >= x)` is not equivalent to `w < x` for partially 113 // ordered types due to NaNs which are equal to nothing. 114 if !(w.borrow() >= &zero) { 115 return Err(WeightedError::InvalidWeight); 116 } 117 weights.push(total_weight.clone()); 118 total_weight += w.borrow(); 119 } 120 121 if total_weight == zero { 122 return Err(WeightedError::AllWeightsZero); 123 } 124 let distr = X::Sampler::new(zero, total_weight.clone()); 125 126 Ok(WeightedIndex { 127 cumulative_weights: weights, 128 total_weight, 129 weight_distribution: distr, 130 }) 131 } 132 133 /// Update a subset of weights, without changing the number of weights. 134 /// 135 /// `new_weights` must be sorted by the index. 136 /// 137 /// Using this method instead of `new` might be more efficient if only a small number of 138 /// weights is modified. No allocations are performed, unless the weight type `X` uses 139 /// allocation internally. 140 /// 141 /// In case of error, `self` is not modified. update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> where X: for<'a> ::core::ops::AddAssign<&'a X> + for<'a> ::core::ops::SubAssign<&'a X> + Clone + Default142 pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> 143 where X: for<'a> ::core::ops::AddAssign<&'a X> 144 + for<'a> ::core::ops::SubAssign<&'a X> 145 + Clone 146 + Default { 147 if new_weights.is_empty() { 148 return Ok(()); 149 } 150 151 let zero = <X as Default>::default(); 152 153 let mut total_weight = self.total_weight.clone(); 154 155 // Check for errors first, so we don't modify `self` in case something 156 // goes wrong. 157 let mut prev_i = None; 158 for &(i, w) in new_weights { 159 if let Some(old_i) = prev_i { 160 if old_i >= i { 161 return Err(WeightedError::InvalidWeight); 162 } 163 } 164 if !(*w >= zero) { 165 return Err(WeightedError::InvalidWeight); 166 } 167 if i > self.cumulative_weights.len() { 168 return Err(WeightedError::TooMany); 169 } 170 171 let mut old_w = if i < self.cumulative_weights.len() { 172 self.cumulative_weights[i].clone() 173 } else { 174 self.total_weight.clone() 175 }; 176 if i > 0 { 177 old_w -= &self.cumulative_weights[i - 1]; 178 } 179 180 total_weight -= &old_w; 181 total_weight += w; 182 prev_i = Some(i); 183 } 184 if total_weight <= zero { 185 return Err(WeightedError::AllWeightsZero); 186 } 187 188 // Update the weights. Because we checked all the preconditions in the 189 // previous loop, this should never panic. 190 let mut iter = new_weights.iter(); 191 192 let mut prev_weight = zero.clone(); 193 let mut next_new_weight = iter.next(); 194 let &(first_new_index, _) = next_new_weight.unwrap(); 195 let mut cumulative_weight = if first_new_index > 0 { 196 self.cumulative_weights[first_new_index - 1].clone() 197 } else { 198 zero.clone() 199 }; 200 for i in first_new_index..self.cumulative_weights.len() { 201 match next_new_weight { 202 Some(&(j, w)) if i == j => { 203 cumulative_weight += w; 204 next_new_weight = iter.next(); 205 } 206 _ => { 207 let mut tmp = self.cumulative_weights[i].clone(); 208 tmp -= &prev_weight; // We know this is positive. 209 cumulative_weight += &tmp; 210 } 211 } 212 prev_weight = cumulative_weight.clone(); 213 core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); 214 } 215 216 self.total_weight = total_weight; 217 self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); 218 219 Ok(()) 220 } 221 } 222 223 impl<X> Distribution<usize> for WeightedIndex<X> 224 where X: SampleUniform + PartialOrd 225 { sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize226 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { 227 use ::core::cmp::Ordering; 228 let chosen_weight = self.weight_distribution.sample(rng); 229 // Find the first item which has a weight *higher* than the chosen weight. 230 self.cumulative_weights 231 .binary_search_by(|w| { 232 if *w <= chosen_weight { 233 Ordering::Less 234 } else { 235 Ordering::Greater 236 } 237 }) 238 .unwrap_err() 239 } 240 } 241 242 #[cfg(test)] 243 mod test { 244 use super::*; 245 246 #[cfg(feature = "serde1")] 247 #[test] test_weightedindex_serde1()248 fn test_weightedindex_serde1() { 249 let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); 250 251 let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); 252 let de_weighted_index: WeightedIndex<i32> = 253 bincode::deserialize(&ser_weighted_index).unwrap(); 254 255 assert_eq!( 256 de_weighted_index.cumulative_weights, 257 weighted_index.cumulative_weights 258 ); 259 assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight); 260 } 261 262 #[test] test_accepting_nan()263 fn test_accepting_nan(){ 264 assert_eq!( 265 WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), 266 WeightedError::InvalidWeight, 267 ); 268 assert_eq!( 269 WeightedIndex::new(&[core::f32::NAN]).unwrap_err(), 270 WeightedError::InvalidWeight, 271 ); 272 assert_eq!( 273 WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(), 274 WeightedError::InvalidWeight, 275 ); 276 277 assert_eq!( 278 WeightedIndex::new(&[0.5, 7.0]) 279 .unwrap() 280 .update_weights(&[(0, &core::f32::NAN)]) 281 .unwrap_err(), 282 WeightedError::InvalidWeight, 283 ) 284 } 285 286 287 #[test] 288 #[cfg_attr(miri, ignore)] // Miri is too slow test_weightedindex()289 fn test_weightedindex() { 290 let mut r = crate::test::rng(700); 291 const N_REPS: u32 = 5000; 292 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; 293 let total_weight = weights.iter().sum::<u32>() as f32; 294 295 let verify = |result: [i32; 14]| { 296 for (i, count) in result.iter().enumerate() { 297 let exp = (weights[i] * N_REPS) as f32 / total_weight; 298 let mut err = (*count as f32 - exp).abs(); 299 if err != 0.0 { 300 err /= exp; 301 } 302 assert!(err <= 0.25); 303 } 304 }; 305 306 // WeightedIndex from vec 307 let mut chosen = [0i32; 14]; 308 let distr = WeightedIndex::new(weights.to_vec()).unwrap(); 309 for _ in 0..N_REPS { 310 chosen[distr.sample(&mut r)] += 1; 311 } 312 verify(chosen); 313 314 // WeightedIndex from slice 315 chosen = [0i32; 14]; 316 let distr = WeightedIndex::new(&weights[..]).unwrap(); 317 for _ in 0..N_REPS { 318 chosen[distr.sample(&mut r)] += 1; 319 } 320 verify(chosen); 321 322 // WeightedIndex from iterator 323 chosen = [0i32; 14]; 324 let distr = WeightedIndex::new(weights.iter()).unwrap(); 325 for _ in 0..N_REPS { 326 chosen[distr.sample(&mut r)] += 1; 327 } 328 verify(chosen); 329 330 for _ in 0..5 { 331 assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); 332 assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); 333 assert_eq!( 334 WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) 335 .unwrap() 336 .sample(&mut r), 337 4 338 ); 339 } 340 341 assert_eq!( 342 WeightedIndex::new(&[10][0..0]).unwrap_err(), 343 WeightedError::NoItem 344 ); 345 assert_eq!( 346 WeightedIndex::new(&[0]).unwrap_err(), 347 WeightedError::AllWeightsZero 348 ); 349 assert_eq!( 350 WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), 351 WeightedError::InvalidWeight 352 ); 353 assert_eq!( 354 WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), 355 WeightedError::InvalidWeight 356 ); 357 assert_eq!( 358 WeightedIndex::new(&[-10]).unwrap_err(), 359 WeightedError::InvalidWeight 360 ); 361 } 362 363 #[test] test_update_weights()364 fn test_update_weights() { 365 let data = [ 366 ( 367 &[10u32, 2, 3, 4][..], 368 &[(1, &100), (2, &4)][..], // positive change 369 &[10, 100, 4, 4][..], 370 ), 371 ( 372 &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], 373 &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element 374 &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], 375 ), 376 ]; 377 378 for (weights, update, expected_weights) in data.iter() { 379 let total_weight = weights.iter().sum::<u32>(); 380 let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); 381 assert_eq!(distr.total_weight, total_weight); 382 383 distr.update_weights(update).unwrap(); 384 let expected_total_weight = expected_weights.iter().sum::<u32>(); 385 let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); 386 assert_eq!(distr.total_weight, expected_total_weight); 387 assert_eq!(distr.total_weight, expected_distr.total_weight); 388 assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); 389 } 390 } 391 392 #[test] value_stability()393 fn value_stability() { 394 fn test_samples<X: SampleUniform + PartialOrd, I>( 395 weights: I, buf: &mut [usize], expected: &[usize], 396 ) where 397 I: IntoIterator, 398 I::Item: SampleBorrow<X>, 399 X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, 400 { 401 assert_eq!(buf.len(), expected.len()); 402 let distr = WeightedIndex::new(weights).unwrap(); 403 let mut rng = crate::test::rng(701); 404 for r in buf.iter_mut() { 405 *r = rng.sample(&distr); 406 } 407 assert_eq!(buf, expected); 408 } 409 410 let mut buf = [0; 10]; 411 test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ 412 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, 413 ]); 414 test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ 415 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, 416 ]); 417 test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ 418 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, 419 ]); 420 } 421 } 422 423 /// Error type returned from `WeightedIndex::new`. 424 #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] 425 #[derive(Debug, Clone, Copy, PartialEq, Eq)] 426 pub enum WeightedError { 427 /// The provided weight collection contains no items. 428 NoItem, 429 430 /// A weight is either less than zero, greater than the supported maximum, 431 /// NaN, or otherwise invalid. 432 InvalidWeight, 433 434 /// All items in the provided weight collection are zero. 435 AllWeightsZero, 436 437 /// Too many weights are provided (length greater than `u32::MAX`) 438 TooMany, 439 } 440 441 #[cfg(feature = "std")] 442 impl ::std::error::Error for WeightedError {} 443 444 impl fmt::Display for WeightedError { fmt(&self, f: &mut fmt::Formatter) -> fmt::Result445 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 446 match *self { 447 WeightedError::NoItem => write!(f, "No weights provided."), 448 WeightedError::InvalidWeight => write!(f, "A weight is invalid."), 449 WeightedError::AllWeightsZero => write!(f, "All weights are zero."), 450 WeightedError::TooMany => write!(f, "Too many weights (hit u32::MAX)"), 451 } 452 } 453 } 454