• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "anonymous_tokens/cpp/crypto/crypto_utils.h"
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #include <cstdint>
21 #include <iterator>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/status/status.h"
28 #include "absl/status/statusor.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/string_view.h"
31 #include "anonymous_tokens/cpp/crypto/constants.h"
32 #include "anonymous_tokens/cpp/shared/status_utils.h"
33 #include <openssl/bytestring.h>
34 #include <openssl/err.h>
35 #include <openssl/hkdf.h>
36 #include <openssl/mem.h>
37 #include <openssl/rand.h>
38 #include <openssl/rsa.h>
39 
40 
41 namespace anonymous_tokens {
42 
43 namespace internal {
44 
45 // Approximation of sqrt(2) taken from
46 // //depot/google3/third_party/openssl/boringssl/src/crypto/fipsmodule/rsa/rsa_impl.c;l=997
47 constexpr uint32_t kBoringSSLRSASqrtTwo[] = {
48     0x4d7c60a5, 0xe633e3e1, 0x5fcf8f7b, 0xca3ea33b, 0xc246785e, 0x92957023,
49     0xf9acce41, 0x797f2805, 0xfdfe170f, 0xd3b1f780, 0xd24f4a76, 0x3facb882,
50     0x18838a2e, 0xaff5f3b2, 0xc1fcbdde, 0xa2f7dc33, 0xdea06241, 0xf7aa81c2,
51     0xf6a1be3f, 0xca221307, 0x332a5e9f, 0x7bda1ebf, 0x0104dc01, 0xfe32352f,
52     0xb8cf341b, 0x6f8236c7, 0x4264dabc, 0xd528b651, 0xf4d3a02c, 0xebc93e0c,
53     0x81394ab6, 0xd8fd0efd, 0xeaa4a089, 0x9040ca4a, 0xf52f120f, 0x836e582e,
54     0xcb2a6343, 0x31f3c84d, 0xc6d5a8a3, 0x8bb7e9dc, 0x460abc72, 0x2f7c4e33,
55     0xcab1bc91, 0x1688458a, 0x53059c60, 0x11bc337b, 0xd2202e87, 0x42af1f4e,
56     0x78048736, 0x3dfa2768, 0x0f74a85e, 0x439c7b4a, 0xa8b1fe6f, 0xdc83db39,
57     0x4afc8304, 0x3ab8a2c3, 0xed17ac85, 0x83339915, 0x1d6f60ba, 0x893ba84c,
58     0x597d89b3, 0x754abe9f, 0xb504f333, 0xf9de6484,
59 };
60 
PublicMetadataHashWithHKDF(absl::string_view public_metadata,absl::string_view rsa_modulus_str,size_t out_len_bytes)61 absl::StatusOr<bssl::UniquePtr<BIGNUM>> PublicMetadataHashWithHKDF(
62     absl::string_view public_metadata, absl::string_view rsa_modulus_str,
63     size_t out_len_bytes) {
64   const EVP_MD* evp_md_sha_384 = EVP_sha384();
65   // Prepend "key" to input.
66   std::string modified_input = absl::StrCat("key", public_metadata);
67   std::vector<uint8_t> input_buffer(modified_input.begin(),
68                                     modified_input.end());
69   // Append 0x00 to input.
70   input_buffer.push_back(0x00);
71   std::string out_e;
72   // We set the out_e size beyond out_len_bytes so that out_e bytes are
73   // indifferentiable from truly random bytes even after truncations.
74   //
75   // Expanding to 16 more bytes is sufficient.
76   // https://cfrg.github.io/draft-irtf-cfrg-hash-to-curve/draft-irtf-cfrg-hash-to-curve.html#name-hashing-to-a-finite-field
77   const size_t hkdf_output_size = out_len_bytes + 16;
78   out_e.resize(hkdf_output_size);
79   // The modulus is used as salt to ensure different outputs for same metadata
80   // and different modulus.
81   if (HKDF(reinterpret_cast<uint8_t*>(out_e.data()), hkdf_output_size,
82            evp_md_sha_384, input_buffer.data(), input_buffer.size(),
83            reinterpret_cast<const uint8_t*>(rsa_modulus_str.data()),
84            rsa_modulus_str.size(),
85            reinterpret_cast<const uint8_t*>(kHkdfPublicMetadataInfo.data()),
86            kHkdfPublicMetadataInfoSizeInBytes) != kBsslSuccess) {
87     return absl::InternalError("HKDF failed in public_metadata_crypto_utils");
88   }
89   // Truncate out_e to out_len_bytes
90   out_e.resize(out_len_bytes);
91   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> out,
92                                StringToBignum(out_e));
93   return out;
94 }
95 
96 }  // namespace internal
97 
98 namespace {
99 
100 // Marshals an RSA public key in the DER format.
MarshalRsaPublicKey(const RSA * rsa)101 absl::StatusOr<std::string> MarshalRsaPublicKey(const RSA* rsa) {
102   uint8_t* rsa_public_key_bytes;
103   size_t rsa_public_key_bytes_len = 0;
104   if (!RSA_public_key_to_bytes(&rsa_public_key_bytes, &rsa_public_key_bytes_len,
105                                rsa)) {
106     return absl::InvalidArgumentError(absl::StrCat(
107         "Failed to marshall rsa public key to a DER encoded RSAPublicKey "
108         "structure (RFC 8017): ",
109         GetSslErrors()));
110   }
111   std::string rsa_public_key_str(reinterpret_cast<char*>(rsa_public_key_bytes),
112                                  rsa_public_key_bytes_len);
113   OPENSSL_free(rsa_public_key_bytes);
114   return rsa_public_key_str;
115 }
116 
117 }  // namespace
118 
GetAndStartBigNumCtx()119 absl::StatusOr<BnCtxPtr> GetAndStartBigNumCtx() {
120   // Create context to be used in intermediate computation.
121   BnCtxPtr bn_ctx = BnCtxPtr(BN_CTX_new());
122   if (!bn_ctx.get()) {
123     return absl::InternalError("Error generating bignum context.");
124   }
125   BN_CTX_start(bn_ctx.get());
126 
127   return bn_ctx;
128 }
129 
NewBigNum()130 absl::StatusOr<bssl::UniquePtr<BIGNUM>> NewBigNum() {
131   bssl::UniquePtr<BIGNUM> bn(BN_new());
132   if (!bn.get()) {
133     return absl::InternalError("Error generating bignum.");
134   }
135   return bn;
136 }
137 
BignumToString(const BIGNUM & big_num,const size_t output_len)138 absl::StatusOr<std::string> BignumToString(const BIGNUM& big_num,
139                                            const size_t output_len) {
140   std::vector<uint8_t> serialization(output_len);
141   if (BN_bn2bin_padded(serialization.data(), serialization.size(), &big_num) !=
142       kBsslSuccess) {
143     return absl::InternalError(
144         absl::StrCat("Function BN_bn2bin_padded failed: ", GetSslErrors()));
145   }
146   return std::string(std::make_move_iterator(serialization.begin()),
147                      std::make_move_iterator(serialization.end()));
148 }
149 
StringToBignum(const absl::string_view input_str)150 absl::StatusOr<bssl::UniquePtr<BIGNUM>> StringToBignum(
151     const absl::string_view input_str) {
152   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> output, NewBigNum());
153   if (!BN_bin2bn(reinterpret_cast<const uint8_t*>(input_str.data()),
154                  input_str.size(), output.get())) {
155     return absl::InternalError(
156         absl::StrCat("Function BN_bin2bn failed: ", GetSslErrors()));
157   }
158   if (!output.get()) {
159     return absl::InternalError("Function BN_bin2bn failed.");
160   }
161   return output;
162 }
163 
GetSslErrors()164 std::string GetSslErrors() {
165   std::string ret;
166   ERR_print_errors_cb(
167       [](const char* str, size_t len, void* ctx) -> int {
168         static_cast<std::string*>(ctx)->append(str, len);
169         return 1;
170       },
171       &ret);
172   return ret;
173 }
174 
MaskMessageConcat(absl::string_view mask,absl::string_view message)175 std::string MaskMessageConcat(absl::string_view mask,
176                               absl::string_view message) {
177   return absl::StrCat(mask, message);
178 }
179 
EncodeMessagePublicMetadata(absl::string_view message,absl::string_view public_metadata)180 std::string EncodeMessagePublicMetadata(absl::string_view message,
181                                         absl::string_view public_metadata) {
182   // Prepend encoding of "msg" followed by 4 bytes representing public metadata
183   // length.
184   std::string tag = "msg";
185   std::vector<uint8_t> buffer(tag.begin(), tag.end());
186   buffer.push_back((public_metadata.size() >> 24) & 0xFF);
187   buffer.push_back((public_metadata.size() >> 16) & 0xFF);
188   buffer.push_back((public_metadata.size() >> 8) & 0xFF);
189   buffer.push_back((public_metadata.size() >> 0) & 0xFF);
190 
191   // Finally append public metadata and then the message to the output.
192   std::string encoding(buffer.begin(), buffer.end());
193   return absl::StrCat(encoding, public_metadata, message);
194 }
195 
GetRsaSqrtTwo(int x)196 absl::StatusOr<bssl::UniquePtr<BIGNUM>> GetRsaSqrtTwo(int x) {
197   // Compute hard-coded sqrt(2).
198   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> sqrt2, NewBigNum());
199   const int sqrt2_size = sizeof(internal::kBoringSSLRSASqrtTwo) /
200                          sizeof(*internal::kBoringSSLRSASqrtTwo);
201   for (int i = sqrt2_size - 2; i >= 0; i = i - 2) {
202     // Add the uint32_t values as words directly and shift.
203     // 'i' is the "hi" value of a uint64_t, and 'i+1' is the "lo" value.
204     if (BN_add_word(sqrt2.get(), internal::kBoringSSLRSASqrtTwo[i]) != 1) {
205       return absl::InternalError(absl::StrCat(
206           "Cannot add word to compute RSA sqrt(2): ", GetSslErrors()));
207     }
208     if (BN_lshift(sqrt2.get(), sqrt2.get(), 32) != 1) {
209       return absl::InternalError(absl::StrCat(
210           "Cannot shift to compute RSA sqrt(2): ", GetSslErrors()));
211     }
212     if (BN_add_word(sqrt2.get(), internal::kBoringSSLRSASqrtTwo[i + 1]) != 1) {
213       return absl::InternalError(absl::StrCat(
214           "Cannot add word to compute RSA sqrt(2): ", GetSslErrors()));
215     }
216     if (i > 0) {
217       if (BN_lshift(sqrt2.get(), sqrt2.get(), 32) != 1) {
218         return absl::InternalError(absl::StrCat(
219             "Cannot shift to compute RSA sqrt(2): ", GetSslErrors()));
220       }
221     }
222   }
223 
224   // Check that hard-coded result is correct length.
225   int sqrt2_bits = 32 * sqrt2_size;
226   if (BN_num_bits(sqrt2.get()) != sqrt2_bits) {
227     return absl::InternalError("RSA sqrt(2) is not correct length.");
228   }
229 
230   // Either shift left or right depending on value x.
231   if (sqrt2_bits > x) {
232     if (BN_rshift(sqrt2.get(), sqrt2.get(), sqrt2_bits - x) != 1) {
233       return absl::InternalError(
234           absl::StrCat("Cannot rshift to compute 2^(x-1/2): ", GetSslErrors()));
235     }
236   } else {
237     // Round up and be pessimistic about minimium factors.
238     if (BN_add_word(sqrt2.get(), 1) != 1 ||
239         BN_lshift(sqrt2.get(), sqrt2.get(), x - sqrt2_bits) != 1) {
240       return absl::InternalError(absl::StrCat(
241           "Cannot add/lshift to compute 2^(x-1/2): ", GetSslErrors()));
242     }
243   }
244 
245   // Check that 2^(x - 1/2) is correct length.
246   if (BN_num_bits(sqrt2.get()) != x) {
247     return absl::InternalError(
248         "2^(x-1/2) is not correct length after shifting.");
249   }
250 
251   return std::move(sqrt2);
252 }
253 
ComputePowerOfTwo(int x)254 absl::StatusOr<bssl::UniquePtr<BIGNUM>> ComputePowerOfTwo(int x) {
255   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> ret, NewBigNum());
256   if (BN_set_bit(ret.get(), x) != 1) {
257     return absl::InternalError(
258         absl::StrCat("Unable to set bit to compute 2^x: ", GetSslErrors()));
259   }
260   if (!BN_is_pow2(ret.get()) || !BN_is_bit_set(ret.get(), x)) {
261     return absl::InternalError(absl::StrCat("Unable to compute 2^", x, "."));
262   }
263   return ret;
264 }
265 
ComputeHash(absl::string_view input,const EVP_MD & hasher)266 absl::StatusOr<std::string> ComputeHash(absl::string_view input,
267                                         const EVP_MD& hasher) {
268   std::string digest;
269   digest.resize(EVP_MAX_MD_SIZE);
270 
271   uint32_t digest_length = 0;
272   if (EVP_Digest(input.data(), input.length(),
273                  reinterpret_cast<uint8_t*>(&digest[0]), &digest_length,
274                  &hasher, /*impl=*/nullptr) != 1) {
275     return absl::InternalError(absl::StrCat(
276         "Openssl internal error computing hash: ", GetSslErrors()));
277   }
278   digest.resize(digest_length);
279   return digest;
280 }
281 
CreatePrivateKeyRSA(const absl::string_view rsa_modulus,const absl::string_view public_exponent,const absl::string_view private_exponent,const absl::string_view p,const absl::string_view q,const absl::string_view dp,const absl::string_view dq,const absl::string_view crt)282 absl::StatusOr<bssl::UniquePtr<RSA>> CreatePrivateKeyRSA(
283     const absl::string_view rsa_modulus,
284     const absl::string_view public_exponent,
285     const absl::string_view private_exponent, const absl::string_view p,
286     const absl::string_view q, const absl::string_view dp,
287     const absl::string_view dq, const absl::string_view crt) {
288   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> n_bn,
289                                StringToBignum(rsa_modulus));
290   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> e_bn,
291                                StringToBignum(public_exponent));
292   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> d_bn,
293                                StringToBignum(private_exponent));
294   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> p_bn, StringToBignum(p));
295   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> q_bn, StringToBignum(q));
296   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> dp_bn,
297                                StringToBignum(dp));
298   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> dq_bn,
299                                StringToBignum(dq));
300   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> crt_bn,
301                                StringToBignum(crt));
302 
303   bssl::UniquePtr<RSA> rsa_private_key(
304       RSA_new_private_key(n_bn.get(), e_bn.get(), d_bn.get(), p_bn.get(),
305                           q_bn.get(), dp_bn.get(), dq_bn.get(), crt_bn.get()));
306   if (!rsa_private_key.get()) {
307     return absl::InternalError(
308         absl::StrCat("RSA_new_private_key failed: ", GetSslErrors()));
309   }
310   return rsa_private_key;
311 }
312 
CreatePublicKeyRSA(const absl::string_view rsa_modulus,const absl::string_view public_exponent)313 absl::StatusOr<bssl::UniquePtr<RSA>> CreatePublicKeyRSA(
314     const absl::string_view rsa_modulus,
315     const absl::string_view public_exponent) {
316   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> n_bn,
317                                StringToBignum(rsa_modulus));
318   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> e_bn,
319                                StringToBignum(public_exponent));
320   // Convert to OpenSSL RSA.
321   bssl::UniquePtr<RSA> rsa_public_key(
322       RSA_new_public_key(n_bn.get(), e_bn.get()));
323   if (!rsa_public_key.get()) {
324     return absl::InternalError(
325         absl::StrCat("RSA_new_public_key failed: ", GetSslErrors()));
326   }
327   return rsa_public_key;
328 }
329 
CreatePublicKeyRSAWithPublicMetadata(const BIGNUM & rsa_modulus,const BIGNUM & public_exponent,absl::string_view public_metadata,const bool use_rsa_public_exponent)330 absl::StatusOr<bssl::UniquePtr<RSA>> CreatePublicKeyRSAWithPublicMetadata(
331     const BIGNUM& rsa_modulus, const BIGNUM& public_exponent,
332     absl::string_view public_metadata, const bool use_rsa_public_exponent) {
333   bssl::UniquePtr<BIGNUM> derived_rsa_e;
334   if (use_rsa_public_exponent) {
335     ANON_TOKENS_ASSIGN_OR_RETURN(
336         derived_rsa_e, ComputeExponentWithPublicMetadataAndPublicExponent(
337                            rsa_modulus, public_exponent, public_metadata));
338   } else {
339     ANON_TOKENS_ASSIGN_OR_RETURN(
340         derived_rsa_e,
341         ComputeExponentWithPublicMetadata(rsa_modulus, public_metadata));
342   }
343   bssl::UniquePtr<RSA> rsa_public_key = bssl::UniquePtr<RSA>(
344       RSA_new_public_key_large_e(&rsa_modulus, derived_rsa_e.get()));
345   if (!rsa_public_key.get()) {
346     return absl::InternalError(
347         absl::StrCat("RSA_new_public_key_large_e failed: ", GetSslErrors()));
348   }
349   return rsa_public_key;
350 }
351 
CreatePublicKeyRSAWithPublicMetadata(const absl::string_view rsa_modulus,const absl::string_view public_exponent,const absl::string_view public_metadata,const bool use_rsa_public_exponent)352 absl::StatusOr<bssl::UniquePtr<RSA>> CreatePublicKeyRSAWithPublicMetadata(
353     const absl::string_view rsa_modulus,
354     const absl::string_view public_exponent,
355     const absl::string_view public_metadata,
356     const bool use_rsa_public_exponent) {
357   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> rsa_n,
358                                StringToBignum(rsa_modulus));
359   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> rsa_e,
360                                StringToBignum(public_exponent));
361   return CreatePublicKeyRSAWithPublicMetadata(
362       *rsa_n.get(), *rsa_e.get(), public_metadata, use_rsa_public_exponent);
363 }
364 
ComputeCarmichaelLcm(const BIGNUM & phi_p,const BIGNUM & phi_q,BN_CTX & bn_ctx)365 absl::StatusOr<bssl::UniquePtr<BIGNUM>> ComputeCarmichaelLcm(
366     const BIGNUM& phi_p, const BIGNUM& phi_q, BN_CTX& bn_ctx) {
367   // To compute lcm(phi(p), phi(q)), we first compute phi(n) =
368   // (p-1)(q-1). As n is assumed to be a safe RSA modulus (signing_key is
369   // assumed to be part of a strong rsa key pair), phi(n) = (p-1)(q-1) =
370   // (2 phi(p))(2 phi(q)) = 4 * phi(p) * phi(q) where phi(p) and phi(q) are also
371   // primes. So we get the lcm by outputting phi(n) >> 1 = 2 * phi(p) * phi(q).
372   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> phi_n, NewBigNum());
373   if (BN_mul(phi_n.get(), &phi_p, &phi_q, &bn_ctx) != 1) {
374     return absl::InternalError(
375         absl::StrCat("Unable to compute phi(n): ", GetSslErrors()));
376   }
377   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> lcm, NewBigNum());
378   if (BN_rshift1(lcm.get(), phi_n.get()) != 1) {
379     return absl::InternalError(absl::StrCat(
380         "Could not compute LCM(phi(p), phi(q)): ", GetSslErrors()));
381   }
382   return lcm;
383 }
384 
ComputeExponentWithPublicMetadata(const BIGNUM & n,absl::string_view public_metadata)385 absl::StatusOr<bssl::UniquePtr<BIGNUM>> ComputeExponentWithPublicMetadata(
386     const BIGNUM& n, absl::string_view public_metadata) {
387   // Check modulus length.
388   if (BN_num_bits(&n) % 2 == 1) {
389     return absl::InvalidArgumentError(
390         "Strong RSA modulus should be even length.");
391   }
392   int modulus_bytes = BN_num_bytes(&n);
393   // The integer modulus_bytes is expected to be a power of 2.
394   int prime_bytes = modulus_bytes / 2;
395 
396   ANON_TOKENS_ASSIGN_OR_RETURN(std::string rsa_modulus_str,
397                                BignumToString(n, modulus_bytes));
398 
399   // Get HKDF output of length prime_bytes.
400   ANON_TOKENS_ASSIGN_OR_RETURN(
401       bssl::UniquePtr<BIGNUM> exponent,
402       internal::PublicMetadataHashWithHKDF(public_metadata, rsa_modulus_str,
403                                            prime_bytes));
404 
405   // We need to generate random odd exponents < 2^(primes_bits - 2) where
406   // prime_bits = prime_bytes * 8. This will guarantee that the resulting
407   // exponent is coprime to phi(N) = 4p'q' as 2^(prime_bits - 2) < p', q' <
408   // 2^(prime_bits - 1).
409   //
410   // To do this, we can truncate the HKDF output (exponent) which is prime_bits
411   // long, to prime_bits - 2, by clearing its top two bits. We then set the
412   // least significant bit to 1. This way the final exponent will be less than
413   // 2^(primes_bits - 2) and will always be odd.
414   if (BN_clear_bit(exponent.get(), (prime_bytes * 8) - 1) != kBsslSuccess ||
415       BN_clear_bit(exponent.get(), (prime_bytes * 8) - 2) != kBsslSuccess ||
416       BN_set_bit(exponent.get(), 0) != kBsslSuccess) {
417     return absl::InvalidArgumentError(absl::StrCat(
418         "Could not clear the two most significant bits and set the least "
419         "significant bit to zero: ",
420         GetSslErrors()));
421   }
422   // Check that exponent is small enough to ensure it is coprime to phi(n).
423   if (BN_num_bits(exponent.get()) >= (8 * prime_bytes - 1)) {
424     return absl::InternalError("Generated exponent is too large.");
425   }
426 
427   return exponent;
428 }
429 
430 absl::StatusOr<bssl::UniquePtr<BIGNUM>>
ComputeExponentWithPublicMetadataAndPublicExponent(const BIGNUM & n,const BIGNUM & e,absl::string_view public_metadata)431 ComputeExponentWithPublicMetadataAndPublicExponent(
432     const BIGNUM& n, const BIGNUM& e, absl::string_view public_metadata) {
433   ANON_TOKENS_ASSIGN_OR_RETURN(
434       bssl::UniquePtr<BIGNUM> md_exp,
435       ComputeExponentWithPublicMetadata(n, public_metadata));
436   ANON_TOKENS_ASSIGN_OR_RETURN(BnCtxPtr bn_ctx, GetAndStartBigNumCtx());
437   // new_e=e*md_exp
438   ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> new_e, NewBigNum());
439   if (BN_mul(new_e.get(), md_exp.get(), &e, bn_ctx.get()) != kBsslSuccess) {
440     return absl::InternalError(
441         absl::StrCat("Unable to multiply e with md_exp: ", GetSslErrors()));
442   }
443   return new_e;
444 }
445 
RsaBlindSignatureVerify(const int salt_length,const EVP_MD * sig_hash,const EVP_MD * mgf1_hash,const absl::string_view signature,const absl::string_view message,RSA * rsa_public_key)446 absl::Status RsaBlindSignatureVerify(const int salt_length,
447                                      const EVP_MD* sig_hash,
448                                      const EVP_MD* mgf1_hash,
449                                      const absl::string_view signature,
450                                      const absl::string_view message,
451                                      RSA* rsa_public_key) {
452   ANON_TOKENS_ASSIGN_OR_RETURN(std::string message_digest,
453                                ComputeHash(message, *sig_hash));
454   const int hash_size = EVP_MD_size(sig_hash);
455   // Make sure the size of the digest is correct.
456   if (message_digest.size() != hash_size) {
457     return absl::InvalidArgumentError(
458         absl::StrCat("Size of the digest doesn't match the one "
459                      "of the hashing algorithm; expected ",
460                      hash_size, " got ", message_digest.size()));
461   }
462   // Make sure the size of the signature is correct.
463   const int rsa_modulus_size = BN_num_bytes(RSA_get0_n(rsa_public_key));
464   if (signature.size() != rsa_modulus_size) {
465     return absl::InvalidArgumentError(
466         "Signature size not equal to modulus size.");
467   }
468 
469   std::string recovered_message_digest(rsa_modulus_size, 0);
470   int recovered_message_digest_size = RSA_public_decrypt(
471       /*flen=*/signature.size(),
472       /*from=*/reinterpret_cast<const uint8_t*>(signature.data()),
473       /*to=*/
474       reinterpret_cast<uint8_t*>(recovered_message_digest.data()),
475       /*rsa=*/rsa_public_key,
476       /*padding=*/RSA_NO_PADDING);
477   if (recovered_message_digest_size != rsa_modulus_size) {
478     return absl::InvalidArgumentError(
479         absl::StrCat("Invalid signature size (likely an incorrect key is "
480                      "used); expected ",
481                      rsa_modulus_size, " got ", recovered_message_digest_size,
482                      ": ", GetSslErrors()));
483   }
484   if (RSA_verify_PKCS1_PSS_mgf1(
485           rsa_public_key, reinterpret_cast<const uint8_t*>(&message_digest[0]),
486           sig_hash, mgf1_hash,
487           reinterpret_cast<const uint8_t*>(recovered_message_digest.data()),
488           salt_length) != kBsslSuccess) {
489     return absl::InvalidArgumentError(
490         absl::StrCat("PSS padding verification failed: ", GetSslErrors()));
491   }
492   return absl::OkStatus();
493 }
494 
RsaSsaPssPublicKeyToDerEncoding(const RSA * rsa)495 absl::StatusOr<std::string> RsaSsaPssPublicKeyToDerEncoding(const RSA* rsa) {
496   if (rsa == NULL) {
497     return absl::InvalidArgumentError("Public Key rsa is null.");
498   }
499   // Create DER encoded RSA public key string.
500   ANON_TOKENS_ASSIGN_OR_RETURN(std::string rsa_public_key_str,
501                                MarshalRsaPublicKey(rsa));
502   // Main CRYPTO ByteBuilder object cbb which will be passed to CBB_finish to
503   // finalize and output the DER encoding of the RsaSsaPssPublicKey.
504   bssl::ScopedCBB cbb;
505   // initial_capacity only serves as a hint.
506   if (!CBB_init(cbb.get(), /*initial_capacity=*/2 * RSA_size(rsa))) {
507     return absl::InternalError("CBB_init() failed.");
508   }
509 
510   // Temporary CBB objects to write ASN1 sequences and object identifiers into.
511   CBB outer_seq, inner_seq, param_seq, sha384_seq, mgf1_seq, mgf1_sha384_seq;
512   CBB param0_tag, param1_tag, param2_tag;
513   CBB rsassa_pss_oid, sha384_oid, mgf1_oid, mgf1_sha384_oid;
514   CBB public_key_bit_str_cbb;
515   // RsaSsaPssPublicKey ASN.1 structure example:
516   //
517   //  SEQUENCE {                                               # outer_seq
518   //    SEQUENCE {                                             # inner_seq
519   //      OBJECT_IDENTIFIER{1.2.840.113549.1.1.10}             # rsassa_pss_oid
520   //      SEQUENCE {                                           # param_seq
521   //        [0] {                                              # param0_tag
522   //              {                                            # sha384_seq
523   //                OBJECT_IDENTIFIER{2.16.840.1.101.3.4.2.2}  # sha384_oid
524   //              }
525   //            }
526   //        [1] {                                              # param1_tag
527   //              {                                            # mgf1_seq
528   //                OBJECT_IDENTIFIER{1.2.840.113549.1.1.8}    # mgf1_oid
529   //                {                                          # mgf1_sha384_seq
530   //                  OBJECT_IDENTIFIER{2.16.840.1.101.3.4.2.2}# mgf1_sha384_oid
531   //                }
532   //              }
533   //            }
534   //        [2] {                                              # param2_tag
535   //              INTEGER { 48 }                               # salt length
536   //            }
537   //      }
538   //    }
539   //    BIT STRING {                                    # public_key_bit_str_cbb
540   //      0                                             # unused bits
541   //      der_encoded_rsa_public_key_structure
542   //    }
543   //  }
544   //
545   // Start with the outer sequence.
546   if (!CBB_add_asn1(cbb.get(), &outer_seq, CBS_ASN1_SEQUENCE) ||
547       // The outer sequence consists of two parts; the inner sequence and the
548       // encoded rsa public key.
549       //
550       // Add the inner sequence to the outer sequence.
551       !CBB_add_asn1(&outer_seq, &inner_seq, CBS_ASN1_SEQUENCE) ||
552       // Add object identifier for RSASSA-PSS algorithm to the inner sequence.
553       !CBB_add_asn1(&inner_seq, &rsassa_pss_oid, CBS_ASN1_OBJECT) ||
554       !CBB_add_asn1_oid_from_text(&rsassa_pss_oid, kRsaSsaPssOid,
555                                   strlen(kRsaSsaPssOid)) ||
556       // Add a parameter sequence to the inner sequence.
557       !CBB_add_asn1(&inner_seq, &param_seq, CBS_ASN1_SEQUENCE) ||
558       // SHA384 hash function algorithm identifier will be parameter 0 in the
559       // parameter sequence.
560       !CBB_add_asn1(&param_seq, &param0_tag,
561                     CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 0) ||
562       !CBB_add_asn1(&param0_tag, &sha384_seq, CBS_ASN1_SEQUENCE) ||
563       // Add SHA384 object identifier to finish the SHA384 algorithm identifier
564       // and parameter 0.
565       !CBB_add_asn1(&sha384_seq, &sha384_oid, CBS_ASN1_OBJECT) ||
566       !CBB_add_asn1_oid_from_text(&sha384_oid, kSha384Oid,
567                                   strlen(kSha384Oid)) ||
568       // mgf1-SHA384 algorithm identifier as parameter 1 to the parameter
569       // sequence.
570       !CBB_add_asn1(&param_seq, &param1_tag,
571                     CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 1) ||
572       !CBB_add_asn1(&param1_tag, &mgf1_seq, CBS_ASN1_SEQUENCE) ||
573       // Add mgf1 object identifier to the mgf1-SHA384 algorithm identifier.
574       !CBB_add_asn1(&mgf1_seq, &mgf1_oid, CBS_ASN1_OBJECT) ||
575       !CBB_add_asn1_oid_from_text(&mgf1_oid, kRsaSsaPssMgf1Oid,
576                                   strlen(kRsaSsaPssMgf1Oid)) ||
577       // Add SHA384 algorithm identifier to the mgf1-SHA384 algorithm
578       // identifier.
579       !CBB_add_asn1(&mgf1_seq, &mgf1_sha384_seq, CBS_ASN1_SEQUENCE) ||
580       // Add SHA384 object identifier to finish SHA384 algorithm identifier,
581       // mgf1-SHA384 algorithm identifier and parameter 1.
582       !CBB_add_asn1(&mgf1_sha384_seq, &mgf1_sha384_oid, CBS_ASN1_OBJECT) ||
583       !CBB_add_asn1_oid_from_text(&mgf1_sha384_oid, kSha384Oid,
584                                   strlen(kSha384Oid)) ||
585       // Add salt length as parameter 2 to the parameter sequence to finish the
586       // parameter sequence and the inner sequence.
587       !CBB_add_asn1(&param_seq, &param2_tag,
588                     CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 2) ||
589       !CBB_add_asn1_int64(&param2_tag, kSaltLengthInBytes48) ||
590       // Add public key to the outer sequence as an ASN1 bitstring.
591       !CBB_add_asn1(&outer_seq, &public_key_bit_str_cbb, CBS_ASN1_BITSTRING) ||
592       !CBB_add_u8(&public_key_bit_str_cbb, 0 /* no unused bits */) ||
593       !CBB_add_bytes(
594           &public_key_bit_str_cbb,
595           reinterpret_cast<const uint8_t*>(rsa_public_key_str.data()),
596           rsa_public_key_str.size())) {
597     return absl::InvalidArgumentError(
598         "Failed to set the crypto byte builder object.");
599   }
600   // Finish creating the DER-encoding of RsaSsaPssPublicKey.
601   uint8_t* rsa_ssa_pss_public_key_der;
602   size_t rsa_ssa_pss_public_key_der_len;
603   if (!CBB_finish(cbb.get(), &rsa_ssa_pss_public_key_der,
604                   &rsa_ssa_pss_public_key_der_len)) {
605     return absl::InternalError("CBB_finish() failed.");
606   }
607   std::string rsa_ssa_pss_public_key_der_str(
608       reinterpret_cast<const char*>(rsa_ssa_pss_public_key_der),
609       rsa_ssa_pss_public_key_der_len);
610   // Free memory.
611   OPENSSL_free(rsa_ssa_pss_public_key_der);
612   // Return the DER encoding as string.
613   return rsa_ssa_pss_public_key_der_str;
614 }
615 
616 }  // namespace anonymous_tokens
617 
618