• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package nistec_test
6
7import (
8	"bytes"
9	"crypto/elliptic"
10	"crypto/internal/nistec"
11	"fmt"
12	"internal/testenv"
13	"math/big"
14	"math/rand"
15	"testing"
16)
17
18func TestAllocations(t *testing.T) {
19	testenv.SkipIfOptimizationOff(t)
20
21	t.Run("P224", func(t *testing.T) {
22		if allocs := testing.AllocsPerRun(10, func() {
23			p := nistec.NewP224Point().SetGenerator()
24			scalar := make([]byte, 28)
25			rand.Read(scalar)
26			p.ScalarBaseMult(scalar)
27			p.ScalarMult(p, scalar)
28			out := p.Bytes()
29			if _, err := nistec.NewP224Point().SetBytes(out); err != nil {
30				t.Fatal(err)
31			}
32			out = p.BytesCompressed()
33			if _, err := p.SetBytes(out); err != nil {
34				t.Fatal(err)
35			}
36		}); allocs > 0 {
37			t.Errorf("expected zero allocations, got %0.1f", allocs)
38		}
39	})
40	t.Run("P256", func(t *testing.T) {
41		if allocs := testing.AllocsPerRun(10, func() {
42			p := nistec.NewP256Point().SetGenerator()
43			scalar := make([]byte, 32)
44			rand.Read(scalar)
45			p.ScalarBaseMult(scalar)
46			p.ScalarMult(p, scalar)
47			out := p.Bytes()
48			if _, err := nistec.NewP256Point().SetBytes(out); err != nil {
49				t.Fatal(err)
50			}
51			out = p.BytesCompressed()
52			if _, err := p.SetBytes(out); err != nil {
53				t.Fatal(err)
54			}
55		}); allocs > 0 {
56			t.Errorf("expected zero allocations, got %0.1f", allocs)
57		}
58	})
59	t.Run("P384", func(t *testing.T) {
60		if allocs := testing.AllocsPerRun(10, func() {
61			p := nistec.NewP384Point().SetGenerator()
62			scalar := make([]byte, 48)
63			rand.Read(scalar)
64			p.ScalarBaseMult(scalar)
65			p.ScalarMult(p, scalar)
66			out := p.Bytes()
67			if _, err := nistec.NewP384Point().SetBytes(out); err != nil {
68				t.Fatal(err)
69			}
70			out = p.BytesCompressed()
71			if _, err := p.SetBytes(out); err != nil {
72				t.Fatal(err)
73			}
74		}); allocs > 0 {
75			t.Errorf("expected zero allocations, got %0.1f", allocs)
76		}
77	})
78	t.Run("P521", func(t *testing.T) {
79		if allocs := testing.AllocsPerRun(10, func() {
80			p := nistec.NewP521Point().SetGenerator()
81			scalar := make([]byte, 66)
82			rand.Read(scalar)
83			p.ScalarBaseMult(scalar)
84			p.ScalarMult(p, scalar)
85			out := p.Bytes()
86			if _, err := nistec.NewP521Point().SetBytes(out); err != nil {
87				t.Fatal(err)
88			}
89			out = p.BytesCompressed()
90			if _, err := p.SetBytes(out); err != nil {
91				t.Fatal(err)
92			}
93		}); allocs > 0 {
94			t.Errorf("expected zero allocations, got %0.1f", allocs)
95		}
96	})
97}
98
99type nistPoint[T any] interface {
100	Bytes() []byte
101	SetGenerator() T
102	SetBytes([]byte) (T, error)
103	Add(T, T) T
104	Double(T) T
105	ScalarMult(T, []byte) (T, error)
106	ScalarBaseMult([]byte) (T, error)
107}
108
109func TestEquivalents(t *testing.T) {
110	t.Run("P224", func(t *testing.T) {
111		testEquivalents(t, nistec.NewP224Point, elliptic.P224())
112	})
113	t.Run("P256", func(t *testing.T) {
114		testEquivalents(t, nistec.NewP256Point, elliptic.P256())
115	})
116	t.Run("P384", func(t *testing.T) {
117		testEquivalents(t, nistec.NewP384Point, elliptic.P384())
118	})
119	t.Run("P521", func(t *testing.T) {
120		testEquivalents(t, nistec.NewP521Point, elliptic.P521())
121	})
122}
123
124func testEquivalents[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) {
125	p := newPoint().SetGenerator()
126
127	elementSize := (c.Params().BitSize + 7) / 8
128	two := make([]byte, elementSize)
129	two[len(two)-1] = 2
130	nPlusTwo := make([]byte, elementSize)
131	new(big.Int).Add(c.Params().N, big.NewInt(2)).FillBytes(nPlusTwo)
132
133	p1 := newPoint().Double(p)
134	p2 := newPoint().Add(p, p)
135	p3, err := newPoint().ScalarMult(p, two)
136	fatalIfErr(t, err)
137	p4, err := newPoint().ScalarBaseMult(two)
138	fatalIfErr(t, err)
139	p5, err := newPoint().ScalarMult(p, nPlusTwo)
140	fatalIfErr(t, err)
141	p6, err := newPoint().ScalarBaseMult(nPlusTwo)
142	fatalIfErr(t, err)
143
144	if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
145		t.Error("P+P != 2*P")
146	}
147	if !bytes.Equal(p1.Bytes(), p3.Bytes()) {
148		t.Error("P+P != [2]P")
149	}
150	if !bytes.Equal(p1.Bytes(), p4.Bytes()) {
151		t.Error("G+G != [2]G")
152	}
153	if !bytes.Equal(p1.Bytes(), p5.Bytes()) {
154		t.Error("P+P != [N+2]P")
155	}
156	if !bytes.Equal(p1.Bytes(), p6.Bytes()) {
157		t.Error("G+G != [N+2]G")
158	}
159}
160
161func TestScalarMult(t *testing.T) {
162	t.Run("P224", func(t *testing.T) {
163		testScalarMult(t, nistec.NewP224Point, elliptic.P224())
164	})
165	t.Run("P256", func(t *testing.T) {
166		testScalarMult(t, nistec.NewP256Point, elliptic.P256())
167	})
168	t.Run("P384", func(t *testing.T) {
169		testScalarMult(t, nistec.NewP384Point, elliptic.P384())
170	})
171	t.Run("P521", func(t *testing.T) {
172		testScalarMult(t, nistec.NewP521Point, elliptic.P521())
173	})
174}
175
176func testScalarMult[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) {
177	G := newPoint().SetGenerator()
178	checkScalar := func(t *testing.T, scalar []byte) {
179		p1, err := newPoint().ScalarBaseMult(scalar)
180		fatalIfErr(t, err)
181		p2, err := newPoint().ScalarMult(G, scalar)
182		fatalIfErr(t, err)
183		if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
184			t.Error("[k]G != ScalarBaseMult(k)")
185		}
186
187		expectInfinity := new(big.Int).Mod(new(big.Int).SetBytes(scalar), c.Params().N).Sign() == 0
188		if expectInfinity {
189			if !bytes.Equal(p1.Bytes(), newPoint().Bytes()) {
190				t.Error("ScalarBaseMult(k) != ∞")
191			}
192			if !bytes.Equal(p2.Bytes(), newPoint().Bytes()) {
193				t.Error("[k]G != ∞")
194			}
195		} else {
196			if bytes.Equal(p1.Bytes(), newPoint().Bytes()) {
197				t.Error("ScalarBaseMult(k) == ∞")
198			}
199			if bytes.Equal(p2.Bytes(), newPoint().Bytes()) {
200				t.Error("[k]G == ∞")
201			}
202		}
203
204		d := new(big.Int).SetBytes(scalar)
205		d.Sub(c.Params().N, d)
206		d.Mod(d, c.Params().N)
207		g1, err := newPoint().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar))))
208		fatalIfErr(t, err)
209		g1.Add(g1, p1)
210		if !bytes.Equal(g1.Bytes(), newPoint().Bytes()) {
211			t.Error("[N - k]G + [k]G != ∞")
212		}
213	}
214
215	byteLen := len(c.Params().N.Bytes())
216	bitLen := c.Params().N.BitLen()
217	t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) })
218	t.Run("1", func(t *testing.T) {
219		checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
220	})
221	t.Run("N-1", func(t *testing.T) {
222		checkScalar(t, new(big.Int).Sub(c.Params().N, big.NewInt(1)).Bytes())
223	})
224	t.Run("N", func(t *testing.T) { checkScalar(t, c.Params().N.Bytes()) })
225	t.Run("N+1", func(t *testing.T) {
226		checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(1)).Bytes())
227	})
228	t.Run("all1s", func(t *testing.T) {
229		s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
230		s.Sub(s, big.NewInt(1))
231		checkScalar(t, s.Bytes())
232	})
233	if testing.Short() {
234		return
235	}
236	for i := 0; i < bitLen; i++ {
237		t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) {
238			s := new(big.Int).Lsh(big.NewInt(1), uint(i))
239			checkScalar(t, s.FillBytes(make([]byte, byteLen)))
240		})
241	}
242	for i := 0; i <= 64; i++ {
243		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
244			checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen)))
245		})
246	}
247	// Test N-64...N+64 since they risk overlapping with precomputed table values
248	// in the final additions.
249	for i := int64(-64); i <= 64; i++ {
250		t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) {
251			checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(i)).Bytes())
252		})
253	}
254}
255
256func fatalIfErr(t *testing.T, err error) {
257	t.Helper()
258	if err != nil {
259		t.Fatal(err)
260	}
261}
262
263func BenchmarkScalarMult(b *testing.B) {
264	b.Run("P224", func(b *testing.B) {
265		benchmarkScalarMult(b, nistec.NewP224Point().SetGenerator(), 28)
266	})
267	b.Run("P256", func(b *testing.B) {
268		benchmarkScalarMult(b, nistec.NewP256Point().SetGenerator(), 32)
269	})
270	b.Run("P384", func(b *testing.B) {
271		benchmarkScalarMult(b, nistec.NewP384Point().SetGenerator(), 48)
272	})
273	b.Run("P521", func(b *testing.B) {
274		benchmarkScalarMult(b, nistec.NewP521Point().SetGenerator(), 66)
275	})
276}
277
278func benchmarkScalarMult[P nistPoint[P]](b *testing.B, p P, scalarSize int) {
279	scalar := make([]byte, scalarSize)
280	rand.Read(scalar)
281	b.ReportAllocs()
282	b.ResetTimer()
283	for i := 0; i < b.N; i++ {
284		p.ScalarMult(p, scalar)
285	}
286}
287
288func BenchmarkScalarBaseMult(b *testing.B) {
289	b.Run("P224", func(b *testing.B) {
290		benchmarkScalarBaseMult(b, nistec.NewP224Point().SetGenerator(), 28)
291	})
292	b.Run("P256", func(b *testing.B) {
293		benchmarkScalarBaseMult(b, nistec.NewP256Point().SetGenerator(), 32)
294	})
295	b.Run("P384", func(b *testing.B) {
296		benchmarkScalarBaseMult(b, nistec.NewP384Point().SetGenerator(), 48)
297	})
298	b.Run("P521", func(b *testing.B) {
299		benchmarkScalarBaseMult(b, nistec.NewP521Point().SetGenerator(), 66)
300	})
301}
302
303func benchmarkScalarBaseMult[P nistPoint[P]](b *testing.B, p P, scalarSize int) {
304	scalar := make([]byte, scalarSize)
305	rand.Read(scalar)
306	b.ReportAllocs()
307	b.ResetTimer()
308	for i := 0; i < b.N; i++ {
309		p.ScalarBaseMult(scalar)
310	}
311}
312