• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ecdh_test
6
7import (
8	"bytes"
9	"crypto"
10	"crypto/cipher"
11	"crypto/ecdh"
12	"crypto/rand"
13	"crypto/sha256"
14	"encoding/hex"
15	"fmt"
16	"internal/testenv"
17	"io"
18	"os"
19	"os/exec"
20	"path/filepath"
21	"regexp"
22	"strings"
23	"testing"
24
25	"golang.org/x/crypto/chacha20"
26)
27
28// Check that PublicKey and PrivateKey implement the interfaces documented in
29// crypto.PublicKey and crypto.PrivateKey.
30var _ interface {
31	Equal(x crypto.PublicKey) bool
32} = &ecdh.PublicKey{}
33var _ interface {
34	Public() crypto.PublicKey
35	Equal(x crypto.PrivateKey) bool
36} = &ecdh.PrivateKey{}
37
38func TestECDH(t *testing.T) {
39	testAllCurves(t, func(t *testing.T, curve ecdh.Curve) {
40		aliceKey, err := curve.GenerateKey(rand.Reader)
41		if err != nil {
42			t.Fatal(err)
43		}
44		bobKey, err := curve.GenerateKey(rand.Reader)
45		if err != nil {
46			t.Fatal(err)
47		}
48
49		alicePubKey, err := curve.NewPublicKey(aliceKey.PublicKey().Bytes())
50		if err != nil {
51			t.Error(err)
52		}
53		if !bytes.Equal(aliceKey.PublicKey().Bytes(), alicePubKey.Bytes()) {
54			t.Error("encoded and decoded public keys are different")
55		}
56		if !aliceKey.PublicKey().Equal(alicePubKey) {
57			t.Error("encoded and decoded public keys are different")
58		}
59
60		alicePrivKey, err := curve.NewPrivateKey(aliceKey.Bytes())
61		if err != nil {
62			t.Error(err)
63		}
64		if !bytes.Equal(aliceKey.Bytes(), alicePrivKey.Bytes()) {
65			t.Error("encoded and decoded private keys are different")
66		}
67		if !aliceKey.Equal(alicePrivKey) {
68			t.Error("encoded and decoded private keys are different")
69		}
70
71		bobSecret, err := bobKey.ECDH(aliceKey.PublicKey())
72		if err != nil {
73			t.Fatal(err)
74		}
75		aliceSecret, err := aliceKey.ECDH(bobKey.PublicKey())
76		if err != nil {
77			t.Fatal(err)
78		}
79
80		if !bytes.Equal(bobSecret, aliceSecret) {
81			t.Error("two ECDH computations came out different")
82		}
83	})
84}
85
86type countingReader struct {
87	r io.Reader
88	n int
89}
90
91func (r *countingReader) Read(p []byte) (int, error) {
92	n, err := r.r.Read(p)
93	r.n += n
94	return n, err
95}
96
97func TestGenerateKey(t *testing.T) {
98	testAllCurves(t, func(t *testing.T, curve ecdh.Curve) {
99		r := &countingReader{r: rand.Reader}
100		k, err := curve.GenerateKey(r)
101		if err != nil {
102			t.Fatal(err)
103		}
104
105		// GenerateKey does rejection sampling. If the masking works correctly,
106		// the probability of a rejection is 1-ord(G)/2^ceil(log2(ord(G))),
107		// which for all curves is small enough (at most 2^-32, for P-256) that
108		// a bit flip is more likely to make this test fail than bad luck.
109		// Account for the extra MaybeReadByte byte, too.
110		if got, expected := r.n, len(k.Bytes())+1; got > expected {
111			t.Errorf("expected GenerateKey to consume at most %v bytes, got %v", expected, got)
112		}
113	})
114}
115
116var vectors = map[ecdh.Curve]struct {
117	PrivateKey, PublicKey string
118	PeerPublicKey         string
119	SharedSecret          string
120}{
121	// NIST vectors from CAVS 14.1, ECC CDH Primitive (SP800-56A).
122	ecdh.P256(): {
123		PrivateKey: "7d7dc5f71eb29ddaf80d6214632eeae03d9058af1fb6d22ed80badb62bc1a534",
124		PublicKey: "04ead218590119e8876b29146ff89ca61770c4edbbf97d38ce385ed281d8a6b230" +
125			"28af61281fd35e2fa7002523acc85a429cb06ee6648325389f59edfce1405141",
126		PeerPublicKey: "04700c48f77f56584c5cc632ca65640db91b6bacce3a4df6b42ce7cc838833d287" +
127			"db71e509e3fd9b060ddb20ba5c51dcc5948d46fbf640dfe0441782cab85fa4ac",
128		SharedSecret: "46fc62106420ff012e54a434fbdd2d25ccc5852060561e68040dd7778997bd7b",
129	},
130	ecdh.P384(): {
131		PrivateKey: "3cc3122a68f0d95027ad38c067916ba0eb8c38894d22e1b15618b6818a661774ad463b205da88cf699ab4d43c9cf98a1",
132		PublicKey: "049803807f2f6d2fd966cdd0290bd410c0190352fbec7ff6247de1302df86f25d34fe4a97bef60cff548355c015dbb3e5f" +
133			"ba26ca69ec2f5b5d9dad20cc9da711383a9dbe34ea3fa5a2af75b46502629ad54dd8b7d73a8abb06a3a3be47d650cc99",
134		PeerPublicKey: "04a7c76b970c3b5fe8b05d2838ae04ab47697b9eaf52e764592efda27fe7513272734466b400091adbf2d68c58e0c50066" +
135			"ac68f19f2e1cb879aed43a9969b91a0839c4c38a49749b661efedf243451915ed0905a32b060992b468c64766fc8437a",
136		SharedSecret: "5f9d29dc5e31a163060356213669c8ce132e22f57c9a04f40ba7fcead493b457e5621e766c40a2e3d4d6a04b25e533f1",
137	},
138	// For some reason all field elements in the test vector (both scalars and
139	// base field elements), but not the shared secret output, have two extra
140	// leading zero bytes (which in big-endian are irrelevant). Removed here.
141	ecdh.P521(): {
142		PrivateKey: "017eecc07ab4b329068fba65e56a1f8890aa935e57134ae0ffcce802735151f4eac6564f6ee9974c5e6887a1fefee5743ae2241bfeb95d5ce31ddcb6f9edb4d6fc47",
143		PublicKey: "0400602f9d0cf9e526b29e22381c203c48a886c2b0673033366314f1ffbcba240ba42f4ef38a76174635f91e6b4ed34275eb01c8467d05ca80315bf1a7bbd945f550a5" +
144			"01b7c85f26f5d4b2d7355cf6b02117659943762b6d1db5ab4f1dbc44ce7b2946eb6c7de342962893fd387d1b73d7a8672d1f236961170b7eb3579953ee5cdc88cd2d",
145		PeerPublicKey: "0400685a48e86c79f0f0875f7bc18d25eb5fc8c0b07e5da4f4370f3a9490340854334b1e1b87fa395464c60626124a4e70d0f785601d37c09870ebf176666877a2046d" +
146			"01ba52c56fc8776d9e8f5db4f0cc27636d0b741bbe05400697942e80b739884a83bde99e0f6716939e632bc8986fa18dccd443a348b6c3e522497955a4f3c302f676",
147		SharedSecret: "005fc70477c3e63bc3954bd0df3ea0d1f41ee21746ed95fc5e1fdf90930d5e136672d72cc770742d1711c3c3a4c334a0ad9759436a4d3c5bf6e74b9578fac148c831",
148	},
149	// X25519 test vector from RFC 7748, Section 6.1.
150	ecdh.X25519(): {
151		PrivateKey:    "77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a",
152		PublicKey:     "8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a",
153		PeerPublicKey: "de9edb7d7b7dc1b4d35b61c2ece435373f8343c85b78674dadfc7e146f882b4f",
154		SharedSecret:  "4a5d9d5ba4ce2de1728e3bf480350f25e07e21c947d19e3376f09b3c1e161742",
155	},
156}
157
158func TestVectors(t *testing.T) {
159	testAllCurves(t, func(t *testing.T, curve ecdh.Curve) {
160		v := vectors[curve]
161		key, err := curve.NewPrivateKey(hexDecode(t, v.PrivateKey))
162		if err != nil {
163			t.Fatal(err)
164		}
165		if !bytes.Equal(key.PublicKey().Bytes(), hexDecode(t, v.PublicKey)) {
166			t.Error("public key derived from the private key does not match")
167		}
168		peer, err := curve.NewPublicKey(hexDecode(t, v.PeerPublicKey))
169		if err != nil {
170			t.Fatal(err)
171		}
172		secret, err := key.ECDH(peer)
173		if err != nil {
174			t.Fatal(err)
175		}
176		if !bytes.Equal(secret, hexDecode(t, v.SharedSecret)) {
177			t.Errorf("shared secret does not match: %x %x %s %x", secret, sha256.Sum256(secret), v.SharedSecret,
178				sha256.Sum256(hexDecode(t, v.SharedSecret)))
179		}
180	})
181}
182
183func hexDecode(t *testing.T, s string) []byte {
184	b, err := hex.DecodeString(s)
185	if err != nil {
186		t.Fatal("invalid hex string:", s)
187	}
188	return b
189}
190
191func TestString(t *testing.T) {
192	testAllCurves(t, func(t *testing.T, curve ecdh.Curve) {
193		s := fmt.Sprintf("%s", curve)
194		if s[:1] != "P" && s[:1] != "X" {
195			t.Errorf("unexpected Curve string encoding: %q", s)
196		}
197	})
198}
199
200func TestX25519Failure(t *testing.T) {
201	identity := hexDecode(t, "0000000000000000000000000000000000000000000000000000000000000000")
202	lowOrderPoint := hexDecode(t, "e0eb7a7c3b41b8ae1656e3faf19fc46ada098deb9c32b1fd866205165f49b800")
203	randomScalar := make([]byte, 32)
204	rand.Read(randomScalar)
205
206	t.Run("identity point", func(t *testing.T) { testX25519Failure(t, randomScalar, identity) })
207	t.Run("low order point", func(t *testing.T) { testX25519Failure(t, randomScalar, lowOrderPoint) })
208}
209
210func testX25519Failure(t *testing.T, private, public []byte) {
211	priv, err := ecdh.X25519().NewPrivateKey(private)
212	if err != nil {
213		t.Fatal(err)
214	}
215	pub, err := ecdh.X25519().NewPublicKey(public)
216	if err != nil {
217		t.Fatal(err)
218	}
219	secret, err := priv.ECDH(pub)
220	if err == nil {
221		t.Error("expected ECDH error")
222	}
223	if secret != nil {
224		t.Errorf("unexpected ECDH output: %x", secret)
225	}
226}
227
228var invalidPrivateKeys = map[ecdh.Curve][]string{
229	ecdh.P256(): {
230		// Bad lengths.
231		"",
232		"01",
233		"01010101010101010101010101010101010101010101010101010101010101",
234		"000101010101010101010101010101010101010101010101010101010101010101",
235		strings.Repeat("01", 200),
236		// Zero.
237		"0000000000000000000000000000000000000000000000000000000000000000",
238		// Order of the curve and above.
239		"ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551",
240		"ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632552",
241		"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
242	},
243	ecdh.P384(): {
244		// Bad lengths.
245		"",
246		"01",
247		"0101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101",
248		"00010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101",
249		strings.Repeat("01", 200),
250		// Zero.
251		"000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
252		// Order of the curve and above.
253		"ffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52973",
254		"ffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52974",
255		"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
256	},
257	ecdh.P521(): {
258		// Bad lengths.
259		"",
260		"01",
261		"0101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101",
262		"00010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101",
263		strings.Repeat("01", 200),
264		// Zero.
265		"000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
266		// Order of the curve and above.
267		"01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409",
268		"01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e9138640a",
269		"11fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409",
270		"03fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff4a30d0f077e5f2cd6ff980291ee134ba0776b937113388f5d76df6e3d2270c812",
271	},
272	ecdh.X25519(): {
273		// X25519 only rejects bad lengths.
274		"",
275		"01",
276		"01010101010101010101010101010101010101010101010101010101010101",
277		"000101010101010101010101010101010101010101010101010101010101010101",
278		strings.Repeat("01", 200),
279	},
280}
281
282func TestNewPrivateKey(t *testing.T) {
283	testAllCurves(t, func(t *testing.T, curve ecdh.Curve) {
284		for _, input := range invalidPrivateKeys[curve] {
285			k, err := curve.NewPrivateKey(hexDecode(t, input))
286			if err == nil {
287				t.Errorf("unexpectedly accepted %q", input)
288			} else if k != nil {
289				t.Error("PrivateKey was not nil on error")
290			} else if strings.Contains(err.Error(), "boringcrypto") {
291				t.Errorf("boringcrypto error leaked out: %v", err)
292			}
293		}
294	})
295}
296
297var invalidPublicKeys = map[ecdh.Curve][]string{
298	ecdh.P256(): {
299		// Bad lengths.
300		"",
301		"04",
302		strings.Repeat("04", 200),
303		// Infinity.
304		"00",
305		// Compressed encodings.
306		"036b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296",
307		"02e2534a3532d08fbba02dde659ee62bd0031fe2db785596ef509302446b030852",
308		// Points not on the curve.
309		"046b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c2964fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f6",
310		"0400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
311	},
312	ecdh.P384(): {
313		// Bad lengths.
314		"",
315		"04",
316		strings.Repeat("04", 200),
317		// Infinity.
318		"00",
319		// Compressed encodings.
320		"03aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab7",
321		"0208d999057ba3d2d969260045c55b97f089025959a6f434d651d207d19fb96e9e4fe0e86ebe0e64f85b96a9c75295df61",
322		// Points not on the curve.
323		"04aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab73617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e60",
324		"04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
325	},
326	ecdh.P521(): {
327		// Bad lengths.
328		"",
329		"04",
330		strings.Repeat("04", 200),
331		// Infinity.
332		"00",
333		// Compressed encodings.
334		"030035b5df64ae2ac204c354b483487c9070cdc61c891c5ff39afc06c5d55541d3ceac8659e24afe3d0750e8b88e9f078af066a1d5025b08e5a5e2fbc87412871902f3",
335		"0200c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66",
336		// Points not on the curve.
337		"0400c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66011839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16651",
338		"04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
339	},
340	ecdh.X25519(): {},
341}
342
343func TestNewPublicKey(t *testing.T) {
344	testAllCurves(t, func(t *testing.T, curve ecdh.Curve) {
345		for _, input := range invalidPublicKeys[curve] {
346			k, err := curve.NewPublicKey(hexDecode(t, input))
347			if err == nil {
348				t.Errorf("unexpectedly accepted %q", input)
349			} else if k != nil {
350				t.Error("PublicKey was not nil on error")
351			} else if strings.Contains(err.Error(), "boringcrypto") {
352				t.Errorf("boringcrypto error leaked out: %v", err)
353			}
354		}
355	})
356}
357
358func testAllCurves(t *testing.T, f func(t *testing.T, curve ecdh.Curve)) {
359	t.Run("P256", func(t *testing.T) { f(t, ecdh.P256()) })
360	t.Run("P384", func(t *testing.T) { f(t, ecdh.P384()) })
361	t.Run("P521", func(t *testing.T) { f(t, ecdh.P521()) })
362	t.Run("X25519", func(t *testing.T) { f(t, ecdh.X25519()) })
363}
364
365func BenchmarkECDH(b *testing.B) {
366	benchmarkAllCurves(b, func(b *testing.B, curve ecdh.Curve) {
367		c, err := chacha20.NewUnauthenticatedCipher(make([]byte, 32), make([]byte, 12))
368		if err != nil {
369			b.Fatal(err)
370		}
371		rand := cipher.StreamReader{
372			S: c, R: zeroReader,
373		}
374
375		peerKey, err := curve.GenerateKey(rand)
376		if err != nil {
377			b.Fatal(err)
378		}
379		peerShare := peerKey.PublicKey().Bytes()
380		b.ResetTimer()
381		b.ReportAllocs()
382
383		var allocationsSink byte
384
385		for i := 0; i < b.N; i++ {
386			key, err := curve.GenerateKey(rand)
387			if err != nil {
388				b.Fatal(err)
389			}
390			share := key.PublicKey().Bytes()
391			peerPubKey, err := curve.NewPublicKey(peerShare)
392			if err != nil {
393				b.Fatal(err)
394			}
395			secret, err := key.ECDH(peerPubKey)
396			if err != nil {
397				b.Fatal(err)
398			}
399			allocationsSink ^= secret[0] ^ share[0]
400		}
401	})
402}
403
404func benchmarkAllCurves(b *testing.B, f func(b *testing.B, curve ecdh.Curve)) {
405	b.Run("P256", func(b *testing.B) { f(b, ecdh.P256()) })
406	b.Run("P384", func(b *testing.B) { f(b, ecdh.P384()) })
407	b.Run("P521", func(b *testing.B) { f(b, ecdh.P521()) })
408	b.Run("X25519", func(b *testing.B) { f(b, ecdh.X25519()) })
409}
410
411type zr struct{}
412
413// Read replaces the contents of dst with zeros. It is safe for concurrent use.
414func (zr) Read(dst []byte) (n int, err error) {
415	clear(dst)
416	return len(dst), nil
417}
418
419var zeroReader = zr{}
420
421const linkerTestProgram = `
422package main
423import "crypto/ecdh"
424import "crypto/rand"
425func main() {
426	curve := ecdh.P384()
427	key, err := curve.GenerateKey(rand.Reader)
428	if err != nil { panic(err) }
429	_, err = curve.NewPublicKey(key.PublicKey().Bytes())
430	if err != nil { panic(err) }
431	_, err = curve.NewPrivateKey(key.Bytes())
432	if err != nil { panic(err) }
433	_, err = key.ECDH(key.PublicKey())
434	if err != nil { panic(err) }
435	println("OK")
436}
437`
438
439// TestLinker ensures that using one curve does not bring all other
440// implementations into the binary. This also guarantees that govulncheck can
441// avoid warning about a curve-specific vulnerability if that curve is not used.
442func TestLinker(t *testing.T) {
443	if testing.Short() {
444		t.Skip("test requires running 'go build'")
445	}
446	testenv.MustHaveGoBuild(t)
447
448	dir := t.TempDir()
449	hello := filepath.Join(dir, "hello.go")
450	err := os.WriteFile(hello, []byte(linkerTestProgram), 0664)
451	if err != nil {
452		t.Fatal(err)
453	}
454
455	run := func(args ...string) string {
456		cmd := exec.Command(args[0], args[1:]...)
457		cmd.Dir = dir
458		out, err := cmd.CombinedOutput()
459		if err != nil {
460			t.Fatalf("%v: %v\n%s", args, err, string(out))
461		}
462		return string(out)
463	}
464
465	goBin := testenv.GoToolPath(t)
466	run(goBin, "build", "-o", "hello.exe", "hello.go")
467	if out := run("./hello.exe"); out != "OK\n" {
468		t.Error("unexpected output:", out)
469	}
470
471	// List all text symbols under crypto/... and make sure there are some for
472	// P384, but none for the other curves.
473	var consistent bool
474	nm := run(goBin, "tool", "nm", "hello.exe")
475	for _, match := range regexp.MustCompile(`(?m)T (crypto/.*)$`).FindAllStringSubmatch(nm, -1) {
476		symbol := strings.ToLower(match[1])
477		if strings.Contains(symbol, "p384") {
478			consistent = true
479		}
480		if strings.Contains(symbol, "p224") || strings.Contains(symbol, "p256") || strings.Contains(symbol, "p521") {
481			t.Errorf("unexpected symbol in program using only ecdh.P384: %s", match[1])
482		}
483	}
484	if !consistent {
485		t.Error("no P384 symbols found in program using ecdh.P384, test is broken")
486	}
487}
488
489func TestMismatchedCurves(t *testing.T) {
490	curves := []struct {
491		name  string
492		curve ecdh.Curve
493	}{
494		{"P256", ecdh.P256()},
495		{"P384", ecdh.P384()},
496		{"P521", ecdh.P521()},
497		{"X25519", ecdh.X25519()},
498	}
499
500	for _, privCurve := range curves {
501		priv, err := privCurve.curve.GenerateKey(rand.Reader)
502		if err != nil {
503			t.Fatalf("failed to generate test key: %s", err)
504		}
505
506		for _, pubCurve := range curves {
507			if privCurve == pubCurve {
508				continue
509			}
510			t.Run(fmt.Sprintf("%s/%s", privCurve.name, pubCurve.name), func(t *testing.T) {
511				pub, err := pubCurve.curve.GenerateKey(rand.Reader)
512				if err != nil {
513					t.Fatalf("failed to generate test key: %s", err)
514				}
515				expected := "crypto/ecdh: private key and public key curves do not match"
516				_, err = priv.ECDH(pub.PublicKey())
517				if err.Error() != expected {
518					t.Fatalf("unexpected error: want %q, got %q", expected, err)
519				}
520			})
521		}
522	}
523}
524