• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Basic statistics module.
3
4This module provides functions for calculating statistics of data, including
5averages, variance, and standard deviation.
6
7Calculating averages
8--------------------
9
10==================  ==================================================
11Function            Description
12==================  ==================================================
13mean                Arithmetic mean (average) of data.
14fmean               Fast, floating point arithmetic mean.
15geometric_mean      Geometric mean of data.
16harmonic_mean       Harmonic mean of data.
17median              Median (middle value) of data.
18median_low          Low median of data.
19median_high         High median of data.
20median_grouped      Median, or 50th percentile, of grouped data.
21mode                Mode (most common value) of data.
22multimode           List of modes (most common values of data).
23quantiles           Divide data into intervals with equal probability.
24==================  ==================================================
25
26Calculate the arithmetic mean ("the average") of data:
27
28>>> mean([-1.0, 2.5, 3.25, 5.75])
292.625
30
31
32Calculate the standard median of discrete data:
33
34>>> median([2, 3, 4, 5])
353.5
36
37
38Calculate the median, or 50th percentile, of data grouped into class intervals
39centred on the data values provided. E.g. if your data points are rounded to
40the nearest whole number:
41
42>>> median_grouped([2, 2, 3, 3, 3, 4])  #doctest: +ELLIPSIS
432.8333333333...
44
45This should be interpreted in this way: you have two data points in the class
46interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
47the class interval 3.5-4.5. The median of these data points is 2.8333...
48
49
50Calculating variability or spread
51---------------------------------
52
53==================  =============================================
54Function            Description
55==================  =============================================
56pvariance           Population variance of data.
57variance            Sample variance of data.
58pstdev              Population standard deviation of data.
59stdev               Sample standard deviation of data.
60==================  =============================================
61
62Calculate the standard deviation of sample data:
63
64>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75])  #doctest: +ELLIPSIS
654.38961843444...
66
67If you have previously calculated the mean, you can pass it as the optional
68second argument to the four "spread" functions to avoid recalculating it:
69
70>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
71>>> mu = mean(data)
72>>> pvariance(data, mu)
732.5
74
75
76Statistics for relations between two inputs
77-------------------------------------------
78
79==================  ====================================================
80Function            Description
81==================  ====================================================
82covariance          Sample covariance for two variables.
83correlation         Pearson's correlation coefficient for two variables.
84linear_regression   Intercept and slope for simple linear regression.
85==================  ====================================================
86
87Calculate covariance, Pearson's correlation, and simple linear regression
88for two inputs:
89
90>>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
91>>> y = [1, 2, 3, 1, 2, 3, 1, 2, 3]
92>>> covariance(x, y)
930.75
94>>> correlation(x, y)  #doctest: +ELLIPSIS
950.31622776601...
96>>> linear_regression(x, y)  #doctest:
97LinearRegression(slope=0.1, intercept=1.5)
98
99
100Exceptions
101----------
102
103A single exception is defined: StatisticsError is a subclass of ValueError.
104
105"""
106
107__all__ = [
108    'NormalDist',
109    'StatisticsError',
110    'correlation',
111    'covariance',
112    'fmean',
113    'geometric_mean',
114    'harmonic_mean',
115    'linear_regression',
116    'mean',
117    'median',
118    'median_grouped',
119    'median_high',
120    'median_low',
121    'mode',
122    'multimode',
123    'pstdev',
124    'pvariance',
125    'quantiles',
126    'stdev',
127    'variance',
128]
129
130import math
131import numbers
132import random
133
134from fractions import Fraction
135from decimal import Decimal
136from itertools import groupby, repeat
137from bisect import bisect_left, bisect_right
138from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
139from operator import itemgetter
140from collections import Counter, namedtuple
141
142# === Exceptions ===
143
144class StatisticsError(ValueError):
145    pass
146
147
148# === Private utilities ===
149
150def _sum(data):
151    """_sum(data) -> (type, sum, count)
152
153    Return a high-precision sum of the given numeric data as a fraction,
154    together with the type to be converted to and the count of items.
155
156    Examples
157    --------
158
159    >>> _sum([3, 2.25, 4.5, -0.5, 0.25])
160    (<class 'float'>, Fraction(19, 2), 5)
161
162    Some sources of round-off error will be avoided:
163
164    # Built-in sum returns zero.
165    >>> _sum([1e50, 1, -1e50] * 1000)
166    (<class 'float'>, Fraction(1000, 1), 3000)
167
168    Fractions and Decimals are also supported:
169
170    >>> from fractions import Fraction as F
171    >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
172    (<class 'fractions.Fraction'>, Fraction(63, 20), 4)
173
174    >>> from decimal import Decimal as D
175    >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
176    >>> _sum(data)
177    (<class 'decimal.Decimal'>, Fraction(6963, 10000), 4)
178
179    Mixed types are currently treated as an error, except that int is
180    allowed.
181    """
182    count = 0
183    partials = {}
184    partials_get = partials.get
185    T = int
186    for typ, values in groupby(data, type):
187        T = _coerce(T, typ)  # or raise TypeError
188        for n, d in map(_exact_ratio, values):
189            count += 1
190            partials[d] = partials_get(d, 0) + n
191    if None in partials:
192        # The sum will be a NAN or INF. We can ignore all the finite
193        # partials, and just look at this special one.
194        total = partials[None]
195        assert not _isfinite(total)
196    else:
197        # Sum all the partial sums using builtin sum.
198        total = sum(Fraction(n, d) for d, n in partials.items())
199    return (T, total, count)
200
201
202def _isfinite(x):
203    try:
204        return x.is_finite()  # Likely a Decimal.
205    except AttributeError:
206        return math.isfinite(x)  # Coerces to float first.
207
208
209def _coerce(T, S):
210    """Coerce types T and S to a common type, or raise TypeError.
211
212    Coercion rules are currently an implementation detail. See the CoerceTest
213    test class in test_statistics for details.
214    """
215    # See http://bugs.python.org/issue24068.
216    assert T is not bool, "initial type T is bool"
217    # If the types are the same, no need to coerce anything. Put this
218    # first, so that the usual case (no coercion needed) happens as soon
219    # as possible.
220    if T is S:  return T
221    # Mixed int & other coerce to the other type.
222    if S is int or S is bool:  return T
223    if T is int:  return S
224    # If one is a (strict) subclass of the other, coerce to the subclass.
225    if issubclass(S, T):  return S
226    if issubclass(T, S):  return T
227    # Ints coerce to the other type.
228    if issubclass(T, int):  return S
229    if issubclass(S, int):  return T
230    # Mixed fraction & float coerces to float (or float subclass).
231    if issubclass(T, Fraction) and issubclass(S, float):
232        return S
233    if issubclass(T, float) and issubclass(S, Fraction):
234        return T
235    # Any other combination is disallowed.
236    msg = "don't know how to coerce %s and %s"
237    raise TypeError(msg % (T.__name__, S.__name__))
238
239
240def _exact_ratio(x):
241    """Return Real number x to exact (numerator, denominator) pair.
242
243    >>> _exact_ratio(0.25)
244    (1, 4)
245
246    x is expected to be an int, Fraction, Decimal or float.
247    """
248    try:
249        return x.as_integer_ratio()
250    except AttributeError:
251        pass
252    except (OverflowError, ValueError):
253        # float NAN or INF.
254        assert not _isfinite(x)
255        return (x, None)
256    try:
257        # x may be an Integral ABC.
258        return (x.numerator, x.denominator)
259    except AttributeError:
260        msg = f"can't convert type '{type(x).__name__}' to numerator/denominator"
261        raise TypeError(msg)
262
263
264def _convert(value, T):
265    """Convert value to given numeric type T."""
266    if type(value) is T:
267        # This covers the cases where T is Fraction, or where value is
268        # a NAN or INF (Decimal or float).
269        return value
270    if issubclass(T, int) and value.denominator != 1:
271        T = float
272    try:
273        # FIXME: what do we do if this overflows?
274        return T(value)
275    except TypeError:
276        if issubclass(T, Decimal):
277            return T(value.numerator) / T(value.denominator)
278        else:
279            raise
280
281
282def _find_lteq(a, x):
283    'Locate the leftmost value exactly equal to x'
284    i = bisect_left(a, x)
285    if i != len(a) and a[i] == x:
286        return i
287    raise ValueError
288
289
290def _find_rteq(a, l, x):
291    'Locate the rightmost value exactly equal to x'
292    i = bisect_right(a, x, lo=l)
293    if i != (len(a) + 1) and a[i - 1] == x:
294        return i - 1
295    raise ValueError
296
297
298def _fail_neg(values, errmsg='negative value'):
299    """Iterate over values, failing if any are less than zero."""
300    for x in values:
301        if x < 0:
302            raise StatisticsError(errmsg)
303        yield x
304
305
306# === Measures of central tendency (averages) ===
307
308def mean(data):
309    """Return the sample arithmetic mean of data.
310
311    >>> mean([1, 2, 3, 4, 4])
312    2.8
313
314    >>> from fractions import Fraction as F
315    >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
316    Fraction(13, 21)
317
318    >>> from decimal import Decimal as D
319    >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
320    Decimal('0.5625')
321
322    If ``data`` is empty, StatisticsError will be raised.
323    """
324    if iter(data) is data:
325        data = list(data)
326    n = len(data)
327    if n < 1:
328        raise StatisticsError('mean requires at least one data point')
329    T, total, count = _sum(data)
330    assert count == n
331    return _convert(total / n, T)
332
333
334def fmean(data):
335    """Convert data to floats and compute the arithmetic mean.
336
337    This runs faster than the mean() function and it always returns a float.
338    If the input dataset is empty, it raises a StatisticsError.
339
340    >>> fmean([3.5, 4.0, 5.25])
341    4.25
342    """
343    try:
344        n = len(data)
345    except TypeError:
346        # Handle iterators that do not define __len__().
347        n = 0
348        def count(iterable):
349            nonlocal n
350            for n, x in enumerate(iterable, start=1):
351                yield x
352        total = fsum(count(data))
353    else:
354        total = fsum(data)
355    try:
356        return total / n
357    except ZeroDivisionError:
358        raise StatisticsError('fmean requires at least one data point') from None
359
360
361def geometric_mean(data):
362    """Convert data to floats and compute the geometric mean.
363
364    Raises a StatisticsError if the input dataset is empty,
365    if it contains a zero, or if it contains a negative value.
366
367    No special efforts are made to achieve exact results.
368    (However, this may change in the future.)
369
370    >>> round(geometric_mean([54, 24, 36]), 9)
371    36.0
372    """
373    try:
374        return exp(fmean(map(log, data)))
375    except ValueError:
376        raise StatisticsError('geometric mean requires a non-empty dataset '
377                              'containing positive numbers') from None
378
379
380def harmonic_mean(data, weights=None):
381    """Return the harmonic mean of data.
382
383    The harmonic mean is the reciprocal of the arithmetic mean of the
384    reciprocals of the data.  It can be used for averaging ratios or
385    rates, for example speeds.
386
387    Suppose a car travels 40 km/hr for 5 km and then speeds-up to
388    60 km/hr for another 5 km. What is the average speed?
389
390        >>> harmonic_mean([40, 60])
391        48.0
392
393    Suppose a car travels 40 km/hr for 5 km, and when traffic clears,
394    speeds-up to 60 km/hr for the remaining 30 km of the journey. What
395    is the average speed?
396
397        >>> harmonic_mean([40, 60], weights=[5, 30])
398        56.0
399
400    If ``data`` is empty, or any element is less than zero,
401    ``harmonic_mean`` will raise ``StatisticsError``.
402    """
403    if iter(data) is data:
404        data = list(data)
405    errmsg = 'harmonic mean does not support negative values'
406    n = len(data)
407    if n < 1:
408        raise StatisticsError('harmonic_mean requires at least one data point')
409    elif n == 1 and weights is None:
410        x = data[0]
411        if isinstance(x, (numbers.Real, Decimal)):
412            if x < 0:
413                raise StatisticsError(errmsg)
414            return x
415        else:
416            raise TypeError('unsupported type')
417    if weights is None:
418        weights = repeat(1, n)
419        sum_weights = n
420    else:
421        if iter(weights) is weights:
422            weights = list(weights)
423        if len(weights) != n:
424            raise StatisticsError('Number of weights does not match data size')
425        _, sum_weights, _ = _sum(w for w in _fail_neg(weights, errmsg))
426    try:
427        data = _fail_neg(data, errmsg)
428        T, total, count = _sum(w / x if w else 0 for w, x in zip(weights, data))
429    except ZeroDivisionError:
430        return 0
431    if total <= 0:
432        raise StatisticsError('Weighted sum must be positive')
433    return _convert(sum_weights / total, T)
434
435# FIXME: investigate ways to calculate medians without sorting? Quickselect?
436def median(data):
437    """Return the median (middle value) of numeric data.
438
439    When the number of data points is odd, return the middle data point.
440    When the number of data points is even, the median is interpolated by
441    taking the average of the two middle values:
442
443    >>> median([1, 3, 5])
444    3
445    >>> median([1, 3, 5, 7])
446    4.0
447
448    """
449    data = sorted(data)
450    n = len(data)
451    if n == 0:
452        raise StatisticsError("no median for empty data")
453    if n % 2 == 1:
454        return data[n // 2]
455    else:
456        i = n // 2
457        return (data[i - 1] + data[i]) / 2
458
459
460def median_low(data):
461    """Return the low median of numeric data.
462
463    When the number of data points is odd, the middle value is returned.
464    When it is even, the smaller of the two middle values is returned.
465
466    >>> median_low([1, 3, 5])
467    3
468    >>> median_low([1, 3, 5, 7])
469    3
470
471    """
472    data = sorted(data)
473    n = len(data)
474    if n == 0:
475        raise StatisticsError("no median for empty data")
476    if n % 2 == 1:
477        return data[n // 2]
478    else:
479        return data[n // 2 - 1]
480
481
482def median_high(data):
483    """Return the high median of data.
484
485    When the number of data points is odd, the middle value is returned.
486    When it is even, the larger of the two middle values is returned.
487
488    >>> median_high([1, 3, 5])
489    3
490    >>> median_high([1, 3, 5, 7])
491    5
492
493    """
494    data = sorted(data)
495    n = len(data)
496    if n == 0:
497        raise StatisticsError("no median for empty data")
498    return data[n // 2]
499
500
501def median_grouped(data, interval=1):
502    """Return the 50th percentile (median) of grouped continuous data.
503
504    >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
505    3.7
506    >>> median_grouped([52, 52, 53, 54])
507    52.5
508
509    This calculates the median as the 50th percentile, and should be
510    used when your data is continuous and grouped. In the above example,
511    the values 1, 2, 3, etc. actually represent the midpoint of classes
512    0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
513    class 3.5-4.5, and interpolation is used to estimate it.
514
515    Optional argument ``interval`` represents the class interval, and
516    defaults to 1. Changing the class interval naturally will change the
517    interpolated 50th percentile value:
518
519    >>> median_grouped([1, 3, 3, 5, 7], interval=1)
520    3.25
521    >>> median_grouped([1, 3, 3, 5, 7], interval=2)
522    3.5
523
524    This function does not check whether the data points are at least
525    ``interval`` apart.
526    """
527    data = sorted(data)
528    n = len(data)
529    if n == 0:
530        raise StatisticsError("no median for empty data")
531    elif n == 1:
532        return data[0]
533    # Find the value at the midpoint. Remember this corresponds to the
534    # centre of the class interval.
535    x = data[n // 2]
536    for obj in (x, interval):
537        if isinstance(obj, (str, bytes)):
538            raise TypeError('expected number but got %r' % obj)
539    try:
540        L = x - interval / 2  # The lower limit of the median interval.
541    except TypeError:
542        # Mixed type. For now we just coerce to float.
543        L = float(x) - float(interval) / 2
544
545    # Uses bisection search to search for x in data with log(n) time complexity
546    # Find the position of leftmost occurrence of x in data
547    l1 = _find_lteq(data, x)
548    # Find the position of rightmost occurrence of x in data[l1...len(data)]
549    # Assuming always l1 <= l2
550    l2 = _find_rteq(data, l1, x)
551    cf = l1
552    f = l2 - l1 + 1
553    return L + interval * (n / 2 - cf) / f
554
555
556def mode(data):
557    """Return the most common data point from discrete or nominal data.
558
559    ``mode`` assumes discrete data, and returns a single value. This is the
560    standard treatment of the mode as commonly taught in schools:
561
562        >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
563        3
564
565    This also works with nominal (non-numeric) data:
566
567        >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
568        'red'
569
570    If there are multiple modes with same frequency, return the first one
571    encountered:
572
573        >>> mode(['red', 'red', 'green', 'blue', 'blue'])
574        'red'
575
576    If *data* is empty, ``mode``, raises StatisticsError.
577
578    """
579    pairs = Counter(iter(data)).most_common(1)
580    try:
581        return pairs[0][0]
582    except IndexError:
583        raise StatisticsError('no mode for empty data') from None
584
585
586def multimode(data):
587    """Return a list of the most frequently occurring values.
588
589    Will return more than one result if there are multiple modes
590    or an empty list if *data* is empty.
591
592    >>> multimode('aabbbbbbbbcc')
593    ['b']
594    >>> multimode('aabbbbccddddeeffffgg')
595    ['b', 'd', 'f']
596    >>> multimode('')
597    []
598    """
599    counts = Counter(iter(data)).most_common()
600    maxcount, mode_items = next(groupby(counts, key=itemgetter(1)), (0, []))
601    return list(map(itemgetter(0), mode_items))
602
603
604# Notes on methods for computing quantiles
605# ----------------------------------------
606#
607# There is no one perfect way to compute quantiles.  Here we offer
608# two methods that serve common needs.  Most other packages
609# surveyed offered at least one or both of these two, making them
610# "standard" in the sense of "widely-adopted and reproducible".
611# They are also easy to explain, easy to compute manually, and have
612# straight-forward interpretations that aren't surprising.
613
614# The default method is known as "R6", "PERCENTILE.EXC", or "expected
615# value of rank order statistics". The alternative method is known as
616# "R7", "PERCENTILE.INC", or "mode of rank order statistics".
617
618# For sample data where there is a positive probability for values
619# beyond the range of the data, the R6 exclusive method is a
620# reasonable choice.  Consider a random sample of nine values from a
621# population with a uniform distribution from 0.0 to 1.0.  The
622# distribution of the third ranked sample point is described by
623# betavariate(alpha=3, beta=7) which has mode=0.250, median=0.286, and
624# mean=0.300.  Only the latter (which corresponds with R6) gives the
625# desired cut point with 30% of the population falling below that
626# value, making it comparable to a result from an inv_cdf() function.
627# The R6 exclusive method is also idempotent.
628
629# For describing population data where the end points are known to
630# be included in the data, the R7 inclusive method is a reasonable
631# choice.  Instead of the mean, it uses the mode of the beta
632# distribution for the interior points.  Per Hyndman & Fan, "One nice
633# property is that the vertices of Q7(p) divide the range into n - 1
634# intervals, and exactly 100p% of the intervals lie to the left of
635# Q7(p) and 100(1 - p)% of the intervals lie to the right of Q7(p)."
636
637# If needed, other methods could be added.  However, for now, the
638# position is that fewer options make for easier choices and that
639# external packages can be used for anything more advanced.
640
641def quantiles(data, *, n=4, method='exclusive'):
642    """Divide *data* into *n* continuous intervals with equal probability.
643
644    Returns a list of (n - 1) cut points separating the intervals.
645
646    Set *n* to 4 for quartiles (the default).  Set *n* to 10 for deciles.
647    Set *n* to 100 for percentiles which gives the 99 cuts points that
648    separate *data* in to 100 equal sized groups.
649
650    The *data* can be any iterable containing sample.
651    The cut points are linearly interpolated between data points.
652
653    If *method* is set to *inclusive*, *data* is treated as population
654    data.  The minimum value is treated as the 0th percentile and the
655    maximum value is treated as the 100th percentile.
656    """
657    if n < 1:
658        raise StatisticsError('n must be at least 1')
659    data = sorted(data)
660    ld = len(data)
661    if ld < 2:
662        raise StatisticsError('must have at least two data points')
663    if method == 'inclusive':
664        m = ld - 1
665        result = []
666        for i in range(1, n):
667            j, delta = divmod(i * m, n)
668            interpolated = (data[j] * (n - delta) + data[j + 1] * delta) / n
669            result.append(interpolated)
670        return result
671    if method == 'exclusive':
672        m = ld + 1
673        result = []
674        for i in range(1, n):
675            j = i * m // n                               # rescale i to m/n
676            j = 1 if j < 1 else ld-1 if j > ld-1 else j  # clamp to 1 .. ld-1
677            delta = i*m - j*n                            # exact integer math
678            interpolated = (data[j - 1] * (n - delta) + data[j] * delta) / n
679            result.append(interpolated)
680        return result
681    raise ValueError(f'Unknown method: {method!r}')
682
683
684# === Measures of spread ===
685
686# See http://mathworld.wolfram.com/Variance.html
687#     http://mathworld.wolfram.com/SampleVariance.html
688#     http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
689#
690# Under no circumstances use the so-called "computational formula for
691# variance", as that is only suitable for hand calculations with a small
692# amount of low-precision data. It has terrible numeric properties.
693#
694# See a comparison of three computational methods here:
695# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
696
697def _ss(data, c=None):
698    """Return sum of square deviations of sequence data.
699
700    If ``c`` is None, the mean is calculated in one pass, and the deviations
701    from the mean are calculated in a second pass. Otherwise, deviations are
702    calculated from ``c`` as given. Use the second case with care, as it can
703    lead to garbage results.
704    """
705    if c is not None:
706        T, total, count = _sum((x-c)**2 for x in data)
707        return (T, total)
708    T, total, count = _sum(data)
709    mean_n, mean_d = (total / count).as_integer_ratio()
710    partials = Counter()
711    for n, d in map(_exact_ratio, data):
712        diff_n = n * mean_d - d * mean_n
713        diff_d = d * mean_d
714        partials[diff_d * diff_d] += diff_n * diff_n
715    if None in partials:
716        # The sum will be a NAN or INF. We can ignore all the finite
717        # partials, and just look at this special one.
718        total = partials[None]
719        assert not _isfinite(total)
720    else:
721        total = sum(Fraction(n, d) for d, n in partials.items())
722    return (T, total)
723
724
725def variance(data, xbar=None):
726    """Return the sample variance of data.
727
728    data should be an iterable of Real-valued numbers, with at least two
729    values. The optional argument xbar, if given, should be the mean of
730    the data. If it is missing or None, the mean is automatically calculated.
731
732    Use this function when your data is a sample from a population. To
733    calculate the variance from the entire population, see ``pvariance``.
734
735    Examples:
736
737    >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
738    >>> variance(data)
739    1.3720238095238095
740
741    If you have already calculated the mean of your data, you can pass it as
742    the optional second argument ``xbar`` to avoid recalculating it:
743
744    >>> m = mean(data)
745    >>> variance(data, m)
746    1.3720238095238095
747
748    This function does not check that ``xbar`` is actually the mean of
749    ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
750    impossible results.
751
752    Decimals and Fractions are supported:
753
754    >>> from decimal import Decimal as D
755    >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
756    Decimal('31.01875')
757
758    >>> from fractions import Fraction as F
759    >>> variance([F(1, 6), F(1, 2), F(5, 3)])
760    Fraction(67, 108)
761
762    """
763    if iter(data) is data:
764        data = list(data)
765    n = len(data)
766    if n < 2:
767        raise StatisticsError('variance requires at least two data points')
768    T, ss = _ss(data, xbar)
769    return _convert(ss / (n - 1), T)
770
771
772def pvariance(data, mu=None):
773    """Return the population variance of ``data``.
774
775    data should be a sequence or iterable of Real-valued numbers, with at least one
776    value. The optional argument mu, if given, should be the mean of
777    the data. If it is missing or None, the mean is automatically calculated.
778
779    Use this function to calculate the variance from the entire population.
780    To estimate the variance from a sample, the ``variance`` function is
781    usually a better choice.
782
783    Examples:
784
785    >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
786    >>> pvariance(data)
787    1.25
788
789    If you have already calculated the mean of the data, you can pass it as
790    the optional second argument to avoid recalculating it:
791
792    >>> mu = mean(data)
793    >>> pvariance(data, mu)
794    1.25
795
796    Decimals and Fractions are supported:
797
798    >>> from decimal import Decimal as D
799    >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
800    Decimal('24.815')
801
802    >>> from fractions import Fraction as F
803    >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
804    Fraction(13, 72)
805
806    """
807    if iter(data) is data:
808        data = list(data)
809    n = len(data)
810    if n < 1:
811        raise StatisticsError('pvariance requires at least one data point')
812    T, ss = _ss(data, mu)
813    return _convert(ss / n, T)
814
815
816def stdev(data, xbar=None):
817    """Return the square root of the sample variance.
818
819    See ``variance`` for arguments and other details.
820
821    >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
822    1.0810874155219827
823
824    """
825    # Fixme: Despite the exact sum of squared deviations, some inaccuracy
826    # remain because there are two rounding steps.  The first occurs in
827    # the _convert() step for variance(), the second occurs in math.sqrt().
828    var = variance(data, xbar)
829    try:
830        return var.sqrt()
831    except AttributeError:
832        return math.sqrt(var)
833
834
835def pstdev(data, mu=None):
836    """Return the square root of the population variance.
837
838    See ``pvariance`` for arguments and other details.
839
840    >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
841    0.986893273527251
842
843    """
844    # Fixme: Despite the exact sum of squared deviations, some inaccuracy
845    # remain because there are two rounding steps.  The first occurs in
846    # the _convert() step for pvariance(), the second occurs in math.sqrt().
847    var = pvariance(data, mu)
848    try:
849        return var.sqrt()
850    except AttributeError:
851        return math.sqrt(var)
852
853
854# === Statistics for relations between two inputs ===
855
856# See https://en.wikipedia.org/wiki/Covariance
857#     https://en.wikipedia.org/wiki/Pearson_correlation_coefficient
858#     https://en.wikipedia.org/wiki/Simple_linear_regression
859
860
861def covariance(x, y, /):
862    """Covariance
863
864    Return the sample covariance of two inputs *x* and *y*. Covariance
865    is a measure of the joint variability of two inputs.
866
867    >>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
868    >>> y = [1, 2, 3, 1, 2, 3, 1, 2, 3]
869    >>> covariance(x, y)
870    0.75
871    >>> z = [9, 8, 7, 6, 5, 4, 3, 2, 1]
872    >>> covariance(x, z)
873    -7.5
874    >>> covariance(z, x)
875    -7.5
876
877    """
878    n = len(x)
879    if len(y) != n:
880        raise StatisticsError('covariance requires that both inputs have same number of data points')
881    if n < 2:
882        raise StatisticsError('covariance requires at least two data points')
883    xbar = fsum(x) / n
884    ybar = fsum(y) / n
885    sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
886    return sxy / (n - 1)
887
888
889def correlation(x, y, /):
890    """Pearson's correlation coefficient
891
892    Return the Pearson's correlation coefficient for two inputs. Pearson's
893    correlation coefficient *r* takes values between -1 and +1. It measures the
894    strength and direction of the linear relationship, where +1 means very
895    strong, positive linear relationship, -1 very strong, negative linear
896    relationship, and 0 no linear relationship.
897
898    >>> x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
899    >>> y = [9, 8, 7, 6, 5, 4, 3, 2, 1]
900    >>> correlation(x, x)
901    1.0
902    >>> correlation(x, y)
903    -1.0
904
905    """
906    n = len(x)
907    if len(y) != n:
908        raise StatisticsError('correlation requires that both inputs have same number of data points')
909    if n < 2:
910        raise StatisticsError('correlation requires at least two data points')
911    xbar = fsum(x) / n
912    ybar = fsum(y) / n
913    sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
914    sxx = fsum((xi - xbar) ** 2.0 for xi in x)
915    syy = fsum((yi - ybar) ** 2.0 for yi in y)
916    try:
917        return sxy / sqrt(sxx * syy)
918    except ZeroDivisionError:
919        raise StatisticsError('at least one of the inputs is constant')
920
921
922LinearRegression = namedtuple('LinearRegression', ('slope', 'intercept'))
923
924
925def linear_regression(x, y, /):
926    """Slope and intercept for simple linear regression.
927
928    Return the slope and intercept of simple linear regression
929    parameters estimated using ordinary least squares. Simple linear
930    regression describes relationship between an independent variable
931    *x* and a dependent variable *y* in terms of linear function:
932
933        y = slope * x + intercept + noise
934
935    where *slope* and *intercept* are the regression parameters that are
936    estimated, and noise represents the variability of the data that was
937    not explained by the linear regression (it is equal to the
938    difference between predicted and actual values of the dependent
939    variable).
940
941    The parameters are returned as a named tuple.
942
943    >>> x = [1, 2, 3, 4, 5]
944    >>> noise = NormalDist().samples(5, seed=42)
945    >>> y = [3 * x[i] + 2 + noise[i] for i in range(5)]
946    >>> linear_regression(x, y)  #doctest: +ELLIPSIS
947    LinearRegression(slope=3.09078914170..., intercept=1.75684970486...)
948
949    """
950    n = len(x)
951    if len(y) != n:
952        raise StatisticsError('linear regression requires that both inputs have same number of data points')
953    if n < 2:
954        raise StatisticsError('linear regression requires at least two data points')
955    xbar = fsum(x) / n
956    ybar = fsum(y) / n
957    sxy = fsum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
958    sxx = fsum((xi - xbar) ** 2.0 for xi in x)
959    try:
960        slope = sxy / sxx   # equivalent to:  covariance(x, y) / variance(x)
961    except ZeroDivisionError:
962        raise StatisticsError('x is constant')
963    intercept = ybar - slope * xbar
964    return LinearRegression(slope=slope, intercept=intercept)
965
966
967## Normal Distribution #####################################################
968
969
970def _normal_dist_inv_cdf(p, mu, sigma):
971    # There is no closed-form solution to the inverse CDF for the normal
972    # distribution, so we use a rational approximation instead:
973    # Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
974    # Normal Distribution".  Applied Statistics. Blackwell Publishing. 37
975    # (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
976    q = p - 0.5
977    if fabs(q) <= 0.425:
978        r = 0.180625 - q * q
979        # Hash sum: 55.88319_28806_14901_4439
980        num = (((((((2.50908_09287_30122_6727e+3 * r +
981                     3.34305_75583_58812_8105e+4) * r +
982                     6.72657_70927_00870_0853e+4) * r +
983                     4.59219_53931_54987_1457e+4) * r +
984                     1.37316_93765_50946_1125e+4) * r +
985                     1.97159_09503_06551_4427e+3) * r +
986                     1.33141_66789_17843_7745e+2) * r +
987                     3.38713_28727_96366_6080e+0) * q
988        den = (((((((5.22649_52788_52854_5610e+3 * r +
989                     2.87290_85735_72194_2674e+4) * r +
990                     3.93078_95800_09271_0610e+4) * r +
991                     2.12137_94301_58659_5867e+4) * r +
992                     5.39419_60214_24751_1077e+3) * r +
993                     6.87187_00749_20579_0830e+2) * r +
994                     4.23133_30701_60091_1252e+1) * r +
995                     1.0)
996        x = num / den
997        return mu + (x * sigma)
998    r = p if q <= 0.0 else 1.0 - p
999    r = sqrt(-log(r))
1000    if r <= 5.0:
1001        r = r - 1.6
1002        # Hash sum: 49.33206_50330_16102_89036
1003        num = (((((((7.74545_01427_83414_07640e-4 * r +
1004                     2.27238_44989_26918_45833e-2) * r +
1005                     2.41780_72517_74506_11770e-1) * r +
1006                     1.27045_82524_52368_38258e+0) * r +
1007                     3.64784_83247_63204_60504e+0) * r +
1008                     5.76949_72214_60691_40550e+0) * r +
1009                     4.63033_78461_56545_29590e+0) * r +
1010                     1.42343_71107_49683_57734e+0)
1011        den = (((((((1.05075_00716_44416_84324e-9 * r +
1012                     5.47593_80849_95344_94600e-4) * r +
1013                     1.51986_66563_61645_71966e-2) * r +
1014                     1.48103_97642_74800_74590e-1) * r +
1015                     6.89767_33498_51000_04550e-1) * r +
1016                     1.67638_48301_83803_84940e+0) * r +
1017                     2.05319_16266_37758_82187e+0) * r +
1018                     1.0)
1019    else:
1020        r = r - 5.0
1021        # Hash sum: 47.52583_31754_92896_71629
1022        num = (((((((2.01033_43992_92288_13265e-7 * r +
1023                     2.71155_55687_43487_57815e-5) * r +
1024                     1.24266_09473_88078_43860e-3) * r +
1025                     2.65321_89526_57612_30930e-2) * r +
1026                     2.96560_57182_85048_91230e-1) * r +
1027                     1.78482_65399_17291_33580e+0) * r +
1028                     5.46378_49111_64114_36990e+0) * r +
1029                     6.65790_46435_01103_77720e+0)
1030        den = (((((((2.04426_31033_89939_78564e-15 * r +
1031                     1.42151_17583_16445_88870e-7) * r +
1032                     1.84631_83175_10054_68180e-5) * r +
1033                     7.86869_13114_56132_59100e-4) * r +
1034                     1.48753_61290_85061_48525e-2) * r +
1035                     1.36929_88092_27358_05310e-1) * r +
1036                     5.99832_20655_58879_37690e-1) * r +
1037                     1.0)
1038    x = num / den
1039    if q < 0.0:
1040        x = -x
1041    return mu + (x * sigma)
1042
1043
1044# If available, use C implementation
1045try:
1046    from _statistics import _normal_dist_inv_cdf
1047except ImportError:
1048    pass
1049
1050
1051class NormalDist:
1052    "Normal distribution of a random variable"
1053    # https://en.wikipedia.org/wiki/Normal_distribution
1054    # https://en.wikipedia.org/wiki/Variance#Properties
1055
1056    __slots__ = {
1057        '_mu': 'Arithmetic mean of a normal distribution',
1058        '_sigma': 'Standard deviation of a normal distribution',
1059    }
1060
1061    def __init__(self, mu=0.0, sigma=1.0):
1062        "NormalDist where mu is the mean and sigma is the standard deviation."
1063        if sigma < 0.0:
1064            raise StatisticsError('sigma must be non-negative')
1065        self._mu = float(mu)
1066        self._sigma = float(sigma)
1067
1068    @classmethod
1069    def from_samples(cls, data):
1070        "Make a normal distribution instance from sample data."
1071        if not isinstance(data, (list, tuple)):
1072            data = list(data)
1073        xbar = fmean(data)
1074        return cls(xbar, stdev(data, xbar))
1075
1076    def samples(self, n, *, seed=None):
1077        "Generate *n* samples for a given mean and standard deviation."
1078        gauss = random.gauss if seed is None else random.Random(seed).gauss
1079        mu, sigma = self._mu, self._sigma
1080        return [gauss(mu, sigma) for i in range(n)]
1081
1082    def pdf(self, x):
1083        "Probability density function.  P(x <= X < x+dx) / dx"
1084        variance = self._sigma ** 2.0
1085        if not variance:
1086            raise StatisticsError('pdf() not defined when sigma is zero')
1087        return exp((x - self._mu)**2.0 / (-2.0*variance)) / sqrt(tau*variance)
1088
1089    def cdf(self, x):
1090        "Cumulative distribution function.  P(X <= x)"
1091        if not self._sigma:
1092            raise StatisticsError('cdf() not defined when sigma is zero')
1093        return 0.5 * (1.0 + erf((x - self._mu) / (self._sigma * sqrt(2.0))))
1094
1095    def inv_cdf(self, p):
1096        """Inverse cumulative distribution function.  x : P(X <= x) = p
1097
1098        Finds the value of the random variable such that the probability of
1099        the variable being less than or equal to that value equals the given
1100        probability.
1101
1102        This function is also called the percent point function or quantile
1103        function.
1104        """
1105        if p <= 0.0 or p >= 1.0:
1106            raise StatisticsError('p must be in the range 0.0 < p < 1.0')
1107        if self._sigma <= 0.0:
1108            raise StatisticsError('cdf() not defined when sigma at or below zero')
1109        return _normal_dist_inv_cdf(p, self._mu, self._sigma)
1110
1111    def quantiles(self, n=4):
1112        """Divide into *n* continuous intervals with equal probability.
1113
1114        Returns a list of (n - 1) cut points separating the intervals.
1115
1116        Set *n* to 4 for quartiles (the default).  Set *n* to 10 for deciles.
1117        Set *n* to 100 for percentiles which gives the 99 cuts points that
1118        separate the normal distribution in to 100 equal sized groups.
1119        """
1120        return [self.inv_cdf(i / n) for i in range(1, n)]
1121
1122    def overlap(self, other):
1123        """Compute the overlapping coefficient (OVL) between two normal distributions.
1124
1125        Measures the agreement between two normal probability distributions.
1126        Returns a value between 0.0 and 1.0 giving the overlapping area in
1127        the two underlying probability density functions.
1128
1129            >>> N1 = NormalDist(2.4, 1.6)
1130            >>> N2 = NormalDist(3.2, 2.0)
1131            >>> N1.overlap(N2)
1132            0.8035050657330205
1133        """
1134        # See: "The overlapping coefficient as a measure of agreement between
1135        # probability distributions and point estimation of the overlap of two
1136        # normal densities" -- Henry F. Inman and Edwin L. Bradley Jr
1137        # http://dx.doi.org/10.1080/03610928908830127
1138        if not isinstance(other, NormalDist):
1139            raise TypeError('Expected another NormalDist instance')
1140        X, Y = self, other
1141        if (Y._sigma, Y._mu) < (X._sigma, X._mu):  # sort to assure commutativity
1142            X, Y = Y, X
1143        X_var, Y_var = X.variance, Y.variance
1144        if not X_var or not Y_var:
1145            raise StatisticsError('overlap() not defined when sigma is zero')
1146        dv = Y_var - X_var
1147        dm = fabs(Y._mu - X._mu)
1148        if not dv:
1149            return 1.0 - erf(dm / (2.0 * X._sigma * sqrt(2.0)))
1150        a = X._mu * Y_var - Y._mu * X_var
1151        b = X._sigma * Y._sigma * sqrt(dm**2.0 + dv * log(Y_var / X_var))
1152        x1 = (a + b) / dv
1153        x2 = (a - b) / dv
1154        return 1.0 - (fabs(Y.cdf(x1) - X.cdf(x1)) + fabs(Y.cdf(x2) - X.cdf(x2)))
1155
1156    def zscore(self, x):
1157        """Compute the Standard Score.  (x - mean) / stdev
1158
1159        Describes *x* in terms of the number of standard deviations
1160        above or below the mean of the normal distribution.
1161        """
1162        # https://www.statisticshowto.com/probability-and-statistics/z-score/
1163        if not self._sigma:
1164            raise StatisticsError('zscore() not defined when sigma is zero')
1165        return (x - self._mu) / self._sigma
1166
1167    @property
1168    def mean(self):
1169        "Arithmetic mean of the normal distribution."
1170        return self._mu
1171
1172    @property
1173    def median(self):
1174        "Return the median of the normal distribution"
1175        return self._mu
1176
1177    @property
1178    def mode(self):
1179        """Return the mode of the normal distribution
1180
1181        The mode is the value x where which the probability density
1182        function (pdf) takes its maximum value.
1183        """
1184        return self._mu
1185
1186    @property
1187    def stdev(self):
1188        "Standard deviation of the normal distribution."
1189        return self._sigma
1190
1191    @property
1192    def variance(self):
1193        "Square of the standard deviation."
1194        return self._sigma ** 2.0
1195
1196    def __add__(x1, x2):
1197        """Add a constant or another NormalDist instance.
1198
1199        If *other* is a constant, translate mu by the constant,
1200        leaving sigma unchanged.
1201
1202        If *other* is a NormalDist, add both the means and the variances.
1203        Mathematically, this works only if the two distributions are
1204        independent or if they are jointly normally distributed.
1205        """
1206        if isinstance(x2, NormalDist):
1207            return NormalDist(x1._mu + x2._mu, hypot(x1._sigma, x2._sigma))
1208        return NormalDist(x1._mu + x2, x1._sigma)
1209
1210    def __sub__(x1, x2):
1211        """Subtract a constant or another NormalDist instance.
1212
1213        If *other* is a constant, translate by the constant mu,
1214        leaving sigma unchanged.
1215
1216        If *other* is a NormalDist, subtract the means and add the variances.
1217        Mathematically, this works only if the two distributions are
1218        independent or if they are jointly normally distributed.
1219        """
1220        if isinstance(x2, NormalDist):
1221            return NormalDist(x1._mu - x2._mu, hypot(x1._sigma, x2._sigma))
1222        return NormalDist(x1._mu - x2, x1._sigma)
1223
1224    def __mul__(x1, x2):
1225        """Multiply both mu and sigma by a constant.
1226
1227        Used for rescaling, perhaps to change measurement units.
1228        Sigma is scaled with the absolute value of the constant.
1229        """
1230        return NormalDist(x1._mu * x2, x1._sigma * fabs(x2))
1231
1232    def __truediv__(x1, x2):
1233        """Divide both mu and sigma by a constant.
1234
1235        Used for rescaling, perhaps to change measurement units.
1236        Sigma is scaled with the absolute value of the constant.
1237        """
1238        return NormalDist(x1._mu / x2, x1._sigma / fabs(x2))
1239
1240    def __pos__(x1):
1241        "Return a copy of the instance."
1242        return NormalDist(x1._mu, x1._sigma)
1243
1244    def __neg__(x1):
1245        "Negates mu while keeping sigma the same."
1246        return NormalDist(-x1._mu, x1._sigma)
1247
1248    __radd__ = __add__
1249
1250    def __rsub__(x1, x2):
1251        "Subtract a NormalDist from a constant or another NormalDist."
1252        return -(x1 - x2)
1253
1254    __rmul__ = __mul__
1255
1256    def __eq__(x1, x2):
1257        "Two NormalDist objects are equal if their mu and sigma are both equal."
1258        if not isinstance(x2, NormalDist):
1259            return NotImplemented
1260        return x1._mu == x2._mu and x1._sigma == x2._sigma
1261
1262    def __hash__(self):
1263        "NormalDist objects hash equal if their mu and sigma are both equal."
1264        return hash((self._mu, self._sigma))
1265
1266    def __repr__(self):
1267        return f'{type(self).__name__}(mu={self._mu!r}, sigma={self._sigma!r})'
1268