1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/crypto_utils.h"
16
17 #include <stddef.h>
18 #include <stdint.h>
19
20 #include <cstdint>
21 #include <iterator>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/string_view.h"
30 #include "quiche/blind_sign_auth/anonymous_tokens/cpp/crypto/constants.h"
31 #include "quiche/blind_sign_auth/anonymous_tokens/cpp/shared/status_utils.h"
32 #include "quiche/blind_sign_auth/anonymous_tokens/proto/anonymous_tokens.pb.h"
33 #include "openssl/err.h"
34 #include "openssl/hkdf.h"
35 #include "openssl/rand.h"
36 #include "openssl/rsa.h"
37
38 namespace private_membership {
39 namespace anonymous_tokens {
40
41 namespace internal {
42
43 // Approximation of sqrt(2) taken from
44 // //depot/google3/third_party/openssl/boringssl/src/crypto/fipsmodule/rsa/rsa_impl.c;l=997
45 const std::vector<uint32_t> kBoringSSLRSASqrtTwo = {
46 0x4d7c60a5, 0xe633e3e1, 0x5fcf8f7b, 0xca3ea33b, 0xc246785e, 0x92957023,
47 0xf9acce41, 0x797f2805, 0xfdfe170f, 0xd3b1f780, 0xd24f4a76, 0x3facb882,
48 0x18838a2e, 0xaff5f3b2, 0xc1fcbdde, 0xa2f7dc33, 0xdea06241, 0xf7aa81c2,
49 0xf6a1be3f, 0xca221307, 0x332a5e9f, 0x7bda1ebf, 0x0104dc01, 0xfe32352f,
50 0xb8cf341b, 0x6f8236c7, 0x4264dabc, 0xd528b651, 0xf4d3a02c, 0xebc93e0c,
51 0x81394ab6, 0xd8fd0efd, 0xeaa4a089, 0x9040ca4a, 0xf52f120f, 0x836e582e,
52 0xcb2a6343, 0x31f3c84d, 0xc6d5a8a3, 0x8bb7e9dc, 0x460abc72, 0x2f7c4e33,
53 0xcab1bc91, 0x1688458a, 0x53059c60, 0x11bc337b, 0xd2202e87, 0x42af1f4e,
54 0x78048736, 0x3dfa2768, 0x0f74a85e, 0x439c7b4a, 0xa8b1fe6f, 0xdc83db39,
55 0x4afc8304, 0x3ab8a2c3, 0xed17ac85, 0x83339915, 0x1d6f60ba, 0x893ba84c,
56 0x597d89b3, 0x754abe9f, 0xb504f333, 0xf9de6484,
57 };
58
PublicMetadataHashWithHKDF(absl::string_view public_metadata,absl::string_view rsa_modulus_str,size_t out_len_bytes)59 absl::StatusOr<bssl::UniquePtr<BIGNUM>> PublicMetadataHashWithHKDF(
60 absl::string_view public_metadata, absl::string_view rsa_modulus_str,
61 size_t out_len_bytes) {
62 const EVP_MD* evp_md_sha_384 = EVP_sha384();
63 // Prepend "key" to input.
64 std::string modified_input = absl::StrCat("key", public_metadata);
65 std::vector<uint8_t> input_buffer(modified_input.begin(),
66 modified_input.end());
67 // Append 0x00 to input.
68 input_buffer.push_back(0x00);
69 std::string out_e;
70 // We set the out_e size beyond out_len_bytes so that out_e bytes are
71 // indifferentiable from truly random bytes even after truncations.
72 //
73 // Expanding to 16 more bytes is sufficient.
74 // https://cfrg.github.io/draft-irtf-cfrg-hash-to-curve/draft-irtf-cfrg-hash-to-curve.html#name-hashing-to-a-finite-field
75 const size_t hkdf_output_size = out_len_bytes + 16;
76 out_e.resize(hkdf_output_size);
77 // The modulus is used as salt to ensure different outputs for same metadata
78 // and different modulus.
79 if (HKDF(reinterpret_cast<uint8_t*>(out_e.data()), hkdf_output_size,
80 evp_md_sha_384, input_buffer.data(), input_buffer.size(),
81 reinterpret_cast<const uint8_t*>(rsa_modulus_str.data()),
82 rsa_modulus_str.size(),
83 reinterpret_cast<const uint8_t*>(kHkdfPublicMetadataInfo.data()),
84 kHkdfPublicMetadataInfoSizeInBytes) != kBsslSuccess) {
85 return absl::InternalError("HKDF failed in public_metadata_crypto_utils");
86 }
87 // Truncate out_e to out_len_bytes
88 out_e.resize(out_len_bytes);
89 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> out,
90 StringToBignum(out_e));
91 return out;
92 }
93
94 } // namespace internal
95
GetAndStartBigNumCtx()96 absl::StatusOr<BnCtxPtr> GetAndStartBigNumCtx() {
97 // Create context to be used in intermediate computation.
98 BnCtxPtr bn_ctx = BnCtxPtr(BN_CTX_new());
99 if (!bn_ctx.get()) {
100 return absl::InternalError("Error generating bignum context.");
101 }
102 BN_CTX_start(bn_ctx.get());
103
104 return bn_ctx;
105 }
106
NewBigNum()107 absl::StatusOr<bssl::UniquePtr<BIGNUM>> NewBigNum() {
108 bssl::UniquePtr<BIGNUM> bn(BN_new());
109 if (!bn.get()) {
110 return absl::InternalError("Error generating bignum.");
111 }
112 return bn;
113 }
114
BignumToString(const BIGNUM & big_num,const size_t output_len)115 absl::StatusOr<std::string> BignumToString(const BIGNUM& big_num,
116 const size_t output_len) {
117 std::vector<uint8_t> serialization(output_len);
118 if (BN_bn2bin_padded(serialization.data(), serialization.size(), &big_num) !=
119 kBsslSuccess) {
120 return absl::InternalError(
121 absl::StrCat("Function BN_bn2bin_padded failed: ", GetSslErrors()));
122 }
123 return std::string(std::make_move_iterator(serialization.begin()),
124 std::make_move_iterator(serialization.end()));
125 }
126
StringToBignum(const absl::string_view input_str)127 absl::StatusOr<bssl::UniquePtr<BIGNUM>> StringToBignum(
128 const absl::string_view input_str) {
129 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> output, NewBigNum());
130 if (!BN_bin2bn(reinterpret_cast<const uint8_t*>(input_str.data()),
131 input_str.size(), output.get())) {
132 return absl::InternalError(
133 absl::StrCat("Function BN_bin2bn failed: ", GetSslErrors()));
134 }
135 if (!output.get()) {
136 return absl::InternalError("Function BN_bin2bn failed.");
137 }
138 return output;
139 }
140
GetSslErrors()141 std::string GetSslErrors() {
142 std::string ret;
143 ERR_print_errors_cb(
144 [](const char* str, size_t len, void* ctx) -> int {
145 static_cast<std::string*>(ctx)->append(str, len);
146 return 1;
147 },
148 &ret);
149 return ret;
150 }
151
GenerateMask(const RSABlindSignaturePublicKey & public_key)152 absl::StatusOr<std::string> GenerateMask(
153 const RSABlindSignaturePublicKey& public_key) {
154 std::string mask;
155 if (public_key.message_mask_type() == AT_MESSAGE_MASK_CONCAT &&
156 public_key.message_mask_size() >= kRsaMessageMaskSizeInBytes32) {
157 mask = std::string(public_key.message_mask_size(), '\0');
158 RAND_bytes(reinterpret_cast<uint8_t*>(mask.data()), mask.size());
159 } else {
160 return absl::InvalidArgumentError(
161 "Undefined or unsupported message mask type.");
162 }
163 return mask;
164 }
165
MaskMessageConcat(absl::string_view mask,absl::string_view message)166 std::string MaskMessageConcat(absl::string_view mask,
167 absl::string_view message) {
168 return absl::StrCat(mask, message);
169 }
170
EncodeMessagePublicMetadata(absl::string_view message,absl::string_view public_metadata)171 std::string EncodeMessagePublicMetadata(absl::string_view message,
172 absl::string_view public_metadata) {
173 // Prepend encoding of "msg" followed by 4 bytes representing public metadata
174 // length.
175 std::string tag = "msg";
176 std::vector<uint8_t> buffer(tag.begin(), tag.end());
177 buffer.push_back((public_metadata.size() >> 24) & 0xFF);
178 buffer.push_back((public_metadata.size() >> 16) & 0xFF);
179 buffer.push_back((public_metadata.size() >> 8) & 0xFF);
180 buffer.push_back((public_metadata.size() >> 0) & 0xFF);
181
182 // Finally append public metadata and then the message to the output.
183 std::string encoding(buffer.begin(), buffer.end());
184 return absl::StrCat(encoding, public_metadata, message);
185 }
186
ProtoHashTypeToEVPDigest(const HashType hash_type)187 absl::StatusOr<const EVP_MD*> ProtoHashTypeToEVPDigest(
188 const HashType hash_type) {
189 switch (hash_type) {
190 case AT_HASH_TYPE_SHA256:
191 return EVP_sha256();
192 case AT_HASH_TYPE_SHA384:
193 return EVP_sha384();
194 case AT_HASH_TYPE_UNDEFINED:
195 default:
196 return absl::InvalidArgumentError("Unknown hash type.");
197 }
198 }
199
ProtoMaskGenFunctionToEVPDigest(const MaskGenFunction mgf)200 absl::StatusOr<const EVP_MD*> ProtoMaskGenFunctionToEVPDigest(
201 const MaskGenFunction mgf) {
202 switch (mgf) {
203 case AT_MGF_SHA256:
204 return EVP_sha256();
205 case AT_MGF_SHA384:
206 return EVP_sha384();
207 case AT_MGF_UNDEFINED:
208 default:
209 return absl::InvalidArgumentError(
210 "Unknown hash type for mask generation hash function.");
211 }
212 }
213
GetRsaSqrtTwo(int x)214 absl::StatusOr<bssl::UniquePtr<BIGNUM>> GetRsaSqrtTwo(int x) {
215 // Compute hard-coded sqrt(2).
216 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> sqrt2, NewBigNum());
217 // TODO(b/277606961): simplify RsaSqrtTwo initialization logic
218 for (int i = internal::kBoringSSLRSASqrtTwo.size() - 2; i >= 0; i = i - 2) {
219 // Add the uint32_t values as words directly and shift.
220 // 'i' is the "hi" value of a uint64_t, and 'i+1' is the "lo" value.
221 if (BN_add_word(sqrt2.get(), internal::kBoringSSLRSASqrtTwo[i]) != 1) {
222 return absl::InternalError(absl::StrCat(
223 "Cannot add word to compute RSA sqrt(2): ", GetSslErrors()));
224 }
225 if (BN_lshift(sqrt2.get(), sqrt2.get(), 32) != 1) {
226 return absl::InternalError(absl::StrCat(
227 "Cannot shift to compute RSA sqrt(2): ", GetSslErrors()));
228 }
229 if (BN_add_word(sqrt2.get(), internal::kBoringSSLRSASqrtTwo[i+1]) != 1) {
230 return absl::InternalError(absl::StrCat(
231 "Cannot add word to compute RSA sqrt(2): ", GetSslErrors()));
232 }
233 if (i > 0) {
234 if (BN_lshift(sqrt2.get(), sqrt2.get(), 32) != 1) {
235 return absl::InternalError(absl::StrCat(
236 "Cannot shift to compute RSA sqrt(2): ", GetSslErrors()));
237 }
238 }
239 }
240
241 // Check that hard-coded result is correct length.
242 int sqrt2_bits = 32 * internal::kBoringSSLRSASqrtTwo.size();
243 if (BN_num_bits(sqrt2.get()) != sqrt2_bits) {
244 return absl::InternalError("RSA sqrt(2) is not correct length.");
245 }
246
247 // Either shift left or right depending on value x.
248 if (sqrt2_bits > x) {
249 if (BN_rshift(sqrt2.get(), sqrt2.get(), sqrt2_bits - x) != 1) {
250 return absl::InternalError(
251 absl::StrCat("Cannot rshift to compute 2^(x-1/2): ", GetSslErrors()));
252 }
253 } else {
254 // Round up and be pessimistic about minimium factors.
255 if (BN_add_word(sqrt2.get(), 1) != 1 ||
256 BN_lshift(sqrt2.get(), sqrt2.get(), x - sqrt2_bits) != 1) {
257 return absl::InternalError(absl::StrCat(
258 "Cannot add/lshift to compute 2^(x-1/2): ", GetSslErrors()));
259 }
260 }
261
262 // Check that 2^(x - 1/2) is correct length.
263 if (BN_num_bits(sqrt2.get()) != x) {
264 return absl::InternalError(
265 "2^(x-1/2) is not correct length after shifting.");
266 }
267
268 return std::move(sqrt2);
269 }
270
ComputePowerOfTwo(int x)271 absl::StatusOr<bssl::UniquePtr<BIGNUM>> ComputePowerOfTwo(int x) {
272 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> ret, NewBigNum());
273 if (BN_set_bit(ret.get(), x) != 1) {
274 return absl::InternalError(
275 absl::StrCat("Unable to set bit to compute 2^x: ", GetSslErrors()));
276 }
277 if (!BN_is_pow2(ret.get()) || !BN_is_bit_set(ret.get(), x)) {
278 return absl::InternalError(absl::StrCat("Unable to compute 2^", x, "."));
279 }
280 return ret;
281 }
282
ComputeHash(absl::string_view input,const EVP_MD & hasher)283 absl::StatusOr<std::string> ComputeHash(absl::string_view input,
284 const EVP_MD& hasher) {
285 std::string digest;
286 digest.resize(EVP_MAX_MD_SIZE);
287
288 uint32_t digest_length = 0;
289 if (EVP_Digest(input.data(), input.length(),
290 reinterpret_cast<uint8_t*>(&digest[0]), &digest_length,
291 &hasher, /*impl=*/nullptr) != 1) {
292 return absl::InternalError(absl::StrCat(
293 "Openssl internal error computing hash: ", GetSslErrors()));
294 }
295 digest.resize(digest_length);
296 return digest;
297 }
298
AnonymousTokensRSAPrivateKeyToRSA(const RSAPrivateKey & private_key)299 absl::StatusOr<bssl::UniquePtr<RSA>> AnonymousTokensRSAPrivateKeyToRSA(
300 const RSAPrivateKey& private_key) {
301 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> n,
302 StringToBignum(private_key.n()));
303 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> e,
304 StringToBignum(private_key.e()));
305 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> d,
306 StringToBignum(private_key.d()));
307 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> p,
308 StringToBignum(private_key.p()));
309 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> q,
310 StringToBignum(private_key.q()));
311 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> dp,
312 StringToBignum(private_key.dp()));
313 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> dq,
314 StringToBignum(private_key.dq()));
315 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> crt,
316 StringToBignum(private_key.crt()));
317
318 bssl::UniquePtr<RSA> rsa_private_key(RSA_new());
319 // Populate private key.
320 if (!rsa_private_key.get()) {
321 return absl::InternalError(
322 absl::StrCat("RSA_new failed: ", GetSslErrors()));
323 } else if (RSA_set0_key(rsa_private_key.get(), n.get(), e.get(), d.get()) !=
324 kBsslSuccess) {
325 return absl::InternalError(
326 absl::StrCat("RSA_set0_key failed: ", GetSslErrors()));
327 } else if (RSA_set0_factors(rsa_private_key.get(), p.get(), q.get()) !=
328 kBsslSuccess) {
329 return absl::InternalError(
330 absl::StrCat("RSA_set0_factors failed: ", GetSslErrors()));
331 } else if (RSA_set0_crt_params(rsa_private_key.get(), dp.get(), dq.get(),
332 crt.get()) != kBsslSuccess) {
333 return absl::InternalError(
334 absl::StrCat("RSA_set0_crt_params failed: ", GetSslErrors()));
335 } else {
336 n.release();
337 e.release();
338 d.release();
339 p.release();
340 q.release();
341 dp.release();
342 dq.release();
343 crt.release();
344 }
345 return std::move(rsa_private_key);
346 }
347
AnonymousTokensRSAPublicKeyToRSA(const RSAPublicKey & public_key)348 absl::StatusOr<bssl::UniquePtr<RSA>> AnonymousTokensRSAPublicKeyToRSA(
349 const RSAPublicKey& public_key) {
350 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> rsa_modulus,
351 StringToBignum(public_key.n()));
352 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> rsa_e,
353 StringToBignum(public_key.e()));
354 // Convert to OpenSSL RSA.
355 bssl::UniquePtr<RSA> rsa_public_key(RSA_new());
356 if (!rsa_public_key.get()) {
357 return absl::InternalError(
358 absl::StrCat("RSA_new failed: ", GetSslErrors()));
359 } else if (RSA_set0_key(rsa_public_key.get(), rsa_modulus.get(), rsa_e.get(),
360 nullptr) != kBsslSuccess) {
361 return absl::InternalError(
362 absl::StrCat("RSA_set0_key failed: ", GetSslErrors()));
363 }
364 // RSA_set0_key takes ownership of the pointers under rsa_modulus, new_e on
365 // success.
366 rsa_modulus.release();
367 rsa_e.release();
368 return rsa_public_key;
369 }
370
ComputeCarmichaelLcm(const BIGNUM & phi_p,const BIGNUM & phi_q,BN_CTX & bn_ctx)371 absl::StatusOr<bssl::UniquePtr<BIGNUM>> ComputeCarmichaelLcm(
372 const BIGNUM& phi_p, const BIGNUM& phi_q, BN_CTX& bn_ctx) {
373 // To compute lcm(phi(p), phi(q)), we first compute phi(n) =
374 // (p-1)(q-1). As n is assumed to be a safe RSA modulus (signing_key is
375 // assumed to be part of a strong rsa key pair), phi(n) = (p-1)(q-1) =
376 // (2 phi(p))(2 phi(q)) = 4 * phi(p) * phi(q) where phi(p) and phi(q) are also
377 // primes. So we get the lcm by outputting phi(n) >> 1 = 2 * phi(p) * phi(q).
378 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> phi_n, NewBigNum());
379 if (BN_mul(phi_n.get(), &phi_p, &phi_q, &bn_ctx) != 1) {
380 return absl::InternalError(
381 absl::StrCat("Unable to compute phi(n): ", GetSslErrors()));
382 }
383 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> lcm, NewBigNum());
384 if (BN_rshift1(lcm.get(), phi_n.get()) != 1) {
385 return absl::InternalError(absl::StrCat(
386 "Could not compute LCM(phi(p), phi(q)): ", GetSslErrors()));
387 }
388 return lcm;
389 }
390
PublicMetadataExponent(const BIGNUM & n,absl::string_view public_metadata)391 absl::StatusOr<bssl::UniquePtr<BIGNUM>> PublicMetadataExponent(
392 const BIGNUM& n, absl::string_view public_metadata) {
393 // Check modulus length.
394 if (BN_num_bits(&n) % 2 == 1) {
395 return absl::InvalidArgumentError(
396 "Strong RSA modulus should be even length.");
397 }
398 int modulus_bytes = BN_num_bytes(&n);
399 // The integer modulus_bytes is expected to be a power of 2.
400 int prime_bytes = modulus_bytes / 2;
401
402 ANON_TOKENS_ASSIGN_OR_RETURN(std::string rsa_modulus_str,
403 BignumToString(n, modulus_bytes));
404
405 // Get HKDF output of length prime_bytes.
406 ANON_TOKENS_ASSIGN_OR_RETURN(
407 bssl::UniquePtr<BIGNUM> exponent,
408 internal::PublicMetadataHashWithHKDF(public_metadata, rsa_modulus_str,
409 prime_bytes));
410
411 // We need to generate random odd exponents < 2^(primes_bits - 2) where
412 // prime_bits = prime_bytes * 8. This will guarantee that the resulting
413 // exponent is coprime to phi(N) = 4p'q' as 2^(prime_bits - 2) < p', q' <
414 // 2^(prime_bits - 1).
415 //
416 // To do this, we can truncate the HKDF output (exponent) which is prime_bits
417 // long, to prime_bits - 2, by clearing its top two bits. We then set the
418 // least significant bit to 1. This way the final exponent will be less than
419 // 2^(primes_bits - 2) and will always be odd.
420 if (BN_clear_bit(exponent.get(), (prime_bytes * 8) - 1) != kBsslSuccess ||
421 BN_clear_bit(exponent.get(), (prime_bytes * 8) - 2) != kBsslSuccess ||
422 BN_set_bit(exponent.get(), 0) != kBsslSuccess) {
423 return absl::InvalidArgumentError(absl::StrCat(
424 "Could not clear the two most significant bits and set the least "
425 "significant bit to zero: ",
426 GetSslErrors()));
427 }
428 // Check that exponent is small enough to ensure it is coprime to phi(n).
429 if (BN_num_bits(exponent.get()) >= (8 * prime_bytes - 1)) {
430 return absl::InternalError("Generated exponent is too large.");
431 }
432
433 return exponent;
434 }
435
ComputeFinalExponentUnderPublicMetadata(const BIGNUM & n,const BIGNUM & e,absl::string_view public_metadata)436 absl::StatusOr<bssl::UniquePtr<BIGNUM>> ComputeFinalExponentUnderPublicMetadata(
437 const BIGNUM& n, const BIGNUM& e, absl::string_view public_metadata) {
438 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> md_exp,
439 PublicMetadataExponent(n, public_metadata));
440 ANON_TOKENS_ASSIGN_OR_RETURN(BnCtxPtr bn_ctx, GetAndStartBigNumCtx());
441 // new_e=e*md_exp
442 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> new_e, NewBigNum());
443 if (BN_mul(new_e.get(), md_exp.get(), &e, bn_ctx.get()) != kBsslSuccess) {
444 return absl::InternalError(
445 absl::StrCat("Unable to multiply e with md_exp: ", GetSslErrors()));
446 }
447 return new_e;
448 }
449
RsaBlindSignatureVerify(const int salt_length,const EVP_MD * sig_hash,const EVP_MD * mgf1_hash,RSA * rsa_public_key,const BIGNUM & rsa_modulus,const BIGNUM & augmented_rsa_e,absl::string_view signature,absl::string_view message,std::optional<absl::string_view> public_metadata)450 absl::Status RsaBlindSignatureVerify(
451 const int salt_length, const EVP_MD* sig_hash, const EVP_MD* mgf1_hash,
452 RSA* rsa_public_key, const BIGNUM& rsa_modulus,
453 const BIGNUM& augmented_rsa_e, absl::string_view signature,
454 absl::string_view message,
455 std::optional<absl::string_view> public_metadata) {
456 std::string augmented_message(message);
457 if (public_metadata.has_value()) {
458 augmented_message = EncodeMessagePublicMetadata(message, *public_metadata);
459 }
460 ANON_TOKENS_ASSIGN_OR_RETURN(std::string message_digest,
461 ComputeHash(augmented_message, *sig_hash));
462 const int hash_size = EVP_MD_size(sig_hash);
463 // Make sure the size of the digest is correct.
464 if (message_digest.size() != hash_size) {
465 return absl::InvalidArgumentError(
466 absl::StrCat("Size of the digest doesn't match the one "
467 "of the hashing algorithm; expected ",
468 hash_size, " got ", message_digest.size()));
469 }
470 const int rsa_modulus_size = BN_num_bytes(&rsa_modulus);
471 if (signature.size() != rsa_modulus_size) {
472 return absl::InvalidArgumentError(
473 "Signature size not equal to modulus size.");
474 }
475
476 std::string recovered_message_digest(rsa_modulus_size, 0);
477 if (!public_metadata.has_value()) {
478 int recovered_message_digest_size = RSA_public_decrypt(
479 /*flen=*/signature.size(),
480 /*from=*/reinterpret_cast<const uint8_t*>(signature.data()),
481 /*to=*/
482 reinterpret_cast<uint8_t*>(recovered_message_digest.data()),
483 /*rsa=*/rsa_public_key,
484 /*padding=*/RSA_NO_PADDING);
485 if (recovered_message_digest_size != rsa_modulus_size) {
486 return absl::InvalidArgumentError(
487 absl::StrCat("Invalid signature size (likely an incorrect key is "
488 "used); expected ",
489 rsa_modulus_size, " got ", recovered_message_digest_size,
490 ": ", GetSslErrors()));
491 }
492 } else {
493 ANON_TOKENS_ASSIGN_OR_RETURN(bssl::UniquePtr<BIGNUM> signature_bn,
494 StringToBignum(signature));
495 if (BN_ucmp(signature_bn.get(), &rsa_modulus) >= 0) {
496 return absl::InternalError("Data too large for modulus.");
497 }
498 ANON_TOKENS_ASSIGN_OR_RETURN(BnCtxPtr bn_ctx, GetAndStartBigNumCtx());
499 bssl::UniquePtr<BN_MONT_CTX> bn_mont_ctx(
500 BN_MONT_CTX_new_for_modulus(&rsa_modulus, bn_ctx.get()));
501 if (!bn_mont_ctx) {
502 return absl::InternalError("BN_MONT_CTX_new_for_modulus failed.");
503 }
504 ANON_TOKENS_ASSIGN_OR_RETURN(
505 bssl::UniquePtr<BIGNUM> recovered_message_digest_bn, NewBigNum());
506 if (BN_mod_exp_mont(recovered_message_digest_bn.get(), signature_bn.get(),
507 &augmented_rsa_e, &rsa_modulus, bn_ctx.get(),
508 bn_mont_ctx.get()) != kBsslSuccess) {
509 return absl::InternalError("Exponentiation failed.");
510 }
511 ANON_TOKENS_ASSIGN_OR_RETURN(
512 recovered_message_digest,
513 BignumToString(*recovered_message_digest_bn, rsa_modulus_size));
514 }
515 if (RSA_verify_PKCS1_PSS_mgf1(
516 rsa_public_key, reinterpret_cast<const uint8_t*>(&message_digest[0]),
517 sig_hash, mgf1_hash,
518 reinterpret_cast<const uint8_t*>(recovered_message_digest.data()),
519 salt_length) != kBsslSuccess) {
520 return absl::InvalidArgumentError(
521 absl::StrCat("PSS padding verification failed: ", GetSslErrors()));
522 }
523 return absl::OkStatus();
524 }
525
526 } // namespace anonymous_tokens
527 } // namespace private_membership
528