• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Python implementations of some algorithms for use by longobject.c.
2The goal is to provide asymptotically faster algorithms that can be
3used for operations on integers with many digits.  In those cases, the
4performance overhead of the Python implementation is not significant
5since the asymptotic behavior is what dominates runtime. Functions
6provided by this module should be considered private and not part of any
7public API.
8
9Note: for ease of maintainability, please prefer clear code and avoid
10"micro-optimizations".  This module will only be imported and used for
11integers with a huge number of digits.  Saving a few microseconds with
12tricky or non-obvious code is not worth it.  For people looking for
13maximum performance, they should use something like gmpy2."""
14
15import re
16import decimal
17try:
18    import _decimal
19except ImportError:
20    _decimal = None
21
22# A number of functions have this form, where `w` is a desired number of
23# digits in base `base`:
24#
25#    def inner(...w...):
26#        if w <= LIMIT:
27#            return something
28#        lo = w >> 1
29#        hi = w - lo
30#        something involving base**lo, inner(...lo...), j, and inner(...hi...)
31#    figure out largest w needed
32#    result = inner(w)
33#
34# They all had some on-the-fly scheme to cache `base**lo` results for reuse.
35# Power is costly.
36#
37# This routine aims to compute all amd only the needed powers in advance, as
38# efficiently as reasonably possible. This isn't trivial, and all the
39# on-the-fly methods did needless work in many cases. The driving code above
40# changes to:
41#
42#    figure out largest w needed
43#    mycache = compute_powers(w, base, LIMIT)
44#    result = inner(w)
45#
46# and `mycache[lo]` replaces `base**lo` in the inner function.
47#
48# While this does give minor speedups (a few percent at best), the primary
49# intent is to simplify the functions using this, by eliminating the need for
50# them to craft their own ad-hoc caching schemes.
51def compute_powers(w, base, more_than, show=False):
52    seen = set()
53    need = set()
54    ws = {w}
55    while ws:
56        w = ws.pop() # any element is fine to use next
57        if w in seen or w <= more_than:
58            continue
59        seen.add(w)
60        lo = w >> 1
61        # only _need_ lo here; some other path may, or may not, need hi
62        need.add(lo)
63        ws.add(lo)
64        if w & 1:
65            ws.add(lo + 1)
66
67    d = {}
68    if not need:
69        return d
70    it = iter(sorted(need))
71    first = next(it)
72    if show:
73        print("pow at", first)
74    d[first] = base ** first
75    for this in it:
76        if this - 1 in d:
77            if show:
78                print("* base at", this)
79            d[this] = d[this - 1] * base # cheap
80        else:
81            lo = this >> 1
82            hi = this - lo
83            assert lo in d
84            if show:
85                print("square at", this)
86            # Multiplying a bigint by itself (same object!) is about twice
87            # as fast in CPython.
88            sq = d[lo] * d[lo]
89            if hi != lo:
90                assert hi == lo + 1
91                if show:
92                    print("    and * base")
93                sq *= base
94            d[this] = sq
95    return d
96
97_unbounded_dec_context = decimal.getcontext().copy()
98_unbounded_dec_context.prec = decimal.MAX_PREC
99_unbounded_dec_context.Emax = decimal.MAX_EMAX
100_unbounded_dec_context.Emin = decimal.MIN_EMIN
101_unbounded_dec_context.traps[decimal.Inexact] = 1 # sanity check
102
103def int_to_decimal(n):
104    """Asymptotically fast conversion of an 'int' to Decimal."""
105
106    # Function due to Tim Peters.  See GH issue #90716 for details.
107    # https://github.com/python/cpython/issues/90716
108    #
109    # The implementation in longobject.c of base conversion algorithms
110    # between power-of-2 and non-power-of-2 bases are quadratic time.
111    # This function implements a divide-and-conquer algorithm that is
112    # faster for large numbers.  Builds an equal decimal.Decimal in a
113    # "clever" recursive way.  If we want a string representation, we
114    # apply str to _that_.
115
116    from decimal import Decimal as D
117    BITLIM = 200
118
119    # Don't bother caching the "lo" mask in this; the time to compute it is
120    # tiny compared to the multiply.
121    def inner(n, w):
122        if w <= BITLIM:
123            return D(n)
124        w2 = w >> 1
125        hi = n >> w2
126        lo = n & ((1 << w2) - 1)
127        return inner(lo, w2) + inner(hi, w - w2) * w2pow[w2]
128
129    with decimal.localcontext(_unbounded_dec_context):
130        nbits = n.bit_length()
131        w2pow = compute_powers(nbits, D(2), BITLIM)
132        if n < 0:
133            negate = True
134            n = -n
135        else:
136            negate = False
137        result = inner(n, nbits)
138        if negate:
139            result = -result
140    return result
141
142def int_to_decimal_string(n):
143    """Asymptotically fast conversion of an 'int' to a decimal string."""
144    w = n.bit_length()
145    if w > 450_000 and _decimal is not None:
146        # It is only usable with the C decimal implementation.
147        # _pydecimal.py calls str() on very large integers, which in its
148        # turn calls int_to_decimal_string(), causing very deep recursion.
149        return str(int_to_decimal(n))
150
151    # Fallback algorithm for the case when the C decimal module isn't
152    # available.  This algorithm is asymptotically worse than the algorithm
153    # using the decimal module, but better than the quadratic time
154    # implementation in longobject.c.
155
156    DIGLIM = 1000
157    def inner(n, w):
158        if w <= DIGLIM:
159            return str(n)
160        w2 = w >> 1
161        hi, lo = divmod(n, pow10[w2])
162        return inner(hi, w - w2) + inner(lo, w2).zfill(w2)
163
164    # The estimation of the number of decimal digits.
165    # There is no harm in small error.  If we guess too large, there may
166    # be leading 0's that need to be stripped.  If we guess too small, we
167    # may need to call str() recursively for the remaining highest digits,
168    # which can still potentially be a large integer. This is manifested
169    # only if the number has way more than 10**15 digits, that exceeds
170    # the 52-bit physical address limit in both Intel64 and AMD64.
171    w = int(w * 0.3010299956639812 + 1)  # log10(2)
172    pow10 = compute_powers(w, 5, DIGLIM)
173    for k, v in pow10.items():
174        pow10[k] = v << k # 5**k << k == 5**k * 2**k == 10**k
175    if n < 0:
176        n = -n
177        sign = '-'
178    else:
179        sign = ''
180    s = inner(n, w)
181    if s[0] == '0' and n:
182        # If our guess of w is too large, there may be leading 0's that
183        # need to be stripped.
184        s = s.lstrip('0')
185    return sign + s
186
187def _str_to_int_inner(s):
188    """Asymptotically fast conversion of a 'str' to an 'int'."""
189
190    # Function due to Bjorn Martinsson.  See GH issue #90716 for details.
191    # https://github.com/python/cpython/issues/90716
192    #
193    # The implementation in longobject.c of base conversion algorithms
194    # between power-of-2 and non-power-of-2 bases are quadratic time.
195    # This function implements a divide-and-conquer algorithm making use
196    # of Python's built in big int multiplication. Since Python uses the
197    # Karatsuba algorithm for multiplication, the time complexity
198    # of this function is O(len(s)**1.58).
199
200    DIGLIM = 2048
201
202    def inner(a, b):
203        if b - a <= DIGLIM:
204            return int(s[a:b])
205        mid = (a + b + 1) >> 1
206        return (inner(mid, b)
207                + ((inner(a, mid) * w5pow[b - mid])
208                    << (b - mid)))
209
210    w5pow = compute_powers(len(s), 5, DIGLIM)
211    return inner(0, len(s))
212
213
214def int_from_string(s):
215    """Asymptotically fast version of PyLong_FromString(), conversion
216    of a string of decimal digits into an 'int'."""
217    # PyLong_FromString() has already removed leading +/-, checked for invalid
218    # use of underscore characters, checked that string consists of only digits
219    # and underscores, and stripped leading whitespace.  The input can still
220    # contain underscores and have trailing whitespace.
221    s = s.rstrip().replace('_', '')
222    return _str_to_int_inner(s)
223
224def str_to_int(s):
225    """Asymptotically fast version of decimal string to 'int' conversion."""
226    # FIXME: this doesn't support the full syntax that int() supports.
227    m = re.match(r'\s*([+-]?)([0-9_]+)\s*', s)
228    if not m:
229        raise ValueError('invalid literal for int() with base 10')
230    v = int_from_string(m.group(2))
231    if m.group(1) == '-':
232        v = -v
233    return v
234
235
236# Fast integer division, based on code from Mark Dickinson, fast_div.py
237# GH-47701. Additional refinements and optimizations by Bjorn Martinsson.  The
238# algorithm is due to Burnikel and Ziegler, in their paper "Fast Recursive
239# Division".
240
241_DIV_LIMIT = 4000
242
243
244def _div2n1n(a, b, n):
245    """Divide a 2n-bit nonnegative integer a by an n-bit positive integer
246    b, using a recursive divide-and-conquer algorithm.
247
248    Inputs:
249      n is a positive integer
250      b is a positive integer with exactly n bits
251      a is a nonnegative integer such that a < 2**n * b
252
253    Output:
254      (q, r) such that a = b*q+r and 0 <= r < b.
255
256    """
257    if a.bit_length() - n <= _DIV_LIMIT:
258        return divmod(a, b)
259    pad = n & 1
260    if pad:
261        a <<= 1
262        b <<= 1
263        n += 1
264    half_n = n >> 1
265    mask = (1 << half_n) - 1
266    b1, b2 = b >> half_n, b & mask
267    q1, r = _div3n2n(a >> n, (a >> half_n) & mask, b, b1, b2, half_n)
268    q2, r = _div3n2n(r, a & mask, b, b1, b2, half_n)
269    if pad:
270        r >>= 1
271    return q1 << half_n | q2, r
272
273
274def _div3n2n(a12, a3, b, b1, b2, n):
275    """Helper function for _div2n1n; not intended to be called directly."""
276    if a12 >> n == b1:
277        q, r = (1 << n) - 1, a12 - (b1 << n) + b1
278    else:
279        q, r = _div2n1n(a12, b1, n)
280    r = (r << n | a3) - q * b2
281    while r < 0:
282        q -= 1
283        r += b
284    return q, r
285
286
287def _int2digits(a, n):
288    """Decompose non-negative int a into base 2**n
289
290    Input:
291      a is a non-negative integer
292
293    Output:
294      List of the digits of a in base 2**n in little-endian order,
295      meaning the most significant digit is last. The most
296      significant digit is guaranteed to be non-zero.
297      If a is 0 then the output is an empty list.
298
299    """
300    a_digits = [0] * ((a.bit_length() + n - 1) // n)
301
302    def inner(x, L, R):
303        if L + 1 == R:
304            a_digits[L] = x
305            return
306        mid = (L + R) >> 1
307        shift = (mid - L) * n
308        upper = x >> shift
309        lower = x ^ (upper << shift)
310        inner(lower, L, mid)
311        inner(upper, mid, R)
312
313    if a:
314        inner(a, 0, len(a_digits))
315    return a_digits
316
317
318def _digits2int(digits, n):
319    """Combine base-2**n digits into an int. This function is the
320    inverse of `_int2digits`. For more details, see _int2digits.
321    """
322
323    def inner(L, R):
324        if L + 1 == R:
325            return digits[L]
326        mid = (L + R) >> 1
327        shift = (mid - L) * n
328        return (inner(mid, R) << shift) + inner(L, mid)
329
330    return inner(0, len(digits)) if digits else 0
331
332
333def _divmod_pos(a, b):
334    """Divide a non-negative integer a by a positive integer b, giving
335    quotient and remainder."""
336    # Use grade-school algorithm in base 2**n, n = nbits(b)
337    n = b.bit_length()
338    a_digits = _int2digits(a, n)
339
340    r = 0
341    q_digits = []
342    for a_digit in reversed(a_digits):
343        q_digit, r = _div2n1n((r << n) + a_digit, b, n)
344        q_digits.append(q_digit)
345    q_digits.reverse()
346    q = _digits2int(q_digits, n)
347    return q, r
348
349
350def int_divmod(a, b):
351    """Asymptotically fast replacement for divmod, for 'int'.
352    Its time complexity is O(n**1.58), where n = #bits(a) + #bits(b).
353    """
354    if b == 0:
355        raise ZeroDivisionError
356    elif b < 0:
357        q, r = int_divmod(-a, -b)
358        return q, -r
359    elif a < 0:
360        q, r = int_divmod(~a, b)
361        return ~q, b + ~r
362    else:
363        return _divmod_pos(a, b)
364