• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package rand_test
6
7import (
8	"errors"
9	"fmt"
10	"internal/testenv"
11	"math"
12	. "math/rand/v2"
13	"os"
14	"runtime"
15	"sync"
16	"sync/atomic"
17	"testing"
18)
19
20const (
21	numTestSamples = 10000
22)
23
24var rn, kn, wn, fn = GetNormalDistributionParameters()
25var re, ke, we, fe = GetExponentialDistributionParameters()
26
27type statsResults struct {
28	mean        float64
29	stddev      float64
30	closeEnough float64
31	maxError    float64
32}
33
34func nearEqual(a, b, closeEnough, maxError float64) bool {
35	absDiff := math.Abs(a - b)
36	if absDiff < closeEnough { // Necessary when one value is zero and one value is close to zero.
37		return true
38	}
39	return absDiff/max(math.Abs(a), math.Abs(b)) < maxError
40}
41
42var testSeeds = []uint64{1, 1754801282, 1698661970, 1550503961}
43
44// checkSimilarDistribution returns success if the mean and stddev of the
45// two statsResults are similar.
46func (sr *statsResults) checkSimilarDistribution(expected *statsResults) error {
47	if !nearEqual(sr.mean, expected.mean, expected.closeEnough, expected.maxError) {
48		s := fmt.Sprintf("mean %v != %v (allowed error %v, %v)", sr.mean, expected.mean, expected.closeEnough, expected.maxError)
49		fmt.Println(s)
50		return errors.New(s)
51	}
52	if !nearEqual(sr.stddev, expected.stddev, expected.closeEnough, expected.maxError) {
53		s := fmt.Sprintf("stddev %v != %v (allowed error %v, %v)", sr.stddev, expected.stddev, expected.closeEnough, expected.maxError)
54		fmt.Println(s)
55		return errors.New(s)
56	}
57	return nil
58}
59
60func getStatsResults(samples []float64) *statsResults {
61	res := new(statsResults)
62	var sum, squaresum float64
63	for _, s := range samples {
64		sum += s
65		squaresum += s * s
66	}
67	res.mean = sum / float64(len(samples))
68	res.stddev = math.Sqrt(squaresum/float64(len(samples)) - res.mean*res.mean)
69	return res
70}
71
72func checkSampleDistribution(t *testing.T, samples []float64, expected *statsResults) {
73	t.Helper()
74	actual := getStatsResults(samples)
75	err := actual.checkSimilarDistribution(expected)
76	if err != nil {
77		t.Error(err)
78	}
79}
80
81func checkSampleSliceDistributions(t *testing.T, samples []float64, nslices int, expected *statsResults) {
82	t.Helper()
83	chunk := len(samples) / nslices
84	for i := 0; i < nslices; i++ {
85		low := i * chunk
86		var high int
87		if i == nslices-1 {
88			high = len(samples) - 1
89		} else {
90			high = (i + 1) * chunk
91		}
92		checkSampleDistribution(t, samples[low:high], expected)
93	}
94}
95
96//
97// Normal distribution tests
98//
99
100func generateNormalSamples(nsamples int, mean, stddev float64, seed uint64) []float64 {
101	r := New(NewPCG(seed, seed))
102	samples := make([]float64, nsamples)
103	for i := range samples {
104		samples[i] = r.NormFloat64()*stddev + mean
105	}
106	return samples
107}
108
109func testNormalDistribution(t *testing.T, nsamples int, mean, stddev float64, seed uint64) {
110	//fmt.Printf("testing nsamples=%v mean=%v stddev=%v seed=%v\n", nsamples, mean, stddev, seed);
111
112	samples := generateNormalSamples(nsamples, mean, stddev, seed)
113	errorScale := max(1.0, stddev) // Error scales with stddev
114	expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale}
115
116	// Make sure that the entire set matches the expected distribution.
117	checkSampleDistribution(t, samples, expected)
118
119	// Make sure that each half of the set matches the expected distribution.
120	checkSampleSliceDistributions(t, samples, 2, expected)
121
122	// Make sure that each 7th of the set matches the expected distribution.
123	checkSampleSliceDistributions(t, samples, 7, expected)
124}
125
126// Actual tests
127
128func TestStandardNormalValues(t *testing.T) {
129	for _, seed := range testSeeds {
130		testNormalDistribution(t, numTestSamples, 0, 1, seed)
131	}
132}
133
134func TestNonStandardNormalValues(t *testing.T) {
135	sdmax := 1000.0
136	mmax := 1000.0
137	if testing.Short() {
138		sdmax = 5
139		mmax = 5
140	}
141	for sd := 0.5; sd < sdmax; sd *= 2 {
142		for m := 0.5; m < mmax; m *= 2 {
143			for _, seed := range testSeeds {
144				testNormalDistribution(t, numTestSamples, m, sd, seed)
145				if testing.Short() {
146					break
147				}
148			}
149		}
150	}
151}
152
153//
154// Exponential distribution tests
155//
156
157func generateExponentialSamples(nsamples int, rate float64, seed uint64) []float64 {
158	r := New(NewPCG(seed, seed))
159	samples := make([]float64, nsamples)
160	for i := range samples {
161		samples[i] = r.ExpFloat64() / rate
162	}
163	return samples
164}
165
166func testExponentialDistribution(t *testing.T, nsamples int, rate float64, seed uint64) {
167	//fmt.Printf("testing nsamples=%v rate=%v seed=%v\n", nsamples, rate, seed);
168
169	mean := 1 / rate
170	stddev := mean
171
172	samples := generateExponentialSamples(nsamples, rate, seed)
173	errorScale := max(1.0, 1/rate) // Error scales with the inverse of the rate
174	expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.20 * errorScale}
175
176	// Make sure that the entire set matches the expected distribution.
177	checkSampleDistribution(t, samples, expected)
178
179	// Make sure that each half of the set matches the expected distribution.
180	checkSampleSliceDistributions(t, samples, 2, expected)
181
182	// Make sure that each 7th of the set matches the expected distribution.
183	checkSampleSliceDistributions(t, samples, 7, expected)
184}
185
186// Actual tests
187
188func TestStandardExponentialValues(t *testing.T) {
189	for _, seed := range testSeeds {
190		testExponentialDistribution(t, numTestSamples, 1, seed)
191	}
192}
193
194func TestNonStandardExponentialValues(t *testing.T) {
195	for rate := 0.05; rate < 10; rate *= 2 {
196		for _, seed := range testSeeds {
197			testExponentialDistribution(t, numTestSamples, rate, seed)
198			if testing.Short() {
199				break
200			}
201		}
202	}
203}
204
205//
206// Table generation tests
207//
208
209func initNorm() (testKn []uint32, testWn, testFn []float32) {
210	const m1 = 1 << 31
211	var (
212		dn float64 = rn
213		tn         = dn
214		vn float64 = 9.91256303526217e-3
215	)
216
217	testKn = make([]uint32, 128)
218	testWn = make([]float32, 128)
219	testFn = make([]float32, 128)
220
221	q := vn / math.Exp(-0.5*dn*dn)
222	testKn[0] = uint32((dn / q) * m1)
223	testKn[1] = 0
224	testWn[0] = float32(q / m1)
225	testWn[127] = float32(dn / m1)
226	testFn[0] = 1.0
227	testFn[127] = float32(math.Exp(-0.5 * dn * dn))
228	for i := 126; i >= 1; i-- {
229		dn = math.Sqrt(-2.0 * math.Log(vn/dn+math.Exp(-0.5*dn*dn)))
230		testKn[i+1] = uint32((dn / tn) * m1)
231		tn = dn
232		testFn[i] = float32(math.Exp(-0.5 * dn * dn))
233		testWn[i] = float32(dn / m1)
234	}
235	return
236}
237
238func initExp() (testKe []uint32, testWe, testFe []float32) {
239	const m2 = 1 << 32
240	var (
241		de float64 = re
242		te         = de
243		ve float64 = 3.9496598225815571993e-3
244	)
245
246	testKe = make([]uint32, 256)
247	testWe = make([]float32, 256)
248	testFe = make([]float32, 256)
249
250	q := ve / math.Exp(-de)
251	testKe[0] = uint32((de / q) * m2)
252	testKe[1] = 0
253	testWe[0] = float32(q / m2)
254	testWe[255] = float32(de / m2)
255	testFe[0] = 1.0
256	testFe[255] = float32(math.Exp(-de))
257	for i := 254; i >= 1; i-- {
258		de = -math.Log(ve/de + math.Exp(-de))
259		testKe[i+1] = uint32((de / te) * m2)
260		te = de
261		testFe[i] = float32(math.Exp(-de))
262		testWe[i] = float32(de / m2)
263	}
264	return
265}
266
267// compareUint32Slices returns the first index where the two slices
268// disagree, or <0 if the lengths are the same and all elements
269// are identical.
270func compareUint32Slices(s1, s2 []uint32) int {
271	if len(s1) != len(s2) {
272		if len(s1) > len(s2) {
273			return len(s2) + 1
274		}
275		return len(s1) + 1
276	}
277	for i := range s1 {
278		if s1[i] != s2[i] {
279			return i
280		}
281	}
282	return -1
283}
284
285// compareFloat32Slices returns the first index where the two slices
286// disagree, or <0 if the lengths are the same and all elements
287// are identical.
288func compareFloat32Slices(s1, s2 []float32) int {
289	if len(s1) != len(s2) {
290		if len(s1) > len(s2) {
291			return len(s2) + 1
292		}
293		return len(s1) + 1
294	}
295	for i := range s1 {
296		if !nearEqual(float64(s1[i]), float64(s2[i]), 0, 1e-7) {
297			return i
298		}
299	}
300	return -1
301}
302
303func TestNormTables(t *testing.T) {
304	testKn, testWn, testFn := initNorm()
305	if i := compareUint32Slices(kn[0:], testKn); i >= 0 {
306		t.Errorf("kn disagrees at index %v; %v != %v", i, kn[i], testKn[i])
307	}
308	if i := compareFloat32Slices(wn[0:], testWn); i >= 0 {
309		t.Errorf("wn disagrees at index %v; %v != %v", i, wn[i], testWn[i])
310	}
311	if i := compareFloat32Slices(fn[0:], testFn); i >= 0 {
312		t.Errorf("fn disagrees at index %v; %v != %v", i, fn[i], testFn[i])
313	}
314}
315
316func TestExpTables(t *testing.T) {
317	testKe, testWe, testFe := initExp()
318	if i := compareUint32Slices(ke[0:], testKe); i >= 0 {
319		t.Errorf("ke disagrees at index %v; %v != %v", i, ke[i], testKe[i])
320	}
321	if i := compareFloat32Slices(we[0:], testWe); i >= 0 {
322		t.Errorf("we disagrees at index %v; %v != %v", i, we[i], testWe[i])
323	}
324	if i := compareFloat32Slices(fe[0:], testFe); i >= 0 {
325		t.Errorf("fe disagrees at index %v; %v != %v", i, fe[i], testFe[i])
326	}
327}
328
329func hasSlowFloatingPoint() bool {
330	switch runtime.GOARCH {
331	case "arm":
332		return os.Getenv("GOARM") == "5"
333	case "mips", "mipsle", "mips64", "mips64le":
334		// Be conservative and assume that all mips boards
335		// have emulated floating point.
336		// TODO: detect what it actually has.
337		return true
338	}
339	return false
340}
341
342func TestFloat32(t *testing.T) {
343	// For issue 6721, the problem came after 7533753 calls, so check 10e6.
344	num := int(10e6)
345	// But do the full amount only on builders (not locally).
346	// But ARM5 floating point emulation is slow (Issue 10749), so
347	// do less for that builder:
348	if testing.Short() && (testenv.Builder() == "" || hasSlowFloatingPoint()) {
349		num /= 100 // 1.72 seconds instead of 172 seconds
350	}
351
352	r := testRand()
353	for ct := 0; ct < num; ct++ {
354		f := r.Float32()
355		if f >= 1 {
356			t.Fatal("Float32() should be in range [0,1). ct:", ct, "f:", f)
357		}
358	}
359}
360
361func TestShuffleSmall(t *testing.T) {
362	// Check that Shuffle allows n=0 and n=1, but that swap is never called for them.
363	r := testRand()
364	for n := 0; n <= 1; n++ {
365		r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) })
366	}
367}
368
369// encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!).
370// See https://en.wikipedia.org/wiki/Lehmer_code.
371// encodePerm modifies the input slice.
372func encodePerm(s []int) int {
373	// Convert to Lehmer code.
374	for i, x := range s {
375		r := s[i+1:]
376		for j, y := range r {
377			if y > x {
378				r[j]--
379			}
380		}
381	}
382	// Convert to int in [0, n!).
383	m := 0
384	fact := 1
385	for i := len(s) - 1; i >= 0; i-- {
386		m += s[i] * fact
387		fact *= len(s) - i
388	}
389	return m
390}
391
392// TestUniformFactorial tests several ways of generating a uniform value in [0, n!).
393func TestUniformFactorial(t *testing.T) {
394	r := New(NewPCG(1, 2))
395	top := 6
396	if testing.Short() {
397		top = 3
398	}
399	for n := 3; n <= top; n++ {
400		t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) {
401			// Calculate n!.
402			nfact := 1
403			for i := 2; i <= n; i++ {
404				nfact *= i
405			}
406
407			// Test a few different ways to generate a uniform distribution.
408			p := make([]int, n) // re-usable slice for Shuffle generator
409			tests := [...]struct {
410				name string
411				fn   func() int
412			}{
413				{name: "Int32N", fn: func() int { return int(r.Int32N(int32(nfact))) }},
414				{name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }},
415				{name: "Shuffle", fn: func() int {
416					// Generate permutation using Shuffle.
417					for i := range p {
418						p[i] = i
419					}
420					r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] })
421					return encodePerm(p)
422				}},
423			}
424
425			for _, test := range tests {
426				t.Run(test.name, func(t *testing.T) {
427					// Gather chi-squared values and check that they follow
428					// the expected normal distribution given n!-1 degrees of freedom.
429					// See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and
430					// https://www.johndcook.com/Beautiful_Testing_ch10.pdf.
431					nsamples := 10 * nfact
432					if nsamples < 1000 {
433						nsamples = 1000
434					}
435					samples := make([]float64, nsamples)
436					for i := range samples {
437						// Generate some uniformly distributed values and count their occurrences.
438						const iters = 1000
439						counts := make([]int, nfact)
440						for i := 0; i < iters; i++ {
441							counts[test.fn()]++
442						}
443						// Calculate chi-squared and add to samples.
444						want := iters / float64(nfact)
445						var χ2 float64
446						for _, have := range counts {
447							err := float64(have) - want
448							χ2 += err * err
449						}
450						χ2 /= want
451						samples[i] = χ2
452					}
453
454					// Check that our samples approximate the appropriate normal distribution.
455					dof := float64(nfact - 1)
456					expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)}
457					errorScale := max(1.0, expected.stddev)
458					expected.closeEnough = 0.10 * errorScale
459					expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211.
460					checkSampleDistribution(t, samples, expected)
461				})
462			}
463		})
464	}
465}
466
467// Benchmarks
468
469var Sink uint64
470
471func testRand() *Rand {
472	return New(NewPCG(1, 2))
473}
474
475func BenchmarkSourceUint64(b *testing.B) {
476	s := NewPCG(1, 2)
477	var t uint64
478	for n := b.N; n > 0; n-- {
479		t += s.Uint64()
480	}
481	Sink = uint64(t)
482}
483
484func BenchmarkGlobalInt64(b *testing.B) {
485	var t int64
486	for n := b.N; n > 0; n-- {
487		t += Int64()
488	}
489	Sink = uint64(t)
490}
491
492func BenchmarkGlobalInt64Parallel(b *testing.B) {
493	b.RunParallel(func(pb *testing.PB) {
494		var t int64
495		for pb.Next() {
496			t += Int64()
497		}
498		atomic.AddUint64(&Sink, uint64(t))
499	})
500}
501
502func BenchmarkGlobalUint64(b *testing.B) {
503	var t uint64
504	for n := b.N; n > 0; n-- {
505		t += Uint64()
506	}
507	Sink = t
508}
509
510func BenchmarkGlobalUint64Parallel(b *testing.B) {
511	b.RunParallel(func(pb *testing.PB) {
512		var t uint64
513		for pb.Next() {
514			t += Uint64()
515		}
516		atomic.AddUint64(&Sink, t)
517	})
518}
519
520func BenchmarkInt64(b *testing.B) {
521	r := testRand()
522	var t int64
523	for n := b.N; n > 0; n-- {
524		t += r.Int64()
525	}
526	Sink = uint64(t)
527}
528
529var AlwaysFalse = false
530
531func keep[T int | uint | int32 | uint32 | int64 | uint64](x T) T {
532	if AlwaysFalse {
533		return -x
534	}
535	return x
536}
537
538func BenchmarkUint64(b *testing.B) {
539	r := testRand()
540	var t uint64
541	for n := b.N; n > 0; n-- {
542		t += r.Uint64()
543	}
544	Sink = t
545}
546
547func BenchmarkGlobalIntN1000(b *testing.B) {
548	var t int
549	arg := keep(1000)
550	for n := b.N; n > 0; n-- {
551		t += IntN(arg)
552	}
553	Sink = uint64(t)
554}
555
556func BenchmarkIntN1000(b *testing.B) {
557	r := testRand()
558	var t int
559	arg := keep(1000)
560	for n := b.N; n > 0; n-- {
561		t += r.IntN(arg)
562	}
563	Sink = uint64(t)
564}
565
566func BenchmarkInt64N1000(b *testing.B) {
567	r := testRand()
568	var t int64
569	arg := keep(int64(1000))
570	for n := b.N; n > 0; n-- {
571		t += r.Int64N(arg)
572	}
573	Sink = uint64(t)
574}
575
576func BenchmarkInt64N1e8(b *testing.B) {
577	r := testRand()
578	var t int64
579	arg := keep(int64(1e8))
580	for n := b.N; n > 0; n-- {
581		t += r.Int64N(arg)
582	}
583	Sink = uint64(t)
584}
585
586func BenchmarkInt64N1e9(b *testing.B) {
587	r := testRand()
588	var t int64
589	arg := keep(int64(1e9))
590	for n := b.N; n > 0; n-- {
591		t += r.Int64N(arg)
592	}
593	Sink = uint64(t)
594}
595
596func BenchmarkInt64N2e9(b *testing.B) {
597	r := testRand()
598	var t int64
599	arg := keep(int64(2e9))
600	for n := b.N; n > 0; n-- {
601		t += r.Int64N(arg)
602	}
603	Sink = uint64(t)
604}
605
606func BenchmarkInt64N1e18(b *testing.B) {
607	r := testRand()
608	var t int64
609	arg := keep(int64(1e18))
610	for n := b.N; n > 0; n-- {
611		t += r.Int64N(arg)
612	}
613	Sink = uint64(t)
614}
615
616func BenchmarkInt64N2e18(b *testing.B) {
617	r := testRand()
618	var t int64
619	arg := keep(int64(2e18))
620	for n := b.N; n > 0; n-- {
621		t += r.Int64N(arg)
622	}
623	Sink = uint64(t)
624}
625
626func BenchmarkInt64N4e18(b *testing.B) {
627	r := testRand()
628	var t int64
629	arg := keep(int64(4e18))
630	for n := b.N; n > 0; n-- {
631		t += r.Int64N(arg)
632	}
633	Sink = uint64(t)
634}
635
636func BenchmarkInt32N1000(b *testing.B) {
637	r := testRand()
638	var t int32
639	arg := keep(int32(1000))
640	for n := b.N; n > 0; n-- {
641		t += r.Int32N(arg)
642	}
643	Sink = uint64(t)
644}
645
646func BenchmarkInt32N1e8(b *testing.B) {
647	r := testRand()
648	var t int32
649	arg := keep(int32(1e8))
650	for n := b.N; n > 0; n-- {
651		t += r.Int32N(arg)
652	}
653	Sink = uint64(t)
654}
655
656func BenchmarkInt32N1e9(b *testing.B) {
657	r := testRand()
658	var t int32
659	arg := keep(int32(1e9))
660	for n := b.N; n > 0; n-- {
661		t += r.Int32N(arg)
662	}
663	Sink = uint64(t)
664}
665
666func BenchmarkInt32N2e9(b *testing.B) {
667	r := testRand()
668	var t int32
669	arg := keep(int32(2e9))
670	for n := b.N; n > 0; n-- {
671		t += r.Int32N(arg)
672	}
673	Sink = uint64(t)
674}
675
676func BenchmarkFloat32(b *testing.B) {
677	r := testRand()
678	var t float32
679	for n := b.N; n > 0; n-- {
680		t += r.Float32()
681	}
682	Sink = uint64(t)
683}
684
685func BenchmarkFloat64(b *testing.B) {
686	r := testRand()
687	var t float64
688	for n := b.N; n > 0; n-- {
689		t += r.Float64()
690	}
691	Sink = uint64(t)
692}
693
694func BenchmarkExpFloat64(b *testing.B) {
695	r := testRand()
696	var t float64
697	for n := b.N; n > 0; n-- {
698		t += r.ExpFloat64()
699	}
700	Sink = uint64(t)
701}
702
703func BenchmarkNormFloat64(b *testing.B) {
704	r := testRand()
705	var t float64
706	for n := b.N; n > 0; n-- {
707		t += r.NormFloat64()
708	}
709	Sink = uint64(t)
710}
711
712func BenchmarkPerm3(b *testing.B) {
713	r := testRand()
714	var t int
715	for n := b.N; n > 0; n-- {
716		t += r.Perm(3)[0]
717	}
718	Sink = uint64(t)
719
720}
721
722func BenchmarkPerm30(b *testing.B) {
723	r := testRand()
724	var t int
725	for n := b.N; n > 0; n-- {
726		t += r.Perm(30)[0]
727	}
728	Sink = uint64(t)
729}
730
731func BenchmarkPerm30ViaShuffle(b *testing.B) {
732	r := testRand()
733	var t int
734	for n := b.N; n > 0; n-- {
735		p := make([]int, 30)
736		for i := range p {
737			p[i] = i
738		}
739		r.Shuffle(30, func(i, j int) { p[i], p[j] = p[j], p[i] })
740		t += p[0]
741	}
742	Sink = uint64(t)
743}
744
745// BenchmarkShuffleOverhead uses a minimal swap function
746// to measure just the shuffling overhead.
747func BenchmarkShuffleOverhead(b *testing.B) {
748	r := testRand()
749	for n := b.N; n > 0; n-- {
750		r.Shuffle(30, func(i, j int) {
751			if i < 0 || i >= 30 || j < 0 || j >= 30 {
752				b.Fatalf("bad swap(%d, %d)", i, j)
753			}
754		})
755	}
756}
757
758func BenchmarkConcurrent(b *testing.B) {
759	const goroutines = 4
760	var wg sync.WaitGroup
761	wg.Add(goroutines)
762	for i := 0; i < goroutines; i++ {
763		go func() {
764			defer wg.Done()
765			for n := b.N; n > 0; n-- {
766				Int64()
767			}
768		}()
769	}
770	wg.Wait()
771}
772
773func TestN(t *testing.T) {
774	for i := 0; i < 1000; i++ {
775		v := N(10)
776		if v < 0 || v >= 10 {
777			t.Fatalf("N(10) returned %d", v)
778		}
779	}
780}
781