1// Copyright 2024 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package hpke 6 7import ( 8 "bytes" 9 "encoding/hex" 10 "encoding/json" 11 "os" 12 "strconv" 13 "strings" 14 "testing" 15 16 "crypto/ecdh" 17 _ "crypto/sha256" 18 _ "crypto/sha512" 19) 20 21func mustDecodeHex(t *testing.T, in string) []byte { 22 b, err := hex.DecodeString(in) 23 if err != nil { 24 t.Fatal(err) 25 } 26 return b 27} 28 29func parseVectorSetup(vector string) map[string]string { 30 vals := map[string]string{} 31 for _, l := range strings.Split(vector, "\n") { 32 fields := strings.Split(l, ": ") 33 vals[fields[0]] = fields[1] 34 } 35 return vals 36} 37 38func parseVectorEncryptions(vector string) []map[string]string { 39 vals := []map[string]string{} 40 for _, section := range strings.Split(vector, "\n\n") { 41 e := map[string]string{} 42 for _, l := range strings.Split(section, "\n") { 43 fields := strings.Split(l, ": ") 44 e[fields[0]] = fields[1] 45 } 46 vals = append(vals, e) 47 } 48 return vals 49} 50 51func TestRFC9180Vectors(t *testing.T) { 52 vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json") 53 if err != nil { 54 t.Fatal(err) 55 } 56 57 var vectors []struct { 58 Name string 59 Setup string 60 Encryptions string 61 } 62 if err := json.Unmarshal(vectorsJSON, &vectors); err != nil { 63 t.Fatal(err) 64 } 65 66 for _, vector := range vectors { 67 t.Run(vector.Name, func(t *testing.T) { 68 setup := parseVectorSetup(vector.Setup) 69 70 kemID, err := strconv.Atoi(setup["kem_id"]) 71 if err != nil { 72 t.Fatal(err) 73 } 74 if _, ok := SupportedKEMs[uint16(kemID)]; !ok { 75 t.Skip("unsupported KEM") 76 } 77 kdfID, err := strconv.Atoi(setup["kdf_id"]) 78 if err != nil { 79 t.Fatal(err) 80 } 81 if _, ok := SupportedKDFs[uint16(kdfID)]; !ok { 82 t.Skip("unsupported KDF") 83 } 84 aeadID, err := strconv.Atoi(setup["aead_id"]) 85 if err != nil { 86 t.Fatal(err) 87 } 88 if _, ok := SupportedAEADs[uint16(aeadID)]; !ok { 89 t.Skip("unsupported AEAD") 90 } 91 92 info := mustDecodeHex(t, setup["info"]) 93 pubKeyBytes := mustDecodeHex(t, setup["pkRm"]) 94 pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes) 95 if err != nil { 96 t.Fatal(err) 97 } 98 99 ephemeralPrivKey := mustDecodeHex(t, setup["skEm"]) 100 101 testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) { 102 return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey) 103 } 104 t.Cleanup(func() { testingOnlyGenerateKey = nil }) 105 106 encap, context, err := SetupSender( 107 uint16(kemID), 108 uint16(kdfID), 109 uint16(aeadID), 110 pub, 111 info, 112 ) 113 if err != nil { 114 t.Fatal(err) 115 } 116 117 expectedEncap := mustDecodeHex(t, setup["enc"]) 118 if !bytes.Equal(encap, expectedEncap) { 119 t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap) 120 } 121 expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"]) 122 if !bytes.Equal(context.sharedSecret, expectedSharedSecret) { 123 t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret) 124 } 125 expectedKey := mustDecodeHex(t, setup["key"]) 126 if !bytes.Equal(context.key, expectedKey) { 127 t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey) 128 } 129 expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"]) 130 if !bytes.Equal(context.baseNonce, expectedBaseNonce) { 131 t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce) 132 } 133 expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"]) 134 if !bytes.Equal(context.exporterSecret, expectedExporterSecret) { 135 t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret) 136 } 137 138 for _, enc := range parseVectorEncryptions(vector.Encryptions) { 139 t.Run("seq num "+enc["sequence number"], func(t *testing.T) { 140 seqNum, err := strconv.Atoi(enc["sequence number"]) 141 if err != nil { 142 t.Fatal(err) 143 } 144 context.seqNum = uint128{lo: uint64(seqNum)} 145 expectedNonce := mustDecodeHex(t, enc["nonce"]) 146 // We can't call nextNonce, because it increments the sequence number, 147 // so just compute it directly. 148 computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():] 149 for i := range context.baseNonce { 150 computedNonce[i] ^= context.baseNonce[i] 151 } 152 if !bytes.Equal(computedNonce, expectedNonce) { 153 t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce) 154 } 155 156 expectedCiphertext := mustDecodeHex(t, enc["ct"]) 157 ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"])) 158 if err != nil { 159 t.Fatal(err) 160 } 161 if !bytes.Equal(ciphertext, expectedCiphertext) { 162 t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext) 163 } 164 }) 165 } 166 }) 167 } 168} 169