• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Google LLC.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Test class for EcCommutativeCipher."""
16
17import unittest
18from private_join_and_compute.py.ciphers import ec_cipher
19from private_join_and_compute.py.crypto_util import supported_curves
20from private_join_and_compute.py.crypto_util import supported_hashes
21
22
23class EcCommutativeCipherTest(unittest.TestCase):
24
25  def setUp(self):
26    super(EcCommutativeCipherTest, self).setUp()
27    self.client_cipher = ec_cipher.EcCipher(713)
28    self.server_cipher = ec_cipher.EcCipher(713)
29
30  def ReEncryptionSameId(self, cipher1, cipher2):
31    user_id = b'3274646578436540569872403985702934875092834502'
32    enc_id1 = cipher1.Encrypt(user_id)
33    enc_id2 = cipher2.Encrypt(user_id)
34    result1 = cipher2.ReEncrypt(enc_id1)
35    result2 = cipher1.ReEncrypt(enc_id2)
36    self.assertEqual(result1, result2)
37
38  def testReEncryptionSameId(self):
39    self.ReEncryptionSameId(self.client_cipher, self.server_cipher)
40
41  def testReEncryptionDifferentId(self):
42    user_id1 = b'3274646578436540569872403985702934875092834502'
43    user_id2 = b'7402039857096829483572943875209348524958235824'
44    enc_id1 = self.client_cipher.Encrypt(user_id1)
45    enc_id2 = self.server_cipher.Encrypt(user_id2)
46    result1 = self.server_cipher.ReEncrypt(enc_id1)
47    result2 = self.client_cipher.ReEncrypt(enc_id2)
48    self.assertNotEqual(result1, result2)
49
50  def testDecode(self):
51    user_id = b'7402039857096829483572943875209348524958235824'
52    enc_id1 = self.client_cipher.Encrypt(user_id)
53    enc_id2 = self.server_cipher.Encrypt(user_id)
54    result1 = self.server_cipher.ReEncrypt(enc_id1)
55    actual_enc_id1 = self.client_cipher.DecryptReEncryptedId(result1)
56    actual_enc_id2 = self.server_cipher.DecryptReEncryptedId(result1)
57    self.assertEqual(enc_id1, actual_enc_id2)
58    self.assertEqual(enc_id2, actual_enc_id1)
59
60  def testDifferentHashFunctions(self):
61    # freshly sampled key
62    sha256_cipher = ec_cipher.EcCipher(
63        curve_id=supported_curves.SupportedCurve.SECP256R1.id,
64        hash_type=supported_hashes.HashType.SHA256,
65    )
66    sha512_cipher = ec_cipher.EcCipher(
67        curve_id=supported_curves.SupportedCurve.SECP256R1.id,
68        hash_type=supported_hashes.HashType.SHA512,
69        private_key_bytes=sha256_cipher.ec_key.priv_key_bytes,
70    )
71    user_id = b'7402039857096829483572943875209348524958235824'
72    enc_id1 = sha256_cipher.Encrypt(user_id)
73    enc_id2 = sha512_cipher.Encrypt(user_id)
74    self.assertNotEqual(enc_id1, enc_id2)
75
76
77if __name__ == '__main__':
78  unittest.main()
79