1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Cross-language tests for the KMS Envelope AEAD primitive with AWS and GCP.""" 15from typing import Dict, Iterable, List, Sequence, Tuple 16 17from absl.testing import absltest 18from absl.testing import parameterized 19import tink 20from tink import aead 21 22from tink.proto import tink_pb2 23from util import testing_servers 24from util import utilities 25 26# AWS Key with alias "unit-and-integration-testing" 27_AWS_KEY_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:key/' 28 '3ee50705-5a82-4f5b-9753-05c4f473922f') 29_AWS_KEY_ALIAS_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:alias/' 30 'unit-and-integration-testing') 31 32 33# 2nd AWS Key with alias "unit-and-integration-testing-2" 34_AWS_KEY_2_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:key/' 35 'b3ca2efd-a8fb-47f2-b541-7e20f8c5cd11') 36_AWS_KEY_2_ALIAS_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:alias/' 37 'unit-and-integration-testing-2') 38 39_AWS_UNKNOWN_KEY_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:key/' 40 '4ee50705-5a82-4f5b-9753-05c4f473922f') 41_AWS_UNKNOWN_KEY_ALIAS_URI = ( 42 'aws-kms://arn:aws:kms:us-east-2:235739564943:alias/' 43 'unknown-unit-and-integration-testing') 44 45_GCP_KEY_URI = ('gcp-kms://projects/tink-test-infrastructure/locations/global/' 46 'keyRings/unit-and-integration-testing/cryptoKeys/aead-key') 47_GCP_KEY_2_URI = ( 48 'gcp-kms://projects/tink-test-infrastructure/locations/global/' 49 'keyRings/unit-and-integration-testing/cryptoKeys/aead2-key') 50_GCP_UNKNOWN_KEY_URI = ( 51 'gcp-kms://projects/tink-test-infrastructure/locations/global/' 52 'keyRings/unit-and-integration-testing/cryptoKeys/unknown') 53 54_KMS_KEY_URI = { 55 'GCP': _GCP_KEY_URI, 56 'AWS': _AWS_KEY_URI, 57} 58 59_DEK_TEMPLATE = utilities.KEY_TEMPLATE['AES128_GCM'] 60 61 62def _kms_envelope_aead_templates( 63 kms_services: Sequence[str]) -> Dict[str, tink_pb2.KeyTemplate]: 64 """Generates a map from KMS envelope AEAD template name to key template.""" 65 kms_key_templates = {} 66 for kms_service in kms_services: 67 key_uri = _KMS_KEY_URI[kms_service] 68 kms_envelope_aead_key_template = ( 69 aead.aead_key_templates.create_kms_envelope_aead_key_template( 70 key_uri, _DEK_TEMPLATE)) 71 kms_envelope_aead_template_name = '%s_KMS_ENVELOPE_AEAD' % kms_service 72 kms_key_templates[kms_envelope_aead_template_name] = ( 73 kms_envelope_aead_key_template) 74 return kms_key_templates 75 76 77_KMS_ENVELOPE_AEAD_KEY_TEMPLATES = _kms_envelope_aead_templates(['GCP', 'AWS']) 78_SUPPORTED_LANGUAGES_FOR_KMS_ENVELOPE_AEAD = ('python', 'cc', 'go', 'java') 79 80_SUPPORTED_LANGUAGES_FOR_KMS_AEAD = { 81 'AWS': ('python', 'cc', 'go', 'java'), 82 'GCP': ('python', 'cc', 'go', 'java'), 83} 84 85 86def setUpModule(): 87 aead.register() 88 testing_servers.start('aead') 89 90 91def tearDownModule(): 92 testing_servers.stop() 93 94 95def _get_lang_tuples(langs: List[str]) -> Iterable[Tuple[str, str]]: 96 """Yields language tuples to run cross-language tests. 97 98 Ideally, we would want to the test all possible tuples of languages. But 99 that results in a quadratic number of tuples. It is not really necessary, 100 because if an implementation in one language does something different, then 101 any cross-language test with another language will fail. So it is enough to 102 only use every implementation once for encryption and once for decryption. 103 104 Args: 105 langs: List of language names. 106 107 Yields: 108 Tuples of 2 languages. 109 """ 110 for i, _ in enumerate(langs): 111 yield (langs[i], langs[((i + 1) % len(langs))]) 112 113 114def _get_plaintext_and_aad(key_template_name: str, 115 lang: str) -> Tuple[bytes, bytes]: 116 """Creates test plaintext and associated data from a key template and lang.""" 117 plaintext = ( 118 b'This is some plaintext message to be encrypted using key_template ' 119 b'%s using %s for encryption.' % 120 (key_template_name.encode('utf8'), lang.encode('utf8'))) 121 associated_data = (b'Some associated data for %s using %s for encryption.' % 122 (key_template_name.encode('utf8'), lang.encode('utf8'))) 123 return (plaintext, associated_data) 124 125 126def _kms_aead_test_cases() -> Iterable[Tuple[str, str, str]]: 127 """Yields (KMS service, encrypt lang, decrypt lang).""" 128 for kms_service, supported_langs in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.items(): 129 for encrypt_lang, decrypt_lang in _get_lang_tuples(supported_langs): 130 yield (kms_service, encrypt_lang, decrypt_lang) 131 132 133def _two_key_uris_test_cases(): 134 for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('AWS', []): 135 yield (lang, _AWS_KEY_URI, _AWS_KEY_2_URI) 136 for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('GCP', []): 137 yield (lang, _GCP_KEY_URI, _GCP_KEY_2_URI) 138 139 140def _key_uris_with_alias_test_cases(): 141 for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('AWS', []): 142 yield (lang, _AWS_KEY_ALIAS_URI) 143 144 145def _two_key_uris_with_alias_test_cases(): 146 for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('AWS', []): 147 yield (lang, _AWS_KEY_ALIAS_URI, _AWS_KEY_2_ALIAS_URI) 148 149 150def _unknown_key_uris_test_cases(): 151 for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('AWS', []): 152 yield (lang, _AWS_UNKNOWN_KEY_URI) 153 yield (lang, _AWS_UNKNOWN_KEY_ALIAS_URI) 154 for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('GCP', []): 155 yield (lang, _GCP_UNKNOWN_KEY_URI) 156 157 158class KmsAeadTest(parameterized.TestCase): 159 160 def test_get_lang_tuples(self): 161 self.assertEqual( 162 list(_get_lang_tuples(['cc', 'java', 'go', 'python'])), 163 [('cc', 'java'), ('java', 'go'), ('go', 'python'), ('python', 'cc')], 164 ) 165 self.assertEqual(list(_get_lang_tuples([])), []) 166 167 @parameterized.parameters(_kms_aead_test_cases()) 168 def test_encrypt_decrypt_with_associated_data( 169 self, kms_service, encrypt_lang, decrypt_lang 170 ): 171 kms_key_uri = _KMS_KEY_URI[kms_service] 172 kms_aead_template_name = '%s_KMS_AEAD' % kms_service 173 key_template = aead.aead_key_templates.create_kms_aead_key_template( 174 kms_key_uri) 175 keyset = testing_servers.new_keyset(encrypt_lang, key_template) 176 encrypt_primitive = testing_servers.remote_primitive( 177 lang=encrypt_lang, keyset=keyset, primitive_class=aead.Aead) 178 plaintext, associated_data = _get_plaintext_and_aad(kms_aead_template_name, 179 encrypt_primitive.lang) 180 ciphertext = encrypt_primitive.encrypt(plaintext, associated_data) 181 decrypt_primitive = testing_servers.remote_primitive( 182 decrypt_lang, keyset, aead.Aead) 183 output = decrypt_primitive.decrypt(ciphertext, associated_data) 184 self.assertEqual(output, plaintext) 185 186 @parameterized.parameters(_kms_aead_test_cases()) 187 def test_encrypt_decrypt_with_empty_associated_data( 188 self, kms_service, encrypt_lang, decrypt_lang 189 ): 190 kms_key_uri = _KMS_KEY_URI[kms_service] 191 key_template = aead.aead_key_templates.create_kms_aead_key_template( 192 kms_key_uri) 193 keyset = testing_servers.new_keyset(encrypt_lang, key_template) 194 encrypt_primitive = testing_servers.remote_primitive( 195 lang=encrypt_lang, keyset=keyset, primitive_class=aead.Aead) 196 plaintext = b'plaintext' 197 associated_data = b'' 198 ciphertext = encrypt_primitive.encrypt(plaintext, associated_data) 199 decrypt_primitive = testing_servers.remote_primitive( 200 decrypt_lang, keyset, aead.Aead) 201 output = decrypt_primitive.decrypt(ciphertext, associated_data) 202 self.assertEqual(output, plaintext) 203 204 @parameterized.parameters(_two_key_uris_test_cases()) 205 def test_cannot_decrypt_ciphertext_of_other_key_uri(self, lang, key_uri, 206 key_uri_2): 207 keyset = testing_servers.new_keyset( 208 lang, aead.aead_key_templates.create_kms_aead_key_template(key_uri)) 209 keyset_2 = testing_servers.new_keyset( 210 lang, aead.aead_key_templates.create_kms_aead_key_template(key_uri_2)) 211 212 primitive = testing_servers.remote_primitive( 213 lang=lang, keyset=keyset, primitive_class=aead.Aead) 214 primitive_2 = testing_servers.remote_primitive( 215 lang=lang, keyset=keyset_2, primitive_class=aead.Aead) 216 217 plaintext = b'plaintext' 218 associated_data = b'associated_data' 219 220 ciphertext = primitive.encrypt(plaintext, associated_data) 221 ciphertext_2 = primitive_2.encrypt(plaintext, associated_data) 222 223 # Can be decrypted by the primtive that created the ciphertext. 224 self.assertEqual(primitive.decrypt(ciphertext, associated_data), plaintext) 225 self.assertEqual( 226 primitive_2.decrypt(ciphertext_2, associated_data), plaintext) 227 228 # Cannot be decrypted by the other primitive. 229 with self.assertRaises(tink.TinkError): 230 primitive.decrypt(ciphertext_2, associated_data) 231 with self.assertRaises(tink.TinkError): 232 primitive_2.decrypt(ciphertext, associated_data) 233 234 @parameterized.parameters(_key_uris_with_alias_test_cases()) 235 def test_encrypt_decrypt_with_key_aliases(self, lang, alias_key_uri): 236 keyset = testing_servers.new_keyset( 237 lang, 238 aead.aead_key_templates.create_kms_aead_key_template(alias_key_uri)) 239 primitive = testing_servers.remote_primitive( 240 lang=lang, keyset=keyset, primitive_class=aead.Aead) 241 plaintext = b'plaintext' 242 associated_data = b'associated_data' 243 ciphertext = primitive.encrypt(plaintext, associated_data) 244 self.assertEqual( 245 primitive.decrypt(ciphertext, associated_data), plaintext) 246 247 @parameterized.parameters(_two_key_uris_with_alias_test_cases()) 248 def test_cannot_decrypt_ciphertext_of_other_alias_key_uri( 249 self, lang, alias_key_uri, alias_key_uri_2): 250 keyset = testing_servers.new_keyset( 251 lang, 252 aead.aead_key_templates.create_kms_aead_key_template(alias_key_uri)) 253 keyset_2 = testing_servers.new_keyset( 254 lang, 255 aead.aead_key_templates.create_kms_aead_key_template(alias_key_uri_2)) 256 257 primitive = testing_servers.remote_primitive( 258 lang=lang, keyset=keyset, primitive_class=aead.Aead) 259 primitive_2 = testing_servers.remote_primitive( 260 lang=lang, keyset=keyset_2, primitive_class=aead.Aead) 261 262 plaintext = b'plaintext' 263 associated_data = b'associated_data' 264 265 ciphertext = primitive.encrypt(plaintext, associated_data) 266 ciphertext_2 = primitive_2.encrypt(plaintext, associated_data) 267 268 # Can be decrypted by the primtive that created the ciphertext. 269 self.assertEqual(primitive.decrypt(ciphertext, associated_data), plaintext) 270 self.assertEqual( 271 primitive_2.decrypt(ciphertext_2, associated_data), plaintext) 272 273 # Cannot be decrypted by the other primitive. 274 with self.assertRaises(tink.TinkError): 275 primitive.decrypt(ciphertext_2, associated_data) 276 with self.assertRaises(tink.TinkError): 277 primitive_2.decrypt(ciphertext, associated_data) 278 279 @parameterized.parameters(_unknown_key_uris_test_cases()) 280 def test_encrypt_fails_with_unknown_key_uri(self, lang, unknown_key_uri): 281 key_template = aead.aead_key_templates.create_kms_aead_key_template( 282 unknown_key_uri) 283 keyset = testing_servers.new_keyset(lang, key_template) 284 primitive = testing_servers.remote_primitive( 285 lang=lang, keyset=keyset, primitive_class=aead.Aead) 286 287 plaintext = b'plaintext' 288 associated_data = b'associated_data' 289 290 with self.assertRaises(tink.TinkError): 291 primitive.encrypt(plaintext, associated_data) 292 293 294def _kms_envelope_aead_test_cases() -> Iterable[Tuple[str, str, str]]: 295 """Yields (KMS Envelope AEAD template names, encrypt lang, decrypt lang).""" 296 for key_template_name in _KMS_ENVELOPE_AEAD_KEY_TEMPLATES: 297 # Make sure to test languages that support the pritive used for DEK. 298 supported_langs = _SUPPORTED_LANGUAGES_FOR_KMS_ENVELOPE_AEAD 299 for encrypt_lang, decrypt_lang in _get_lang_tuples(supported_langs): 300 yield (key_template_name, encrypt_lang, decrypt_lang) 301 302 303class KmsEnvelopeAeadTest(parameterized.TestCase): 304 305 @parameterized.parameters(_kms_envelope_aead_test_cases()) 306 def test_encrypt_decrypt_with_associated_data( 307 self, key_template_name, encrypt_lang, decrypt_lang 308 ): 309 key_template = _KMS_ENVELOPE_AEAD_KEY_TEMPLATES[key_template_name] 310 # Use the encryption language to generate the keyset proto. 311 keyset = testing_servers.new_keyset(encrypt_lang, key_template) 312 encrypt_primitive = testing_servers.remote_primitive( 313 encrypt_lang, keyset, aead.Aead) 314 plaintext, associated_data = _get_plaintext_and_aad(key_template_name, 315 encrypt_primitive.lang) 316 ciphertext = encrypt_primitive.encrypt(plaintext, associated_data) 317 318 # Decrypt. 319 decrypt_primitive = testing_servers.remote_primitive( 320 decrypt_lang, keyset, aead.Aead) 321 output = decrypt_primitive.decrypt(ciphertext, associated_data) 322 self.assertEqual(output, plaintext) 323 324 @parameterized.parameters(_kms_envelope_aead_test_cases()) 325 def test_encrypt_decrypt_with_empty_associated_data( 326 self, key_template_name, encrypt_lang, decrypt_lang 327 ): 328 key_template = _KMS_ENVELOPE_AEAD_KEY_TEMPLATES[key_template_name] 329 # Use the encryption language to generate the keyset proto. 330 keyset = testing_servers.new_keyset(encrypt_lang, key_template) 331 encrypt_primitive = testing_servers.remote_primitive( 332 encrypt_lang, keyset, aead.Aead) 333 plaintext = b'plaintext' 334 associated_data = b'' 335 ciphertext = encrypt_primitive.encrypt(plaintext, associated_data) 336 decrypt_primitive = testing_servers.remote_primitive( 337 decrypt_lang, keyset, aead.Aead) 338 output = decrypt_primitive.decrypt(ciphertext, associated_data) 339 self.assertEqual(output, plaintext) 340 341 @parameterized.parameters(_kms_envelope_aead_test_cases()) 342 def test_decryption_fails_with_wrong_aad(self, key_template_name, 343 encrypt_lang, decrypt_lang): 344 key_template = _KMS_ENVELOPE_AEAD_KEY_TEMPLATES[key_template_name] 345 # Use the encryption language to generate the keyset proto. 346 keyset = testing_servers.new_keyset(encrypt_lang, key_template) 347 encrypt_primitive = testing_servers.remote_primitive( 348 encrypt_lang, keyset, aead.Aead) 349 plaintext, associated_data = _get_plaintext_and_aad(key_template_name, 350 encrypt_primitive.lang) 351 ciphertext = encrypt_primitive.encrypt(plaintext, associated_data) 352 decrypt_primitive = testing_servers.remote_primitive( 353 decrypt_lang, keyset, aead.Aead) 354 with self.assertRaises(tink.TinkError, msg='decryption failed'): 355 decrypt_primitive.decrypt(ciphertext, b'wrong aad') 356 357if __name__ == '__main__': 358 absltest.main() 359