• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Google LLC.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16"""Test class for ssl_util module."""
17
18import os
19import unittest
20from unittest import mock
21from unittest.mock import call
22from unittest.mock import patch
23
24from private_join_and_compute.py.crypto_util import converters
25from private_join_and_compute.py.crypto_util import ssl_util
26from private_join_and_compute.py.crypto_util.ssl_util import PRNG
27from private_join_and_compute.py.crypto_util.ssl_util import TempBNs
28
29
30class SSLUtilTest(unittest.TestCase):
31
32  def setUp(self):
33    self.test_path = os.path.join(
34        os.getcwd(), 'privacy/blinders/testing/data/random_oracle'
35    )
36
37  def testRandomOracleRaisesValueErrorForVeryLargeDomains(self):
38    self.assertRaises(ValueError, ssl_util.RandomOracle, 1, 1 << 130048)
39
40  def _GenericRandomTestForCasesThatShouldReturnOneNum(
41      self, expected_value, rand_func, *args
42  ):
43    # There is at least %50 chance one iteration would catch the error if
44    # rand_func also returns something outside the interval. Doing the same test
45    # 20 times would increase the overall chance to %99.9999 in the worst case
46    # scenario (i.e., the rand_func may return only one other element except the
47    # the expected value).
48    for _ in range(20):
49      actual_value = rand_func(*args)
50      self.assertEqual(
51          actual_value,
52          expected_value,
53          'The generated rand is {} but should be {} instead.'.format(
54              actual_value, expected_value
55          ),
56      )
57
58  def testGetRandomInRangeSingleNumber(self):
59    self._GenericRandomTestForCasesThatShouldReturnOneNum(
60        2**30 - 1, ssl_util.GetRandomInRange, 2**30 - 1, 2**30
61    )
62
63  def testGetRandomInRangeMultipleNumbers(self):
64    rand = ssl_util.GetRandomInRange(11111111111, 11111111111111111111111)
65    self.assertTrue(11111111111 <= rand < 11111111111111111111111)  # pylint: disable=g-generic-assert
66
67  def testModExp(self):
68    self.assertEqual(1, ssl_util.ModExp(3, 4, 80))
69
70  def testModInverse(self):
71    self.assertEqual(5, ssl_util.ModInverse(2, 9))
72
73  def testGetRandomInRangeReturnOnlyOneValueWhenIntervalIsOne(self):
74    random = ssl_util.GetRandomInRange(99999999999999998, 99999999999999999)
75    self.assertEqual(99999999999999998, random)
76
77  def testGetRandomInRangeReturnsAValueInRange(self):
78    random = ssl_util.GetRandomInRange(99999999999999998, 100000000000000000000)
79    self.assertLessEqual(99999999999999998, random)
80    self.assertLess(random, 100000000000000000000)
81
82  @patch(
83      'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
84  )
85  def testTempBNsForValues(self, mocked_ssl):
86    with TempBNs(x=10, y=20) as bn:
87      self.assertEqual(10, ssl_util.BnToLong(bn.x))
88      self.assertEqual(20, ssl_util.BnToLong(bn.y))
89      x_addr = bn.x
90      y_addr = bn.y
91    self.assertEqual(2, mocked_ssl.BN_free.call_count)
92    mocked_ssl.BN_free.assert_any_call(x_addr)
93    mocked_ssl.BN_free.assert_any_call(y_addr)
94
95  @patch(
96      'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
97  )
98  def testTempBNsForLists(self, mocked_ssl):
99    with TempBNs(x=10, y=[20, 30], z=40) as bn:
100      self.assertEqual(10, ssl_util.BnToLong(bn.x))
101      self.assertEqual(20, ssl_util.BnToLong(bn.y[0]))
102      self.assertEqual(30, ssl_util.BnToLong(bn.y[1]))
103      self.assertEqual(40, ssl_util.BnToLong(bn.z))
104      addrs = [bn.x, bn.y[0], bn.y[1], bn.z]
105    self.assertEqual(4, mocked_ssl.BN_free.call_count)
106    for addr in addrs:
107      mocked_ssl.BN_free.assert_any_call(addr)
108
109  @patch(
110      'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
111  )
112  def testTempBNsForBytes(self, mocked_ssl):
113    with TempBNs(x='\001', y=['\002', '\003'], z='\004') as bn:
114      self.assertEqual(1, ssl_util.BnToLong(bn.x))
115      self.assertEqual(2, ssl_util.BnToLong(bn.y[0]))
116      self.assertEqual(3, ssl_util.BnToLong(bn.y[1]))
117      self.assertEqual(4, ssl_util.BnToLong(bn.z))
118      addrs = [bn.x, bn.y[0], bn.y[1], bn.z]
119    self.assertEqual(4, mocked_ssl.BN_free.call_count)
120    for addr in addrs:
121      mocked_ssl.BN_free.assert_any_call(addr)
122
123  @patch(
124      'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
125  )
126  def testTempBNsForBytesOrLong(self, mocked_ssl):
127    with TempBNs(x=1, y=['\002', 3], z='\004') as bn:
128      self.assertEqual(1, ssl_util.BnToLong(bn.x))
129      self.assertEqual(2, ssl_util.BnToLong(bn.y[0]))
130      self.assertEqual(3, ssl_util.BnToLong(bn.y[1]))
131      self.assertEqual(4, ssl_util.BnToLong(bn.z))
132      addrs = [bn.x, bn.y[0], bn.y[1], bn.z]
133    self.assertEqual(4, mocked_ssl.BN_free.call_count)
134    for addr in addrs:
135      mocked_ssl.BN_free.assert_any_call(addr)
136
137  def testTempBNsRaisesAssertionErrorWhenAListIsEmpty(self):
138    self.assertRaises(AssertionError, TempBNs, x=10, y=[20, 30], z=[])
139
140  def testTempBNsRaisesAssertionErrorWhenAlreadySetKeyUsed(self):
141    self.assertRaises(AssertionError, TempBNs, _args=10)
142
143  def testBigNumInitializes(self):
144    big_num = ssl_util.BigNum.FromLongNumber(1)
145    self.assertEqual(1, big_num.GetAsLong())
146
147  def testOpenSSLHelperIsSingleton(self):
148    helper1 = ssl_util.OpenSSLHelper()
149    helper2 = ssl_util.OpenSSLHelper()
150    self.assertIs(helper1, helper2)
151
152  def testBigNumGeneratesSafePrime(self):
153    big_prime = ssl_util.BigNum.GenerateSafePrime(100)
154    self.assertTrue(
155        big_prime.IsPrime()
156        and (
157            big_prime.SubtractOne() / ssl_util.BigNum.FromLongNumber(2)
158        ).IsPrime()
159    )
160    self.assertEqual(100, big_prime.BitLength())
161
162  def testBigNumIsSafePrime(self):
163    prime = ssl_util.BigNum.FromLongNumber(23)
164    self.assertTrue(prime.IsSafePrime())
165    prime = ssl_util.BigNum.FromLongNumber(29)
166    self.assertFalse(prime.IsSafePrime())
167
168  def testBigNumGeneratesPrime(self):
169    big_prime = ssl_util.BigNum.GeneratePrime(100)
170    self.assertTrue(big_prime.IsPrime())
171    self.assertEqual(100, big_prime.BitLength())
172
173  def testBigNumGeneratesPrimeForSubGroup(self):
174    prime = ssl_util.BigNum.GeneratePrime(50)
175    big_prime = prime.GeneratePrimeForSubGroup(100)
176    self.assertTrue(big_prime.IsPrime())
177    self.assertEqual(ssl_util.BigNum.One(), big_prime % prime)
178    self.assertEqual(100, big_prime.BitLength())
179
180  def testBigNumBitLength(self):
181    big_prime = ssl_util.BigNum.FromLongNumber(15)
182    self.assertEqual(4, big_prime.BitLength())
183    big_prime = ssl_util.BigNum.FromLongNumber(16)
184    self.assertEqual(5, big_prime.BitLength())
185
186  def testBigNumAdds(self):
187    big_num1 = ssl_util.BigNum.FromLongNumber(2)
188    big_num2 = ssl_util.BigNum.FromLongNumber(3)
189    big_num3 = big_num1 + big_num2
190    self.assertEqual(2, big_num1.GetAsLong())
191    self.assertEqual(3, big_num2.GetAsLong())
192    self.assertEqual(5, big_num3.GetAsLong())
193
194  def testBigNumAddsInPlace(self):
195    big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
196    big_num2 = ssl_util.BigNum.FromLongNumber(3)
197    big_num1 += big_num2
198    self.assertEqual(5, big_num1.GetAsLong())
199    self.assertEqual(3, big_num2.GetAsLong())
200
201  def testBigNumSubtracts(self):
202    big_num1 = ssl_util.BigNum.FromLongNumber(4)
203    big_num2 = ssl_util.BigNum.FromLongNumber(3)
204    big_num3 = big_num1 - big_num2
205    self.assertEqual(4, big_num1.GetAsLong())
206    self.assertEqual(3, big_num2.GetAsLong())
207    self.assertEqual(1, big_num3.GetAsLong())
208
209  def testBigNumSubtractsInPlace(self):
210    big_num1 = ssl_util.BigNum.FromLongNumber(4).Mutable()
211    big_num2 = ssl_util.BigNum.FromLongNumber(3)
212    big_num1 -= big_num2
213    self.assertEqual(1, big_num1.GetAsLong())
214    self.assertEqual(3, big_num2.GetAsLong())
215
216  def testBigNumOperationsInPlaceRaisesValueErrorOnImmutableBigNums(self):
217    big_num1 = ssl_util.BigNum.FromLongNumber(2)
218    big_num2 = ssl_util.BigNum.FromLongNumber(3)
219    self.assertRaises(ValueError, big_num1.__iadd__, big_num2)
220
221  def testBigNumMultiplies(self):
222    big_num1 = ssl_util.BigNum.FromLongNumber(2)
223    big_num2 = ssl_util.BigNum.FromLongNumber(3)
224    big_num3 = big_num1 * big_num2
225    self.assertEqual(2, big_num1.GetAsLong())
226    self.assertEqual(3, big_num2.GetAsLong())
227    self.assertEqual(6, big_num3.GetAsLong())
228
229  def testBigNumMultipliesInPlace(self):
230    big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
231    big_num2 = ssl_util.BigNum.FromLongNumber(3)
232    big_num1 *= big_num2
233    self.assertEqual(6, big_num1.GetAsLong())
234    self.assertEqual(3, big_num2.GetAsLong())
235
236  def testBigNumMods(self):
237    big_num1 = ssl_util.BigNum.FromLongNumber(5)
238    big_num2 = ssl_util.BigNum.FromLongNumber(3)
239    big_num3 = big_num1 % big_num2
240    self.assertEqual(5, big_num1.GetAsLong())
241    self.assertEqual(3, big_num2.GetAsLong())
242    self.assertEqual(2, big_num3.GetAsLong())
243
244  def testBigNumModsInPlace(self):
245    big_num1 = ssl_util.BigNum.FromLongNumber(5).Mutable()
246    big_num2 = ssl_util.BigNum.FromLongNumber(3)
247    big_num1 %= big_num2
248    self.assertEqual(2, big_num1.GetAsLong())
249    self.assertEqual(3, big_num2.GetAsLong())
250
251  def testBigNumExponentiates(self):
252    big_num1 = ssl_util.BigNum.FromLongNumber(2)
253    big_num2 = ssl_util.BigNum.FromLongNumber(3)
254    big_num3 = big_num1**big_num2
255    self.assertEqual(2, big_num1.GetAsLong())
256    self.assertEqual(3, big_num2.GetAsLong())
257    self.assertEqual(8, big_num3.GetAsLong())
258
259  def testBigNumExponentiatesInPlace(self):
260    big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
261    big_num2 = ssl_util.BigNum.FromLongNumber(3)
262    big_num1 **= big_num2
263    self.assertEqual(8, big_num1.GetAsLong())
264    self.assertEqual(3, big_num2.GetAsLong())
265
266  def testBigNumRShifts(self):
267    big_num = ssl_util.BigNum.FromLongNumber(4)
268    big_num1 = big_num >> 1
269    self.assertEqual(2, big_num1.GetAsLong())
270    self.assertEqual(4, big_num.GetAsLong())
271
272  def testBigNumRShiftsInPlace(self):
273    big_num = ssl_util.BigNum.FromLongNumber(4)
274    big_num >>= 1
275    self.assertEqual(2, big_num.GetAsLong())
276
277  def testBigNumLShifts(self):
278    big_num = ssl_util.BigNum.FromLongNumber(4)
279    big_num1 = big_num << 1
280    self.assertEqual(8, big_num1.GetAsLong())
281    self.assertEqual(4, big_num.GetAsLong())
282
283  def testBigNumLShiftsInPlace(self):
284    big_num = ssl_util.BigNum.FromLongNumber(4)
285    big_num <<= 1
286    self.assertEqual(8, big_num.GetAsLong())
287
288  def testBigNumDivides(self):
289    big_num1 = ssl_util.BigNum.FromLongNumber(6)
290    big_num2 = ssl_util.BigNum.FromLongNumber(2)
291    self.assertEqual(3, (big_num1 / big_num2).GetAsLong())
292    self.assertEqual(6, big_num1.GetAsLong())
293    self.assertEqual(2, big_num2.GetAsLong())
294
295  def testBigNumDividesInPlace(self):
296    big_num1 = ssl_util.BigNum.FromLongNumber(6)
297    big_num2 = ssl_util.BigNum.FromLongNumber(2)
298    big_num1 /= big_num2
299    self.assertEqual(3, big_num1.GetAsLong())
300    self.assertEqual(2, big_num2.GetAsLong())
301
302  def testBigNumDivisionByZeroRaisesAssertionError(self):
303    big_num1 = ssl_util.BigNum.FromLongNumber(6)
304    big_num2 = ssl_util.BigNum.FromLongNumber(0)
305    self.assertRaises(AssertionError, big_num1.__div__, big_num2)
306
307  def testBigNumDivisionRaisesValueErrorWhenThereIsARemainder(self):
308    big_num1 = ssl_util.BigNum.FromLongNumber(5)
309    big_num2 = ssl_util.BigNum.FromLongNumber(2)
310    self.assertRaises(ValueError, big_num1.__div__, big_num2)
311
312  def testBigNumModMultiplies(self):
313    big_num1 = ssl_util.BigNum.FromLongNumber(2)
314    big_num2 = ssl_util.BigNum.FromLongNumber(3)
315    mod_big_num = ssl_util.BigNum.FromLongNumber(5)
316    big_num3 = big_num1.ModMul(big_num2, mod_big_num)
317    self.assertEqual(2, big_num1.GetAsLong())
318    self.assertEqual(3, big_num2.GetAsLong())
319    self.assertEqual(5, mod_big_num.GetAsLong())
320    self.assertEqual(1, big_num3.GetAsLong())
321
322  def testBigNumModMultipliesInPlace(self):
323    big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
324    big_num2 = ssl_util.BigNum.FromLongNumber(3)
325    mod_big_num = ssl_util.BigNum.FromLongNumber(5)
326    big_num1.IModMul(big_num2, mod_big_num)
327    self.assertEqual(1, big_num1.GetAsLong())
328    self.assertEqual(3, big_num2.GetAsLong())
329    self.assertEqual(5, mod_big_num.GetAsLong())
330
331  def testBigNumModExponentiates(self):
332    big_num1 = ssl_util.BigNum.FromLongNumber(2)
333    big_num2 = ssl_util.BigNum.FromLongNumber(3)
334    mod_big_num = ssl_util.BigNum.FromLongNumber(7)
335    big_num3 = big_num1.ModExp(big_num2, mod_big_num)
336    self.assertEqual(2, big_num1.GetAsLong())
337    self.assertEqual(3, big_num2.GetAsLong())
338    self.assertEqual(7, mod_big_num.GetAsLong())
339    self.assertEqual(1, big_num3.GetAsLong())
340
341  def testBigNumModExponentiatesInPlace(self):
342    big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
343    big_num2 = ssl_util.BigNum.FromLongNumber(3)
344    mod_big_num = ssl_util.BigNum.FromLongNumber(7)
345    big_num1.IModExp(big_num2, mod_big_num)
346    self.assertEqual(1, big_num1.GetAsLong())
347    self.assertEqual(3, big_num2.GetAsLong())
348    self.assertEqual(7, mod_big_num.GetAsLong())
349
350  def testBigNumGCD(self):
351    big_num1 = ssl_util.BigNum.FromLongNumber(11)
352    big_num2 = ssl_util.BigNum.FromLongNumber(20)
353    big_num3 = ssl_util.BigNum.FromLongNumber(15)
354    big_num4 = big_num2.GCD(big_num1)
355    big_num5 = big_num2.GCD(big_num3)
356    self.assertEqual(11, big_num1.GetAsLong())
357    self.assertEqual(20, big_num2.GetAsLong())
358    self.assertEqual(15, big_num3.GetAsLong())
359    self.assertEqual(1, big_num4.GetAsLong())
360    self.assertEqual(5, big_num5.GetAsLong())
361
362  def testBigNumModInverse(self):
363    big_num1 = ssl_util.BigNum.FromLongNumber(11)
364    big_num_mod = ssl_util.BigNum.FromLongNumber(20)
365    big_num_result = big_num1.ModInverse(big_num_mod)
366    self.assertEqual(11, big_num1.GetAsLong())
367    self.assertEqual(20, big_num_mod.GetAsLong())
368    self.assertEqual(11, big_num_result.GetAsLong())
369
370  def testBigNumModSqrt(self):
371    big_num1 = ssl_util.BigNum.FromLongNumber(11)
372    big_num_mod = ssl_util.BigNum.FromLongNumber(19)
373    big_num_result = big_num1.ModSqrt(big_num_mod)
374    self.assertEqual(11, big_num1.GetAsLong())
375    self.assertEqual(19, big_num_mod.GetAsLong())
376    self.assertEqual(7, big_num_result.GetAsLong())
377
378  def testBigNumModInverseInvalidForNotRelativelyPrimes(self):
379    big_num1 = ssl_util.BigNum.FromLongNumber(10)
380    big_num_mod = ssl_util.BigNum.FromLongNumber(20)
381    self.assertRaises(ValueError, big_num1.ModInverse, big_num_mod)
382    self.assertEqual(10, big_num1.GetAsLong())
383    self.assertEqual(20, big_num_mod.GetAsLong())
384
385  def testBigNumNegates(self):
386    big_num = ssl_util.BigNum.FromLongNumber(10)
387    big_num = big_num.ModNegate(ssl_util.BigNum.FromLongNumber(6))
388    self.assertEqual(2, big_num.GetAsLong())
389
390  def testBigNumAddsOne(self):
391    big_num = ssl_util.BigNum.FromLongNumber(10)
392    self.assertEqual(11, big_num.AddOne().GetAsLong())
393
394  def testBigNumSubtractOne(self):
395    big_num = ssl_util.BigNum.FromLongNumber(10)
396    self.assertEqual(9, big_num.SubtractOne().GetAsLong())
397
398  def testBigNumGeneratesRandsBetweenZeroAndGivenBigNum(self):
399    big_num = ssl_util.BigNum.FromLongNumber(3)
400    big_rand = big_num.GenerateRand()
401    self.assertTrue(0 <= big_rand.GetAsLong() < 3)  # pylint: disable=g-generic-assert
402
403  def testBigNumGeneratesZeroForRandWhenTheUpperBoundIsOne(self):
404    big_num = ssl_util.BigNum.FromLongNumber(1)
405    self._GenericRandomTestForCasesThatShouldReturnOneNum(
406        ssl_util.BigNum.Zero(), big_num.GenerateRand
407    )
408
409  def testBigNumGeneratesRandsBetweenStartAndGivenBigNum(self):
410    big_num = ssl_util.BigNum.FromLongNumber(3)
411    big_rand = big_num.GenerateRandWithStart(ssl_util.BigNum.FromLongNumber(1))
412    self.assertTrue(1 <= big_rand.GetAsLong() < 3)  # pylint: disable=g-generic-assert
413
414  def testBigNumGeneratesSingleRandWhenIntervalIsOne(self):
415    start = ssl_util.BigNum.FromLongNumber(2**30 - 1)
416    end = ssl_util.BigNum.FromLongNumber(2**30)
417    self._GenericRandomTestForCasesThatShouldReturnOneNum(
418        start, end.GenerateRandWithStart, start
419    )
420
421  def testBigNumIsBitSet(self):
422    big_num = ssl_util.BigNum.FromLongNumber(11)
423    self.assertTrue(big_num.IsBitSet(0))
424    self.assertTrue(big_num.IsBitSet(1))
425    self.assertFalse(big_num.IsBitSet(2))
426    self.assertTrue(big_num.IsBitSet(3))
427
428  def testBigNumEq(self):
429    big_num1 = ssl_util.BigNum.FromLongNumber(11)
430    big_num2 = ssl_util.BigNum.FromLongNumber(11)
431    self.assertEqual(big_num1, big_num2)
432
433  def testBigNumNeq(self):
434    big_num1 = ssl_util.BigNum.FromLongNumber(11)
435    big_num2 = ssl_util.BigNum.FromLongNumber(12)
436    self.assertNotEqual(big_num1, big_num2)
437
438  def testBigNumGt(self):
439    big_num1 = ssl_util.BigNum.FromLongNumber(11)
440    big_num2 = ssl_util.BigNum.FromLongNumber(12)
441    self.assertGreater(big_num2, big_num1)
442
443  def testBigNumGtEq(self):
444    big_num1 = ssl_util.BigNum.FromLongNumber(11)
445    big_num2 = ssl_util.BigNum.FromLongNumber(11)
446    big_num3 = ssl_util.BigNum.FromLongNumber(12)
447    self.assertGreaterEqual(big_num2, big_num1)
448    self.assertGreaterEqual(big_num3, big_num2)
449
450  def testBigNumComparisonWithOtherTypesRaisesValueError(self):
451    big_num1 = ssl_util.BigNum.FromLongNumber(11)
452    self.assertRaises(ValueError, big_num1.__lt__, 11)
453
454  def testClonesCreatesANewBigNum(self):
455    big_num = ssl_util.BigNum.FromLongNumber(0).Mutable()
456    clone_big_num = big_num.Clone()
457    big_num += ssl_util.BigNum.One()
458    self.assertEqual(ssl_util.BigNum.Zero(), clone_big_num)
459    self.assertEqual(ssl_util.BigNum.One(), big_num)
460
461  def testBigNumCacheIsSingleton(self):
462    cache1 = ssl_util.BigNumCache(10)
463    cache2 = ssl_util.BigNumCache(11)
464    self.assertIs(cache1, cache2)
465
466  def testBigNumCacheReturnsTheSameCachedBigNum(self):
467    cache = ssl_util.BigNumCache(10)
468    self.assertIs(cache.Get(1), cache.Get(1))
469
470  def testBigNumCacheReturnsDifferentBigNumWhenCacheIsFull(self):
471    cache = ssl_util.BigNumCache(10)
472    for i in range(10):
473      cache.Get(i)
474    self.assertIsNot(cache.Get(11), cache.Get(11))
475
476  def testStringRepresentation(self):
477    big_num = ssl_util.BigNum.FromLongNumber(11)
478    self.assertEqual('11', '{}'.format(big_num))
479
480
481class _HashMock(object):
482
483  def __init__(self):
484    self.with_patch = patch('hashlib.sha512')
485
486  def __enter__(self):
487    hashlib_mock = self.with_patch.__enter__()
488    sha512_mock = mock.MagicMock()
489    hashlib_mock.return_value = sha512_mock
490    return sha512_mock, hashlib_mock
491
492  def __exit__(self, t, value, traceback):
493    self.with_patch.__exit__(t, value, traceback)
494
495
496class PRNGTest(unittest.TestCase):
497
498  def testPRNG(self):
499    with _HashMock() as (hash_mock, hashlib_mock):
500      hash_mock.digest.return_value = b'\x7f' + b'\x01' * 64
501      prng = PRNG(b'\x02' * 32)
502      self.assertEqual(0, prng.GetRand(2))
503      self.assertEqual(1, prng.GetRand(256))
504      self.assertEqual(2, prng.GetRand(257))
505      self.assertEqual(128, prng.GetRand(32768))
506      self.assertEqual(257, prng.GetRand(65536))
507      hash_mock.digest.assert_called_once_with()
508      hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x02' * 32)
509
510  def testGetNBitRandReturnsAtLeastUpperLimit(self):
511    with _HashMock() as (hash_mock, hashlib_mock):
512      hash_mock.digest.return_value = b'\x81\x82\xff\x05' + b'\x00' * 60
513      prng = PRNG(b'\x00' * 32)
514      self.assertEqual(5, prng.GetRand(129))
515      hash_mock.digest.assert_called_once_with()
516      hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x00' * 32)
517
518  def testRaisesValueErrorWhenSeedIsNotAtLeastFourBytes(self):
519    self.assertRaises(ValueError, PRNG, b'\x00' * 31)
520
521  def testRaisesValueErrorWhenMaxNumberOfHashingIsDone(self):
522    prng = PRNG(b'\x00' * 32, 1)
523    upper_limit = 1 << 512
524    for _ in range(256):
525      prng.GetRand(upper_limit)
526    self.assertRaises(AssertionError, prng.GetRand, 2)
527    self.assertEqual(0, prng.GetRand(1))
528
529  def testGetsMoreBytesWithHashingUntilSufficientBytesArePresent(self):
530    with _HashMock() as (hash_mock, _):
531      hash_mock.digest.side_effect = [
532          b'\x80' + b'\x00' * 63,
533          b'\x00' * 64,
534          b'\x00' * 64,
535      ]
536      prng = PRNG(b'\x00' * 32, 1)
537      upper_limit = 1 << 1025
538      self.assertEqual(1 << 1024, prng.GetRand(upper_limit))
539      hash_mock.digest.assert_has_calls([call(), call(), call()])
540
541
542if __name__ == '__main__':
543  unittest.main()
544