• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* test_poisson.cpp
2  *
3  * Copyright Steven Watanabe 2010
4  * Distributed under the Boost Software License, Version 1.0. (See
5  * accompanying file LICENSE_1_0.txt or copy at
6  * http://www.boost.org/LICENSE_1_0.txt)
7  *
8  * $Id$
9  *
10  */
11 
12 #include <boost/random/discrete_distribution.hpp>
13 #include <boost/random/uniform_int.hpp>
14 #include <boost/random/mersenne_twister.hpp>
15 #include <boost/lexical_cast.hpp>
16 #include <boost/exception/diagnostic_information.hpp>
17 #include <vector>
18 #include <iostream>
19 #include <numeric>
20 
21 #include "chi_squared_test.hpp"
22 
do_test(int n,long long max)23 bool do_test(int n, long long max) {
24     std::cout << "running discrete(p0, p1, ..., p" << n-1 << ")" << " " << max << " times: " << std::flush;
25 
26     std::vector<double> expected;
27     {
28         boost::mt19937 egen;
29         for(int i = 0; i < n; ++i) {
30             expected.push_back(egen());
31         }
32         double sum = std::accumulate(expected.begin(), expected.end(), 0.0);
33         for(std::vector<double>::iterator iter = expected.begin(), end = expected.end(); iter != end; ++iter) {
34             *iter /= sum;
35         }
36     }
37 
38     boost::random::discrete_distribution<> dist(expected);
39     boost::mt19937 gen;
40     std::vector<long long> results(expected.size());
41     for(long long i = 0; i < max; ++i) {
42         ++results[dist(gen)];
43     }
44 
45     long long sum = std::accumulate(results.begin(), results.end(), 0ll);
46     if(sum != max) {
47         std::cout << "*** Failed: incorrect total: " << sum << " ***" << std::endl;
48         return false;
49     }
50     double chsqr = chi_squared_test(results, expected, max);
51 
52     bool result = chsqr < 0.99;
53     const char* err = result? "" : "*";
54     std::cout << std::setprecision(17) << chsqr << err << std::endl;
55 
56     std::cout << std::setprecision(6);
57 
58     return result;
59 }
60 
do_tests(int repeat,int max_n,long long trials)61 bool do_tests(int repeat, int max_n, long long trials) {
62     boost::mt19937 gen;
63     boost::uniform_int<> idist(1, max_n);
64     int errors = 0;
65     for(int i = 0; i < repeat; ++i) {
66         if(!do_test(idist(gen), trials)) {
67             ++errors;
68         }
69     }
70     if(errors != 0) {
71         std::cout << "*** " << errors << " errors detected ***" << std::endl;
72     }
73     return errors == 0;
74 }
75 
usage()76 int usage() {
77     std::cerr << "Usage: test_discrete -r <repeat> -n <max n> -t <trials>" << std::endl;
78     return 2;
79 }
80 
81 template<class T>
handle_option(int & argc,char ** & argv,char opt,T & value)82 bool handle_option(int& argc, char**& argv, char opt, T& value) {
83     if(argv[0][1] == opt && argc > 1) {
84         --argc;
85         ++argv;
86         value = boost::lexical_cast<T>(argv[0]);
87         return true;
88     } else {
89         return false;
90     }
91 }
92 
main(int argc,char ** argv)93 int main(int argc, char** argv) {
94     int repeat = 10;
95     int max_n = 10000;
96     long long trials = 1000000ll;
97 
98     if(argc > 0) {
99         --argc;
100         ++argv;
101     }
102     while(argc > 0) {
103         if(argv[0][0] != '-') return usage();
104         else if(!handle_option(argc, argv, 'r', repeat)
105              && !handle_option(argc, argv, 'n', max_n)
106              && !handle_option(argc, argv, 't', trials)) {
107             return usage();
108         }
109         --argc;
110         ++argv;
111     }
112 
113     try {
114         if(do_tests(repeat, max_n, trials)) {
115             return 0;
116         } else {
117             return EXIT_FAILURE;
118         }
119     } catch(...) {
120         std::cerr << boost::current_exception_diagnostic_information() << std::endl;
121         return EXIT_FAILURE;
122     }
123 }
124