1"""cryptomath module 2 3This module has basic math/crypto code.""" 4 5import os 6import math 7import base64 8import binascii 9 10# The sha module is deprecated in Python 2.6 11try: 12 import sha 13except ImportError: 14 from hashlib import sha1 as sha 15 16# The md5 module is deprecated in Python 2.6 17try: 18 import md5 19except ImportError: 20 from hashlib import md5 21 22from compat import * 23 24 25# ************************************************************************** 26# Load Optional Modules 27# ************************************************************************** 28 29# Try to load M2Crypto/OpenSSL 30try: 31 from M2Crypto import m2 32 m2cryptoLoaded = True 33 34except ImportError: 35 m2cryptoLoaded = False 36 37 38# Try to load cryptlib 39try: 40 import cryptlib_py 41 try: 42 cryptlib_py.cryptInit() 43 except cryptlib_py.CryptException, e: 44 #If tlslite and cryptoIDlib are both present, 45 #they might each try to re-initialize this, 46 #so we're tolerant of that. 47 if e[0] != cryptlib_py.CRYPT_ERROR_INITED: 48 raise 49 cryptlibpyLoaded = True 50 51except ImportError: 52 cryptlibpyLoaded = False 53 54#Try to load GMPY 55try: 56 import gmpy 57 gmpyLoaded = True 58except ImportError: 59 gmpyLoaded = False 60 61#Try to load pycrypto 62try: 63 import Crypto.Cipher.AES 64 pycryptoLoaded = True 65except ImportError: 66 pycryptoLoaded = False 67 68 69# ************************************************************************** 70# PRNG Functions 71# ************************************************************************** 72 73# Get os.urandom PRNG 74try: 75 os.urandom(1) 76 def getRandomBytes(howMany): 77 return stringToBytes(os.urandom(howMany)) 78 prngName = "os.urandom" 79 80except: 81 # Else get cryptlib PRNG 82 if cryptlibpyLoaded: 83 def getRandomBytes(howMany): 84 randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED, 85 cryptlib_py.CRYPT_ALGO_AES) 86 cryptlib_py.cryptSetAttribute(randomKey, 87 cryptlib_py.CRYPT_CTXINFO_MODE, 88 cryptlib_py.CRYPT_MODE_OFB) 89 cryptlib_py.cryptGenerateKey(randomKey) 90 bytes = createByteArrayZeros(howMany) 91 cryptlib_py.cryptEncrypt(randomKey, bytes) 92 return bytes 93 prngName = "cryptlib" 94 95 else: 96 #Else get UNIX /dev/urandom PRNG 97 try: 98 devRandomFile = open("/dev/urandom", "rb") 99 def getRandomBytes(howMany): 100 return stringToBytes(devRandomFile.read(howMany)) 101 prngName = "/dev/urandom" 102 except IOError: 103 #Else get Win32 CryptoAPI PRNG 104 try: 105 import win32prng 106 def getRandomBytes(howMany): 107 s = win32prng.getRandomBytes(howMany) 108 if len(s) != howMany: 109 raise AssertionError() 110 return stringToBytes(s) 111 prngName ="CryptoAPI" 112 except ImportError: 113 #Else no PRNG :-( 114 def getRandomBytes(howMany): 115 raise NotImplementedError("No Random Number Generator "\ 116 "available.") 117 prngName = "None" 118 119# ************************************************************************** 120# Converter Functions 121# ************************************************************************** 122 123def bytesToNumber(bytes): 124 total = 0L 125 multiplier = 1L 126 for count in range(len(bytes)-1, -1, -1): 127 byte = bytes[count] 128 total += multiplier * byte 129 multiplier *= 256 130 return total 131 132def numberToBytes(n): 133 howManyBytes = numBytes(n) 134 bytes = createByteArrayZeros(howManyBytes) 135 for count in range(howManyBytes-1, -1, -1): 136 bytes[count] = int(n % 256) 137 n >>= 8 138 return bytes 139 140def bytesToBase64(bytes): 141 s = bytesToString(bytes) 142 return stringToBase64(s) 143 144def base64ToBytes(s): 145 s = base64ToString(s) 146 return stringToBytes(s) 147 148def numberToBase64(n): 149 bytes = numberToBytes(n) 150 return bytesToBase64(bytes) 151 152def base64ToNumber(s): 153 bytes = base64ToBytes(s) 154 return bytesToNumber(bytes) 155 156def stringToNumber(s): 157 bytes = stringToBytes(s) 158 return bytesToNumber(bytes) 159 160def numberToString(s): 161 bytes = numberToBytes(s) 162 return bytesToString(bytes) 163 164def base64ToString(s): 165 try: 166 return base64.decodestring(s) 167 except binascii.Error, e: 168 raise SyntaxError(e) 169 except binascii.Incomplete, e: 170 raise SyntaxError(e) 171 172def stringToBase64(s): 173 return base64.encodestring(s).replace("\n", "") 174 175def mpiToNumber(mpi): #mpi is an openssl-format bignum string 176 if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number 177 raise AssertionError() 178 bytes = stringToBytes(mpi[4:]) 179 return bytesToNumber(bytes) 180 181def numberToMPI(n): 182 bytes = numberToBytes(n) 183 ext = 0 184 #If the high-order bit is going to be set, 185 #add an extra byte of zeros 186 if (numBits(n) & 0x7)==0: 187 ext = 1 188 length = numBytes(n) + ext 189 bytes = concatArrays(createByteArrayZeros(4+ext), bytes) 190 bytes[0] = (length >> 24) & 0xFF 191 bytes[1] = (length >> 16) & 0xFF 192 bytes[2] = (length >> 8) & 0xFF 193 bytes[3] = length & 0xFF 194 return bytesToString(bytes) 195 196 197 198# ************************************************************************** 199# Misc. Utility Functions 200# ************************************************************************** 201 202def numBytes(n): 203 if n==0: 204 return 0 205 bits = numBits(n) 206 return int(math.ceil(bits / 8.0)) 207 208def hashAndBase64(s): 209 return stringToBase64(sha.sha(s).digest()) 210 211def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce 212 bytes = getRandomBytes(numChars) 213 bytesStr = "".join([chr(b) for b in bytes]) 214 return stringToBase64(bytesStr)[:numChars] 215 216 217# ************************************************************************** 218# Big Number Math 219# ************************************************************************** 220 221def getRandomNumber(low, high): 222 if low >= high: 223 raise AssertionError() 224 howManyBits = numBits(high) 225 howManyBytes = numBytes(high) 226 lastBits = howManyBits % 8 227 while 1: 228 bytes = getRandomBytes(howManyBytes) 229 if lastBits: 230 bytes[0] = bytes[0] % (1 << lastBits) 231 n = bytesToNumber(bytes) 232 if n >= low and n < high: 233 return n 234 235def gcd(a,b): 236 a, b = max(a,b), min(a,b) 237 while b: 238 a, b = b, a % b 239 return a 240 241def lcm(a, b): 242 #This will break when python division changes, but we can't use // cause 243 #of Jython 244 return (a * b) / gcd(a, b) 245 246#Returns inverse of a mod b, zero if none 247#Uses Extended Euclidean Algorithm 248def invMod(a, b): 249 c, d = a, b 250 uc, ud = 1, 0 251 while c != 0: 252 #This will break when python division changes, but we can't use // 253 #cause of Jython 254 q = d / c 255 c, d = d-(q*c), c 256 uc, ud = ud - (q * uc), uc 257 if d == 1: 258 return ud % b 259 return 0 260 261 262if gmpyLoaded: 263 def powMod(base, power, modulus): 264 base = gmpy.mpz(base) 265 power = gmpy.mpz(power) 266 modulus = gmpy.mpz(modulus) 267 result = pow(base, power, modulus) 268 return long(result) 269 270else: 271 #Copied from Bryan G. Olson's post to comp.lang.python 272 #Does left-to-right instead of pow()'s right-to-left, 273 #thus about 30% faster than the python built-in with small bases 274 def powMod(base, power, modulus): 275 nBitScan = 5 276 277 """ Return base**power mod modulus, using multi bit scanning 278 with nBitScan bits at a time.""" 279 280 #TREV - Added support for negative exponents 281 negativeResult = False 282 if (power < 0): 283 power *= -1 284 negativeResult = True 285 286 exp2 = 2**nBitScan 287 mask = exp2 - 1 288 289 # Break power into a list of digits of nBitScan bits. 290 # The list is recursive so easy to read in reverse direction. 291 nibbles = None 292 while power: 293 nibbles = int(power & mask), nibbles 294 power = power >> nBitScan 295 296 # Make a table of powers of base up to 2**nBitScan - 1 297 lowPowers = [1] 298 for i in xrange(1, exp2): 299 lowPowers.append((lowPowers[i-1] * base) % modulus) 300 301 # To exponentiate by the first nibble, look it up in the table 302 nib, nibbles = nibbles 303 prod = lowPowers[nib] 304 305 # For the rest, square nBitScan times, then multiply by 306 # base^nibble 307 while nibbles: 308 nib, nibbles = nibbles 309 for i in xrange(nBitScan): 310 prod = (prod * prod) % modulus 311 if nib: prod = (prod * lowPowers[nib]) % modulus 312 313 #TREV - Added support for negative exponents 314 if negativeResult: 315 prodInv = invMod(prod, modulus) 316 #Check to make sure the inverse is correct 317 if (prod * prodInv) % modulus != 1: 318 raise AssertionError() 319 return prodInv 320 return prod 321 322 323#Pre-calculate a sieve of the ~100 primes < 1000: 324def makeSieve(n): 325 sieve = range(n) 326 for count in range(2, int(math.sqrt(n))): 327 if sieve[count] == 0: 328 continue 329 x = sieve[count] * 2 330 while x < len(sieve): 331 sieve[x] = 0 332 x += sieve[count] 333 sieve = [x for x in sieve[2:] if x] 334 return sieve 335 336sieve = makeSieve(1000) 337 338def isPrime(n, iterations=5, display=False): 339 #Trial division with sieve 340 for x in sieve: 341 if x >= n: return True 342 if n % x == 0: return False 343 #Passed trial division, proceed to Rabin-Miller 344 #Rabin-Miller implemented per Ferguson & Schneier 345 #Compute s, t for Rabin-Miller 346 if display: print "*", 347 s, t = n-1, 0 348 while s % 2 == 0: 349 s, t = s/2, t+1 350 #Repeat Rabin-Miller x times 351 a = 2 #Use 2 as a base for first iteration speedup, per HAC 352 for count in range(iterations): 353 v = powMod(a, s, n) 354 if v==1: 355 continue 356 i = 0 357 while v != n-1: 358 if i == t-1: 359 return False 360 else: 361 v, i = powMod(v, 2, n), i+1 362 a = getRandomNumber(2, n) 363 return True 364 365def getRandomPrime(bits, display=False): 366 if bits < 10: 367 raise AssertionError() 368 #The 1.5 ensures the 2 MSBs are set 369 #Thus, when used for p,q in RSA, n will have its MSB set 370 # 371 #Since 30 is lcm(2,3,5), we'll set our test numbers to 372 #29 % 30 and keep them there 373 low = (2L ** (bits-1)) * 3/2 374 high = 2L ** bits - 30 375 p = getRandomNumber(low, high) 376 p += 29 - (p % 30) 377 while 1: 378 if display: print ".", 379 p += 30 380 if p >= high: 381 p = getRandomNumber(low, high) 382 p += 29 - (p % 30) 383 if isPrime(p, display=display): 384 return p 385 386#Unused at the moment... 387def getRandomSafePrime(bits, display=False): 388 if bits < 10: 389 raise AssertionError() 390 #The 1.5 ensures the 2 MSBs are set 391 #Thus, when used for p,q in RSA, n will have its MSB set 392 # 393 #Since 30 is lcm(2,3,5), we'll set our test numbers to 394 #29 % 30 and keep them there 395 low = (2 ** (bits-2)) * 3/2 396 high = (2 ** (bits-1)) - 30 397 q = getRandomNumber(low, high) 398 q += 29 - (q % 30) 399 while 1: 400 if display: print ".", 401 q += 30 402 if (q >= high): 403 q = getRandomNumber(low, high) 404 q += 29 - (q % 30) 405 #Ideas from Tom Wu's SRP code 406 #Do trial division on p and q before Rabin-Miller 407 if isPrime(q, 0, display=display): 408 p = (2 * q) + 1 409 if isPrime(p, display=display): 410 if isPrime(q, display=display): 411 return p 412