• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#  Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
2#
3#  Licensed under the Apache License, Version 2.0 (the "License");
4#  you may not use this file except in compliance with the License.
5#  You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9#  Unless required by applicable law or agreed to in writing, software
10#  distributed under the License is distributed on an "AS IS" BASIS,
11#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#  See the License for the specific language governing permissions and
13#  limitations under the License.
14
15"""Common functionality shared by several modules."""
16
17import typing
18
19
20class NotRelativePrimeError(ValueError):
21    def __init__(self, a: int, b: int, d: int, msg: str = '') -> None:
22        super().__init__(msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
23        self.a = a
24        self.b = b
25        self.d = d
26
27
28def bit_size(num: int) -> int:
29    """
30    Number of bits needed to represent a integer excluding any prefix
31    0 bits.
32
33    Usage::
34
35        >>> bit_size(1023)
36        10
37        >>> bit_size(1024)
38        11
39        >>> bit_size(1025)
40        11
41
42    :param num:
43        Integer value. If num is 0, returns 0. Only the absolute value of the
44        number is considered. Therefore, signed integers will be abs(num)
45        before the number's bit length is determined.
46    :returns:
47        Returns the number of bits in the integer.
48    """
49
50    try:
51        return num.bit_length()
52    except AttributeError as ex:
53        raise TypeError('bit_size(num) only supports integers, not %r' % type(num)) from ex
54
55
56def byte_size(number: int) -> int:
57    """
58    Returns the number of bytes required to hold a specific long number.
59
60    The number of bytes is rounded up.
61
62    Usage::
63
64        >>> byte_size(1 << 1023)
65        128
66        >>> byte_size((1 << 1024) - 1)
67        128
68        >>> byte_size(1 << 1024)
69        129
70
71    :param number:
72        An unsigned integer
73    :returns:
74        The number of bytes required to hold a specific long number.
75    """
76    if number == 0:
77        return 1
78    return ceil_div(bit_size(number), 8)
79
80
81def ceil_div(num: int, div: int) -> int:
82    """
83    Returns the ceiling function of a division between `num` and `div`.
84
85    Usage::
86
87        >>> ceil_div(100, 7)
88        15
89        >>> ceil_div(100, 10)
90        10
91        >>> ceil_div(1, 4)
92        1
93
94    :param num: Division's numerator, a number
95    :param div: Division's divisor, a number
96
97    :return: Rounded up result of the division between the parameters.
98    """
99    quanta, mod = divmod(num, div)
100    if mod:
101        quanta += 1
102    return quanta
103
104
105def extended_gcd(a: int, b: int) -> typing.Tuple[int, int, int]:
106    """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
107    """
108    # r = gcd(a,b) i = multiplicitive inverse of a mod b
109    #      or      j = multiplicitive inverse of b mod a
110    # Neg return values for i or j are made positive mod b or a respectively
111    # Iterateive Version is faster and uses much less stack space
112    x = 0
113    y = 1
114    lx = 1
115    ly = 0
116    oa = a  # Remember original a/b to remove
117    ob = b  # negative values from return results
118    while b != 0:
119        q = a // b
120        (a, b) = (b, a % b)
121        (x, lx) = ((lx - (q * x)), x)
122        (y, ly) = ((ly - (q * y)), y)
123    if lx < 0:
124        lx += ob  # If neg wrap modulo orignal b
125    if ly < 0:
126        ly += oa  # If neg wrap modulo orignal a
127    return a, lx, ly  # Return only positive values
128
129
130def inverse(x: int, n: int) -> int:
131    """Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n)
132
133    >>> inverse(7, 4)
134    3
135    >>> (inverse(143, 4) * 143) % 4
136    1
137    """
138
139    (divider, inv, _) = extended_gcd(x, n)
140
141    if divider != 1:
142        raise NotRelativePrimeError(x, n, divider)
143
144    return inv
145
146
147def crt(a_values: typing.Iterable[int], modulo_values: typing.Iterable[int]) -> int:
148    """Chinese Remainder Theorem.
149
150    Calculates x such that x = a[i] (mod m[i]) for each i.
151
152    :param a_values: the a-values of the above equation
153    :param modulo_values: the m-values of the above equation
154    :returns: x such that x = a[i] (mod m[i]) for each i
155
156
157    >>> crt([2, 3], [3, 5])
158    8
159
160    >>> crt([2, 3, 2], [3, 5, 7])
161    23
162
163    >>> crt([2, 3, 0], [7, 11, 15])
164    135
165    """
166
167    m = 1
168    x = 0
169
170    for modulo in modulo_values:
171        m *= modulo
172
173    for (m_i, a_i) in zip(modulo_values, a_values):
174        M_i = m // m_i
175        inv = inverse(M_i, m_i)
176
177        x = (x + a_i * M_i * inv) % m
178
179    return x
180
181
182if __name__ == '__main__':
183    import doctest
184
185    doctest.testmod()
186