• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Cross-language tests for the KMS Envelope AEAD primitive with AWS and GCP."""
15from typing import Dict, Iterable, List, Sequence, Tuple
16
17from absl.testing import absltest
18from absl.testing import parameterized
19import tink
20from tink import aead
21
22from tink.proto import tink_pb2
23from util import testing_servers
24from util import utilities
25
26# AWS Key with alias "unit-and-integration-testing"
27_AWS_KEY_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:key/'
28                '3ee50705-5a82-4f5b-9753-05c4f473922f')
29_AWS_KEY_ALIAS_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:alias/'
30                      'unit-and-integration-testing')
31
32
33# 2nd AWS Key with alias "unit-and-integration-testing-2"
34_AWS_KEY_2_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:key/'
35                  'b3ca2efd-a8fb-47f2-b541-7e20f8c5cd11')
36_AWS_KEY_2_ALIAS_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:alias/'
37                        'unit-and-integration-testing-2')
38
39_AWS_UNKNOWN_KEY_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:key/'
40                        '4ee50705-5a82-4f5b-9753-05c4f473922f')
41_AWS_UNKNOWN_KEY_ALIAS_URI = (
42    'aws-kms://arn:aws:kms:us-east-2:235739564943:alias/'
43    'unknown-unit-and-integration-testing')
44
45_GCP_KEY_URI = ('gcp-kms://projects/tink-test-infrastructure/locations/global/'
46                'keyRings/unit-and-integration-testing/cryptoKeys/aead-key')
47_GCP_KEY_2_URI = (
48    'gcp-kms://projects/tink-test-infrastructure/locations/global/'
49    'keyRings/unit-and-integration-testing/cryptoKeys/aead2-key')
50_GCP_UNKNOWN_KEY_URI = (
51    'gcp-kms://projects/tink-test-infrastructure/locations/global/'
52    'keyRings/unit-and-integration-testing/cryptoKeys/unknown')
53
54_KMS_KEY_URI = {
55    'GCP': _GCP_KEY_URI,
56    'AWS': _AWS_KEY_URI,
57}
58
59_DEK_TEMPLATE = utilities.KEY_TEMPLATE['AES128_GCM']
60
61
62def _kms_envelope_aead_templates(
63    kms_services: Sequence[str]) -> Dict[str, tink_pb2.KeyTemplate]:
64  """Generates a map from KMS envelope AEAD template name to key template."""
65  kms_key_templates = {}
66  for kms_service in kms_services:
67    key_uri = _KMS_KEY_URI[kms_service]
68    kms_envelope_aead_key_template = (
69        aead.aead_key_templates.create_kms_envelope_aead_key_template(
70            key_uri, _DEK_TEMPLATE))
71    kms_envelope_aead_template_name = '%s_KMS_ENVELOPE_AEAD' % kms_service
72    kms_key_templates[kms_envelope_aead_template_name] = (
73        kms_envelope_aead_key_template)
74  return kms_key_templates
75
76
77_KMS_ENVELOPE_AEAD_KEY_TEMPLATES = _kms_envelope_aead_templates(['GCP', 'AWS'])
78_SUPPORTED_LANGUAGES_FOR_KMS_ENVELOPE_AEAD = ('python', 'cc', 'go', 'java')
79
80_SUPPORTED_LANGUAGES_FOR_KMS_AEAD = {
81    'AWS': ('python', 'cc', 'go', 'java'),
82    'GCP': ('python', 'cc', 'go', 'java'),
83}
84
85
86def setUpModule():
87  aead.register()
88  testing_servers.start('aead')
89
90
91def tearDownModule():
92  testing_servers.stop()
93
94
95def _get_lang_tuples(langs: List[str]) -> Iterable[Tuple[str, str]]:
96  """Yields language tuples to run cross-language tests.
97
98  Ideally, we would want to the test all possible tuples of languages. But
99  that results in a quadratic number of tuples. It is not really necessary,
100  because if an implementation in one language does something different, then
101  any cross-language test with another language will fail. So it is enough to
102  only use every implementation once for encryption and once for decryption.
103
104  Args:
105    langs: List of language names.
106
107  Yields:
108    Tuples of 2 languages.
109  """
110  for i, _ in enumerate(langs):
111    yield (langs[i], langs[((i + 1) % len(langs))])
112
113
114def _get_plaintext_and_aad(key_template_name: str,
115                           lang: str) -> Tuple[bytes, bytes]:
116  """Creates test plaintext and associated data from a key template and lang."""
117  plaintext = (
118      b'This is some plaintext message to be encrypted using key_template '
119      b'%s using %s for encryption.' %
120      (key_template_name.encode('utf8'), lang.encode('utf8')))
121  associated_data = (b'Some associated data for %s using %s for encryption.' %
122                     (key_template_name.encode('utf8'), lang.encode('utf8')))
123  return (plaintext, associated_data)
124
125
126def _kms_aead_test_cases() -> Iterable[Tuple[str, str, str]]:
127  """Yields (KMS service, encrypt lang, decrypt lang)."""
128  for kms_service, supported_langs in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.items():
129    for encrypt_lang, decrypt_lang in _get_lang_tuples(supported_langs):
130      yield (kms_service, encrypt_lang, decrypt_lang)
131
132
133def _two_key_uris_test_cases():
134  for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('AWS', []):
135    yield (lang, _AWS_KEY_URI, _AWS_KEY_2_URI)
136  for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('GCP', []):
137    yield (lang, _GCP_KEY_URI, _GCP_KEY_2_URI)
138
139
140def _key_uris_with_alias_test_cases():
141  for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('AWS', []):
142    yield (lang, _AWS_KEY_ALIAS_URI)
143
144
145def _two_key_uris_with_alias_test_cases():
146  for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('AWS', []):
147    yield (lang, _AWS_KEY_ALIAS_URI, _AWS_KEY_2_ALIAS_URI)
148
149
150def _unknown_key_uris_test_cases():
151  for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('AWS', []):
152    yield (lang, _AWS_UNKNOWN_KEY_URI)
153    yield (lang, _AWS_UNKNOWN_KEY_ALIAS_URI)
154  for lang in _SUPPORTED_LANGUAGES_FOR_KMS_AEAD.get('GCP', []):
155    yield (lang, _GCP_UNKNOWN_KEY_URI)
156
157
158class KmsAeadTest(parameterized.TestCase):
159
160  def test_get_lang_tuples(self):
161    self.assertEqual(
162        list(_get_lang_tuples(['cc', 'java', 'go', 'python'])),
163        [('cc', 'java'), ('java', 'go'), ('go', 'python'), ('python', 'cc')],
164    )
165    self.assertEqual(list(_get_lang_tuples([])), [])
166
167  @parameterized.parameters(_kms_aead_test_cases())
168  def test_encrypt_decrypt_with_associated_data(
169      self, kms_service, encrypt_lang, decrypt_lang
170  ):
171    kms_key_uri = _KMS_KEY_URI[kms_service]
172    kms_aead_template_name = '%s_KMS_AEAD' % kms_service
173    key_template = aead.aead_key_templates.create_kms_aead_key_template(
174        kms_key_uri)
175    keyset = testing_servers.new_keyset(encrypt_lang, key_template)
176    encrypt_primitive = testing_servers.remote_primitive(
177        lang=encrypt_lang, keyset=keyset, primitive_class=aead.Aead)
178    plaintext, associated_data = _get_plaintext_and_aad(kms_aead_template_name,
179                                                        encrypt_primitive.lang)
180    ciphertext = encrypt_primitive.encrypt(plaintext, associated_data)
181    decrypt_primitive = testing_servers.remote_primitive(
182        decrypt_lang, keyset, aead.Aead)
183    output = decrypt_primitive.decrypt(ciphertext, associated_data)
184    self.assertEqual(output, plaintext)
185
186  @parameterized.parameters(_kms_aead_test_cases())
187  def test_encrypt_decrypt_with_empty_associated_data(
188      self, kms_service, encrypt_lang, decrypt_lang
189  ):
190    kms_key_uri = _KMS_KEY_URI[kms_service]
191    key_template = aead.aead_key_templates.create_kms_aead_key_template(
192        kms_key_uri)
193    keyset = testing_servers.new_keyset(encrypt_lang, key_template)
194    encrypt_primitive = testing_servers.remote_primitive(
195        lang=encrypt_lang, keyset=keyset, primitive_class=aead.Aead)
196    plaintext = b'plaintext'
197    associated_data = b''
198    ciphertext = encrypt_primitive.encrypt(plaintext, associated_data)
199    decrypt_primitive = testing_servers.remote_primitive(
200        decrypt_lang, keyset, aead.Aead)
201    output = decrypt_primitive.decrypt(ciphertext, associated_data)
202    self.assertEqual(output, plaintext)
203
204  @parameterized.parameters(_two_key_uris_test_cases())
205  def test_cannot_decrypt_ciphertext_of_other_key_uri(self, lang, key_uri,
206                                                      key_uri_2):
207    keyset = testing_servers.new_keyset(
208        lang, aead.aead_key_templates.create_kms_aead_key_template(key_uri))
209    keyset_2 = testing_servers.new_keyset(
210        lang, aead.aead_key_templates.create_kms_aead_key_template(key_uri_2))
211
212    primitive = testing_servers.remote_primitive(
213        lang=lang, keyset=keyset, primitive_class=aead.Aead)
214    primitive_2 = testing_servers.remote_primitive(
215        lang=lang, keyset=keyset_2, primitive_class=aead.Aead)
216
217    plaintext = b'plaintext'
218    associated_data = b'associated_data'
219
220    ciphertext = primitive.encrypt(plaintext, associated_data)
221    ciphertext_2 = primitive_2.encrypt(plaintext, associated_data)
222
223    # Can be decrypted by the primtive that created the ciphertext.
224    self.assertEqual(primitive.decrypt(ciphertext, associated_data), plaintext)
225    self.assertEqual(
226        primitive_2.decrypt(ciphertext_2, associated_data), plaintext)
227
228    # Cannot be decrypted by the other primitive.
229    with self.assertRaises(tink.TinkError):
230      primitive.decrypt(ciphertext_2, associated_data)
231    with self.assertRaises(tink.TinkError):
232      primitive_2.decrypt(ciphertext, associated_data)
233
234  @parameterized.parameters(_key_uris_with_alias_test_cases())
235  def test_encrypt_decrypt_with_key_aliases(self, lang, alias_key_uri):
236    keyset = testing_servers.new_keyset(
237        lang,
238        aead.aead_key_templates.create_kms_aead_key_template(alias_key_uri))
239    primitive = testing_servers.remote_primitive(
240        lang=lang, keyset=keyset, primitive_class=aead.Aead)
241    plaintext = b'plaintext'
242    associated_data = b'associated_data'
243    ciphertext = primitive.encrypt(plaintext, associated_data)
244    self.assertEqual(
245        primitive.decrypt(ciphertext, associated_data), plaintext)
246
247  @parameterized.parameters(_two_key_uris_with_alias_test_cases())
248  def test_cannot_decrypt_ciphertext_of_other_alias_key_uri(
249      self, lang, alias_key_uri, alias_key_uri_2):
250    keyset = testing_servers.new_keyset(
251        lang,
252        aead.aead_key_templates.create_kms_aead_key_template(alias_key_uri))
253    keyset_2 = testing_servers.new_keyset(
254        lang,
255        aead.aead_key_templates.create_kms_aead_key_template(alias_key_uri_2))
256
257    primitive = testing_servers.remote_primitive(
258        lang=lang, keyset=keyset, primitive_class=aead.Aead)
259    primitive_2 = testing_servers.remote_primitive(
260        lang=lang, keyset=keyset_2, primitive_class=aead.Aead)
261
262    plaintext = b'plaintext'
263    associated_data = b'associated_data'
264
265    ciphertext = primitive.encrypt(plaintext, associated_data)
266    ciphertext_2 = primitive_2.encrypt(plaintext, associated_data)
267
268    # Can be decrypted by the primtive that created the ciphertext.
269    self.assertEqual(primitive.decrypt(ciphertext, associated_data), plaintext)
270    self.assertEqual(
271        primitive_2.decrypt(ciphertext_2, associated_data), plaintext)
272
273    # Cannot be decrypted by the other primitive.
274    with self.assertRaises(tink.TinkError):
275      primitive.decrypt(ciphertext_2, associated_data)
276    with self.assertRaises(tink.TinkError):
277      primitive_2.decrypt(ciphertext, associated_data)
278
279  @parameterized.parameters(_unknown_key_uris_test_cases())
280  def test_encrypt_fails_with_unknown_key_uri(self, lang, unknown_key_uri):
281    key_template = aead.aead_key_templates.create_kms_aead_key_template(
282        unknown_key_uri)
283    keyset = testing_servers.new_keyset(lang, key_template)
284    primitive = testing_servers.remote_primitive(
285        lang=lang, keyset=keyset, primitive_class=aead.Aead)
286
287    plaintext = b'plaintext'
288    associated_data = b'associated_data'
289
290    with self.assertRaises(tink.TinkError):
291      primitive.encrypt(plaintext, associated_data)
292
293
294def _kms_envelope_aead_test_cases() -> Iterable[Tuple[str, str, str]]:
295  """Yields (KMS Envelope AEAD template names, encrypt lang, decrypt lang)."""
296  for key_template_name in _KMS_ENVELOPE_AEAD_KEY_TEMPLATES:
297    # Make sure to test languages that support the pritive used for DEK.
298    supported_langs = _SUPPORTED_LANGUAGES_FOR_KMS_ENVELOPE_AEAD
299    for encrypt_lang, decrypt_lang in _get_lang_tuples(supported_langs):
300      yield (key_template_name, encrypt_lang, decrypt_lang)
301
302
303class KmsEnvelopeAeadTest(parameterized.TestCase):
304
305  @parameterized.parameters(_kms_envelope_aead_test_cases())
306  def test_encrypt_decrypt_with_associated_data(
307      self, key_template_name, encrypt_lang, decrypt_lang
308  ):
309    key_template = _KMS_ENVELOPE_AEAD_KEY_TEMPLATES[key_template_name]
310    # Use the encryption language to generate the keyset proto.
311    keyset = testing_servers.new_keyset(encrypt_lang, key_template)
312    encrypt_primitive = testing_servers.remote_primitive(
313        encrypt_lang, keyset, aead.Aead)
314    plaintext, associated_data = _get_plaintext_and_aad(key_template_name,
315                                                        encrypt_primitive.lang)
316    ciphertext = encrypt_primitive.encrypt(plaintext, associated_data)
317
318    # Decrypt.
319    decrypt_primitive = testing_servers.remote_primitive(
320        decrypt_lang, keyset, aead.Aead)
321    output = decrypt_primitive.decrypt(ciphertext, associated_data)
322    self.assertEqual(output, plaintext)
323
324  @parameterized.parameters(_kms_envelope_aead_test_cases())
325  def test_encrypt_decrypt_with_empty_associated_data(
326      self, key_template_name, encrypt_lang, decrypt_lang
327  ):
328    key_template = _KMS_ENVELOPE_AEAD_KEY_TEMPLATES[key_template_name]
329    # Use the encryption language to generate the keyset proto.
330    keyset = testing_servers.new_keyset(encrypt_lang, key_template)
331    encrypt_primitive = testing_servers.remote_primitive(
332        encrypt_lang, keyset, aead.Aead)
333    plaintext = b'plaintext'
334    associated_data = b''
335    ciphertext = encrypt_primitive.encrypt(plaintext, associated_data)
336    decrypt_primitive = testing_servers.remote_primitive(
337        decrypt_lang, keyset, aead.Aead)
338    output = decrypt_primitive.decrypt(ciphertext, associated_data)
339    self.assertEqual(output, plaintext)
340
341  @parameterized.parameters(_kms_envelope_aead_test_cases())
342  def test_decryption_fails_with_wrong_aad(self, key_template_name,
343                                           encrypt_lang, decrypt_lang):
344    key_template = _KMS_ENVELOPE_AEAD_KEY_TEMPLATES[key_template_name]
345    # Use the encryption language to generate the keyset proto.
346    keyset = testing_servers.new_keyset(encrypt_lang, key_template)
347    encrypt_primitive = testing_servers.remote_primitive(
348        encrypt_lang, keyset, aead.Aead)
349    plaintext, associated_data = _get_plaintext_and_aad(key_template_name,
350                                                        encrypt_primitive.lang)
351    ciphertext = encrypt_primitive.encrypt(plaintext, associated_data)
352    decrypt_primitive = testing_servers.remote_primitive(
353        decrypt_lang, keyset, aead.Aead)
354    with self.assertRaises(tink.TinkError, msg='decryption failed'):
355      decrypt_primitive.decrypt(ciphertext, b'wrong aad')
356
357if __name__ == '__main__':
358  absltest.main()
359