1# Authors: 2# Trevor Perrin 3# Martin von Loewis - python 3 port 4# 5# See the LICENSE file for legal information regarding use of this file. 6 7"""cryptomath module 8 9This module has basic math/crypto code.""" 10from __future__ import print_function 11import os 12import math 13import base64 14import binascii 15 16from .compat import * 17 18 19# ************************************************************************** 20# Load Optional Modules 21# ************************************************************************** 22 23# Try to load M2Crypto/OpenSSL 24try: 25 from M2Crypto import m2 26 m2cryptoLoaded = True 27 28except ImportError: 29 m2cryptoLoaded = False 30 31#Try to load GMPY 32try: 33 import gmpy 34 gmpyLoaded = True 35except ImportError: 36 gmpyLoaded = False 37 38#Try to load pycrypto 39try: 40 import Crypto.Cipher.AES 41 pycryptoLoaded = True 42except ImportError: 43 pycryptoLoaded = False 44 45 46# ************************************************************************** 47# PRNG Functions 48# ************************************************************************** 49 50# Check that os.urandom works 51import zlib 52length = len(zlib.compress(os.urandom(1000))) 53assert(length > 900) 54 55def getRandomBytes(howMany): 56 b = bytearray(os.urandom(howMany)) 57 assert(len(b) == howMany) 58 return b 59 60prngName = "os.urandom" 61 62# ************************************************************************** 63# Simple hash functions 64# ************************************************************************** 65 66import hmac 67import hashlib 68 69def MD5(b): 70 return bytearray(hashlib.md5(compat26Str(b)).digest()) 71 72def SHA1(b): 73 return bytearray(hashlib.sha1(compat26Str(b)).digest()) 74 75def HMAC_MD5(k, b): 76 k = compatHMAC(k) 77 b = compatHMAC(b) 78 return bytearray(hmac.new(k, b, hashlib.md5).digest()) 79 80def HMAC_SHA1(k, b): 81 k = compatHMAC(k) 82 b = compatHMAC(b) 83 return bytearray(hmac.new(k, b, hashlib.sha1).digest()) 84 85 86# ************************************************************************** 87# Converter Functions 88# ************************************************************************** 89 90def bytesToNumber(b): 91 total = 0 92 multiplier = 1 93 for count in range(len(b)-1, -1, -1): 94 byte = b[count] 95 total += multiplier * byte 96 multiplier *= 256 97 # Force-cast to long to appease PyCrypto. 98 # https://github.com/trevp/tlslite/issues/15 99 return long(total) 100 101def numberToByteArray(n, howManyBytes=None): 102 """Convert an integer into a bytearray, zero-pad to howManyBytes. 103 104 The returned bytearray may be smaller than howManyBytes, but will 105 not be larger. The returned bytearray will contain a big-endian 106 encoding of the input integer (n). 107 """ 108 if howManyBytes == None: 109 howManyBytes = numBytes(n) 110 b = bytearray(howManyBytes) 111 for count in range(howManyBytes-1, -1, -1): 112 b[count] = int(n % 256) 113 n >>= 8 114 return b 115 116def mpiToNumber(mpi): #mpi is an openssl-format bignum string 117 if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number 118 raise AssertionError() 119 b = bytearray(mpi[4:]) 120 return bytesToNumber(b) 121 122def numberToMPI(n): 123 b = numberToByteArray(n) 124 ext = 0 125 #If the high-order bit is going to be set, 126 #add an extra byte of zeros 127 if (numBits(n) & 0x7)==0: 128 ext = 1 129 length = numBytes(n) + ext 130 b = bytearray(4+ext) + b 131 b[0] = (length >> 24) & 0xFF 132 b[1] = (length >> 16) & 0xFF 133 b[2] = (length >> 8) & 0xFF 134 b[3] = length & 0xFF 135 return bytes(b) 136 137 138# ************************************************************************** 139# Misc. Utility Functions 140# ************************************************************************** 141 142def numBits(n): 143 if n==0: 144 return 0 145 s = "%x" % n 146 return ((len(s)-1)*4) + \ 147 {'0':0, '1':1, '2':2, '3':2, 148 '4':3, '5':3, '6':3, '7':3, 149 '8':4, '9':4, 'a':4, 'b':4, 150 'c':4, 'd':4, 'e':4, 'f':4, 151 }[s[0]] 152 return int(math.floor(math.log(n, 2))+1) 153 154def numBytes(n): 155 if n==0: 156 return 0 157 bits = numBits(n) 158 return int(math.ceil(bits / 8.0)) 159 160# ************************************************************************** 161# Big Number Math 162# ************************************************************************** 163 164def getRandomNumber(low, high): 165 if low >= high: 166 raise AssertionError() 167 howManyBits = numBits(high) 168 howManyBytes = numBytes(high) 169 lastBits = howManyBits % 8 170 while 1: 171 bytes = getRandomBytes(howManyBytes) 172 if lastBits: 173 bytes[0] = bytes[0] % (1 << lastBits) 174 n = bytesToNumber(bytes) 175 if n >= low and n < high: 176 return n 177 178def gcd(a,b): 179 a, b = max(a,b), min(a,b) 180 while b: 181 a, b = b, a % b 182 return a 183 184def lcm(a, b): 185 return (a * b) // gcd(a, b) 186 187#Returns inverse of a mod b, zero if none 188#Uses Extended Euclidean Algorithm 189def invMod(a, b): 190 c, d = a, b 191 uc, ud = 1, 0 192 while c != 0: 193 q = d // c 194 c, d = d-(q*c), c 195 uc, ud = ud - (q * uc), uc 196 if d == 1: 197 return ud % b 198 return 0 199 200 201if gmpyLoaded: 202 def powMod(base, power, modulus): 203 base = gmpy.mpz(base) 204 power = gmpy.mpz(power) 205 modulus = gmpy.mpz(modulus) 206 result = pow(base, power, modulus) 207 return long(result) 208 209else: 210 def powMod(base, power, modulus): 211 if power < 0: 212 result = pow(base, power*-1, modulus) 213 result = invMod(result, modulus) 214 return result 215 else: 216 return pow(base, power, modulus) 217 218#Pre-calculate a sieve of the ~100 primes < 1000: 219def makeSieve(n): 220 sieve = list(range(n)) 221 for count in range(2, int(math.sqrt(n))): 222 if sieve[count] == 0: 223 continue 224 x = sieve[count] * 2 225 while x < len(sieve): 226 sieve[x] = 0 227 x += sieve[count] 228 sieve = [x for x in sieve[2:] if x] 229 return sieve 230 231sieve = makeSieve(1000) 232 233def isPrime(n, iterations=5, display=False): 234 #Trial division with sieve 235 for x in sieve: 236 if x >= n: return True 237 if n % x == 0: return False 238 #Passed trial division, proceed to Rabin-Miller 239 #Rabin-Miller implemented per Ferguson & Schneier 240 #Compute s, t for Rabin-Miller 241 if display: print("*", end=' ') 242 s, t = n-1, 0 243 while s % 2 == 0: 244 s, t = s//2, t+1 245 #Repeat Rabin-Miller x times 246 a = 2 #Use 2 as a base for first iteration speedup, per HAC 247 for count in range(iterations): 248 v = powMod(a, s, n) 249 if v==1: 250 continue 251 i = 0 252 while v != n-1: 253 if i == t-1: 254 return False 255 else: 256 v, i = powMod(v, 2, n), i+1 257 a = getRandomNumber(2, n) 258 return True 259 260def getRandomPrime(bits, display=False): 261 if bits < 10: 262 raise AssertionError() 263 #The 1.5 ensures the 2 MSBs are set 264 #Thus, when used for p,q in RSA, n will have its MSB set 265 # 266 #Since 30 is lcm(2,3,5), we'll set our test numbers to 267 #29 % 30 and keep them there 268 low = ((2 ** (bits-1)) * 3) // 2 269 high = 2 ** bits - 30 270 p = getRandomNumber(low, high) 271 p += 29 - (p % 30) 272 while 1: 273 if display: print(".", end=' ') 274 p += 30 275 if p >= high: 276 p = getRandomNumber(low, high) 277 p += 29 - (p % 30) 278 if isPrime(p, display=display): 279 return p 280 281#Unused at the moment... 282def getRandomSafePrime(bits, display=False): 283 if bits < 10: 284 raise AssertionError() 285 #The 1.5 ensures the 2 MSBs are set 286 #Thus, when used for p,q in RSA, n will have its MSB set 287 # 288 #Since 30 is lcm(2,3,5), we'll set our test numbers to 289 #29 % 30 and keep them there 290 low = (2 ** (bits-2)) * 3//2 291 high = (2 ** (bits-1)) - 30 292 q = getRandomNumber(low, high) 293 q += 29 - (q % 30) 294 while 1: 295 if display: print(".", end=' ') 296 q += 30 297 if (q >= high): 298 q = getRandomNumber(low, high) 299 q += 29 - (q % 30) 300 #Ideas from Tom Wu's SRP code 301 #Do trial division on p and q before Rabin-Miller 302 if isPrime(q, 0, display=display): 303 p = (2 * q) + 1 304 if isPrime(p, display=display): 305 if isPrime(q, display=display): 306 return p 307