• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This program generates a test to verify that the standard comparison
6// operators properly handle one const operand. The test file should be
7// generated with a known working version of go.
8// launch with `go run cmpConstGen.go` a file called cmpConst.go
9// will be written into the parent directory containing the tests
10
11package main
12
13import (
14	"bytes"
15	"fmt"
16	"go/format"
17	"log"
18	"math/big"
19	"sort"
20)
21
22const (
23	maxU64 = (1 << 64) - 1
24	maxU32 = (1 << 32) - 1
25	maxU16 = (1 << 16) - 1
26	maxU8  = (1 << 8) - 1
27
28	maxI64 = (1 << 63) - 1
29	maxI32 = (1 << 31) - 1
30	maxI16 = (1 << 15) - 1
31	maxI8  = (1 << 7) - 1
32
33	minI64 = -(1 << 63)
34	minI32 = -(1 << 31)
35	minI16 = -(1 << 15)
36	minI8  = -(1 << 7)
37)
38
39func cmp(left *big.Int, op string, right *big.Int) bool {
40	switch left.Cmp(right) {
41	case -1: // less than
42		return op == "<" || op == "<=" || op == "!="
43	case 0: // equal
44		return op == "==" || op == "<=" || op == ">="
45	case 1: // greater than
46		return op == ">" || op == ">=" || op == "!="
47	}
48	panic("unexpected comparison value")
49}
50
51func inRange(typ string, val *big.Int) bool {
52	min, max := &big.Int{}, &big.Int{}
53	switch typ {
54	case "uint64":
55		max = max.SetUint64(maxU64)
56	case "uint32":
57		max = max.SetUint64(maxU32)
58	case "uint16":
59		max = max.SetUint64(maxU16)
60	case "uint8":
61		max = max.SetUint64(maxU8)
62	case "int64":
63		min = min.SetInt64(minI64)
64		max = max.SetInt64(maxI64)
65	case "int32":
66		min = min.SetInt64(minI32)
67		max = max.SetInt64(maxI32)
68	case "int16":
69		min = min.SetInt64(minI16)
70		max = max.SetInt64(maxI16)
71	case "int8":
72		min = min.SetInt64(minI8)
73		max = max.SetInt64(maxI8)
74	default:
75		panic("unexpected type")
76	}
77	return cmp(min, "<=", val) && cmp(val, "<=", max)
78}
79
80func getValues(typ string) []*big.Int {
81	Uint := func(v uint64) *big.Int { return big.NewInt(0).SetUint64(v) }
82	Int := func(v int64) *big.Int { return big.NewInt(0).SetInt64(v) }
83	values := []*big.Int{
84		// limits
85		Uint(maxU64),
86		Uint(maxU64 - 1),
87		Uint(maxI64 + 1),
88		Uint(maxI64),
89		Uint(maxI64 - 1),
90		Uint(maxU32 + 1),
91		Uint(maxU32),
92		Uint(maxU32 - 1),
93		Uint(maxI32 + 1),
94		Uint(maxI32),
95		Uint(maxI32 - 1),
96		Uint(maxU16 + 1),
97		Uint(maxU16),
98		Uint(maxU16 - 1),
99		Uint(maxI16 + 1),
100		Uint(maxI16),
101		Uint(maxI16 - 1),
102		Uint(maxU8 + 1),
103		Uint(maxU8),
104		Uint(maxU8 - 1),
105		Uint(maxI8 + 1),
106		Uint(maxI8),
107		Uint(maxI8 - 1),
108		Uint(0),
109		Int(minI8 + 1),
110		Int(minI8),
111		Int(minI8 - 1),
112		Int(minI16 + 1),
113		Int(minI16),
114		Int(minI16 - 1),
115		Int(minI32 + 1),
116		Int(minI32),
117		Int(minI32 - 1),
118		Int(minI64 + 1),
119		Int(minI64),
120
121		// other possibly interesting values
122		Uint(1),
123		Int(-1),
124		Uint(0xff << 56),
125		Uint(0xff << 32),
126		Uint(0xff << 24),
127	}
128	sort.Slice(values, func(i, j int) bool { return values[i].Cmp(values[j]) == -1 })
129	var ret []*big.Int
130	for _, val := range values {
131		if !inRange(typ, val) {
132			continue
133		}
134		ret = append(ret, val)
135	}
136	return ret
137}
138
139func sigString(v *big.Int) string {
140	var t big.Int
141	t.Abs(v)
142	if v.Sign() == -1 {
143		return "neg" + t.String()
144	}
145	return t.String()
146}
147
148func main() {
149	types := []string{
150		"uint64", "uint32", "uint16", "uint8",
151		"int64", "int32", "int16", "int8",
152	}
153
154	w := new(bytes.Buffer)
155	fmt.Fprintf(w, "// Code generated by gen/cmpConstGen.go. DO NOT EDIT.\n\n")
156	fmt.Fprintf(w, "package main;\n")
157	fmt.Fprintf(w, "import (\"testing\"; \"reflect\"; \"runtime\";)\n")
158	fmt.Fprintf(w, "// results show the expected result for the elements left of, equal to and right of the index.\n")
159	fmt.Fprintf(w, "type result struct{l, e, r bool}\n")
160	fmt.Fprintf(w, "var (\n")
161	fmt.Fprintf(w, "	eq = result{l: false, e: true, r: false}\n")
162	fmt.Fprintf(w, "	ne = result{l: true, e: false, r: true}\n")
163	fmt.Fprintf(w, "	lt = result{l: true, e: false, r: false}\n")
164	fmt.Fprintf(w, "	le = result{l: true, e: true, r: false}\n")
165	fmt.Fprintf(w, "	gt = result{l: false, e: false, r: true}\n")
166	fmt.Fprintf(w, "	ge = result{l: false, e: true, r: true}\n")
167	fmt.Fprintf(w, ")\n")
168
169	operators := []struct{ op, name string }{
170		{"<", "lt"},
171		{"<=", "le"},
172		{">", "gt"},
173		{">=", "ge"},
174		{"==", "eq"},
175		{"!=", "ne"},
176	}
177
178	for _, typ := range types {
179		// generate a slice containing valid values for this type
180		fmt.Fprintf(w, "\n// %v tests\n", typ)
181		values := getValues(typ)
182		fmt.Fprintf(w, "var %v_vals = []%v{\n", typ, typ)
183		for _, val := range values {
184			fmt.Fprintf(w, "%v,\n", val.String())
185		}
186		fmt.Fprintf(w, "}\n")
187
188		// generate test functions
189		for _, r := range values {
190			// TODO: could also test constant on lhs.
191			sig := sigString(r)
192			for _, op := range operators {
193				// no need for go:noinline because the function is called indirectly
194				fmt.Fprintf(w, "func %v_%v_%v(x %v) bool { return x %v %v; }\n", op.name, sig, typ, typ, op.op, r.String())
195			}
196		}
197
198		// generate a table of test cases
199		fmt.Fprintf(w, "var %v_tests = []struct{\n", typ)
200		fmt.Fprintf(w, "	idx int // index of the constant used\n")
201		fmt.Fprintf(w, "	exp result // expected results\n")
202		fmt.Fprintf(w, "	fn  func(%v) bool\n", typ)
203		fmt.Fprintf(w, "}{\n")
204		for i, r := range values {
205			sig := sigString(r)
206			for _, op := range operators {
207				fmt.Fprintf(w, "{idx: %v,", i)
208				fmt.Fprintf(w, "exp: %v,", op.name)
209				fmt.Fprintf(w, "fn:  %v_%v_%v},\n", op.name, sig, typ)
210			}
211		}
212		fmt.Fprintf(w, "}\n")
213	}
214
215	// emit the main function, looping over all test cases
216	fmt.Fprintf(w, "// TestComparisonsConst tests results for comparison operations against constants.\n")
217	fmt.Fprintf(w, "func TestComparisonsConst(t *testing.T) {\n")
218	for _, typ := range types {
219		fmt.Fprintf(w, "for i, test := range %v_tests {\n", typ)
220		fmt.Fprintf(w, "	for j, x := range %v_vals {\n", typ)
221		fmt.Fprintf(w, "		want := test.exp.l\n")
222		fmt.Fprintf(w, "		if j == test.idx {\nwant = test.exp.e\n}")
223		fmt.Fprintf(w, "		else if j > test.idx {\nwant = test.exp.r\n}\n")
224		fmt.Fprintf(w, "		if test.fn(x) != want {\n")
225		fmt.Fprintf(w, "			fn := runtime.FuncForPC(reflect.ValueOf(test.fn).Pointer()).Name()\n")
226		fmt.Fprintf(w, "			t.Errorf(\"test failed: %%v(%%v) != %%v [type=%v i=%%v j=%%v idx=%%v]\", fn, x, want, i, j, test.idx)\n", typ)
227		fmt.Fprintf(w, "		}\n")
228		fmt.Fprintf(w, "	}\n")
229		fmt.Fprintf(w, "}\n")
230	}
231	fmt.Fprintf(w, "}\n")
232
233	// gofmt result
234	b := w.Bytes()
235	src, err := format.Source(b)
236	if err != nil {
237		fmt.Printf("%s\n", b)
238		panic(err)
239	}
240
241	// write to file
242	err = os.WriteFile("../cmpConst_test.go", src, 0666)
243	if err != nil {
244		log.Fatalf("can't write output: %v\n", err)
245	}
246}
247