1// Copyright 2009 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package rand_test 6 7import ( 8 "errors" 9 "fmt" 10 "internal/testenv" 11 "math" 12 . "math/rand/v2" 13 "os" 14 "runtime" 15 "sync" 16 "sync/atomic" 17 "testing" 18) 19 20const ( 21 numTestSamples = 10000 22) 23 24var rn, kn, wn, fn = GetNormalDistributionParameters() 25var re, ke, we, fe = GetExponentialDistributionParameters() 26 27type statsResults struct { 28 mean float64 29 stddev float64 30 closeEnough float64 31 maxError float64 32} 33 34func nearEqual(a, b, closeEnough, maxError float64) bool { 35 absDiff := math.Abs(a - b) 36 if absDiff < closeEnough { // Necessary when one value is zero and one value is close to zero. 37 return true 38 } 39 return absDiff/max(math.Abs(a), math.Abs(b)) < maxError 40} 41 42var testSeeds = []uint64{1, 1754801282, 1698661970, 1550503961} 43 44// checkSimilarDistribution returns success if the mean and stddev of the 45// two statsResults are similar. 46func (sr *statsResults) checkSimilarDistribution(expected *statsResults) error { 47 if !nearEqual(sr.mean, expected.mean, expected.closeEnough, expected.maxError) { 48 s := fmt.Sprintf("mean %v != %v (allowed error %v, %v)", sr.mean, expected.mean, expected.closeEnough, expected.maxError) 49 fmt.Println(s) 50 return errors.New(s) 51 } 52 if !nearEqual(sr.stddev, expected.stddev, expected.closeEnough, expected.maxError) { 53 s := fmt.Sprintf("stddev %v != %v (allowed error %v, %v)", sr.stddev, expected.stddev, expected.closeEnough, expected.maxError) 54 fmt.Println(s) 55 return errors.New(s) 56 } 57 return nil 58} 59 60func getStatsResults(samples []float64) *statsResults { 61 res := new(statsResults) 62 var sum, squaresum float64 63 for _, s := range samples { 64 sum += s 65 squaresum += s * s 66 } 67 res.mean = sum / float64(len(samples)) 68 res.stddev = math.Sqrt(squaresum/float64(len(samples)) - res.mean*res.mean) 69 return res 70} 71 72func checkSampleDistribution(t *testing.T, samples []float64, expected *statsResults) { 73 t.Helper() 74 actual := getStatsResults(samples) 75 err := actual.checkSimilarDistribution(expected) 76 if err != nil { 77 t.Error(err) 78 } 79} 80 81func checkSampleSliceDistributions(t *testing.T, samples []float64, nslices int, expected *statsResults) { 82 t.Helper() 83 chunk := len(samples) / nslices 84 for i := 0; i < nslices; i++ { 85 low := i * chunk 86 var high int 87 if i == nslices-1 { 88 high = len(samples) - 1 89 } else { 90 high = (i + 1) * chunk 91 } 92 checkSampleDistribution(t, samples[low:high], expected) 93 } 94} 95 96// 97// Normal distribution tests 98// 99 100func generateNormalSamples(nsamples int, mean, stddev float64, seed uint64) []float64 { 101 r := New(NewPCG(seed, seed)) 102 samples := make([]float64, nsamples) 103 for i := range samples { 104 samples[i] = r.NormFloat64()*stddev + mean 105 } 106 return samples 107} 108 109func testNormalDistribution(t *testing.T, nsamples int, mean, stddev float64, seed uint64) { 110 //fmt.Printf("testing nsamples=%v mean=%v stddev=%v seed=%v\n", nsamples, mean, stddev, seed); 111 112 samples := generateNormalSamples(nsamples, mean, stddev, seed) 113 errorScale := max(1.0, stddev) // Error scales with stddev 114 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale} 115 116 // Make sure that the entire set matches the expected distribution. 117 checkSampleDistribution(t, samples, expected) 118 119 // Make sure that each half of the set matches the expected distribution. 120 checkSampleSliceDistributions(t, samples, 2, expected) 121 122 // Make sure that each 7th of the set matches the expected distribution. 123 checkSampleSliceDistributions(t, samples, 7, expected) 124} 125 126// Actual tests 127 128func TestStandardNormalValues(t *testing.T) { 129 for _, seed := range testSeeds { 130 testNormalDistribution(t, numTestSamples, 0, 1, seed) 131 } 132} 133 134func TestNonStandardNormalValues(t *testing.T) { 135 sdmax := 1000.0 136 mmax := 1000.0 137 if testing.Short() { 138 sdmax = 5 139 mmax = 5 140 } 141 for sd := 0.5; sd < sdmax; sd *= 2 { 142 for m := 0.5; m < mmax; m *= 2 { 143 for _, seed := range testSeeds { 144 testNormalDistribution(t, numTestSamples, m, sd, seed) 145 if testing.Short() { 146 break 147 } 148 } 149 } 150 } 151} 152 153// 154// Exponential distribution tests 155// 156 157func generateExponentialSamples(nsamples int, rate float64, seed uint64) []float64 { 158 r := New(NewPCG(seed, seed)) 159 samples := make([]float64, nsamples) 160 for i := range samples { 161 samples[i] = r.ExpFloat64() / rate 162 } 163 return samples 164} 165 166func testExponentialDistribution(t *testing.T, nsamples int, rate float64, seed uint64) { 167 //fmt.Printf("testing nsamples=%v rate=%v seed=%v\n", nsamples, rate, seed); 168 169 mean := 1 / rate 170 stddev := mean 171 172 samples := generateExponentialSamples(nsamples, rate, seed) 173 errorScale := max(1.0, 1/rate) // Error scales with the inverse of the rate 174 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.20 * errorScale} 175 176 // Make sure that the entire set matches the expected distribution. 177 checkSampleDistribution(t, samples, expected) 178 179 // Make sure that each half of the set matches the expected distribution. 180 checkSampleSliceDistributions(t, samples, 2, expected) 181 182 // Make sure that each 7th of the set matches the expected distribution. 183 checkSampleSliceDistributions(t, samples, 7, expected) 184} 185 186// Actual tests 187 188func TestStandardExponentialValues(t *testing.T) { 189 for _, seed := range testSeeds { 190 testExponentialDistribution(t, numTestSamples, 1, seed) 191 } 192} 193 194func TestNonStandardExponentialValues(t *testing.T) { 195 for rate := 0.05; rate < 10; rate *= 2 { 196 for _, seed := range testSeeds { 197 testExponentialDistribution(t, numTestSamples, rate, seed) 198 if testing.Short() { 199 break 200 } 201 } 202 } 203} 204 205// 206// Table generation tests 207// 208 209func initNorm() (testKn []uint32, testWn, testFn []float32) { 210 const m1 = 1 << 31 211 var ( 212 dn float64 = rn 213 tn = dn 214 vn float64 = 9.91256303526217e-3 215 ) 216 217 testKn = make([]uint32, 128) 218 testWn = make([]float32, 128) 219 testFn = make([]float32, 128) 220 221 q := vn / math.Exp(-0.5*dn*dn) 222 testKn[0] = uint32((dn / q) * m1) 223 testKn[1] = 0 224 testWn[0] = float32(q / m1) 225 testWn[127] = float32(dn / m1) 226 testFn[0] = 1.0 227 testFn[127] = float32(math.Exp(-0.5 * dn * dn)) 228 for i := 126; i >= 1; i-- { 229 dn = math.Sqrt(-2.0 * math.Log(vn/dn+math.Exp(-0.5*dn*dn))) 230 testKn[i+1] = uint32((dn / tn) * m1) 231 tn = dn 232 testFn[i] = float32(math.Exp(-0.5 * dn * dn)) 233 testWn[i] = float32(dn / m1) 234 } 235 return 236} 237 238func initExp() (testKe []uint32, testWe, testFe []float32) { 239 const m2 = 1 << 32 240 var ( 241 de float64 = re 242 te = de 243 ve float64 = 3.9496598225815571993e-3 244 ) 245 246 testKe = make([]uint32, 256) 247 testWe = make([]float32, 256) 248 testFe = make([]float32, 256) 249 250 q := ve / math.Exp(-de) 251 testKe[0] = uint32((de / q) * m2) 252 testKe[1] = 0 253 testWe[0] = float32(q / m2) 254 testWe[255] = float32(de / m2) 255 testFe[0] = 1.0 256 testFe[255] = float32(math.Exp(-de)) 257 for i := 254; i >= 1; i-- { 258 de = -math.Log(ve/de + math.Exp(-de)) 259 testKe[i+1] = uint32((de / te) * m2) 260 te = de 261 testFe[i] = float32(math.Exp(-de)) 262 testWe[i] = float32(de / m2) 263 } 264 return 265} 266 267// compareUint32Slices returns the first index where the two slices 268// disagree, or <0 if the lengths are the same and all elements 269// are identical. 270func compareUint32Slices(s1, s2 []uint32) int { 271 if len(s1) != len(s2) { 272 if len(s1) > len(s2) { 273 return len(s2) + 1 274 } 275 return len(s1) + 1 276 } 277 for i := range s1 { 278 if s1[i] != s2[i] { 279 return i 280 } 281 } 282 return -1 283} 284 285// compareFloat32Slices returns the first index where the two slices 286// disagree, or <0 if the lengths are the same and all elements 287// are identical. 288func compareFloat32Slices(s1, s2 []float32) int { 289 if len(s1) != len(s2) { 290 if len(s1) > len(s2) { 291 return len(s2) + 1 292 } 293 return len(s1) + 1 294 } 295 for i := range s1 { 296 if !nearEqual(float64(s1[i]), float64(s2[i]), 0, 1e-7) { 297 return i 298 } 299 } 300 return -1 301} 302 303func TestNormTables(t *testing.T) { 304 testKn, testWn, testFn := initNorm() 305 if i := compareUint32Slices(kn[0:], testKn); i >= 0 { 306 t.Errorf("kn disagrees at index %v; %v != %v", i, kn[i], testKn[i]) 307 } 308 if i := compareFloat32Slices(wn[0:], testWn); i >= 0 { 309 t.Errorf("wn disagrees at index %v; %v != %v", i, wn[i], testWn[i]) 310 } 311 if i := compareFloat32Slices(fn[0:], testFn); i >= 0 { 312 t.Errorf("fn disagrees at index %v; %v != %v", i, fn[i], testFn[i]) 313 } 314} 315 316func TestExpTables(t *testing.T) { 317 testKe, testWe, testFe := initExp() 318 if i := compareUint32Slices(ke[0:], testKe); i >= 0 { 319 t.Errorf("ke disagrees at index %v; %v != %v", i, ke[i], testKe[i]) 320 } 321 if i := compareFloat32Slices(we[0:], testWe); i >= 0 { 322 t.Errorf("we disagrees at index %v; %v != %v", i, we[i], testWe[i]) 323 } 324 if i := compareFloat32Slices(fe[0:], testFe); i >= 0 { 325 t.Errorf("fe disagrees at index %v; %v != %v", i, fe[i], testFe[i]) 326 } 327} 328 329func hasSlowFloatingPoint() bool { 330 switch runtime.GOARCH { 331 case "arm": 332 return os.Getenv("GOARM") == "5" 333 case "mips", "mipsle", "mips64", "mips64le": 334 // Be conservative and assume that all mips boards 335 // have emulated floating point. 336 // TODO: detect what it actually has. 337 return true 338 } 339 return false 340} 341 342func TestFloat32(t *testing.T) { 343 // For issue 6721, the problem came after 7533753 calls, so check 10e6. 344 num := int(10e6) 345 // But do the full amount only on builders (not locally). 346 // But ARM5 floating point emulation is slow (Issue 10749), so 347 // do less for that builder: 348 if testing.Short() && (testenv.Builder() == "" || hasSlowFloatingPoint()) { 349 num /= 100 // 1.72 seconds instead of 172 seconds 350 } 351 352 r := testRand() 353 for ct := 0; ct < num; ct++ { 354 f := r.Float32() 355 if f >= 1 { 356 t.Fatal("Float32() should be in range [0,1). ct:", ct, "f:", f) 357 } 358 } 359} 360 361func TestShuffleSmall(t *testing.T) { 362 // Check that Shuffle allows n=0 and n=1, but that swap is never called for them. 363 r := testRand() 364 for n := 0; n <= 1; n++ { 365 r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) }) 366 } 367} 368 369// encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!). 370// See https://en.wikipedia.org/wiki/Lehmer_code. 371// encodePerm modifies the input slice. 372func encodePerm(s []int) int { 373 // Convert to Lehmer code. 374 for i, x := range s { 375 r := s[i+1:] 376 for j, y := range r { 377 if y > x { 378 r[j]-- 379 } 380 } 381 } 382 // Convert to int in [0, n!). 383 m := 0 384 fact := 1 385 for i := len(s) - 1; i >= 0; i-- { 386 m += s[i] * fact 387 fact *= len(s) - i 388 } 389 return m 390} 391 392// TestUniformFactorial tests several ways of generating a uniform value in [0, n!). 393func TestUniformFactorial(t *testing.T) { 394 r := New(NewPCG(1, 2)) 395 top := 6 396 if testing.Short() { 397 top = 3 398 } 399 for n := 3; n <= top; n++ { 400 t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { 401 // Calculate n!. 402 nfact := 1 403 for i := 2; i <= n; i++ { 404 nfact *= i 405 } 406 407 // Test a few different ways to generate a uniform distribution. 408 p := make([]int, n) // re-usable slice for Shuffle generator 409 tests := [...]struct { 410 name string 411 fn func() int 412 }{ 413 {name: "Int32N", fn: func() int { return int(r.Int32N(int32(nfact))) }}, 414 {name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }}, 415 {name: "Shuffle", fn: func() int { 416 // Generate permutation using Shuffle. 417 for i := range p { 418 p[i] = i 419 } 420 r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] }) 421 return encodePerm(p) 422 }}, 423 } 424 425 for _, test := range tests { 426 t.Run(test.name, func(t *testing.T) { 427 // Gather chi-squared values and check that they follow 428 // the expected normal distribution given n!-1 degrees of freedom. 429 // See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and 430 // https://www.johndcook.com/Beautiful_Testing_ch10.pdf. 431 nsamples := 10 * nfact 432 if nsamples < 1000 { 433 nsamples = 1000 434 } 435 samples := make([]float64, nsamples) 436 for i := range samples { 437 // Generate some uniformly distributed values and count their occurrences. 438 const iters = 1000 439 counts := make([]int, nfact) 440 for i := 0; i < iters; i++ { 441 counts[test.fn()]++ 442 } 443 // Calculate chi-squared and add to samples. 444 want := iters / float64(nfact) 445 var χ2 float64 446 for _, have := range counts { 447 err := float64(have) - want 448 χ2 += err * err 449 } 450 χ2 /= want 451 samples[i] = χ2 452 } 453 454 // Check that our samples approximate the appropriate normal distribution. 455 dof := float64(nfact - 1) 456 expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)} 457 errorScale := max(1.0, expected.stddev) 458 expected.closeEnough = 0.10 * errorScale 459 expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211. 460 checkSampleDistribution(t, samples, expected) 461 }) 462 } 463 }) 464 } 465} 466 467// Benchmarks 468 469var Sink uint64 470 471func testRand() *Rand { 472 return New(NewPCG(1, 2)) 473} 474 475func BenchmarkSourceUint64(b *testing.B) { 476 s := NewPCG(1, 2) 477 var t uint64 478 for n := b.N; n > 0; n-- { 479 t += s.Uint64() 480 } 481 Sink = uint64(t) 482} 483 484func BenchmarkGlobalInt64(b *testing.B) { 485 var t int64 486 for n := b.N; n > 0; n-- { 487 t += Int64() 488 } 489 Sink = uint64(t) 490} 491 492func BenchmarkGlobalInt64Parallel(b *testing.B) { 493 b.RunParallel(func(pb *testing.PB) { 494 var t int64 495 for pb.Next() { 496 t += Int64() 497 } 498 atomic.AddUint64(&Sink, uint64(t)) 499 }) 500} 501 502func BenchmarkGlobalUint64(b *testing.B) { 503 var t uint64 504 for n := b.N; n > 0; n-- { 505 t += Uint64() 506 } 507 Sink = t 508} 509 510func BenchmarkGlobalUint64Parallel(b *testing.B) { 511 b.RunParallel(func(pb *testing.PB) { 512 var t uint64 513 for pb.Next() { 514 t += Uint64() 515 } 516 atomic.AddUint64(&Sink, t) 517 }) 518} 519 520func BenchmarkInt64(b *testing.B) { 521 r := testRand() 522 var t int64 523 for n := b.N; n > 0; n-- { 524 t += r.Int64() 525 } 526 Sink = uint64(t) 527} 528 529var AlwaysFalse = false 530 531func keep[T int | uint | int32 | uint32 | int64 | uint64](x T) T { 532 if AlwaysFalse { 533 return -x 534 } 535 return x 536} 537 538func BenchmarkUint64(b *testing.B) { 539 r := testRand() 540 var t uint64 541 for n := b.N; n > 0; n-- { 542 t += r.Uint64() 543 } 544 Sink = t 545} 546 547func BenchmarkGlobalIntN1000(b *testing.B) { 548 var t int 549 arg := keep(1000) 550 for n := b.N; n > 0; n-- { 551 t += IntN(arg) 552 } 553 Sink = uint64(t) 554} 555 556func BenchmarkIntN1000(b *testing.B) { 557 r := testRand() 558 var t int 559 arg := keep(1000) 560 for n := b.N; n > 0; n-- { 561 t += r.IntN(arg) 562 } 563 Sink = uint64(t) 564} 565 566func BenchmarkInt64N1000(b *testing.B) { 567 r := testRand() 568 var t int64 569 arg := keep(int64(1000)) 570 for n := b.N; n > 0; n-- { 571 t += r.Int64N(arg) 572 } 573 Sink = uint64(t) 574} 575 576func BenchmarkInt64N1e8(b *testing.B) { 577 r := testRand() 578 var t int64 579 arg := keep(int64(1e8)) 580 for n := b.N; n > 0; n-- { 581 t += r.Int64N(arg) 582 } 583 Sink = uint64(t) 584} 585 586func BenchmarkInt64N1e9(b *testing.B) { 587 r := testRand() 588 var t int64 589 arg := keep(int64(1e9)) 590 for n := b.N; n > 0; n-- { 591 t += r.Int64N(arg) 592 } 593 Sink = uint64(t) 594} 595 596func BenchmarkInt64N2e9(b *testing.B) { 597 r := testRand() 598 var t int64 599 arg := keep(int64(2e9)) 600 for n := b.N; n > 0; n-- { 601 t += r.Int64N(arg) 602 } 603 Sink = uint64(t) 604} 605 606func BenchmarkInt64N1e18(b *testing.B) { 607 r := testRand() 608 var t int64 609 arg := keep(int64(1e18)) 610 for n := b.N; n > 0; n-- { 611 t += r.Int64N(arg) 612 } 613 Sink = uint64(t) 614} 615 616func BenchmarkInt64N2e18(b *testing.B) { 617 r := testRand() 618 var t int64 619 arg := keep(int64(2e18)) 620 for n := b.N; n > 0; n-- { 621 t += r.Int64N(arg) 622 } 623 Sink = uint64(t) 624} 625 626func BenchmarkInt64N4e18(b *testing.B) { 627 r := testRand() 628 var t int64 629 arg := keep(int64(4e18)) 630 for n := b.N; n > 0; n-- { 631 t += r.Int64N(arg) 632 } 633 Sink = uint64(t) 634} 635 636func BenchmarkInt32N1000(b *testing.B) { 637 r := testRand() 638 var t int32 639 arg := keep(int32(1000)) 640 for n := b.N; n > 0; n-- { 641 t += r.Int32N(arg) 642 } 643 Sink = uint64(t) 644} 645 646func BenchmarkInt32N1e8(b *testing.B) { 647 r := testRand() 648 var t int32 649 arg := keep(int32(1e8)) 650 for n := b.N; n > 0; n-- { 651 t += r.Int32N(arg) 652 } 653 Sink = uint64(t) 654} 655 656func BenchmarkInt32N1e9(b *testing.B) { 657 r := testRand() 658 var t int32 659 arg := keep(int32(1e9)) 660 for n := b.N; n > 0; n-- { 661 t += r.Int32N(arg) 662 } 663 Sink = uint64(t) 664} 665 666func BenchmarkInt32N2e9(b *testing.B) { 667 r := testRand() 668 var t int32 669 arg := keep(int32(2e9)) 670 for n := b.N; n > 0; n-- { 671 t += r.Int32N(arg) 672 } 673 Sink = uint64(t) 674} 675 676func BenchmarkFloat32(b *testing.B) { 677 r := testRand() 678 var t float32 679 for n := b.N; n > 0; n-- { 680 t += r.Float32() 681 } 682 Sink = uint64(t) 683} 684 685func BenchmarkFloat64(b *testing.B) { 686 r := testRand() 687 var t float64 688 for n := b.N; n > 0; n-- { 689 t += r.Float64() 690 } 691 Sink = uint64(t) 692} 693 694func BenchmarkExpFloat64(b *testing.B) { 695 r := testRand() 696 var t float64 697 for n := b.N; n > 0; n-- { 698 t += r.ExpFloat64() 699 } 700 Sink = uint64(t) 701} 702 703func BenchmarkNormFloat64(b *testing.B) { 704 r := testRand() 705 var t float64 706 for n := b.N; n > 0; n-- { 707 t += r.NormFloat64() 708 } 709 Sink = uint64(t) 710} 711 712func BenchmarkPerm3(b *testing.B) { 713 r := testRand() 714 var t int 715 for n := b.N; n > 0; n-- { 716 t += r.Perm(3)[0] 717 } 718 Sink = uint64(t) 719 720} 721 722func BenchmarkPerm30(b *testing.B) { 723 r := testRand() 724 var t int 725 for n := b.N; n > 0; n-- { 726 t += r.Perm(30)[0] 727 } 728 Sink = uint64(t) 729} 730 731func BenchmarkPerm30ViaShuffle(b *testing.B) { 732 r := testRand() 733 var t int 734 for n := b.N; n > 0; n-- { 735 p := make([]int, 30) 736 for i := range p { 737 p[i] = i 738 } 739 r.Shuffle(30, func(i, j int) { p[i], p[j] = p[j], p[i] }) 740 t += p[0] 741 } 742 Sink = uint64(t) 743} 744 745// BenchmarkShuffleOverhead uses a minimal swap function 746// to measure just the shuffling overhead. 747func BenchmarkShuffleOverhead(b *testing.B) { 748 r := testRand() 749 for n := b.N; n > 0; n-- { 750 r.Shuffle(30, func(i, j int) { 751 if i < 0 || i >= 30 || j < 0 || j >= 30 { 752 b.Fatalf("bad swap(%d, %d)", i, j) 753 } 754 }) 755 } 756} 757 758func BenchmarkConcurrent(b *testing.B) { 759 const goroutines = 4 760 var wg sync.WaitGroup 761 wg.Add(goroutines) 762 for i := 0; i < goroutines; i++ { 763 go func() { 764 defer wg.Done() 765 for n := b.N; n > 0; n-- { 766 Int64() 767 } 768 }() 769 } 770 wg.Wait() 771} 772 773func TestN(t *testing.T) { 774 for i := 0; i < 1000; i++ { 775 v := N(10) 776 if v < 0 || v >= 10 { 777 t.Fatalf("N(10) returned %d", v) 778 } 779 } 780} 781