# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests that keys are consistently accepted or rejected in all languages.""" import itertools from typing import Iterable, Tuple from absl import logging from absl.testing import absltest from absl.testing import parameterized import tink from tink import aead from tink import daead from tink import hybrid from tink import mac from tink import prf from tink import signature from tink.proto import common_pb2 from tink.proto import ecdsa_pb2 from tink.proto import jwt_hmac_pb2 from tink.proto import tink_pb2 import tink_config from util import testing_servers # Test cases that succeed in a language but should fail SUCCEEDS_BUT_SHOULD_FAIL = [ # TODO(b/160130470): In CC and Python Hybrid templates are not checked for # valid AEAD params. (These params *are* checked when the key is used.) ('EciesAeadHkdfPrivateKey(NIST_P256,UNCOMPRESSED,SHA256,AesEaxKey(15,11))', 'cc'), ('EciesAeadHkdfPrivateKey(NIST_P256,UNCOMPRESSED,SHA256,AesEaxKey(15,11))', 'python'), ] # Test cases that fail in a language but should succeed FAILS_BUT_SHOULD_SUCCEED = [ # TODO(b/160134058) Java and Go do not accept templates with CURVE25519. ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA1,AesGcmKey(16))', 'java'), ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA1,AesGcmKey(16))', 'go'), ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA224,AesGcmKey(16))', 'java'), ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA224,AesGcmKey(16))', 'go'), ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA256,AesGcmKey(16))', 'java'), ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA256,AesGcmKey(16))', 'go'), ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA384,AesGcmKey(16))', 'java'), ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA384,AesGcmKey(16))', 'go'), ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA512,AesGcmKey(16))', 'java'), ('EciesAeadHkdfPrivateKey(CURVE25519,UNCOMPRESSED,SHA512,AesGcmKey(16))', 'go'), ] HASH_TYPES = [ common_pb2.UNKNOWN_HASH, common_pb2.SHA1, common_pb2.SHA224, common_pb2.SHA256, common_pb2.SHA384, common_pb2.SHA512 ] CURVE_TYPES = [ common_pb2.UNKNOWN_CURVE, common_pb2.NIST_P256, common_pb2.NIST_P384, common_pb2.NIST_P521, common_pb2.CURVE25519 ] EC_POINT_FORMATS = [ common_pb2.UNKNOWN_FORMAT, common_pb2.UNCOMPRESSED, common_pb2.COMPRESSED, common_pb2.DO_NOT_USE_CRUNCHY_UNCOMPRESSED ] SIGNATURE_ENCODINGS = [ ecdsa_pb2.UNKNOWN_ENCODING, ecdsa_pb2.IEEE_P1363, ecdsa_pb2.DER ] TestCasesType = Iterable[Tuple[str, tink_pb2.KeyTemplate]] def aes_eax_test_cases() -> TestCasesType: for key_size in [15, 16, 24, 32, 64, 96]: for iv_size in [11, 12, 16, 17, 24, 32]: yield ('AesEaxKey(%d,%d)' % (key_size, iv_size), aead.aead_key_templates.create_aes_eax_key_template( key_size, iv_size)) def aes_gcm_test_cases() -> TestCasesType: for key_size in [15, 16, 24, 32, 64, 96]: yield ('AesGcmKey(%d)' % key_size, aead.aead_key_templates.create_aes_gcm_key_template(key_size)) def aes_gcm_siv_test_cases() -> TestCasesType: for key_size in [15, 16, 24, 32, 64, 96]: yield ('AesGcmSivKey(%d)' % key_size, aead.aead_key_templates.create_aes_gcm_siv_key_template(key_size)) def aes_ctr_hmac_aead_test_cases() -> TestCasesType: def _test_case(aes_key_size=16, iv_size=16, hmac_key_size=16, tag_size=16, hash_type=common_pb2.SHA256): return ('AesCtrHmacAeadKey(%d,%d,%d,%d,%s)' % (aes_key_size, iv_size, hmac_key_size, tag_size, common_pb2.HashType.Name(hash_type)), aead.aead_key_templates.create_aes_ctr_hmac_aead_key_template( aes_key_size=aes_key_size, iv_size=iv_size, hmac_key_size=hmac_key_size, tag_size=tag_size, hash_type=hash_type)) for aes_key_size in [15, 16, 24, 32, 64, 96]: for iv_size in [11, 12, 16, 17, 24, 32]: yield _test_case(aes_key_size=aes_key_size, iv_size=iv_size) for hmac_key_size in [15, 16, 24, 32, 64, 96]: for tag_size in [9, 10, 16, 20, 21, 24, 32, 33, 64, 65]: for hash_type in HASH_TYPES: yield _test_case(hmac_key_size=hmac_key_size, tag_size=tag_size, hash_type=hash_type) def hmac_test_cases() -> TestCasesType: def _test_case(key_size=32, tag_size=16, hash_type=common_pb2.SHA256): return ('HmacKey(%d,%d,%s)' % (key_size, tag_size, common_pb2.HashType.Name(hash_type)), mac.mac_key_templates.create_hmac_key_template( key_size, tag_size, hash_type)) for key_size in [15, 16, 24, 32, 64, 96]: yield _test_case(key_size=key_size) for tag_size in [9, 10, 16, 20, 21, 24, 32, 33, 64, 65]: for hash_type in HASH_TYPES: yield _test_case(tag_size=tag_size, hash_type=hash_type) def jwt_hmac_test_cases() -> TestCasesType: def _test_case( algorithm: jwt_hmac_pb2.JwtHmacAlgorithm, key_size: int, output_prefix_type: tink_pb2.OutputPrefixType ) -> Tuple[str, tink_pb2.KeyTemplate]: key_format = jwt_hmac_pb2.JwtHmacKeyFormat( algorithm=algorithm, key_size=key_size) template = tink_pb2.KeyTemplate( type_url='type.googleapis.com/google.crypto.tink.JwtHmacKey', value=key_format.SerializeToString(), output_prefix_type=output_prefix_type) return ('JwtHmacKey(%d,%s,%s)' % (key_size, jwt_hmac_pb2.JwtHmacAlgorithm.Name(algorithm), tink_pb2.OutputPrefixType.Name(output_prefix_type)), template) yield _test_case(jwt_hmac_pb2.HS256, 31, tink_pb2.RAW) yield _test_case(jwt_hmac_pb2.HS256, 32, tink_pb2.RAW) yield _test_case(jwt_hmac_pb2.HS256, 32, tink_pb2.TINK) yield _test_case(jwt_hmac_pb2.HS256, 33, tink_pb2.RAW) yield _test_case(jwt_hmac_pb2.HS384, 47, tink_pb2.RAW) yield _test_case(jwt_hmac_pb2.HS384, 48, tink_pb2.RAW) yield _test_case(jwt_hmac_pb2.HS384, 48, tink_pb2.TINK) yield _test_case(jwt_hmac_pb2.HS384, 49, tink_pb2.RAW) yield _test_case(jwt_hmac_pb2.HS512, 63, tink_pb2.RAW) yield _test_case(jwt_hmac_pb2.HS512, 64, tink_pb2.RAW) yield _test_case(jwt_hmac_pb2.HS512, 64, tink_pb2.TINK) yield _test_case(jwt_hmac_pb2.HS512, 65, tink_pb2.RAW) def aes_cmac_test_cases() -> TestCasesType: def _test_case(key_size=32, tag_size=16): return ('AesCmacKey(%d,%d)' % (key_size, tag_size), mac.mac_key_templates.create_aes_cmac_key_template( key_size, tag_size)) for key_size in [15, 16, 24, 32, 64, 96]: yield _test_case(key_size=key_size) for tag_size in [9, 10, 16, 20, 21, 24, 32, 33, 64, 65]: yield _test_case(tag_size=tag_size) def aes_cmac_prf_test_cases() -> TestCasesType: for key_size in [15, 16, 24, 32, 64, 96]: yield ('AesCmacPrfKey(%d)' % key_size, prf.prf_key_templates._create_aes_cmac_key_template(key_size)) def hmac_prf_test_cases() -> TestCasesType: def _test_case(key_size=32, hash_type=common_pb2.SHA256): return ('HmacPrfKey(%d,%s)' % (key_size, common_pb2.HashType.Name(hash_type)), prf.prf_key_templates._create_hmac_key_template( key_size, hash_type)) for key_size in [15, 16, 24, 32, 64, 96]: yield _test_case(key_size=key_size) for hash_type in HASH_TYPES: yield _test_case(hash_type=hash_type) def hkdf_prf_test_cases() -> TestCasesType: def _test_case(key_size=32, hash_type=common_pb2.SHA256): return ('HkdfPrfKey(%d,%s)' % (key_size, common_pb2.HashType.Name(hash_type)), prf.prf_key_templates._create_hkdf_key_template( key_size, hash_type)) for key_size in [15, 16, 24, 32, 64, 96]: yield _test_case(key_size=key_size) for hash_type in HASH_TYPES: yield _test_case(hash_type=hash_type) def aes_siv_test_cases() -> TestCasesType: for key_size in [15, 16, 24, 32, 64, 96]: yield ('AesSivKey(%d)' % key_size, daead.deterministic_aead_key_templates.create_aes_siv_key_template( key_size)) def ecies_aead_hkdf_test_cases() -> TestCasesType: for curve_type in CURVE_TYPES: for hash_type in HASH_TYPES: ec_point_format = common_pb2.UNCOMPRESSED dem_key_template = aead.aead_key_templates.AES128_GCM yield ('EciesAeadHkdfPrivateKey(%s,%s,%s,AesGcmKey(16))' % (common_pb2.EllipticCurveType.Name(curve_type), common_pb2.EcPointFormat.Name(ec_point_format), common_pb2.HashType.Name(hash_type)), hybrid.hybrid_key_templates.create_ecies_aead_hkdf_key_template( curve_type, ec_point_format, hash_type, dem_key_template)) for ec_point_format in EC_POINT_FORMATS: curve_type = common_pb2.NIST_P256 hash_type = common_pb2.SHA256 dem_key_template = aead.aead_key_templates.AES128_GCM yield ('EciesAeadHkdfPrivateKey(%s,%s,%s,AesGcmKey(16))' % (common_pb2.EllipticCurveType.Name(curve_type), common_pb2.EcPointFormat.Name(ec_point_format), common_pb2.HashType.Name(hash_type)), hybrid.hybrid_key_templates.create_ecies_aead_hkdf_key_template( curve_type, ec_point_format, hash_type, dem_key_template)) curve_type = common_pb2.NIST_P256 ec_point_format = common_pb2.UNCOMPRESSED hash_type = common_pb2.SHA256 # Use invalid AEAD key template as DEM # TODO(juerg): Once b/160130470 is fixed, increase test coverage to all # aead templates. dem_key_template = aead.aead_key_templates.create_aes_eax_key_template(15, 11) yield ('EciesAeadHkdfPrivateKey(%s,%s,%s,AesEaxKey(15,11))' % (common_pb2.EllipticCurveType.Name(curve_type), common_pb2.EcPointFormat.Name(ec_point_format), common_pb2.HashType.Name(hash_type)), hybrid.hybrid_key_templates.create_ecies_aead_hkdf_key_template( curve_type, ec_point_format, hash_type, dem_key_template)) def ecdsa_test_cases() -> TestCasesType: for hash_type in HASH_TYPES: for curve_type in CURVE_TYPES: for signature_encoding in SIGNATURE_ENCODINGS: yield ('EcdsaPrivateKey(%s,%s,%s)' % (common_pb2.HashType.Name(hash_type), common_pb2.EllipticCurveType.Name(curve_type), ecdsa_pb2.EcdsaSignatureEncoding.Name(signature_encoding)), signature.signature_key_templates.create_ecdsa_key_template( hash_type, curve_type, signature_encoding)) def rsa_ssa_pkcs1_test_cases() -> TestCasesType: gen = signature.signature_key_templates.create_rsa_ssa_pkcs1_key_template for hash_type in HASH_TYPES: modulus_size = 2048 public_exponent = 65537 yield ('RsaSsaPkcs1PrivateKey(%s,%d,%d)' % (common_pb2.HashType.Name(hash_type), modulus_size, public_exponent), gen(hash_type, modulus_size, public_exponent)) for modulus_size in [0, 2000, 3072, 4096]: hash_type = common_pb2.SHA256 public_exponent = 65537 yield ('RsaSsaPkcs1PrivateKey(%s,%d,%d)' % (common_pb2.HashType.Name(hash_type), modulus_size, public_exponent), gen(hash_type, modulus_size, public_exponent)) for public_exponent in [0, 1, 2, 3, 65536, 65537, 65538]: hash_type = common_pb2.SHA256 modulus_size = 2048 yield ('RsaSsaPkcs1PrivateKey(%s,%d,%d)' % (common_pb2.HashType.Name(hash_type), modulus_size, public_exponent), gen(hash_type, modulus_size, public_exponent)) def rsa_ssa_pss_test_cases() -> TestCasesType: gen = signature.signature_key_templates.create_rsa_ssa_pss_key_template for hash_type in HASH_TYPES: salt_length = 32 modulus_size = 2048 public_exponent = 65537 yield ('RsaSsaPssPrivateKey(%s,%s,%d,%d,%d)' % (common_pb2.HashType.Name(hash_type), common_pb2.HashType.Name(hash_type), salt_length, modulus_size, public_exponent), gen(hash_type, hash_type, salt_length, modulus_size, public_exponent)) for salt_length in [-3, 0, 1, 16, 64]: hash_type = common_pb2.SHA256 modulus_size = 2048 public_exponent = 65537 yield ('RsaSsaPssPrivateKey(%s,%s,%d,%d,%d)' % (common_pb2.HashType.Name(hash_type), common_pb2.HashType.Name(hash_type), salt_length, modulus_size, public_exponent), gen(hash_type, hash_type, salt_length, modulus_size, public_exponent)) for modulus_size in [0, 2000, 3072, 4096]: hash_type = common_pb2.SHA256 salt_length = 32 public_exponent = 65537 yield ('RsaSsaPssPrivateKey(%s,%s,%d,%d,%d)' % (common_pb2.HashType.Name(hash_type), common_pb2.HashType.Name(hash_type), salt_length, modulus_size, public_exponent), gen(hash_type, hash_type, salt_length, modulus_size, public_exponent)) hash_type1 = common_pb2.SHA256 hash_type2 = common_pb2.SHA512 salt_length = 32 modulus_size = 2048 public_exponent = 65537 yield ('RsaSsaPssPrivateKey(%s,%s,%d,%d,%d)' % (common_pb2.HashType.Name(hash_type1), common_pb2.HashType.Name(hash_type2), salt_length, modulus_size, public_exponent), gen(hash_type1, hash_type2, salt_length, modulus_size, public_exponent)) for public_exponent in [0, 1, 2, 3, 65536, 65537, 65538]: hash_type = common_pb2.SHA256 salt_length = 32 modulus_size = 2048 yield ('RsaSsaPssPrivateKey(%s,%s,%d,%d,%d)' % (common_pb2.HashType.Name(hash_type), common_pb2.HashType.Name(hash_type), salt_length, modulus_size, public_exponent), gen(hash_type, hash_type, salt_length, modulus_size, public_exponent)) def setUpModule(): aead.register() daead.register() mac.register() hybrid.register() signature.register() testing_servers.start('key_generation_consistency') def tearDownModule(): testing_servers.stop() class KeyGenerationConsistencyTest(parameterized.TestCase): @parameterized.parameters( itertools.chain(aes_eax_test_cases(), aes_gcm_test_cases(), aes_gcm_siv_test_cases(), aes_ctr_hmac_aead_test_cases(), hmac_test_cases(), jwt_hmac_test_cases(), aes_cmac_test_cases(), aes_cmac_prf_test_cases(), hmac_prf_test_cases(), hkdf_prf_test_cases(), aes_siv_test_cases(), ecies_aead_hkdf_test_cases(), ecdsa_test_cases(), rsa_ssa_pkcs1_test_cases(), rsa_ssa_pss_test_cases())) def test_key_generation_consistency(self, name, template): supported_langs = tink_config.supported_languages_for_key_type( tink_config.key_type_from_type_url(template.type_url)) failures = 0 results = {} for lang in supported_langs: try: _ = testing_servers.new_keyset(lang, template) if (name, lang) in SUCCEEDS_BUT_SHOULD_FAIL: failures += 1 if (name, lang) in FAILS_BUT_SHOULD_SUCCEED: self.fail('(%s, %s) succeeded, but is in FAILS_BUT_SHOULD_SUCCEED' % (name, lang)) results[lang] = 'success' except tink.TinkError as e: if (name, lang) not in FAILS_BUT_SHOULD_SUCCEED: failures += 1 if (name, lang) in SUCCEEDS_BUT_SHOULD_FAIL: self.fail( '(%s, %s) is in SUCCEEDS_BUT_SHOULD_FAIL, but failed with %s' % (name, lang, e)) results[lang] = e # Test that either all supported langs accept the template, or all reject. if failures not in [0, len(supported_langs)]: self.fail('key generation for template %s is inconsistent: %s' % (name, results)) logging.info('Key generation status: %s', 'Success' if failures == 0 else 'Fail') if __name__ == '__main__': absltest.main()