• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "absl/random/bernoulli_distribution.h"
16 
17 #include <cmath>
18 #include <cstddef>
19 #include <random>
20 #include <sstream>
21 #include <utility>
22 
23 #include "gtest/gtest.h"
24 #include "absl/random/internal/pcg_engine.h"
25 #include "absl/random/internal/sequence_urbg.h"
26 #include "absl/random/random.h"
27 
28 namespace {
29 
30 class BernoulliTest : public testing::TestWithParam<std::pair<double, size_t>> {
31 };
32 
TEST_P(BernoulliTest,Serialize)33 TEST_P(BernoulliTest, Serialize) {
34   const double d = GetParam().first;
35   absl::bernoulli_distribution before(d);
36 
37   {
38     absl::bernoulli_distribution via_param{
39         absl::bernoulli_distribution::param_type(d)};
40     EXPECT_EQ(via_param, before);
41   }
42 
43   std::stringstream ss;
44   ss << before;
45   absl::bernoulli_distribution after(0.6789);
46 
47   EXPECT_NE(before.p(), after.p());
48   EXPECT_NE(before.param(), after.param());
49   EXPECT_NE(before, after);
50 
51   ss >> after;
52 
53   EXPECT_EQ(before.p(), after.p());
54   EXPECT_EQ(before.param(), after.param());
55   EXPECT_EQ(before, after);
56 }
57 
TEST_P(BernoulliTest,Accuracy)58 TEST_P(BernoulliTest, Accuracy) {
59   // Sadly, the claim to fame for this implementation is precise accuracy, which
60   // is very, very hard to measure, the improvements come as trials approach the
61   // limit of double accuracy; thus the outcome differs from the
62   // std::bernoulli_distribution with a probability of approximately 1 in 2^-53.
63   const std::pair<double, size_t> para = GetParam();
64   size_t trials = para.second;
65   double p = para.first;
66 
67   // We use a fixed bit generator for distribution accuracy tests.  This allows
68   // these tests to be deterministic, while still testing the qualify of the
69   // implementation.
70   absl::random_internal::pcg64_2018_engine rng(0x2B7E151628AED2A6);
71 
72   size_t yes = 0;
73   absl::bernoulli_distribution dist(p);
74   for (size_t i = 0; i < trials; ++i) {
75     if (dist(rng)) yes++;
76   }
77 
78   // Compute the distribution parameters for a binomial test, using a normal
79   // approximation for the confidence interval, as there are a sufficiently
80   // large number of trials that the central limit theorem applies.
81   const double stddev_p = std::sqrt((p * (1.0 - p)) / trials);
82   const double expected = trials * p;
83   const double stddev = trials * stddev_p;
84 
85   // 5 sigma, approved by Richard Feynman
86   EXPECT_NEAR(yes, expected, 5 * stddev)
87       << "@" << p << ", "
88       << std::abs(static_cast<double>(yes) - expected) / stddev << " stddev";
89 }
90 
91 // There must be many more trials to make the mean approximately normal for `p`
92 // closes to 0 or 1.
93 INSTANTIATE_TEST_SUITE_P(
94     All, BernoulliTest,
95     ::testing::Values(
96         // Typical values.
97         std::make_pair(0, 30000), std::make_pair(1e-3, 30000000),
98         std::make_pair(0.1, 3000000), std::make_pair(0.5, 3000000),
99         std::make_pair(0.9, 30000000), std::make_pair(0.999, 30000000),
100         std::make_pair(1, 30000),
101         // Boundary cases.
102         std::make_pair(std::nextafter(1.0, 0.0), 1),  // ~1 - epsilon
103         std::make_pair(std::numeric_limits<double>::epsilon(), 1),
104         std::make_pair(std::nextafter(std::numeric_limits<double>::min(),
105                                       1.0),  // min + epsilon
106                        1),
107         std::make_pair(std::numeric_limits<double>::min(),  // smallest normal
108                        1),
109         std::make_pair(
110             std::numeric_limits<double>::denorm_min(),  // smallest denorm
111             1),
112         std::make_pair(std::numeric_limits<double>::min() / 2, 1),  // denorm
113         std::make_pair(std::nextafter(std::numeric_limits<double>::min(),
114                                       0.0),  // denorm_max
115                        1)));
116 
117 // NOTE: absl::bernoulli_distribution is not guaranteed to be stable.
TEST(BernoulliTest,StabilityTest)118 TEST(BernoulliTest, StabilityTest) {
119   // absl::bernoulli_distribution stability relies on FastUniformBits and
120   // integer arithmetic.
121   absl::random_internal::sequence_urbg urbg({
122       0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull,
123       0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull,
124       0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull,
125       0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull,
126       0x4864f22c059bf29eull, 0x247856d8b862665cull, 0xe46e86e9a1337e10ull,
127       0xd8c8541f3519b133ull, 0xe75b5162c567b9e4ull, 0xf732e5ded7009c5bull,
128       0xb170b98353121eacull, 0x1ec2e8986d2362caull, 0x814c8e35fe9a961aull,
129       0x0c3cd59c9b638a02ull, 0xcb3bb6478a07715cull, 0x1224e62c978bbc7full,
130       0x671ef2cb04e81f6eull, 0x3c1cbd811eaf1808ull, 0x1bbc23cfa8fac721ull,
131       0xa4c2cda65e596a51ull, 0xb77216fad37adf91ull, 0x836d794457c08849ull,
132       0xe083df03475f49d7ull, 0xbc9feb512e6b0d6cull, 0xb12d74fdd718c8c5ull,
133       0x12ff09653bfbe4caull, 0x8dd03a105bc4ee7eull, 0x5738341045ba0d85ull,
134       0xe3fd722dc65ad09eull, 0x5a14fd21ea2a5705ull, 0x14e6ea4d6edb0c73ull,
135       0x275b0dc7e0a18acfull, 0x36cebe0d2653682eull, 0x0361e9b23861596bull,
136   });
137 
138   // Generate a string of '0' and '1' for the distribution output.
139   auto generate = [&urbg](absl::bernoulli_distribution& dist) {
140     std::string output;
141     output.reserve(36);
142     urbg.reset();
143     for (int i = 0; i < 35; i++) {
144       output.append(dist(urbg) ? "1" : "0");
145     }
146     return output;
147   };
148 
149   const double kP = 0.0331289862362;
150   {
151     absl::bernoulli_distribution dist(kP);
152     auto v = generate(dist);
153     EXPECT_EQ(35, urbg.invocations());
154     EXPECT_EQ(v, "00000000000010000000000010000000000") << dist;
155   }
156   {
157     absl::bernoulli_distribution dist(kP * 10.0);
158     auto v = generate(dist);
159     EXPECT_EQ(35, urbg.invocations());
160     EXPECT_EQ(v, "00000100010010010010000011000011010") << dist;
161   }
162   {
163     absl::bernoulli_distribution dist(kP * 20.0);
164     auto v = generate(dist);
165     EXPECT_EQ(35, urbg.invocations());
166     EXPECT_EQ(v, "00011110010110110011011111110111011") << dist;
167   }
168   {
169     absl::bernoulli_distribution dist(1.0 - kP);
170     auto v = generate(dist);
171     EXPECT_EQ(35, urbg.invocations());
172     EXPECT_EQ(v, "11111111111111111111011111111111111") << dist;
173   }
174 }
175 
TEST(BernoulliTest,StabilityTest2)176 TEST(BernoulliTest, StabilityTest2) {
177   absl::random_internal::sequence_urbg urbg(
178       {0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull, 0xC332DDEFBE6C5AA5ull,
179        0x6558218568AB9702ull, 0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull,
180        0xECDD4775619F1510ull, 0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull,
181        0xB5735C904C70A239ull, 0xD59E9E0BCBAADE14ull, 0xEECC86BC60622CA7ull});
182 
183   // Generate a string of '0' and '1' for the distribution output.
184   auto generate = [&urbg](absl::bernoulli_distribution& dist) {
185     std::string output;
186     output.reserve(13);
187     urbg.reset();
188     for (int i = 0; i < 12; i++) {
189       output.append(dist(urbg) ? "1" : "0");
190     }
191     return output;
192   };
193 
194   constexpr double b0 = 1.0 / 13.0 / 0.2;
195   constexpr double b1 = 2.0 / 13.0 / 0.2;
196   constexpr double b3 = (5.0 / 13.0 / 0.2) - ((1 - b0) + (1 - b1) + (1 - b1));
197   {
198     absl::bernoulli_distribution dist(b0);
199     auto v = generate(dist);
200     EXPECT_EQ(12, urbg.invocations());
201     EXPECT_EQ(v, "000011100101") << dist;
202   }
203   {
204     absl::bernoulli_distribution dist(b1);
205     auto v = generate(dist);
206     EXPECT_EQ(12, urbg.invocations());
207     EXPECT_EQ(v, "001111101101") << dist;
208   }
209   {
210     absl::bernoulli_distribution dist(b3);
211     auto v = generate(dist);
212     EXPECT_EQ(12, urbg.invocations());
213     EXPECT_EQ(v, "001111101111") << dist;
214   }
215 }
216 
217 }  // namespace
218