• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* boost random/binomial_distribution.hpp header file
2  *
3  * Copyright Steven Watanabe 2010
4  * Distributed under the Boost Software License, Version 1.0. (See
5  * accompanying file LICENSE_1_0.txt or copy at
6  * http://www.boost.org/LICENSE_1_0.txt)
7  *
8  * See http://www.boost.org for most recent version including documentation.
9  *
10  * $Id$
11  */
12 
13 #ifndef BOOST_RANDOM_BINOMIAL_DISTRIBUTION_HPP_INCLUDED
14 #define BOOST_RANDOM_BINOMIAL_DISTRIBUTION_HPP_INCLUDED
15 
16 #include <boost/config/no_tr1/cmath.hpp>
17 #include <cstdlib>
18 #include <iosfwd>
19 
20 #include <boost/random/detail/config.hpp>
21 #include <boost/random/uniform_01.hpp>
22 
23 #include <boost/random/detail/disable_warnings.hpp>
24 
25 namespace boost {
26 namespace random {
27 
28 namespace detail {
29 
30 template<class RealType>
31 struct binomial_table {
32     static const RealType table[10];
33 };
34 
35 template<class RealType>
36 const RealType binomial_table<RealType>::table[10] = {
37     0.08106146679532726,
38     0.04134069595540929,
39     0.02767792568499834,
40     0.02079067210376509,
41     0.01664469118982119,
42     0.01387612882307075,
43     0.01189670994589177,
44     0.01041126526197209,
45     0.009255462182712733,
46     0.008330563433362871
47 };
48 
49 }
50 
51 /**
52  * The binomial distribution is an integer valued distribution with
53  * two parameters, @c t and @c p.  The values of the distribution
54  * are within the range [0,t].
55  *
56  * The distribution function is
57  * \f$\displaystyle P(k) = {t \choose k}p^k(1-p)^{t-k}\f$.
58  *
59  * The algorithm used is the BTRD algorithm described in
60  *
61  *  @blockquote
62  *  "The generation of binomial random variates", Wolfgang Hormann,
63  *  Journal of Statistical Computation and Simulation, Volume 46,
64  *  Issue 1 & 2 April 1993 , pages 101 - 110
65  *  @endblockquote
66  */
67 template<class IntType = int, class RealType = double>
68 class binomial_distribution {
69 public:
70     typedef IntType result_type;
71     typedef RealType input_type;
72 
73     class param_type {
74     public:
75         typedef binomial_distribution distribution_type;
76         /**
77          * Construct a param_type object.  @c t and @c p
78          * are the parameters of the distribution.
79          *
80          * Requires: t >=0 && 0 <= p <= 1
81          */
param_type(IntType t_arg=1,RealType p_arg=RealType (0.5))82         explicit param_type(IntType t_arg = 1, RealType p_arg = RealType (0.5))
83           : _t(t_arg), _p(p_arg)
84         {}
85         /** Returns the @c t parameter of the distribution. */
t() const86         IntType t() const { return _t; }
87         /** Returns the @c p parameter of the distribution. */
p() const88         RealType p() const { return _p; }
89 #ifndef BOOST_RANDOM_NO_STREAM_OPERATORS
90         /** Writes the parameters of the distribution to a @c std::ostream. */
91         template<class CharT, class Traits>
92         friend std::basic_ostream<CharT,Traits>&
operator <<(std::basic_ostream<CharT,Traits> & os,const param_type & parm)93         operator<<(std::basic_ostream<CharT,Traits>& os,
94                    const param_type& parm)
95         {
96             os << parm._p << " " << parm._t;
97             return os;
98         }
99 
100         /** Reads the parameters of the distribution from a @c std::istream. */
101         template<class CharT, class Traits>
102         friend std::basic_istream<CharT,Traits>&
operator >>(std::basic_istream<CharT,Traits> & is,param_type & parm)103         operator>>(std::basic_istream<CharT,Traits>& is, param_type& parm)
104         {
105             is >> parm._p >> std::ws >> parm._t;
106             return is;
107         }
108 #endif
109         /** Returns true if the parameters have the same values. */
operator ==(const param_type & lhs,const param_type & rhs)110         friend bool operator==(const param_type& lhs, const param_type& rhs)
111         {
112             return lhs._t == rhs._t && lhs._p == rhs._p;
113         }
114         /** Returns true if the parameters have different values. */
operator !=(const param_type & lhs,const param_type & rhs)115         friend bool operator!=(const param_type& lhs, const param_type& rhs)
116         {
117             return !(lhs == rhs);
118         }
119     private:
120         IntType _t;
121         RealType _p;
122     };
123 
124     /**
125      * Construct a @c binomial_distribution object. @c t and @c p
126      * are the parameters of the distribution.
127      *
128      * Requires: t >=0 && 0 <= p <= 1
129      */
binomial_distribution(IntType t_arg=1,RealType p_arg=RealType (0.5))130     explicit binomial_distribution(IntType t_arg = 1,
131                                    RealType p_arg = RealType(0.5))
132       : _t(t_arg), _p(p_arg)
133     {
134         init();
135     }
136 
137     /**
138      * Construct an @c binomial_distribution object from the
139      * parameters.
140      */
binomial_distribution(const param_type & parm)141     explicit binomial_distribution(const param_type& parm)
142       : _t(parm.t()), _p(parm.p())
143     {
144         init();
145     }
146 
147     /**
148      * Returns a random variate distributed according to the
149      * binomial distribution.
150      */
151     template<class URNG>
operator ()(URNG & urng) const152     IntType operator()(URNG& urng) const
153     {
154         if(use_inversion()) {
155             if(0.5 < _p) {
156                 return _t - invert(_t, 1-_p, urng);
157             } else {
158                 return invert(_t, _p, urng);
159             }
160         } else if(0.5 < _p) {
161             return _t - generate(urng);
162         } else {
163             return generate(urng);
164         }
165     }
166 
167     /**
168      * Returns a random variate distributed according to the
169      * binomial distribution with parameters specified by @c param.
170      */
171     template<class URNG>
operator ()(URNG & urng,const param_type & parm) const172     IntType operator()(URNG& urng, const param_type& parm) const
173     {
174         return binomial_distribution(parm)(urng);
175     }
176 
177     /** Returns the @c t parameter of the distribution. */
t() const178     IntType t() const { return _t; }
179     /** Returns the @c p parameter of the distribution. */
p() const180     RealType p() const { return _p; }
181 
182     /** Returns the smallest value that the distribution can produce. */
BOOST_PREVENT_MACRO_SUBSTITUTION() const183     IntType min BOOST_PREVENT_MACRO_SUBSTITUTION() const { return 0; }
184     /** Returns the largest value that the distribution can produce. */
BOOST_PREVENT_MACRO_SUBSTITUTION() const185     IntType max BOOST_PREVENT_MACRO_SUBSTITUTION() const { return _t; }
186 
187     /** Returns the parameters of the distribution. */
param() const188     param_type param() const { return param_type(_t, _p); }
189     /** Sets parameters of the distribution. */
param(const param_type & parm)190     void param(const param_type& parm)
191     {
192         _t = parm.t();
193         _p = parm.p();
194         init();
195     }
196 
197     /**
198      * Effects: Subsequent uses of the distribution do not depend
199      * on values produced by any engine prior to invoking reset.
200      */
reset()201     void reset() { }
202 
203 #ifndef BOOST_RANDOM_NO_STREAM_OPERATORS
204     /** Writes the parameters of the distribution to a @c std::ostream. */
205     template<class CharT, class Traits>
206     friend std::basic_ostream<CharT,Traits>&
operator <<(std::basic_ostream<CharT,Traits> & os,const binomial_distribution & bd)207     operator<<(std::basic_ostream<CharT,Traits>& os,
208                const binomial_distribution& bd)
209     {
210         os << bd.param();
211         return os;
212     }
213 
214     /** Reads the parameters of the distribution from a @c std::istream. */
215     template<class CharT, class Traits>
216     friend std::basic_istream<CharT,Traits>&
operator >>(std::basic_istream<CharT,Traits> & is,binomial_distribution & bd)217     operator>>(std::basic_istream<CharT,Traits>& is, binomial_distribution& bd)
218     {
219         bd.read(is);
220         return is;
221     }
222 #endif
223 
224     /** Returns true if the two distributions will produce the same
225         sequence of values, given equal generators. */
operator ==(const binomial_distribution & lhs,const binomial_distribution & rhs)226     friend bool operator==(const binomial_distribution& lhs,
227                            const binomial_distribution& rhs)
228     {
229         return lhs._t == rhs._t && lhs._p == rhs._p;
230     }
231     /** Returns true if the two distributions could produce different
232         sequences of values, given equal generators. */
operator !=(const binomial_distribution & lhs,const binomial_distribution & rhs)233     friend bool operator!=(const binomial_distribution& lhs,
234                            const binomial_distribution& rhs)
235     {
236         return !(lhs == rhs);
237     }
238 
239 private:
240 
241     /// @cond show_private
242 
243     template<class CharT, class Traits>
read(std::basic_istream<CharT,Traits> & is)244     void read(std::basic_istream<CharT, Traits>& is) {
245         param_type parm;
246         if(is >> parm) {
247             param(parm);
248         }
249     }
250 
use_inversion() const251     bool use_inversion() const
252     {
253         // BTRD is safe when np >= 10
254         return m < 11;
255     }
256 
257     // computes the correction factor for the Stirling approximation
258     // for log(k!)
fc(IntType k)259     static RealType fc(IntType k)
260     {
261         if(k < 10) return detail::binomial_table<RealType>::table[k];
262         else {
263             RealType ikp1 = RealType(1) / (k + 1);
264             return (RealType(1)/12
265                  - (RealType(1)/360
266                  - (RealType(1)/1260)*(ikp1*ikp1))*(ikp1*ikp1))*ikp1;
267         }
268     }
269 
init()270     void init()
271     {
272         using std::sqrt;
273         using std::pow;
274 
275         RealType p = (0.5 < _p)? (1 - _p) : _p;
276         IntType t = _t;
277 
278         m = static_cast<IntType>((t+1)*p);
279 
280         if(use_inversion()) {
281             _u.q_n = pow((1 - p), static_cast<RealType>(t));
282         } else {
283             _u.btrd.r = p/(1-p);
284             _u.btrd.nr = (t+1)*_u.btrd.r;
285             _u.btrd.npq = t*p*(1-p);
286             RealType sqrt_npq = sqrt(_u.btrd.npq);
287             _u.btrd.b = 1.15 + 2.53 * sqrt_npq;
288             _u.btrd.a = -0.0873 + 0.0248*_u.btrd.b + 0.01*p;
289             _u.btrd.c = t*p + 0.5;
290             _u.btrd.alpha = (2.83 + 5.1/_u.btrd.b) * sqrt_npq;
291             _u.btrd.v_r = 0.92 - 4.2/_u.btrd.b;
292             _u.btrd.u_rv_r = 0.86*_u.btrd.v_r;
293         }
294     }
295 
296     template<class URNG>
generate(URNG & urng) const297     result_type generate(URNG& urng) const
298     {
299         using std::floor;
300         using std::abs;
301         using std::log;
302 
303         while(true) {
304             RealType u;
305             RealType v = uniform_01<RealType>()(urng);
306             if(v <= _u.btrd.u_rv_r) {
307                 u = v/_u.btrd.v_r - 0.43;
308                 return static_cast<IntType>(floor(
309                     (2*_u.btrd.a/(0.5 - abs(u)) + _u.btrd.b)*u + _u.btrd.c));
310             }
311 
312             if(v >= _u.btrd.v_r) {
313                 u = uniform_01<RealType>()(urng) - 0.5;
314             } else {
315                 u = v/_u.btrd.v_r - 0.93;
316                 u = ((u < 0)? -0.5 : 0.5) - u;
317                 v = uniform_01<RealType>()(urng) * _u.btrd.v_r;
318             }
319 
320             RealType us = 0.5 - abs(u);
321             IntType k = static_cast<IntType>(floor((2*_u.btrd.a/us + _u.btrd.b)*u + _u.btrd.c));
322             if(k < 0 || k > _t) continue;
323             v = v*_u.btrd.alpha/(_u.btrd.a/(us*us) + _u.btrd.b);
324             RealType km = abs(k - m);
325             if(km <= 15) {
326                 RealType f = 1;
327                 if(m < k) {
328                     IntType i = m;
329                     do {
330                         ++i;
331                         f = f*(_u.btrd.nr/i - _u.btrd.r);
332                     } while(i != k);
333                 } else if(m > k) {
334                     IntType i = k;
335                     do {
336                         ++i;
337                         v = v*(_u.btrd.nr/i - _u.btrd.r);
338                     } while(i != m);
339                 }
340                 if(v <= f) return k;
341                 else continue;
342             } else {
343                 // final acceptance/rejection
344                 v = log(v);
345                 RealType rho =
346                     (km/_u.btrd.npq)*(((km/3. + 0.625)*km + 1./6)/_u.btrd.npq + 0.5);
347                 RealType t = -km*km/(2*_u.btrd.npq);
348                 if(v < t - rho) return k;
349                 if(v > t + rho) continue;
350 
351                 IntType nm = _t - m + 1;
352                 RealType h = (m + 0.5)*log((m + 1)/(_u.btrd.r*nm))
353                            + fc(m) + fc(_t - m);
354 
355                 IntType nk = _t - k + 1;
356                 if(v <= h + (_t+1)*log(static_cast<RealType>(nm)/nk)
357                           + (k + 0.5)*log(nk*_u.btrd.r/(k+1))
358                           - fc(k)
359                           - fc(_t - k))
360                 {
361                     return k;
362                 } else {
363                     continue;
364                 }
365             }
366         }
367     }
368 
369     template<class URNG>
invert(IntType t,RealType p,URNG & urng) const370     IntType invert(IntType t, RealType p, URNG& urng) const
371     {
372         RealType q = 1 - p;
373         RealType s = p / q;
374         RealType a = (t + 1) * s;
375         RealType r = _u.q_n;
376         RealType u = uniform_01<RealType>()(urng);
377         IntType x = 0;
378         while(u > r) {
379             u = u - r;
380             ++x;
381             RealType r1 = ((a/x) - s) * r;
382             // If r gets too small then the round-off error
383             // becomes a problem.  At this point, p(i) is
384             // decreasing exponentially, so if we just call
385             // it 0, it's close enough.  Note that the
386             // minimum value of q_n is about 1e-7, so we
387             // may need to be a little careful to make sure that
388             // we don't terminate the first time through the loop
389             // for float.  (Hence the test that r is decreasing)
390             if(r1 < std::numeric_limits<RealType>::epsilon() && r1 < r) {
391                 break;
392             }
393             r = r1;
394         }
395         return x;
396     }
397 
398     // parameters
399     IntType _t;
400     RealType _p;
401 
402     // common data
403     IntType m;
404 
405     union {
406         // for btrd
407         struct {
408             RealType r;
409             RealType nr;
410             RealType npq;
411             RealType b;
412             RealType a;
413             RealType c;
414             RealType alpha;
415             RealType v_r;
416             RealType u_rv_r;
417         } btrd;
418         // for inversion
419         RealType q_n;
420     } _u;
421 
422     /// @endcond
423 };
424 
425 }
426 
427 // backwards compatibility
428 using random::binomial_distribution;
429 
430 }
431 
432 #include <boost/random/detail/enable_warnings.hpp>
433 
434 #endif
435