1 // Copyright 2017 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_
16 #define ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_
17
18 #include <algorithm>
19 #include <cassert>
20 #include <cmath>
21 #include <istream>
22 #include <limits>
23 #include <ostream>
24 #include <type_traits>
25
26 #include "absl/numeric/bits.h"
27 #include "absl/random/internal/fastmath.h"
28 #include "absl/random/internal/generate_real.h"
29 #include "absl/random/internal/iostream_state_saver.h"
30 #include "absl/random/internal/traits.h"
31 #include "absl/random/uniform_int_distribution.h"
32
33 namespace absl {
34 ABSL_NAMESPACE_BEGIN
35
36 // log_uniform_int_distribution:
37 //
38 // Returns a random variate R in range [min, max] such that
39 // floor(log(R-min, base)) is uniformly distributed.
40 // We ensure uniformity by discretization using the
41 // boundary sets [0, 1, base, base * base, ... min(base*n, max)]
42 //
43 template <typename IntType = int>
44 class log_uniform_int_distribution {
45 private:
46 using unsigned_type =
47 typename random_internal::make_unsigned_bits<IntType>::type;
48
49 public:
50 using result_type = IntType;
51
52 class param_type {
53 public:
54 using distribution_type = log_uniform_int_distribution;
55
56 explicit param_type(
57 result_type min = 0,
58 result_type max = (std::numeric_limits<result_type>::max)(),
59 result_type base = 2)
min_(min)60 : min_(min),
61 max_(max),
62 base_(base),
63 range_(static_cast<unsigned_type>(max_) -
64 static_cast<unsigned_type>(min_)),
65 log_range_(0) {
66 assert(max_ >= min_);
67 assert(base_ > 1);
68
69 if (base_ == 2) {
70 // Determine where the first set bit is on range(), giving a log2(range)
71 // value which can be used to construct bounds.
72 log_range_ =
73 (std::min)(bit_width(range()),
74 static_cast<unsigned_type>(
75 std::numeric_limits<unsigned_type>::digits));
76 } else {
77 // NOTE: Computing the logN(x) introduces error from 2 sources:
78 // 1. Conversion of int to double loses precision for values >=
79 // 2^53, which may cause some log() computations to operate on
80 // different values.
81 // 2. The error introduced by the division will cause the result
82 // to differ from the expected value.
83 //
84 // Thus a result which should equal K may equal K +/- epsilon,
85 // which can eliminate some values depending on where the bounds fall.
86 const double inv_log_base = 1.0 / std::log(base_);
87 const double log_range = std::log(static_cast<double>(range()) + 0.5);
88 log_range_ = static_cast<int>(std::ceil(inv_log_base * log_range));
89 }
90 }
91
result_type(min)92 result_type(min)() const { return min_; }
result_type(max)93 result_type(max)() const { return max_; }
base()94 result_type base() const { return base_; }
95
96 friend bool operator==(const param_type& a, const param_type& b) {
97 return a.min_ == b.min_ && a.max_ == b.max_ && a.base_ == b.base_;
98 }
99
100 friend bool operator!=(const param_type& a, const param_type& b) {
101 return !(a == b);
102 }
103
104 private:
105 friend class log_uniform_int_distribution;
106
log_range()107 int log_range() const { return log_range_; }
range()108 unsigned_type range() const { return range_; }
109
110 result_type min_;
111 result_type max_;
112 result_type base_;
113 unsigned_type range_; // max - min
114 int log_range_; // ceil(logN(range_))
115
116 static_assert(std::is_integral<IntType>::value,
117 "Class-template absl::log_uniform_int_distribution<> must be "
118 "parameterized using an integral type.");
119 };
120
log_uniform_int_distribution()121 log_uniform_int_distribution() : log_uniform_int_distribution(0) {}
122
123 explicit log_uniform_int_distribution(
124 result_type min,
125 result_type max = (std::numeric_limits<result_type>::max)(),
126 result_type base = 2)
param_(min,max,base)127 : param_(min, max, base) {}
128
log_uniform_int_distribution(const param_type & p)129 explicit log_uniform_int_distribution(const param_type& p) : param_(p) {}
130
reset()131 void reset() {}
132
133 // generating functions
134 template <typename URBG>
operator()135 result_type operator()(URBG& g) { // NOLINT(runtime/references)
136 return (*this)(g, param_);
137 }
138
139 template <typename URBG>
operator()140 result_type operator()(URBG& g, // NOLINT(runtime/references)
141 const param_type& p) {
142 return (p.min)() + Generate(g, p);
143 }
144
result_type(min)145 result_type(min)() const { return (param_.min)(); }
result_type(max)146 result_type(max)() const { return (param_.max)(); }
base()147 result_type base() const { return param_.base(); }
148
param()149 param_type param() const { return param_; }
param(const param_type & p)150 void param(const param_type& p) { param_ = p; }
151
152 friend bool operator==(const log_uniform_int_distribution& a,
153 const log_uniform_int_distribution& b) {
154 return a.param_ == b.param_;
155 }
156 friend bool operator!=(const log_uniform_int_distribution& a,
157 const log_uniform_int_distribution& b) {
158 return a.param_ != b.param_;
159 }
160
161 private:
162 // Returns a log-uniform variate in the range [0, p.range()]. The caller
163 // should add min() to shift the result to the correct range.
164 template <typename URNG>
165 unsigned_type Generate(URNG& g, // NOLINT(runtime/references)
166 const param_type& p);
167
168 param_type param_;
169 };
170
171 template <typename IntType>
172 template <typename URBG>
173 typename log_uniform_int_distribution<IntType>::unsigned_type
Generate(URBG & g,const param_type & p)174 log_uniform_int_distribution<IntType>::Generate(
175 URBG& g, // NOLINT(runtime/references)
176 const param_type& p) {
177 // sample e over [0, log_range]. Map the results of e to this:
178 // 0 => 0
179 // 1 => [1, b-1]
180 // 2 => [b, (b^2)-1]
181 // n => [b^(n-1)..(b^n)-1]
182 const int e = absl::uniform_int_distribution<int>(0, p.log_range())(g);
183 if (e == 0) {
184 return 0;
185 }
186 const int d = e - 1;
187
188 unsigned_type base_e, top_e;
189 if (p.base() == 2) {
190 base_e = static_cast<unsigned_type>(1) << d;
191
192 top_e = (e >= std::numeric_limits<unsigned_type>::digits)
193 ? (std::numeric_limits<unsigned_type>::max)()
194 : (static_cast<unsigned_type>(1) << e) - 1;
195 } else {
196 const double r = std::pow(p.base(), d);
197 const double s = (r * p.base()) - 1.0;
198
199 base_e =
200 (r > static_cast<double>((std::numeric_limits<unsigned_type>::max)()))
201 ? (std::numeric_limits<unsigned_type>::max)()
202 : static_cast<unsigned_type>(r);
203
204 top_e =
205 (s > static_cast<double>((std::numeric_limits<unsigned_type>::max)()))
206 ? (std::numeric_limits<unsigned_type>::max)()
207 : static_cast<unsigned_type>(s);
208 }
209
210 const unsigned_type lo = (base_e >= p.range()) ? p.range() : base_e;
211 const unsigned_type hi = (top_e >= p.range()) ? p.range() : top_e;
212
213 // choose uniformly over [lo, hi]
214 return absl::uniform_int_distribution<result_type>(lo, hi)(g);
215 }
216
217 template <typename CharT, typename Traits, typename IntType>
218 std::basic_ostream<CharT, Traits>& operator<<(
219 std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references)
220 const log_uniform_int_distribution<IntType>& x) {
221 using stream_type =
222 typename random_internal::stream_format_type<IntType>::type;
223 auto saver = random_internal::make_ostream_state_saver(os);
224 os << static_cast<stream_type>((x.min)()) << os.fill()
225 << static_cast<stream_type>((x.max)()) << os.fill()
226 << static_cast<stream_type>(x.base());
227 return os;
228 }
229
230 template <typename CharT, typename Traits, typename IntType>
231 std::basic_istream<CharT, Traits>& operator>>(
232 std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references)
233 log_uniform_int_distribution<IntType>& x) { // NOLINT(runtime/references)
234 using param_type = typename log_uniform_int_distribution<IntType>::param_type;
235 using result_type =
236 typename log_uniform_int_distribution<IntType>::result_type;
237 using stream_type =
238 typename random_internal::stream_format_type<IntType>::type;
239
240 stream_type min;
241 stream_type max;
242 stream_type base;
243
244 auto saver = random_internal::make_istream_state_saver(is);
245 is >> min >> max >> base;
246 if (!is.fail()) {
247 x.param(param_type(static_cast<result_type>(min),
248 static_cast<result_type>(max),
249 static_cast<result_type>(base)));
250 }
251 return is;
252 }
253
254 ABSL_NAMESPACE_END
255 } // namespace absl
256
257 #endif // ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_
258