1 //===----------------------------------------------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is dual licensed under the MIT and the University of Illinois Open 6 // Source Licenses. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // REQUIRES: long_tests 11 12 // <random> 13 14 // template<class RealType = double> 15 // class gamma_distribution 16 17 // template<class _URNG> result_type operator()(_URNG& g); 18 19 #include <random> 20 #include <cassert> 21 #include <vector> 22 #include <numeric> 23 24 template <class T> 25 inline 26 T sqr(T x)27sqr(T x) 28 { 29 return x * x; 30 } 31 main()32int main() 33 { 34 { 35 typedef std::gamma_distribution<> D; 36 typedef std::mt19937 G; 37 G g; 38 D d(0.5, 2); 39 const int N = 1000000; 40 std::vector<D::result_type> u; 41 for (int i = 0; i < N; ++i) 42 { 43 D::result_type v = d(g); 44 assert(d.min() < v); 45 u.push_back(v); 46 } 47 double mean = std::accumulate(u.begin(), u.end(), 0.0) / u.size(); 48 double var = 0; 49 double skew = 0; 50 double kurtosis = 0; 51 for (unsigned i = 0; i < u.size(); ++i) 52 { 53 double dbl = (u[i] - mean); 54 double d2 = sqr(dbl); 55 var += d2; 56 skew += dbl * d2; 57 kurtosis += d2 * d2; 58 } 59 var /= u.size(); 60 double dev = std::sqrt(var); 61 skew /= u.size() * dev * var; 62 kurtosis /= u.size() * var * var; 63 kurtosis -= 3; 64 double x_mean = d.alpha() * d.beta(); 65 double x_var = d.alpha() * sqr(d.beta()); 66 double x_skew = 2 / std::sqrt(d.alpha()); 67 double x_kurtosis = 6 / d.alpha(); 68 assert(std::abs((mean - x_mean) / x_mean) < 0.01); 69 assert(std::abs((var - x_var) / x_var) < 0.01); 70 assert(std::abs((skew - x_skew) / x_skew) < 0.01); 71 assert(std::abs((kurtosis - x_kurtosis) / x_kurtosis) < 0.01); 72 } 73 { 74 typedef std::gamma_distribution<> D; 75 typedef std::mt19937 G; 76 G g; 77 D d(1, .5); 78 const int N = 1000000; 79 std::vector<D::result_type> u; 80 for (int i = 0; i < N; ++i) 81 { 82 D::result_type v = d(g); 83 assert(d.min() < v); 84 u.push_back(v); 85 } 86 double mean = std::accumulate(u.begin(), u.end(), 0.0) / u.size(); 87 double var = 0; 88 double skew = 0; 89 double kurtosis = 0; 90 for (unsigned i = 0; i < u.size(); ++i) 91 { 92 double dbl = (u[i] - mean); 93 double d2 = sqr(dbl); 94 var += d2; 95 skew += dbl * d2; 96 kurtosis += d2 * d2; 97 } 98 var /= u.size(); 99 double dev = std::sqrt(var); 100 skew /= u.size() * dev * var; 101 kurtosis /= u.size() * var * var; 102 kurtosis -= 3; 103 double x_mean = d.alpha() * d.beta(); 104 double x_var = d.alpha() * sqr(d.beta()); 105 double x_skew = 2 / std::sqrt(d.alpha()); 106 double x_kurtosis = 6 / d.alpha(); 107 assert(std::abs((mean - x_mean) / x_mean) < 0.01); 108 assert(std::abs((var - x_var) / x_var) < 0.01); 109 assert(std::abs((skew - x_skew) / x_skew) < 0.01); 110 assert(std::abs((kurtosis - x_kurtosis) / x_kurtosis) < 0.01); 111 } 112 { 113 typedef std::gamma_distribution<> D; 114 typedef std::mt19937 G; 115 G g; 116 D d(2, 3); 117 const int N = 1000000; 118 std::vector<D::result_type> u; 119 for (int i = 0; i < N; ++i) 120 { 121 D::result_type v = d(g); 122 assert(d.min() < v); 123 u.push_back(v); 124 } 125 double mean = std::accumulate(u.begin(), u.end(), 0.0) / u.size(); 126 double var = 0; 127 double skew = 0; 128 double kurtosis = 0; 129 for (unsigned i = 0; i < u.size(); ++i) 130 { 131 double dbl = (u[i] - mean); 132 double d2 = sqr(dbl); 133 var += d2; 134 skew += dbl * d2; 135 kurtosis += d2 * d2; 136 } 137 var /= u.size(); 138 double dev = std::sqrt(var); 139 skew /= u.size() * dev * var; 140 kurtosis /= u.size() * var * var; 141 kurtosis -= 3; 142 double x_mean = d.alpha() * d.beta(); 143 double x_var = d.alpha() * sqr(d.beta()); 144 double x_skew = 2 / std::sqrt(d.alpha()); 145 double x_kurtosis = 6 / d.alpha(); 146 assert(std::abs((mean - x_mean) / x_mean) < 0.01); 147 assert(std::abs((var - x_var) / x_var) < 0.01); 148 assert(std::abs((skew - x_skew) / x_skew) < 0.01); 149 assert(std::abs((kurtosis - x_kurtosis) / x_kurtosis) < 0.01); 150 } 151 } 152