# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This class implements helper functions for testing.""" import os from typing import Mapping from tink.proto import tink_pb2 from tink import aead from tink import core from tink import daead from tink import hybrid from tink import mac from tink import prf from tink import signature as pk_signature _RELATIVE_TESTDATA_PATH = 'tink_py/testdata' def tink_py_testdata_path() -> str: """Returns the path to the test data directory to be used for testing.""" # List of pairs . testdata_paths = [] if 'TINK_PYTHON_ROOT_PATH' in os.environ: testdata_paths.append(('TINK_PYTHON_ROOT_PATH', os.path.join(os.environ['TINK_PYTHON_ROOT_PATH'], 'testdata'))) if 'TEST_SRCDIR' in os.environ: testdata_paths.append(('TEST_SRCDIR', os.path.join(os.environ['TEST_SRCDIR'], _RELATIVE_TESTDATA_PATH))) for env_variable, testdata_path in testdata_paths: # Return the first path that is encountered. if not os.path.exists(testdata_path): raise FileNotFoundError(f'Variable {env_variable} is set but has an ' + f'invalid path {testdata_path}') return testdata_path raise ValueError('No path environment variable set among ' + 'TINK_PYTHON_ROOT_PATH, TEST_SRCDIR') def fake_key( value: bytes = b'fakevalue', type_url: str = 'fakeurl', key_material_type: tink_pb2.KeyData.KeyMaterialType = tink_pb2.KeyData .SYMMETRIC, key_id: int = 1234, status: tink_pb2.KeyStatusType = tink_pb2.ENABLED, output_prefix_type: tink_pb2.OutputPrefixType = tink_pb2.TINK ) -> tink_pb2.Keyset.Key: """Returns a fake but valid key.""" key = tink_pb2.Keyset.Key( key_id=key_id, status=status, output_prefix_type=output_prefix_type) key.key_data.type_url = type_url key.key_data.value = value key.key_data.key_material_type = key_material_type return key class FakeMac(mac.Mac): """A fake MAC implementation.""" def __init__(self, name: str = 'FakeMac'): self._name = name def compute_mac(self, data: bytes) -> bytes: return data + b'|' + self._name.encode() def verify_mac(self, mac_value: bytes, data: bytes) -> None: if mac_value != data + b'|' + self._name.encode(): raise core.TinkError('invalid mac ' + mac_value.decode()) class FakeAead(aead.Aead): """A fake AEAD implementation.""" def __init__(self, name: str = 'FakeAead'): self._name = name def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: return plaintext + b'|' + associated_data + b'|' + self._name.encode() def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes: data = ciphertext.split(b'|') if (len(data) < 3 or data[1] != associated_data or data[2] != self._name.encode()): raise core.TinkError('failed to decrypt ciphertext ' + ciphertext.decode()) return data[0] class FakeDeterministicAead(daead.DeterministicAead): """A fake Deterministic AEAD implementation.""" def __init__(self, name: str = 'FakeDeterministicAead'): self._name = name def encrypt_deterministically(self, plaintext: bytes, associated_data: bytes) -> bytes: return plaintext + b'|' + associated_data + b'|' + self._name.encode() def decrypt_deterministically(self, ciphertext: bytes, associated_data: bytes) -> bytes: data = ciphertext.split(b'|') if (len(data) < 3 or data[1] != associated_data or data[2] != self._name.encode()): raise core.TinkError('failed to decrypt ciphertext ' + ciphertext.decode()) return data[0] class FakeHybridDecrypt(hybrid.HybridDecrypt): """A fake HybridEncrypt implementation.""" def __init__(self, name: str = 'Hybrid'): self._name = name def decrypt(self, ciphertext: bytes, context_info: bytes) -> bytes: data = ciphertext.split(b'|') if (len(data) < 3 or data[1] != context_info or data[2] != self._name.encode()): raise core.TinkError('failed to decrypt ciphertext ' + ciphertext.decode()) return data[0] class FakeHybridEncrypt(hybrid.HybridEncrypt): """A fake HybridEncrypt implementation.""" def __init__(self, name: str = 'Hybrid'): self._name = name def encrypt(self, plaintext: bytes, context_info: bytes) -> bytes: return plaintext + b'|' + context_info + b'|' + self._name.encode() class FakePublicKeySign(pk_signature.PublicKeySign): """A fake PublicKeySign implementation.""" def __init__(self, name: str = 'FakePublicKeySign'): self._name = name def sign(self, data: bytes) -> bytes: return data + b'|' + self._name.encode() class FakePublicKeyVerify(pk_signature.PublicKeyVerify): """A fake PublicKeyVerify implementation.""" def __init__(self, name: str = 'FakePublicKeyVerify'): self._name = name def verify(self, signature: bytes, data: bytes): if signature != data + b'|' + self._name.encode(): raise core.TinkError('invalid signature ' + signature.decode()) class FakePrf(prf.Prf): """A fake Prf implementation.""" def __init__(self, name: str = 'FakePrf'): self._name = name def compute(self, input_data: bytes, output_length: int) -> bytes: if output_length > 32: raise core.TinkError('invalid output_length') output = ( input_data + b'|' + self._name.encode() + b'|' + b''.join([b'*' for _ in range(output_length)])) return output[:output_length] class FakePrfSet(prf.PrfSet): """A fake PrfSet implementation that contains exactly one Prf.""" def __init__(self, name: str = 'FakePrf'): self._prf = FakePrf(name) def primary_id(self) -> int: return 0 def all(self) -> Mapping[int, prf.Prf]: return {0: self._prf} def primary(self) -> prf.Prf: return self._prf