1# Copyright 2019 Google LLC. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16"""Test class for ssl_util module.""" 17 18import os 19import unittest 20from unittest import mock 21from unittest.mock import call 22from unittest.mock import patch 23 24from private_join_and_compute.py.crypto_util import converters 25from private_join_and_compute.py.crypto_util import ssl_util 26from private_join_and_compute.py.crypto_util.ssl_util import PRNG 27from private_join_and_compute.py.crypto_util.ssl_util import TempBNs 28 29 30class SSLUtilTest(unittest.TestCase): 31 32 def setUp(self): 33 self.test_path = os.path.join( 34 os.getcwd(), 'privacy/blinders/testing/data/random_oracle' 35 ) 36 37 def testRandomOracleRaisesValueErrorForVeryLargeDomains(self): 38 self.assertRaises(ValueError, ssl_util.RandomOracle, 1, 1 << 130048) 39 40 def _GenericRandomTestForCasesThatShouldReturnOneNum( 41 self, expected_value, rand_func, *args 42 ): 43 # There is at least %50 chance one iteration would catch the error if 44 # rand_func also returns something outside the interval. Doing the same test 45 # 20 times would increase the overall chance to %99.9999 in the worst case 46 # scenario (i.e., the rand_func may return only one other element except the 47 # the expected value). 48 for _ in range(20): 49 actual_value = rand_func(*args) 50 self.assertEqual( 51 actual_value, 52 expected_value, 53 'The generated rand is {} but should be {} instead.'.format( 54 actual_value, expected_value 55 ), 56 ) 57 58 def testGetRandomInRangeSingleNumber(self): 59 self._GenericRandomTestForCasesThatShouldReturnOneNum( 60 2**30 - 1, ssl_util.GetRandomInRange, 2**30 - 1, 2**30 61 ) 62 63 def testGetRandomInRangeMultipleNumbers(self): 64 rand = ssl_util.GetRandomInRange(11111111111, 11111111111111111111111) 65 self.assertTrue(11111111111 <= rand < 11111111111111111111111) # pylint: disable=g-generic-assert 66 67 def testModExp(self): 68 self.assertEqual(1, ssl_util.ModExp(3, 4, 80)) 69 70 def testModInverse(self): 71 self.assertEqual(5, ssl_util.ModInverse(2, 9)) 72 73 def testGetRandomInRangeReturnOnlyOneValueWhenIntervalIsOne(self): 74 random = ssl_util.GetRandomInRange(99999999999999998, 99999999999999999) 75 self.assertEqual(99999999999999998, random) 76 77 def testGetRandomInRangeReturnsAValueInRange(self): 78 random = ssl_util.GetRandomInRange(99999999999999998, 100000000000000000000) 79 self.assertLessEqual(99999999999999998, random) 80 self.assertLess(random, 100000000000000000000) 81 82 @patch( 83 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl 84 ) 85 def testTempBNsForValues(self, mocked_ssl): 86 with TempBNs(x=10, y=20) as bn: 87 self.assertEqual(10, ssl_util.BnToLong(bn.x)) 88 self.assertEqual(20, ssl_util.BnToLong(bn.y)) 89 x_addr = bn.x 90 y_addr = bn.y 91 self.assertEqual(2, mocked_ssl.BN_free.call_count) 92 mocked_ssl.BN_free.assert_any_call(x_addr) 93 mocked_ssl.BN_free.assert_any_call(y_addr) 94 95 @patch( 96 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl 97 ) 98 def testTempBNsForLists(self, mocked_ssl): 99 with TempBNs(x=10, y=[20, 30], z=40) as bn: 100 self.assertEqual(10, ssl_util.BnToLong(bn.x)) 101 self.assertEqual(20, ssl_util.BnToLong(bn.y[0])) 102 self.assertEqual(30, ssl_util.BnToLong(bn.y[1])) 103 self.assertEqual(40, ssl_util.BnToLong(bn.z)) 104 addrs = [bn.x, bn.y[0], bn.y[1], bn.z] 105 self.assertEqual(4, mocked_ssl.BN_free.call_count) 106 for addr in addrs: 107 mocked_ssl.BN_free.assert_any_call(addr) 108 109 @patch( 110 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl 111 ) 112 def testTempBNsForBytes(self, mocked_ssl): 113 with TempBNs(x='\001', y=['\002', '\003'], z='\004') as bn: 114 self.assertEqual(1, ssl_util.BnToLong(bn.x)) 115 self.assertEqual(2, ssl_util.BnToLong(bn.y[0])) 116 self.assertEqual(3, ssl_util.BnToLong(bn.y[1])) 117 self.assertEqual(4, ssl_util.BnToLong(bn.z)) 118 addrs = [bn.x, bn.y[0], bn.y[1], bn.z] 119 self.assertEqual(4, mocked_ssl.BN_free.call_count) 120 for addr in addrs: 121 mocked_ssl.BN_free.assert_any_call(addr) 122 123 @patch( 124 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl 125 ) 126 def testTempBNsForBytesOrLong(self, mocked_ssl): 127 with TempBNs(x=1, y=['\002', 3], z='\004') as bn: 128 self.assertEqual(1, ssl_util.BnToLong(bn.x)) 129 self.assertEqual(2, ssl_util.BnToLong(bn.y[0])) 130 self.assertEqual(3, ssl_util.BnToLong(bn.y[1])) 131 self.assertEqual(4, ssl_util.BnToLong(bn.z)) 132 addrs = [bn.x, bn.y[0], bn.y[1], bn.z] 133 self.assertEqual(4, mocked_ssl.BN_free.call_count) 134 for addr in addrs: 135 mocked_ssl.BN_free.assert_any_call(addr) 136 137 def testTempBNsRaisesAssertionErrorWhenAListIsEmpty(self): 138 self.assertRaises(AssertionError, TempBNs, x=10, y=[20, 30], z=[]) 139 140 def testTempBNsRaisesAssertionErrorWhenAlreadySetKeyUsed(self): 141 self.assertRaises(AssertionError, TempBNs, _args=10) 142 143 def testBigNumInitializes(self): 144 big_num = ssl_util.BigNum.FromLongNumber(1) 145 self.assertEqual(1, big_num.GetAsLong()) 146 147 def testOpenSSLHelperIsSingleton(self): 148 helper1 = ssl_util.OpenSSLHelper() 149 helper2 = ssl_util.OpenSSLHelper() 150 self.assertIs(helper1, helper2) 151 152 def testBigNumGeneratesSafePrime(self): 153 big_prime = ssl_util.BigNum.GenerateSafePrime(100) 154 self.assertTrue( 155 big_prime.IsPrime() 156 and ( 157 big_prime.SubtractOne() / ssl_util.BigNum.FromLongNumber(2) 158 ).IsPrime() 159 ) 160 self.assertEqual(100, big_prime.BitLength()) 161 162 def testBigNumIsSafePrime(self): 163 prime = ssl_util.BigNum.FromLongNumber(23) 164 self.assertTrue(prime.IsSafePrime()) 165 prime = ssl_util.BigNum.FromLongNumber(29) 166 self.assertFalse(prime.IsSafePrime()) 167 168 def testBigNumGeneratesPrime(self): 169 big_prime = ssl_util.BigNum.GeneratePrime(100) 170 self.assertTrue(big_prime.IsPrime()) 171 self.assertEqual(100, big_prime.BitLength()) 172 173 def testBigNumGeneratesPrimeForSubGroup(self): 174 prime = ssl_util.BigNum.GeneratePrime(50) 175 big_prime = prime.GeneratePrimeForSubGroup(100) 176 self.assertTrue(big_prime.IsPrime()) 177 self.assertEqual(ssl_util.BigNum.One(), big_prime % prime) 178 self.assertEqual(100, big_prime.BitLength()) 179 180 def testBigNumBitLength(self): 181 big_prime = ssl_util.BigNum.FromLongNumber(15) 182 self.assertEqual(4, big_prime.BitLength()) 183 big_prime = ssl_util.BigNum.FromLongNumber(16) 184 self.assertEqual(5, big_prime.BitLength()) 185 186 def testBigNumAdds(self): 187 big_num1 = ssl_util.BigNum.FromLongNumber(2) 188 big_num2 = ssl_util.BigNum.FromLongNumber(3) 189 big_num3 = big_num1 + big_num2 190 self.assertEqual(2, big_num1.GetAsLong()) 191 self.assertEqual(3, big_num2.GetAsLong()) 192 self.assertEqual(5, big_num3.GetAsLong()) 193 194 def testBigNumAddsInPlace(self): 195 big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() 196 big_num2 = ssl_util.BigNum.FromLongNumber(3) 197 big_num1 += big_num2 198 self.assertEqual(5, big_num1.GetAsLong()) 199 self.assertEqual(3, big_num2.GetAsLong()) 200 201 def testBigNumSubtracts(self): 202 big_num1 = ssl_util.BigNum.FromLongNumber(4) 203 big_num2 = ssl_util.BigNum.FromLongNumber(3) 204 big_num3 = big_num1 - big_num2 205 self.assertEqual(4, big_num1.GetAsLong()) 206 self.assertEqual(3, big_num2.GetAsLong()) 207 self.assertEqual(1, big_num3.GetAsLong()) 208 209 def testBigNumSubtractsInPlace(self): 210 big_num1 = ssl_util.BigNum.FromLongNumber(4).Mutable() 211 big_num2 = ssl_util.BigNum.FromLongNumber(3) 212 big_num1 -= big_num2 213 self.assertEqual(1, big_num1.GetAsLong()) 214 self.assertEqual(3, big_num2.GetAsLong()) 215 216 def testBigNumOperationsInPlaceRaisesValueErrorOnImmutableBigNums(self): 217 big_num1 = ssl_util.BigNum.FromLongNumber(2) 218 big_num2 = ssl_util.BigNum.FromLongNumber(3) 219 self.assertRaises(ValueError, big_num1.__iadd__, big_num2) 220 221 def testBigNumMultiplies(self): 222 big_num1 = ssl_util.BigNum.FromLongNumber(2) 223 big_num2 = ssl_util.BigNum.FromLongNumber(3) 224 big_num3 = big_num1 * big_num2 225 self.assertEqual(2, big_num1.GetAsLong()) 226 self.assertEqual(3, big_num2.GetAsLong()) 227 self.assertEqual(6, big_num3.GetAsLong()) 228 229 def testBigNumMultipliesInPlace(self): 230 big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() 231 big_num2 = ssl_util.BigNum.FromLongNumber(3) 232 big_num1 *= big_num2 233 self.assertEqual(6, big_num1.GetAsLong()) 234 self.assertEqual(3, big_num2.GetAsLong()) 235 236 def testBigNumMods(self): 237 big_num1 = ssl_util.BigNum.FromLongNumber(5) 238 big_num2 = ssl_util.BigNum.FromLongNumber(3) 239 big_num3 = big_num1 % big_num2 240 self.assertEqual(5, big_num1.GetAsLong()) 241 self.assertEqual(3, big_num2.GetAsLong()) 242 self.assertEqual(2, big_num3.GetAsLong()) 243 244 def testBigNumModsInPlace(self): 245 big_num1 = ssl_util.BigNum.FromLongNumber(5).Mutable() 246 big_num2 = ssl_util.BigNum.FromLongNumber(3) 247 big_num1 %= big_num2 248 self.assertEqual(2, big_num1.GetAsLong()) 249 self.assertEqual(3, big_num2.GetAsLong()) 250 251 def testBigNumExponentiates(self): 252 big_num1 = ssl_util.BigNum.FromLongNumber(2) 253 big_num2 = ssl_util.BigNum.FromLongNumber(3) 254 big_num3 = big_num1**big_num2 255 self.assertEqual(2, big_num1.GetAsLong()) 256 self.assertEqual(3, big_num2.GetAsLong()) 257 self.assertEqual(8, big_num3.GetAsLong()) 258 259 def testBigNumExponentiatesInPlace(self): 260 big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() 261 big_num2 = ssl_util.BigNum.FromLongNumber(3) 262 big_num1 **= big_num2 263 self.assertEqual(8, big_num1.GetAsLong()) 264 self.assertEqual(3, big_num2.GetAsLong()) 265 266 def testBigNumRShifts(self): 267 big_num = ssl_util.BigNum.FromLongNumber(4) 268 big_num1 = big_num >> 1 269 self.assertEqual(2, big_num1.GetAsLong()) 270 self.assertEqual(4, big_num.GetAsLong()) 271 272 def testBigNumRShiftsInPlace(self): 273 big_num = ssl_util.BigNum.FromLongNumber(4) 274 big_num >>= 1 275 self.assertEqual(2, big_num.GetAsLong()) 276 277 def testBigNumLShifts(self): 278 big_num = ssl_util.BigNum.FromLongNumber(4) 279 big_num1 = big_num << 1 280 self.assertEqual(8, big_num1.GetAsLong()) 281 self.assertEqual(4, big_num.GetAsLong()) 282 283 def testBigNumLShiftsInPlace(self): 284 big_num = ssl_util.BigNum.FromLongNumber(4) 285 big_num <<= 1 286 self.assertEqual(8, big_num.GetAsLong()) 287 288 def testBigNumDivides(self): 289 big_num1 = ssl_util.BigNum.FromLongNumber(6) 290 big_num2 = ssl_util.BigNum.FromLongNumber(2) 291 self.assertEqual(3, (big_num1 / big_num2).GetAsLong()) 292 self.assertEqual(6, big_num1.GetAsLong()) 293 self.assertEqual(2, big_num2.GetAsLong()) 294 295 def testBigNumDividesInPlace(self): 296 big_num1 = ssl_util.BigNum.FromLongNumber(6) 297 big_num2 = ssl_util.BigNum.FromLongNumber(2) 298 big_num1 /= big_num2 299 self.assertEqual(3, big_num1.GetAsLong()) 300 self.assertEqual(2, big_num2.GetAsLong()) 301 302 def testBigNumDivisionByZeroRaisesAssertionError(self): 303 big_num1 = ssl_util.BigNum.FromLongNumber(6) 304 big_num2 = ssl_util.BigNum.FromLongNumber(0) 305 self.assertRaises(AssertionError, big_num1.__div__, big_num2) 306 307 def testBigNumDivisionRaisesValueErrorWhenThereIsARemainder(self): 308 big_num1 = ssl_util.BigNum.FromLongNumber(5) 309 big_num2 = ssl_util.BigNum.FromLongNumber(2) 310 self.assertRaises(ValueError, big_num1.__div__, big_num2) 311 312 def testBigNumModMultiplies(self): 313 big_num1 = ssl_util.BigNum.FromLongNumber(2) 314 big_num2 = ssl_util.BigNum.FromLongNumber(3) 315 mod_big_num = ssl_util.BigNum.FromLongNumber(5) 316 big_num3 = big_num1.ModMul(big_num2, mod_big_num) 317 self.assertEqual(2, big_num1.GetAsLong()) 318 self.assertEqual(3, big_num2.GetAsLong()) 319 self.assertEqual(5, mod_big_num.GetAsLong()) 320 self.assertEqual(1, big_num3.GetAsLong()) 321 322 def testBigNumModMultipliesInPlace(self): 323 big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() 324 big_num2 = ssl_util.BigNum.FromLongNumber(3) 325 mod_big_num = ssl_util.BigNum.FromLongNumber(5) 326 big_num1.IModMul(big_num2, mod_big_num) 327 self.assertEqual(1, big_num1.GetAsLong()) 328 self.assertEqual(3, big_num2.GetAsLong()) 329 self.assertEqual(5, mod_big_num.GetAsLong()) 330 331 def testBigNumModExponentiates(self): 332 big_num1 = ssl_util.BigNum.FromLongNumber(2) 333 big_num2 = ssl_util.BigNum.FromLongNumber(3) 334 mod_big_num = ssl_util.BigNum.FromLongNumber(7) 335 big_num3 = big_num1.ModExp(big_num2, mod_big_num) 336 self.assertEqual(2, big_num1.GetAsLong()) 337 self.assertEqual(3, big_num2.GetAsLong()) 338 self.assertEqual(7, mod_big_num.GetAsLong()) 339 self.assertEqual(1, big_num3.GetAsLong()) 340 341 def testBigNumModExponentiatesInPlace(self): 342 big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable() 343 big_num2 = ssl_util.BigNum.FromLongNumber(3) 344 mod_big_num = ssl_util.BigNum.FromLongNumber(7) 345 big_num1.IModExp(big_num2, mod_big_num) 346 self.assertEqual(1, big_num1.GetAsLong()) 347 self.assertEqual(3, big_num2.GetAsLong()) 348 self.assertEqual(7, mod_big_num.GetAsLong()) 349 350 def testBigNumGCD(self): 351 big_num1 = ssl_util.BigNum.FromLongNumber(11) 352 big_num2 = ssl_util.BigNum.FromLongNumber(20) 353 big_num3 = ssl_util.BigNum.FromLongNumber(15) 354 big_num4 = big_num2.GCD(big_num1) 355 big_num5 = big_num2.GCD(big_num3) 356 self.assertEqual(11, big_num1.GetAsLong()) 357 self.assertEqual(20, big_num2.GetAsLong()) 358 self.assertEqual(15, big_num3.GetAsLong()) 359 self.assertEqual(1, big_num4.GetAsLong()) 360 self.assertEqual(5, big_num5.GetAsLong()) 361 362 def testBigNumModInverse(self): 363 big_num1 = ssl_util.BigNum.FromLongNumber(11) 364 big_num_mod = ssl_util.BigNum.FromLongNumber(20) 365 big_num_result = big_num1.ModInverse(big_num_mod) 366 self.assertEqual(11, big_num1.GetAsLong()) 367 self.assertEqual(20, big_num_mod.GetAsLong()) 368 self.assertEqual(11, big_num_result.GetAsLong()) 369 370 def testBigNumModSqrt(self): 371 big_num1 = ssl_util.BigNum.FromLongNumber(11) 372 big_num_mod = ssl_util.BigNum.FromLongNumber(19) 373 big_num_result = big_num1.ModSqrt(big_num_mod) 374 self.assertEqual(11, big_num1.GetAsLong()) 375 self.assertEqual(19, big_num_mod.GetAsLong()) 376 self.assertEqual(7, big_num_result.GetAsLong()) 377 378 def testBigNumModInverseInvalidForNotRelativelyPrimes(self): 379 big_num1 = ssl_util.BigNum.FromLongNumber(10) 380 big_num_mod = ssl_util.BigNum.FromLongNumber(20) 381 self.assertRaises(ValueError, big_num1.ModInverse, big_num_mod) 382 self.assertEqual(10, big_num1.GetAsLong()) 383 self.assertEqual(20, big_num_mod.GetAsLong()) 384 385 def testBigNumNegates(self): 386 big_num = ssl_util.BigNum.FromLongNumber(10) 387 big_num = big_num.ModNegate(ssl_util.BigNum.FromLongNumber(6)) 388 self.assertEqual(2, big_num.GetAsLong()) 389 390 def testBigNumAddsOne(self): 391 big_num = ssl_util.BigNum.FromLongNumber(10) 392 self.assertEqual(11, big_num.AddOne().GetAsLong()) 393 394 def testBigNumSubtractOne(self): 395 big_num = ssl_util.BigNum.FromLongNumber(10) 396 self.assertEqual(9, big_num.SubtractOne().GetAsLong()) 397 398 def testBigNumGeneratesRandsBetweenZeroAndGivenBigNum(self): 399 big_num = ssl_util.BigNum.FromLongNumber(3) 400 big_rand = big_num.GenerateRand() 401 self.assertTrue(0 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert 402 403 def testBigNumGeneratesZeroForRandWhenTheUpperBoundIsOne(self): 404 big_num = ssl_util.BigNum.FromLongNumber(1) 405 self._GenericRandomTestForCasesThatShouldReturnOneNum( 406 ssl_util.BigNum.Zero(), big_num.GenerateRand 407 ) 408 409 def testBigNumGeneratesRandsBetweenStartAndGivenBigNum(self): 410 big_num = ssl_util.BigNum.FromLongNumber(3) 411 big_rand = big_num.GenerateRandWithStart(ssl_util.BigNum.FromLongNumber(1)) 412 self.assertTrue(1 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert 413 414 def testBigNumGeneratesSingleRandWhenIntervalIsOne(self): 415 start = ssl_util.BigNum.FromLongNumber(2**30 - 1) 416 end = ssl_util.BigNum.FromLongNumber(2**30) 417 self._GenericRandomTestForCasesThatShouldReturnOneNum( 418 start, end.GenerateRandWithStart, start 419 ) 420 421 def testBigNumIsBitSet(self): 422 big_num = ssl_util.BigNum.FromLongNumber(11) 423 self.assertTrue(big_num.IsBitSet(0)) 424 self.assertTrue(big_num.IsBitSet(1)) 425 self.assertFalse(big_num.IsBitSet(2)) 426 self.assertTrue(big_num.IsBitSet(3)) 427 428 def testBigNumEq(self): 429 big_num1 = ssl_util.BigNum.FromLongNumber(11) 430 big_num2 = ssl_util.BigNum.FromLongNumber(11) 431 self.assertEqual(big_num1, big_num2) 432 433 def testBigNumNeq(self): 434 big_num1 = ssl_util.BigNum.FromLongNumber(11) 435 big_num2 = ssl_util.BigNum.FromLongNumber(12) 436 self.assertNotEqual(big_num1, big_num2) 437 438 def testBigNumGt(self): 439 big_num1 = ssl_util.BigNum.FromLongNumber(11) 440 big_num2 = ssl_util.BigNum.FromLongNumber(12) 441 self.assertGreater(big_num2, big_num1) 442 443 def testBigNumGtEq(self): 444 big_num1 = ssl_util.BigNum.FromLongNumber(11) 445 big_num2 = ssl_util.BigNum.FromLongNumber(11) 446 big_num3 = ssl_util.BigNum.FromLongNumber(12) 447 self.assertGreaterEqual(big_num2, big_num1) 448 self.assertGreaterEqual(big_num3, big_num2) 449 450 def testBigNumComparisonWithOtherTypesRaisesValueError(self): 451 big_num1 = ssl_util.BigNum.FromLongNumber(11) 452 self.assertRaises(ValueError, big_num1.__lt__, 11) 453 454 def testClonesCreatesANewBigNum(self): 455 big_num = ssl_util.BigNum.FromLongNumber(0).Mutable() 456 clone_big_num = big_num.Clone() 457 big_num += ssl_util.BigNum.One() 458 self.assertEqual(ssl_util.BigNum.Zero(), clone_big_num) 459 self.assertEqual(ssl_util.BigNum.One(), big_num) 460 461 def testBigNumCacheIsSingleton(self): 462 cache1 = ssl_util.BigNumCache(10) 463 cache2 = ssl_util.BigNumCache(11) 464 self.assertIs(cache1, cache2) 465 466 def testBigNumCacheReturnsTheSameCachedBigNum(self): 467 cache = ssl_util.BigNumCache(10) 468 self.assertIs(cache.Get(1), cache.Get(1)) 469 470 def testBigNumCacheReturnsDifferentBigNumWhenCacheIsFull(self): 471 cache = ssl_util.BigNumCache(10) 472 for i in range(10): 473 cache.Get(i) 474 self.assertIsNot(cache.Get(11), cache.Get(11)) 475 476 def testStringRepresentation(self): 477 big_num = ssl_util.BigNum.FromLongNumber(11) 478 self.assertEqual('11', '{}'.format(big_num)) 479 480 481class _HashMock(object): 482 483 def __init__(self): 484 self.with_patch = patch('hashlib.sha512') 485 486 def __enter__(self): 487 hashlib_mock = self.with_patch.__enter__() 488 sha512_mock = mock.MagicMock() 489 hashlib_mock.return_value = sha512_mock 490 return sha512_mock, hashlib_mock 491 492 def __exit__(self, t, value, traceback): 493 self.with_patch.__exit__(t, value, traceback) 494 495 496class PRNGTest(unittest.TestCase): 497 498 def testPRNG(self): 499 with _HashMock() as (hash_mock, hashlib_mock): 500 hash_mock.digest.return_value = b'\x7f' + b'\x01' * 64 501 prng = PRNG(b'\x02' * 32) 502 self.assertEqual(0, prng.GetRand(2)) 503 self.assertEqual(1, prng.GetRand(256)) 504 self.assertEqual(2, prng.GetRand(257)) 505 self.assertEqual(128, prng.GetRand(32768)) 506 self.assertEqual(257, prng.GetRand(65536)) 507 hash_mock.digest.assert_called_once_with() 508 hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x02' * 32) 509 510 def testGetNBitRandReturnsAtLeastUpperLimit(self): 511 with _HashMock() as (hash_mock, hashlib_mock): 512 hash_mock.digest.return_value = b'\x81\x82\xff\x05' + b'\x00' * 60 513 prng = PRNG(b'\x00' * 32) 514 self.assertEqual(5, prng.GetRand(129)) 515 hash_mock.digest.assert_called_once_with() 516 hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x00' * 32) 517 518 def testRaisesValueErrorWhenSeedIsNotAtLeastFourBytes(self): 519 self.assertRaises(ValueError, PRNG, b'\x00' * 31) 520 521 def testRaisesValueErrorWhenMaxNumberOfHashingIsDone(self): 522 prng = PRNG(b'\x00' * 32, 1) 523 upper_limit = 1 << 512 524 for _ in range(256): 525 prng.GetRand(upper_limit) 526 self.assertRaises(AssertionError, prng.GetRand, 2) 527 self.assertEqual(0, prng.GetRand(1)) 528 529 def testGetsMoreBytesWithHashingUntilSufficientBytesArePresent(self): 530 with _HashMock() as (hash_mock, _): 531 hash_mock.digest.side_effect = [ 532 b'\x80' + b'\x00' * 63, 533 b'\x00' * 64, 534 b'\x00' * 64, 535 ] 536 prng = PRNG(b'\x00' * 32, 1) 537 upper_limit = 1 << 1025 538 self.assertEqual(1 << 1024, prng.GetRand(upper_limit)) 539 hash_mock.digest.assert_has_calls([call(), call(), call()]) 540 541 542if __name__ == '__main__': 543 unittest.main() 544