• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package hpke
6
7import (
8	"bytes"
9	"encoding/hex"
10	"encoding/json"
11	"os"
12	"strconv"
13	"strings"
14	"testing"
15
16	"crypto/ecdh"
17	_ "crypto/sha256"
18	_ "crypto/sha512"
19)
20
21func mustDecodeHex(t *testing.T, in string) []byte {
22	b, err := hex.DecodeString(in)
23	if err != nil {
24		t.Fatal(err)
25	}
26	return b
27}
28
29func parseVectorSetup(vector string) map[string]string {
30	vals := map[string]string{}
31	for _, l := range strings.Split(vector, "\n") {
32		fields := strings.Split(l, ": ")
33		vals[fields[0]] = fields[1]
34	}
35	return vals
36}
37
38func parseVectorEncryptions(vector string) []map[string]string {
39	vals := []map[string]string{}
40	for _, section := range strings.Split(vector, "\n\n") {
41		e := map[string]string{}
42		for _, l := range strings.Split(section, "\n") {
43			fields := strings.Split(l, ": ")
44			e[fields[0]] = fields[1]
45		}
46		vals = append(vals, e)
47	}
48	return vals
49}
50
51func TestRFC9180Vectors(t *testing.T) {
52	vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
53	if err != nil {
54		t.Fatal(err)
55	}
56
57	var vectors []struct {
58		Name        string
59		Setup       string
60		Encryptions string
61	}
62	if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
63		t.Fatal(err)
64	}
65
66	for _, vector := range vectors {
67		t.Run(vector.Name, func(t *testing.T) {
68			setup := parseVectorSetup(vector.Setup)
69
70			kemID, err := strconv.Atoi(setup["kem_id"])
71			if err != nil {
72				t.Fatal(err)
73			}
74			if _, ok := SupportedKEMs[uint16(kemID)]; !ok {
75				t.Skip("unsupported KEM")
76			}
77			kdfID, err := strconv.Atoi(setup["kdf_id"])
78			if err != nil {
79				t.Fatal(err)
80			}
81			if _, ok := SupportedKDFs[uint16(kdfID)]; !ok {
82				t.Skip("unsupported KDF")
83			}
84			aeadID, err := strconv.Atoi(setup["aead_id"])
85			if err != nil {
86				t.Fatal(err)
87			}
88			if _, ok := SupportedAEADs[uint16(aeadID)]; !ok {
89				t.Skip("unsupported AEAD")
90			}
91
92			info := mustDecodeHex(t, setup["info"])
93			pubKeyBytes := mustDecodeHex(t, setup["pkRm"])
94			pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
95			if err != nil {
96				t.Fatal(err)
97			}
98
99			ephemeralPrivKey := mustDecodeHex(t, setup["skEm"])
100
101			testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) {
102				return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey)
103			}
104			t.Cleanup(func() { testingOnlyGenerateKey = nil })
105
106			encap, context, err := SetupSender(
107				uint16(kemID),
108				uint16(kdfID),
109				uint16(aeadID),
110				pub,
111				info,
112			)
113			if err != nil {
114				t.Fatal(err)
115			}
116
117			expectedEncap := mustDecodeHex(t, setup["enc"])
118			if !bytes.Equal(encap, expectedEncap) {
119				t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
120			}
121			expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
122			if !bytes.Equal(context.sharedSecret, expectedSharedSecret) {
123				t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret)
124			}
125			expectedKey := mustDecodeHex(t, setup["key"])
126			if !bytes.Equal(context.key, expectedKey) {
127				t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey)
128			}
129			expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
130			if !bytes.Equal(context.baseNonce, expectedBaseNonce) {
131				t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce)
132			}
133			expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
134			if !bytes.Equal(context.exporterSecret, expectedExporterSecret) {
135				t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret)
136			}
137
138			for _, enc := range parseVectorEncryptions(vector.Encryptions) {
139				t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
140					seqNum, err := strconv.Atoi(enc["sequence number"])
141					if err != nil {
142						t.Fatal(err)
143					}
144					context.seqNum = uint128{lo: uint64(seqNum)}
145					expectedNonce := mustDecodeHex(t, enc["nonce"])
146					// We can't call nextNonce, because it increments the sequence number,
147					// so just compute it directly.
148					computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():]
149					for i := range context.baseNonce {
150						computedNonce[i] ^= context.baseNonce[i]
151					}
152					if !bytes.Equal(computedNonce, expectedNonce) {
153						t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
154					}
155
156					expectedCiphertext := mustDecodeHex(t, enc["ct"])
157					ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
158					if err != nil {
159						t.Fatal(err)
160					}
161					if !bytes.Equal(ciphertext, expectedCiphertext) {
162						t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
163					}
164				})
165			}
166		})
167	}
168}
169