• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <keymaster/cppcose/cppcose.h>
18 
19 #include <iostream>
20 #include <stdio.h>
21 
22 #include <cppbor.h>
23 #include <cppbor_parse.h>
24 #include <openssl/ecdsa.h>
25 
26 #include <openssl/err.h>
27 
28 namespace cppcose {
29 constexpr int kP256AffinePointSize = 32;
30 
31 using EVP_PKEY_Ptr = bssl::UniquePtr<EVP_PKEY>;
32 using EVP_PKEY_CTX_Ptr = bssl::UniquePtr<EVP_PKEY_CTX>;
33 using ECDSA_SIG_Ptr = bssl::UniquePtr<ECDSA_SIG>;
34 using EC_KEY_Ptr = bssl::UniquePtr<EC_KEY>;
35 
36 namespace {
37 
aesGcmInitAndProcessAad(const bytevec & key,const bytevec & nonce,const bytevec & aad,bool encrypt)38 ErrMsgOr<bssl::UniquePtr<EVP_CIPHER_CTX>> aesGcmInitAndProcessAad(const bytevec& key,
39                                                                   const bytevec& nonce,
40                                                                   const bytevec& aad,
41                                                                   bool encrypt) {
42     if (key.size() != kAesGcmKeySize) return "Invalid key size";
43 
44     bssl::UniquePtr<EVP_CIPHER_CTX> ctx(EVP_CIPHER_CTX_new());
45     if (!ctx) return "Failed to allocate cipher context";
46 
47     if (!EVP_CipherInit_ex(ctx.get(), EVP_aes_256_gcm(), nullptr /* engine */, key.data(),
48                            nonce.data(), encrypt ? 1 : 0)) {
49         return "Failed to initialize cipher";
50     }
51 
52     int outlen;
53     if (!aad.empty() && !EVP_CipherUpdate(ctx.get(), nullptr /* out; null means AAD */, &outlen,
54                                           aad.data(), aad.size())) {
55         return "Failed to process AAD";
56     }
57 
58     return std::move(ctx);
59 }
60 
signEcdsaDigest(const bytevec & key,const bytevec & data)61 ErrMsgOr<bytevec> signEcdsaDigest(const bytevec& key, const bytevec& data) {
62     auto bn = BIGNUM_Ptr(BN_bin2bn(key.data(), key.size(), nullptr));
63     if (bn.get() == nullptr) {
64         return "Error creating BIGNUM";
65     }
66 
67     auto ec_key = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
68     if (EC_KEY_set_private_key(ec_key.get(), bn.get()) != 1) {
69         return "Error setting private key from BIGNUM";
70     }
71 
72     auto sig = ECDSA_SIG_Ptr(ECDSA_do_sign(data.data(), data.size(), ec_key.get()));
73     if (sig == nullptr) {
74         return "Error signing digest";
75     }
76     size_t len = i2d_ECDSA_SIG(sig.get(), nullptr);
77     bytevec signature(len);
78     unsigned char* p = (unsigned char*)signature.data();
79     i2d_ECDSA_SIG(sig.get(), &p);
80     return signature;
81 }
82 
ecdh(const bytevec & publicKey,const bytevec & privateKey)83 ErrMsgOr<bytevec> ecdh(const bytevec& publicKey, const bytevec& privateKey) {
84     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
85     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
86     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
87         1) {
88         return "Error decoding publicKey";
89     }
90     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
91     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
92     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
93         return "Memory allocation failed";
94     }
95     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
96         return "Error setting group";
97     }
98     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
99         return "Error setting point";
100     }
101     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
102         return "Error setting key";
103     }
104 
105     auto bn = BIGNUM_Ptr(BN_bin2bn(privateKey.data(), privateKey.size(), nullptr));
106     if (bn.get() == nullptr) {
107         return "Error creating BIGNUM for private key";
108     }
109     auto privEcKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
110     if (EC_KEY_set_private_key(privEcKey.get(), bn.get()) != 1) {
111         return "Error setting private key from BIGNUM";
112     }
113     auto privPkey = EVP_PKEY_Ptr(EVP_PKEY_new());
114     if (EVP_PKEY_set1_EC_KEY(privPkey.get(), privEcKey.get()) != 1) {
115         return "Error setting private key";
116     }
117 
118     auto ctx = EVP_PKEY_CTX_Ptr(EVP_PKEY_CTX_new(privPkey.get(), NULL));
119     if (ctx.get() == nullptr) {
120         return "Error creating context";
121     }
122 
123     if (EVP_PKEY_derive_init(ctx.get()) != 1) {
124         return "Error initializing context";
125     }
126 
127     if (EVP_PKEY_derive_set_peer(ctx.get(), pkey.get()) != 1) {
128         return "Error setting peer";
129     }
130 
131     /* Determine buffer length for shared secret */
132     size_t secretLen = 0;
133     if (EVP_PKEY_derive(ctx.get(), NULL, &secretLen) != 1) {
134         return "Error determing length of shared secret";
135     }
136     bytevec sharedSecret(secretLen);
137 
138     if (EVP_PKEY_derive(ctx.get(), sharedSecret.data(), &secretLen) != 1) {
139         return "Error deriving shared secret";
140     }
141     return sharedSecret;
142 }
143 
144 }  // namespace
145 
ecdsaCoseSignatureToDer(const bytevec & ecdsaCoseSignature)146 ErrMsgOr<bytevec> ecdsaCoseSignatureToDer(const bytevec& ecdsaCoseSignature) {
147     if (ecdsaCoseSignature.size() != 64) {
148         return "COSE signature wrong length";
149     }
150 
151     auto rBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data(), 32, nullptr));
152     if (rBn.get() == nullptr) {
153         return "Error creating BIGNUM for r";
154     }
155 
156     auto sBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data() + 32, 32, nullptr));
157     if (sBn.get() == nullptr) {
158         return "Error creating BIGNUM for s";
159     }
160 
161     ECDSA_SIG sig;
162     sig.r = rBn.get();
163     sig.s = sBn.get();
164 
165     size_t len = i2d_ECDSA_SIG(&sig, nullptr);
166     bytevec derSignature(len);
167     unsigned char* p = (unsigned char*)derSignature.data();
168     i2d_ECDSA_SIG(&sig, &p);
169     return derSignature;
170 }
171 
ecdsaDerSignatureToCose(const bytevec & ecdsaSignature)172 ErrMsgOr<bytevec> ecdsaDerSignatureToCose(const bytevec& ecdsaSignature) {
173     const unsigned char* p = ecdsaSignature.data();
174     auto sig = ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, ecdsaSignature.size()));
175     if (sig == nullptr) {
176         return "Error decoding DER signature";
177     }
178 
179     bytevec ecdsaCoseSignature(64, 0);
180     if (BN_bn2binpad(ECDSA_SIG_get0_r(sig.get()), ecdsaCoseSignature.data(), 32) != 32) {
181         return "Error encoding r";
182     }
183     if (BN_bn2binpad(ECDSA_SIG_get0_s(sig.get()), ecdsaCoseSignature.data() + 32, 32) != 32) {
184         return "Error encoding s";
185     }
186     return ecdsaCoseSignature;
187 }
188 
generateHmacSha256(const bytevec & key,const bytevec & data)189 ErrMsgOr<HmacSha256> generateHmacSha256(const bytevec& key, const bytevec& data) {
190     HmacSha256 digest;
191     unsigned int outLen;
192     uint8_t* out = HMAC(EVP_sha256(),              //
193                         key.data(), key.size(),    //
194                         data.data(), data.size(),  //
195                         digest.data(), &outLen);
196 
197     if (out == nullptr || outLen != digest.size()) {
198         return "Error generating HMAC";
199     }
200     return digest;
201 }
202 
generateCoseMac0Mac(HmacSha256Function macFunction,const bytevec & externalAad,const bytevec & payload)203 ErrMsgOr<HmacSha256> generateCoseMac0Mac(HmacSha256Function macFunction, const bytevec& externalAad,
204                                          const bytevec& payload) {
205     auto macStructure = cppbor::Array()
206                             .add("MAC0")
207                             .add(cppbor::Map().add(ALGORITHM, HMAC_256).canonicalize().encode())
208                             .add(externalAad)
209                             .add(payload)
210                             .encode();
211 
212     auto macTag = macFunction(macStructure);
213     if (!macTag) {
214         return "Error computing public key MAC";
215     }
216 
217     return *macTag;
218 }
219 
constructCoseMac0(HmacSha256Function macFunction,const bytevec & externalAad,const bytevec & payload)220 ErrMsgOr<cppbor::Array> constructCoseMac0(HmacSha256Function macFunction,
221                                           const bytevec& externalAad, const bytevec& payload) {
222     auto tag = generateCoseMac0Mac(macFunction, externalAad, payload);
223     if (!tag) return tag.moveMessage();
224 
225     return cppbor::Array()
226         .add(cppbor::Map().add(ALGORITHM, HMAC_256).canonicalize().encode())
227         .add(cppbor::Map() /* unprotected */)
228         .add(payload)
229         .add(std::pair(tag->begin(), tag->end()));
230 }
231 
verifyAndParseCoseMac0(const cppbor::Item * macItem,const bytevec & macKey)232 ErrMsgOr<bytevec /* payload */> verifyAndParseCoseMac0(const cppbor::Item* macItem,
233                                                        const bytevec& macKey) {
234     auto mac = macItem ? macItem->asArray() : nullptr;
235     if (!mac || mac->size() != kCoseMac0EntryCount) {
236         return "Invalid COSE_Mac0";
237     }
238 
239     auto protectedParms = mac->get(kCoseMac0ProtectedParams)->asBstr();
240     auto unprotectedParms = mac->get(kCoseMac0UnprotectedParams)->asMap();
241     auto payload = mac->get(kCoseMac0Payload)->asBstr();
242     auto tag = mac->get(kCoseMac0Tag)->asBstr();
243     if (!protectedParms || !unprotectedParms || !payload || !tag) {
244         return "Invalid COSE_Mac0 contents";
245     }
246 
247     auto [protectedMap, _, errMsg] = cppbor::parse(protectedParms);
248     if (!protectedMap || !protectedMap->asMap()) {
249         return "Invalid Mac0 protected: " + errMsg;
250     }
251     auto& algo = protectedMap->asMap()->get(ALGORITHM);
252     if (!algo || !algo->asInt() || algo->asInt()->value() != HMAC_256) {
253         return "Unsupported Mac0 algorithm";
254     }
255 
256     auto macFunction = [&macKey](const bytevec& input) {
257         return generateHmacSha256(macKey, input);
258     };
259     auto macTag = generateCoseMac0Mac(macFunction, {} /* external_aad */, payload->value());
260     if (!macTag) return macTag.moveMessage();
261 
262     if (macTag->size() != tag->value().size() ||
263         CRYPTO_memcmp(macTag->data(), tag->value().data(), macTag->size()) != 0) {
264         return "MAC tag mismatch";
265     }
266 
267     return payload->value();
268 }
269 
createECDSACoseSign1Signature(const bytevec & key,const bytevec & protectedParams,const bytevec & payload,const bytevec & aad)270 ErrMsgOr<bytevec> createECDSACoseSign1Signature(const bytevec& key, const bytevec& protectedParams,
271                                                 const bytevec& payload, const bytevec& aad) {
272     bytevec signatureInput = cppbor::Array()
273                                  .add("Signature1")  //
274                                  .add(protectedParams)
275                                  .add(aad)
276                                  .add(payload)
277                                  .encode();
278     auto ecdsaSignature = signEcdsaDigest(key, sha256(signatureInput));
279     if (!ecdsaSignature) return ecdsaSignature.moveMessage();
280 
281     return ecdsaDerSignatureToCose(*ecdsaSignature);
282 }
283 
createCoseSign1Signature(const bytevec & key,const bytevec & protectedParams,const bytevec & payload,const bytevec & aad)284 ErrMsgOr<bytevec> createCoseSign1Signature(const bytevec& key, const bytevec& protectedParams,
285                                            const bytevec& payload, const bytevec& aad) {
286     bytevec signatureInput = cppbor::Array()
287                                  .add("Signature1")  //
288                                  .add(protectedParams)
289                                  .add(aad)
290                                  .add(payload)
291                                  .encode();
292 
293     if (key.size() != ED25519_PRIVATE_KEY_LEN) return "Invalid signing key";
294     bytevec signature(ED25519_SIGNATURE_LEN);
295     if (!ED25519_sign(signature.data(), signatureInput.data(), signatureInput.size(), key.data())) {
296         return "Signing failed";
297     }
298 
299     return signature;
300 }
301 
constructECDSACoseSign1(const bytevec & key,cppbor::Map protectedParams,const bytevec & payload,const bytevec & aad)302 ErrMsgOr<cppbor::Array> constructECDSACoseSign1(const bytevec& key, cppbor::Map protectedParams,
303                                                 const bytevec& payload, const bytevec& aad) {
304     bytevec protParms = protectedParams.add(ALGORITHM, ES256).canonicalize().encode();
305     auto signature = createECDSACoseSign1Signature(key, protParms, payload, aad);
306     if (!signature) return signature.moveMessage();
307 
308     return cppbor::Array()
309         .add(std::move(protParms))
310         .add(cppbor::Map() /* unprotected parameters */)
311         .add(std::move(payload))
312         .add(std::move(*signature));
313 }
314 
constructCoseSign1(const bytevec & key,cppbor::Map protectedParams,const bytevec & payload,const bytevec & aad)315 ErrMsgOr<cppbor::Array> constructCoseSign1(const bytevec& key, cppbor::Map protectedParams,
316                                            const bytevec& payload, const bytevec& aad) {
317     bytevec protParms = protectedParams.add(ALGORITHM, EDDSA).canonicalize().encode();
318     auto signature = createCoseSign1Signature(key, protParms, payload, aad);
319     if (!signature) return signature.moveMessage();
320 
321     return cppbor::Array()
322         .add(std::move(protParms))
323         .add(cppbor::Map() /* unprotected parameters */)
324         .add(std::move(payload))
325         .add(std::move(*signature));
326 }
327 
constructCoseSign1(const bytevec & key,const bytevec & payload,const bytevec & aad)328 ErrMsgOr<cppbor::Array> constructCoseSign1(const bytevec& key, const bytevec& payload,
329                                            const bytevec& aad) {
330     return constructCoseSign1(key, {} /* protectedParams */, payload, aad);
331 }
332 
verifyAndParseCoseSign1(const cppbor::Array * coseSign1,const bytevec & signingCoseKey,const bytevec & aad)333 ErrMsgOr<bytevec> verifyAndParseCoseSign1(const cppbor::Array* coseSign1,
334                                           const bytevec& signingCoseKey, const bytevec& aad) {
335     if (!coseSign1 || coseSign1->size() != kCoseSign1EntryCount) {
336         return "Invalid COSE_Sign1";
337     }
338 
339     const cppbor::Bstr* protectedParams = coseSign1->get(kCoseSign1ProtectedParams)->asBstr();
340     const cppbor::Map* unprotectedParams = coseSign1->get(kCoseSign1UnprotectedParams)->asMap();
341     const cppbor::Bstr* payload = coseSign1->get(kCoseSign1Payload)->asBstr();
342 
343     if (!protectedParams || !unprotectedParams || !payload) {
344         return "Missing input parameters";
345     }
346 
347     auto [parsedProtParams, _, errMsg] = cppbor::parse(protectedParams);
348     if (!parsedProtParams) {
349         return errMsg + " when parsing protected params.";
350     }
351     if (!parsedProtParams->asMap()) {
352         return "Protected params must be a map";
353     }
354 
355     auto& algorithm = parsedProtParams->asMap()->get(ALGORITHM);
356     if (!algorithm || !algorithm->asInt() ||
357         !(algorithm->asInt()->value() == EDDSA || algorithm->asInt()->value() == ES256)) {
358         return "Unsupported signature algorithm";
359     }
360 
361     const cppbor::Bstr* signature = coseSign1->get(kCoseSign1Signature)->asBstr();
362     if (!signature || signature->value().empty()) {
363         return "Missing signature input";
364     }
365 
366     bool selfSigned = signingCoseKey.empty();
367     bytevec signatureInput =
368         cppbor::Array().add("Signature1").add(*protectedParams).add(aad).add(*payload).encode();
369     if (algorithm->asInt()->value() == EDDSA) {
370         auto key = CoseKey::parseEd25519(selfSigned ? payload->value() : signingCoseKey);
371         if (!key || key->getBstrValue(CoseKey::PUBKEY_X)->empty()) {
372             return "Bad signing key: " + key.moveMessage();
373         }
374 
375         if (!ED25519_verify(signatureInput.data(), signatureInput.size(), signature->value().data(),
376                             key->getBstrValue(CoseKey::PUBKEY_X)->data())) {
377             return "Signature verification failed";
378         }
379     } else {  // P256
380         auto key = CoseKey::parseP256(selfSigned ? payload->value() : signingCoseKey);
381         if (!key || key->getBstrValue(CoseKey::PUBKEY_X)->empty() ||
382             key->getBstrValue(CoseKey::PUBKEY_Y)->empty()) {
383             return "Bad signing key: " + key.moveMessage();
384         }
385         auto publicKey = key->getEcPublicKey();
386         if (!publicKey) return publicKey.moveMessage();
387 
388         auto ecdsaDerSignature = ecdsaCoseSignatureToDer(signature->value());
389         if (!ecdsaDerSignature) return ecdsaDerSignature.moveMessage();
390 
391         // convert public key to uncompressed form by prepending 0x04 at begin.
392         publicKey->insert(publicKey->begin(), 0x04);
393 
394         if (!verifyEcdsaDigest(publicKey.moveValue(), sha256(signatureInput), *ecdsaDerSignature)) {
395             return "Signature verification failed";
396         }
397     }
398 
399     return payload->value();
400 }
401 
createCoseEncryptCiphertext(const bytevec & key,const bytevec & nonce,const bytevec & protectedParams,const bytevec & plaintextPayload,const bytevec & aad)402 ErrMsgOr<bytevec> createCoseEncryptCiphertext(const bytevec& key, const bytevec& nonce,
403                                               const bytevec& protectedParams,
404                                               const bytevec& plaintextPayload, const bytevec& aad) {
405     auto ciphertext = aesGcmEncrypt(key, nonce,
406                                     cppbor::Array()            // Enc strucure as AAD
407                                         .add("Encrypt")        // Context
408                                         .add(protectedParams)  // Protected
409                                         .add(aad)              // External AAD
410                                         .encode(),
411                                     plaintextPayload);
412 
413     if (!ciphertext) return ciphertext.moveMessage();
414     return ciphertext.moveValue();
415 }
416 
constructCoseEncrypt(const bytevec & key,const bytevec & nonce,const bytevec & plaintextPayload,const bytevec & aad,cppbor::Array recipients)417 ErrMsgOr<cppbor::Array> constructCoseEncrypt(const bytevec& key, const bytevec& nonce,
418                                              const bytevec& plaintextPayload, const bytevec& aad,
419                                              cppbor::Array recipients) {
420     auto encryptProtectedHeader = cppbor::Map()  //
421                                       .add(ALGORITHM, AES_GCM_256)
422                                       .canonicalize()
423                                       .encode();
424 
425     auto ciphertext =
426         createCoseEncryptCiphertext(key, nonce, encryptProtectedHeader, plaintextPayload, aad);
427     if (!ciphertext) return ciphertext.moveMessage();
428 
429     return cppbor::Array()
430         .add(encryptProtectedHeader)                       // Protected
431         .add(cppbor::Map().add(IV, nonce).canonicalize())  // Unprotected
432         .add(*ciphertext)                                  // Payload
433         .add(std::move(recipients));
434 }
435 
436 ErrMsgOr<std::pair<bytevec /* pubkey */, bytevec /* key ID */>>
getSenderPubKeyFromCoseEncrypt(const cppbor::Item * coseEncrypt)437 getSenderPubKeyFromCoseEncrypt(const cppbor::Item* coseEncrypt) {
438     if (!coseEncrypt || !coseEncrypt->asArray() ||
439         coseEncrypt->asArray()->size() != kCoseEncryptEntryCount) {
440         return "Invalid COSE_Encrypt";
441     }
442 
443     auto& recipients = coseEncrypt->asArray()->get(kCoseEncryptRecipients);
444     if (!recipients || !recipients->asArray() || recipients->asArray()->size() != 1) {
445         return "Invalid recipients list";
446     }
447 
448     auto& recipient = recipients->asArray()->get(0);
449     if (!recipient || !recipient->asArray() || recipient->asArray()->size() != 3) {
450         return "Invalid COSE_recipient";
451     }
452 
453     auto& ciphertext = recipient->asArray()->get(2);
454     if (!ciphertext->asSimple() || !ciphertext->asSimple()->asNull()) {
455         return "Unexpected value in recipients ciphertext field " +
456                cppbor::prettyPrint(ciphertext.get());
457     }
458 
459     auto& protParms = recipient->asArray()->get(0);
460     if (!protParms || !protParms->asBstr()) return "Invalid protected params";
461     auto [parsedProtParms, _, errMsg] = cppbor::parse(protParms->asBstr());
462     if (!parsedProtParms) return "Failed to parse protected params: " + errMsg;
463     if (!parsedProtParms->asMap()) return "Invalid protected params";
464 
465     auto& algorithm = parsedProtParms->asMap()->get(ALGORITHM);
466     if (!algorithm || !algorithm->asInt() || algorithm->asInt()->value() != ECDH_ES_HKDF_256) {
467         return "Invalid algorithm";
468     }
469 
470     auto& unprotParms = recipient->asArray()->get(1);
471     if (!unprotParms || !unprotParms->asMap()) return "Invalid unprotected params";
472 
473     auto& senderCoseKey = unprotParms->asMap()->get(COSE_KEY);
474     if (!senderCoseKey || !senderCoseKey->asMap()) return "Invalid sender COSE_Key";
475 
476     auto& keyType = senderCoseKey->asMap()->get(CoseKey::KEY_TYPE);
477     if (!keyType || !keyType->asInt() ||
478         (keyType->asInt()->value() != OCTET_KEY_PAIR && keyType->asInt()->value() != EC2)) {
479         return "Invalid key type";
480     }
481 
482     auto& curve = senderCoseKey->asMap()->get(CoseKey::CURVE);
483     if (!curve || !curve->asInt() ||
484         (keyType->asInt()->value() == OCTET_KEY_PAIR && curve->asInt()->value() != X25519) ||
485         (keyType->asInt()->value() == EC2 && curve->asInt()->value() != P256)) {
486         return "Unsupported curve";
487     }
488 
489     bytevec publicKey;
490     if (keyType->asInt()->value() == EC2) {
491         auto& pubX = senderCoseKey->asMap()->get(CoseKey::PUBKEY_X);
492         if (!pubX || !pubX->asBstr() || pubX->asBstr()->value().size() != kP256AffinePointSize) {
493             return "Invalid EC public key";
494         }
495         auto& pubY = senderCoseKey->asMap()->get(CoseKey::PUBKEY_Y);
496         if (!pubY || !pubY->asBstr() || pubY->asBstr()->value().size() != kP256AffinePointSize) {
497             return "Invalid EC public key";
498         }
499         auto key = CoseKey::getEcPublicKey(pubX->asBstr()->value(), pubY->asBstr()->value());
500         if (!key) return key.moveMessage();
501         publicKey = key.moveValue();
502     } else {
503         auto& pubkey = senderCoseKey->asMap()->get(CoseKey::PUBKEY_X);
504         if (!pubkey || !pubkey->asBstr() ||
505             pubkey->asBstr()->value().size() != X25519_PUBLIC_VALUE_LEN) {
506             return "Invalid X25519 public key";
507         }
508         publicKey = pubkey->asBstr()->value();
509     }
510 
511     auto& key_id = unprotParms->asMap()->get(KEY_ID);
512     if (key_id && key_id->asBstr()) {
513         return std::make_pair(publicKey, key_id->asBstr()->value());
514     }
515 
516     // If no key ID, just return an empty vector.
517     return std::make_pair(publicKey, bytevec{});
518 }
519 
decryptCoseEncrypt(const bytevec & key,const cppbor::Item * coseEncrypt,const bytevec & external_aad)520 ErrMsgOr<bytevec> decryptCoseEncrypt(const bytevec& key, const cppbor::Item* coseEncrypt,
521                                      const bytevec& external_aad) {
522     if (!coseEncrypt || !coseEncrypt->asArray() ||
523         coseEncrypt->asArray()->size() != kCoseEncryptEntryCount) {
524         return "Invalid COSE_Encrypt";
525     }
526 
527     auto& protParms = coseEncrypt->asArray()->get(kCoseEncryptProtectedParams);
528     auto& unprotParms = coseEncrypt->asArray()->get(kCoseEncryptUnprotectedParams);
529     auto& ciphertext = coseEncrypt->asArray()->get(kCoseEncryptPayload);
530     auto& recipients = coseEncrypt->asArray()->get(kCoseEncryptRecipients);
531 
532     if (!protParms || !protParms->asBstr() || !unprotParms || !ciphertext || !recipients) {
533         return "Invalid COSE_Encrypt";
534     }
535 
536     auto [parsedProtParams, _, errMsg] = cppbor::parse(protParms->asBstr()->value());
537     if (!parsedProtParams) {
538         return errMsg + " when parsing protected params.";
539     }
540     if (!parsedProtParams->asMap()) {
541         return "Protected params must be a map";
542     }
543 
544     auto& algorithm = parsedProtParams->asMap()->get(ALGORITHM);
545     if (!algorithm || !algorithm->asInt() || algorithm->asInt()->value() != AES_GCM_256) {
546         return "Unsupported encryption algorithm";
547     }
548 
549     if (!unprotParms->asMap() || unprotParms->asMap()->size() != 1) {
550         return "Invalid unprotected params";
551     }
552 
553     auto& nonce = unprotParms->asMap()->get(IV);
554     if (!nonce || !nonce->asBstr() || nonce->asBstr()->value().size() != kAesGcmNonceLength) {
555         return "Invalid nonce";
556     }
557 
558     if (!ciphertext->asBstr()) return "Invalid ciphertext";
559 
560     auto aad = cppbor::Array()                         // Enc strucure as AAD
561                    .add("Encrypt")                     // Context
562                    .add(protParms->asBstr()->value())  // Protected
563                    .add(external_aad)                  // External AAD
564                    .encode();
565 
566     return aesGcmDecrypt(key, nonce->asBstr()->value(), aad, ciphertext->asBstr()->value());
567 }
568 
consructKdfContext(const bytevec & pubKeyA,const bytevec & privKeyA,const bytevec & pubKeyB,bool senderIsA)569 ErrMsgOr<bytevec> consructKdfContext(const bytevec& pubKeyA, const bytevec& privKeyA,
570                                      const bytevec& pubKeyB, bool senderIsA) {
571     if (privKeyA.empty() || pubKeyA.empty() || pubKeyB.empty()) {
572         return "Missing input key parameters";
573     }
574 
575     bytevec kdfContext = cppbor::Array()
576                              .add(AES_GCM_256)
577                              .add(cppbor::Array()  // Sender Info
578                                       .add(cppbor::Bstr("client"))
579                                       .add(bytevec{} /* nonce */)
580                                       .add(senderIsA ? pubKeyA : pubKeyB))
581                              .add(cppbor::Array()  // Recipient Info
582                                       .add(cppbor::Bstr("server"))
583                                       .add(bytevec{} /* nonce */)
584                                       .add(senderIsA ? pubKeyB : pubKeyA))
585                              .add(cppbor::Array()               // SuppPubInfo
586                                       .add(kAesGcmKeySizeBits)  // output key length
587                                       .add(bytevec{}))          // protected
588                              .encode();
589     return kdfContext;
590 }
591 
ECDH_HKDF_DeriveKey(const bytevec & pubKeyA,const bytevec & privKeyA,const bytevec & pubKeyB,bool senderIsA)592 ErrMsgOr<bytevec> ECDH_HKDF_DeriveKey(const bytevec& pubKeyA, const bytevec& privKeyA,
593                                       const bytevec& pubKeyB, bool senderIsA) {
594     if (privKeyA.empty() || pubKeyA.empty() || pubKeyB.empty()) {
595         return "Missing input key parameters";
596     }
597 
598     // convert public key to uncompressed form by prepending 0x04 at begin
599     bytevec publicKey;
600     publicKey.insert(publicKey.begin(), 0x04);
601     publicKey.insert(publicKey.end(), pubKeyB.begin(), pubKeyB.end());
602     auto rawSharedKey = ecdh(publicKey, privKeyA);
603     if (!rawSharedKey) return rawSharedKey.moveMessage();
604 
605     auto kdfContext = consructKdfContext(pubKeyA, privKeyA, pubKeyB, senderIsA);
606     if (!kdfContext) return kdfContext.moveMessage();
607 
608     bytevec retval(SHA256_DIGEST_LENGTH);
609     bytevec salt{};
610     if (!HKDF(retval.data(), retval.size(),                //
611               EVP_sha256(),                                //
612               rawSharedKey->data(), rawSharedKey->size(),  //
613               salt.data(), salt.size(),                    //
614               kdfContext->data(), kdfContext->size())) {
615         return "ECDH HKDF failed";
616     }
617 
618     return retval;
619 }
620 
x25519_HKDF_DeriveKey(const bytevec & pubKeyA,const bytevec & privKeyA,const bytevec & pubKeyB,bool senderIsA)621 ErrMsgOr<bytevec> x25519_HKDF_DeriveKey(const bytevec& pubKeyA, const bytevec& privKeyA,
622                                         const bytevec& pubKeyB, bool senderIsA) {
623     if (privKeyA.empty() || pubKeyA.empty() || pubKeyB.empty()) {
624         return "Missing input key parameters";
625     }
626 
627     bytevec rawSharedKey(X25519_SHARED_KEY_LEN);
628     if (!::X25519(rawSharedKey.data(), privKeyA.data(), pubKeyB.data())) {
629         return "ECDH operation failed";
630     }
631 
632     auto kdfContext = consructKdfContext(pubKeyA, privKeyA, pubKeyB, senderIsA);
633     if (!kdfContext) return kdfContext.moveMessage();
634 
635     bytevec retval(SHA256_DIGEST_LENGTH);
636     bytevec salt{};
637     if (!HKDF(retval.data(), retval.size(),              //
638               EVP_sha256(),                              //
639               rawSharedKey.data(), rawSharedKey.size(),  //
640               salt.data(), salt.size(),                  //
641               kdfContext->data(), kdfContext->size())) {
642         return "ECDH HKDF failed";
643     }
644 
645     return retval;
646 }
647 
aesGcmEncrypt(const bytevec & key,const bytevec & nonce,const bytevec & aad,const bytevec & plaintext)648 ErrMsgOr<bytevec> aesGcmEncrypt(const bytevec& key, const bytevec& nonce, const bytevec& aad,
649                                 const bytevec& plaintext) {
650     auto ctx = aesGcmInitAndProcessAad(key, nonce, aad, true /* encrypt */);
651     if (!ctx) return ctx.moveMessage();
652 
653     bytevec ciphertext(plaintext.size() + kAesGcmTagSize);
654     int outlen;
655     if (!EVP_CipherUpdate(ctx->get(), ciphertext.data(), &outlen, plaintext.data(),
656                           plaintext.size())) {
657         return "Failed to encrypt plaintext";
658     }
659     assert(plaintext.size() == static_cast<uint64_t>(outlen));
660 
661     if (!EVP_CipherFinal_ex(ctx->get(), ciphertext.data() + outlen, &outlen)) {
662         return "Failed to finalize encryption";
663     }
664     assert(outlen == 0);
665 
666     if (!EVP_CIPHER_CTX_ctrl(ctx->get(), EVP_CTRL_GCM_GET_TAG, kAesGcmTagSize,
667                              ciphertext.data() + plaintext.size())) {
668         return "Failed to retrieve tag";
669     }
670 
671     return ciphertext;
672 }
673 
aesGcmDecrypt(const bytevec & key,const bytevec & nonce,const bytevec & aad,const bytevec & ciphertextWithTag)674 ErrMsgOr<bytevec> aesGcmDecrypt(const bytevec& key, const bytevec& nonce, const bytevec& aad,
675                                 const bytevec& ciphertextWithTag) {
676     auto ctx = aesGcmInitAndProcessAad(key, nonce, aad, false /* encrypt */);
677     if (!ctx) return ctx.moveMessage();
678 
679     if (ciphertextWithTag.size() < kAesGcmTagSize) return "Missing tag";
680 
681     bytevec plaintext(ciphertextWithTag.size() - kAesGcmTagSize);
682     int outlen;
683     if (!EVP_CipherUpdate(ctx->get(), plaintext.data(), &outlen, ciphertextWithTag.data(),
684                           ciphertextWithTag.size() - kAesGcmTagSize)) {
685         return "Failed to decrypt plaintext";
686     }
687     assert(plaintext.size() == static_cast<uint64_t>(outlen));
688 
689     bytevec tag(ciphertextWithTag.end() - kAesGcmTagSize, ciphertextWithTag.end());
690     if (!EVP_CIPHER_CTX_ctrl(ctx->get(), EVP_CTRL_GCM_SET_TAG, kAesGcmTagSize, tag.data())) {
691         return "Failed to set tag: " + std::to_string(ERR_peek_last_error());
692     }
693 
694     if (!EVP_CipherFinal_ex(ctx->get(), nullptr, &outlen)) {
695         return "Failed to finalize encryption";
696     }
697     assert(outlen == 0);
698 
699     return plaintext;
700 }
701 
sha256(const bytevec & data)702 bytevec sha256(const bytevec& data) {
703     bytevec ret(SHA256_DIGEST_LENGTH);
704     SHA256_CTX ctx;
705     SHA256_Init(&ctx);
706     SHA256_Update(&ctx, data.data(), data.size());
707     SHA256_Final((unsigned char*)ret.data(), &ctx);
708     return ret;
709 }
710 
verifyEcdsaDigest(const bytevec & key,const bytevec & digest,const bytevec & signature)711 bool verifyEcdsaDigest(const bytevec& key, const bytevec& digest, const bytevec& signature) {
712     const unsigned char* p = (unsigned char*)signature.data();
713     auto sig = ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, signature.size()));
714     if (sig.get() == nullptr) {
715         return false;
716     }
717 
718     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
719     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
720     if (EC_POINT_oct2point(group.get(), point.get(), key.data(), key.size(), nullptr) != 1) {
721         return false;
722     }
723     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
724     if (ecKey.get() == nullptr) {
725         return false;
726     }
727     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
728         return false;
729     }
730     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
731         return false;
732     }
733 
734     int rc = ECDSA_do_verify(digest.data(), digest.size(), sig.get(), ecKey.get());
735     if (rc != 1) {
736         return false;
737     }
738     return true;
739 }
740 
741 }  // namespace cppcose
742