1# Copyright 2019 Google LLC. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Test class for EcCommutativeCipher.""" 16 17import unittest 18from private_join_and_compute.py.ciphers import ec_cipher 19from private_join_and_compute.py.crypto_util import supported_curves 20from private_join_and_compute.py.crypto_util import supported_hashes 21 22 23class EcCommutativeCipherTest(unittest.TestCase): 24 25 def setUp(self): 26 super(EcCommutativeCipherTest, self).setUp() 27 self.client_cipher = ec_cipher.EcCipher(713) 28 self.server_cipher = ec_cipher.EcCipher(713) 29 30 def ReEncryptionSameId(self, cipher1, cipher2): 31 user_id = b'3274646578436540569872403985702934875092834502' 32 enc_id1 = cipher1.Encrypt(user_id) 33 enc_id2 = cipher2.Encrypt(user_id) 34 result1 = cipher2.ReEncrypt(enc_id1) 35 result2 = cipher1.ReEncrypt(enc_id2) 36 self.assertEqual(result1, result2) 37 38 def testReEncryptionSameId(self): 39 self.ReEncryptionSameId(self.client_cipher, self.server_cipher) 40 41 def testReEncryptionDifferentId(self): 42 user_id1 = b'3274646578436540569872403985702934875092834502' 43 user_id2 = b'7402039857096829483572943875209348524958235824' 44 enc_id1 = self.client_cipher.Encrypt(user_id1) 45 enc_id2 = self.server_cipher.Encrypt(user_id2) 46 result1 = self.server_cipher.ReEncrypt(enc_id1) 47 result2 = self.client_cipher.ReEncrypt(enc_id2) 48 self.assertNotEqual(result1, result2) 49 50 def testDecode(self): 51 user_id = b'7402039857096829483572943875209348524958235824' 52 enc_id1 = self.client_cipher.Encrypt(user_id) 53 enc_id2 = self.server_cipher.Encrypt(user_id) 54 result1 = self.server_cipher.ReEncrypt(enc_id1) 55 actual_enc_id1 = self.client_cipher.DecryptReEncryptedId(result1) 56 actual_enc_id2 = self.server_cipher.DecryptReEncryptedId(result1) 57 self.assertEqual(enc_id1, actual_enc_id2) 58 self.assertEqual(enc_id2, actual_enc_id1) 59 60 def testDifferentHashFunctions(self): 61 # freshly sampled key 62 sha256_cipher = ec_cipher.EcCipher( 63 curve_id=supported_curves.SupportedCurve.SECP256R1.id, 64 hash_type=supported_hashes.HashType.SHA256, 65 ) 66 sha512_cipher = ec_cipher.EcCipher( 67 curve_id=supported_curves.SupportedCurve.SECP256R1.id, 68 hash_type=supported_hashes.HashType.SHA512, 69 private_key_bytes=sha256_cipher.ec_key.priv_key_bytes, 70 ) 71 user_id = b'7402039857096829483572943875209348524958235824' 72 enc_id1 = sha256_cipher.Encrypt(user_id) 73 enc_id2 = sha512_cipher.Encrypt(user_id) 74 self.assertNotEqual(enc_id1, enc_id2) 75 76 77if __name__ == '__main__': 78 unittest.main() 79