• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""cryptomath module
2
3This module has basic math/crypto code."""
4
5import os
6import math
7import base64
8import binascii
9
10# The sha module is deprecated in Python 2.6
11try:
12    import sha
13except ImportError:
14    from hashlib import sha1 as sha
15
16# The md5 module is deprecated in Python 2.6
17try:
18    import md5
19except ImportError:
20    from hashlib import md5
21
22from compat import *
23
24
25# **************************************************************************
26# Load Optional Modules
27# **************************************************************************
28
29# Try to load M2Crypto/OpenSSL
30try:
31    from M2Crypto import m2
32    m2cryptoLoaded = True
33
34except ImportError:
35    m2cryptoLoaded = False
36
37
38# Try to load cryptlib
39try:
40    import cryptlib_py
41    try:
42        cryptlib_py.cryptInit()
43    except cryptlib_py.CryptException, e:
44        #If tlslite and cryptoIDlib are both present,
45        #they might each try to re-initialize this,
46        #so we're tolerant of that.
47        if e[0] != cryptlib_py.CRYPT_ERROR_INITED:
48            raise
49    cryptlibpyLoaded = True
50
51except ImportError:
52    cryptlibpyLoaded = False
53
54#Try to load GMPY
55try:
56    import gmpy
57    gmpyLoaded = True
58except ImportError:
59    gmpyLoaded = False
60
61#Try to load pycrypto
62try:
63    import Crypto.Cipher.AES
64    pycryptoLoaded = True
65except ImportError:
66    pycryptoLoaded = False
67
68
69# **************************************************************************
70# PRNG Functions
71# **************************************************************************
72
73# Get os.urandom PRNG
74try:
75    os.urandom(1)
76    def getRandomBytes(howMany):
77        return stringToBytes(os.urandom(howMany))
78    prngName = "os.urandom"
79
80except:
81    # Else get cryptlib PRNG
82    if cryptlibpyLoaded:
83        def getRandomBytes(howMany):
84            randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED,
85                                                       cryptlib_py.CRYPT_ALGO_AES)
86            cryptlib_py.cryptSetAttribute(randomKey,
87                                          cryptlib_py.CRYPT_CTXINFO_MODE,
88                                          cryptlib_py.CRYPT_MODE_OFB)
89            cryptlib_py.cryptGenerateKey(randomKey)
90            bytes = createByteArrayZeros(howMany)
91            cryptlib_py.cryptEncrypt(randomKey, bytes)
92            return bytes
93        prngName = "cryptlib"
94
95    else:
96        #Else get UNIX /dev/urandom PRNG
97        try:
98            devRandomFile = open("/dev/urandom", "rb")
99            def getRandomBytes(howMany):
100                return stringToBytes(devRandomFile.read(howMany))
101            prngName = "/dev/urandom"
102        except IOError:
103            #Else get Win32 CryptoAPI PRNG
104            try:
105                import win32prng
106                def getRandomBytes(howMany):
107                    s = win32prng.getRandomBytes(howMany)
108                    if len(s) != howMany:
109                        raise AssertionError()
110                    return stringToBytes(s)
111                prngName ="CryptoAPI"
112            except ImportError:
113                #Else no PRNG :-(
114                def getRandomBytes(howMany):
115                    raise NotImplementedError("No Random Number Generator "\
116                                              "available.")
117            prngName = "None"
118
119# **************************************************************************
120# Converter Functions
121# **************************************************************************
122
123def bytesToNumber(bytes):
124    total = 0L
125    multiplier = 1L
126    for count in range(len(bytes)-1, -1, -1):
127        byte = bytes[count]
128        total += multiplier * byte
129        multiplier *= 256
130    return total
131
132def numberToBytes(n):
133    howManyBytes = numBytes(n)
134    bytes = createByteArrayZeros(howManyBytes)
135    for count in range(howManyBytes-1, -1, -1):
136        bytes[count] = int(n % 256)
137        n >>= 8
138    return bytes
139
140def bytesToBase64(bytes):
141    s = bytesToString(bytes)
142    return stringToBase64(s)
143
144def base64ToBytes(s):
145    s = base64ToString(s)
146    return stringToBytes(s)
147
148def numberToBase64(n):
149    bytes = numberToBytes(n)
150    return bytesToBase64(bytes)
151
152def base64ToNumber(s):
153    bytes = base64ToBytes(s)
154    return bytesToNumber(bytes)
155
156def stringToNumber(s):
157    bytes = stringToBytes(s)
158    return bytesToNumber(bytes)
159
160def numberToString(s):
161    bytes = numberToBytes(s)
162    return bytesToString(bytes)
163
164def base64ToString(s):
165    try:
166        return base64.decodestring(s)
167    except binascii.Error, e:
168        raise SyntaxError(e)
169    except binascii.Incomplete, e:
170        raise SyntaxError(e)
171
172def stringToBase64(s):
173    return base64.encodestring(s).replace("\n", "")
174
175def mpiToNumber(mpi): #mpi is an openssl-format bignum string
176    if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
177        raise AssertionError()
178    bytes = stringToBytes(mpi[4:])
179    return bytesToNumber(bytes)
180
181def numberToMPI(n):
182    bytes = numberToBytes(n)
183    ext = 0
184    #If the high-order bit is going to be set,
185    #add an extra byte of zeros
186    if (numBits(n) & 0x7)==0:
187        ext = 1
188    length = numBytes(n) + ext
189    bytes = concatArrays(createByteArrayZeros(4+ext), bytes)
190    bytes[0] = (length >> 24) & 0xFF
191    bytes[1] = (length >> 16) & 0xFF
192    bytes[2] = (length >> 8) & 0xFF
193    bytes[3] = length & 0xFF
194    return bytesToString(bytes)
195
196
197
198# **************************************************************************
199# Misc. Utility Functions
200# **************************************************************************
201
202def numBytes(n):
203    if n==0:
204        return 0
205    bits = numBits(n)
206    return int(math.ceil(bits / 8.0))
207
208def hashAndBase64(s):
209    return stringToBase64(sha.sha(s).digest())
210
211def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce
212    bytes = getRandomBytes(numChars)
213    bytesStr = "".join([chr(b) for b in bytes])
214    return stringToBase64(bytesStr)[:numChars]
215
216
217# **************************************************************************
218# Big Number Math
219# **************************************************************************
220
221def getRandomNumber(low, high):
222    if low >= high:
223        raise AssertionError()
224    howManyBits = numBits(high)
225    howManyBytes = numBytes(high)
226    lastBits = howManyBits % 8
227    while 1:
228        bytes = getRandomBytes(howManyBytes)
229        if lastBits:
230            bytes[0] = bytes[0] % (1 << lastBits)
231        n = bytesToNumber(bytes)
232        if n >= low and n < high:
233            return n
234
235def gcd(a,b):
236    a, b = max(a,b), min(a,b)
237    while b:
238        a, b = b, a % b
239    return a
240
241def lcm(a, b):
242    #This will break when python division changes, but we can't use // cause
243    #of Jython
244    return (a * b) / gcd(a, b)
245
246#Returns inverse of a mod b, zero if none
247#Uses Extended Euclidean Algorithm
248def invMod(a, b):
249    c, d = a, b
250    uc, ud = 1, 0
251    while c != 0:
252        #This will break when python division changes, but we can't use //
253        #cause of Jython
254        q = d / c
255        c, d = d-(q*c), c
256        uc, ud = ud - (q * uc), uc
257    if d == 1:
258        return ud % b
259    return 0
260
261
262if gmpyLoaded:
263    def powMod(base, power, modulus):
264        base = gmpy.mpz(base)
265        power = gmpy.mpz(power)
266        modulus = gmpy.mpz(modulus)
267        result = pow(base, power, modulus)
268        return long(result)
269
270else:
271    #Copied from Bryan G. Olson's post to comp.lang.python
272    #Does left-to-right instead of pow()'s right-to-left,
273    #thus about 30% faster than the python built-in with small bases
274    def powMod(base, power, modulus):
275        nBitScan = 5
276
277        """ Return base**power mod modulus, using multi bit scanning
278        with nBitScan bits at a time."""
279
280        #TREV - Added support for negative exponents
281        negativeResult = False
282        if (power < 0):
283            power *= -1
284            negativeResult = True
285
286        exp2 = 2**nBitScan
287        mask = exp2 - 1
288
289        # Break power into a list of digits of nBitScan bits.
290        # The list is recursive so easy to read in reverse direction.
291        nibbles = None
292        while power:
293            nibbles = int(power & mask), nibbles
294            power = power >> nBitScan
295
296        # Make a table of powers of base up to 2**nBitScan - 1
297        lowPowers = [1]
298        for i in xrange(1, exp2):
299            lowPowers.append((lowPowers[i-1] * base) % modulus)
300
301        # To exponentiate by the first nibble, look it up in the table
302        nib, nibbles = nibbles
303        prod = lowPowers[nib]
304
305        # For the rest, square nBitScan times, then multiply by
306        # base^nibble
307        while nibbles:
308            nib, nibbles = nibbles
309            for i in xrange(nBitScan):
310                prod = (prod * prod) % modulus
311            if nib: prod = (prod * lowPowers[nib]) % modulus
312
313        #TREV - Added support for negative exponents
314        if negativeResult:
315            prodInv = invMod(prod, modulus)
316            #Check to make sure the inverse is correct
317            if (prod * prodInv) % modulus != 1:
318                raise AssertionError()
319            return prodInv
320        return prod
321
322
323#Pre-calculate a sieve of the ~100 primes < 1000:
324def makeSieve(n):
325    sieve = range(n)
326    for count in range(2, int(math.sqrt(n))):
327        if sieve[count] == 0:
328            continue
329        x = sieve[count] * 2
330        while x < len(sieve):
331            sieve[x] = 0
332            x += sieve[count]
333    sieve = [x for x in sieve[2:] if x]
334    return sieve
335
336sieve = makeSieve(1000)
337
338def isPrime(n, iterations=5, display=False):
339    #Trial division with sieve
340    for x in sieve:
341        if x >= n: return True
342        if n % x == 0: return False
343    #Passed trial division, proceed to Rabin-Miller
344    #Rabin-Miller implemented per Ferguson & Schneier
345    #Compute s, t for Rabin-Miller
346    if display: print "*",
347    s, t = n-1, 0
348    while s % 2 == 0:
349        s, t = s/2, t+1
350    #Repeat Rabin-Miller x times
351    a = 2 #Use 2 as a base for first iteration speedup, per HAC
352    for count in range(iterations):
353        v = powMod(a, s, n)
354        if v==1:
355            continue
356        i = 0
357        while v != n-1:
358            if i == t-1:
359                return False
360            else:
361                v, i = powMod(v, 2, n), i+1
362        a = getRandomNumber(2, n)
363    return True
364
365def getRandomPrime(bits, display=False):
366    if bits < 10:
367        raise AssertionError()
368    #The 1.5 ensures the 2 MSBs are set
369    #Thus, when used for p,q in RSA, n will have its MSB set
370    #
371    #Since 30 is lcm(2,3,5), we'll set our test numbers to
372    #29 % 30 and keep them there
373    low = (2L ** (bits-1)) * 3/2
374    high = 2L ** bits - 30
375    p = getRandomNumber(low, high)
376    p += 29 - (p % 30)
377    while 1:
378        if display: print ".",
379        p += 30
380        if p >= high:
381            p = getRandomNumber(low, high)
382            p += 29 - (p % 30)
383        if isPrime(p, display=display):
384            return p
385
386#Unused at the moment...
387def getRandomSafePrime(bits, display=False):
388    if bits < 10:
389        raise AssertionError()
390    #The 1.5 ensures the 2 MSBs are set
391    #Thus, when used for p,q in RSA, n will have its MSB set
392    #
393    #Since 30 is lcm(2,3,5), we'll set our test numbers to
394    #29 % 30 and keep them there
395    low = (2 ** (bits-2)) * 3/2
396    high = (2 ** (bits-1)) - 30
397    q = getRandomNumber(low, high)
398    q += 29 - (q % 30)
399    while 1:
400        if display: print ".",
401        q += 30
402        if (q >= high):
403            q = getRandomNumber(low, high)
404            q += 29 - (q % 30)
405        #Ideas from Tom Wu's SRP code
406        #Do trial division on p and q before Rabin-Miller
407        if isPrime(q, 0, display=display):
408            p = (2 * q) + 1
409            if isPrime(p, display=display):
410                if isPrime(q, display=display):
411                    return p
412