• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "IdentityCredentialSupport"
18 
19 #include <android/hardware/identity/support/IdentityCredentialSupport.h>
20 
21 #define _POSIX_C_SOURCE 199309L
22 
23 #include <ctype.h>
24 #include <stdarg.h>
25 #include <stdio.h>
26 #include <time.h>
27 #include <chrono>
28 #include <iomanip>
29 
30 #include <openssl/aes.h>
31 #include <openssl/bn.h>
32 #include <openssl/crypto.h>
33 #include <openssl/ec.h>
34 #include <openssl/err.h>
35 #include <openssl/evp.h>
36 #include <openssl/hkdf.h>
37 #include <openssl/hmac.h>
38 #include <openssl/objects.h>
39 #include <openssl/pem.h>
40 #include <openssl/pkcs12.h>
41 #include <openssl/rand.h>
42 #include <openssl/x509.h>
43 #include <openssl/x509_vfy.h>
44 
45 #include <android-base/logging.h>
46 #include <android-base/stringprintf.h>
47 #include <charconv>
48 
49 #include <cppbor.h>
50 #include <cppbor_parse.h>
51 
52 #include <android/hardware/keymaster/4.0/types.h>
53 #include <keymaster/authorization_set.h>
54 #include <keymaster/contexts/pure_soft_keymaster_context.h>
55 #include <keymaster/contexts/soft_attestation_cert.h>
56 #include <keymaster/keymaster_tags.h>
57 #include <keymaster/km_openssl/attestation_utils.h>
58 
59 namespace android {
60 namespace hardware {
61 namespace identity {
62 namespace support {
63 
64 using ::std::pair;
65 using ::std::unique_ptr;
66 
67 // ---------------------------------------------------------------------------
68 // Miscellaneous utilities.
69 // ---------------------------------------------------------------------------
70 
hexdump(const string & name,const vector<uint8_t> & data)71 void hexdump(const string& name, const vector<uint8_t>& data) {
72     fprintf(stderr, "%s: dumping %zd bytes\n", name.c_str(), data.size());
73     size_t n, m, o;
74     for (n = 0; n < data.size(); n += 16) {
75         fprintf(stderr, "%04zx  ", n);
76         for (m = 0; m < 16 && n + m < data.size(); m++) {
77             fprintf(stderr, "%02x ", data[n + m]);
78         }
79         for (o = m; o < 16; o++) {
80             fprintf(stderr, "   ");
81         }
82         fprintf(stderr, " ");
83         for (m = 0; m < 16 && n + m < data.size(); m++) {
84             int c = data[n + m];
85             fprintf(stderr, "%c", isprint(c) ? c : '.');
86         }
87         fprintf(stderr, "\n");
88     }
89     fprintf(stderr, "\n");
90 }
91 
encodeHex(const uint8_t * data,size_t dataLen)92 string encodeHex(const uint8_t* data, size_t dataLen) {
93     static const char hexDigits[16] = {'0', '1', '2', '3', '4', '5', '6', '7',
94                                        '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
95 
96     string ret;
97     ret.resize(dataLen * 2);
98     for (size_t n = 0; n < dataLen; n++) {
99         uint8_t byte = data[n];
100         ret[n * 2 + 0] = hexDigits[byte >> 4];
101         ret[n * 2 + 1] = hexDigits[byte & 0x0f];
102     }
103 
104     return ret;
105 }
106 
encodeHex(const string & str)107 string encodeHex(const string& str) {
108     return encodeHex(reinterpret_cast<const uint8_t*>(str.data()), str.size());
109 }
110 
encodeHex(const vector<uint8_t> & data)111 string encodeHex(const vector<uint8_t>& data) {
112     return encodeHex(data.data(), data.size());
113 }
114 
115 // Returns -1 on error, otherwise an integer in the range 0 through 15, both inclusive.
parseHexDigit(char hexDigit)116 int parseHexDigit(char hexDigit) {
117     if (hexDigit >= '0' && hexDigit <= '9') {
118         return int(hexDigit) - '0';
119     } else if (hexDigit >= 'a' && hexDigit <= 'f') {
120         return int(hexDigit) - 'a' + 10;
121     } else if (hexDigit >= 'A' && hexDigit <= 'F') {
122         return int(hexDigit) - 'A' + 10;
123     }
124     return -1;
125 }
126 
decodeHex(const string & hexEncoded)127 optional<vector<uint8_t>> decodeHex(const string& hexEncoded) {
128     vector<uint8_t> out;
129     size_t hexSize = hexEncoded.size();
130     if ((hexSize & 1) != 0) {
131         LOG(ERROR) << "Size of data cannot be odd";
132         return {};
133     }
134 
135     out.resize(hexSize / 2);
136     for (size_t n = 0; n < hexSize / 2; n++) {
137         int upperNibble = parseHexDigit(hexEncoded[n * 2]);
138         int lowerNibble = parseHexDigit(hexEncoded[n * 2 + 1]);
139         if (upperNibble == -1 || lowerNibble == -1) {
140             LOG(ERROR) << "Invalid hex digit at position " << n;
141             return {};
142         }
143         out[n] = (upperNibble << 4) + lowerNibble;
144     }
145 
146     return out;
147 }
148 
149 // ---------------------------------------------------------------------------
150 // CBOR utilities.
151 // ---------------------------------------------------------------------------
152 
cborAreAllElementsNonCompound(const cppbor::CompoundItem * compoundItem)153 static bool cborAreAllElementsNonCompound(const cppbor::CompoundItem* compoundItem) {
154     if (compoundItem->type() == cppbor::ARRAY) {
155         const cppbor::Array* array = compoundItem->asArray();
156         for (size_t n = 0; n < array->size(); n++) {
157             const cppbor::Item* entry = (*array)[n].get();
158             switch (entry->type()) {
159                 case cppbor::ARRAY:
160                 case cppbor::MAP:
161                     return false;
162                 default:
163                     break;
164             }
165         }
166     } else {
167         const cppbor::Map* map = compoundItem->asMap();
168         for (size_t n = 0; n < map->size(); n++) {
169             auto [keyEntry, valueEntry] = (*map)[n];
170             switch (keyEntry->type()) {
171                 case cppbor::ARRAY:
172                 case cppbor::MAP:
173                     return false;
174                 default:
175                     break;
176             }
177             switch (valueEntry->type()) {
178                 case cppbor::ARRAY:
179                 case cppbor::MAP:
180                     return false;
181                 default:
182                     break;
183             }
184         }
185     }
186     return true;
187 }
188 
cborPrettyPrintInternal(const cppbor::Item * item,string & out,size_t indent,size_t maxBStrSize,const vector<string> & mapKeysToNotPrint)189 static bool cborPrettyPrintInternal(const cppbor::Item* item, string& out, size_t indent,
190                                     size_t maxBStrSize, const vector<string>& mapKeysToNotPrint) {
191     char buf[80];
192 
193     string indentString(indent, ' ');
194 
195     switch (item->type()) {
196         case cppbor::UINT:
197             snprintf(buf, sizeof(buf), "%" PRIu64, item->asUint()->unsignedValue());
198             out.append(buf);
199             break;
200 
201         case cppbor::NINT:
202             snprintf(buf, sizeof(buf), "%" PRId64, item->asNint()->value());
203             out.append(buf);
204             break;
205 
206         case cppbor::BSTR: {
207             const cppbor::Bstr* bstr = item->asBstr();
208             const vector<uint8_t>& value = bstr->value();
209             if (value.size() > maxBStrSize) {
210                 unsigned char digest[SHA_DIGEST_LENGTH];
211                 SHA_CTX ctx;
212                 SHA1_Init(&ctx);
213                 SHA1_Update(&ctx, value.data(), value.size());
214                 SHA1_Final(digest, &ctx);
215                 char buf2[SHA_DIGEST_LENGTH * 2 + 1];
216                 for (size_t n = 0; n < SHA_DIGEST_LENGTH; n++) {
217                     snprintf(buf2 + n * 2, 3, "%02x", digest[n]);
218                 }
219                 snprintf(buf, sizeof(buf), "<bstr size=%zd sha1=%s>", value.size(), buf2);
220                 out.append(buf);
221             } else {
222                 out.append("{");
223                 for (size_t n = 0; n < value.size(); n++) {
224                     if (n > 0) {
225                         out.append(", ");
226                     }
227                     snprintf(buf, sizeof(buf), "0x%02x", value[n]);
228                     out.append(buf);
229                 }
230                 out.append("}");
231             }
232         } break;
233 
234         case cppbor::TSTR:
235             out.append("'");
236             {
237                 // TODO: escape "'" characters
238                 out.append(item->asTstr()->value().c_str());
239             }
240             out.append("'");
241             break;
242 
243         case cppbor::ARRAY: {
244             const cppbor::Array* array = item->asArray();
245             if (array->size() == 0) {
246                 out.append("[]");
247             } else if (cborAreAllElementsNonCompound(array)) {
248                 out.append("[");
249                 for (size_t n = 0; n < array->size(); n++) {
250                     if (!cborPrettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize,
251                                                  mapKeysToNotPrint)) {
252                         return false;
253                     }
254                     out.append(", ");
255                 }
256                 out.append("]");
257             } else {
258                 out.append("[\n" + indentString);
259                 for (size_t n = 0; n < array->size(); n++) {
260                     out.append("  ");
261                     if (!cborPrettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize,
262                                                  mapKeysToNotPrint)) {
263                         return false;
264                     }
265                     out.append(",\n" + indentString);
266                 }
267                 out.append("]");
268             }
269         } break;
270 
271         case cppbor::MAP: {
272             const cppbor::Map* map = item->asMap();
273 
274             if (map->size() == 0) {
275                 out.append("{}");
276             } else {
277                 out.append("{\n" + indentString);
278                 for (size_t n = 0; n < map->size(); n++) {
279                     out.append("  ");
280 
281                     auto [map_key, map_value] = (*map)[n];
282 
283                     if (!cborPrettyPrintInternal(map_key.get(), out, indent + 2, maxBStrSize,
284                                                  mapKeysToNotPrint)) {
285                         return false;
286                     }
287                     out.append(" : ");
288                     if (map_key->type() == cppbor::TSTR &&
289                         std::find(mapKeysToNotPrint.begin(), mapKeysToNotPrint.end(),
290                                   map_key->asTstr()->value()) != mapKeysToNotPrint.end()) {
291                         out.append("<not printed>");
292                     } else {
293                         if (!cborPrettyPrintInternal(map_value.get(), out, indent + 2, maxBStrSize,
294                                                      mapKeysToNotPrint)) {
295                             return false;
296                         }
297                     }
298                     out.append(",\n" + indentString);
299                 }
300                 out.append("}");
301             }
302         } break;
303 
304         case cppbor::SEMANTIC: {
305             const cppbor::Semantic* semantic = item->asSemantic();
306             snprintf(buf, sizeof(buf), "tag %" PRIu64 " ", semantic->value());
307             out.append(buf);
308             cborPrettyPrintInternal(semantic->child().get(), out, indent, maxBStrSize,
309                                     mapKeysToNotPrint);
310         } break;
311 
312         case cppbor::SIMPLE:
313             const cppbor::Bool* asBool = item->asSimple()->asBool();
314             const cppbor::Null* asNull = item->asSimple()->asNull();
315             if (asBool != nullptr) {
316                 out.append(asBool->value() ? "true" : "false");
317             } else if (asNull != nullptr) {
318                 out.append("null");
319             } else {
320                 LOG(ERROR) << "Only boolean/null is implemented for SIMPLE";
321                 return false;
322             }
323             break;
324     }
325 
326     return true;
327 }
328 
cborPrettyPrint(const vector<uint8_t> & encodedCbor,size_t maxBStrSize,const vector<string> & mapKeysToNotPrint)329 string cborPrettyPrint(const vector<uint8_t>& encodedCbor, size_t maxBStrSize,
330                        const vector<string>& mapKeysToNotPrint) {
331     auto [item, _, message] = cppbor::parse(encodedCbor);
332     if (item == nullptr) {
333         LOG(ERROR) << "Data to pretty print is not valid CBOR: " << message;
334         return "";
335     }
336 
337     string out;
338     cborPrettyPrintInternal(item.get(), out, 0, maxBStrSize, mapKeysToNotPrint);
339     return out;
340 }
341 
342 // ---------------------------------------------------------------------------
343 // Crypto functionality / abstraction.
344 // ---------------------------------------------------------------------------
345 
346 struct EVP_CIPHER_CTX_Deleter {
operator ()android::hardware::identity::support::EVP_CIPHER_CTX_Deleter347     void operator()(EVP_CIPHER_CTX* ctx) const {
348         if (ctx != nullptr) {
349             EVP_CIPHER_CTX_free(ctx);
350         }
351     }
352 };
353 
354 using EvpCipherCtxPtr = unique_ptr<EVP_CIPHER_CTX, EVP_CIPHER_CTX_Deleter>;
355 
356 // bool getRandom(size_t numBytes, vector<uint8_t>& output) {
getRandom(size_t numBytes)357 optional<vector<uint8_t>> getRandom(size_t numBytes) {
358     vector<uint8_t> output;
359     output.resize(numBytes);
360     if (RAND_bytes(output.data(), numBytes) != 1) {
361         LOG(ERROR) << "RAND_bytes: failed getting " << numBytes << " random";
362         return {};
363     }
364     return output;
365 }
366 
decryptAes128Gcm(const vector<uint8_t> & key,const vector<uint8_t> & encryptedData,const vector<uint8_t> & additionalAuthenticatedData)367 optional<vector<uint8_t>> decryptAes128Gcm(const vector<uint8_t>& key,
368                                            const vector<uint8_t>& encryptedData,
369                                            const vector<uint8_t>& additionalAuthenticatedData) {
370     int cipherTextSize = int(encryptedData.size()) - kAesGcmIvSize - kAesGcmTagSize;
371     if (cipherTextSize < 0) {
372         LOG(ERROR) << "encryptedData too small";
373         return {};
374     }
375     unsigned char* nonce = (unsigned char*)encryptedData.data();
376     unsigned char* cipherText = nonce + kAesGcmIvSize;
377     unsigned char* tag = cipherText + cipherTextSize;
378 
379     vector<uint8_t> plainText;
380     plainText.resize(cipherTextSize);
381 
382     auto ctx = EvpCipherCtxPtr(EVP_CIPHER_CTX_new());
383     if (ctx.get() == nullptr) {
384         LOG(ERROR) << "EVP_CIPHER_CTX_new: failed";
385         return {};
386     }
387 
388     if (EVP_DecryptInit_ex(ctx.get(), EVP_aes_128_gcm(), NULL, NULL, NULL) != 1) {
389         LOG(ERROR) << "EVP_DecryptInit_ex: failed";
390         return {};
391     }
392 
393     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, kAesGcmIvSize, NULL) != 1) {
394         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed setting nonce length";
395         return {};
396     }
397 
398     if (EVP_DecryptInit_ex(ctx.get(), NULL, NULL, (unsigned char*)key.data(), nonce) != 1) {
399         LOG(ERROR) << "EVP_DecryptInit_ex: failed";
400         return {};
401     }
402 
403     int numWritten;
404     if (additionalAuthenticatedData.size() > 0) {
405         if (EVP_DecryptUpdate(ctx.get(), NULL, &numWritten,
406                               (unsigned char*)additionalAuthenticatedData.data(),
407                               additionalAuthenticatedData.size()) != 1) {
408             LOG(ERROR) << "EVP_DecryptUpdate: failed for additionalAuthenticatedData";
409             return {};
410         }
411         if ((size_t)numWritten != additionalAuthenticatedData.size()) {
412             LOG(ERROR) << "EVP_DecryptUpdate: Unexpected outl=" << numWritten << " (expected "
413                        << additionalAuthenticatedData.size() << ") for additionalAuthenticatedData";
414             return {};
415         }
416     }
417 
418     if (EVP_DecryptUpdate(ctx.get(), (unsigned char*)plainText.data(), &numWritten, cipherText,
419                           cipherTextSize) != 1) {
420         LOG(ERROR) << "EVP_DecryptUpdate: failed";
421         return {};
422     }
423     if (numWritten != cipherTextSize) {
424         LOG(ERROR) << "EVP_DecryptUpdate: Unexpected outl=" << numWritten << " (expected "
425                    << cipherTextSize << ")";
426         return {};
427     }
428 
429     if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, kAesGcmTagSize, tag)) {
430         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed setting expected tag";
431         return {};
432     }
433 
434     int ret = EVP_DecryptFinal_ex(ctx.get(), (unsigned char*)plainText.data() + numWritten,
435                                   &numWritten);
436     if (ret != 1) {
437         LOG(ERROR) << "EVP_DecryptFinal_ex: failed";
438         return {};
439     }
440     if (numWritten != 0) {
441         LOG(ERROR) << "EVP_DecryptFinal_ex: Unexpected non-zero outl=" << numWritten;
442         return {};
443     }
444 
445     return plainText;
446 }
447 
encryptAes128Gcm(const vector<uint8_t> & key,const vector<uint8_t> & nonce,const vector<uint8_t> & data,const vector<uint8_t> & additionalAuthenticatedData)448 optional<vector<uint8_t>> encryptAes128Gcm(const vector<uint8_t>& key, const vector<uint8_t>& nonce,
449                                            const vector<uint8_t>& data,
450                                            const vector<uint8_t>& additionalAuthenticatedData) {
451     if (key.size() != kAes128GcmKeySize) {
452         LOG(ERROR) << "key is not kAes128GcmKeySize bytes";
453         return {};
454     }
455     if (nonce.size() != kAesGcmIvSize) {
456         LOG(ERROR) << "nonce is not kAesGcmIvSize bytes";
457         return {};
458     }
459 
460     // The result is the nonce (kAesGcmIvSize bytes), the ciphertext, and
461     // finally the tag (kAesGcmTagSize bytes).
462     vector<uint8_t> encryptedData;
463     encryptedData.resize(data.size() + kAesGcmIvSize + kAesGcmTagSize);
464     unsigned char* noncePtr = (unsigned char*)encryptedData.data();
465     unsigned char* cipherText = noncePtr + kAesGcmIvSize;
466     unsigned char* tag = cipherText + data.size();
467     memcpy(noncePtr, nonce.data(), kAesGcmIvSize);
468 
469     auto ctx = EvpCipherCtxPtr(EVP_CIPHER_CTX_new());
470     if (ctx.get() == nullptr) {
471         LOG(ERROR) << "EVP_CIPHER_CTX_new: failed";
472         return {};
473     }
474 
475     if (EVP_EncryptInit_ex(ctx.get(), EVP_aes_128_gcm(), NULL, NULL, NULL) != 1) {
476         LOG(ERROR) << "EVP_EncryptInit_ex: failed";
477         return {};
478     }
479 
480     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, kAesGcmIvSize, NULL) != 1) {
481         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed setting nonce length";
482         return {};
483     }
484 
485     if (EVP_EncryptInit_ex(ctx.get(), NULL, NULL, (unsigned char*)key.data(),
486                            (unsigned char*)nonce.data()) != 1) {
487         LOG(ERROR) << "EVP_EncryptInit_ex: failed";
488         return {};
489     }
490 
491     int numWritten;
492     if (additionalAuthenticatedData.size() > 0) {
493         if (EVP_EncryptUpdate(ctx.get(), NULL, &numWritten,
494                               (unsigned char*)additionalAuthenticatedData.data(),
495                               additionalAuthenticatedData.size()) != 1) {
496             LOG(ERROR) << "EVP_EncryptUpdate: failed for additionalAuthenticatedData";
497             return {};
498         }
499         if ((size_t)numWritten != additionalAuthenticatedData.size()) {
500             LOG(ERROR) << "EVP_EncryptUpdate: Unexpected outl=" << numWritten << " (expected "
501                        << additionalAuthenticatedData.size() << ") for additionalAuthenticatedData";
502             return {};
503         }
504     }
505 
506     if (data.size() > 0) {
507         if (EVP_EncryptUpdate(ctx.get(), cipherText, &numWritten, (unsigned char*)data.data(),
508                               data.size()) != 1) {
509             LOG(ERROR) << "EVP_EncryptUpdate: failed";
510             return {};
511         }
512         if ((size_t)numWritten != data.size()) {
513             LOG(ERROR) << "EVP_EncryptUpdate: Unexpected outl=" << numWritten << " (expected "
514                        << data.size() << ")";
515             return {};
516         }
517     }
518 
519     if (EVP_EncryptFinal_ex(ctx.get(), cipherText + numWritten, &numWritten) != 1) {
520         LOG(ERROR) << "EVP_EncryptFinal_ex: failed";
521         return {};
522     }
523     if (numWritten != 0) {
524         LOG(ERROR) << "EVP_EncryptFinal_ex: Unexpected non-zero outl=" << numWritten;
525         return {};
526     }
527 
528     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, kAesGcmTagSize, tag) != 1) {
529         LOG(ERROR) << "EVP_CIPHER_CTX_ctrl: failed getting tag";
530         return {};
531     }
532 
533     return encryptedData;
534 }
535 
536 struct EC_KEY_Deleter {
operator ()android::hardware::identity::support::EC_KEY_Deleter537     void operator()(EC_KEY* key) const {
538         if (key != nullptr) {
539             EC_KEY_free(key);
540         }
541     }
542 };
543 using EC_KEY_Ptr = unique_ptr<EC_KEY, EC_KEY_Deleter>;
544 
545 struct EVP_PKEY_Deleter {
operator ()android::hardware::identity::support::EVP_PKEY_Deleter546     void operator()(EVP_PKEY* key) const {
547         if (key != nullptr) {
548             EVP_PKEY_free(key);
549         }
550     }
551 };
552 using EVP_PKEY_Ptr = unique_ptr<EVP_PKEY, EVP_PKEY_Deleter>;
553 
554 struct EVP_PKEY_CTX_Deleter {
operator ()android::hardware::identity::support::EVP_PKEY_CTX_Deleter555     void operator()(EVP_PKEY_CTX* ctx) const {
556         if (ctx != nullptr) {
557             EVP_PKEY_CTX_free(ctx);
558         }
559     }
560 };
561 using EVP_PKEY_CTX_Ptr = unique_ptr<EVP_PKEY_CTX, EVP_PKEY_CTX_Deleter>;
562 
563 struct EC_GROUP_Deleter {
operator ()android::hardware::identity::support::EC_GROUP_Deleter564     void operator()(EC_GROUP* group) const {
565         if (group != nullptr) {
566             EC_GROUP_free(group);
567         }
568     }
569 };
570 using EC_GROUP_Ptr = unique_ptr<EC_GROUP, EC_GROUP_Deleter>;
571 
572 struct EC_POINT_Deleter {
operator ()android::hardware::identity::support::EC_POINT_Deleter573     void operator()(EC_POINT* point) const {
574         if (point != nullptr) {
575             EC_POINT_free(point);
576         }
577     }
578 };
579 
580 using EC_POINT_Ptr = unique_ptr<EC_POINT, EC_POINT_Deleter>;
581 
582 struct ECDSA_SIG_Deleter {
operator ()android::hardware::identity::support::ECDSA_SIG_Deleter583     void operator()(ECDSA_SIG* sig) const {
584         if (sig != nullptr) {
585             ECDSA_SIG_free(sig);
586         }
587     }
588 };
589 using ECDSA_SIG_Ptr = unique_ptr<ECDSA_SIG, ECDSA_SIG_Deleter>;
590 
591 struct X509_Deleter {
operator ()android::hardware::identity::support::X509_Deleter592     void operator()(X509* x509) const {
593         if (x509 != nullptr) {
594             X509_free(x509);
595         }
596     }
597 };
598 using X509_Ptr = unique_ptr<X509, X509_Deleter>;
599 
600 struct PKCS12_Deleter {
operator ()android::hardware::identity::support::PKCS12_Deleter601     void operator()(PKCS12* pkcs12) const {
602         if (pkcs12 != nullptr) {
603             PKCS12_free(pkcs12);
604         }
605     }
606 };
607 using PKCS12_Ptr = unique_ptr<PKCS12, PKCS12_Deleter>;
608 
609 struct BIGNUM_Deleter {
operator ()android::hardware::identity::support::BIGNUM_Deleter610     void operator()(BIGNUM* bignum) const {
611         if (bignum != nullptr) {
612             BN_free(bignum);
613         }
614     }
615 };
616 using BIGNUM_Ptr = unique_ptr<BIGNUM, BIGNUM_Deleter>;
617 
618 struct ASN1_INTEGER_Deleter {
operator ()android::hardware::identity::support::ASN1_INTEGER_Deleter619     void operator()(ASN1_INTEGER* value) const {
620         if (value != nullptr) {
621             ASN1_INTEGER_free(value);
622         }
623     }
624 };
625 using ASN1_INTEGER_Ptr = unique_ptr<ASN1_INTEGER, ASN1_INTEGER_Deleter>;
626 
627 struct ASN1_TIME_Deleter {
operator ()android::hardware::identity::support::ASN1_TIME_Deleter628     void operator()(ASN1_TIME* value) const {
629         if (value != nullptr) {
630             ASN1_TIME_free(value);
631         }
632     }
633 };
634 using ASN1_TIME_Ptr = unique_ptr<ASN1_TIME, ASN1_TIME_Deleter>;
635 
636 struct X509_NAME_Deleter {
operator ()android::hardware::identity::support::X509_NAME_Deleter637     void operator()(X509_NAME* value) const {
638         if (value != nullptr) {
639             X509_NAME_free(value);
640         }
641     }
642 };
643 using X509_NAME_Ptr = unique_ptr<X509_NAME, X509_NAME_Deleter>;
644 
certificateChainJoin(const vector<vector<uint8_t>> & certificateChain)645 vector<uint8_t> certificateChainJoin(const vector<vector<uint8_t>>& certificateChain) {
646     vector<uint8_t> ret;
647     for (const vector<uint8_t>& certificate : certificateChain) {
648         ret.insert(ret.end(), certificate.begin(), certificate.end());
649     }
650     return ret;
651 }
652 
certificateChainSplit(const vector<uint8_t> & certificateChain)653 optional<vector<vector<uint8_t>>> certificateChainSplit(const vector<uint8_t>& certificateChain) {
654     const unsigned char* pStart = (unsigned char*)certificateChain.data();
655     const unsigned char* p = pStart;
656     const unsigned char* pEnd = p + certificateChain.size();
657     vector<vector<uint8_t>> certificates;
658     while (p < pEnd) {
659         size_t begin = p - pStart;
660         auto x509 = X509_Ptr(d2i_X509(nullptr, &p, pEnd - p));
661         size_t next = p - pStart;
662         if (x509 == nullptr) {
663             LOG(ERROR) << "Error parsing X509 certificate";
664             return {};
665         }
666         vector<uint8_t> cert =
667                 vector<uint8_t>(certificateChain.begin() + begin, certificateChain.begin() + next);
668         certificates.push_back(std::move(cert));
669     }
670     return certificates;
671 }
672 
parseX509Certificates(const vector<uint8_t> & certificateChain,vector<X509_Ptr> & parsedCertificates)673 static bool parseX509Certificates(const vector<uint8_t>& certificateChain,
674                                   vector<X509_Ptr>& parsedCertificates) {
675     const unsigned char* p = (unsigned char*)certificateChain.data();
676     const unsigned char* pEnd = p + certificateChain.size();
677     parsedCertificates.resize(0);
678     while (p < pEnd) {
679         auto x509 = X509_Ptr(d2i_X509(nullptr, &p, pEnd - p));
680         if (x509 == nullptr) {
681             LOG(ERROR) << "Error parsing X509 certificate";
682             return false;
683         }
684         parsedCertificates.push_back(std::move(x509));
685     }
686     return true;
687 }
688 
certificateSignedByPublicKey(const vector<uint8_t> & certificate,const vector<uint8_t> & publicKey)689 bool certificateSignedByPublicKey(const vector<uint8_t>& certificate,
690                                   const vector<uint8_t>& publicKey) {
691     const unsigned char* p = certificate.data();
692     auto x509 = X509_Ptr(d2i_X509(nullptr, &p, certificate.size()));
693     if (x509 == nullptr) {
694         LOG(ERROR) << "Error parsing X509 certificate";
695         return false;
696     }
697 
698     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
699     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
700     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
701         1) {
702         LOG(ERROR) << "Error decoding publicKey";
703         return false;
704     }
705     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
706     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
707     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
708         LOG(ERROR) << "Memory allocation failed";
709         return false;
710     }
711     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
712         LOG(ERROR) << "Error setting group";
713         return false;
714     }
715     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
716         LOG(ERROR) << "Error setting point";
717         return false;
718     }
719     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
720         LOG(ERROR) << "Error setting key";
721         return false;
722     }
723 
724     if (X509_verify(x509.get(), pkey.get()) != 1) {
725         return false;
726     }
727 
728     return true;
729 }
730 
731 // TODO: Right now the only check we perform is to check that each certificate
732 //       is signed by its successor. We should - but currently don't - also check
733 //       things like valid dates etc.
734 //
735 //       It would be nice to use X509_verify_cert() instead of doing our own thing.
736 //
certificateChainValidate(const vector<uint8_t> & certificateChain)737 bool certificateChainValidate(const vector<uint8_t>& certificateChain) {
738     vector<X509_Ptr> certs;
739 
740     if (!parseX509Certificates(certificateChain, certs)) {
741         LOG(ERROR) << "Error parsing X509 certificates";
742         return false;
743     }
744 
745     if (certs.size() == 1) {
746         return true;
747     }
748 
749     for (size_t n = 1; n < certs.size(); n++) {
750         const X509_Ptr& keyCert = certs[n - 1];
751         const X509_Ptr& signingCert = certs[n];
752         EVP_PKEY_Ptr signingPubkey(X509_get_pubkey(signingCert.get()));
753         if (X509_verify(keyCert.get(), signingPubkey.get()) != 1) {
754             LOG(ERROR) << "Error validating cert at index " << n - 1
755                        << " is signed by its successor";
756             return false;
757         }
758     }
759 
760     return true;
761 }
762 
checkEcDsaSignature(const vector<uint8_t> & digest,const vector<uint8_t> & signature,const vector<uint8_t> & publicKey)763 bool checkEcDsaSignature(const vector<uint8_t>& digest, const vector<uint8_t>& signature,
764                          const vector<uint8_t>& publicKey) {
765     const unsigned char* p = (unsigned char*)signature.data();
766     auto sig = ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, signature.size()));
767     if (sig.get() == nullptr) {
768         LOG(ERROR) << "Error decoding DER encoded signature";
769         return false;
770     }
771 
772     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
773     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
774     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
775         1) {
776         LOG(ERROR) << "Error decoding publicKey";
777         return false;
778     }
779     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
780     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
781     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
782         LOG(ERROR) << "Memory allocation failed";
783         return false;
784     }
785     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
786         LOG(ERROR) << "Error setting group";
787         return false;
788     }
789     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
790         LOG(ERROR) << "Error setting point";
791         return false;
792     }
793     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
794         LOG(ERROR) << "Error setting key";
795         return false;
796     }
797 
798     int rc = ECDSA_do_verify(digest.data(), digest.size(), sig.get(), ecKey.get());
799     if (rc != 1) {
800         LOG(ERROR) << "Error verifying signature (rc=" << rc << ")";
801         return false;
802     }
803 
804     return true;
805 }
806 
sha256(const vector<uint8_t> & data)807 vector<uint8_t> sha256(const vector<uint8_t>& data) {
808     vector<uint8_t> ret;
809     ret.resize(SHA256_DIGEST_LENGTH);
810     SHA256_CTX ctx;
811     SHA256_Init(&ctx);
812     SHA256_Update(&ctx, data.data(), data.size());
813     SHA256_Final((unsigned char*)ret.data(), &ctx);
814     return ret;
815 }
816 
signEcDsaDigest(const vector<uint8_t> & key,const vector<uint8_t> & dataDigest)817 optional<vector<uint8_t>> signEcDsaDigest(const vector<uint8_t>& key,
818                                           const vector<uint8_t>& dataDigest) {
819     auto bn = BIGNUM_Ptr(BN_bin2bn(key.data(), key.size(), nullptr));
820     if (bn.get() == nullptr) {
821         LOG(ERROR) << "Error creating BIGNUM";
822         return {};
823     }
824 
825     auto ec_key = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
826     if (EC_KEY_set_private_key(ec_key.get(), bn.get()) != 1) {
827         LOG(ERROR) << "Error setting private key from BIGNUM";
828         return {};
829     }
830 
831     ECDSA_SIG* sig = ECDSA_do_sign(dataDigest.data(), dataDigest.size(), ec_key.get());
832     if (sig == nullptr) {
833         LOG(ERROR) << "Error signing digest";
834         return {};
835     }
836     size_t len = i2d_ECDSA_SIG(sig, nullptr);
837     vector<uint8_t> signature;
838     signature.resize(len);
839     unsigned char* p = (unsigned char*)signature.data();
840     i2d_ECDSA_SIG(sig, &p);
841     ECDSA_SIG_free(sig);
842     return signature;
843 }
844 
signEcDsa(const vector<uint8_t> & key,const vector<uint8_t> & data)845 optional<vector<uint8_t>> signEcDsa(const vector<uint8_t>& key, const vector<uint8_t>& data) {
846     return signEcDsaDigest(key, sha256(data));
847 }
848 
hmacSha256(const vector<uint8_t> & key,const vector<uint8_t> & data)849 optional<vector<uint8_t>> hmacSha256(const vector<uint8_t>& key, const vector<uint8_t>& data) {
850     HMAC_CTX ctx;
851     HMAC_CTX_init(&ctx);
852     if (HMAC_Init_ex(&ctx, key.data(), key.size(), EVP_sha256(), nullptr /* impl */) != 1) {
853         LOG(ERROR) << "Error initializing HMAC_CTX";
854         return {};
855     }
856     if (HMAC_Update(&ctx, data.data(), data.size()) != 1) {
857         LOG(ERROR) << "Error updating HMAC_CTX";
858         return {};
859     }
860     vector<uint8_t> hmac;
861     hmac.resize(32);
862     unsigned int size = 0;
863     if (HMAC_Final(&ctx, hmac.data(), &size) != 1) {
864         LOG(ERROR) << "Error finalizing HMAC_CTX";
865         return {};
866     }
867     if (size != 32) {
868         LOG(ERROR) << "Expected 32 bytes from HMAC_Final, got " << size;
869         return {};
870     }
871     return hmac;
872 }
873 
parseDigits(const char ** s,int numDigits)874 int parseDigits(const char** s, int numDigits) {
875     int result;
876     auto [_, ec] = std::from_chars(*s, *s + numDigits, result);
877     if (ec != std::errc()) {
878         LOG(ERROR) << "Error parsing " << numDigits << " digits "
879                    << " from " << s;
880         return 0;
881     }
882     *s += numDigits;
883     return result;
884 }
885 
parseAsn1Time(const ASN1_TIME * asn1Time,time_t * outTime)886 bool parseAsn1Time(const ASN1_TIME* asn1Time, time_t* outTime) {
887     struct tm tm;
888 
889     memset(&tm, '\0', sizeof(tm));
890     const char* timeStr = (const char*)asn1Time->data;
891     const char* s = timeStr;
892     if (asn1Time->type == V_ASN1_UTCTIME) {
893         tm.tm_year = parseDigits(&s, 2);
894         if (tm.tm_year < 70) {
895             tm.tm_year += 100;
896         }
897     } else if (asn1Time->type == V_ASN1_GENERALIZEDTIME) {
898         tm.tm_year = parseDigits(&s, 4) - 1900;
899         tm.tm_year -= 1900;
900     } else {
901         LOG(ERROR) << "Unsupported ASN1_TIME type " << asn1Time->type;
902         return false;
903     }
904     tm.tm_mon = parseDigits(&s, 2) - 1;
905     tm.tm_mday = parseDigits(&s, 2);
906     tm.tm_hour = parseDigits(&s, 2);
907     tm.tm_min = parseDigits(&s, 2);
908     tm.tm_sec = parseDigits(&s, 2);
909     // This may need to be updated if someone create certificates using +/- instead of Z.
910     //
911     if (*s != 'Z') {
912         LOG(ERROR) << "Expected Z in string '" << timeStr << "' at offset " << (s - timeStr);
913         return false;
914     }
915 
916     time_t t = timegm(&tm);
917     if (t == -1) {
918         LOG(ERROR) << "Error converting broken-down time to time_t";
919         return false;
920     }
921     *outTime = t;
922     return true;
923 }
924 
925 // Generates the attestation certificate with the parameters passed in.  Note
926 // that the passed in |activeTimeMilliSeconds| |expireTimeMilliSeconds| are in
927 // milli seconds since epoch.  We are setting them to milliseconds due to
928 // requirement in AuthorizationSet KM_DATE fields.  The certificate created is
929 // actually in seconds.
930 //
931 // If 0 is passed for expiration time, the expiration time from batch
932 // certificate will be used.
933 //
createAttestation(const EVP_PKEY * key,const vector<uint8_t> & applicationId,const vector<uint8_t> & challenge,uint64_t activeTimeMilliSeconds,uint64_t expireTimeMilliSeconds,bool isTestCredential)934 optional<vector<vector<uint8_t>>> createAttestation(
935         const EVP_PKEY* key, const vector<uint8_t>& applicationId, const vector<uint8_t>& challenge,
936         uint64_t activeTimeMilliSeconds, uint64_t expireTimeMilliSeconds, bool isTestCredential) {
937     const keymaster_cert_chain_t* attestation_chain =
938             ::keymaster::getAttestationChain(KM_ALGORITHM_EC, nullptr);
939     if (attestation_chain == nullptr) {
940         LOG(ERROR) << "Error getting attestation chain";
941         return {};
942     }
943     if (expireTimeMilliSeconds == 0) {
944         if (attestation_chain->entry_count < 1) {
945             LOG(ERROR) << "Expected at least one entry in attestation chain";
946             return {};
947         }
948         keymaster_blob_t* bcBlob = &(attestation_chain->entries[0]);
949         const uint8_t* bcData = bcBlob->data;
950         auto bc = X509_Ptr(d2i_X509(nullptr, &bcData, bcBlob->data_length));
951         time_t bcNotAfter;
952         if (!parseAsn1Time(X509_get0_notAfter(bc.get()), &bcNotAfter)) {
953             LOG(ERROR) << "Error getting notAfter from batch certificate";
954             return {};
955         }
956         expireTimeMilliSeconds = bcNotAfter * 1000;
957     }
958     const keymaster_key_blob_t* attestation_signing_key =
959             ::keymaster::getAttestationKey(KM_ALGORITHM_EC, nullptr);
960     if (attestation_signing_key == nullptr) {
961         LOG(ERROR) << "Error getting attestation key";
962         return {};
963     }
964 
965     ::keymaster::AuthorizationSet auth_set(
966             ::keymaster::AuthorizationSetBuilder()
967                     .Authorization(::keymaster::TAG_ATTESTATION_CHALLENGE, challenge.data(),
968                                    challenge.size())
969                     .Authorization(::keymaster::TAG_ACTIVE_DATETIME, activeTimeMilliSeconds)
970                     // Even though identity attestation hal said the application
971                     // id should be in software enforced authentication set,
972                     // keymaster portable lib expect the input in this
973                     // parameter because the software enforced in input to keymaster
974                     // refers to the key software enforced properties. And this
975                     // parameter refers to properties of the attestation which
976                     // includes app id.
977                     .Authorization(::keymaster::TAG_ATTESTATION_APPLICATION_ID,
978                                    applicationId.data(), applicationId.size())
979                     .Authorization(::keymaster::TAG_USAGE_EXPIRE_DATETIME, expireTimeMilliSeconds));
980 
981     // Unique id and device id is not applicable for identity credential attestation,
982     // so we don't need to set those or application id.
983     ::keymaster::AuthorizationSet swEnforced(::keymaster::AuthorizationSetBuilder().Authorization(
984             ::keymaster::TAG_CREATION_DATETIME, activeTimeMilliSeconds));
985 
986     ::keymaster::AuthorizationSetBuilder hwEnforcedBuilder =
987             ::keymaster::AuthorizationSetBuilder()
988                     .Authorization(::keymaster::TAG_PURPOSE, KM_PURPOSE_SIGN)
989                     .Authorization(::keymaster::TAG_KEY_SIZE, 256)
990                     .Authorization(::keymaster::TAG_ALGORITHM, KM_ALGORITHM_EC)
991                     .Authorization(::keymaster::TAG_NO_AUTH_REQUIRED)
992                     .Authorization(::keymaster::TAG_DIGEST, KM_DIGEST_SHA_2_256)
993                     .Authorization(::keymaster::TAG_EC_CURVE, KM_EC_CURVE_P_256)
994                     .Authorization(::keymaster::TAG_OS_VERSION, 42)
995                     .Authorization(::keymaster::TAG_OS_PATCHLEVEL, 43);
996 
997     // Only include TAG_IDENTITY_CREDENTIAL_KEY if it's not a test credential
998     if (!isTestCredential) {
999         hwEnforcedBuilder.Authorization(::keymaster::TAG_IDENTITY_CREDENTIAL_KEY);
1000     }
1001     ::keymaster::AuthorizationSet hwEnforced(hwEnforcedBuilder);
1002 
1003     keymaster_error_t error;
1004     ::keymaster::CertChainPtr cert_chain_out;
1005 
1006     // Pretend to be implemented in a trusted environment just so we can pass
1007     // the VTS tests. Of course, this is a pretend-only game since hopefully no
1008     // relying party is ever going to trust our batch key and those keys above
1009     // it.
1010     //
1011     ::keymaster::PureSoftKeymasterContext context(KM_SECURITY_LEVEL_TRUSTED_ENVIRONMENT);
1012 
1013     error = generate_attestation_from_EVP(key, swEnforced, hwEnforced, auth_set, context,
1014                                           ::keymaster::kCurrentKeymasterVersion, *attestation_chain,
1015                                           *attestation_signing_key,
1016                                           "Android Identity Credential Key", &cert_chain_out);
1017 
1018     if (KM_ERROR_OK != error || !cert_chain_out) {
1019         LOG(ERROR) << "Error generate attestation from EVP key" << error;
1020         return {};
1021     }
1022 
1023     // translate certificate format from keymaster_cert_chain_t to vector<uint8_t>.
1024     vector<vector<uint8_t>> attestationCertificate;
1025     for (int i = 0; i < cert_chain_out->entry_count; i++) {
1026         attestationCertificate.insert(
1027                 attestationCertificate.end(),
1028                 vector<uint8_t>(
1029                         cert_chain_out->entries[i].data,
1030                         cert_chain_out->entries[i].data + cert_chain_out->entries[i].data_length));
1031     }
1032 
1033     return attestationCertificate;
1034 }
1035 
createEcKeyPairAndAttestation(const vector<uint8_t> & challenge,const vector<uint8_t> & applicationId,bool isTestCredential)1036 optional<std::pair<vector<uint8_t>, vector<vector<uint8_t>>>> createEcKeyPairAndAttestation(
1037         const vector<uint8_t>& challenge, const vector<uint8_t>& applicationId,
1038         bool isTestCredential) {
1039     auto ec_key = ::keymaster::EC_KEY_Ptr(EC_KEY_new());
1040     auto pkey = ::keymaster::EVP_PKEY_Ptr(EVP_PKEY_new());
1041     auto group = ::keymaster::EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
1042 
1043     if (ec_key.get() == nullptr || pkey.get() == nullptr) {
1044         LOG(ERROR) << "Memory allocation failed";
1045         return {};
1046     }
1047 
1048     if (EC_KEY_set_group(ec_key.get(), group.get()) != 1 ||
1049         EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) < 0) {
1050         LOG(ERROR) << "Error generating key";
1051         return {};
1052     }
1053 
1054     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ec_key.get()) != 1) {
1055         LOG(ERROR) << "Error getting private key";
1056         return {};
1057     }
1058 
1059     uint64_t nowMs = time(nullptr) * 1000;
1060     uint64_t expireTimeMs = 0;  // Set to same as batch certificate
1061 
1062     optional<vector<vector<uint8_t>>> attestationCert = createAttestation(
1063             pkey.get(), applicationId, challenge, nowMs, expireTimeMs, isTestCredential);
1064     if (!attestationCert) {
1065         LOG(ERROR) << "Error create attestation from key and challenge";
1066         return {};
1067     }
1068 
1069     int size = i2d_PrivateKey(pkey.get(), nullptr);
1070     if (size == 0) {
1071         LOG(ERROR) << "Error generating public key encoding";
1072         return {};
1073     }
1074 
1075     vector<uint8_t> keyPair(size);
1076     unsigned char* p = keyPair.data();
1077     i2d_PrivateKey(pkey.get(), &p);
1078 
1079     return make_pair(keyPair, attestationCert.value());
1080 }
1081 
createAttestationForEcPublicKey(const vector<uint8_t> & publicKey,const vector<uint8_t> & challenge,const vector<uint8_t> & applicationId)1082 optional<vector<vector<uint8_t>>> createAttestationForEcPublicKey(
1083         const vector<uint8_t>& publicKey, const vector<uint8_t>& challenge,
1084         const vector<uint8_t>& applicationId) {
1085     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
1086     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
1087     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
1088         1) {
1089         LOG(ERROR) << "Error decoding publicKey";
1090         return {};
1091     }
1092     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
1093     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1094     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
1095         LOG(ERROR) << "Memory allocation failed";
1096         return {};
1097     }
1098     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
1099         LOG(ERROR) << "Error setting group";
1100         return {};
1101     }
1102     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
1103         LOG(ERROR) << "Error setting point";
1104         return {};
1105     }
1106     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
1107         LOG(ERROR) << "Error setting key";
1108         return {};
1109     }
1110 
1111     uint64_t nowMs = time(nullptr) * 1000;
1112     uint64_t expireTimeMs = 0;  // Set to same as batch certificate
1113 
1114     optional<vector<vector<uint8_t>>> attestationCert =
1115             createAttestation(pkey.get(), applicationId, challenge, nowMs, expireTimeMs,
1116                               false /* isTestCredential */);
1117     if (!attestationCert) {
1118         LOG(ERROR) << "Error create attestation from key and challenge";
1119         return {};
1120     }
1121 
1122     return attestationCert.value();
1123 }
1124 
createEcKeyPair()1125 optional<vector<uint8_t>> createEcKeyPair() {
1126     auto ec_key = EC_KEY_Ptr(EC_KEY_new());
1127     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1128     if (ec_key.get() == nullptr || pkey.get() == nullptr) {
1129         LOG(ERROR) << "Memory allocation failed";
1130         return {};
1131     }
1132 
1133     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
1134     if (group.get() == nullptr) {
1135         LOG(ERROR) << "Error creating EC group by curve name";
1136         return {};
1137     }
1138 
1139     if (EC_KEY_set_group(ec_key.get(), group.get()) != 1 ||
1140         EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) < 0) {
1141         LOG(ERROR) << "Error generating key";
1142         return {};
1143     }
1144 
1145     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ec_key.get()) != 1) {
1146         LOG(ERROR) << "Error getting private key";
1147         return {};
1148     }
1149 
1150     int size = i2d_PrivateKey(pkey.get(), nullptr);
1151     if (size == 0) {
1152         LOG(ERROR) << "Error generating public key encoding";
1153         return {};
1154     }
1155     vector<uint8_t> keyPair;
1156     keyPair.resize(size);
1157     unsigned char* p = keyPair.data();
1158     i2d_PrivateKey(pkey.get(), &p);
1159     return keyPair;
1160 }
1161 
ecKeyPairGetPublicKey(const vector<uint8_t> & keyPair)1162 optional<vector<uint8_t>> ecKeyPairGetPublicKey(const vector<uint8_t>& keyPair) {
1163     const unsigned char* p = (const unsigned char*)keyPair.data();
1164     auto pkey = EVP_PKEY_Ptr(d2i_PrivateKey(EVP_PKEY_EC, nullptr, &p, keyPair.size()));
1165     if (pkey.get() == nullptr) {
1166         LOG(ERROR) << "Error parsing keyPair";
1167         return {};
1168     }
1169 
1170     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
1171     if (ecKey.get() == nullptr) {
1172         LOG(ERROR) << "Failed getting EC key";
1173         return {};
1174     }
1175 
1176     auto ecGroup = EC_KEY_get0_group(ecKey.get());
1177     auto ecPoint = EC_KEY_get0_public_key(ecKey.get());
1178     int size = EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0,
1179                                   nullptr);
1180     if (size == 0) {
1181         LOG(ERROR) << "Error generating public key encoding";
1182         return {};
1183     }
1184 
1185     vector<uint8_t> publicKey;
1186     publicKey.resize(size);
1187     EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, publicKey.data(),
1188                        publicKey.size(), nullptr);
1189     return publicKey;
1190 }
1191 
ecKeyPairGetPrivateKey(const vector<uint8_t> & keyPair)1192 optional<vector<uint8_t>> ecKeyPairGetPrivateKey(const vector<uint8_t>& keyPair) {
1193     const unsigned char* p = (const unsigned char*)keyPair.data();
1194     auto pkey = EVP_PKEY_Ptr(d2i_PrivateKey(EVP_PKEY_EC, nullptr, &p, keyPair.size()));
1195     if (pkey.get() == nullptr) {
1196         LOG(ERROR) << "Error parsing keyPair";
1197         return {};
1198     }
1199 
1200     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
1201     if (ecKey.get() == nullptr) {
1202         LOG(ERROR) << "Failed getting EC key";
1203         return {};
1204     }
1205 
1206     const BIGNUM* bignum = EC_KEY_get0_private_key(ecKey.get());
1207     if (bignum == nullptr) {
1208         LOG(ERROR) << "Error getting bignum from private key";
1209         return {};
1210     }
1211     vector<uint8_t> privateKey;
1212     privateKey.resize(BN_num_bytes(bignum));
1213     BN_bn2bin(bignum, privateKey.data());
1214     return privateKey;
1215 }
1216 
ecPrivateKeyToKeyPair(const vector<uint8_t> & privateKey)1217 optional<vector<uint8_t>> ecPrivateKeyToKeyPair(const vector<uint8_t>& privateKey) {
1218     auto bn = BIGNUM_Ptr(BN_bin2bn(privateKey.data(), privateKey.size(), nullptr));
1219     if (bn.get() == nullptr) {
1220         LOG(ERROR) << "Error creating BIGNUM";
1221         return {};
1222     }
1223 
1224     auto ecKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
1225     if (EC_KEY_set_private_key(ecKey.get(), bn.get()) != 1) {
1226         LOG(ERROR) << "Error setting private key from BIGNUM";
1227         return {};
1228     }
1229 
1230     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1231     if (pkey.get() == nullptr) {
1232         LOG(ERROR) << "Memory allocation failed";
1233         return {};
1234     }
1235 
1236     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
1237         LOG(ERROR) << "Error getting private key";
1238         return {};
1239     }
1240 
1241     int size = i2d_PrivateKey(pkey.get(), nullptr);
1242     if (size == 0) {
1243         LOG(ERROR) << "Error generating public key encoding";
1244         return {};
1245     }
1246     vector<uint8_t> keyPair;
1247     keyPair.resize(size);
1248     unsigned char* p = keyPair.data();
1249     i2d_PrivateKey(pkey.get(), &p);
1250     return keyPair;
1251 }
1252 
ecKeyPairGetPkcs12(const vector<uint8_t> & keyPair,const string & name,const string & serialDecimal,const string & issuer,const string & subject,time_t validityNotBefore,time_t validityNotAfter)1253 optional<vector<uint8_t>> ecKeyPairGetPkcs12(const vector<uint8_t>& keyPair, const string& name,
1254                                              const string& serialDecimal, const string& issuer,
1255                                              const string& subject, time_t validityNotBefore,
1256                                              time_t validityNotAfter) {
1257     const unsigned char* p = (const unsigned char*)keyPair.data();
1258     auto pkey = EVP_PKEY_Ptr(d2i_PrivateKey(EVP_PKEY_EC, nullptr, &p, keyPair.size()));
1259     if (pkey.get() == nullptr) {
1260         LOG(ERROR) << "Error parsing keyPair";
1261         return {};
1262     }
1263 
1264     auto x509 = X509_Ptr(X509_new());
1265     if (!x509.get()) {
1266         LOG(ERROR) << "Error creating X509 certificate";
1267         return {};
1268     }
1269 
1270     if (!X509_set_version(x509.get(), 2 /* version 3, but zero-based */)) {
1271         LOG(ERROR) << "Error setting version to 3";
1272         return {};
1273     }
1274 
1275     if (X509_set_pubkey(x509.get(), pkey.get()) != 1) {
1276         LOG(ERROR) << "Error setting public key";
1277         return {};
1278     }
1279 
1280     BIGNUM* bignumSerial = nullptr;
1281     if (BN_dec2bn(&bignumSerial, serialDecimal.c_str()) == 0) {
1282         LOG(ERROR) << "Error parsing serial";
1283         return {};
1284     }
1285     auto bignumSerialPtr = BIGNUM_Ptr(bignumSerial);
1286     auto asnSerial = ASN1_INTEGER_Ptr(BN_to_ASN1_INTEGER(bignumSerial, nullptr));
1287     if (X509_set_serialNumber(x509.get(), asnSerial.get()) != 1) {
1288         LOG(ERROR) << "Error setting serial";
1289         return {};
1290     }
1291 
1292     auto x509Issuer = X509_NAME_Ptr(X509_NAME_new());
1293     if (x509Issuer.get() == nullptr ||
1294         X509_NAME_add_entry_by_txt(x509Issuer.get(), "CN", MBSTRING_ASC,
1295                                    (const uint8_t*)issuer.c_str(), issuer.size(), -1 /* loc */,
1296                                    0 /* set */) != 1 ||
1297         X509_set_issuer_name(x509.get(), x509Issuer.get()) != 1) {
1298         LOG(ERROR) << "Error setting issuer";
1299         return {};
1300     }
1301 
1302     auto x509Subject = X509_NAME_Ptr(X509_NAME_new());
1303     if (x509Subject.get() == nullptr ||
1304         X509_NAME_add_entry_by_txt(x509Subject.get(), "CN", MBSTRING_ASC,
1305                                    (const uint8_t*)subject.c_str(), subject.size(), -1 /* loc */,
1306                                    0 /* set */) != 1 ||
1307         X509_set_subject_name(x509.get(), x509Subject.get()) != 1) {
1308         LOG(ERROR) << "Error setting subject";
1309         return {};
1310     }
1311 
1312     auto asnNotBefore = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotBefore));
1313     if (asnNotBefore.get() == nullptr || X509_set_notBefore(x509.get(), asnNotBefore.get()) != 1) {
1314         LOG(ERROR) << "Error setting notBefore";
1315         return {};
1316     }
1317 
1318     auto asnNotAfter = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotAfter));
1319     if (asnNotAfter.get() == nullptr || X509_set_notAfter(x509.get(), asnNotAfter.get()) != 1) {
1320         LOG(ERROR) << "Error setting notAfter";
1321         return {};
1322     }
1323 
1324     if (X509_sign(x509.get(), pkey.get(), EVP_sha256()) == 0) {
1325         LOG(ERROR) << "Error signing X509 certificate";
1326         return {};
1327     }
1328 
1329     // Ideally we wouldn't encrypt it (we're only using this function for
1330     // sending a key-pair over binder to the Android app) but BoringSSL does not
1331     // support this: from pkcs8_x509.c in BoringSSL: "In OpenSSL, -1 here means
1332     // to use no encryption, which we do not currently support."
1333     //
1334     // Passing nullptr as |pass|, though, means "no password". So we'll do that.
1335     // Compare with the receiving side - CredstoreIdentityCredential.java - where
1336     // an empty char[] is passed as the password.
1337     //
1338     auto pkcs12 = PKCS12_Ptr(PKCS12_create(nullptr, name.c_str(), pkey.get(), x509.get(),
1339                                            nullptr,  // ca
1340                                            0,        // nid_key
1341                                            0,        // nid_cert
1342                                            0,        // iter,
1343                                            0,        // mac_iter,
1344                                            0));      // keytype
1345     if (pkcs12.get() == nullptr) {
1346         char buf[128];
1347         long errCode = ERR_get_error();
1348         ERR_error_string_n(errCode, buf, sizeof buf);
1349         LOG(ERROR) << "Error creating PKCS12, code " << errCode << ": " << buf;
1350         return {};
1351     }
1352 
1353     unsigned char* buffer = nullptr;
1354     int length = i2d_PKCS12(pkcs12.get(), &buffer);
1355     if (length < 0) {
1356         LOG(ERROR) << "Error encoding PKCS12";
1357         return {};
1358     }
1359     vector<uint8_t> pkcs12Bytes;
1360     pkcs12Bytes.resize(length);
1361     memcpy(pkcs12Bytes.data(), buffer, length);
1362     OPENSSL_free(buffer);
1363 
1364     return pkcs12Bytes;
1365 }
1366 
ecPublicKeyGenerateCertificate(const vector<uint8_t> & publicKey,const vector<uint8_t> & signingKey,const string & serialDecimal,const string & issuer,const string & subject,time_t validityNotBefore,time_t validityNotAfter)1367 optional<vector<uint8_t>> ecPublicKeyGenerateCertificate(
1368         const vector<uint8_t>& publicKey, const vector<uint8_t>& signingKey,
1369         const string& serialDecimal, const string& issuer, const string& subject,
1370         time_t validityNotBefore, time_t validityNotAfter) {
1371     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
1372     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
1373     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
1374         1) {
1375         LOG(ERROR) << "Error decoding publicKey";
1376         return {};
1377     }
1378     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
1379     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1380     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
1381         LOG(ERROR) << "Memory allocation failed";
1382         return {};
1383     }
1384     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
1385         LOG(ERROR) << "Error setting group";
1386         return {};
1387     }
1388     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
1389         LOG(ERROR) << "Error setting point";
1390         return {};
1391     }
1392     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
1393         LOG(ERROR) << "Error setting key";
1394         return {};
1395     }
1396 
1397     auto bn = BIGNUM_Ptr(BN_bin2bn(signingKey.data(), signingKey.size(), nullptr));
1398     if (bn.get() == nullptr) {
1399         LOG(ERROR) << "Error creating BIGNUM for private key";
1400         return {};
1401     }
1402     auto privEcKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
1403     if (EC_KEY_set_private_key(privEcKey.get(), bn.get()) != 1) {
1404         LOG(ERROR) << "Error setting private key from BIGNUM";
1405         return {};
1406     }
1407     auto privPkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1408     if (EVP_PKEY_set1_EC_KEY(privPkey.get(), privEcKey.get()) != 1) {
1409         LOG(ERROR) << "Error setting private key";
1410         return {};
1411     }
1412 
1413     auto x509 = X509_Ptr(X509_new());
1414     if (!x509.get()) {
1415         LOG(ERROR) << "Error creating X509 certificate";
1416         return {};
1417     }
1418 
1419     if (!X509_set_version(x509.get(), 2 /* version 3, but zero-based */)) {
1420         LOG(ERROR) << "Error setting version to 3";
1421         return {};
1422     }
1423 
1424     if (X509_set_pubkey(x509.get(), pkey.get()) != 1) {
1425         LOG(ERROR) << "Error setting public key";
1426         return {};
1427     }
1428 
1429     BIGNUM* bignumSerial = nullptr;
1430     if (BN_dec2bn(&bignumSerial, serialDecimal.c_str()) == 0) {
1431         LOG(ERROR) << "Error parsing serial";
1432         return {};
1433     }
1434     auto bignumSerialPtr = BIGNUM_Ptr(bignumSerial);
1435     auto asnSerial = ASN1_INTEGER_Ptr(BN_to_ASN1_INTEGER(bignumSerial, nullptr));
1436     if (X509_set_serialNumber(x509.get(), asnSerial.get()) != 1) {
1437         LOG(ERROR) << "Error setting serial";
1438         return {};
1439     }
1440 
1441     auto x509Issuer = X509_NAME_Ptr(X509_NAME_new());
1442     if (x509Issuer.get() == nullptr ||
1443         X509_NAME_add_entry_by_txt(x509Issuer.get(), "CN", MBSTRING_ASC,
1444                                    (const uint8_t*)issuer.c_str(), issuer.size(), -1 /* loc */,
1445                                    0 /* set */) != 1 ||
1446         X509_set_issuer_name(x509.get(), x509Issuer.get()) != 1) {
1447         LOG(ERROR) << "Error setting issuer";
1448         return {};
1449     }
1450 
1451     auto x509Subject = X509_NAME_Ptr(X509_NAME_new());
1452     if (x509Subject.get() == nullptr ||
1453         X509_NAME_add_entry_by_txt(x509Subject.get(), "CN", MBSTRING_ASC,
1454                                    (const uint8_t*)subject.c_str(), subject.size(), -1 /* loc */,
1455                                    0 /* set */) != 1 ||
1456         X509_set_subject_name(x509.get(), x509Subject.get()) != 1) {
1457         LOG(ERROR) << "Error setting subject";
1458         return {};
1459     }
1460 
1461     auto asnNotBefore = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotBefore));
1462     if (asnNotBefore.get() == nullptr || X509_set_notBefore(x509.get(), asnNotBefore.get()) != 1) {
1463         LOG(ERROR) << "Error setting notBefore";
1464         return {};
1465     }
1466 
1467     auto asnNotAfter = ASN1_TIME_Ptr(ASN1_TIME_set(nullptr, validityNotAfter));
1468     if (asnNotAfter.get() == nullptr || X509_set_notAfter(x509.get(), asnNotAfter.get()) != 1) {
1469         LOG(ERROR) << "Error setting notAfter";
1470         return {};
1471     }
1472 
1473     if (X509_sign(x509.get(), privPkey.get(), EVP_sha256()) == 0) {
1474         LOG(ERROR) << "Error signing X509 certificate";
1475         return {};
1476     }
1477 
1478     unsigned char* buffer = nullptr;
1479     int length = i2d_X509(x509.get(), &buffer);
1480     if (length < 0) {
1481         LOG(ERROR) << "Error DER encoding X509 certificate";
1482         return {};
1483     }
1484 
1485     vector<uint8_t> certificate;
1486     certificate.resize(length);
1487     memcpy(certificate.data(), buffer, length);
1488     OPENSSL_free(buffer);
1489     return certificate;
1490 }
1491 
ecdh(const vector<uint8_t> & publicKey,const vector<uint8_t> & privateKey)1492 optional<vector<uint8_t>> ecdh(const vector<uint8_t>& publicKey,
1493                                const vector<uint8_t>& privateKey) {
1494     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
1495     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
1496     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
1497         1) {
1498         LOG(ERROR) << "Error decoding publicKey";
1499         return {};
1500     }
1501     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
1502     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1503     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
1504         LOG(ERROR) << "Memory allocation failed";
1505         return {};
1506     }
1507     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
1508         LOG(ERROR) << "Error setting group";
1509         return {};
1510     }
1511     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
1512         LOG(ERROR) << "Error setting point";
1513         return {};
1514     }
1515     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
1516         LOG(ERROR) << "Error setting key";
1517         return {};
1518     }
1519 
1520     auto bn = BIGNUM_Ptr(BN_bin2bn(privateKey.data(), privateKey.size(), nullptr));
1521     if (bn.get() == nullptr) {
1522         LOG(ERROR) << "Error creating BIGNUM for private key";
1523         return {};
1524     }
1525     auto privEcKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
1526     if (EC_KEY_set_private_key(privEcKey.get(), bn.get()) != 1) {
1527         LOG(ERROR) << "Error setting private key from BIGNUM";
1528         return {};
1529     }
1530     auto privPkey = EVP_PKEY_Ptr(EVP_PKEY_new());
1531     if (EVP_PKEY_set1_EC_KEY(privPkey.get(), privEcKey.get()) != 1) {
1532         LOG(ERROR) << "Error setting private key";
1533         return {};
1534     }
1535 
1536     auto ctx = EVP_PKEY_CTX_Ptr(EVP_PKEY_CTX_new(privPkey.get(), NULL));
1537     if (ctx.get() == nullptr) {
1538         LOG(ERROR) << "Error creating context";
1539         return {};
1540     }
1541 
1542     if (EVP_PKEY_derive_init(ctx.get()) != 1) {
1543         LOG(ERROR) << "Error initializing context";
1544         return {};
1545     }
1546 
1547     if (EVP_PKEY_derive_set_peer(ctx.get(), pkey.get()) != 1) {
1548         LOG(ERROR) << "Error setting peer";
1549         return {};
1550     }
1551 
1552     /* Determine buffer length for shared secret */
1553     size_t secretLen = 0;
1554     if (EVP_PKEY_derive(ctx.get(), NULL, &secretLen) != 1) {
1555         LOG(ERROR) << "Error determing length of shared secret";
1556         return {};
1557     }
1558     vector<uint8_t> sharedSecret;
1559     sharedSecret.resize(secretLen);
1560 
1561     if (EVP_PKEY_derive(ctx.get(), sharedSecret.data(), &secretLen) != 1) {
1562         LOG(ERROR) << "Error deriving shared secret";
1563         return {};
1564     }
1565     return sharedSecret;
1566 }
1567 
hkdf(const vector<uint8_t> & sharedSecret,const vector<uint8_t> & salt,const vector<uint8_t> & info,size_t size)1568 optional<vector<uint8_t>> hkdf(const vector<uint8_t>& sharedSecret, const vector<uint8_t>& salt,
1569                                const vector<uint8_t>& info, size_t size) {
1570     vector<uint8_t> derivedKey;
1571     derivedKey.resize(size);
1572     if (HKDF(derivedKey.data(), derivedKey.size(), EVP_sha256(), sharedSecret.data(),
1573              sharedSecret.size(), salt.data(), salt.size(), info.data(), info.size()) != 1) {
1574         LOG(ERROR) << "Error deriving key";
1575         return {};
1576     }
1577     return derivedKey;
1578 }
1579 
removeLeadingZeroes(vector<uint8_t> & vec)1580 void removeLeadingZeroes(vector<uint8_t>& vec) {
1581     while (vec.size() >= 1 && vec[0] == 0x00) {
1582         vec.erase(vec.begin());
1583     }
1584 }
1585 
ecPublicKeyGetXandY(const vector<uint8_t> & publicKey)1586 tuple<bool, vector<uint8_t>, vector<uint8_t>> ecPublicKeyGetXandY(
1587         const vector<uint8_t>& publicKey) {
1588     if (publicKey.size() != 65 || publicKey[0] != 0x04) {
1589         LOG(ERROR) << "publicKey is not in the expected format";
1590         return std::make_tuple(false, vector<uint8_t>(), vector<uint8_t>());
1591     }
1592     vector<uint8_t> x, y;
1593     x.resize(32);
1594     y.resize(32);
1595     memcpy(x.data(), publicKey.data() + 1, 32);
1596     memcpy(y.data(), publicKey.data() + 33, 32);
1597 
1598     removeLeadingZeroes(x);
1599     removeLeadingZeroes(y);
1600 
1601     return std::make_tuple(true, x, y);
1602 }
1603 
certificateChainGetTopMostKey(const vector<uint8_t> & certificateChain)1604 optional<vector<uint8_t>> certificateChainGetTopMostKey(const vector<uint8_t>& certificateChain) {
1605     vector<X509_Ptr> certs;
1606     if (!parseX509Certificates(certificateChain, certs)) {
1607         return {};
1608     }
1609     if (certs.size() < 1) {
1610         LOG(ERROR) << "No certificates in chain";
1611         return {};
1612     }
1613 
1614     int algoId = OBJ_obj2nid(certs[0]->cert_info->key->algor->algorithm);
1615     if (algoId != NID_X9_62_id_ecPublicKey) {
1616         LOG(ERROR) << "Expected NID_X9_62_id_ecPublicKey, got " << OBJ_nid2ln(algoId);
1617         return {};
1618     }
1619 
1620     auto pkey = EVP_PKEY_Ptr(X509_get_pubkey(certs[0].get()));
1621     if (pkey.get() == nullptr) {
1622         LOG(ERROR) << "No public key";
1623         return {};
1624     }
1625 
1626     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
1627     if (ecKey.get() == nullptr) {
1628         LOG(ERROR) << "Failed getting EC key";
1629         return {};
1630     }
1631 
1632     auto ecGroup = EC_KEY_get0_group(ecKey.get());
1633     auto ecPoint = EC_KEY_get0_public_key(ecKey.get());
1634     int size = EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0,
1635                                   nullptr);
1636     if (size == 0) {
1637         LOG(ERROR) << "Error generating public key encoding";
1638         return {};
1639     }
1640     vector<uint8_t> publicKey;
1641     publicKey.resize(size);
1642     EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, publicKey.data(),
1643                        publicKey.size(), nullptr);
1644     return publicKey;
1645 }
1646 
certificateFindPublicKey(const vector<uint8_t> & x509Certificate)1647 optional<pair<size_t, size_t>> certificateFindPublicKey(const vector<uint8_t>& x509Certificate) {
1648     vector<X509_Ptr> certs;
1649     if (!parseX509Certificates(x509Certificate, certs)) {
1650         return {};
1651     }
1652     if (certs.size() < 1) {
1653         LOG(ERROR) << "No certificates in chain";
1654         return {};
1655     }
1656 
1657     auto pkey = EVP_PKEY_Ptr(X509_get_pubkey(certs[0].get()));
1658     if (pkey.get() == nullptr) {
1659         LOG(ERROR) << "No public key";
1660         return {};
1661     }
1662 
1663     auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pkey.get()));
1664     if (ecKey.get() == nullptr) {
1665         LOG(ERROR) << "Failed getting EC key";
1666         return {};
1667     }
1668 
1669     auto ecGroup = EC_KEY_get0_group(ecKey.get());
1670     auto ecPoint = EC_KEY_get0_public_key(ecKey.get());
1671     int size = EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, nullptr, 0,
1672                                   nullptr);
1673     if (size == 0) {
1674         LOG(ERROR) << "Error generating public key encoding";
1675         return {};
1676     }
1677     vector<uint8_t> publicKey;
1678     publicKey.resize(size);
1679     EC_POINT_point2oct(ecGroup, ecPoint, POINT_CONVERSION_UNCOMPRESSED, publicKey.data(),
1680                        publicKey.size(), nullptr);
1681 
1682     size_t publicKeyOffset = 0;
1683     size_t publicKeySize = (size_t)size;
1684     void* location = memmem((const void*)x509Certificate.data(), x509Certificate.size(),
1685                             (const void*)publicKey.data(), publicKey.size());
1686 
1687     if (location == NULL) {
1688         LOG(ERROR) << "Error finding publicKey from x509Certificate";
1689         return {};
1690     }
1691     publicKeyOffset = (size_t)((const char*)location - (const char*)x509Certificate.data());
1692 
1693     return std::make_pair(publicKeyOffset, publicKeySize);
1694 }
1695 
certificateTbsCertificate(const vector<uint8_t> & x509Certificate)1696 optional<pair<size_t, size_t>> certificateTbsCertificate(const vector<uint8_t>& x509Certificate) {
1697     vector<X509_Ptr> certs;
1698     if (!parseX509Certificates(x509Certificate, certs)) {
1699         return {};
1700     }
1701     if (certs.size() < 1) {
1702         LOG(ERROR) << "No certificates in chain";
1703         return {};
1704     }
1705 
1706     unsigned char* buf = NULL;
1707     int len = i2d_re_X509_tbs(certs[0].get(), &buf);
1708     if ((len < 0) || (buf == NULL)) {
1709         LOG(ERROR) << "fail to extract tbsCertificate in x509Certificate";
1710         return {};
1711     }
1712 
1713     vector<uint8_t> tbsCertificate(len);
1714     memcpy(tbsCertificate.data(), buf, len);
1715 
1716     size_t tbsCertificateOffset = 0;
1717     size_t tbsCertificateSize = (size_t)len;
1718     void* location = memmem((const void*)x509Certificate.data(), x509Certificate.size(),
1719                             (const void*)tbsCertificate.data(), tbsCertificate.size());
1720 
1721     if (location == NULL) {
1722         LOG(ERROR) << "Error finding tbsCertificate from x509Certificate";
1723         return {};
1724     }
1725     tbsCertificateOffset = (size_t)((const char*)location - (const char*)x509Certificate.data());
1726 
1727     return std::make_pair(tbsCertificateOffset, tbsCertificateSize);
1728 }
1729 
certificateGetValidity(const vector<uint8_t> & x509Certificate)1730 optional<pair<time_t, time_t>> certificateGetValidity(const vector<uint8_t>& x509Certificate) {
1731     vector<X509_Ptr> certs;
1732     if (!parseX509Certificates(x509Certificate, certs)) {
1733         LOG(ERROR) << "Error parsing certificates";
1734         return {};
1735     }
1736     if (certs.size() < 1) {
1737         LOG(ERROR) << "No certificates in chain";
1738         return {};
1739     }
1740 
1741     time_t notBefore;
1742     time_t notAfter;
1743     if (!parseAsn1Time(X509_get0_notBefore(certs[0].get()), &notBefore)) {
1744         LOG(ERROR) << "Error parsing notBefore";
1745         return {};
1746     }
1747 
1748     if (!parseAsn1Time(X509_get0_notAfter(certs[0].get()), &notAfter)) {
1749         LOG(ERROR) << "Error parsing notAfter";
1750         return {};
1751     }
1752 
1753     return std::make_pair(notBefore, notAfter);
1754 }
1755 
certificateFindSignature(const vector<uint8_t> & x509Certificate)1756 optional<pair<size_t, size_t>> certificateFindSignature(const vector<uint8_t>& x509Certificate) {
1757     vector<X509_Ptr> certs;
1758     if (!parseX509Certificates(x509Certificate, certs)) {
1759         return {};
1760     }
1761     if (certs.size() < 1) {
1762         LOG(ERROR) << "No certificates in chain";
1763         return {};
1764     }
1765 
1766     ASN1_BIT_STRING* psig;
1767     X509_ALGOR* palg;
1768     X509_get0_signature((const ASN1_BIT_STRING**)&psig, (const X509_ALGOR**)&palg, certs[0].get());
1769 
1770     vector<char> signature(psig->length);
1771     memcpy(signature.data(), psig->data, psig->length);
1772 
1773     size_t signatureOffset = 0;
1774     size_t signatureSize = (size_t)psig->length;
1775     void* location = memmem((const void*)x509Certificate.data(), x509Certificate.size(),
1776                             (const void*)signature.data(), signature.size());
1777 
1778     if (location == NULL) {
1779         LOG(ERROR) << "Error finding signature from x509Certificate";
1780         return {};
1781     }
1782     signatureOffset = (size_t)((const char*)location - (const char*)x509Certificate.data());
1783 
1784     return std::make_pair(signatureOffset, signatureSize);
1785 }
1786 
1787 // ---------------------------------------------------------------------------
1788 // COSE Utility Functions
1789 // ---------------------------------------------------------------------------
1790 
coseBuildToBeSigned(const vector<uint8_t> & encodedProtectedHeaders,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent)1791 vector<uint8_t> coseBuildToBeSigned(const vector<uint8_t>& encodedProtectedHeaders,
1792                                     const vector<uint8_t>& data,
1793                                     const vector<uint8_t>& detachedContent) {
1794     cppbor::Array sigStructure;
1795     sigStructure.add("Signature1");
1796     sigStructure.add(encodedProtectedHeaders);
1797 
1798     // We currently don't support Externally Supplied Data (RFC 8152 section 4.3)
1799     // so external_aad is the empty bstr
1800     vector<uint8_t> emptyExternalAad;
1801     sigStructure.add(emptyExternalAad);
1802 
1803     // Next field is the payload, independently of how it's transported (RFC
1804     // 8152 section 4.4). Since our API specifies only one of |data| and
1805     // |detachedContent| can be non-empty, it's simply just the non-empty one.
1806     if (data.size() > 0) {
1807         sigStructure.add(data);
1808     } else {
1809         sigStructure.add(detachedContent);
1810     }
1811     return sigStructure.encode();
1812 }
1813 
coseEncodeHeaders(const cppbor::Map & protectedHeaders)1814 vector<uint8_t> coseEncodeHeaders(const cppbor::Map& protectedHeaders) {
1815     if (protectedHeaders.size() == 0) {
1816         cppbor::Bstr emptyBstr(vector<uint8_t>({}));
1817         return emptyBstr.encode();
1818     }
1819     return protectedHeaders.encode();
1820 }
1821 
1822 // From https://tools.ietf.org/html/rfc8152
1823 const int COSE_LABEL_ALG = 1;
1824 const int COSE_LABEL_X5CHAIN = 33;  // temporary identifier
1825 
1826 // From "COSE Algorithms" registry
1827 const int COSE_ALG_ECDSA_256 = -7;
1828 const int COSE_ALG_HMAC_256_256 = 5;
1829 
ecdsaSignatureCoseToDer(const vector<uint8_t> & ecdsaCoseSignature,vector<uint8_t> & ecdsaDerSignature)1830 bool ecdsaSignatureCoseToDer(const vector<uint8_t>& ecdsaCoseSignature,
1831                              vector<uint8_t>& ecdsaDerSignature) {
1832     if (ecdsaCoseSignature.size() != 64) {
1833         LOG(ERROR) << "COSE signature length is " << ecdsaCoseSignature.size() << ", expected 64";
1834         return false;
1835     }
1836 
1837     auto rBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data(), 32, nullptr));
1838     if (rBn.get() == nullptr) {
1839         LOG(ERROR) << "Error creating BIGNUM for r";
1840         return false;
1841     }
1842 
1843     auto sBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data() + 32, 32, nullptr));
1844     if (sBn.get() == nullptr) {
1845         LOG(ERROR) << "Error creating BIGNUM for s";
1846         return false;
1847     }
1848 
1849     ECDSA_SIG sig;
1850     sig.r = rBn.get();
1851     sig.s = sBn.get();
1852 
1853     size_t len = i2d_ECDSA_SIG(&sig, nullptr);
1854     ecdsaDerSignature.resize(len);
1855     unsigned char* p = (unsigned char*)ecdsaDerSignature.data();
1856     i2d_ECDSA_SIG(&sig, &p);
1857 
1858     return true;
1859 }
1860 
ecdsaSignatureDerToCose(const vector<uint8_t> & ecdsaDerSignature,vector<uint8_t> & ecdsaCoseSignature)1861 bool ecdsaSignatureDerToCose(const vector<uint8_t>& ecdsaDerSignature,
1862                              vector<uint8_t>& ecdsaCoseSignature) {
1863     ECDSA_SIG* sig;
1864     const unsigned char* p = ecdsaDerSignature.data();
1865     sig = d2i_ECDSA_SIG(nullptr, &p, ecdsaDerSignature.size());
1866     if (sig == nullptr) {
1867         LOG(ERROR) << "Error decoding DER signature";
1868         return false;
1869     }
1870 
1871     ecdsaCoseSignature.clear();
1872     ecdsaCoseSignature.resize(64);
1873     if (BN_bn2binpad(sig->r, ecdsaCoseSignature.data(), 32) != 32) {
1874         LOG(ERROR) << "Error encoding r";
1875         return false;
1876     }
1877     if (BN_bn2binpad(sig->s, ecdsaCoseSignature.data() + 32, 32) != 32) {
1878         LOG(ERROR) << "Error encoding s";
1879         return false;
1880     }
1881     return true;
1882 }
1883 
coseSignEcDsaWithSignature(const vector<uint8_t> & signatureToBeSigned,const vector<uint8_t> & data,const vector<uint8_t> & certificateChain)1884 optional<vector<uint8_t>> coseSignEcDsaWithSignature(const vector<uint8_t>& signatureToBeSigned,
1885                                                      const vector<uint8_t>& data,
1886                                                      const vector<uint8_t>& certificateChain) {
1887     if (signatureToBeSigned.size() != 64) {
1888         LOG(ERROR) << "Invalid size for signatureToBeSigned, expected 64 got "
1889                    << signatureToBeSigned.size();
1890         return {};
1891     }
1892 
1893     cppbor::Map unprotectedHeaders;
1894     cppbor::Map protectedHeaders;
1895 
1896     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_ECDSA_256);
1897 
1898     if (certificateChain.size() != 0) {
1899         optional<vector<vector<uint8_t>>> certs = support::certificateChainSplit(certificateChain);
1900         if (!certs) {
1901             LOG(ERROR) << "Error splitting certificate chain";
1902             return {};
1903         }
1904         if (certs.value().size() == 1) {
1905             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, certs.value()[0]);
1906         } else {
1907             cppbor::Array certArray;
1908             for (const vector<uint8_t>& cert : certs.value()) {
1909                 certArray.add(cert);
1910             }
1911             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, std::move(certArray));
1912         }
1913     }
1914 
1915     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
1916 
1917     cppbor::Array coseSign1;
1918     coseSign1.add(encodedProtectedHeaders);
1919     coseSign1.add(std::move(unprotectedHeaders));
1920     if (data.size() == 0) {
1921         cppbor::Null nullValue;
1922         coseSign1.add(std::move(nullValue));
1923     } else {
1924         coseSign1.add(data);
1925     }
1926     coseSign1.add(signatureToBeSigned);
1927     vector<uint8_t> signatureCoseSign1;
1928     signatureCoseSign1 = coseSign1.encode();
1929 
1930     return signatureCoseSign1;
1931 }
1932 
coseSignEcDsa(const vector<uint8_t> & key,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent,const vector<uint8_t> & certificateChain)1933 optional<vector<uint8_t>> coseSignEcDsa(const vector<uint8_t>& key, const vector<uint8_t>& data,
1934                                         const vector<uint8_t>& detachedContent,
1935                                         const vector<uint8_t>& certificateChain) {
1936     cppbor::Map unprotectedHeaders;
1937     cppbor::Map protectedHeaders;
1938 
1939     if (data.size() > 0 && detachedContent.size() > 0) {
1940         LOG(ERROR) << "data and detachedContent cannot both be non-empty";
1941         return {};
1942     }
1943 
1944     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_ECDSA_256);
1945 
1946     if (certificateChain.size() != 0) {
1947         optional<vector<vector<uint8_t>>> certs = support::certificateChainSplit(certificateChain);
1948         if (!certs) {
1949             LOG(ERROR) << "Error splitting certificate chain";
1950             return {};
1951         }
1952         if (certs.value().size() == 1) {
1953             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, certs.value()[0]);
1954         } else {
1955             cppbor::Array certArray;
1956             for (const vector<uint8_t>& cert : certs.value()) {
1957                 certArray.add(cert);
1958             }
1959             unprotectedHeaders.add(COSE_LABEL_X5CHAIN, std::move(certArray));
1960         }
1961     }
1962 
1963     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
1964     vector<uint8_t> toBeSigned =
1965             coseBuildToBeSigned(encodedProtectedHeaders, data, detachedContent);
1966 
1967     optional<vector<uint8_t>> derSignature = signEcDsa(key, toBeSigned);
1968     if (!derSignature) {
1969         LOG(ERROR) << "Error signing toBeSigned data";
1970         return {};
1971     }
1972     vector<uint8_t> coseSignature;
1973     if (!ecdsaSignatureDerToCose(derSignature.value(), coseSignature)) {
1974         LOG(ERROR) << "Error converting ECDSA signature from DER to COSE format";
1975         return {};
1976     }
1977 
1978     cppbor::Array coseSign1;
1979     coseSign1.add(encodedProtectedHeaders);
1980     coseSign1.add(std::move(unprotectedHeaders));
1981     if (data.size() == 0) {
1982         cppbor::Null nullValue;
1983         coseSign1.add(std::move(nullValue));
1984     } else {
1985         coseSign1.add(data);
1986     }
1987     coseSign1.add(coseSignature);
1988     vector<uint8_t> signatureCoseSign1;
1989     signatureCoseSign1 = coseSign1.encode();
1990     return signatureCoseSign1;
1991 }
1992 
coseCheckEcDsaSignature(const vector<uint8_t> & signatureCoseSign1,const vector<uint8_t> & detachedContent,const vector<uint8_t> & publicKey)1993 bool coseCheckEcDsaSignature(const vector<uint8_t>& signatureCoseSign1,
1994                              const vector<uint8_t>& detachedContent,
1995                              const vector<uint8_t>& publicKey) {
1996     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
1997     if (item == nullptr) {
1998         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
1999         return false;
2000     }
2001     const cppbor::Array* array = item->asArray();
2002     if (array == nullptr) {
2003         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
2004         return false;
2005     }
2006     if (array->size() != 4) {
2007         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
2008         return false;
2009     }
2010 
2011     const cppbor::Bstr* encodedProtectedHeadersBstr = (*array)[0]->asBstr();
2012     ;
2013     if (encodedProtectedHeadersBstr == nullptr) {
2014         LOG(ERROR) << "Value for encodedProtectedHeaders is not a bstr";
2015         return false;
2016     }
2017     const vector<uint8_t> encodedProtectedHeaders = encodedProtectedHeadersBstr->value();
2018 
2019     const cppbor::Map* unprotectedHeaders = (*array)[1]->asMap();
2020     if (unprotectedHeaders == nullptr) {
2021         LOG(ERROR) << "Value for unprotectedHeaders is not a map";
2022         return false;
2023     }
2024 
2025     vector<uint8_t> data;
2026     const cppbor::Simple* payloadAsSimple = (*array)[2]->asSimple();
2027     if (payloadAsSimple != nullptr) {
2028         if (payloadAsSimple->asNull() == nullptr) {
2029             LOG(ERROR) << "Value for payload is not null or a bstr";
2030             return false;
2031         }
2032     } else {
2033         const cppbor::Bstr* payloadAsBstr = (*array)[2]->asBstr();
2034         if (payloadAsBstr == nullptr) {
2035             LOG(ERROR) << "Value for payload is not null or a bstr";
2036             return false;
2037         }
2038         data = payloadAsBstr->value();  // TODO: avoid copy
2039     }
2040 
2041     if (data.size() > 0 && detachedContent.size() > 0) {
2042         LOG(ERROR) << "data and detachedContent cannot both be non-empty";
2043         return false;
2044     }
2045 
2046     const cppbor::Bstr* signatureBstr = (*array)[3]->asBstr();
2047     if (signatureBstr == nullptr) {
2048         LOG(ERROR) << "Value for signature is a bstr";
2049         return false;
2050     }
2051     const vector<uint8_t>& coseSignature = signatureBstr->value();
2052 
2053     vector<uint8_t> derSignature;
2054     if (!ecdsaSignatureCoseToDer(coseSignature, derSignature)) {
2055         LOG(ERROR) << "Error converting ECDSA signature from COSE to DER format";
2056         return false;
2057     }
2058 
2059     vector<uint8_t> toBeSigned =
2060             coseBuildToBeSigned(encodedProtectedHeaders, data, detachedContent);
2061     if (!checkEcDsaSignature(support::sha256(toBeSigned), derSignature, publicKey)) {
2062         LOG(ERROR) << "Signature check failed";
2063         return false;
2064     }
2065     return true;
2066 }
2067 
2068 // Extracts the signature (of the ToBeSigned CBOR) from a COSE_Sign1.
coseSignGetSignature(const vector<uint8_t> & signatureCoseSign1)2069 optional<vector<uint8_t>> coseSignGetSignature(const vector<uint8_t>& signatureCoseSign1) {
2070     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
2071     if (item == nullptr) {
2072         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
2073         return {};
2074     }
2075     const cppbor::Array* array = item->asArray();
2076     if (array == nullptr) {
2077         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
2078         return {};
2079     }
2080     if (array->size() != 4) {
2081         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
2082         return {};
2083     }
2084 
2085     vector<uint8_t> signature;
2086     const cppbor::Bstr* signatureAsBstr = (*array)[3]->asBstr();
2087     if (signatureAsBstr == nullptr) {
2088         LOG(ERROR) << "Value for signature is not a bstr";
2089         return {};
2090     }
2091     // Copy payload into |data|
2092     signature = signatureAsBstr->value();
2093 
2094     return signature;
2095 }
2096 
coseSignGetPayload(const vector<uint8_t> & signatureCoseSign1)2097 optional<vector<uint8_t>> coseSignGetPayload(const vector<uint8_t>& signatureCoseSign1) {
2098     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
2099     if (item == nullptr) {
2100         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
2101         return {};
2102     }
2103     const cppbor::Array* array = item->asArray();
2104     if (array == nullptr) {
2105         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
2106         return {};
2107     }
2108     if (array->size() != 4) {
2109         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
2110         return {};
2111     }
2112 
2113     vector<uint8_t> data;
2114     const cppbor::Simple* payloadAsSimple = (*array)[2]->asSimple();
2115     if (payloadAsSimple != nullptr) {
2116         if (payloadAsSimple->asNull() == nullptr) {
2117             LOG(ERROR) << "Value for payload is not null or a bstr";
2118             return {};
2119         }
2120         // payload is null, so |data| should be empty (as it is)
2121     } else {
2122         const cppbor::Bstr* payloadAsBstr = (*array)[2]->asBstr();
2123         if (payloadAsBstr == nullptr) {
2124             LOG(ERROR) << "Value for payload is not null or a bstr";
2125             return {};
2126         }
2127         // Copy payload into |data|
2128         data = payloadAsBstr->value();
2129     }
2130 
2131     return data;
2132 }
2133 
coseSignGetAlg(const vector<uint8_t> & signatureCoseSign1)2134 optional<int> coseSignGetAlg(const vector<uint8_t>& signatureCoseSign1) {
2135     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
2136     if (item == nullptr) {
2137         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
2138         return {};
2139     }
2140     const cppbor::Array* array = item->asArray();
2141     if (array == nullptr) {
2142         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
2143         return {};
2144     }
2145     if (array->size() != 4) {
2146         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
2147         return {};
2148     }
2149 
2150     const cppbor::Bstr* protectedHeadersBytes = (*array)[0]->asBstr();
2151     if (protectedHeadersBytes == nullptr) {
2152         LOG(ERROR) << "Value for protectedHeaders is not a bstr";
2153         return {};
2154     }
2155     auto [item2, _2, message2] = cppbor::parse(protectedHeadersBytes->value());
2156     if (item2 == nullptr) {
2157         LOG(ERROR) << "Error parsing protectedHeaders: " << message2;
2158         return {};
2159     }
2160     const cppbor::Map* protectedHeaders = item2->asMap();
2161     if (protectedHeaders == nullptr) {
2162         LOG(ERROR) << "Decoded CBOR for protectedHeaders is not a map";
2163         return {};
2164     }
2165 
2166     for (size_t n = 0; n < protectedHeaders->size(); n++) {
2167         auto [keyItem, valueItem] = (*protectedHeaders)[n];
2168         const cppbor::Int* number = keyItem->asInt();
2169         if (number == nullptr) {
2170             LOG(ERROR) << "Key item in top-level map is not a number";
2171             return {};
2172         }
2173         int label = number->value();
2174         if (label == COSE_LABEL_ALG) {
2175             const cppbor::Int* number = valueItem->asInt();
2176             if (number != nullptr) {
2177                 return number->value();
2178             }
2179             LOG(ERROR) << "Value for COSE_LABEL_ALG label is not a number";
2180             return {};
2181         }
2182     }
2183     LOG(ERROR) << "Did not find COSE_LABEL_ALG label in protected headers";
2184     return {};
2185 }
2186 
coseSignGetX5Chain(const vector<uint8_t> & signatureCoseSign1)2187 optional<vector<uint8_t>> coseSignGetX5Chain(const vector<uint8_t>& signatureCoseSign1) {
2188     auto [item, _, message] = cppbor::parse(signatureCoseSign1);
2189     if (item == nullptr) {
2190         LOG(ERROR) << "Passed-in COSE_Sign1 is not valid CBOR: " << message;
2191         return {};
2192     }
2193     const cppbor::Array* array = item->asArray();
2194     if (array == nullptr) {
2195         LOG(ERROR) << "Value for COSE_Sign1 is not an array";
2196         return {};
2197     }
2198     if (array->size() != 4) {
2199         LOG(ERROR) << "Value for COSE_Sign1 is not an array of size 4";
2200         return {};
2201     }
2202 
2203     const cppbor::Map* unprotectedHeaders = (*array)[1]->asMap();
2204     if (unprotectedHeaders == nullptr) {
2205         LOG(ERROR) << "Value for unprotectedHeaders is not a map";
2206         return {};
2207     }
2208 
2209     for (size_t n = 0; n < unprotectedHeaders->size(); n++) {
2210         auto [keyItem, valueItem] = (*unprotectedHeaders)[n];
2211         const cppbor::Int* number = keyItem->asInt();
2212         if (number == nullptr) {
2213             LOG(ERROR) << "Key item in top-level map is not a number";
2214             return {};
2215         }
2216         int label = number->value();
2217         if (label == COSE_LABEL_X5CHAIN) {
2218             const cppbor::Bstr* bstr = valueItem->asBstr();
2219             if (bstr != nullptr) {
2220                 return bstr->value();
2221             }
2222             const cppbor::Array* array = valueItem->asArray();
2223             if (array != nullptr) {
2224                 vector<uint8_t> certs;
2225                 for (size_t m = 0; m < array->size(); m++) {
2226                     const cppbor::Bstr* bstr = ((*array)[m])->asBstr();
2227                     if (bstr == nullptr) {
2228                         LOG(ERROR) << "Item in x5chain array is not a bstr";
2229                         return {};
2230                     }
2231                     const vector<uint8_t>& certValue = bstr->value();
2232                     certs.insert(certs.end(), certValue.begin(), certValue.end());
2233                 }
2234                 return certs;
2235             }
2236             LOG(ERROR) << "Value for x5chain label is not a bstr or array";
2237             return {};
2238         }
2239     }
2240     LOG(ERROR) << "Did not find x5chain label in unprotected headers";
2241     return {};
2242 }
2243 
coseBuildToBeMACed(const vector<uint8_t> & encodedProtectedHeaders,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent)2244 vector<uint8_t> coseBuildToBeMACed(const vector<uint8_t>& encodedProtectedHeaders,
2245                                    const vector<uint8_t>& data,
2246                                    const vector<uint8_t>& detachedContent) {
2247     cppbor::Array macStructure;
2248     macStructure.add("MAC0");
2249     macStructure.add(encodedProtectedHeaders);
2250 
2251     // We currently don't support Externally Supplied Data (RFC 8152 section 4.3)
2252     // so external_aad is the empty bstr
2253     vector<uint8_t> emptyExternalAad;
2254     macStructure.add(emptyExternalAad);
2255 
2256     // Next field is the payload, independently of how it's transported (RFC
2257     // 8152 section 4.4). Since our API specifies only one of |data| and
2258     // |detachedContent| can be non-empty, it's simply just the non-empty one.
2259     if (data.size() > 0) {
2260         macStructure.add(data);
2261     } else {
2262         macStructure.add(detachedContent);
2263     }
2264 
2265     return macStructure.encode();
2266 }
2267 
coseMac0(const vector<uint8_t> & key,const vector<uint8_t> & data,const vector<uint8_t> & detachedContent)2268 optional<vector<uint8_t>> coseMac0(const vector<uint8_t>& key, const vector<uint8_t>& data,
2269                                    const vector<uint8_t>& detachedContent) {
2270     cppbor::Map unprotectedHeaders;
2271     cppbor::Map protectedHeaders;
2272 
2273     if (data.size() > 0 && detachedContent.size() > 0) {
2274         LOG(ERROR) << "data and detachedContent cannot both be non-empty";
2275         return {};
2276     }
2277 
2278     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_HMAC_256_256);
2279 
2280     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
2281     vector<uint8_t> toBeMACed = coseBuildToBeMACed(encodedProtectedHeaders, data, detachedContent);
2282 
2283     optional<vector<uint8_t>> mac = hmacSha256(key, toBeMACed);
2284     if (!mac) {
2285         LOG(ERROR) << "Error MACing toBeMACed data";
2286         return {};
2287     }
2288 
2289     cppbor::Array array;
2290     array.add(encodedProtectedHeaders);
2291     array.add(std::move(unprotectedHeaders));
2292     if (data.size() == 0) {
2293         cppbor::Null nullValue;
2294         array.add(std::move(nullValue));
2295     } else {
2296         array.add(data);
2297     }
2298     array.add(mac.value());
2299     return array.encode();
2300 }
2301 
coseMacWithDigest(const vector<uint8_t> & digestToBeMaced,const vector<uint8_t> & data)2302 optional<vector<uint8_t>> coseMacWithDigest(const vector<uint8_t>& digestToBeMaced,
2303                                             const vector<uint8_t>& data) {
2304     cppbor::Map unprotectedHeaders;
2305     cppbor::Map protectedHeaders;
2306 
2307     protectedHeaders.add(COSE_LABEL_ALG, COSE_ALG_HMAC_256_256);
2308 
2309     vector<uint8_t> encodedProtectedHeaders = coseEncodeHeaders(protectedHeaders);
2310 
2311     cppbor::Array array;
2312     array.add(encodedProtectedHeaders);
2313     array.add(std::move(unprotectedHeaders));
2314     if (data.size() == 0) {
2315         cppbor::Null nullValue;
2316         array.add(std::move(nullValue));
2317     } else {
2318         array.add(data);
2319     }
2320     array.add(digestToBeMaced);
2321     return array.encode();
2322 }
2323 
2324 // ---------------------------------------------------------------------------
2325 // Utility functions specific to IdentityCredential.
2326 // ---------------------------------------------------------------------------
2327 
calcEMacKey(const vector<uint8_t> & privateKey,const vector<uint8_t> & publicKey,const vector<uint8_t> & sessionTranscriptBytes)2328 optional<vector<uint8_t>> calcEMacKey(const vector<uint8_t>& privateKey,
2329                                       const vector<uint8_t>& publicKey,
2330                                       const vector<uint8_t>& sessionTranscriptBytes) {
2331     optional<vector<uint8_t>> sharedSecret = support::ecdh(publicKey, privateKey);
2332     if (!sharedSecret) {
2333         LOG(ERROR) << "Error performing ECDH";
2334         return {};
2335     }
2336     vector<uint8_t> salt = support::sha256(sessionTranscriptBytes);
2337     vector<uint8_t> info = {'E', 'M', 'a', 'c', 'K', 'e', 'y'};
2338     optional<vector<uint8_t>> derivedKey = support::hkdf(sharedSecret.value(), salt, info, 32);
2339     if (!derivedKey) {
2340         LOG(ERROR) << "Error performing HKDF";
2341         return {};
2342     }
2343     return derivedKey.value();
2344 }
2345 
calcMac(const vector<uint8_t> & sessionTranscriptEncoded,const string & docType,const vector<uint8_t> & deviceNameSpacesEncoded,const vector<uint8_t> & eMacKey)2346 optional<vector<uint8_t>> calcMac(const vector<uint8_t>& sessionTranscriptEncoded,
2347                                   const string& docType,
2348                                   const vector<uint8_t>& deviceNameSpacesEncoded,
2349                                   const vector<uint8_t>& eMacKey) {
2350     auto [sessionTranscriptItem, _, errMsg] = cppbor::parse(sessionTranscriptEncoded);
2351     if (sessionTranscriptItem == nullptr) {
2352         LOG(ERROR) << "Error parsing sessionTranscriptEncoded: " << errMsg;
2353         return {};
2354     }
2355     // The data that is MACed is ["DeviceAuthentication", sessionTranscript, docType,
2356     // deviceNameSpacesBytes] so build up that structure
2357     cppbor::Array deviceAuthentication =
2358             cppbor::Array()
2359                     .add("DeviceAuthentication")
2360                     .add(std::move(sessionTranscriptItem))
2361                     .add(docType)
2362                     .add(cppbor::Semantic(kSemanticTagEncodedCbor, deviceNameSpacesEncoded));
2363     vector<uint8_t> deviceAuthenticationBytes =
2364             cppbor::Semantic(kSemanticTagEncodedCbor, deviceAuthentication.encode()).encode();
2365     optional<vector<uint8_t>> calculatedMac =
2366             support::coseMac0(eMacKey, {},                 // payload
2367                               deviceAuthenticationBytes);  // detached content
2368     return calculatedMac;
2369 }
2370 
chunkVector(const vector<uint8_t> & content,size_t maxChunkSize)2371 vector<vector<uint8_t>> chunkVector(const vector<uint8_t>& content, size_t maxChunkSize) {
2372     vector<vector<uint8_t>> ret;
2373 
2374     size_t contentSize = content.size();
2375     if (contentSize <= maxChunkSize) {
2376         ret.push_back(content);
2377         return ret;
2378     }
2379 
2380     size_t numChunks = (contentSize + maxChunkSize - 1) / maxChunkSize;
2381 
2382     size_t pos = 0;
2383     for (size_t n = 0; n < numChunks; n++) {
2384         size_t size = contentSize - pos;
2385         if (size > maxChunkSize) {
2386             size = maxChunkSize;
2387         }
2388         auto begin = content.begin() + pos;
2389         auto end = content.begin() + pos + size;
2390         ret.emplace_back(vector<uint8_t>(begin, end));
2391         pos += maxChunkSize;
2392     }
2393 
2394     return ret;
2395 }
2396 
2397 
2398 vector<uint8_t> testHardwareBoundKey = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
2399 
getTestHardwareBoundKey()2400 const vector<uint8_t>& getTestHardwareBoundKey() {
2401     return testHardwareBoundKey;
2402 }
2403 
2404 }  // namespace support
2405 }  // namespace identity
2406 }  // namespace hardware
2407 }  // namespace android
2408