• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Test suite for statistics module, including helper NumericTestCase and
2approx_equal function.
3
4"""
5
6import bisect
7import collections
8import collections.abc
9import copy
10import decimal
11import doctest
12import math
13import pickle
14import random
15import sys
16import unittest
17from test import support
18
19from decimal import Decimal
20from fractions import Fraction
21from test import support
22
23
24# Module to be tested.
25import statistics
26
27
28# === Helper functions and class ===
29
30def sign(x):
31    """Return -1.0 for negatives, including -0.0, otherwise +1.0."""
32    return math.copysign(1, x)
33
34def _nan_equal(a, b):
35    """Return True if a and b are both the same kind of NAN.
36
37    >>> _nan_equal(Decimal('NAN'), Decimal('NAN'))
38    True
39    >>> _nan_equal(Decimal('sNAN'), Decimal('sNAN'))
40    True
41    >>> _nan_equal(Decimal('NAN'), Decimal('sNAN'))
42    False
43    >>> _nan_equal(Decimal(42), Decimal('NAN'))
44    False
45
46    >>> _nan_equal(float('NAN'), float('NAN'))
47    True
48    >>> _nan_equal(float('NAN'), 0.5)
49    False
50
51    >>> _nan_equal(float('NAN'), Decimal('NAN'))
52    False
53
54    NAN payloads are not compared.
55    """
56    if type(a) is not type(b):
57        return False
58    if isinstance(a, float):
59        return math.isnan(a) and math.isnan(b)
60    aexp = a.as_tuple()[2]
61    bexp = b.as_tuple()[2]
62    return (aexp == bexp) and (aexp in ('n', 'N'))  # Both NAN or both sNAN.
63
64
65def _calc_errors(actual, expected):
66    """Return the absolute and relative errors between two numbers.
67
68    >>> _calc_errors(100, 75)
69    (25, 0.25)
70    >>> _calc_errors(100, 100)
71    (0, 0.0)
72
73    Returns the (absolute error, relative error) between the two arguments.
74    """
75    base = max(abs(actual), abs(expected))
76    abs_err = abs(actual - expected)
77    rel_err = abs_err/base if base else float('inf')
78    return (abs_err, rel_err)
79
80
81def approx_equal(x, y, tol=1e-12, rel=1e-7):
82    """approx_equal(x, y [, tol [, rel]]) => True|False
83
84    Return True if numbers x and y are approximately equal, to within some
85    margin of error, otherwise return False. Numbers which compare equal
86    will also compare approximately equal.
87
88    x is approximately equal to y if the difference between them is less than
89    an absolute error tol or a relative error rel, whichever is bigger.
90
91    If given, both tol and rel must be finite, non-negative numbers. If not
92    given, default values are tol=1e-12 and rel=1e-7.
93
94    >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0)
95    True
96    >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0)
97    False
98
99    Absolute error is defined as abs(x-y); if that is less than or equal to
100    tol, x and y are considered approximately equal.
101
102    Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is
103    smaller, provided x or y are not zero. If that figure is less than or
104    equal to rel, x and y are considered approximately equal.
105
106    Complex numbers are not directly supported. If you wish to compare to
107    complex numbers, extract their real and imaginary parts and compare them
108    individually.
109
110    NANs always compare unequal, even with themselves. Infinities compare
111    approximately equal if they have the same sign (both positive or both
112    negative). Infinities with different signs compare unequal; so do
113    comparisons of infinities with finite numbers.
114    """
115    if tol < 0 or rel < 0:
116        raise ValueError('error tolerances must be non-negative')
117    # NANs are never equal to anything, approximately or otherwise.
118    if math.isnan(x) or math.isnan(y):
119        return False
120    # Numbers which compare equal also compare approximately equal.
121    if x == y:
122        # This includes the case of two infinities with the same sign.
123        return True
124    if math.isinf(x) or math.isinf(y):
125        # This includes the case of two infinities of opposite sign, or
126        # one infinity and one finite number.
127        return False
128    # Two finite numbers.
129    actual_error = abs(x - y)
130    allowed_error = max(tol, rel*max(abs(x), abs(y)))
131    return actual_error <= allowed_error
132
133
134# This class exists only as somewhere to stick a docstring containing
135# doctests. The following docstring and tests were originally in a separate
136# module. Now that it has been merged in here, I need somewhere to hang the.
137# docstring. Ultimately, this class will die, and the information below will
138# either become redundant, or be moved into more appropriate places.
139class _DoNothing:
140    """
141    When doing numeric work, especially with floats, exact equality is often
142    not what you want. Due to round-off error, it is often a bad idea to try
143    to compare floats with equality. Instead the usual procedure is to test
144    them with some (hopefully small!) allowance for error.
145
146    The ``approx_equal`` function allows you to specify either an absolute
147    error tolerance, or a relative error, or both.
148
149    Absolute error tolerances are simple, but you need to know the magnitude
150    of the quantities being compared:
151
152    >>> approx_equal(12.345, 12.346, tol=1e-3)
153    True
154    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3)  # tol is too small.
155    False
156
157    Relative errors are more suitable when the values you are comparing can
158    vary in magnitude:
159
160    >>> approx_equal(12.345, 12.346, rel=1e-4)
161    True
162    >>> approx_equal(12.345e6, 12.346e6, rel=1e-4)
163    True
164
165    but a naive implementation of relative error testing can run into trouble
166    around zero.
167
168    If you supply both an absolute tolerance and a relative error, the
169    comparison succeeds if either individual test succeeds:
170
171    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4)
172    True
173
174    """
175    pass
176
177
178
179# We prefer this for testing numeric values that may not be exactly equal,
180# and avoid using TestCase.assertAlmostEqual, because it sucks :-)
181
182py_statistics = support.import_fresh_module('statistics', blocked=['_statistics'])
183c_statistics = support.import_fresh_module('statistics', fresh=['_statistics'])
184
185
186class TestModules(unittest.TestCase):
187    func_names = ['_normal_dist_inv_cdf']
188
189    def test_py_functions(self):
190        for fname in self.func_names:
191            self.assertEqual(getattr(py_statistics, fname).__module__, 'statistics')
192
193    @unittest.skipUnless(c_statistics, 'requires _statistics')
194    def test_c_functions(self):
195        for fname in self.func_names:
196            self.assertEqual(getattr(c_statistics, fname).__module__, '_statistics')
197
198
199class NumericTestCase(unittest.TestCase):
200    """Unit test class for numeric work.
201
202    This subclasses TestCase. In addition to the standard method
203    ``TestCase.assertAlmostEqual``,  ``assertApproxEqual`` is provided.
204    """
205    # By default, we expect exact equality, unless overridden.
206    tol = rel = 0
207
208    def assertApproxEqual(
209            self, first, second, tol=None, rel=None, msg=None
210            ):
211        """Test passes if ``first`` and ``second`` are approximately equal.
212
213        This test passes if ``first`` and ``second`` are equal to
214        within ``tol``, an absolute error, or ``rel``, a relative error.
215
216        If either ``tol`` or ``rel`` are None or not given, they default to
217        test attributes of the same name (by default, 0).
218
219        The objects may be either numbers, or sequences of numbers. Sequences
220        are tested element-by-element.
221
222        >>> class MyTest(NumericTestCase):
223        ...     def test_number(self):
224        ...         x = 1.0/6
225        ...         y = sum([x]*6)
226        ...         self.assertApproxEqual(y, 1.0, tol=1e-15)
227        ...     def test_sequence(self):
228        ...         a = [1.001, 1.001e-10, 1.001e10]
229        ...         b = [1.0, 1e-10, 1e10]
230        ...         self.assertApproxEqual(a, b, rel=1e-3)
231        ...
232        >>> import unittest
233        >>> from io import StringIO  # Suppress test runner output.
234        >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest)
235        >>> unittest.TextTestRunner(stream=StringIO()).run(suite)
236        <unittest.runner.TextTestResult run=2 errors=0 failures=0>
237
238        """
239        if tol is None:
240            tol = self.tol
241        if rel is None:
242            rel = self.rel
243        if (
244                isinstance(first, collections.abc.Sequence) and
245                isinstance(second, collections.abc.Sequence)
246            ):
247            check = self._check_approx_seq
248        else:
249            check = self._check_approx_num
250        check(first, second, tol, rel, msg)
251
252    def _check_approx_seq(self, first, second, tol, rel, msg):
253        if len(first) != len(second):
254            standardMsg = (
255                "sequences differ in length: %d items != %d items"
256                % (len(first), len(second))
257                )
258            msg = self._formatMessage(msg, standardMsg)
259            raise self.failureException(msg)
260        for i, (a,e) in enumerate(zip(first, second)):
261            self._check_approx_num(a, e, tol, rel, msg, i)
262
263    def _check_approx_num(self, first, second, tol, rel, msg, idx=None):
264        if approx_equal(first, second, tol, rel):
265            # Test passes. Return early, we are done.
266            return None
267        # Otherwise we failed.
268        standardMsg = self._make_std_err_msg(first, second, tol, rel, idx)
269        msg = self._formatMessage(msg, standardMsg)
270        raise self.failureException(msg)
271
272    @staticmethod
273    def _make_std_err_msg(first, second, tol, rel, idx):
274        # Create the standard error message for approx_equal failures.
275        assert first != second
276        template = (
277            '  %r != %r\n'
278            '  values differ by more than tol=%r and rel=%r\n'
279            '  -> absolute error = %r\n'
280            '  -> relative error = %r'
281            )
282        if idx is not None:
283            header = 'numeric sequences first differ at index %d.\n' % idx
284            template = header + template
285        # Calculate actual errors:
286        abs_err, rel_err = _calc_errors(first, second)
287        return template % (first, second, tol, rel, abs_err, rel_err)
288
289
290# ========================
291# === Test the helpers ===
292# ========================
293
294class TestSign(unittest.TestCase):
295    """Test that the helper function sign() works correctly."""
296    def testZeroes(self):
297        # Test that signed zeroes report their sign correctly.
298        self.assertEqual(sign(0.0), +1)
299        self.assertEqual(sign(-0.0), -1)
300
301
302# --- Tests for approx_equal ---
303
304class ApproxEqualSymmetryTest(unittest.TestCase):
305    # Test symmetry of approx_equal.
306
307    def test_relative_symmetry(self):
308        # Check that approx_equal treats relative error symmetrically.
309        # (a-b)/a is usually not equal to (a-b)/b. Ensure that this
310        # doesn't matter.
311        #
312        #   Note: the reason for this test is that an early version
313        #   of approx_equal was not symmetric. A relative error test
314        #   would pass, or fail, depending on which value was passed
315        #   as the first argument.
316        #
317        args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)]
318        args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)]
319        assert len(args1) == len(args2)
320        for a, b in zip(args1, args2):
321            self.do_relative_symmetry(a, b)
322
323    def do_relative_symmetry(self, a, b):
324        a, b = min(a, b), max(a, b)
325        assert a < b
326        delta = b - a  # The absolute difference between the values.
327        rel_err1, rel_err2 = abs(delta/a), abs(delta/b)
328        # Choose an error margin halfway between the two.
329        rel = (rel_err1 + rel_err2)/2
330        # Now see that values a and b compare approx equal regardless of
331        # which is given first.
332        self.assertTrue(approx_equal(a, b, tol=0, rel=rel))
333        self.assertTrue(approx_equal(b, a, tol=0, rel=rel))
334
335    def test_symmetry(self):
336        # Test that approx_equal(a, b) == approx_equal(b, a)
337        args = [-23, -2, 5, 107, 93568]
338        delta = 2
339        for a in args:
340            for type_ in (int, float, Decimal, Fraction):
341                x = type_(a)*100
342                y = x + delta
343                r = abs(delta/max(x, y))
344                # There are five cases to check:
345                # 1) actual error <= tol, <= rel
346                self.do_symmetry_test(x, y, tol=delta, rel=r)
347                self.do_symmetry_test(x, y, tol=delta+1, rel=2*r)
348                # 2) actual error > tol, > rel
349                self.do_symmetry_test(x, y, tol=delta-1, rel=r/2)
350                # 3) actual error <= tol, > rel
351                self.do_symmetry_test(x, y, tol=delta, rel=r/2)
352                # 4) actual error > tol, <= rel
353                self.do_symmetry_test(x, y, tol=delta-1, rel=r)
354                self.do_symmetry_test(x, y, tol=delta-1, rel=2*r)
355                # 5) exact equality test
356                self.do_symmetry_test(x, x, tol=0, rel=0)
357                self.do_symmetry_test(x, y, tol=0, rel=0)
358
359    def do_symmetry_test(self, a, b, tol, rel):
360        template = "approx_equal comparisons don't match for %r"
361        flag1 = approx_equal(a, b, tol, rel)
362        flag2 = approx_equal(b, a, tol, rel)
363        self.assertEqual(flag1, flag2, template.format((a, b, tol, rel)))
364
365
366class ApproxEqualExactTest(unittest.TestCase):
367    # Test the approx_equal function with exactly equal values.
368    # Equal values should compare as approximately equal.
369    # Test cases for exactly equal values, which should compare approx
370    # equal regardless of the error tolerances given.
371
372    def do_exactly_equal_test(self, x, tol, rel):
373        result = approx_equal(x, x, tol=tol, rel=rel)
374        self.assertTrue(result, 'equality failure for x=%r' % x)
375        result = approx_equal(-x, -x, tol=tol, rel=rel)
376        self.assertTrue(result, 'equality failure for x=%r' % -x)
377
378    def test_exactly_equal_ints(self):
379        # Test that equal int values are exactly equal.
380        for n in [42, 19740, 14974, 230, 1795, 700245, 36587]:
381            self.do_exactly_equal_test(n, 0, 0)
382
383    def test_exactly_equal_floats(self):
384        # Test that equal float values are exactly equal.
385        for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]:
386            self.do_exactly_equal_test(x, 0, 0)
387
388    def test_exactly_equal_fractions(self):
389        # Test that equal Fraction values are exactly equal.
390        F = Fraction
391        for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]:
392            self.do_exactly_equal_test(f, 0, 0)
393
394    def test_exactly_equal_decimals(self):
395        # Test that equal Decimal values are exactly equal.
396        D = Decimal
397        for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()):
398            self.do_exactly_equal_test(d, 0, 0)
399
400    def test_exactly_equal_absolute(self):
401        # Test that equal values are exactly equal with an absolute error.
402        for n in [16, 1013, 1372, 1198, 971, 4]:
403            # Test as ints.
404            self.do_exactly_equal_test(n, 0.01, 0)
405            # Test as floats.
406            self.do_exactly_equal_test(n/10, 0.01, 0)
407            # Test as Fractions.
408            f = Fraction(n, 1234)
409            self.do_exactly_equal_test(f, 0.01, 0)
410
411    def test_exactly_equal_absolute_decimals(self):
412        # Test equal Decimal values are exactly equal with an absolute error.
413        self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0)
414        self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0)
415
416    def test_exactly_equal_relative(self):
417        # Test that equal values are exactly equal with a relative error.
418        for x in [8347, 101.3, -7910.28, Fraction(5, 21)]:
419            self.do_exactly_equal_test(x, 0, 0.01)
420        self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01"))
421
422    def test_exactly_equal_both(self):
423        # Test that equal values are equal when both tol and rel are given.
424        for x in [41017, 16.742, -813.02, Fraction(3, 8)]:
425            self.do_exactly_equal_test(x, 0.1, 0.01)
426        D = Decimal
427        self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01"))
428
429
430class ApproxEqualUnequalTest(unittest.TestCase):
431    # Unequal values should compare unequal with zero error tolerances.
432    # Test cases for unequal values, with exact equality test.
433
434    def do_exactly_unequal_test(self, x):
435        for a in (x, -x):
436            result = approx_equal(a, a+1, tol=0, rel=0)
437            self.assertFalse(result, 'inequality failure for x=%r' % a)
438
439    def test_exactly_unequal_ints(self):
440        # Test unequal int values are unequal with zero error tolerance.
441        for n in [951, 572305, 478, 917, 17240]:
442            self.do_exactly_unequal_test(n)
443
444    def test_exactly_unequal_floats(self):
445        # Test unequal float values are unequal with zero error tolerance.
446        for x in [9.51, 5723.05, 47.8, 9.17, 17.24]:
447            self.do_exactly_unequal_test(x)
448
449    def test_exactly_unequal_fractions(self):
450        # Test that unequal Fractions are unequal with zero error tolerance.
451        F = Fraction
452        for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]:
453            self.do_exactly_unequal_test(f)
454
455    def test_exactly_unequal_decimals(self):
456        # Test that unequal Decimals are unequal with zero error tolerance.
457        for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()):
458            self.do_exactly_unequal_test(d)
459
460
461class ApproxEqualInexactTest(unittest.TestCase):
462    # Inexact test cases for approx_error.
463    # Test cases when comparing two values that are not exactly equal.
464
465    # === Absolute error tests ===
466
467    def do_approx_equal_abs_test(self, x, delta):
468        template = "Test failure for x={!r}, y={!r}"
469        for y in (x + delta, x - delta):
470            msg = template.format(x, y)
471            self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg)
472            self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg)
473
474    def test_approx_equal_absolute_ints(self):
475        # Test approximate equality of ints with an absolute error.
476        for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]:
477            self.do_approx_equal_abs_test(n, 10)
478            self.do_approx_equal_abs_test(n, 2)
479
480    def test_approx_equal_absolute_floats(self):
481        # Test approximate equality of floats with an absolute error.
482        for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]:
483            self.do_approx_equal_abs_test(x, 1.5)
484            self.do_approx_equal_abs_test(x, 0.01)
485            self.do_approx_equal_abs_test(x, 0.0001)
486
487    def test_approx_equal_absolute_fractions(self):
488        # Test approximate equality of Fractions with an absolute error.
489        delta = Fraction(1, 29)
490        numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71]
491        for f in (Fraction(n, 29) for n in numerators):
492            self.do_approx_equal_abs_test(f, delta)
493            self.do_approx_equal_abs_test(f, float(delta))
494
495    def test_approx_equal_absolute_decimals(self):
496        # Test approximate equality of Decimals with an absolute error.
497        delta = Decimal("0.01")
498        for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()):
499            self.do_approx_equal_abs_test(d, delta)
500            self.do_approx_equal_abs_test(-d, delta)
501
502    def test_cross_zero(self):
503        # Test for the case of the two values having opposite signs.
504        self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0))
505
506    # === Relative error tests ===
507
508    def do_approx_equal_rel_test(self, x, delta):
509        template = "Test failure for x={!r}, y={!r}"
510        for y in (x*(1+delta), x*(1-delta)):
511            msg = template.format(x, y)
512            self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg)
513            self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg)
514
515    def test_approx_equal_relative_ints(self):
516        # Test approximate equality of ints with a relative error.
517        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36))
518        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37))
519        # ---
520        self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125))
521        self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125))
522        self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125))
523
524    def test_approx_equal_relative_floats(self):
525        # Test approximate equality of floats with a relative error.
526        for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]:
527            self.do_approx_equal_rel_test(x, 0.02)
528            self.do_approx_equal_rel_test(x, 0.0001)
529
530    def test_approx_equal_relative_fractions(self):
531        # Test approximate equality of Fractions with a relative error.
532        F = Fraction
533        delta = Fraction(3, 8)
534        for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]:
535            for d in (delta, float(delta)):
536                self.do_approx_equal_rel_test(f, d)
537                self.do_approx_equal_rel_test(-f, d)
538
539    def test_approx_equal_relative_decimals(self):
540        # Test approximate equality of Decimals with a relative error.
541        for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()):
542            self.do_approx_equal_rel_test(d, Decimal("0.001"))
543            self.do_approx_equal_rel_test(-d, Decimal("0.05"))
544
545    # === Both absolute and relative error tests ===
546
547    # There are four cases to consider:
548    #   1) actual error <= both absolute and relative error
549    #   2) actual error <= absolute error but > relative error
550    #   3) actual error <= relative error but > absolute error
551    #   4) actual error > both absolute and relative error
552
553    def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag):
554        check = self.assertTrue if tol_flag else self.assertFalse
555        check(approx_equal(a, b, tol=tol, rel=0))
556        check = self.assertTrue if rel_flag else self.assertFalse
557        check(approx_equal(a, b, tol=0, rel=rel))
558        check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse
559        check(approx_equal(a, b, tol=tol, rel=rel))
560
561    def test_approx_equal_both1(self):
562        # Test actual error <= both absolute and relative error.
563        self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True)
564        self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True)
565
566    def test_approx_equal_both2(self):
567        # Test actual error <= absolute error but > relative error.
568        self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False)
569
570    def test_approx_equal_both3(self):
571        # Test actual error <= relative error but > absolute error.
572        self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True)
573
574    def test_approx_equal_both4(self):
575        # Test actual error > both absolute and relative error.
576        self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False)
577        self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False)
578
579
580class ApproxEqualSpecialsTest(unittest.TestCase):
581    # Test approx_equal with NANs and INFs and zeroes.
582
583    def test_inf(self):
584        for type_ in (float, Decimal):
585            inf = type_('inf')
586            self.assertTrue(approx_equal(inf, inf))
587            self.assertTrue(approx_equal(inf, inf, 0, 0))
588            self.assertTrue(approx_equal(inf, inf, 1, 0.01))
589            self.assertTrue(approx_equal(-inf, -inf))
590            self.assertFalse(approx_equal(inf, -inf))
591            self.assertFalse(approx_equal(inf, 1000))
592
593    def test_nan(self):
594        for type_ in (float, Decimal):
595            nan = type_('nan')
596            for other in (nan, type_('inf'), 1000):
597                self.assertFalse(approx_equal(nan, other))
598
599    def test_float_zeroes(self):
600        nzero = math.copysign(0.0, -1)
601        self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1))
602
603    def test_decimal_zeroes(self):
604        nzero = Decimal("-0.0")
605        self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1))
606
607
608class TestApproxEqualErrors(unittest.TestCase):
609    # Test error conditions of approx_equal.
610
611    def test_bad_tol(self):
612        # Test negative tol raises.
613        self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1)
614
615    def test_bad_rel(self):
616        # Test negative rel raises.
617        self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1)
618
619
620# --- Tests for NumericTestCase ---
621
622# The formatting routine that generates the error messages is complex enough
623# that it too needs testing.
624
625class TestNumericTestCase(unittest.TestCase):
626    # The exact wording of NumericTestCase error messages is *not* guaranteed,
627    # but we need to give them some sort of test to ensure that they are
628    # generated correctly. As a compromise, we look for specific substrings
629    # that are expected to be found even if the overall error message changes.
630
631    def do_test(self, args):
632        actual_msg = NumericTestCase._make_std_err_msg(*args)
633        expected = self.generate_substrings(*args)
634        for substring in expected:
635            self.assertIn(substring, actual_msg)
636
637    def test_numerictestcase_is_testcase(self):
638        # Ensure that NumericTestCase actually is a TestCase.
639        self.assertTrue(issubclass(NumericTestCase, unittest.TestCase))
640
641    def test_error_msg_numeric(self):
642        # Test the error message generated for numeric comparisons.
643        args = (2.5, 4.0, 0.5, 0.25, None)
644        self.do_test(args)
645
646    def test_error_msg_sequence(self):
647        # Test the error message generated for sequence comparisons.
648        args = (3.75, 8.25, 1.25, 0.5, 7)
649        self.do_test(args)
650
651    def generate_substrings(self, first, second, tol, rel, idx):
652        """Return substrings we expect to see in error messages."""
653        abs_err, rel_err = _calc_errors(first, second)
654        substrings = [
655                'tol=%r' % tol,
656                'rel=%r' % rel,
657                'absolute error = %r' % abs_err,
658                'relative error = %r' % rel_err,
659                ]
660        if idx is not None:
661            substrings.append('differ at index %d' % idx)
662        return substrings
663
664
665# =======================================
666# === Tests for the statistics module ===
667# =======================================
668
669
670class GlobalsTest(unittest.TestCase):
671    module = statistics
672    expected_metadata = ["__doc__", "__all__"]
673
674    def test_meta(self):
675        # Test for the existence of metadata.
676        for meta in self.expected_metadata:
677            self.assertTrue(hasattr(self.module, meta),
678                            "%s not present" % meta)
679
680    def test_check_all(self):
681        # Check everything in __all__ exists and is public.
682        module = self.module
683        for name in module.__all__:
684            # No private names in __all__:
685            self.assertFalse(name.startswith("_"),
686                             'private name "%s" in __all__' % name)
687            # And anything in __all__ must exist:
688            self.assertTrue(hasattr(module, name),
689                            'missing name "%s" in __all__' % name)
690
691
692class DocTests(unittest.TestCase):
693    @unittest.skipIf(sys.flags.optimize >= 2,
694                     "Docstrings are omitted with -OO and above")
695    def test_doc_tests(self):
696        failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS)
697        self.assertGreater(tried, 0)
698        self.assertEqual(failed, 0)
699
700class StatisticsErrorTest(unittest.TestCase):
701    def test_has_exception(self):
702        errmsg = (
703                "Expected StatisticsError to be a ValueError, but got a"
704                " subclass of %r instead."
705                )
706        self.assertTrue(hasattr(statistics, 'StatisticsError'))
707        self.assertTrue(
708                issubclass(statistics.StatisticsError, ValueError),
709                errmsg % statistics.StatisticsError.__base__
710                )
711
712
713# === Tests for private utility functions ===
714
715class ExactRatioTest(unittest.TestCase):
716    # Test _exact_ratio utility.
717
718    def test_int(self):
719        for i in (-20, -3, 0, 5, 99, 10**20):
720            self.assertEqual(statistics._exact_ratio(i), (i, 1))
721
722    def test_fraction(self):
723        numerators = (-5, 1, 12, 38)
724        for n in numerators:
725            f = Fraction(n, 37)
726            self.assertEqual(statistics._exact_ratio(f), (n, 37))
727
728    def test_float(self):
729        self.assertEqual(statistics._exact_ratio(0.125), (1, 8))
730        self.assertEqual(statistics._exact_ratio(1.125), (9, 8))
731        data = [random.uniform(-100, 100) for _ in range(100)]
732        for x in data:
733            num, den = statistics._exact_ratio(x)
734            self.assertEqual(x, num/den)
735
736    def test_decimal(self):
737        D = Decimal
738        _exact_ratio = statistics._exact_ratio
739        self.assertEqual(_exact_ratio(D("0.125")), (1, 8))
740        self.assertEqual(_exact_ratio(D("12.345")), (2469, 200))
741        self.assertEqual(_exact_ratio(D("-1.98")), (-99, 50))
742
743    def test_inf(self):
744        INF = float("INF")
745        class MyFloat(float):
746            pass
747        class MyDecimal(Decimal):
748            pass
749        for inf in (INF, -INF):
750            for type_ in (float, MyFloat, Decimal, MyDecimal):
751                x = type_(inf)
752                ratio = statistics._exact_ratio(x)
753                self.assertEqual(ratio, (x, None))
754                self.assertEqual(type(ratio[0]), type_)
755                self.assertTrue(math.isinf(ratio[0]))
756
757    def test_float_nan(self):
758        NAN = float("NAN")
759        class MyFloat(float):
760            pass
761        for nan in (NAN, MyFloat(NAN)):
762            ratio = statistics._exact_ratio(nan)
763            self.assertTrue(math.isnan(ratio[0]))
764            self.assertIs(ratio[1], None)
765            self.assertEqual(type(ratio[0]), type(nan))
766
767    def test_decimal_nan(self):
768        NAN = Decimal("NAN")
769        sNAN = Decimal("sNAN")
770        class MyDecimal(Decimal):
771            pass
772        for nan in (NAN, MyDecimal(NAN), sNAN, MyDecimal(sNAN)):
773            ratio = statistics._exact_ratio(nan)
774            self.assertTrue(_nan_equal(ratio[0], nan))
775            self.assertIs(ratio[1], None)
776            self.assertEqual(type(ratio[0]), type(nan))
777
778
779class DecimalToRatioTest(unittest.TestCase):
780    # Test _exact_ratio private function.
781
782    def test_infinity(self):
783        # Test that INFs are handled correctly.
784        inf = Decimal('INF')
785        self.assertEqual(statistics._exact_ratio(inf), (inf, None))
786        self.assertEqual(statistics._exact_ratio(-inf), (-inf, None))
787
788    def test_nan(self):
789        # Test that NANs are handled correctly.
790        for nan in (Decimal('NAN'), Decimal('sNAN')):
791            num, den = statistics._exact_ratio(nan)
792            # Because NANs always compare non-equal, we cannot use assertEqual.
793            # Nor can we use an identity test, as we don't guarantee anything
794            # about the object identity.
795            self.assertTrue(_nan_equal(num, nan))
796            self.assertIs(den, None)
797
798    def test_sign(self):
799        # Test sign is calculated correctly.
800        numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")]
801        for d in numbers:
802            # First test positive decimals.
803            assert d > 0
804            num, den = statistics._exact_ratio(d)
805            self.assertGreaterEqual(num, 0)
806            self.assertGreater(den, 0)
807            # Then test negative decimals.
808            num, den = statistics._exact_ratio(-d)
809            self.assertLessEqual(num, 0)
810            self.assertGreater(den, 0)
811
812    def test_negative_exponent(self):
813        # Test result when the exponent is negative.
814        t = statistics._exact_ratio(Decimal("0.1234"))
815        self.assertEqual(t, (617, 5000))
816
817    def test_positive_exponent(self):
818        # Test results when the exponent is positive.
819        t = statistics._exact_ratio(Decimal("1.234e7"))
820        self.assertEqual(t, (12340000, 1))
821
822    def test_regression_20536(self):
823        # Regression test for issue 20536.
824        # See http://bugs.python.org/issue20536
825        t = statistics._exact_ratio(Decimal("1e2"))
826        self.assertEqual(t, (100, 1))
827        t = statistics._exact_ratio(Decimal("1.47e5"))
828        self.assertEqual(t, (147000, 1))
829
830
831class IsFiniteTest(unittest.TestCase):
832    # Test _isfinite private function.
833
834    def test_finite(self):
835        # Test that finite numbers are recognised as finite.
836        for x in (5, Fraction(1, 3), 2.5, Decimal("5.5")):
837            self.assertTrue(statistics._isfinite(x))
838
839    def test_infinity(self):
840        # Test that INFs are not recognised as finite.
841        for x in (float("inf"), Decimal("inf")):
842            self.assertFalse(statistics._isfinite(x))
843
844    def test_nan(self):
845        # Test that NANs are not recognised as finite.
846        for x in (float("nan"), Decimal("NAN"), Decimal("sNAN")):
847            self.assertFalse(statistics._isfinite(x))
848
849
850class CoerceTest(unittest.TestCase):
851    # Test that private function _coerce correctly deals with types.
852
853    # The coercion rules are currently an implementation detail, although at
854    # some point that should change. The tests and comments here define the
855    # correct implementation.
856
857    # Pre-conditions of _coerce:
858    #
859    #   - The first time _sum calls _coerce, the
860    #   - coerce(T, S) will never be called with bool as the first argument;
861    #     this is a pre-condition, guarded with an assertion.
862
863    #
864    #   - coerce(T, T) will always return T; we assume T is a valid numeric
865    #     type. Violate this assumption at your own risk.
866    #
867    #   - Apart from as above, bool is treated as if it were actually int.
868    #
869    #   - coerce(int, X) and coerce(X, int) return X.
870    #   -
871    def test_bool(self):
872        # bool is somewhat special, due to the pre-condition that it is
873        # never given as the first argument to _coerce, and that it cannot
874        # be subclassed. So we test it specially.
875        for T in (int, float, Fraction, Decimal):
876            self.assertIs(statistics._coerce(T, bool), T)
877            class MyClass(T): pass
878            self.assertIs(statistics._coerce(MyClass, bool), MyClass)
879
880    def assertCoerceTo(self, A, B):
881        """Assert that type A coerces to B."""
882        self.assertIs(statistics._coerce(A, B), B)
883        self.assertIs(statistics._coerce(B, A), B)
884
885    def check_coerce_to(self, A, B):
886        """Checks that type A coerces to B, including subclasses."""
887        # Assert that type A is coerced to B.
888        self.assertCoerceTo(A, B)
889        # Subclasses of A are also coerced to B.
890        class SubclassOfA(A): pass
891        self.assertCoerceTo(SubclassOfA, B)
892        # A, and subclasses of A, are coerced to subclasses of B.
893        class SubclassOfB(B): pass
894        self.assertCoerceTo(A, SubclassOfB)
895        self.assertCoerceTo(SubclassOfA, SubclassOfB)
896
897    def assertCoerceRaises(self, A, B):
898        """Assert that coercing A to B, or vice versa, raises TypeError."""
899        self.assertRaises(TypeError, statistics._coerce, (A, B))
900        self.assertRaises(TypeError, statistics._coerce, (B, A))
901
902    def check_type_coercions(self, T):
903        """Check that type T coerces correctly with subclasses of itself."""
904        assert T is not bool
905        # Coercing a type with itself returns the same type.
906        self.assertIs(statistics._coerce(T, T), T)
907        # Coercing a type with a subclass of itself returns the subclass.
908        class U(T): pass
909        class V(T): pass
910        class W(U): pass
911        for typ in (U, V, W):
912            self.assertCoerceTo(T, typ)
913        self.assertCoerceTo(U, W)
914        # Coercing two subclasses that aren't parent/child is an error.
915        self.assertCoerceRaises(U, V)
916        self.assertCoerceRaises(V, W)
917
918    def test_int(self):
919        # Check that int coerces correctly.
920        self.check_type_coercions(int)
921        for typ in (float, Fraction, Decimal):
922            self.check_coerce_to(int, typ)
923
924    def test_fraction(self):
925        # Check that Fraction coerces correctly.
926        self.check_type_coercions(Fraction)
927        self.check_coerce_to(Fraction, float)
928
929    def test_decimal(self):
930        # Check that Decimal coerces correctly.
931        self.check_type_coercions(Decimal)
932
933    def test_float(self):
934        # Check that float coerces correctly.
935        self.check_type_coercions(float)
936
937    def test_non_numeric_types(self):
938        for bad_type in (str, list, type(None), tuple, dict):
939            for good_type in (int, float, Fraction, Decimal):
940                self.assertCoerceRaises(good_type, bad_type)
941
942    def test_incompatible_types(self):
943        # Test that incompatible types raise.
944        for T in (float, Fraction):
945            class MySubclass(T): pass
946            self.assertCoerceRaises(T, Decimal)
947            self.assertCoerceRaises(MySubclass, Decimal)
948
949
950class ConvertTest(unittest.TestCase):
951    # Test private _convert function.
952
953    def check_exact_equal(self, x, y):
954        """Check that x equals y, and has the same type as well."""
955        self.assertEqual(x, y)
956        self.assertIs(type(x), type(y))
957
958    def test_int(self):
959        # Test conversions to int.
960        x = statistics._convert(Fraction(71), int)
961        self.check_exact_equal(x, 71)
962        class MyInt(int): pass
963        x = statistics._convert(Fraction(17), MyInt)
964        self.check_exact_equal(x, MyInt(17))
965
966    def test_fraction(self):
967        # Test conversions to Fraction.
968        x = statistics._convert(Fraction(95, 99), Fraction)
969        self.check_exact_equal(x, Fraction(95, 99))
970        class MyFraction(Fraction):
971            def __truediv__(self, other):
972                return self.__class__(super().__truediv__(other))
973        x = statistics._convert(Fraction(71, 13), MyFraction)
974        self.check_exact_equal(x, MyFraction(71, 13))
975
976    def test_float(self):
977        # Test conversions to float.
978        x = statistics._convert(Fraction(-1, 2), float)
979        self.check_exact_equal(x, -0.5)
980        class MyFloat(float):
981            def __truediv__(self, other):
982                return self.__class__(super().__truediv__(other))
983        x = statistics._convert(Fraction(9, 8), MyFloat)
984        self.check_exact_equal(x, MyFloat(1.125))
985
986    def test_decimal(self):
987        # Test conversions to Decimal.
988        x = statistics._convert(Fraction(1, 40), Decimal)
989        self.check_exact_equal(x, Decimal("0.025"))
990        class MyDecimal(Decimal):
991            def __truediv__(self, other):
992                return self.__class__(super().__truediv__(other))
993        x = statistics._convert(Fraction(-15, 16), MyDecimal)
994        self.check_exact_equal(x, MyDecimal("-0.9375"))
995
996    def test_inf(self):
997        for INF in (float('inf'), Decimal('inf')):
998            for inf in (INF, -INF):
999                x = statistics._convert(inf, type(inf))
1000                self.check_exact_equal(x, inf)
1001
1002    def test_nan(self):
1003        for nan in (float('nan'), Decimal('NAN'), Decimal('sNAN')):
1004            x = statistics._convert(nan, type(nan))
1005            self.assertTrue(_nan_equal(x, nan))
1006
1007    def test_invalid_input_type(self):
1008        with self.assertRaises(TypeError):
1009            statistics._convert(None, float)
1010
1011
1012class FailNegTest(unittest.TestCase):
1013    """Test _fail_neg private function."""
1014
1015    def test_pass_through(self):
1016        # Test that values are passed through unchanged.
1017        values = [1, 2.0, Fraction(3), Decimal(4)]
1018        new = list(statistics._fail_neg(values))
1019        self.assertEqual(values, new)
1020
1021    def test_negatives_raise(self):
1022        # Test that negatives raise an exception.
1023        for x in [1, 2.0, Fraction(3), Decimal(4)]:
1024            seq = [-x]
1025            it = statistics._fail_neg(seq)
1026            self.assertRaises(statistics.StatisticsError, next, it)
1027
1028    def test_error_msg(self):
1029        # Test that a given error message is used.
1030        msg = "badness #%d" % random.randint(10000, 99999)
1031        try:
1032            next(statistics._fail_neg([-1], msg))
1033        except statistics.StatisticsError as e:
1034            errmsg = e.args[0]
1035        else:
1036            self.fail("expected exception, but it didn't happen")
1037        self.assertEqual(errmsg, msg)
1038
1039
1040class FindLteqTest(unittest.TestCase):
1041    # Test _find_lteq private function.
1042
1043    def test_invalid_input_values(self):
1044        for a, x in [
1045            ([], 1),
1046            ([1, 2], 3),
1047            ([1, 3], 2)
1048        ]:
1049            with self.subTest(a=a, x=x):
1050                with self.assertRaises(ValueError):
1051                    statistics._find_lteq(a, x)
1052
1053    def test_locate_successfully(self):
1054        for a, x, expected_i in [
1055            ([1, 1, 1, 2, 3], 1, 0),
1056            ([0, 1, 1, 1, 2, 3], 1, 1),
1057            ([1, 2, 3, 3, 3], 3, 2)
1058        ]:
1059            with self.subTest(a=a, x=x):
1060                self.assertEqual(expected_i, statistics._find_lteq(a, x))
1061
1062
1063class FindRteqTest(unittest.TestCase):
1064    # Test _find_rteq private function.
1065
1066    def test_invalid_input_values(self):
1067        for a, l, x in [
1068            ([1], 2, 1),
1069            ([1, 3], 0, 2)
1070        ]:
1071            with self.assertRaises(ValueError):
1072                statistics._find_rteq(a, l, x)
1073
1074    def test_locate_successfully(self):
1075        for a, l, x, expected_i in [
1076            ([1, 1, 1, 2, 3], 0, 1, 2),
1077            ([0, 1, 1, 1, 2, 3], 0, 1, 3),
1078            ([1, 2, 3, 3, 3], 0, 3, 4)
1079        ]:
1080            with self.subTest(a=a, l=l, x=x):
1081                self.assertEqual(expected_i, statistics._find_rteq(a, l, x))
1082
1083
1084# === Tests for public functions ===
1085
1086class UnivariateCommonMixin:
1087    # Common tests for most univariate functions that take a data argument.
1088
1089    def test_no_args(self):
1090        # Fail if given no arguments.
1091        self.assertRaises(TypeError, self.func)
1092
1093    def test_empty_data(self):
1094        # Fail when the data argument (first argument) is empty.
1095        for empty in ([], (), iter([])):
1096            self.assertRaises(statistics.StatisticsError, self.func, empty)
1097
1098    def prepare_data(self):
1099        """Return int data for various tests."""
1100        data = list(range(10))
1101        while data == sorted(data):
1102            random.shuffle(data)
1103        return data
1104
1105    def test_no_inplace_modifications(self):
1106        # Test that the function does not modify its input data.
1107        data = self.prepare_data()
1108        assert len(data) != 1  # Necessary to avoid infinite loop.
1109        assert data != sorted(data)
1110        saved = data[:]
1111        assert data is not saved
1112        _ = self.func(data)
1113        self.assertListEqual(data, saved, "data has been modified")
1114
1115    def test_order_doesnt_matter(self):
1116        # Test that the order of data points doesn't change the result.
1117
1118        # CAUTION: due to floating point rounding errors, the result actually
1119        # may depend on the order. Consider this test representing an ideal.
1120        # To avoid this test failing, only test with exact values such as ints
1121        # or Fractions.
1122        data = [1, 2, 3, 3, 3, 4, 5, 6]*100
1123        expected = self.func(data)
1124        random.shuffle(data)
1125        actual = self.func(data)
1126        self.assertEqual(expected, actual)
1127
1128    def test_type_of_data_collection(self):
1129        # Test that the type of iterable data doesn't effect the result.
1130        class MyList(list):
1131            pass
1132        class MyTuple(tuple):
1133            pass
1134        def generator(data):
1135            return (obj for obj in data)
1136        data = self.prepare_data()
1137        expected = self.func(data)
1138        for kind in (list, tuple, iter, MyList, MyTuple, generator):
1139            result = self.func(kind(data))
1140            self.assertEqual(result, expected)
1141
1142    def test_range_data(self):
1143        # Test that functions work with range objects.
1144        data = range(20, 50, 3)
1145        expected = self.func(list(data))
1146        self.assertEqual(self.func(data), expected)
1147
1148    def test_bad_arg_types(self):
1149        # Test that function raises when given data of the wrong type.
1150
1151        # Don't roll the following into a loop like this:
1152        #   for bad in list_of_bad:
1153        #       self.check_for_type_error(bad)
1154        #
1155        # Since assertRaises doesn't show the arguments that caused the test
1156        # failure, it is very difficult to debug these test failures when the
1157        # following are in a loop.
1158        self.check_for_type_error(None)
1159        self.check_for_type_error(23)
1160        self.check_for_type_error(42.0)
1161        self.check_for_type_error(object())
1162
1163    def check_for_type_error(self, *args):
1164        self.assertRaises(TypeError, self.func, *args)
1165
1166    def test_type_of_data_element(self):
1167        # Check the type of data elements doesn't affect the numeric result.
1168        # This is a weaker test than UnivariateTypeMixin.testTypesConserved,
1169        # because it checks the numeric result by equality, but not by type.
1170        class MyFloat(float):
1171            def __truediv__(self, other):
1172                return type(self)(super().__truediv__(other))
1173            def __add__(self, other):
1174                return type(self)(super().__add__(other))
1175            __radd__ = __add__
1176
1177        raw = self.prepare_data()
1178        expected = self.func(raw)
1179        for kind in (float, MyFloat, Decimal, Fraction):
1180            data = [kind(x) for x in raw]
1181            result = type(expected)(self.func(data))
1182            self.assertEqual(result, expected)
1183
1184
1185class UnivariateTypeMixin:
1186    """Mixin class for type-conserving functions.
1187
1188    This mixin class holds test(s) for functions which conserve the type of
1189    individual data points. E.g. the mean of a list of Fractions should itself
1190    be a Fraction.
1191
1192    Not all tests to do with types need go in this class. Only those that
1193    rely on the function returning the same type as its input data.
1194    """
1195    def prepare_types_for_conservation_test(self):
1196        """Return the types which are expected to be conserved."""
1197        class MyFloat(float):
1198            def __truediv__(self, other):
1199                return type(self)(super().__truediv__(other))
1200            def __rtruediv__(self, other):
1201                return type(self)(super().__rtruediv__(other))
1202            def __sub__(self, other):
1203                return type(self)(super().__sub__(other))
1204            def __rsub__(self, other):
1205                return type(self)(super().__rsub__(other))
1206            def __pow__(self, other):
1207                return type(self)(super().__pow__(other))
1208            def __add__(self, other):
1209                return type(self)(super().__add__(other))
1210            __radd__ = __add__
1211        return (float, Decimal, Fraction, MyFloat)
1212
1213    def test_types_conserved(self):
1214        # Test that functions keeps the same type as their data points.
1215        # (Excludes mixed data types.) This only tests the type of the return
1216        # result, not the value.
1217        data = self.prepare_data()
1218        for kind in self.prepare_types_for_conservation_test():
1219            d = [kind(x) for x in data]
1220            result = self.func(d)
1221            self.assertIs(type(result), kind)
1222
1223
1224class TestSumCommon(UnivariateCommonMixin, UnivariateTypeMixin):
1225    # Common test cases for statistics._sum() function.
1226
1227    # This test suite looks only at the numeric value returned by _sum,
1228    # after conversion to the appropriate type.
1229    def setUp(self):
1230        def simplified_sum(*args):
1231            T, value, n = statistics._sum(*args)
1232            return statistics._coerce(value, T)
1233        self.func = simplified_sum
1234
1235
1236class TestSum(NumericTestCase):
1237    # Test cases for statistics._sum() function.
1238
1239    # These tests look at the entire three value tuple returned by _sum.
1240
1241    def setUp(self):
1242        self.func = statistics._sum
1243
1244    def test_empty_data(self):
1245        # Override test for empty data.
1246        for data in ([], (), iter([])):
1247            self.assertEqual(self.func(data), (int, Fraction(0), 0))
1248            self.assertEqual(self.func(data, 23), (int, Fraction(23), 0))
1249            self.assertEqual(self.func(data, 2.3), (float, Fraction(2.3), 0))
1250
1251    def test_ints(self):
1252        self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
1253                         (int, Fraction(60), 8))
1254        self.assertEqual(self.func([4, 2, 3, -8, 7], 1000),
1255                         (int, Fraction(1008), 5))
1256
1257    def test_floats(self):
1258        self.assertEqual(self.func([0.25]*20),
1259                         (float, Fraction(5.0), 20))
1260        self.assertEqual(self.func([0.125, 0.25, 0.5, 0.75], 1.5),
1261                         (float, Fraction(3.125), 4))
1262
1263    def test_fractions(self):
1264        self.assertEqual(self.func([Fraction(1, 1000)]*500),
1265                         (Fraction, Fraction(1, 2), 500))
1266
1267    def test_decimals(self):
1268        D = Decimal
1269        data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"),
1270                D("3.974"), D("2.328"), D("4.617"), D("2.843"),
1271                ]
1272        self.assertEqual(self.func(data),
1273                         (Decimal, Decimal("20.686"), 8))
1274
1275    def test_compare_with_math_fsum(self):
1276        # Compare with the math.fsum function.
1277        # Ideally we ought to get the exact same result, but sometimes
1278        # we differ by a very slight amount :-(
1279        data = [random.uniform(-100, 1000) for _ in range(1000)]
1280        self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
1281
1282    def test_start_argument(self):
1283        # Test that the optional start argument works correctly.
1284        data = [random.uniform(1, 1000) for _ in range(100)]
1285        t = self.func(data)[1]
1286        self.assertEqual(t+42, self.func(data, 42)[1])
1287        self.assertEqual(t-23, self.func(data, -23)[1])
1288        self.assertEqual(t+Fraction(1e20), self.func(data, 1e20)[1])
1289
1290    def test_strings_fail(self):
1291        # Sum of strings should fail.
1292        self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
1293        self.assertRaises(TypeError, self.func, [1, 2, 3, '999'])
1294
1295    def test_bytes_fail(self):
1296        # Sum of bytes should fail.
1297        self.assertRaises(TypeError, self.func, [1, 2, 3], b'999')
1298        self.assertRaises(TypeError, self.func, [1, 2, 3, b'999'])
1299
1300    def test_mixed_sum(self):
1301        # Mixed input types are not (currently) allowed.
1302        # Check that mixed data types fail.
1303        self.assertRaises(TypeError, self.func, [1, 2.0, Decimal(1)])
1304        # And so does mixed start argument.
1305        self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1))
1306
1307
1308class SumTortureTest(NumericTestCase):
1309    def test_torture(self):
1310        # Tim Peters' torture test for sum, and variants of same.
1311        self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000),
1312                         (float, Fraction(20000.0), 40000))
1313        self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000),
1314                         (float, Fraction(20000.0), 40000))
1315        T, num, count = statistics._sum([1e-100, 1, 1e-100, -1]*10000)
1316        self.assertIs(T, float)
1317        self.assertEqual(count, 40000)
1318        self.assertApproxEqual(float(num), 2.0e-96, rel=5e-16)
1319
1320
1321class SumSpecialValues(NumericTestCase):
1322    # Test that sum works correctly with IEEE-754 special values.
1323
1324    def test_nan(self):
1325        for type_ in (float, Decimal):
1326            nan = type_('nan')
1327            result = statistics._sum([1, nan, 2])[1]
1328            self.assertIs(type(result), type_)
1329            self.assertTrue(math.isnan(result))
1330
1331    def check_infinity(self, x, inf):
1332        """Check x is an infinity of the same type and sign as inf."""
1333        self.assertTrue(math.isinf(x))
1334        self.assertIs(type(x), type(inf))
1335        self.assertEqual(x > 0, inf > 0)
1336        assert x == inf
1337
1338    def do_test_inf(self, inf):
1339        # Adding a single infinity gives infinity.
1340        result = statistics._sum([1, 2, inf, 3])[1]
1341        self.check_infinity(result, inf)
1342        # Adding two infinities of the same sign also gives infinity.
1343        result = statistics._sum([1, 2, inf, 3, inf, 4])[1]
1344        self.check_infinity(result, inf)
1345
1346    def test_float_inf(self):
1347        inf = float('inf')
1348        for sign in (+1, -1):
1349            self.do_test_inf(sign*inf)
1350
1351    def test_decimal_inf(self):
1352        inf = Decimal('inf')
1353        for sign in (+1, -1):
1354            self.do_test_inf(sign*inf)
1355
1356    def test_float_mismatched_infs(self):
1357        # Test that adding two infinities of opposite sign gives a NAN.
1358        inf = float('inf')
1359        result = statistics._sum([1, 2, inf, 3, -inf, 4])[1]
1360        self.assertTrue(math.isnan(result))
1361
1362    def test_decimal_extendedcontext_mismatched_infs_to_nan(self):
1363        # Test adding Decimal INFs with opposite sign returns NAN.
1364        inf = Decimal('inf')
1365        data = [1, 2, inf, 3, -inf, 4]
1366        with decimal.localcontext(decimal.ExtendedContext):
1367            self.assertTrue(math.isnan(statistics._sum(data)[1]))
1368
1369    def test_decimal_basiccontext_mismatched_infs_to_nan(self):
1370        # Test adding Decimal INFs with opposite sign raises InvalidOperation.
1371        inf = Decimal('inf')
1372        data = [1, 2, inf, 3, -inf, 4]
1373        with decimal.localcontext(decimal.BasicContext):
1374            self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1375
1376    def test_decimal_snan_raises(self):
1377        # Adding sNAN should raise InvalidOperation.
1378        sNAN = Decimal('sNAN')
1379        data = [1, sNAN, 2]
1380        self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1381
1382
1383# === Tests for averages ===
1384
1385class AverageMixin(UnivariateCommonMixin):
1386    # Mixin class holding common tests for averages.
1387
1388    def test_single_value(self):
1389        # Average of a single value is the value itself.
1390        for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')):
1391            self.assertEqual(self.func([x]), x)
1392
1393    def prepare_values_for_repeated_single_test(self):
1394        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712'))
1395
1396    def test_repeated_single_value(self):
1397        # The average of a single repeated value is the value itself.
1398        for x in self.prepare_values_for_repeated_single_test():
1399            for count in (2, 5, 10, 20):
1400                with self.subTest(x=x, count=count):
1401                    data = [x]*count
1402                    self.assertEqual(self.func(data), x)
1403
1404
1405class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1406    def setUp(self):
1407        self.func = statistics.mean
1408
1409    def test_torture_pep(self):
1410        # "Torture Test" from PEP-450.
1411        self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1)
1412
1413    def test_ints(self):
1414        # Test mean with ints.
1415        data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9]
1416        random.shuffle(data)
1417        self.assertEqual(self.func(data), 4.8125)
1418
1419    def test_floats(self):
1420        # Test mean with floats.
1421        data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5]
1422        random.shuffle(data)
1423        self.assertEqual(self.func(data), 22.015625)
1424
1425    def test_decimals(self):
1426        # Test mean with Decimals.
1427        D = Decimal
1428        data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")]
1429        random.shuffle(data)
1430        self.assertEqual(self.func(data), D("3.5896"))
1431
1432    def test_fractions(self):
1433        # Test mean with Fractions.
1434        F = Fraction
1435        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1436        random.shuffle(data)
1437        self.assertEqual(self.func(data), F(1479, 1960))
1438
1439    def test_inf(self):
1440        # Test mean with infinities.
1441        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1442        for kind in (float, Decimal):
1443            for sign in (1, -1):
1444                inf = kind("inf")*sign
1445                data = raw + [inf]
1446                result = self.func(data)
1447                self.assertTrue(math.isinf(result))
1448                self.assertEqual(result, inf)
1449
1450    def test_mismatched_infs(self):
1451        # Test mean with infinities of opposite sign.
1452        data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')]
1453        result = self.func(data)
1454        self.assertTrue(math.isnan(result))
1455
1456    def test_nan(self):
1457        # Test mean with NANs.
1458        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1459        for kind in (float, Decimal):
1460            inf = kind("nan")
1461            data = raw + [inf]
1462            result = self.func(data)
1463            self.assertTrue(math.isnan(result))
1464
1465    def test_big_data(self):
1466        # Test adding a large constant to every data point.
1467        c = 1e9
1468        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1469        expected = self.func(data) + c
1470        assert expected != c
1471        result = self.func([x+c for x in data])
1472        self.assertEqual(result, expected)
1473
1474    def test_doubled_data(self):
1475        # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z].
1476        data = [random.uniform(-3, 5) for _ in range(1000)]
1477        expected = self.func(data)
1478        actual = self.func(data*2)
1479        self.assertApproxEqual(actual, expected)
1480
1481    def test_regression_20561(self):
1482        # Regression test for issue 20561.
1483        # See http://bugs.python.org/issue20561
1484        d = Decimal('1e4')
1485        self.assertEqual(statistics.mean([d]), d)
1486
1487    def test_regression_25177(self):
1488        # Regression test for issue 25177.
1489        # Ensure very big and very small floats don't overflow.
1490        # See http://bugs.python.org/issue25177.
1491        self.assertEqual(statistics.mean(
1492            [8.988465674311579e+307, 8.98846567431158e+307]),
1493            8.98846567431158e+307)
1494        big = 8.98846567431158e+307
1495        tiny = 5e-324
1496        for n in (2, 3, 5, 200):
1497            self.assertEqual(statistics.mean([big]*n), big)
1498            self.assertEqual(statistics.mean([tiny]*n), tiny)
1499
1500
1501class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1502    def setUp(self):
1503        self.func = statistics.harmonic_mean
1504
1505    def prepare_data(self):
1506        # Override mixin method.
1507        values = super().prepare_data()
1508        values.remove(0)
1509        return values
1510
1511    def prepare_values_for_repeated_single_test(self):
1512        # Override mixin method.
1513        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125'))
1514
1515    def test_zero(self):
1516        # Test that harmonic mean returns zero when given zero.
1517        values = [1, 0, 2]
1518        self.assertEqual(self.func(values), 0)
1519
1520    def test_negative_error(self):
1521        # Test that harmonic mean raises when given a negative value.
1522        exc = statistics.StatisticsError
1523        for values in ([-1], [1, -2, 3]):
1524            with self.subTest(values=values):
1525                self.assertRaises(exc, self.func, values)
1526
1527    def test_invalid_type_error(self):
1528        # Test error is raised when input contains invalid type(s)
1529        for data in [
1530            ['3.14'],               # single string
1531            ['1', '2', '3'],        # multiple strings
1532            [1, '2', 3, '4', 5],    # mixed strings and valid integers
1533            [2.3, 3.4, 4.5, '5.6']  # only one string and valid floats
1534        ]:
1535            with self.subTest(data=data):
1536                with self.assertRaises(TypeError):
1537                    self.func(data)
1538
1539    def test_ints(self):
1540        # Test harmonic mean with ints.
1541        data = [2, 4, 4, 8, 16, 16]
1542        random.shuffle(data)
1543        self.assertEqual(self.func(data), 6*4/5)
1544
1545    def test_floats_exact(self):
1546        # Test harmonic mean with some carefully chosen floats.
1547        data = [1/8, 1/4, 1/4, 1/2, 1/2]
1548        random.shuffle(data)
1549        self.assertEqual(self.func(data), 1/4)
1550        self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5)
1551
1552    def test_singleton_lists(self):
1553        # Test that harmonic mean([x]) returns (approximately) x.
1554        for x in range(1, 101):
1555            self.assertEqual(self.func([x]), x)
1556
1557    def test_decimals_exact(self):
1558        # Test harmonic mean with some carefully chosen Decimals.
1559        D = Decimal
1560        self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30))
1561        data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")]
1562        random.shuffle(data)
1563        self.assertEqual(self.func(data), D("0.10"))
1564        data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")]
1565        random.shuffle(data)
1566        self.assertEqual(self.func(data), D(66528)/70723)
1567
1568    def test_fractions(self):
1569        # Test harmonic mean with Fractions.
1570        F = Fraction
1571        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1572        random.shuffle(data)
1573        self.assertEqual(self.func(data), F(7*420, 4029))
1574
1575    def test_inf(self):
1576        # Test harmonic mean with infinity.
1577        values = [2.0, float('inf'), 1.0]
1578        self.assertEqual(self.func(values), 2.0)
1579
1580    def test_nan(self):
1581        # Test harmonic mean with NANs.
1582        values = [2.0, float('nan'), 1.0]
1583        self.assertTrue(math.isnan(self.func(values)))
1584
1585    def test_multiply_data_points(self):
1586        # Test multiplying every data point by a constant.
1587        c = 111
1588        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1589        expected = self.func(data)*c
1590        result = self.func([x*c for x in data])
1591        self.assertEqual(result, expected)
1592
1593    def test_doubled_data(self):
1594        # Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z].
1595        data = [random.uniform(1, 5) for _ in range(1000)]
1596        expected = self.func(data)
1597        actual = self.func(data*2)
1598        self.assertApproxEqual(actual, expected)
1599
1600
1601class TestMedian(NumericTestCase, AverageMixin):
1602    # Common tests for median and all median.* functions.
1603    def setUp(self):
1604        self.func = statistics.median
1605
1606    def prepare_data(self):
1607        """Overload method from UnivariateCommonMixin."""
1608        data = super().prepare_data()
1609        if len(data)%2 != 1:
1610            data.append(2)
1611        return data
1612
1613    def test_even_ints(self):
1614        # Test median with an even number of int data points.
1615        data = [1, 2, 3, 4, 5, 6]
1616        assert len(data)%2 == 0
1617        self.assertEqual(self.func(data), 3.5)
1618
1619    def test_odd_ints(self):
1620        # Test median with an odd number of int data points.
1621        data = [1, 2, 3, 4, 5, 6, 9]
1622        assert len(data)%2 == 1
1623        self.assertEqual(self.func(data), 4)
1624
1625    def test_odd_fractions(self):
1626        # Test median works with an odd number of Fractions.
1627        F = Fraction
1628        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)]
1629        assert len(data)%2 == 1
1630        random.shuffle(data)
1631        self.assertEqual(self.func(data), F(3, 7))
1632
1633    def test_even_fractions(self):
1634        # Test median works with an even number of Fractions.
1635        F = Fraction
1636        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1637        assert len(data)%2 == 0
1638        random.shuffle(data)
1639        self.assertEqual(self.func(data), F(1, 2))
1640
1641    def test_odd_decimals(self):
1642        # Test median works with an odd number of Decimals.
1643        D = Decimal
1644        data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1645        assert len(data)%2 == 1
1646        random.shuffle(data)
1647        self.assertEqual(self.func(data), D('4.2'))
1648
1649    def test_even_decimals(self):
1650        # Test median works with an even number of Decimals.
1651        D = Decimal
1652        data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1653        assert len(data)%2 == 0
1654        random.shuffle(data)
1655        self.assertEqual(self.func(data), D('3.65'))
1656
1657
1658class TestMedianDataType(NumericTestCase, UnivariateTypeMixin):
1659    # Test conservation of data element type for median.
1660    def setUp(self):
1661        self.func = statistics.median
1662
1663    def prepare_data(self):
1664        data = list(range(15))
1665        assert len(data)%2 == 1
1666        while data == sorted(data):
1667            random.shuffle(data)
1668        return data
1669
1670
1671class TestMedianLow(TestMedian, UnivariateTypeMixin):
1672    def setUp(self):
1673        self.func = statistics.median_low
1674
1675    def test_even_ints(self):
1676        # Test median_low with an even number of ints.
1677        data = [1, 2, 3, 4, 5, 6]
1678        assert len(data)%2 == 0
1679        self.assertEqual(self.func(data), 3)
1680
1681    def test_even_fractions(self):
1682        # Test median_low works with an even number of Fractions.
1683        F = Fraction
1684        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1685        assert len(data)%2 == 0
1686        random.shuffle(data)
1687        self.assertEqual(self.func(data), F(3, 7))
1688
1689    def test_even_decimals(self):
1690        # Test median_low works with an even number of Decimals.
1691        D = Decimal
1692        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1693        assert len(data)%2 == 0
1694        random.shuffle(data)
1695        self.assertEqual(self.func(data), D('3.3'))
1696
1697
1698class TestMedianHigh(TestMedian, UnivariateTypeMixin):
1699    def setUp(self):
1700        self.func = statistics.median_high
1701
1702    def test_even_ints(self):
1703        # Test median_high with an even number of ints.
1704        data = [1, 2, 3, 4, 5, 6]
1705        assert len(data)%2 == 0
1706        self.assertEqual(self.func(data), 4)
1707
1708    def test_even_fractions(self):
1709        # Test median_high works with an even number of Fractions.
1710        F = Fraction
1711        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1712        assert len(data)%2 == 0
1713        random.shuffle(data)
1714        self.assertEqual(self.func(data), F(4, 7))
1715
1716    def test_even_decimals(self):
1717        # Test median_high works with an even number of Decimals.
1718        D = Decimal
1719        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1720        assert len(data)%2 == 0
1721        random.shuffle(data)
1722        self.assertEqual(self.func(data), D('4.4'))
1723
1724
1725class TestMedianGrouped(TestMedian):
1726    # Test median_grouped.
1727    # Doesn't conserve data element types, so don't use TestMedianType.
1728    def setUp(self):
1729        self.func = statistics.median_grouped
1730
1731    def test_odd_number_repeated(self):
1732        # Test median.grouped with repeated median values.
1733        data = [12, 13, 14, 14, 14, 15, 15]
1734        assert len(data)%2 == 1
1735        self.assertEqual(self.func(data), 14)
1736        #---
1737        data = [12, 13, 14, 14, 14, 14, 15]
1738        assert len(data)%2 == 1
1739        self.assertEqual(self.func(data), 13.875)
1740        #---
1741        data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30]
1742        assert len(data)%2 == 1
1743        self.assertEqual(self.func(data, 5), 19.375)
1744        #---
1745        data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28]
1746        assert len(data)%2 == 1
1747        self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8)
1748
1749    def test_even_number_repeated(self):
1750        # Test median.grouped with repeated median values.
1751        data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30]
1752        assert len(data)%2 == 0
1753        self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8)
1754        #---
1755        data = [2, 3, 4, 4, 4, 5]
1756        assert len(data)%2 == 0
1757        self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8)
1758        #---
1759        data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1760        assert len(data)%2 == 0
1761        self.assertEqual(self.func(data), 4.5)
1762        #---
1763        data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1764        assert len(data)%2 == 0
1765        self.assertEqual(self.func(data), 4.75)
1766
1767    def test_repeated_single_value(self):
1768        # Override method from AverageMixin.
1769        # Yet again, failure of median_grouped to conserve the data type
1770        # causes me headaches :-(
1771        for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')):
1772            for count in (2, 5, 10, 20):
1773                data = [x]*count
1774                self.assertEqual(self.func(data), float(x))
1775
1776    def test_odd_fractions(self):
1777        # Test median_grouped works with an odd number of Fractions.
1778        F = Fraction
1779        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)]
1780        assert len(data)%2 == 1
1781        random.shuffle(data)
1782        self.assertEqual(self.func(data), 3.0)
1783
1784    def test_even_fractions(self):
1785        # Test median_grouped works with an even number of Fractions.
1786        F = Fraction
1787        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)]
1788        assert len(data)%2 == 0
1789        random.shuffle(data)
1790        self.assertEqual(self.func(data), 3.25)
1791
1792    def test_odd_decimals(self):
1793        # Test median_grouped works with an odd number of Decimals.
1794        D = Decimal
1795        data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1796        assert len(data)%2 == 1
1797        random.shuffle(data)
1798        self.assertEqual(self.func(data), 6.75)
1799
1800    def test_even_decimals(self):
1801        # Test median_grouped works with an even number of Decimals.
1802        D = Decimal
1803        data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1804        assert len(data)%2 == 0
1805        random.shuffle(data)
1806        self.assertEqual(self.func(data), 6.5)
1807        #---
1808        data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')]
1809        assert len(data)%2 == 0
1810        random.shuffle(data)
1811        self.assertEqual(self.func(data), 7.0)
1812
1813    def test_interval(self):
1814        # Test median_grouped with interval argument.
1815        data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1816        self.assertEqual(self.func(data, 0.25), 2.875)
1817        data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1818        self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8)
1819        data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340]
1820        self.assertEqual(self.func(data, 20), 265.0)
1821
1822    def test_data_type_error(self):
1823        # Test median_grouped with str, bytes data types for data and interval
1824        data = ["", "", ""]
1825        self.assertRaises(TypeError, self.func, data)
1826        #---
1827        data = [b"", b"", b""]
1828        self.assertRaises(TypeError, self.func, data)
1829        #---
1830        data = [1, 2, 3]
1831        interval = ""
1832        self.assertRaises(TypeError, self.func, data, interval)
1833        #---
1834        data = [1, 2, 3]
1835        interval = b""
1836        self.assertRaises(TypeError, self.func, data, interval)
1837
1838
1839class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1840    # Test cases for the discrete version of mode.
1841    def setUp(self):
1842        self.func = statistics.mode
1843
1844    def prepare_data(self):
1845        """Overload method from UnivariateCommonMixin."""
1846        # Make sure test data has exactly one mode.
1847        return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2]
1848
1849    def test_range_data(self):
1850        # Override test from UnivariateCommonMixin.
1851        data = range(20, 50, 3)
1852        self.assertEqual(self.func(data), 20)
1853
1854    def test_nominal_data(self):
1855        # Test mode with nominal data.
1856        data = 'abcbdb'
1857        self.assertEqual(self.func(data), 'b')
1858        data = 'fe fi fo fum fi fi'.split()
1859        self.assertEqual(self.func(data), 'fi')
1860
1861    def test_discrete_data(self):
1862        # Test mode with discrete numeric data.
1863        data = list(range(10))
1864        for i in range(10):
1865            d = data + [i]
1866            random.shuffle(d)
1867            self.assertEqual(self.func(d), i)
1868
1869    def test_bimodal_data(self):
1870        # Test mode with bimodal data.
1871        data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9]
1872        assert data.count(2) == data.count(6) == 4
1873        # mode() should return 2, the first encountered mode
1874        self.assertEqual(self.func(data), 2)
1875
1876    def test_unique_data(self):
1877        # Test mode when data points are all unique.
1878        data = list(range(10))
1879        # mode() should return 0, the first encountered mode
1880        self.assertEqual(self.func(data), 0)
1881
1882    def test_none_data(self):
1883        # Test that mode raises TypeError if given None as data.
1884
1885        # This test is necessary because the implementation of mode uses
1886        # collections.Counter, which accepts None and returns an empty dict.
1887        self.assertRaises(TypeError, self.func, None)
1888
1889    def test_counter_data(self):
1890        # Test that a Counter is treated like any other iterable.
1891        data = collections.Counter([1, 1, 1, 2])
1892        # Since the keys of the counter are treated as data points, not the
1893        # counts, this should return the first mode encountered, 1
1894        self.assertEqual(self.func(data), 1)
1895
1896
1897class TestMultiMode(unittest.TestCase):
1898
1899    def test_basics(self):
1900        multimode = statistics.multimode
1901        self.assertEqual(multimode('aabbbbbbbbcc'), ['b'])
1902        self.assertEqual(multimode('aabbbbccddddeeffffgg'), ['b', 'd', 'f'])
1903        self.assertEqual(multimode(''), [])
1904
1905
1906class TestFMean(unittest.TestCase):
1907
1908    def test_basics(self):
1909        fmean = statistics.fmean
1910        D = Decimal
1911        F = Fraction
1912        for data, expected_mean, kind in [
1913            ([3.5, 4.0, 5.25], 4.25, 'floats'),
1914            ([D('3.5'), D('4.0'), D('5.25')], 4.25, 'decimals'),
1915            ([F(7, 2), F(4, 1), F(21, 4)], 4.25, 'fractions'),
1916            ([True, False, True, True, False], 0.60, 'booleans'),
1917            ([3.5, 4, F(21, 4)], 4.25, 'mixed types'),
1918            ((3.5, 4.0, 5.25), 4.25, 'tuple'),
1919            (iter([3.5, 4.0, 5.25]), 4.25, 'iterator'),
1920                ]:
1921            actual_mean = fmean(data)
1922            self.assertIs(type(actual_mean), float, kind)
1923            self.assertEqual(actual_mean, expected_mean, kind)
1924
1925    def test_error_cases(self):
1926        fmean = statistics.fmean
1927        StatisticsError = statistics.StatisticsError
1928        with self.assertRaises(StatisticsError):
1929            fmean([])                               # empty input
1930        with self.assertRaises(StatisticsError):
1931            fmean(iter([]))                         # empty iterator
1932        with self.assertRaises(TypeError):
1933            fmean(None)                             # non-iterable input
1934        with self.assertRaises(TypeError):
1935            fmean([10, None, 20])                   # non-numeric input
1936        with self.assertRaises(TypeError):
1937            fmean()                                 # missing data argument
1938        with self.assertRaises(TypeError):
1939            fmean([10, 20, 60], 70)                 # too many arguments
1940
1941    def test_special_values(self):
1942        # Rules for special values are inherited from math.fsum()
1943        fmean = statistics.fmean
1944        NaN = float('Nan')
1945        Inf = float('Inf')
1946        self.assertTrue(math.isnan(fmean([10, NaN])), 'nan')
1947        self.assertTrue(math.isnan(fmean([NaN, Inf])), 'nan and infinity')
1948        self.assertTrue(math.isinf(fmean([10, Inf])), 'infinity')
1949        with self.assertRaises(ValueError):
1950            fmean([Inf, -Inf])
1951
1952
1953# === Tests for variances and standard deviations ===
1954
1955class VarianceStdevMixin(UnivariateCommonMixin):
1956    # Mixin class holding common tests for variance and std dev.
1957
1958    # Subclasses should inherit from this before NumericTestClass, in order
1959    # to see the rel attribute below. See testShiftData for an explanation.
1960
1961    rel = 1e-12
1962
1963    def test_single_value(self):
1964        # Deviation of a single value is zero.
1965        for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')):
1966            self.assertEqual(self.func([x]), 0)
1967
1968    def test_repeated_single_value(self):
1969        # The deviation of a single repeated value is zero.
1970        for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')):
1971            for count in (2, 3, 5, 15):
1972                data = [x]*count
1973                self.assertEqual(self.func(data), 0)
1974
1975    def test_domain_error_regression(self):
1976        # Regression test for a domain error exception.
1977        # (Thanks to Geremy Condra.)
1978        data = [0.123456789012345]*10000
1979        # All the items are identical, so variance should be exactly zero.
1980        # We allow some small round-off error, but not much.
1981        result = self.func(data)
1982        self.assertApproxEqual(result, 0.0, tol=5e-17)
1983        self.assertGreaterEqual(result, 0)  # A negative result must fail.
1984
1985    def test_shift_data(self):
1986        # Test that shifting the data by a constant amount does not affect
1987        # the variance or stdev. Or at least not much.
1988
1989        # Due to rounding, this test should be considered an ideal. We allow
1990        # some tolerance away from "no change at all" by setting tol and/or rel
1991        # attributes. Subclasses may set tighter or looser error tolerances.
1992        raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78]
1993        expected = self.func(raw)
1994        # Don't set shift too high, the bigger it is, the more rounding error.
1995        shift = 1e5
1996        data = [x + shift for x in raw]
1997        self.assertApproxEqual(self.func(data), expected)
1998
1999    def test_shift_data_exact(self):
2000        # Like test_shift_data, but result is always exact.
2001        raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16]
2002        assert all(x==int(x) for x in raw)
2003        expected = self.func(raw)
2004        shift = 10**9
2005        data = [x + shift for x in raw]
2006        self.assertEqual(self.func(data), expected)
2007
2008    def test_iter_list_same(self):
2009        # Test that iter data and list data give the same result.
2010
2011        # This is an explicit test that iterators and lists are treated the
2012        # same; justification for this test over and above the similar test
2013        # in UnivariateCommonMixin is that an earlier design had variance and
2014        # friends swap between one- and two-pass algorithms, which would
2015        # sometimes give different results.
2016        data = [random.uniform(-3, 8) for _ in range(1000)]
2017        expected = self.func(data)
2018        self.assertEqual(self.func(iter(data)), expected)
2019
2020
2021class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
2022    # Tests for population variance.
2023    def setUp(self):
2024        self.func = statistics.pvariance
2025
2026    def test_exact_uniform(self):
2027        # Test the variance against an exact result for uniform data.
2028        data = list(range(10000))
2029        random.shuffle(data)
2030        expected = (10000**2 - 1)/12  # Exact value.
2031        self.assertEqual(self.func(data), expected)
2032
2033    def test_ints(self):
2034        # Test population variance with int data.
2035        data = [4, 7, 13, 16]
2036        exact = 22.5
2037        self.assertEqual(self.func(data), exact)
2038
2039    def test_fractions(self):
2040        # Test population variance with Fraction data.
2041        F = Fraction
2042        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
2043        exact = F(3, 8)
2044        result = self.func(data)
2045        self.assertEqual(result, exact)
2046        self.assertIsInstance(result, Fraction)
2047
2048    def test_decimals(self):
2049        # Test population variance with Decimal data.
2050        D = Decimal
2051        data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")]
2052        exact = D('0.096875')
2053        result = self.func(data)
2054        self.assertEqual(result, exact)
2055        self.assertIsInstance(result, Decimal)
2056
2057
2058class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
2059    # Tests for sample variance.
2060    def setUp(self):
2061        self.func = statistics.variance
2062
2063    def test_single_value(self):
2064        # Override method from VarianceStdevMixin.
2065        for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')):
2066            self.assertRaises(statistics.StatisticsError, self.func, [x])
2067
2068    def test_ints(self):
2069        # Test sample variance with int data.
2070        data = [4, 7, 13, 16]
2071        exact = 30
2072        self.assertEqual(self.func(data), exact)
2073
2074    def test_fractions(self):
2075        # Test sample variance with Fraction data.
2076        F = Fraction
2077        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
2078        exact = F(1, 2)
2079        result = self.func(data)
2080        self.assertEqual(result, exact)
2081        self.assertIsInstance(result, Fraction)
2082
2083    def test_decimals(self):
2084        # Test sample variance with Decimal data.
2085        D = Decimal
2086        data = [D(2), D(2), D(7), D(9)]
2087        exact = 4*D('9.5')/D(3)
2088        result = self.func(data)
2089        self.assertEqual(result, exact)
2090        self.assertIsInstance(result, Decimal)
2091
2092    def test_center_not_at_mean(self):
2093        data = (1.0, 2.0)
2094        self.assertEqual(self.func(data), 0.5)
2095        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2096
2097class TestPStdev(VarianceStdevMixin, NumericTestCase):
2098    # Tests for population standard deviation.
2099    def setUp(self):
2100        self.func = statistics.pstdev
2101
2102    def test_compare_to_variance(self):
2103        # Test that stdev is, in fact, the square root of variance.
2104        data = [random.uniform(-17, 24) for _ in range(1000)]
2105        expected = math.sqrt(statistics.pvariance(data))
2106        self.assertEqual(self.func(data), expected)
2107
2108    def test_center_not_at_mean(self):
2109        # See issue: 40855
2110        data = (3, 6, 7, 10)
2111        self.assertEqual(self.func(data), 2.5)
2112        self.assertEqual(self.func(data, mu=0.5), 6.5)
2113
2114class TestStdev(VarianceStdevMixin, NumericTestCase):
2115    # Tests for sample standard deviation.
2116    def setUp(self):
2117        self.func = statistics.stdev
2118
2119    def test_single_value(self):
2120        # Override method from VarianceStdevMixin.
2121        for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')):
2122            self.assertRaises(statistics.StatisticsError, self.func, [x])
2123
2124    def test_compare_to_variance(self):
2125        # Test that stdev is, in fact, the square root of variance.
2126        data = [random.uniform(-2, 9) for _ in range(1000)]
2127        expected = math.sqrt(statistics.variance(data))
2128        self.assertEqual(self.func(data), expected)
2129
2130    def test_center_not_at_mean(self):
2131        data = (1.0, 2.0)
2132        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2133
2134class TestGeometricMean(unittest.TestCase):
2135
2136    def test_basics(self):
2137        geometric_mean = statistics.geometric_mean
2138        self.assertAlmostEqual(geometric_mean([54, 24, 36]), 36.0)
2139        self.assertAlmostEqual(geometric_mean([4.0, 9.0]), 6.0)
2140        self.assertAlmostEqual(geometric_mean([17.625]), 17.625)
2141
2142        random.seed(86753095551212)
2143        for rng in [
2144                range(1, 100),
2145                range(1, 1_000),
2146                range(1, 10_000),
2147                range(500, 10_000, 3),
2148                range(10_000, 500, -3),
2149                [12, 17, 13, 5, 120, 7],
2150                [random.expovariate(50.0) for i in range(1_000)],
2151                [random.lognormvariate(20.0, 3.0) for i in range(2_000)],
2152                [random.triangular(2000, 3000, 2200) for i in range(3_000)],
2153            ]:
2154            gm_decimal = math.prod(map(Decimal, rng)) ** (Decimal(1) / len(rng))
2155            gm_float = geometric_mean(rng)
2156            self.assertTrue(math.isclose(gm_float, float(gm_decimal)))
2157
2158    def test_various_input_types(self):
2159        geometric_mean = statistics.geometric_mean
2160        D = Decimal
2161        F = Fraction
2162        # https://www.wolframalpha.com/input/?i=geometric+mean+3.5,+4.0,+5.25
2163        expected_mean = 4.18886
2164        for data, kind in [
2165            ([3.5, 4.0, 5.25], 'floats'),
2166            ([D('3.5'), D('4.0'), D('5.25')], 'decimals'),
2167            ([F(7, 2), F(4, 1), F(21, 4)], 'fractions'),
2168            ([3.5, 4, F(21, 4)], 'mixed types'),
2169            ((3.5, 4.0, 5.25), 'tuple'),
2170            (iter([3.5, 4.0, 5.25]), 'iterator'),
2171                ]:
2172            actual_mean = geometric_mean(data)
2173            self.assertIs(type(actual_mean), float, kind)
2174            self.assertAlmostEqual(actual_mean, expected_mean, places=5)
2175
2176    def test_big_and_small(self):
2177        geometric_mean = statistics.geometric_mean
2178
2179        # Avoid overflow to infinity
2180        large = 2.0 ** 1000
2181        big_gm = geometric_mean([54.0 * large, 24.0 * large, 36.0 * large])
2182        self.assertTrue(math.isclose(big_gm, 36.0 * large))
2183        self.assertFalse(math.isinf(big_gm))
2184
2185        # Avoid underflow to zero
2186        small = 2.0 ** -1000
2187        small_gm = geometric_mean([54.0 * small, 24.0 * small, 36.0 * small])
2188        self.assertTrue(math.isclose(small_gm, 36.0 * small))
2189        self.assertNotEqual(small_gm, 0.0)
2190
2191    def test_error_cases(self):
2192        geometric_mean = statistics.geometric_mean
2193        StatisticsError = statistics.StatisticsError
2194        with self.assertRaises(StatisticsError):
2195            geometric_mean([])                      # empty input
2196        with self.assertRaises(StatisticsError):
2197            geometric_mean([3.5, 0.0, 5.25])        # zero input
2198        with self.assertRaises(StatisticsError):
2199            geometric_mean([3.5, -4.0, 5.25])       # negative input
2200        with self.assertRaises(StatisticsError):
2201            geometric_mean(iter([]))                # empty iterator
2202        with self.assertRaises(TypeError):
2203            geometric_mean(None)                    # non-iterable input
2204        with self.assertRaises(TypeError):
2205            geometric_mean([10, None, 20])          # non-numeric input
2206        with self.assertRaises(TypeError):
2207            geometric_mean()                        # missing data argument
2208        with self.assertRaises(TypeError):
2209            geometric_mean([10, 20, 60], 70)        # too many arguments
2210
2211    def test_special_values(self):
2212        # Rules for special values are inherited from math.fsum()
2213        geometric_mean = statistics.geometric_mean
2214        NaN = float('Nan')
2215        Inf = float('Inf')
2216        self.assertTrue(math.isnan(geometric_mean([10, NaN])), 'nan')
2217        self.assertTrue(math.isnan(geometric_mean([NaN, Inf])), 'nan and infinity')
2218        self.assertTrue(math.isinf(geometric_mean([10, Inf])), 'infinity')
2219        with self.assertRaises(ValueError):
2220            geometric_mean([Inf, -Inf])
2221
2222
2223class TestQuantiles(unittest.TestCase):
2224
2225    def test_specific_cases(self):
2226        # Match results computed by hand and cross-checked
2227        # against the PERCENTILE.EXC function in MS Excel.
2228        quantiles = statistics.quantiles
2229        data = [120, 200, 250, 320, 350]
2230        random.shuffle(data)
2231        for n, expected in [
2232            (1, []),
2233            (2, [250.0]),
2234            (3, [200.0, 320.0]),
2235            (4, [160.0, 250.0, 335.0]),
2236            (5, [136.0, 220.0, 292.0, 344.0]),
2237            (6, [120.0, 200.0, 250.0, 320.0, 350.0]),
2238            (8, [100.0, 160.0, 212.5, 250.0, 302.5, 335.0, 357.5]),
2239            (10, [88.0, 136.0, 184.0, 220.0, 250.0, 292.0, 326.0, 344.0, 362.0]),
2240            (12, [80.0, 120.0, 160.0, 200.0, 225.0, 250.0, 285.0, 320.0, 335.0,
2241                  350.0, 365.0]),
2242            (15, [72.0, 104.0, 136.0, 168.0, 200.0, 220.0, 240.0, 264.0, 292.0,
2243                  320.0, 332.0, 344.0, 356.0, 368.0]),
2244                ]:
2245            self.assertEqual(expected, quantiles(data, n=n))
2246            self.assertEqual(len(quantiles(data, n=n)), n - 1)
2247            # Preserve datatype when possible
2248            for datatype in (float, Decimal, Fraction):
2249                result = quantiles(map(datatype, data), n=n)
2250                self.assertTrue(all(type(x) == datatype) for x in result)
2251                self.assertEqual(result, list(map(datatype, expected)))
2252            # Quantiles should be idempotent
2253            if len(expected) >= 2:
2254                self.assertEqual(quantiles(expected, n=n), expected)
2255            # Cross-check against method='inclusive' which should give
2256            # the same result after adding in minimum and maximum values
2257            # extrapolated from the two lowest and two highest points.
2258            sdata = sorted(data)
2259            lo = 2 * sdata[0] - sdata[1]
2260            hi = 2 * sdata[-1] - sdata[-2]
2261            padded_data = data + [lo, hi]
2262            self.assertEqual(
2263                quantiles(data, n=n),
2264                quantiles(padded_data, n=n, method='inclusive'),
2265                (n, data),
2266            )
2267            # Invariant under translation and scaling
2268            def f(x):
2269                return 3.5 * x - 1234.675
2270            exp = list(map(f, expected))
2271            act = quantiles(map(f, data), n=n)
2272            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2273        # Q2 agrees with median()
2274        for k in range(2, 60):
2275            data = random.choices(range(100), k=k)
2276            q1, q2, q3 = quantiles(data)
2277            self.assertEqual(q2, statistics.median(data))
2278
2279    def test_specific_cases_inclusive(self):
2280        # Match results computed by hand and cross-checked
2281        # against the PERCENTILE.INC function in MS Excel
2282        # and against the quantile() function in SciPy.
2283        quantiles = statistics.quantiles
2284        data = [100, 200, 400, 800]
2285        random.shuffle(data)
2286        for n, expected in [
2287            (1, []),
2288            (2, [300.0]),
2289            (3, [200.0, 400.0]),
2290            (4, [175.0, 300.0, 500.0]),
2291            (5, [160.0, 240.0, 360.0, 560.0]),
2292            (6, [150.0, 200.0, 300.0, 400.0, 600.0]),
2293            (8, [137.5, 175, 225.0, 300.0, 375.0, 500.0,650.0]),
2294            (10, [130.0, 160.0, 190.0, 240.0, 300.0, 360.0, 440.0, 560.0, 680.0]),
2295            (12, [125.0, 150.0, 175.0, 200.0, 250.0, 300.0, 350.0, 400.0,
2296                  500.0, 600.0, 700.0]),
2297            (15, [120.0, 140.0, 160.0, 180.0, 200.0, 240.0, 280.0, 320.0, 360.0,
2298                  400.0, 480.0, 560.0, 640.0, 720.0]),
2299                ]:
2300            self.assertEqual(expected, quantiles(data, n=n, method="inclusive"))
2301            self.assertEqual(len(quantiles(data, n=n, method="inclusive")), n - 1)
2302            # Preserve datatype when possible
2303            for datatype in (float, Decimal, Fraction):
2304                result = quantiles(map(datatype, data), n=n, method="inclusive")
2305                self.assertTrue(all(type(x) == datatype) for x in result)
2306                self.assertEqual(result, list(map(datatype, expected)))
2307            # Invariant under translation and scaling
2308            def f(x):
2309                return 3.5 * x - 1234.675
2310            exp = list(map(f, expected))
2311            act = quantiles(map(f, data), n=n, method="inclusive")
2312            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2313        # Natural deciles
2314        self.assertEqual(quantiles([0, 100], n=10, method='inclusive'),
2315                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2316        self.assertEqual(quantiles(range(0, 101), n=10, method='inclusive'),
2317                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2318        # Whenever n is smaller than the number of data points, running
2319        # method='inclusive' should give the same result as method='exclusive'
2320        # after the two included extreme points are removed.
2321        data = [random.randrange(10_000) for i in range(501)]
2322        actual = quantiles(data, n=32, method='inclusive')
2323        data.remove(min(data))
2324        data.remove(max(data))
2325        expected = quantiles(data, n=32)
2326        self.assertEqual(expected, actual)
2327        # Q2 agrees with median()
2328        for k in range(2, 60):
2329            data = random.choices(range(100), k=k)
2330            q1, q2, q3 = quantiles(data, method='inclusive')
2331            self.assertEqual(q2, statistics.median(data))
2332
2333    def test_equal_inputs(self):
2334        quantiles = statistics.quantiles
2335        for n in range(2, 10):
2336            data = [10.0] * n
2337            self.assertEqual(quantiles(data), [10.0, 10.0, 10.0])
2338            self.assertEqual(quantiles(data, method='inclusive'),
2339                            [10.0, 10.0, 10.0])
2340
2341    def test_equal_sized_groups(self):
2342        quantiles = statistics.quantiles
2343        total = 10_000
2344        data = [random.expovariate(0.2) for i in range(total)]
2345        while len(set(data)) != total:
2346            data.append(random.expovariate(0.2))
2347        data.sort()
2348
2349        # Cases where the group size exactly divides the total
2350        for n in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000):
2351            group_size = total // n
2352            self.assertEqual(
2353                [bisect.bisect(data, q) for q in quantiles(data, n=n)],
2354                list(range(group_size, total, group_size)))
2355
2356        # When the group sizes can't be exactly equal, they should
2357        # differ by no more than one
2358        for n in (13, 19, 59, 109, 211, 571, 1019, 1907, 5261, 9769):
2359            group_sizes = {total // n, total // n + 1}
2360            pos = [bisect.bisect(data, q) for q in quantiles(data, n=n)]
2361            sizes = {q - p for p, q in zip(pos, pos[1:])}
2362            self.assertTrue(sizes <= group_sizes)
2363
2364    def test_error_cases(self):
2365        quantiles = statistics.quantiles
2366        StatisticsError = statistics.StatisticsError
2367        with self.assertRaises(TypeError):
2368            quantiles()                         # Missing arguments
2369        with self.assertRaises(TypeError):
2370            quantiles([10, 20, 30], 13, n=4)    # Too many arguments
2371        with self.assertRaises(TypeError):
2372            quantiles([10, 20, 30], 4)          # n is a positional argument
2373        with self.assertRaises(StatisticsError):
2374            quantiles([10, 20, 30], n=0)        # n is zero
2375        with self.assertRaises(StatisticsError):
2376            quantiles([10, 20, 30], n=-1)       # n is negative
2377        with self.assertRaises(TypeError):
2378            quantiles([10, 20, 30], n=1.5)      # n is not an integer
2379        with self.assertRaises(ValueError):
2380            quantiles([10, 20, 30], method='X') # method is unknown
2381        with self.assertRaises(StatisticsError):
2382            quantiles([10], n=4)                # not enough data points
2383        with self.assertRaises(TypeError):
2384            quantiles([10, None, 30], n=4)      # data is non-numeric
2385
2386
2387class TestNormalDist:
2388
2389    # General note on precision: The pdf(), cdf(), and overlap() methods
2390    # depend on functions in the math libraries that do not make
2391    # explicit accuracy guarantees.  Accordingly, some of the accuracy
2392    # tests below may fail if the underlying math functions are
2393    # inaccurate.  There isn't much we can do about this short of
2394    # implementing our own implementations from scratch.
2395
2396    def test_slots(self):
2397        nd = self.module.NormalDist(300, 23)
2398        with self.assertRaises(TypeError):
2399            vars(nd)
2400        self.assertEqual(tuple(nd.__slots__), ('_mu', '_sigma'))
2401
2402    def test_instantiation_and_attributes(self):
2403        nd = self.module.NormalDist(500, 17)
2404        self.assertEqual(nd.mean, 500)
2405        self.assertEqual(nd.stdev, 17)
2406        self.assertEqual(nd.variance, 17**2)
2407
2408        # default arguments
2409        nd = self.module.NormalDist()
2410        self.assertEqual(nd.mean, 0)
2411        self.assertEqual(nd.stdev, 1)
2412        self.assertEqual(nd.variance, 1**2)
2413
2414        # error case: negative sigma
2415        with self.assertRaises(self.module.StatisticsError):
2416            self.module.NormalDist(500, -10)
2417
2418        # verify that subclass type is honored
2419        class NewNormalDist(self.module.NormalDist):
2420            pass
2421        nnd = NewNormalDist(200, 5)
2422        self.assertEqual(type(nnd), NewNormalDist)
2423
2424    def test_alternative_constructor(self):
2425        NormalDist = self.module.NormalDist
2426        data = [96, 107, 90, 92, 110]
2427        # list input
2428        self.assertEqual(NormalDist.from_samples(data), NormalDist(99, 9))
2429        # tuple input
2430        self.assertEqual(NormalDist.from_samples(tuple(data)), NormalDist(99, 9))
2431        # iterator input
2432        self.assertEqual(NormalDist.from_samples(iter(data)), NormalDist(99, 9))
2433        # error cases
2434        with self.assertRaises(self.module.StatisticsError):
2435            NormalDist.from_samples([])                      # empty input
2436        with self.assertRaises(self.module.StatisticsError):
2437            NormalDist.from_samples([10])                    # only one input
2438
2439        # verify that subclass type is honored
2440        class NewNormalDist(NormalDist):
2441            pass
2442        nnd = NewNormalDist.from_samples(data)
2443        self.assertEqual(type(nnd), NewNormalDist)
2444
2445    def test_sample_generation(self):
2446        NormalDist = self.module.NormalDist
2447        mu, sigma = 10_000, 3.0
2448        X = NormalDist(mu, sigma)
2449        n = 1_000
2450        data = X.samples(n)
2451        self.assertEqual(len(data), n)
2452        self.assertEqual(set(map(type, data)), {float})
2453        # mean(data) expected to fall within 8 standard deviations
2454        xbar = self.module.mean(data)
2455        self.assertTrue(mu - sigma*8 <= xbar <= mu + sigma*8)
2456
2457        # verify that seeding makes reproducible sequences
2458        n = 100
2459        data1 = X.samples(n, seed='happiness and joy')
2460        data2 = X.samples(n, seed='trouble and despair')
2461        data3 = X.samples(n, seed='happiness and joy')
2462        data4 = X.samples(n, seed='trouble and despair')
2463        self.assertEqual(data1, data3)
2464        self.assertEqual(data2, data4)
2465        self.assertNotEqual(data1, data2)
2466
2467    def test_pdf(self):
2468        NormalDist = self.module.NormalDist
2469        X = NormalDist(100, 15)
2470        # Verify peak around center
2471        self.assertLess(X.pdf(99), X.pdf(100))
2472        self.assertLess(X.pdf(101), X.pdf(100))
2473        # Test symmetry
2474        for i in range(50):
2475            self.assertAlmostEqual(X.pdf(100 - i), X.pdf(100 + i))
2476        # Test vs CDF
2477        dx = 2.0 ** -10
2478        for x in range(90, 111):
2479            est_pdf = (X.cdf(x + dx) - X.cdf(x)) / dx
2480            self.assertAlmostEqual(X.pdf(x), est_pdf, places=4)
2481        # Test vs table of known values -- CRC 26th Edition
2482        Z = NormalDist()
2483        for x, px in enumerate([
2484            0.3989, 0.3989, 0.3989, 0.3988, 0.3986,
2485            0.3984, 0.3982, 0.3980, 0.3977, 0.3973,
2486            0.3970, 0.3965, 0.3961, 0.3956, 0.3951,
2487            0.3945, 0.3939, 0.3932, 0.3925, 0.3918,
2488            0.3910, 0.3902, 0.3894, 0.3885, 0.3876,
2489            0.3867, 0.3857, 0.3847, 0.3836, 0.3825,
2490            0.3814, 0.3802, 0.3790, 0.3778, 0.3765,
2491            0.3752, 0.3739, 0.3725, 0.3712, 0.3697,
2492            0.3683, 0.3668, 0.3653, 0.3637, 0.3621,
2493            0.3605, 0.3589, 0.3572, 0.3555, 0.3538,
2494        ]):
2495            self.assertAlmostEqual(Z.pdf(x / 100.0), px, places=4)
2496            self.assertAlmostEqual(Z.pdf(-x / 100.0), px, places=4)
2497        # Error case: variance is zero
2498        Y = NormalDist(100, 0)
2499        with self.assertRaises(self.module.StatisticsError):
2500            Y.pdf(90)
2501        # Special values
2502        self.assertEqual(X.pdf(float('-Inf')), 0.0)
2503        self.assertEqual(X.pdf(float('Inf')), 0.0)
2504        self.assertTrue(math.isnan(X.pdf(float('NaN'))))
2505
2506    def test_cdf(self):
2507        NormalDist = self.module.NormalDist
2508        X = NormalDist(100, 15)
2509        cdfs = [X.cdf(x) for x in range(1, 200)]
2510        self.assertEqual(set(map(type, cdfs)), {float})
2511        # Verify montonic
2512        self.assertEqual(cdfs, sorted(cdfs))
2513        # Verify center (should be exact)
2514        self.assertEqual(X.cdf(100), 0.50)
2515        # Check against a table of known values
2516        # https://en.wikipedia.org/wiki/Standard_normal_table#Cumulative
2517        Z = NormalDist()
2518        for z, cum_prob in [
2519            (0.00, 0.50000), (0.01, 0.50399), (0.02, 0.50798),
2520            (0.14, 0.55567), (0.29, 0.61409), (0.33, 0.62930),
2521            (0.54, 0.70540), (0.60, 0.72575), (1.17, 0.87900),
2522            (1.60, 0.94520), (2.05, 0.97982), (2.89, 0.99807),
2523            (3.52, 0.99978), (3.98, 0.99997), (4.07, 0.99998),
2524            ]:
2525            self.assertAlmostEqual(Z.cdf(z), cum_prob, places=5)
2526            self.assertAlmostEqual(Z.cdf(-z), 1.0 - cum_prob, places=5)
2527        # Error case: variance is zero
2528        Y = NormalDist(100, 0)
2529        with self.assertRaises(self.module.StatisticsError):
2530            Y.cdf(90)
2531        # Special values
2532        self.assertEqual(X.cdf(float('-Inf')), 0.0)
2533        self.assertEqual(X.cdf(float('Inf')), 1.0)
2534        self.assertTrue(math.isnan(X.cdf(float('NaN'))))
2535
2536    @support.skip_if_pgo_task
2537    def test_inv_cdf(self):
2538        NormalDist = self.module.NormalDist
2539
2540        # Center case should be exact.
2541        iq = NormalDist(100, 15)
2542        self.assertEqual(iq.inv_cdf(0.50), iq.mean)
2543
2544        # Test versus a published table of known percentage points.
2545        # See the second table at the bottom of the page here:
2546        # http://people.bath.ac.uk/masss/tables/normaltable.pdf
2547        Z = NormalDist()
2548        pp = {5.0: (0.000, 1.645, 2.576, 3.291, 3.891,
2549                    4.417, 4.892, 5.327, 5.731, 6.109),
2550              2.5: (0.674, 1.960, 2.807, 3.481, 4.056,
2551                    4.565, 5.026, 5.451, 5.847, 6.219),
2552              1.0: (1.282, 2.326, 3.090, 3.719, 4.265,
2553                    4.753, 5.199, 5.612, 5.998, 6.361)}
2554        for base, row in pp.items():
2555            for exp, x in enumerate(row, start=1):
2556                p = base * 10.0 ** (-exp)
2557                self.assertAlmostEqual(-Z.inv_cdf(p), x, places=3)
2558                p = 1.0 - p
2559                self.assertAlmostEqual(Z.inv_cdf(p), x, places=3)
2560
2561        # Match published example for MS Excel
2562        # https://support.office.com/en-us/article/norm-inv-function-54b30935-fee7-493c-bedb-2278a9db7e13
2563        self.assertAlmostEqual(NormalDist(40, 1.5).inv_cdf(0.908789), 42.000002)
2564
2565        # One million equally spaced probabilities
2566        n = 2**20
2567        for p in range(1, n):
2568            p /= n
2569            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2570
2571        # One hundred ever smaller probabilities to test tails out to
2572        # extreme probabilities: 1 / 2**50 and (2**50-1) / 2 ** 50
2573        for e in range(1, 51):
2574            p = 2.0 ** (-e)
2575            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2576            p = 1.0 - p
2577            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2578
2579        # Now apply cdf() first.  Near the tails, the round-trip loses
2580        # precision and is ill-conditioned (small changes in the inputs
2581        # give large changes in the output), so only check to 5 places.
2582        for x in range(200):
2583            self.assertAlmostEqual(iq.inv_cdf(iq.cdf(x)), x, places=5)
2584
2585        # Error cases:
2586        with self.assertRaises(self.module.StatisticsError):
2587            iq.inv_cdf(0.0)                         # p is zero
2588        with self.assertRaises(self.module.StatisticsError):
2589            iq.inv_cdf(-0.1)                        # p under zero
2590        with self.assertRaises(self.module.StatisticsError):
2591            iq.inv_cdf(1.0)                         # p is one
2592        with self.assertRaises(self.module.StatisticsError):
2593            iq.inv_cdf(1.1)                         # p over one
2594        with self.assertRaises(self.module.StatisticsError):
2595            iq = NormalDist(100, 0)                 # sigma is zero
2596            iq.inv_cdf(0.5)
2597
2598        # Special values
2599        self.assertTrue(math.isnan(Z.inv_cdf(float('NaN'))))
2600
2601    def test_quantiles(self):
2602        # Quartiles of a standard normal distribution
2603        Z = self.module.NormalDist()
2604        for n, expected in [
2605            (1, []),
2606            (2, [0.0]),
2607            (3, [-0.4307, 0.4307]),
2608            (4 ,[-0.6745, 0.0, 0.6745]),
2609                ]:
2610            actual = Z.quantiles(n=n)
2611            self.assertTrue(all(math.isclose(e, a, abs_tol=0.0001)
2612                            for e, a in zip(expected, actual)))
2613
2614    def test_overlap(self):
2615        NormalDist = self.module.NormalDist
2616
2617        # Match examples from Imman and Bradley
2618        for X1, X2, published_result in [
2619                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0), 0.80258),
2620                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0), 0.60993),
2621            ]:
2622            self.assertAlmostEqual(X1.overlap(X2), published_result, places=4)
2623            self.assertAlmostEqual(X2.overlap(X1), published_result, places=4)
2624
2625        # Check against integration of the PDF
2626        def overlap_numeric(X, Y, *, steps=8_192, z=5):
2627            'Numerical integration cross-check for overlap() '
2628            fsum = math.fsum
2629            center = (X.mean + Y.mean) / 2.0
2630            width = z * max(X.stdev, Y.stdev)
2631            start = center - width
2632            dx = 2.0 * width / steps
2633            x_arr = [start + i*dx for i in range(steps)]
2634            xp = list(map(X.pdf, x_arr))
2635            yp = list(map(Y.pdf, x_arr))
2636            total = max(fsum(xp), fsum(yp))
2637            return fsum(map(min, xp, yp)) / total
2638
2639        for X1, X2 in [
2640                # Examples from Imman and Bradley
2641                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0)),
2642                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2643                # Example from https://www.rasch.org/rmt/rmt101r.htm
2644                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2645                # Gender heights from http://www.usablestats.com/lessons/normal
2646                (NormalDist(70, 4), NormalDist(65, 3.5)),
2647                # Misc cases with equal standard deviations
2648                (NormalDist(100, 15), NormalDist(110, 15)),
2649                (NormalDist(-100, 15), NormalDist(110, 15)),
2650                (NormalDist(-100, 15), NormalDist(-110, 15)),
2651                # Misc cases with unequal standard deviations
2652                (NormalDist(100, 12), NormalDist(100, 15)),
2653                (NormalDist(100, 12), NormalDist(110, 15)),
2654                (NormalDist(100, 12), NormalDist(150, 15)),
2655                (NormalDist(100, 12), NormalDist(150, 35)),
2656                # Misc cases with small values
2657                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.003)),
2658                (NormalDist(1.000, 0.002), NormalDist(1.006, 0.0003)),
2659                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.099)),
2660            ]:
2661            self.assertAlmostEqual(X1.overlap(X2), overlap_numeric(X1, X2), places=5)
2662            self.assertAlmostEqual(X2.overlap(X1), overlap_numeric(X1, X2), places=5)
2663
2664        # Error cases
2665        X = NormalDist()
2666        with self.assertRaises(TypeError):
2667            X.overlap()                             # too few arguments
2668        with self.assertRaises(TypeError):
2669            X.overlap(X, X)                         # too may arguments
2670        with self.assertRaises(TypeError):
2671            X.overlap(None)                         # right operand not a NormalDist
2672        with self.assertRaises(self.module.StatisticsError):
2673            X.overlap(NormalDist(1, 0))             # right operand sigma is zero
2674        with self.assertRaises(self.module.StatisticsError):
2675            NormalDist(1, 0).overlap(X)             # left operand sigma is zero
2676
2677    def test_zscore(self):
2678        NormalDist = self.module.NormalDist
2679        X = NormalDist(100, 15)
2680        self.assertEqual(X.zscore(142), 2.8)
2681        self.assertEqual(X.zscore(58), -2.8)
2682        self.assertEqual(X.zscore(100), 0.0)
2683        with self.assertRaises(TypeError):
2684            X.zscore()                              # too few arguments
2685        with self.assertRaises(TypeError):
2686            X.zscore(1, 1)                          # too may arguments
2687        with self.assertRaises(TypeError):
2688            X.zscore(None)                          # non-numeric type
2689        with self.assertRaises(self.module.StatisticsError):
2690            NormalDist(1, 0).zscore(100)            # sigma is zero
2691
2692    def test_properties(self):
2693        X = self.module.NormalDist(100, 15)
2694        self.assertEqual(X.mean, 100)
2695        self.assertEqual(X.median, 100)
2696        self.assertEqual(X.mode, 100)
2697        self.assertEqual(X.stdev, 15)
2698        self.assertEqual(X.variance, 225)
2699
2700    def test_same_type_addition_and_subtraction(self):
2701        NormalDist = self.module.NormalDist
2702        X = NormalDist(100, 12)
2703        Y = NormalDist(40, 5)
2704        self.assertEqual(X + Y, NormalDist(140, 13))        # __add__
2705        self.assertEqual(X - Y, NormalDist(60, 13))         # __sub__
2706
2707    def test_translation_and_scaling(self):
2708        NormalDist = self.module.NormalDist
2709        X = NormalDist(100, 15)
2710        y = 10
2711        self.assertEqual(+X, NormalDist(100, 15))           # __pos__
2712        self.assertEqual(-X, NormalDist(-100, 15))          # __neg__
2713        self.assertEqual(X + y, NormalDist(110, 15))        # __add__
2714        self.assertEqual(y + X, NormalDist(110, 15))        # __radd__
2715        self.assertEqual(X - y, NormalDist(90, 15))         # __sub__
2716        self.assertEqual(y - X, NormalDist(-90, 15))        # __rsub__
2717        self.assertEqual(X * y, NormalDist(1000, 150))      # __mul__
2718        self.assertEqual(y * X, NormalDist(1000, 150))      # __rmul__
2719        self.assertEqual(X / y, NormalDist(10, 1.5))        # __truediv__
2720        with self.assertRaises(TypeError):                  # __rtruediv__
2721            y / X
2722
2723    def test_unary_operations(self):
2724        NormalDist = self.module.NormalDist
2725        X = NormalDist(100, 12)
2726        Y = +X
2727        self.assertIsNot(X, Y)
2728        self.assertEqual(X.mean, Y.mean)
2729        self.assertEqual(X.stdev, Y.stdev)
2730        Y = -X
2731        self.assertIsNot(X, Y)
2732        self.assertEqual(X.mean, -Y.mean)
2733        self.assertEqual(X.stdev, Y.stdev)
2734
2735    def test_equality(self):
2736        NormalDist = self.module.NormalDist
2737        nd1 = NormalDist()
2738        nd2 = NormalDist(2, 4)
2739        nd3 = NormalDist()
2740        nd4 = NormalDist(2, 4)
2741        nd5 = NormalDist(2, 8)
2742        nd6 = NormalDist(8, 4)
2743        self.assertNotEqual(nd1, nd2)
2744        self.assertEqual(nd1, nd3)
2745        self.assertEqual(nd2, nd4)
2746        self.assertNotEqual(nd2, nd5)
2747        self.assertNotEqual(nd2, nd6)
2748
2749        # Test NotImplemented when types are different
2750        class A:
2751            def __eq__(self, other):
2752                return 10
2753        a = A()
2754        self.assertEqual(nd1.__eq__(a), NotImplemented)
2755        self.assertEqual(nd1 == a, 10)
2756        self.assertEqual(a == nd1, 10)
2757
2758        # All subclasses to compare equal giving the same behavior
2759        # as list, tuple, int, float, complex, str, dict, set, etc.
2760        class SizedNormalDist(NormalDist):
2761            def __init__(self, mu, sigma, n):
2762                super().__init__(mu, sigma)
2763                self.n = n
2764        s = SizedNormalDist(100, 15, 57)
2765        nd4 = NormalDist(100, 15)
2766        self.assertEqual(s, nd4)
2767
2768        # Don't allow duck type equality because we wouldn't
2769        # want a lognormal distribution to compare equal
2770        # to a normal distribution with the same parameters
2771        class LognormalDist:
2772            def __init__(self, mu, sigma):
2773                self.mu = mu
2774                self.sigma = sigma
2775        lnd = LognormalDist(100, 15)
2776        nd = NormalDist(100, 15)
2777        self.assertNotEqual(nd, lnd)
2778
2779    def test_pickle_and_copy(self):
2780        nd = self.module.NormalDist(37.5, 5.625)
2781        nd1 = copy.copy(nd)
2782        self.assertEqual(nd, nd1)
2783        nd2 = copy.deepcopy(nd)
2784        self.assertEqual(nd, nd2)
2785        nd3 = pickle.loads(pickle.dumps(nd))
2786        self.assertEqual(nd, nd3)
2787
2788    def test_hashability(self):
2789        ND = self.module.NormalDist
2790        s = {ND(100, 15), ND(100.0, 15.0), ND(100, 10), ND(95, 15), ND(100, 15)}
2791        self.assertEqual(len(s), 3)
2792
2793    def test_repr(self):
2794        nd = self.module.NormalDist(37.5, 5.625)
2795        self.assertEqual(repr(nd), 'NormalDist(mu=37.5, sigma=5.625)')
2796
2797# Swapping the sys.modules['statistics'] is to solving the
2798# _pickle.PicklingError:
2799# Can't pickle <class 'statistics.NormalDist'>:
2800# it's not the same object as statistics.NormalDist
2801class TestNormalDistPython(unittest.TestCase, TestNormalDist):
2802    module = py_statistics
2803    def setUp(self):
2804        sys.modules['statistics'] = self.module
2805
2806    def tearDown(self):
2807        sys.modules['statistics'] = statistics
2808
2809
2810@unittest.skipUnless(c_statistics, 'requires _statistics')
2811class TestNormalDistC(unittest.TestCase, TestNormalDist):
2812    module = c_statistics
2813    def setUp(self):
2814        sys.modules['statistics'] = self.module
2815
2816    def tearDown(self):
2817        sys.modules['statistics'] = statistics
2818
2819
2820# === Run tests ===
2821
2822def load_tests(loader, tests, ignore):
2823    """Used for doctest/unittest integration."""
2824    tests.addTests(doctest.DocTestSuite())
2825    return tests
2826
2827
2828if __name__ == "__main__":
2829    unittest.main()
2830