• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright (c) 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package field
6
7import (
8	"bytes"
9	"crypto/rand"
10	"encoding/hex"
11	"io"
12	"math/big"
13	"math/bits"
14	mathrand "math/rand"
15	"reflect"
16	"testing"
17	"testing/quick"
18)
19
20func (v Element) String() string {
21	return hex.EncodeToString(v.Bytes())
22}
23
24// quickCheckConfig returns a quick.Config that scales the max count by the
25// given factor if the -short flag is not set.
26func quickCheckConfig(slowScale int) *quick.Config {
27	cfg := new(quick.Config)
28	if !testing.Short() {
29		cfg.MaxCountScale = float64(slowScale)
30	}
31	return cfg
32}
33
34func generateFieldElement(rand *mathrand.Rand) Element {
35	const maskLow52Bits = (1 << 52) - 1
36	return Element{
37		rand.Uint64() & maskLow52Bits,
38		rand.Uint64() & maskLow52Bits,
39		rand.Uint64() & maskLow52Bits,
40		rand.Uint64() & maskLow52Bits,
41		rand.Uint64() & maskLow52Bits,
42	}
43}
44
45// weirdLimbs can be combined to generate a range of edge-case field elements.
46// 0 and -1 are intentionally more weighted, as they combine well.
47var (
48	weirdLimbs51 = []uint64{
49		0, 0, 0, 0,
50		1,
51		19 - 1,
52		19,
53		0x2aaaaaaaaaaaa,
54		0x5555555555555,
55		(1 << 51) - 20,
56		(1 << 51) - 19,
57		(1 << 51) - 1, (1 << 51) - 1,
58		(1 << 51) - 1, (1 << 51) - 1,
59	}
60	weirdLimbs52 = []uint64{
61		0, 0, 0, 0, 0, 0,
62		1,
63		19 - 1,
64		19,
65		0x2aaaaaaaaaaaa,
66		0x5555555555555,
67		(1 << 51) - 20,
68		(1 << 51) - 19,
69		(1 << 51) - 1, (1 << 51) - 1,
70		(1 << 51) - 1, (1 << 51) - 1,
71		(1 << 51) - 1, (1 << 51) - 1,
72		1 << 51,
73		(1 << 51) + 1,
74		(1 << 52) - 19,
75		(1 << 52) - 1,
76	}
77)
78
79func generateWeirdFieldElement(rand *mathrand.Rand) Element {
80	return Element{
81		weirdLimbs52[rand.Intn(len(weirdLimbs52))],
82		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
83		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
84		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
85		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
86	}
87}
88
89func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value {
90	if rand.Intn(2) == 0 {
91		return reflect.ValueOf(generateWeirdFieldElement(rand))
92	}
93	return reflect.ValueOf(generateFieldElement(rand))
94}
95
96// isInBounds returns whether the element is within the expected bit size bounds
97// after a light reduction.
98func isInBounds(x *Element) bool {
99	return bits.Len64(x.l0) <= 52 &&
100		bits.Len64(x.l1) <= 52 &&
101		bits.Len64(x.l2) <= 52 &&
102		bits.Len64(x.l3) <= 52 &&
103		bits.Len64(x.l4) <= 52
104}
105
106func TestMultiplyDistributesOverAdd(t *testing.T) {
107	multiplyDistributesOverAdd := func(x, y, z Element) bool {
108		// Compute t1 = (x+y)*z
109		t1 := new(Element)
110		t1.Add(&x, &y)
111		t1.Multiply(t1, &z)
112
113		// Compute t2 = x*z + y*z
114		t2 := new(Element)
115		t3 := new(Element)
116		t2.Multiply(&x, &z)
117		t3.Multiply(&y, &z)
118		t2.Add(t2, t3)
119
120		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
121	}
122
123	if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig(1024)); err != nil {
124		t.Error(err)
125	}
126}
127
128func TestMul64to128(t *testing.T) {
129	a := uint64(5)
130	b := uint64(5)
131	r := mul64(a, b)
132	if r.lo != 0x19 || r.hi != 0 {
133		t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
134	}
135
136	a = uint64(18014398509481983) // 2^54 - 1
137	b = uint64(18014398509481983) // 2^54 - 1
138	r = mul64(a, b)
139	if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff {
140		t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
141	}
142
143	a = uint64(1125899906842661)
144	b = uint64(2097155)
145	r = mul64(a, b)
146	r = addMul64(r, a, b)
147	r = addMul64(r, a, b)
148	r = addMul64(r, a, b)
149	r = addMul64(r, a, b)
150	if r.lo != 16888498990613035 || r.hi != 640 {
151		t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi)
152	}
153}
154
155func TestSetBytesRoundTrip(t *testing.T) {
156	f1 := func(in [32]byte, fe Element) bool {
157		fe.SetBytes(in[:])
158
159		// Mask the most significant bit as it's ignored by SetBytes. (Now
160		// instead of earlier so we check the masking in SetBytes is working.)
161		in[len(in)-1] &= (1 << 7) - 1
162
163		return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe)
164	}
165	if err := quick.Check(f1, nil); err != nil {
166		t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
167	}
168
169	f2 := func(fe, r Element) bool {
170		r.SetBytes(fe.Bytes())
171
172		// Intentionally not using Equal not to go through Bytes again.
173		// Calling reduce because both Generate and SetBytes can produce
174		// non-canonical representations.
175		fe.reduce()
176		r.reduce()
177		return fe == r
178	}
179	if err := quick.Check(f2, nil); err != nil {
180		t.Errorf("failed FE->bytes->FE round-trip: %v", err)
181	}
182
183	// Check some fixed vectors from dalek
184	type feRTTest struct {
185		fe Element
186		b  []byte
187	}
188	var tests = []feRTTest{
189		{
190			fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
191			b:  []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
192		},
193		{
194			fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
195			b:  []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
196		},
197	}
198
199	for _, tt := range tests {
200		b := tt.fe.Bytes()
201		fe, _ := new(Element).SetBytes(tt.b)
202		if !bytes.Equal(b, tt.b) || fe.Equal(&tt.fe) != 1 {
203			t.Errorf("Failed fixed roundtrip: %v", tt)
204		}
205	}
206}
207
208func swapEndianness(buf []byte) []byte {
209	for i := 0; i < len(buf)/2; i++ {
210		buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
211	}
212	return buf
213}
214
215func TestBytesBigEquivalence(t *testing.T) {
216	f1 := func(in [32]byte, fe, fe1 Element) bool {
217		fe.SetBytes(in[:])
218
219		in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit
220		b := new(big.Int).SetBytes(swapEndianness(in[:]))
221		fe1.fromBig(b)
222
223		if fe != fe1 {
224			return false
225		}
226
227		buf := make([]byte, 32)
228		buf = swapEndianness(fe1.toBig().FillBytes(buf))
229
230		return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1)
231	}
232	if err := quick.Check(f1, nil); err != nil {
233		t.Error(err)
234	}
235}
236
237// fromBig sets v = n, and returns v. The bit length of n must not exceed 256.
238func (v *Element) fromBig(n *big.Int) *Element {
239	if n.BitLen() > 32*8 {
240		panic("edwards25519: invalid field element input size")
241	}
242
243	buf := make([]byte, 0, 32)
244	for _, word := range n.Bits() {
245		for i := 0; i < bits.UintSize; i += 8 {
246			if len(buf) >= cap(buf) {
247				break
248			}
249			buf = append(buf, byte(word))
250			word >>= 8
251		}
252	}
253
254	v.SetBytes(buf[:32])
255	return v
256}
257
258func (v *Element) fromDecimal(s string) *Element {
259	n, ok := new(big.Int).SetString(s, 10)
260	if !ok {
261		panic("not a valid decimal: " + s)
262	}
263	return v.fromBig(n)
264}
265
266// toBig returns v as a big.Int.
267func (v *Element) toBig() *big.Int {
268	buf := v.Bytes()
269
270	words := make([]big.Word, 32*8/bits.UintSize)
271	for n := range words {
272		for i := 0; i < bits.UintSize; i += 8 {
273			if len(buf) == 0 {
274				break
275			}
276			words[n] |= big.Word(buf[0]) << big.Word(i)
277			buf = buf[1:]
278		}
279	}
280
281	return new(big.Int).SetBits(words)
282}
283
284func TestDecimalConstants(t *testing.T) {
285	sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752"
286	if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
287		t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp)
288	}
289	// d is in the parent package, and we don't want to expose d or fromDecimal.
290	// dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555"
291	// if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 {
292	// 	t.Errorf("d is %v, expected %v", d, exp)
293	// }
294}
295
296func TestSetBytesRoundTripEdgeCases(t *testing.T) {
297	// TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1,
298	// and between 2^255 and 2^256-1. Test both the documented SetBytes
299	// behavior, and that Bytes reduces them.
300}
301
302// Tests self-consistency between Multiply and Square.
303func TestConsistency(t *testing.T) {
304	var x Element
305	var x2, x2sq Element
306
307	x = Element{1, 1, 1, 1, 1}
308	x2.Multiply(&x, &x)
309	x2sq.Square(&x)
310
311	if x2 != x2sq {
312		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
313	}
314
315	var bytes [32]byte
316
317	_, err := io.ReadFull(rand.Reader, bytes[:])
318	if err != nil {
319		t.Fatal(err)
320	}
321	x.SetBytes(bytes[:])
322
323	x2.Multiply(&x, &x)
324	x2sq.Square(&x)
325
326	if x2 != x2sq {
327		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
328	}
329}
330
331func TestEqual(t *testing.T) {
332	x := Element{1, 1, 1, 1, 1}
333	y := Element{5, 4, 3, 2, 1}
334
335	eq := x.Equal(&x)
336	if eq != 1 {
337		t.Errorf("wrong about equality")
338	}
339
340	eq = x.Equal(&y)
341	if eq != 0 {
342		t.Errorf("wrong about inequality")
343	}
344}
345
346func TestInvert(t *testing.T) {
347	x := Element{1, 1, 1, 1, 1}
348	one := Element{1, 0, 0, 0, 0}
349	var xinv, r Element
350
351	xinv.Invert(&x)
352	r.Multiply(&x, &xinv)
353	r.reduce()
354
355	if one != r {
356		t.Errorf("inversion identity failed, got: %x", r)
357	}
358
359	var bytes [32]byte
360
361	_, err := io.ReadFull(rand.Reader, bytes[:])
362	if err != nil {
363		t.Fatal(err)
364	}
365	x.SetBytes(bytes[:])
366
367	xinv.Invert(&x)
368	r.Multiply(&x, &xinv)
369	r.reduce()
370
371	if one != r {
372		t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
373	}
374
375	zero := Element{}
376	x.Set(&zero)
377	if xx := xinv.Invert(&x); xx != &xinv {
378		t.Errorf("inverting zero did not return the receiver")
379	} else if xinv.Equal(&zero) != 1 {
380		t.Errorf("inverting zero did not return zero")
381	}
382}
383
384func TestSelectSwap(t *testing.T) {
385	a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
386	b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
387
388	var c, d Element
389
390	c.Select(&a, &b, 1)
391	d.Select(&a, &b, 0)
392
393	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
394		t.Errorf("Select failed")
395	}
396
397	c.Swap(&d, 0)
398
399	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
400		t.Errorf("Swap failed")
401	}
402
403	c.Swap(&d, 1)
404
405	if c.Equal(&b) != 1 || d.Equal(&a) != 1 {
406		t.Errorf("Swap failed")
407	}
408}
409
410func TestMult32(t *testing.T) {
411	mult32EquivalentToMul := func(x Element, y uint32) bool {
412		t1 := new(Element)
413		for i := 0; i < 100; i++ {
414			t1.Mult32(&x, y)
415		}
416
417		ty := new(Element)
418		ty.l0 = uint64(y)
419
420		t2 := new(Element)
421		for i := 0; i < 100; i++ {
422			t2.Multiply(&x, ty)
423		}
424
425		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
426	}
427
428	if err := quick.Check(mult32EquivalentToMul, quickCheckConfig(1024)); err != nil {
429		t.Error(err)
430	}
431}
432
433func TestSqrtRatio(t *testing.T) {
434	// From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4.
435	type test struct {
436		u, v      string
437		wasSquare int
438		r         string
439	}
440	var tests = []test{
441		// If u is 0, the function is defined to return (0, TRUE), even if v
442		// is zero. Note that where used in this package, the denominator v
443		// is never zero.
444		{
445			"0000000000000000000000000000000000000000000000000000000000000000",
446			"0000000000000000000000000000000000000000000000000000000000000000",
447			1, "0000000000000000000000000000000000000000000000000000000000000000",
448		},
449		// 0/1 == 0²
450		{
451			"0000000000000000000000000000000000000000000000000000000000000000",
452			"0100000000000000000000000000000000000000000000000000000000000000",
453			1, "0000000000000000000000000000000000000000000000000000000000000000",
454		},
455		// If u is non-zero and v is zero, defined to return (0, FALSE).
456		{
457			"0100000000000000000000000000000000000000000000000000000000000000",
458			"0000000000000000000000000000000000000000000000000000000000000000",
459			0, "0000000000000000000000000000000000000000000000000000000000000000",
460		},
461		// 2/1 is not square in this field.
462		{
463			"0200000000000000000000000000000000000000000000000000000000000000",
464			"0100000000000000000000000000000000000000000000000000000000000000",
465			0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54",
466		},
467		// 4/1 == 2²
468		{
469			"0400000000000000000000000000000000000000000000000000000000000000",
470			"0100000000000000000000000000000000000000000000000000000000000000",
471			1, "0200000000000000000000000000000000000000000000000000000000000000",
472		},
473		// 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem
474		{
475			"0100000000000000000000000000000000000000000000000000000000000000",
476			"0400000000000000000000000000000000000000000000000000000000000000",
477			1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f",
478		},
479	}
480
481	for i, tt := range tests {
482		u, _ := new(Element).SetBytes(decodeHex(tt.u))
483		v, _ := new(Element).SetBytes(decodeHex(tt.v))
484		want, _ := new(Element).SetBytes(decodeHex(tt.r))
485		got, wasSquare := new(Element).SqrtRatio(u, v)
486		if got.Equal(want) == 0 || wasSquare != tt.wasSquare {
487			t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare)
488		}
489	}
490}
491
492func TestCarryPropagate(t *testing.T) {
493	asmLikeGeneric := func(a [5]uint64) bool {
494		t1 := &Element{a[0], a[1], a[2], a[3], a[4]}
495		t2 := &Element{a[0], a[1], a[2], a[3], a[4]}
496
497		t1.carryPropagate()
498		t2.carryPropagateGeneric()
499
500		if *t1 != *t2 {
501			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
502		}
503
504		return *t1 == *t2 && isInBounds(t2)
505	}
506
507	if err := quick.Check(asmLikeGeneric, quickCheckConfig(1024)); err != nil {
508		t.Error(err)
509	}
510
511	if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) {
512		t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}")
513	}
514}
515
516func TestFeSquare(t *testing.T) {
517	asmLikeGeneric := func(a Element) bool {
518		t1 := a
519		t2 := a
520
521		feSquareGeneric(&t1, &t1)
522		feSquare(&t2, &t2)
523
524		if t1 != t2 {
525			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
526		}
527
528		return t1 == t2 && isInBounds(&t2)
529	}
530
531	if err := quick.Check(asmLikeGeneric, quickCheckConfig(1024)); err != nil {
532		t.Error(err)
533	}
534}
535
536func TestFeMul(t *testing.T) {
537	asmLikeGeneric := func(a, b Element) bool {
538		a1 := a
539		a2 := a
540		b1 := b
541		b2 := b
542
543		feMulGeneric(&a1, &a1, &b1)
544		feMul(&a2, &a2, &b2)
545
546		if a1 != a2 || b1 != b2 {
547			t.Logf("got: %#v,\nexpected: %#v", a1, a2)
548			t.Logf("got: %#v,\nexpected: %#v", b1, b2)
549		}
550
551		return a1 == a2 && isInBounds(&a2) &&
552			b1 == b2 && isInBounds(&b2)
553	}
554
555	if err := quick.Check(asmLikeGeneric, quickCheckConfig(1024)); err != nil {
556		t.Error(err)
557	}
558}
559
560func decodeHex(s string) []byte {
561	b, err := hex.DecodeString(s)
562	if err != nil {
563		panic(err)
564	}
565	return b
566}
567