1x = """Test suite for statistics module, including helper NumericTestCase and 2approx_equal function. 3 4""" 5 6import bisect 7import collections 8import collections.abc 9import copy 10import decimal 11import doctest 12import itertools 13import math 14import pickle 15import random 16import sys 17import unittest 18from test import support 19from test.support import import_helper, requires_IEEE_754 20 21from decimal import Decimal 22from fractions import Fraction 23 24 25# Module to be tested. 26import statistics 27 28 29# === Helper functions and class === 30 31# Test copied from Lib/test/test_math.py 32# detect evidence of double-rounding: fsum is not always correctly 33# rounded on machines that suffer from double rounding. 34x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer 35HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4) 36 37def sign(x): 38 """Return -1.0 for negatives, including -0.0, otherwise +1.0.""" 39 return math.copysign(1, x) 40 41def _nan_equal(a, b): 42 """Return True if a and b are both the same kind of NAN. 43 44 >>> _nan_equal(Decimal('NAN'), Decimal('NAN')) 45 True 46 >>> _nan_equal(Decimal('sNAN'), Decimal('sNAN')) 47 True 48 >>> _nan_equal(Decimal('NAN'), Decimal('sNAN')) 49 False 50 >>> _nan_equal(Decimal(42), Decimal('NAN')) 51 False 52 53 >>> _nan_equal(float('NAN'), float('NAN')) 54 True 55 >>> _nan_equal(float('NAN'), 0.5) 56 False 57 58 >>> _nan_equal(float('NAN'), Decimal('NAN')) 59 False 60 61 NAN payloads are not compared. 62 """ 63 if type(a) is not type(b): 64 return False 65 if isinstance(a, float): 66 return math.isnan(a) and math.isnan(b) 67 aexp = a.as_tuple()[2] 68 bexp = b.as_tuple()[2] 69 return (aexp == bexp) and (aexp in ('n', 'N')) # Both NAN or both sNAN. 70 71 72def _calc_errors(actual, expected): 73 """Return the absolute and relative errors between two numbers. 74 75 >>> _calc_errors(100, 75) 76 (25, 0.25) 77 >>> _calc_errors(100, 100) 78 (0, 0.0) 79 80 Returns the (absolute error, relative error) between the two arguments. 81 """ 82 base = max(abs(actual), abs(expected)) 83 abs_err = abs(actual - expected) 84 rel_err = abs_err/base if base else float('inf') 85 return (abs_err, rel_err) 86 87 88def approx_equal(x, y, tol=1e-12, rel=1e-7): 89 """approx_equal(x, y [, tol [, rel]]) => True|False 90 91 Return True if numbers x and y are approximately equal, to within some 92 margin of error, otherwise return False. Numbers which compare equal 93 will also compare approximately equal. 94 95 x is approximately equal to y if the difference between them is less than 96 an absolute error tol or a relative error rel, whichever is bigger. 97 98 If given, both tol and rel must be finite, non-negative numbers. If not 99 given, default values are tol=1e-12 and rel=1e-7. 100 101 >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0) 102 True 103 >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0) 104 False 105 106 Absolute error is defined as abs(x-y); if that is less than or equal to 107 tol, x and y are considered approximately equal. 108 109 Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is 110 smaller, provided x or y are not zero. If that figure is less than or 111 equal to rel, x and y are considered approximately equal. 112 113 Complex numbers are not directly supported. If you wish to compare to 114 complex numbers, extract their real and imaginary parts and compare them 115 individually. 116 117 NANs always compare unequal, even with themselves. Infinities compare 118 approximately equal if they have the same sign (both positive or both 119 negative). Infinities with different signs compare unequal; so do 120 comparisons of infinities with finite numbers. 121 """ 122 if tol < 0 or rel < 0: 123 raise ValueError('error tolerances must be non-negative') 124 # NANs are never equal to anything, approximately or otherwise. 125 if math.isnan(x) or math.isnan(y): 126 return False 127 # Numbers which compare equal also compare approximately equal. 128 if x == y: 129 # This includes the case of two infinities with the same sign. 130 return True 131 if math.isinf(x) or math.isinf(y): 132 # This includes the case of two infinities of opposite sign, or 133 # one infinity and one finite number. 134 return False 135 # Two finite numbers. 136 actual_error = abs(x - y) 137 allowed_error = max(tol, rel*max(abs(x), abs(y))) 138 return actual_error <= allowed_error 139 140 141# This class exists only as somewhere to stick a docstring containing 142# doctests. The following docstring and tests were originally in a separate 143# module. Now that it has been merged in here, I need somewhere to hang the. 144# docstring. Ultimately, this class will die, and the information below will 145# either become redundant, or be moved into more appropriate places. 146class _DoNothing: 147 """ 148 When doing numeric work, especially with floats, exact equality is often 149 not what you want. Due to round-off error, it is often a bad idea to try 150 to compare floats with equality. Instead the usual procedure is to test 151 them with some (hopefully small!) allowance for error. 152 153 The ``approx_equal`` function allows you to specify either an absolute 154 error tolerance, or a relative error, or both. 155 156 Absolute error tolerances are simple, but you need to know the magnitude 157 of the quantities being compared: 158 159 >>> approx_equal(12.345, 12.346, tol=1e-3) 160 True 161 >>> approx_equal(12.345e6, 12.346e6, tol=1e-3) # tol is too small. 162 False 163 164 Relative errors are more suitable when the values you are comparing can 165 vary in magnitude: 166 167 >>> approx_equal(12.345, 12.346, rel=1e-4) 168 True 169 >>> approx_equal(12.345e6, 12.346e6, rel=1e-4) 170 True 171 172 but a naive implementation of relative error testing can run into trouble 173 around zero. 174 175 If you supply both an absolute tolerance and a relative error, the 176 comparison succeeds if either individual test succeeds: 177 178 >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4) 179 True 180 181 """ 182 pass 183 184 185 186# We prefer this for testing numeric values that may not be exactly equal, 187# and avoid using TestCase.assertAlmostEqual, because it sucks :-) 188 189py_statistics = import_helper.import_fresh_module('statistics', 190 blocked=['_statistics']) 191c_statistics = import_helper.import_fresh_module('statistics', 192 fresh=['_statistics']) 193 194 195class TestModules(unittest.TestCase): 196 func_names = ['_normal_dist_inv_cdf'] 197 198 def test_py_functions(self): 199 for fname in self.func_names: 200 self.assertEqual(getattr(py_statistics, fname).__module__, 'statistics') 201 202 @unittest.skipUnless(c_statistics, 'requires _statistics') 203 def test_c_functions(self): 204 for fname in self.func_names: 205 self.assertEqual(getattr(c_statistics, fname).__module__, '_statistics') 206 207 208class NumericTestCase(unittest.TestCase): 209 """Unit test class for numeric work. 210 211 This subclasses TestCase. In addition to the standard method 212 ``TestCase.assertAlmostEqual``, ``assertApproxEqual`` is provided. 213 """ 214 # By default, we expect exact equality, unless overridden. 215 tol = rel = 0 216 217 def assertApproxEqual( 218 self, first, second, tol=None, rel=None, msg=None 219 ): 220 """Test passes if ``first`` and ``second`` are approximately equal. 221 222 This test passes if ``first`` and ``second`` are equal to 223 within ``tol``, an absolute error, or ``rel``, a relative error. 224 225 If either ``tol`` or ``rel`` are None or not given, they default to 226 test attributes of the same name (by default, 0). 227 228 The objects may be either numbers, or sequences of numbers. Sequences 229 are tested element-by-element. 230 231 >>> class MyTest(NumericTestCase): 232 ... def test_number(self): 233 ... x = 1.0/6 234 ... y = sum([x]*6) 235 ... self.assertApproxEqual(y, 1.0, tol=1e-15) 236 ... def test_sequence(self): 237 ... a = [1.001, 1.001e-10, 1.001e10] 238 ... b = [1.0, 1e-10, 1e10] 239 ... self.assertApproxEqual(a, b, rel=1e-3) 240 ... 241 >>> import unittest 242 >>> from io import StringIO # Suppress test runner output. 243 >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest) 244 >>> unittest.TextTestRunner(stream=StringIO()).run(suite) 245 <unittest.runner.TextTestResult run=2 errors=0 failures=0> 246 247 """ 248 if tol is None: 249 tol = self.tol 250 if rel is None: 251 rel = self.rel 252 if ( 253 isinstance(first, collections.abc.Sequence) and 254 isinstance(second, collections.abc.Sequence) 255 ): 256 check = self._check_approx_seq 257 else: 258 check = self._check_approx_num 259 check(first, second, tol, rel, msg) 260 261 def _check_approx_seq(self, first, second, tol, rel, msg): 262 if len(first) != len(second): 263 standardMsg = ( 264 "sequences differ in length: %d items != %d items" 265 % (len(first), len(second)) 266 ) 267 msg = self._formatMessage(msg, standardMsg) 268 raise self.failureException(msg) 269 for i, (a,e) in enumerate(zip(first, second)): 270 self._check_approx_num(a, e, tol, rel, msg, i) 271 272 def _check_approx_num(self, first, second, tol, rel, msg, idx=None): 273 if approx_equal(first, second, tol, rel): 274 # Test passes. Return early, we are done. 275 return None 276 # Otherwise we failed. 277 standardMsg = self._make_std_err_msg(first, second, tol, rel, idx) 278 msg = self._formatMessage(msg, standardMsg) 279 raise self.failureException(msg) 280 281 @staticmethod 282 def _make_std_err_msg(first, second, tol, rel, idx): 283 # Create the standard error message for approx_equal failures. 284 assert first != second 285 template = ( 286 ' %r != %r\n' 287 ' values differ by more than tol=%r and rel=%r\n' 288 ' -> absolute error = %r\n' 289 ' -> relative error = %r' 290 ) 291 if idx is not None: 292 header = 'numeric sequences first differ at index %d.\n' % idx 293 template = header + template 294 # Calculate actual errors: 295 abs_err, rel_err = _calc_errors(first, second) 296 return template % (first, second, tol, rel, abs_err, rel_err) 297 298 299# ======================== 300# === Test the helpers === 301# ======================== 302 303class TestSign(unittest.TestCase): 304 """Test that the helper function sign() works correctly.""" 305 def testZeroes(self): 306 # Test that signed zeroes report their sign correctly. 307 self.assertEqual(sign(0.0), +1) 308 self.assertEqual(sign(-0.0), -1) 309 310 311# --- Tests for approx_equal --- 312 313class ApproxEqualSymmetryTest(unittest.TestCase): 314 # Test symmetry of approx_equal. 315 316 def test_relative_symmetry(self): 317 # Check that approx_equal treats relative error symmetrically. 318 # (a-b)/a is usually not equal to (a-b)/b. Ensure that this 319 # doesn't matter. 320 # 321 # Note: the reason for this test is that an early version 322 # of approx_equal was not symmetric. A relative error test 323 # would pass, or fail, depending on which value was passed 324 # as the first argument. 325 # 326 args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)] 327 args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)] 328 assert len(args1) == len(args2) 329 for a, b in zip(args1, args2): 330 self.do_relative_symmetry(a, b) 331 332 def do_relative_symmetry(self, a, b): 333 a, b = min(a, b), max(a, b) 334 assert a < b 335 delta = b - a # The absolute difference between the values. 336 rel_err1, rel_err2 = abs(delta/a), abs(delta/b) 337 # Choose an error margin halfway between the two. 338 rel = (rel_err1 + rel_err2)/2 339 # Now see that values a and b compare approx equal regardless of 340 # which is given first. 341 self.assertTrue(approx_equal(a, b, tol=0, rel=rel)) 342 self.assertTrue(approx_equal(b, a, tol=0, rel=rel)) 343 344 def test_symmetry(self): 345 # Test that approx_equal(a, b) == approx_equal(b, a) 346 args = [-23, -2, 5, 107, 93568] 347 delta = 2 348 for a in args: 349 for type_ in (int, float, Decimal, Fraction): 350 x = type_(a)*100 351 y = x + delta 352 r = abs(delta/max(x, y)) 353 # There are five cases to check: 354 # 1) actual error <= tol, <= rel 355 self.do_symmetry_test(x, y, tol=delta, rel=r) 356 self.do_symmetry_test(x, y, tol=delta+1, rel=2*r) 357 # 2) actual error > tol, > rel 358 self.do_symmetry_test(x, y, tol=delta-1, rel=r/2) 359 # 3) actual error <= tol, > rel 360 self.do_symmetry_test(x, y, tol=delta, rel=r/2) 361 # 4) actual error > tol, <= rel 362 self.do_symmetry_test(x, y, tol=delta-1, rel=r) 363 self.do_symmetry_test(x, y, tol=delta-1, rel=2*r) 364 # 5) exact equality test 365 self.do_symmetry_test(x, x, tol=0, rel=0) 366 self.do_symmetry_test(x, y, tol=0, rel=0) 367 368 def do_symmetry_test(self, a, b, tol, rel): 369 template = "approx_equal comparisons don't match for %r" 370 flag1 = approx_equal(a, b, tol, rel) 371 flag2 = approx_equal(b, a, tol, rel) 372 self.assertEqual(flag1, flag2, template.format((a, b, tol, rel))) 373 374 375class ApproxEqualExactTest(unittest.TestCase): 376 # Test the approx_equal function with exactly equal values. 377 # Equal values should compare as approximately equal. 378 # Test cases for exactly equal values, which should compare approx 379 # equal regardless of the error tolerances given. 380 381 def do_exactly_equal_test(self, x, tol, rel): 382 result = approx_equal(x, x, tol=tol, rel=rel) 383 self.assertTrue(result, 'equality failure for x=%r' % x) 384 result = approx_equal(-x, -x, tol=tol, rel=rel) 385 self.assertTrue(result, 'equality failure for x=%r' % -x) 386 387 def test_exactly_equal_ints(self): 388 # Test that equal int values are exactly equal. 389 for n in [42, 19740, 14974, 230, 1795, 700245, 36587]: 390 self.do_exactly_equal_test(n, 0, 0) 391 392 def test_exactly_equal_floats(self): 393 # Test that equal float values are exactly equal. 394 for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]: 395 self.do_exactly_equal_test(x, 0, 0) 396 397 def test_exactly_equal_fractions(self): 398 # Test that equal Fraction values are exactly equal. 399 F = Fraction 400 for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]: 401 self.do_exactly_equal_test(f, 0, 0) 402 403 def test_exactly_equal_decimals(self): 404 # Test that equal Decimal values are exactly equal. 405 D = Decimal 406 for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()): 407 self.do_exactly_equal_test(d, 0, 0) 408 409 def test_exactly_equal_absolute(self): 410 # Test that equal values are exactly equal with an absolute error. 411 for n in [16, 1013, 1372, 1198, 971, 4]: 412 # Test as ints. 413 self.do_exactly_equal_test(n, 0.01, 0) 414 # Test as floats. 415 self.do_exactly_equal_test(n/10, 0.01, 0) 416 # Test as Fractions. 417 f = Fraction(n, 1234) 418 self.do_exactly_equal_test(f, 0.01, 0) 419 420 def test_exactly_equal_absolute_decimals(self): 421 # Test equal Decimal values are exactly equal with an absolute error. 422 self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0) 423 self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0) 424 425 def test_exactly_equal_relative(self): 426 # Test that equal values are exactly equal with a relative error. 427 for x in [8347, 101.3, -7910.28, Fraction(5, 21)]: 428 self.do_exactly_equal_test(x, 0, 0.01) 429 self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01")) 430 431 def test_exactly_equal_both(self): 432 # Test that equal values are equal when both tol and rel are given. 433 for x in [41017, 16.742, -813.02, Fraction(3, 8)]: 434 self.do_exactly_equal_test(x, 0.1, 0.01) 435 D = Decimal 436 self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01")) 437 438 439class ApproxEqualUnequalTest(unittest.TestCase): 440 # Unequal values should compare unequal with zero error tolerances. 441 # Test cases for unequal values, with exact equality test. 442 443 def do_exactly_unequal_test(self, x): 444 for a in (x, -x): 445 result = approx_equal(a, a+1, tol=0, rel=0) 446 self.assertFalse(result, 'inequality failure for x=%r' % a) 447 448 def test_exactly_unequal_ints(self): 449 # Test unequal int values are unequal with zero error tolerance. 450 for n in [951, 572305, 478, 917, 17240]: 451 self.do_exactly_unequal_test(n) 452 453 def test_exactly_unequal_floats(self): 454 # Test unequal float values are unequal with zero error tolerance. 455 for x in [9.51, 5723.05, 47.8, 9.17, 17.24]: 456 self.do_exactly_unequal_test(x) 457 458 def test_exactly_unequal_fractions(self): 459 # Test that unequal Fractions are unequal with zero error tolerance. 460 F = Fraction 461 for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]: 462 self.do_exactly_unequal_test(f) 463 464 def test_exactly_unequal_decimals(self): 465 # Test that unequal Decimals are unequal with zero error tolerance. 466 for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()): 467 self.do_exactly_unequal_test(d) 468 469 470class ApproxEqualInexactTest(unittest.TestCase): 471 # Inexact test cases for approx_error. 472 # Test cases when comparing two values that are not exactly equal. 473 474 # === Absolute error tests === 475 476 def do_approx_equal_abs_test(self, x, delta): 477 template = "Test failure for x={!r}, y={!r}" 478 for y in (x + delta, x - delta): 479 msg = template.format(x, y) 480 self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg) 481 self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg) 482 483 def test_approx_equal_absolute_ints(self): 484 # Test approximate equality of ints with an absolute error. 485 for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]: 486 self.do_approx_equal_abs_test(n, 10) 487 self.do_approx_equal_abs_test(n, 2) 488 489 def test_approx_equal_absolute_floats(self): 490 # Test approximate equality of floats with an absolute error. 491 for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]: 492 self.do_approx_equal_abs_test(x, 1.5) 493 self.do_approx_equal_abs_test(x, 0.01) 494 self.do_approx_equal_abs_test(x, 0.0001) 495 496 def test_approx_equal_absolute_fractions(self): 497 # Test approximate equality of Fractions with an absolute error. 498 delta = Fraction(1, 29) 499 numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71] 500 for f in (Fraction(n, 29) for n in numerators): 501 self.do_approx_equal_abs_test(f, delta) 502 self.do_approx_equal_abs_test(f, float(delta)) 503 504 def test_approx_equal_absolute_decimals(self): 505 # Test approximate equality of Decimals with an absolute error. 506 delta = Decimal("0.01") 507 for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()): 508 self.do_approx_equal_abs_test(d, delta) 509 self.do_approx_equal_abs_test(-d, delta) 510 511 def test_cross_zero(self): 512 # Test for the case of the two values having opposite signs. 513 self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0)) 514 515 # === Relative error tests === 516 517 def do_approx_equal_rel_test(self, x, delta): 518 template = "Test failure for x={!r}, y={!r}" 519 for y in (x*(1+delta), x*(1-delta)): 520 msg = template.format(x, y) 521 self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg) 522 self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg) 523 524 def test_approx_equal_relative_ints(self): 525 # Test approximate equality of ints with a relative error. 526 self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36)) 527 self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37)) 528 # --- 529 self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125)) 530 self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125)) 531 self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125)) 532 533 def test_approx_equal_relative_floats(self): 534 # Test approximate equality of floats with a relative error. 535 for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]: 536 self.do_approx_equal_rel_test(x, 0.02) 537 self.do_approx_equal_rel_test(x, 0.0001) 538 539 def test_approx_equal_relative_fractions(self): 540 # Test approximate equality of Fractions with a relative error. 541 F = Fraction 542 delta = Fraction(3, 8) 543 for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]: 544 for d in (delta, float(delta)): 545 self.do_approx_equal_rel_test(f, d) 546 self.do_approx_equal_rel_test(-f, d) 547 548 def test_approx_equal_relative_decimals(self): 549 # Test approximate equality of Decimals with a relative error. 550 for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()): 551 self.do_approx_equal_rel_test(d, Decimal("0.001")) 552 self.do_approx_equal_rel_test(-d, Decimal("0.05")) 553 554 # === Both absolute and relative error tests === 555 556 # There are four cases to consider: 557 # 1) actual error <= both absolute and relative error 558 # 2) actual error <= absolute error but > relative error 559 # 3) actual error <= relative error but > absolute error 560 # 4) actual error > both absolute and relative error 561 562 def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag): 563 check = self.assertTrue if tol_flag else self.assertFalse 564 check(approx_equal(a, b, tol=tol, rel=0)) 565 check = self.assertTrue if rel_flag else self.assertFalse 566 check(approx_equal(a, b, tol=0, rel=rel)) 567 check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse 568 check(approx_equal(a, b, tol=tol, rel=rel)) 569 570 def test_approx_equal_both1(self): 571 # Test actual error <= both absolute and relative error. 572 self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True) 573 self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True) 574 575 def test_approx_equal_both2(self): 576 # Test actual error <= absolute error but > relative error. 577 self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False) 578 579 def test_approx_equal_both3(self): 580 # Test actual error <= relative error but > absolute error. 581 self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True) 582 583 def test_approx_equal_both4(self): 584 # Test actual error > both absolute and relative error. 585 self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False) 586 self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False) 587 588 589class ApproxEqualSpecialsTest(unittest.TestCase): 590 # Test approx_equal with NANs and INFs and zeroes. 591 592 def test_inf(self): 593 for type_ in (float, Decimal): 594 inf = type_('inf') 595 self.assertTrue(approx_equal(inf, inf)) 596 self.assertTrue(approx_equal(inf, inf, 0, 0)) 597 self.assertTrue(approx_equal(inf, inf, 1, 0.01)) 598 self.assertTrue(approx_equal(-inf, -inf)) 599 self.assertFalse(approx_equal(inf, -inf)) 600 self.assertFalse(approx_equal(inf, 1000)) 601 602 def test_nan(self): 603 for type_ in (float, Decimal): 604 nan = type_('nan') 605 for other in (nan, type_('inf'), 1000): 606 self.assertFalse(approx_equal(nan, other)) 607 608 def test_float_zeroes(self): 609 nzero = math.copysign(0.0, -1) 610 self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1)) 611 612 def test_decimal_zeroes(self): 613 nzero = Decimal("-0.0") 614 self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1)) 615 616 617class TestApproxEqualErrors(unittest.TestCase): 618 # Test error conditions of approx_equal. 619 620 def test_bad_tol(self): 621 # Test negative tol raises. 622 self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1) 623 624 def test_bad_rel(self): 625 # Test negative rel raises. 626 self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1) 627 628 629# --- Tests for NumericTestCase --- 630 631# The formatting routine that generates the error messages is complex enough 632# that it too needs testing. 633 634class TestNumericTestCase(unittest.TestCase): 635 # The exact wording of NumericTestCase error messages is *not* guaranteed, 636 # but we need to give them some sort of test to ensure that they are 637 # generated correctly. As a compromise, we look for specific substrings 638 # that are expected to be found even if the overall error message changes. 639 640 def do_test(self, args): 641 actual_msg = NumericTestCase._make_std_err_msg(*args) 642 expected = self.generate_substrings(*args) 643 for substring in expected: 644 self.assertIn(substring, actual_msg) 645 646 def test_numerictestcase_is_testcase(self): 647 # Ensure that NumericTestCase actually is a TestCase. 648 self.assertTrue(issubclass(NumericTestCase, unittest.TestCase)) 649 650 def test_error_msg_numeric(self): 651 # Test the error message generated for numeric comparisons. 652 args = (2.5, 4.0, 0.5, 0.25, None) 653 self.do_test(args) 654 655 def test_error_msg_sequence(self): 656 # Test the error message generated for sequence comparisons. 657 args = (3.75, 8.25, 1.25, 0.5, 7) 658 self.do_test(args) 659 660 def generate_substrings(self, first, second, tol, rel, idx): 661 """Return substrings we expect to see in error messages.""" 662 abs_err, rel_err = _calc_errors(first, second) 663 substrings = [ 664 'tol=%r' % tol, 665 'rel=%r' % rel, 666 'absolute error = %r' % abs_err, 667 'relative error = %r' % rel_err, 668 ] 669 if idx is not None: 670 substrings.append('differ at index %d' % idx) 671 return substrings 672 673 674# ======================================= 675# === Tests for the statistics module === 676# ======================================= 677 678 679class GlobalsTest(unittest.TestCase): 680 module = statistics 681 expected_metadata = ["__doc__", "__all__"] 682 683 def test_meta(self): 684 # Test for the existence of metadata. 685 for meta in self.expected_metadata: 686 self.assertTrue(hasattr(self.module, meta), 687 "%s not present" % meta) 688 689 def test_check_all(self): 690 # Check everything in __all__ exists and is public. 691 module = self.module 692 for name in module.__all__: 693 # No private names in __all__: 694 self.assertFalse(name.startswith("_"), 695 'private name "%s" in __all__' % name) 696 # And anything in __all__ must exist: 697 self.assertTrue(hasattr(module, name), 698 'missing name "%s" in __all__' % name) 699 700 701class StatisticsErrorTest(unittest.TestCase): 702 def test_has_exception(self): 703 errmsg = ( 704 "Expected StatisticsError to be a ValueError, but got a" 705 " subclass of %r instead." 706 ) 707 self.assertTrue(hasattr(statistics, 'StatisticsError')) 708 self.assertTrue( 709 issubclass(statistics.StatisticsError, ValueError), 710 errmsg % statistics.StatisticsError.__base__ 711 ) 712 713 714# === Tests for private utility functions === 715 716class ExactRatioTest(unittest.TestCase): 717 # Test _exact_ratio utility. 718 719 def test_int(self): 720 for i in (-20, -3, 0, 5, 99, 10**20): 721 self.assertEqual(statistics._exact_ratio(i), (i, 1)) 722 723 def test_fraction(self): 724 numerators = (-5, 1, 12, 38) 725 for n in numerators: 726 f = Fraction(n, 37) 727 self.assertEqual(statistics._exact_ratio(f), (n, 37)) 728 729 def test_float(self): 730 self.assertEqual(statistics._exact_ratio(0.125), (1, 8)) 731 self.assertEqual(statistics._exact_ratio(1.125), (9, 8)) 732 data = [random.uniform(-100, 100) for _ in range(100)] 733 for x in data: 734 num, den = statistics._exact_ratio(x) 735 self.assertEqual(x, num/den) 736 737 def test_decimal(self): 738 D = Decimal 739 _exact_ratio = statistics._exact_ratio 740 self.assertEqual(_exact_ratio(D("0.125")), (1, 8)) 741 self.assertEqual(_exact_ratio(D("12.345")), (2469, 200)) 742 self.assertEqual(_exact_ratio(D("-1.98")), (-99, 50)) 743 744 def test_inf(self): 745 INF = float("INF") 746 class MyFloat(float): 747 pass 748 class MyDecimal(Decimal): 749 pass 750 for inf in (INF, -INF): 751 for type_ in (float, MyFloat, Decimal, MyDecimal): 752 x = type_(inf) 753 ratio = statistics._exact_ratio(x) 754 self.assertEqual(ratio, (x, None)) 755 self.assertEqual(type(ratio[0]), type_) 756 self.assertTrue(math.isinf(ratio[0])) 757 758 def test_float_nan(self): 759 NAN = float("NAN") 760 class MyFloat(float): 761 pass 762 for nan in (NAN, MyFloat(NAN)): 763 ratio = statistics._exact_ratio(nan) 764 self.assertTrue(math.isnan(ratio[0])) 765 self.assertIs(ratio[1], None) 766 self.assertEqual(type(ratio[0]), type(nan)) 767 768 def test_decimal_nan(self): 769 NAN = Decimal("NAN") 770 sNAN = Decimal("sNAN") 771 class MyDecimal(Decimal): 772 pass 773 for nan in (NAN, MyDecimal(NAN), sNAN, MyDecimal(sNAN)): 774 ratio = statistics._exact_ratio(nan) 775 self.assertTrue(_nan_equal(ratio[0], nan)) 776 self.assertIs(ratio[1], None) 777 self.assertEqual(type(ratio[0]), type(nan)) 778 779 780class DecimalToRatioTest(unittest.TestCase): 781 # Test _exact_ratio private function. 782 783 def test_infinity(self): 784 # Test that INFs are handled correctly. 785 inf = Decimal('INF') 786 self.assertEqual(statistics._exact_ratio(inf), (inf, None)) 787 self.assertEqual(statistics._exact_ratio(-inf), (-inf, None)) 788 789 def test_nan(self): 790 # Test that NANs are handled correctly. 791 for nan in (Decimal('NAN'), Decimal('sNAN')): 792 num, den = statistics._exact_ratio(nan) 793 # Because NANs always compare non-equal, we cannot use assertEqual. 794 # Nor can we use an identity test, as we don't guarantee anything 795 # about the object identity. 796 self.assertTrue(_nan_equal(num, nan)) 797 self.assertIs(den, None) 798 799 def test_sign(self): 800 # Test sign is calculated correctly. 801 numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")] 802 for d in numbers: 803 # First test positive decimals. 804 assert d > 0 805 num, den = statistics._exact_ratio(d) 806 self.assertGreaterEqual(num, 0) 807 self.assertGreater(den, 0) 808 # Then test negative decimals. 809 num, den = statistics._exact_ratio(-d) 810 self.assertLessEqual(num, 0) 811 self.assertGreater(den, 0) 812 813 def test_negative_exponent(self): 814 # Test result when the exponent is negative. 815 t = statistics._exact_ratio(Decimal("0.1234")) 816 self.assertEqual(t, (617, 5000)) 817 818 def test_positive_exponent(self): 819 # Test results when the exponent is positive. 820 t = statistics._exact_ratio(Decimal("1.234e7")) 821 self.assertEqual(t, (12340000, 1)) 822 823 def test_regression_20536(self): 824 # Regression test for issue 20536. 825 # See http://bugs.python.org/issue20536 826 t = statistics._exact_ratio(Decimal("1e2")) 827 self.assertEqual(t, (100, 1)) 828 t = statistics._exact_ratio(Decimal("1.47e5")) 829 self.assertEqual(t, (147000, 1)) 830 831 832class IsFiniteTest(unittest.TestCase): 833 # Test _isfinite private function. 834 835 def test_finite(self): 836 # Test that finite numbers are recognised as finite. 837 for x in (5, Fraction(1, 3), 2.5, Decimal("5.5")): 838 self.assertTrue(statistics._isfinite(x)) 839 840 def test_infinity(self): 841 # Test that INFs are not recognised as finite. 842 for x in (float("inf"), Decimal("inf")): 843 self.assertFalse(statistics._isfinite(x)) 844 845 def test_nan(self): 846 # Test that NANs are not recognised as finite. 847 for x in (float("nan"), Decimal("NAN"), Decimal("sNAN")): 848 self.assertFalse(statistics._isfinite(x)) 849 850 851class CoerceTest(unittest.TestCase): 852 # Test that private function _coerce correctly deals with types. 853 854 # The coercion rules are currently an implementation detail, although at 855 # some point that should change. The tests and comments here define the 856 # correct implementation. 857 858 # Pre-conditions of _coerce: 859 # 860 # - The first time _sum calls _coerce, the 861 # - coerce(T, S) will never be called with bool as the first argument; 862 # this is a pre-condition, guarded with an assertion. 863 864 # 865 # - coerce(T, T) will always return T; we assume T is a valid numeric 866 # type. Violate this assumption at your own risk. 867 # 868 # - Apart from as above, bool is treated as if it were actually int. 869 # 870 # - coerce(int, X) and coerce(X, int) return X. 871 # - 872 def test_bool(self): 873 # bool is somewhat special, due to the pre-condition that it is 874 # never given as the first argument to _coerce, and that it cannot 875 # be subclassed. So we test it specially. 876 for T in (int, float, Fraction, Decimal): 877 self.assertIs(statistics._coerce(T, bool), T) 878 class MyClass(T): pass 879 self.assertIs(statistics._coerce(MyClass, bool), MyClass) 880 881 def assertCoerceTo(self, A, B): 882 """Assert that type A coerces to B.""" 883 self.assertIs(statistics._coerce(A, B), B) 884 self.assertIs(statistics._coerce(B, A), B) 885 886 def check_coerce_to(self, A, B): 887 """Checks that type A coerces to B, including subclasses.""" 888 # Assert that type A is coerced to B. 889 self.assertCoerceTo(A, B) 890 # Subclasses of A are also coerced to B. 891 class SubclassOfA(A): pass 892 self.assertCoerceTo(SubclassOfA, B) 893 # A, and subclasses of A, are coerced to subclasses of B. 894 class SubclassOfB(B): pass 895 self.assertCoerceTo(A, SubclassOfB) 896 self.assertCoerceTo(SubclassOfA, SubclassOfB) 897 898 def assertCoerceRaises(self, A, B): 899 """Assert that coercing A to B, or vice versa, raises TypeError.""" 900 self.assertRaises(TypeError, statistics._coerce, (A, B)) 901 self.assertRaises(TypeError, statistics._coerce, (B, A)) 902 903 def check_type_coercions(self, T): 904 """Check that type T coerces correctly with subclasses of itself.""" 905 assert T is not bool 906 # Coercing a type with itself returns the same type. 907 self.assertIs(statistics._coerce(T, T), T) 908 # Coercing a type with a subclass of itself returns the subclass. 909 class U(T): pass 910 class V(T): pass 911 class W(U): pass 912 for typ in (U, V, W): 913 self.assertCoerceTo(T, typ) 914 self.assertCoerceTo(U, W) 915 # Coercing two subclasses that aren't parent/child is an error. 916 self.assertCoerceRaises(U, V) 917 self.assertCoerceRaises(V, W) 918 919 def test_int(self): 920 # Check that int coerces correctly. 921 self.check_type_coercions(int) 922 for typ in (float, Fraction, Decimal): 923 self.check_coerce_to(int, typ) 924 925 def test_fraction(self): 926 # Check that Fraction coerces correctly. 927 self.check_type_coercions(Fraction) 928 self.check_coerce_to(Fraction, float) 929 930 def test_decimal(self): 931 # Check that Decimal coerces correctly. 932 self.check_type_coercions(Decimal) 933 934 def test_float(self): 935 # Check that float coerces correctly. 936 self.check_type_coercions(float) 937 938 def test_non_numeric_types(self): 939 for bad_type in (str, list, type(None), tuple, dict): 940 for good_type in (int, float, Fraction, Decimal): 941 self.assertCoerceRaises(good_type, bad_type) 942 943 def test_incompatible_types(self): 944 # Test that incompatible types raise. 945 for T in (float, Fraction): 946 class MySubclass(T): pass 947 self.assertCoerceRaises(T, Decimal) 948 self.assertCoerceRaises(MySubclass, Decimal) 949 950 951class ConvertTest(unittest.TestCase): 952 # Test private _convert function. 953 954 def check_exact_equal(self, x, y): 955 """Check that x equals y, and has the same type as well.""" 956 self.assertEqual(x, y) 957 self.assertIs(type(x), type(y)) 958 959 def test_int(self): 960 # Test conversions to int. 961 x = statistics._convert(Fraction(71), int) 962 self.check_exact_equal(x, 71) 963 class MyInt(int): pass 964 x = statistics._convert(Fraction(17), MyInt) 965 self.check_exact_equal(x, MyInt(17)) 966 967 def test_fraction(self): 968 # Test conversions to Fraction. 969 x = statistics._convert(Fraction(95, 99), Fraction) 970 self.check_exact_equal(x, Fraction(95, 99)) 971 class MyFraction(Fraction): 972 def __truediv__(self, other): 973 return self.__class__(super().__truediv__(other)) 974 x = statistics._convert(Fraction(71, 13), MyFraction) 975 self.check_exact_equal(x, MyFraction(71, 13)) 976 977 def test_float(self): 978 # Test conversions to float. 979 x = statistics._convert(Fraction(-1, 2), float) 980 self.check_exact_equal(x, -0.5) 981 class MyFloat(float): 982 def __truediv__(self, other): 983 return self.__class__(super().__truediv__(other)) 984 x = statistics._convert(Fraction(9, 8), MyFloat) 985 self.check_exact_equal(x, MyFloat(1.125)) 986 987 def test_decimal(self): 988 # Test conversions to Decimal. 989 x = statistics._convert(Fraction(1, 40), Decimal) 990 self.check_exact_equal(x, Decimal("0.025")) 991 class MyDecimal(Decimal): 992 def __truediv__(self, other): 993 return self.__class__(super().__truediv__(other)) 994 x = statistics._convert(Fraction(-15, 16), MyDecimal) 995 self.check_exact_equal(x, MyDecimal("-0.9375")) 996 997 def test_inf(self): 998 for INF in (float('inf'), Decimal('inf')): 999 for inf in (INF, -INF): 1000 x = statistics._convert(inf, type(inf)) 1001 self.check_exact_equal(x, inf) 1002 1003 def test_nan(self): 1004 for nan in (float('nan'), Decimal('NAN'), Decimal('sNAN')): 1005 x = statistics._convert(nan, type(nan)) 1006 self.assertTrue(_nan_equal(x, nan)) 1007 1008 def test_invalid_input_type(self): 1009 with self.assertRaises(TypeError): 1010 statistics._convert(None, float) 1011 1012 1013class FailNegTest(unittest.TestCase): 1014 """Test _fail_neg private function.""" 1015 1016 def test_pass_through(self): 1017 # Test that values are passed through unchanged. 1018 values = [1, 2.0, Fraction(3), Decimal(4)] 1019 new = list(statistics._fail_neg(values)) 1020 self.assertEqual(values, new) 1021 1022 def test_negatives_raise(self): 1023 # Test that negatives raise an exception. 1024 for x in [1, 2.0, Fraction(3), Decimal(4)]: 1025 seq = [-x] 1026 it = statistics._fail_neg(seq) 1027 self.assertRaises(statistics.StatisticsError, next, it) 1028 1029 def test_error_msg(self): 1030 # Test that a given error message is used. 1031 msg = "badness #%d" % random.randint(10000, 99999) 1032 try: 1033 next(statistics._fail_neg([-1], msg)) 1034 except statistics.StatisticsError as e: 1035 errmsg = e.args[0] 1036 else: 1037 self.fail("expected exception, but it didn't happen") 1038 self.assertEqual(errmsg, msg) 1039 1040 1041# === Tests for public functions === 1042 1043class UnivariateCommonMixin: 1044 # Common tests for most univariate functions that take a data argument. 1045 1046 def test_no_args(self): 1047 # Fail if given no arguments. 1048 self.assertRaises(TypeError, self.func) 1049 1050 def test_empty_data(self): 1051 # Fail when the data argument (first argument) is empty. 1052 for empty in ([], (), iter([])): 1053 self.assertRaises(statistics.StatisticsError, self.func, empty) 1054 1055 def prepare_data(self): 1056 """Return int data for various tests.""" 1057 data = list(range(10)) 1058 while data == sorted(data): 1059 random.shuffle(data) 1060 return data 1061 1062 def test_no_inplace_modifications(self): 1063 # Test that the function does not modify its input data. 1064 data = self.prepare_data() 1065 assert len(data) != 1 # Necessary to avoid infinite loop. 1066 assert data != sorted(data) 1067 saved = data[:] 1068 assert data is not saved 1069 _ = self.func(data) 1070 self.assertListEqual(data, saved, "data has been modified") 1071 1072 def test_order_doesnt_matter(self): 1073 # Test that the order of data points doesn't change the result. 1074 1075 # CAUTION: due to floating-point rounding errors, the result actually 1076 # may depend on the order. Consider this test representing an ideal. 1077 # To avoid this test failing, only test with exact values such as ints 1078 # or Fractions. 1079 data = [1, 2, 3, 3, 3, 4, 5, 6]*100 1080 expected = self.func(data) 1081 random.shuffle(data) 1082 actual = self.func(data) 1083 self.assertEqual(expected, actual) 1084 1085 def test_type_of_data_collection(self): 1086 # Test that the type of iterable data doesn't effect the result. 1087 class MyList(list): 1088 pass 1089 class MyTuple(tuple): 1090 pass 1091 def generator(data): 1092 return (obj for obj in data) 1093 data = self.prepare_data() 1094 expected = self.func(data) 1095 for kind in (list, tuple, iter, MyList, MyTuple, generator): 1096 result = self.func(kind(data)) 1097 self.assertEqual(result, expected) 1098 1099 def test_range_data(self): 1100 # Test that functions work with range objects. 1101 data = range(20, 50, 3) 1102 expected = self.func(list(data)) 1103 self.assertEqual(self.func(data), expected) 1104 1105 def test_bad_arg_types(self): 1106 # Test that function raises when given data of the wrong type. 1107 1108 # Don't roll the following into a loop like this: 1109 # for bad in list_of_bad: 1110 # self.check_for_type_error(bad) 1111 # 1112 # Since assertRaises doesn't show the arguments that caused the test 1113 # failure, it is very difficult to debug these test failures when the 1114 # following are in a loop. 1115 self.check_for_type_error(None) 1116 self.check_for_type_error(23) 1117 self.check_for_type_error(42.0) 1118 self.check_for_type_error(object()) 1119 1120 def check_for_type_error(self, *args): 1121 self.assertRaises(TypeError, self.func, *args) 1122 1123 def test_type_of_data_element(self): 1124 # Check the type of data elements doesn't affect the numeric result. 1125 # This is a weaker test than UnivariateTypeMixin.testTypesConserved, 1126 # because it checks the numeric result by equality, but not by type. 1127 class MyFloat(float): 1128 def __truediv__(self, other): 1129 return type(self)(super().__truediv__(other)) 1130 def __add__(self, other): 1131 return type(self)(super().__add__(other)) 1132 __radd__ = __add__ 1133 1134 raw = self.prepare_data() 1135 expected = self.func(raw) 1136 for kind in (float, MyFloat, Decimal, Fraction): 1137 data = [kind(x) for x in raw] 1138 result = type(expected)(self.func(data)) 1139 self.assertEqual(result, expected) 1140 1141 1142class UnivariateTypeMixin: 1143 """Mixin class for type-conserving functions. 1144 1145 This mixin class holds test(s) for functions which conserve the type of 1146 individual data points. E.g. the mean of a list of Fractions should itself 1147 be a Fraction. 1148 1149 Not all tests to do with types need go in this class. Only those that 1150 rely on the function returning the same type as its input data. 1151 """ 1152 def prepare_types_for_conservation_test(self): 1153 """Return the types which are expected to be conserved.""" 1154 class MyFloat(float): 1155 def __truediv__(self, other): 1156 return type(self)(super().__truediv__(other)) 1157 def __rtruediv__(self, other): 1158 return type(self)(super().__rtruediv__(other)) 1159 def __sub__(self, other): 1160 return type(self)(super().__sub__(other)) 1161 def __rsub__(self, other): 1162 return type(self)(super().__rsub__(other)) 1163 def __pow__(self, other): 1164 return type(self)(super().__pow__(other)) 1165 def __add__(self, other): 1166 return type(self)(super().__add__(other)) 1167 __radd__ = __add__ 1168 def __mul__(self, other): 1169 return type(self)(super().__mul__(other)) 1170 __rmul__ = __mul__ 1171 return (float, Decimal, Fraction, MyFloat) 1172 1173 def test_types_conserved(self): 1174 # Test that functions keeps the same type as their data points. 1175 # (Excludes mixed data types.) This only tests the type of the return 1176 # result, not the value. 1177 data = self.prepare_data() 1178 for kind in self.prepare_types_for_conservation_test(): 1179 d = [kind(x) for x in data] 1180 result = self.func(d) 1181 self.assertIs(type(result), kind) 1182 1183 1184class TestSumCommon(UnivariateCommonMixin, UnivariateTypeMixin): 1185 # Common test cases for statistics._sum() function. 1186 1187 # This test suite looks only at the numeric value returned by _sum, 1188 # after conversion to the appropriate type. 1189 def setUp(self): 1190 def simplified_sum(*args): 1191 T, value, n = statistics._sum(*args) 1192 return statistics._coerce(value, T) 1193 self.func = simplified_sum 1194 1195 1196class TestSum(NumericTestCase): 1197 # Test cases for statistics._sum() function. 1198 1199 # These tests look at the entire three value tuple returned by _sum. 1200 1201 def setUp(self): 1202 self.func = statistics._sum 1203 1204 def test_empty_data(self): 1205 # Override test for empty data. 1206 for data in ([], (), iter([])): 1207 self.assertEqual(self.func(data), (int, Fraction(0), 0)) 1208 1209 def test_ints(self): 1210 self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]), 1211 (int, Fraction(60), 8)) 1212 1213 def test_floats(self): 1214 self.assertEqual(self.func([0.25]*20), 1215 (float, Fraction(5.0), 20)) 1216 1217 def test_fractions(self): 1218 self.assertEqual(self.func([Fraction(1, 1000)]*500), 1219 (Fraction, Fraction(1, 2), 500)) 1220 1221 def test_decimals(self): 1222 D = Decimal 1223 data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"), 1224 D("3.974"), D("2.328"), D("4.617"), D("2.843"), 1225 ] 1226 self.assertEqual(self.func(data), 1227 (Decimal, Decimal("20.686"), 8)) 1228 1229 def test_compare_with_math_fsum(self): 1230 # Compare with the math.fsum function. 1231 # Ideally we ought to get the exact same result, but sometimes 1232 # we differ by a very slight amount :-( 1233 data = [random.uniform(-100, 1000) for _ in range(1000)] 1234 self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16) 1235 1236 def test_strings_fail(self): 1237 # Sum of strings should fail. 1238 self.assertRaises(TypeError, self.func, [1, 2, 3], '999') 1239 self.assertRaises(TypeError, self.func, [1, 2, 3, '999']) 1240 1241 def test_bytes_fail(self): 1242 # Sum of bytes should fail. 1243 self.assertRaises(TypeError, self.func, [1, 2, 3], b'999') 1244 self.assertRaises(TypeError, self.func, [1, 2, 3, b'999']) 1245 1246 def test_mixed_sum(self): 1247 # Mixed input types are not (currently) allowed. 1248 # Check that mixed data types fail. 1249 self.assertRaises(TypeError, self.func, [1, 2.0, Decimal(1)]) 1250 # And so does mixed start argument. 1251 self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1)) 1252 1253 1254class SumTortureTest(NumericTestCase): 1255 def test_torture(self): 1256 # Tim Peters' torture test for sum, and variants of same. 1257 self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000), 1258 (float, Fraction(20000.0), 40000)) 1259 self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000), 1260 (float, Fraction(20000.0), 40000)) 1261 T, num, count = statistics._sum([1e-100, 1, 1e-100, -1]*10000) 1262 self.assertIs(T, float) 1263 self.assertEqual(count, 40000) 1264 self.assertApproxEqual(float(num), 2.0e-96, rel=5e-16) 1265 1266 1267class SumSpecialValues(NumericTestCase): 1268 # Test that sum works correctly with IEEE-754 special values. 1269 1270 def test_nan(self): 1271 for type_ in (float, Decimal): 1272 nan = type_('nan') 1273 result = statistics._sum([1, nan, 2])[1] 1274 self.assertIs(type(result), type_) 1275 self.assertTrue(math.isnan(result)) 1276 1277 def check_infinity(self, x, inf): 1278 """Check x is an infinity of the same type and sign as inf.""" 1279 self.assertTrue(math.isinf(x)) 1280 self.assertIs(type(x), type(inf)) 1281 self.assertEqual(x > 0, inf > 0) 1282 assert x == inf 1283 1284 def do_test_inf(self, inf): 1285 # Adding a single infinity gives infinity. 1286 result = statistics._sum([1, 2, inf, 3])[1] 1287 self.check_infinity(result, inf) 1288 # Adding two infinities of the same sign also gives infinity. 1289 result = statistics._sum([1, 2, inf, 3, inf, 4])[1] 1290 self.check_infinity(result, inf) 1291 1292 def test_float_inf(self): 1293 inf = float('inf') 1294 for sign in (+1, -1): 1295 self.do_test_inf(sign*inf) 1296 1297 def test_decimal_inf(self): 1298 inf = Decimal('inf') 1299 for sign in (+1, -1): 1300 self.do_test_inf(sign*inf) 1301 1302 def test_float_mismatched_infs(self): 1303 # Test that adding two infinities of opposite sign gives a NAN. 1304 inf = float('inf') 1305 result = statistics._sum([1, 2, inf, 3, -inf, 4])[1] 1306 self.assertTrue(math.isnan(result)) 1307 1308 def test_decimal_extendedcontext_mismatched_infs_to_nan(self): 1309 # Test adding Decimal INFs with opposite sign returns NAN. 1310 inf = Decimal('inf') 1311 data = [1, 2, inf, 3, -inf, 4] 1312 with decimal.localcontext(decimal.ExtendedContext): 1313 self.assertTrue(math.isnan(statistics._sum(data)[1])) 1314 1315 def test_decimal_basiccontext_mismatched_infs_to_nan(self): 1316 # Test adding Decimal INFs with opposite sign raises InvalidOperation. 1317 inf = Decimal('inf') 1318 data = [1, 2, inf, 3, -inf, 4] 1319 with decimal.localcontext(decimal.BasicContext): 1320 self.assertRaises(decimal.InvalidOperation, statistics._sum, data) 1321 1322 def test_decimal_snan_raises(self): 1323 # Adding sNAN should raise InvalidOperation. 1324 sNAN = Decimal('sNAN') 1325 data = [1, sNAN, 2] 1326 self.assertRaises(decimal.InvalidOperation, statistics._sum, data) 1327 1328 1329# === Tests for averages === 1330 1331class AverageMixin(UnivariateCommonMixin): 1332 # Mixin class holding common tests for averages. 1333 1334 def test_single_value(self): 1335 # Average of a single value is the value itself. 1336 for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')): 1337 self.assertEqual(self.func([x]), x) 1338 1339 def prepare_values_for_repeated_single_test(self): 1340 return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712')) 1341 1342 def test_repeated_single_value(self): 1343 # The average of a single repeated value is the value itself. 1344 for x in self.prepare_values_for_repeated_single_test(): 1345 for count in (2, 5, 10, 20): 1346 with self.subTest(x=x, count=count): 1347 data = [x]*count 1348 self.assertEqual(self.func(data), x) 1349 1350 1351class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): 1352 def setUp(self): 1353 self.func = statistics.mean 1354 1355 def test_torture_pep(self): 1356 # "Torture Test" from PEP-450. 1357 self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1) 1358 1359 def test_ints(self): 1360 # Test mean with ints. 1361 data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9] 1362 random.shuffle(data) 1363 self.assertEqual(self.func(data), 4.8125) 1364 1365 def test_floats(self): 1366 # Test mean with floats. 1367 data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5] 1368 random.shuffle(data) 1369 self.assertEqual(self.func(data), 22.015625) 1370 1371 def test_decimals(self): 1372 # Test mean with Decimals. 1373 D = Decimal 1374 data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")] 1375 random.shuffle(data) 1376 self.assertEqual(self.func(data), D("3.5896")) 1377 1378 def test_fractions(self): 1379 # Test mean with Fractions. 1380 F = Fraction 1381 data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)] 1382 random.shuffle(data) 1383 self.assertEqual(self.func(data), F(1479, 1960)) 1384 1385 def test_inf(self): 1386 # Test mean with infinities. 1387 raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later. 1388 for kind in (float, Decimal): 1389 for sign in (1, -1): 1390 inf = kind("inf")*sign 1391 data = raw + [inf] 1392 result = self.func(data) 1393 self.assertTrue(math.isinf(result)) 1394 self.assertEqual(result, inf) 1395 1396 def test_mismatched_infs(self): 1397 # Test mean with infinities of opposite sign. 1398 data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')] 1399 result = self.func(data) 1400 self.assertTrue(math.isnan(result)) 1401 1402 def test_nan(self): 1403 # Test mean with NANs. 1404 raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later. 1405 for kind in (float, Decimal): 1406 inf = kind("nan") 1407 data = raw + [inf] 1408 result = self.func(data) 1409 self.assertTrue(math.isnan(result)) 1410 1411 def test_big_data(self): 1412 # Test adding a large constant to every data point. 1413 c = 1e9 1414 data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4] 1415 expected = self.func(data) + c 1416 assert expected != c 1417 result = self.func([x+c for x in data]) 1418 self.assertEqual(result, expected) 1419 1420 def test_doubled_data(self): 1421 # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z]. 1422 data = [random.uniform(-3, 5) for _ in range(1000)] 1423 expected = self.func(data) 1424 actual = self.func(data*2) 1425 self.assertApproxEqual(actual, expected) 1426 1427 def test_regression_20561(self): 1428 # Regression test for issue 20561. 1429 # See http://bugs.python.org/issue20561 1430 d = Decimal('1e4') 1431 self.assertEqual(statistics.mean([d]), d) 1432 1433 def test_regression_25177(self): 1434 # Regression test for issue 25177. 1435 # Ensure very big and very small floats don't overflow. 1436 # See http://bugs.python.org/issue25177. 1437 self.assertEqual(statistics.mean( 1438 [8.988465674311579e+307, 8.98846567431158e+307]), 1439 8.98846567431158e+307) 1440 big = 8.98846567431158e+307 1441 tiny = 5e-324 1442 for n in (2, 3, 5, 200): 1443 self.assertEqual(statistics.mean([big]*n), big) 1444 self.assertEqual(statistics.mean([tiny]*n), tiny) 1445 1446 1447class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): 1448 def setUp(self): 1449 self.func = statistics.harmonic_mean 1450 1451 def prepare_data(self): 1452 # Override mixin method. 1453 values = super().prepare_data() 1454 values.remove(0) 1455 return values 1456 1457 def prepare_values_for_repeated_single_test(self): 1458 # Override mixin method. 1459 return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125')) 1460 1461 def test_zero(self): 1462 # Test that harmonic mean returns zero when given zero. 1463 values = [1, 0, 2] 1464 self.assertEqual(self.func(values), 0) 1465 1466 def test_negative_error(self): 1467 # Test that harmonic mean raises when given a negative value. 1468 exc = statistics.StatisticsError 1469 for values in ([-1], [1, -2, 3]): 1470 with self.subTest(values=values): 1471 self.assertRaises(exc, self.func, values) 1472 1473 def test_invalid_type_error(self): 1474 # Test error is raised when input contains invalid type(s) 1475 for data in [ 1476 ['3.14'], # single string 1477 ['1', '2', '3'], # multiple strings 1478 [1, '2', 3, '4', 5], # mixed strings and valid integers 1479 [2.3, 3.4, 4.5, '5.6'] # only one string and valid floats 1480 ]: 1481 with self.subTest(data=data): 1482 with self.assertRaises(TypeError): 1483 self.func(data) 1484 1485 def test_ints(self): 1486 # Test harmonic mean with ints. 1487 data = [2, 4, 4, 8, 16, 16] 1488 random.shuffle(data) 1489 self.assertEqual(self.func(data), 6*4/5) 1490 1491 def test_floats_exact(self): 1492 # Test harmonic mean with some carefully chosen floats. 1493 data = [1/8, 1/4, 1/4, 1/2, 1/2] 1494 random.shuffle(data) 1495 self.assertEqual(self.func(data), 1/4) 1496 self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5) 1497 1498 def test_singleton_lists(self): 1499 # Test that harmonic mean([x]) returns (approximately) x. 1500 for x in range(1, 101): 1501 self.assertEqual(self.func([x]), x) 1502 1503 def test_decimals_exact(self): 1504 # Test harmonic mean with some carefully chosen Decimals. 1505 D = Decimal 1506 self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30)) 1507 data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")] 1508 random.shuffle(data) 1509 self.assertEqual(self.func(data), D("0.10")) 1510 data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")] 1511 random.shuffle(data) 1512 self.assertEqual(self.func(data), D(66528)/70723) 1513 1514 def test_fractions(self): 1515 # Test harmonic mean with Fractions. 1516 F = Fraction 1517 data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)] 1518 random.shuffle(data) 1519 self.assertEqual(self.func(data), F(7*420, 4029)) 1520 1521 def test_inf(self): 1522 # Test harmonic mean with infinity. 1523 values = [2.0, float('inf'), 1.0] 1524 self.assertEqual(self.func(values), 2.0) 1525 1526 def test_nan(self): 1527 # Test harmonic mean with NANs. 1528 values = [2.0, float('nan'), 1.0] 1529 self.assertTrue(math.isnan(self.func(values))) 1530 1531 def test_multiply_data_points(self): 1532 # Test multiplying every data point by a constant. 1533 c = 111 1534 data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4] 1535 expected = self.func(data)*c 1536 result = self.func([x*c for x in data]) 1537 self.assertEqual(result, expected) 1538 1539 def test_doubled_data(self): 1540 # Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z]. 1541 data = [random.uniform(1, 5) for _ in range(1000)] 1542 expected = self.func(data) 1543 actual = self.func(data*2) 1544 self.assertApproxEqual(actual, expected) 1545 1546 def test_with_weights(self): 1547 self.assertEqual(self.func([40, 60], [5, 30]), 56.0) # common case 1548 self.assertEqual(self.func([40, 60], 1549 weights=[5, 30]), 56.0) # keyword argument 1550 self.assertEqual(self.func(iter([40, 60]), 1551 iter([5, 30])), 56.0) # iterator inputs 1552 self.assertEqual( 1553 self.func([Fraction(10, 3), Fraction(23, 5), Fraction(7, 2)], [5, 2, 10]), 1554 self.func([Fraction(10, 3)] * 5 + 1555 [Fraction(23, 5)] * 2 + 1556 [Fraction(7, 2)] * 10)) 1557 self.assertEqual(self.func([10], [7]), 10) # n=1 fast path 1558 with self.assertRaises(TypeError): 1559 self.func([1, 2, 3], [1, (), 3]) # non-numeric weight 1560 with self.assertRaises(statistics.StatisticsError): 1561 self.func([1, 2, 3], [1, 2]) # wrong number of weights 1562 with self.assertRaises(statistics.StatisticsError): 1563 self.func([10], [0]) # no non-zero weights 1564 with self.assertRaises(statistics.StatisticsError): 1565 self.func([10, 20], [0, 0]) # no non-zero weights 1566 1567 1568class TestMedian(NumericTestCase, AverageMixin): 1569 # Common tests for median and all median.* functions. 1570 def setUp(self): 1571 self.func = statistics.median 1572 1573 def prepare_data(self): 1574 """Overload method from UnivariateCommonMixin.""" 1575 data = super().prepare_data() 1576 if len(data)%2 != 1: 1577 data.append(2) 1578 return data 1579 1580 def test_even_ints(self): 1581 # Test median with an even number of int data points. 1582 data = [1, 2, 3, 4, 5, 6] 1583 assert len(data)%2 == 0 1584 self.assertEqual(self.func(data), 3.5) 1585 1586 def test_odd_ints(self): 1587 # Test median with an odd number of int data points. 1588 data = [1, 2, 3, 4, 5, 6, 9] 1589 assert len(data)%2 == 1 1590 self.assertEqual(self.func(data), 4) 1591 1592 def test_odd_fractions(self): 1593 # Test median works with an odd number of Fractions. 1594 F = Fraction 1595 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)] 1596 assert len(data)%2 == 1 1597 random.shuffle(data) 1598 self.assertEqual(self.func(data), F(3, 7)) 1599 1600 def test_even_fractions(self): 1601 # Test median works with an even number of Fractions. 1602 F = Fraction 1603 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] 1604 assert len(data)%2 == 0 1605 random.shuffle(data) 1606 self.assertEqual(self.func(data), F(1, 2)) 1607 1608 def test_odd_decimals(self): 1609 # Test median works with an odd number of Decimals. 1610 D = Decimal 1611 data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')] 1612 assert len(data)%2 == 1 1613 random.shuffle(data) 1614 self.assertEqual(self.func(data), D('4.2')) 1615 1616 def test_even_decimals(self): 1617 # Test median works with an even number of Decimals. 1618 D = Decimal 1619 data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')] 1620 assert len(data)%2 == 0 1621 random.shuffle(data) 1622 self.assertEqual(self.func(data), D('3.65')) 1623 1624 1625class TestMedianDataType(NumericTestCase, UnivariateTypeMixin): 1626 # Test conservation of data element type for median. 1627 def setUp(self): 1628 self.func = statistics.median 1629 1630 def prepare_data(self): 1631 data = list(range(15)) 1632 assert len(data)%2 == 1 1633 while data == sorted(data): 1634 random.shuffle(data) 1635 return data 1636 1637 1638class TestMedianLow(TestMedian, UnivariateTypeMixin): 1639 def setUp(self): 1640 self.func = statistics.median_low 1641 1642 def test_even_ints(self): 1643 # Test median_low with an even number of ints. 1644 data = [1, 2, 3, 4, 5, 6] 1645 assert len(data)%2 == 0 1646 self.assertEqual(self.func(data), 3) 1647 1648 def test_even_fractions(self): 1649 # Test median_low works with an even number of Fractions. 1650 F = Fraction 1651 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] 1652 assert len(data)%2 == 0 1653 random.shuffle(data) 1654 self.assertEqual(self.func(data), F(3, 7)) 1655 1656 def test_even_decimals(self): 1657 # Test median_low works with an even number of Decimals. 1658 D = Decimal 1659 data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')] 1660 assert len(data)%2 == 0 1661 random.shuffle(data) 1662 self.assertEqual(self.func(data), D('3.3')) 1663 1664 1665class TestMedianHigh(TestMedian, UnivariateTypeMixin): 1666 def setUp(self): 1667 self.func = statistics.median_high 1668 1669 def test_even_ints(self): 1670 # Test median_high with an even number of ints. 1671 data = [1, 2, 3, 4, 5, 6] 1672 assert len(data)%2 == 0 1673 self.assertEqual(self.func(data), 4) 1674 1675 def test_even_fractions(self): 1676 # Test median_high works with an even number of Fractions. 1677 F = Fraction 1678 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] 1679 assert len(data)%2 == 0 1680 random.shuffle(data) 1681 self.assertEqual(self.func(data), F(4, 7)) 1682 1683 def test_even_decimals(self): 1684 # Test median_high works with an even number of Decimals. 1685 D = Decimal 1686 data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')] 1687 assert len(data)%2 == 0 1688 random.shuffle(data) 1689 self.assertEqual(self.func(data), D('4.4')) 1690 1691 1692class TestMedianGrouped(TestMedian): 1693 # Test median_grouped. 1694 # Doesn't conserve data element types, so don't use TestMedianType. 1695 def setUp(self): 1696 self.func = statistics.median_grouped 1697 1698 def test_odd_number_repeated(self): 1699 # Test median.grouped with repeated median values. 1700 data = [12, 13, 14, 14, 14, 15, 15] 1701 assert len(data)%2 == 1 1702 self.assertEqual(self.func(data), 14) 1703 #--- 1704 data = [12, 13, 14, 14, 14, 14, 15] 1705 assert len(data)%2 == 1 1706 self.assertEqual(self.func(data), 13.875) 1707 #--- 1708 data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30] 1709 assert len(data)%2 == 1 1710 self.assertEqual(self.func(data, 5), 19.375) 1711 #--- 1712 data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28] 1713 assert len(data)%2 == 1 1714 self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8) 1715 1716 def test_even_number_repeated(self): 1717 # Test median.grouped with repeated median values. 1718 data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30] 1719 assert len(data)%2 == 0 1720 self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8) 1721 #--- 1722 data = [2, 3, 4, 4, 4, 5] 1723 assert len(data)%2 == 0 1724 self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8) 1725 #--- 1726 data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6] 1727 assert len(data)%2 == 0 1728 self.assertEqual(self.func(data), 4.5) 1729 #--- 1730 data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6] 1731 assert len(data)%2 == 0 1732 self.assertEqual(self.func(data), 4.75) 1733 1734 def test_repeated_single_value(self): 1735 # Override method from AverageMixin. 1736 # Yet again, failure of median_grouped to conserve the data type 1737 # causes me headaches :-( 1738 for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')): 1739 for count in (2, 5, 10, 20): 1740 data = [x]*count 1741 self.assertEqual(self.func(data), float(x)) 1742 1743 def test_single_value(self): 1744 # Override method from AverageMixin. 1745 # Average of a single value is the value as a float. 1746 for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')): 1747 self.assertEqual(self.func([x]), float(x)) 1748 1749 def test_odd_fractions(self): 1750 # Test median_grouped works with an odd number of Fractions. 1751 F = Fraction 1752 data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)] 1753 assert len(data)%2 == 1 1754 random.shuffle(data) 1755 self.assertEqual(self.func(data), 3.0) 1756 1757 def test_even_fractions(self): 1758 # Test median_grouped works with an even number of Fractions. 1759 F = Fraction 1760 data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)] 1761 assert len(data)%2 == 0 1762 random.shuffle(data) 1763 self.assertEqual(self.func(data), 3.25) 1764 1765 def test_odd_decimals(self): 1766 # Test median_grouped works with an odd number of Decimals. 1767 D = Decimal 1768 data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')] 1769 assert len(data)%2 == 1 1770 random.shuffle(data) 1771 self.assertEqual(self.func(data), 6.75) 1772 1773 def test_even_decimals(self): 1774 # Test median_grouped works with an even number of Decimals. 1775 D = Decimal 1776 data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')] 1777 assert len(data)%2 == 0 1778 random.shuffle(data) 1779 self.assertEqual(self.func(data), 6.5) 1780 #--- 1781 data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')] 1782 assert len(data)%2 == 0 1783 random.shuffle(data) 1784 self.assertEqual(self.func(data), 7.0) 1785 1786 def test_interval(self): 1787 # Test median_grouped with interval argument. 1788 data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75] 1789 self.assertEqual(self.func(data, 0.25), 2.875) 1790 data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75] 1791 self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8) 1792 data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340] 1793 self.assertEqual(self.func(data, 20), 265.0) 1794 1795 def test_data_type_error(self): 1796 # Test median_grouped with str, bytes data types for data and interval 1797 data = ["", "", ""] 1798 self.assertRaises(TypeError, self.func, data) 1799 #--- 1800 data = [b"", b"", b""] 1801 self.assertRaises(TypeError, self.func, data) 1802 #--- 1803 data = [1, 2, 3] 1804 interval = "" 1805 self.assertRaises(TypeError, self.func, data, interval) 1806 #--- 1807 data = [1, 2, 3] 1808 interval = b"" 1809 self.assertRaises(TypeError, self.func, data, interval) 1810 1811 1812class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin): 1813 # Test cases for the discrete version of mode. 1814 def setUp(self): 1815 self.func = statistics.mode 1816 1817 def prepare_data(self): 1818 """Overload method from UnivariateCommonMixin.""" 1819 # Make sure test data has exactly one mode. 1820 return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2] 1821 1822 def test_range_data(self): 1823 # Override test from UnivariateCommonMixin. 1824 data = range(20, 50, 3) 1825 self.assertEqual(self.func(data), 20) 1826 1827 def test_nominal_data(self): 1828 # Test mode with nominal data. 1829 data = 'abcbdb' 1830 self.assertEqual(self.func(data), 'b') 1831 data = 'fe fi fo fum fi fi'.split() 1832 self.assertEqual(self.func(data), 'fi') 1833 1834 def test_discrete_data(self): 1835 # Test mode with discrete numeric data. 1836 data = list(range(10)) 1837 for i in range(10): 1838 d = data + [i] 1839 random.shuffle(d) 1840 self.assertEqual(self.func(d), i) 1841 1842 def test_bimodal_data(self): 1843 # Test mode with bimodal data. 1844 data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9] 1845 assert data.count(2) == data.count(6) == 4 1846 # mode() should return 2, the first encountered mode 1847 self.assertEqual(self.func(data), 2) 1848 1849 def test_unique_data(self): 1850 # Test mode when data points are all unique. 1851 data = list(range(10)) 1852 # mode() should return 0, the first encountered mode 1853 self.assertEqual(self.func(data), 0) 1854 1855 def test_none_data(self): 1856 # Test that mode raises TypeError if given None as data. 1857 1858 # This test is necessary because the implementation of mode uses 1859 # collections.Counter, which accepts None and returns an empty dict. 1860 self.assertRaises(TypeError, self.func, None) 1861 1862 def test_counter_data(self): 1863 # Test that a Counter is treated like any other iterable. 1864 # We're making sure mode() first calls iter() on its input. 1865 # The concern is that a Counter of a Counter returns the original 1866 # unchanged rather than counting its keys. 1867 c = collections.Counter(a=1, b=2) 1868 # If iter() is called, mode(c) loops over the keys, ['a', 'b'], 1869 # all the counts will be 1, and the first encountered mode is 'a'. 1870 self.assertEqual(self.func(c), 'a') 1871 1872 1873class TestMultiMode(unittest.TestCase): 1874 1875 def test_basics(self): 1876 multimode = statistics.multimode 1877 self.assertEqual(multimode('aabbbbbbbbcc'), ['b']) 1878 self.assertEqual(multimode('aabbbbccddddeeffffgg'), ['b', 'd', 'f']) 1879 self.assertEqual(multimode(''), []) 1880 1881 1882class TestFMean(unittest.TestCase): 1883 1884 def test_basics(self): 1885 fmean = statistics.fmean 1886 D = Decimal 1887 F = Fraction 1888 for data, expected_mean, kind in [ 1889 ([3.5, 4.0, 5.25], 4.25, 'floats'), 1890 ([D('3.5'), D('4.0'), D('5.25')], 4.25, 'decimals'), 1891 ([F(7, 2), F(4, 1), F(21, 4)], 4.25, 'fractions'), 1892 ([True, False, True, True, False], 0.60, 'booleans'), 1893 ([3.5, 4, F(21, 4)], 4.25, 'mixed types'), 1894 ((3.5, 4.0, 5.25), 4.25, 'tuple'), 1895 (iter([3.5, 4.0, 5.25]), 4.25, 'iterator'), 1896 ]: 1897 actual_mean = fmean(data) 1898 self.assertIs(type(actual_mean), float, kind) 1899 self.assertEqual(actual_mean, expected_mean, kind) 1900 1901 def test_error_cases(self): 1902 fmean = statistics.fmean 1903 StatisticsError = statistics.StatisticsError 1904 with self.assertRaises(StatisticsError): 1905 fmean([]) # empty input 1906 with self.assertRaises(StatisticsError): 1907 fmean(iter([])) # empty iterator 1908 with self.assertRaises(TypeError): 1909 fmean(None) # non-iterable input 1910 with self.assertRaises(TypeError): 1911 fmean([10, None, 20]) # non-numeric input 1912 with self.assertRaises(TypeError): 1913 fmean() # missing data argument 1914 with self.assertRaises(TypeError): 1915 fmean([10, 20, 60], 70) # too many arguments 1916 1917 def test_special_values(self): 1918 # Rules for special values are inherited from math.fsum() 1919 fmean = statistics.fmean 1920 NaN = float('Nan') 1921 Inf = float('Inf') 1922 self.assertTrue(math.isnan(fmean([10, NaN])), 'nan') 1923 self.assertTrue(math.isnan(fmean([NaN, Inf])), 'nan and infinity') 1924 self.assertTrue(math.isinf(fmean([10, Inf])), 'infinity') 1925 with self.assertRaises(ValueError): 1926 fmean([Inf, -Inf]) 1927 1928 def test_weights(self): 1929 fmean = statistics.fmean 1930 StatisticsError = statistics.StatisticsError 1931 self.assertEqual( 1932 fmean([10, 10, 10, 50], [0.25] * 4), 1933 fmean([10, 10, 10, 50])) 1934 self.assertEqual( 1935 fmean([10, 10, 20], [0.25, 0.25, 0.50]), 1936 fmean([10, 10, 20, 20])) 1937 self.assertEqual( # inputs are iterators 1938 fmean(iter([10, 10, 20]), iter([0.25, 0.25, 0.50])), 1939 fmean([10, 10, 20, 20])) 1940 with self.assertRaises(StatisticsError): 1941 fmean([10, 20, 30], [1, 2]) # unequal lengths 1942 with self.assertRaises(StatisticsError): 1943 fmean(iter([10, 20, 30]), iter([1, 2])) # unequal lengths 1944 with self.assertRaises(StatisticsError): 1945 fmean([10, 20], [-1, 1]) # sum of weights is zero 1946 with self.assertRaises(StatisticsError): 1947 fmean(iter([10, 20]), iter([-1, 1])) # sum of weights is zero 1948 1949 1950# === Tests for variances and standard deviations === 1951 1952class VarianceStdevMixin(UnivariateCommonMixin): 1953 # Mixin class holding common tests for variance and std dev. 1954 1955 # Subclasses should inherit from this before NumericTestClass, in order 1956 # to see the rel attribute below. See testShiftData for an explanation. 1957 1958 rel = 1e-12 1959 1960 def test_single_value(self): 1961 # Deviation of a single value is zero. 1962 for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')): 1963 self.assertEqual(self.func([x]), 0) 1964 1965 def test_repeated_single_value(self): 1966 # The deviation of a single repeated value is zero. 1967 for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')): 1968 for count in (2, 3, 5, 15): 1969 data = [x]*count 1970 self.assertEqual(self.func(data), 0) 1971 1972 def test_domain_error_regression(self): 1973 # Regression test for a domain error exception. 1974 # (Thanks to Geremy Condra.) 1975 data = [0.123456789012345]*10000 1976 # All the items are identical, so variance should be exactly zero. 1977 # We allow some small round-off error, but not much. 1978 result = self.func(data) 1979 self.assertApproxEqual(result, 0.0, tol=5e-17) 1980 self.assertGreaterEqual(result, 0) # A negative result must fail. 1981 1982 def test_shift_data(self): 1983 # Test that shifting the data by a constant amount does not affect 1984 # the variance or stdev. Or at least not much. 1985 1986 # Due to rounding, this test should be considered an ideal. We allow 1987 # some tolerance away from "no change at all" by setting tol and/or rel 1988 # attributes. Subclasses may set tighter or looser error tolerances. 1989 raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78] 1990 expected = self.func(raw) 1991 # Don't set shift too high, the bigger it is, the more rounding error. 1992 shift = 1e5 1993 data = [x + shift for x in raw] 1994 self.assertApproxEqual(self.func(data), expected) 1995 1996 def test_shift_data_exact(self): 1997 # Like test_shift_data, but result is always exact. 1998 raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16] 1999 assert all(x==int(x) for x in raw) 2000 expected = self.func(raw) 2001 shift = 10**9 2002 data = [x + shift for x in raw] 2003 self.assertEqual(self.func(data), expected) 2004 2005 def test_iter_list_same(self): 2006 # Test that iter data and list data give the same result. 2007 2008 # This is an explicit test that iterators and lists are treated the 2009 # same; justification for this test over and above the similar test 2010 # in UnivariateCommonMixin is that an earlier design had variance and 2011 # friends swap between one- and two-pass algorithms, which would 2012 # sometimes give different results. 2013 data = [random.uniform(-3, 8) for _ in range(1000)] 2014 expected = self.func(data) 2015 self.assertEqual(self.func(iter(data)), expected) 2016 2017 2018class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin): 2019 # Tests for population variance. 2020 def setUp(self): 2021 self.func = statistics.pvariance 2022 2023 def test_exact_uniform(self): 2024 # Test the variance against an exact result for uniform data. 2025 data = list(range(10000)) 2026 random.shuffle(data) 2027 expected = (10000**2 - 1)/12 # Exact value. 2028 self.assertEqual(self.func(data), expected) 2029 2030 def test_ints(self): 2031 # Test population variance with int data. 2032 data = [4, 7, 13, 16] 2033 exact = 22.5 2034 self.assertEqual(self.func(data), exact) 2035 2036 def test_fractions(self): 2037 # Test population variance with Fraction data. 2038 F = Fraction 2039 data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)] 2040 exact = F(3, 8) 2041 result = self.func(data) 2042 self.assertEqual(result, exact) 2043 self.assertIsInstance(result, Fraction) 2044 2045 def test_decimals(self): 2046 # Test population variance with Decimal data. 2047 D = Decimal 2048 data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")] 2049 exact = D('0.096875') 2050 result = self.func(data) 2051 self.assertEqual(result, exact) 2052 self.assertIsInstance(result, Decimal) 2053 2054 def test_accuracy_bug_20499(self): 2055 data = [0, 0, 1] 2056 exact = 2 / 9 2057 result = self.func(data) 2058 self.assertEqual(result, exact) 2059 self.assertIsInstance(result, float) 2060 2061 2062class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin): 2063 # Tests for sample variance. 2064 def setUp(self): 2065 self.func = statistics.variance 2066 2067 def test_single_value(self): 2068 # Override method from VarianceStdevMixin. 2069 for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')): 2070 self.assertRaises(statistics.StatisticsError, self.func, [x]) 2071 2072 def test_ints(self): 2073 # Test sample variance with int data. 2074 data = [4, 7, 13, 16] 2075 exact = 30 2076 self.assertEqual(self.func(data), exact) 2077 2078 def test_fractions(self): 2079 # Test sample variance with Fraction data. 2080 F = Fraction 2081 data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)] 2082 exact = F(1, 2) 2083 result = self.func(data) 2084 self.assertEqual(result, exact) 2085 self.assertIsInstance(result, Fraction) 2086 2087 def test_decimals(self): 2088 # Test sample variance with Decimal data. 2089 D = Decimal 2090 data = [D(2), D(2), D(7), D(9)] 2091 exact = 4*D('9.5')/D(3) 2092 result = self.func(data) 2093 self.assertEqual(result, exact) 2094 self.assertIsInstance(result, Decimal) 2095 2096 def test_center_not_at_mean(self): 2097 data = (1.0, 2.0) 2098 self.assertEqual(self.func(data), 0.5) 2099 self.assertEqual(self.func(data, xbar=2.0), 1.0) 2100 2101 def test_accuracy_bug_20499(self): 2102 data = [0, 0, 2] 2103 exact = 4 / 3 2104 result = self.func(data) 2105 self.assertEqual(result, exact) 2106 self.assertIsInstance(result, float) 2107 2108class TestPStdev(VarianceStdevMixin, NumericTestCase): 2109 # Tests for population standard deviation. 2110 def setUp(self): 2111 self.func = statistics.pstdev 2112 2113 def test_compare_to_variance(self): 2114 # Test that stdev is, in fact, the square root of variance. 2115 data = [random.uniform(-17, 24) for _ in range(1000)] 2116 expected = math.sqrt(statistics.pvariance(data)) 2117 self.assertEqual(self.func(data), expected) 2118 2119 def test_center_not_at_mean(self): 2120 # See issue: 40855 2121 data = (3, 6, 7, 10) 2122 self.assertEqual(self.func(data), 2.5) 2123 self.assertEqual(self.func(data, mu=0.5), 6.5) 2124 2125class TestSqrtHelpers(unittest.TestCase): 2126 2127 def test_integer_sqrt_of_frac_rto(self): 2128 for n, m in itertools.product(range(100), range(1, 1000)): 2129 r = statistics._integer_sqrt_of_frac_rto(n, m) 2130 self.assertIsInstance(r, int) 2131 if r*r*m == n: 2132 # Root is exact 2133 continue 2134 # Inexact, so the root should be odd 2135 self.assertEqual(r&1, 1) 2136 # Verify correct rounding 2137 self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2) 2138 2139 @requires_IEEE_754 2140 @support.requires_resource('cpu') 2141 def test_float_sqrt_of_frac(self): 2142 2143 def is_root_correctly_rounded(x: Fraction, root: float) -> bool: 2144 if not x: 2145 return root == 0.0 2146 2147 # Extract adjacent representable floats 2148 r_up: float = math.nextafter(root, math.inf) 2149 r_down: float = math.nextafter(root, -math.inf) 2150 assert r_down < root < r_up 2151 2152 # Convert to fractions for exact arithmetic 2153 frac_root: Fraction = Fraction(root) 2154 half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2 2155 half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2 2156 2157 # Check a closed interval. 2158 # Does not test for a midpoint rounding rule. 2159 return half_way_down ** 2 <= x <= half_way_up ** 2 2160 2161 randrange = random.randrange 2162 2163 for i in range(60_000): 2164 numerator: int = randrange(10 ** randrange(50)) 2165 denonimator: int = randrange(10 ** randrange(50)) + 1 2166 with self.subTest(numerator=numerator, denonimator=denonimator): 2167 x: Fraction = Fraction(numerator, denonimator) 2168 root: float = statistics._float_sqrt_of_frac(numerator, denonimator) 2169 self.assertTrue(is_root_correctly_rounded(x, root)) 2170 2171 # Verify that corner cases and error handling match math.sqrt() 2172 self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0) 2173 with self.assertRaises(ValueError): 2174 statistics._float_sqrt_of_frac(-1, 1) 2175 with self.assertRaises(ValueError): 2176 statistics._float_sqrt_of_frac(1, -1) 2177 2178 # Error handling for zero denominator matches that for Fraction(1, 0) 2179 with self.assertRaises(ZeroDivisionError): 2180 statistics._float_sqrt_of_frac(1, 0) 2181 2182 # The result is well defined if both inputs are negative 2183 self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1)) 2184 2185 def test_decimal_sqrt_of_frac(self): 2186 root: Decimal 2187 numerator: int 2188 denominator: int 2189 2190 for root, numerator, denominator in [ 2191 (Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj 2192 (Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up 2193 (Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down 2194 ]: 2195 with decimal.localcontext(decimal.DefaultContext): 2196 self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root) 2197 2198 # Confirm expected root with a quad precision decimal computation 2199 with decimal.localcontext(decimal.DefaultContext) as ctx: 2200 ctx.prec *= 4 2201 high_prec_ratio = Decimal(numerator) / Decimal(denominator) 2202 ctx.rounding = decimal.ROUND_05UP 2203 high_prec_root = high_prec_ratio.sqrt() 2204 with decimal.localcontext(decimal.DefaultContext): 2205 target_root = +high_prec_root 2206 self.assertEqual(root, target_root) 2207 2208 # Verify that corner cases and error handling match Decimal.sqrt() 2209 self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0) 2210 with self.assertRaises(decimal.InvalidOperation): 2211 statistics._decimal_sqrt_of_frac(-1, 1) 2212 with self.assertRaises(decimal.InvalidOperation): 2213 statistics._decimal_sqrt_of_frac(1, -1) 2214 2215 # Error handling for zero denominator matches that for Fraction(1, 0) 2216 with self.assertRaises(ZeroDivisionError): 2217 statistics._decimal_sqrt_of_frac(1, 0) 2218 2219 # The result is well defined if both inputs are negative 2220 self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1)) 2221 2222 2223class TestStdev(VarianceStdevMixin, NumericTestCase): 2224 # Tests for sample standard deviation. 2225 def setUp(self): 2226 self.func = statistics.stdev 2227 2228 def test_single_value(self): 2229 # Override method from VarianceStdevMixin. 2230 for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')): 2231 self.assertRaises(statistics.StatisticsError, self.func, [x]) 2232 2233 def test_compare_to_variance(self): 2234 # Test that stdev is, in fact, the square root of variance. 2235 data = [random.uniform(-2, 9) for _ in range(1000)] 2236 expected = math.sqrt(statistics.variance(data)) 2237 self.assertAlmostEqual(self.func(data), expected) 2238 2239 def test_center_not_at_mean(self): 2240 data = (1.0, 2.0) 2241 self.assertEqual(self.func(data, xbar=2.0), 1.0) 2242 2243class TestGeometricMean(unittest.TestCase): 2244 2245 def test_basics(self): 2246 geometric_mean = statistics.geometric_mean 2247 self.assertAlmostEqual(geometric_mean([54, 24, 36]), 36.0) 2248 self.assertAlmostEqual(geometric_mean([4.0, 9.0]), 6.0) 2249 self.assertAlmostEqual(geometric_mean([17.625]), 17.625) 2250 2251 random.seed(86753095551212) 2252 for rng in [ 2253 range(1, 100), 2254 range(1, 1_000), 2255 range(1, 10_000), 2256 range(500, 10_000, 3), 2257 range(10_000, 500, -3), 2258 [12, 17, 13, 5, 120, 7], 2259 [random.expovariate(50.0) for i in range(1_000)], 2260 [random.lognormvariate(20.0, 3.0) for i in range(2_000)], 2261 [random.triangular(2000, 3000, 2200) for i in range(3_000)], 2262 ]: 2263 gm_decimal = math.prod(map(Decimal, rng)) ** (Decimal(1) / len(rng)) 2264 gm_float = geometric_mean(rng) 2265 self.assertTrue(math.isclose(gm_float, float(gm_decimal))) 2266 2267 def test_various_input_types(self): 2268 geometric_mean = statistics.geometric_mean 2269 D = Decimal 2270 F = Fraction 2271 # https://www.wolframalpha.com/input/?i=geometric+mean+3.5,+4.0,+5.25 2272 expected_mean = 4.18886 2273 for data, kind in [ 2274 ([3.5, 4.0, 5.25], 'floats'), 2275 ([D('3.5'), D('4.0'), D('5.25')], 'decimals'), 2276 ([F(7, 2), F(4, 1), F(21, 4)], 'fractions'), 2277 ([3.5, 4, F(21, 4)], 'mixed types'), 2278 ((3.5, 4.0, 5.25), 'tuple'), 2279 (iter([3.5, 4.0, 5.25]), 'iterator'), 2280 ]: 2281 actual_mean = geometric_mean(data) 2282 self.assertIs(type(actual_mean), float, kind) 2283 self.assertAlmostEqual(actual_mean, expected_mean, places=5) 2284 2285 def test_big_and_small(self): 2286 geometric_mean = statistics.geometric_mean 2287 2288 # Avoid overflow to infinity 2289 large = 2.0 ** 1000 2290 big_gm = geometric_mean([54.0 * large, 24.0 * large, 36.0 * large]) 2291 self.assertTrue(math.isclose(big_gm, 36.0 * large)) 2292 self.assertFalse(math.isinf(big_gm)) 2293 2294 # Avoid underflow to zero 2295 small = 2.0 ** -1000 2296 small_gm = geometric_mean([54.0 * small, 24.0 * small, 36.0 * small]) 2297 self.assertTrue(math.isclose(small_gm, 36.0 * small)) 2298 self.assertNotEqual(small_gm, 0.0) 2299 2300 def test_error_cases(self): 2301 geometric_mean = statistics.geometric_mean 2302 StatisticsError = statistics.StatisticsError 2303 with self.assertRaises(StatisticsError): 2304 geometric_mean([]) # empty input 2305 with self.assertRaises(StatisticsError): 2306 geometric_mean([3.5, -4.0, 5.25]) # negative input 2307 with self.assertRaises(StatisticsError): 2308 geometric_mean([0.0, -4.0, 5.25]) # negative input with zero 2309 with self.assertRaises(StatisticsError): 2310 geometric_mean([3.5, -math.inf, 5.25]) # negative infinity 2311 with self.assertRaises(StatisticsError): 2312 geometric_mean(iter([])) # empty iterator 2313 with self.assertRaises(TypeError): 2314 geometric_mean(None) # non-iterable input 2315 with self.assertRaises(TypeError): 2316 geometric_mean([10, None, 20]) # non-numeric input 2317 with self.assertRaises(TypeError): 2318 geometric_mean() # missing data argument 2319 with self.assertRaises(TypeError): 2320 geometric_mean([10, 20, 60], 70) # too many arguments 2321 2322 def test_special_values(self): 2323 # Rules for special values are inherited from math.fsum() 2324 geometric_mean = statistics.geometric_mean 2325 NaN = float('Nan') 2326 Inf = float('Inf') 2327 self.assertTrue(math.isnan(geometric_mean([10, NaN])), 'nan') 2328 self.assertTrue(math.isnan(geometric_mean([NaN, Inf])), 'nan and infinity') 2329 self.assertTrue(math.isinf(geometric_mean([10, Inf])), 'infinity') 2330 with self.assertRaises(ValueError): 2331 geometric_mean([Inf, -Inf]) 2332 2333 # Cases with zero 2334 self.assertEqual(geometric_mean([3, 0.0, 5]), 0.0) # Any zero gives a zero 2335 self.assertEqual(geometric_mean([3, -0.0, 5]), 0.0) # Negative zero allowed 2336 self.assertTrue(math.isnan(geometric_mean([0, NaN]))) # NaN beats zero 2337 self.assertTrue(math.isnan(geometric_mean([0, Inf]))) # Because 0.0 * Inf -> NaN 2338 2339 def test_mixed_int_and_float(self): 2340 # Regression test for b.p.o. issue #28327 2341 geometric_mean = statistics.geometric_mean 2342 expected_mean = 3.80675409583932 2343 values = [ 2344 [2, 3, 5, 7], 2345 [2, 3, 5, 7.0], 2346 [2, 3, 5.0, 7.0], 2347 [2, 3.0, 5.0, 7.0], 2348 [2.0, 3.0, 5.0, 7.0], 2349 ] 2350 for v in values: 2351 with self.subTest(v=v): 2352 actual_mean = geometric_mean(v) 2353 self.assertAlmostEqual(actual_mean, expected_mean, places=5) 2354 2355 2356class TestKDE(unittest.TestCase): 2357 2358 def test_kde(self): 2359 kde = statistics.kde 2360 StatisticsError = statistics.StatisticsError 2361 2362 kernels = ['normal', 'gauss', 'logistic', 'sigmoid', 'rectangular', 2363 'uniform', 'triangular', 'parabolic', 'epanechnikov', 2364 'quartic', 'biweight', 'triweight', 'cosine'] 2365 2366 sample = [-2.1, -1.3, -0.4, 1.9, 5.1, 6.2] 2367 2368 # The approximate integral of a PDF should be close to 1.0 2369 2370 def integrate(func, low, high, steps=10_000): 2371 "Numeric approximation of a definite function integral." 2372 dx = (high - low) / steps 2373 midpoints = (low + (i + 1/2) * dx for i in range(steps)) 2374 return sum(map(func, midpoints)) * dx 2375 2376 for kernel in kernels: 2377 with self.subTest(kernel=kernel): 2378 f_hat = kde(sample, h=1.5, kernel=kernel) 2379 area = integrate(f_hat, -20, 20) 2380 self.assertAlmostEqual(area, 1.0, places=4) 2381 2382 # Check CDF against an integral of the PDF 2383 2384 data = [3, 5, 10, 12] 2385 h = 2.3 2386 x = 10.5 2387 for kernel in kernels: 2388 with self.subTest(kernel=kernel): 2389 cdf = kde(data, h, kernel, cumulative=True) 2390 f_hat = kde(data, h, kernel) 2391 area = integrate(f_hat, -20, x, 100_000) 2392 self.assertAlmostEqual(cdf(x), area, places=4) 2393 2394 # Check error cases 2395 2396 with self.assertRaises(StatisticsError): 2397 kde([], h=1.0) # Empty dataset 2398 with self.assertRaises(TypeError): 2399 kde(['abc', 'def'], 1.5) # Non-numeric data 2400 with self.assertRaises(TypeError): 2401 kde(iter(sample), 1.5) # Data is not a sequence 2402 with self.assertRaises(StatisticsError): 2403 kde(sample, h=0.0) # Zero bandwidth 2404 with self.assertRaises(StatisticsError): 2405 kde(sample, h=-1.0) # Negative bandwidth 2406 with self.assertRaises(TypeError): 2407 kde(sample, h='str') # Wrong bandwidth type 2408 with self.assertRaises(StatisticsError): 2409 kde(sample, h=1.0, kernel='bogus') # Invalid kernel 2410 with self.assertRaises(TypeError): 2411 kde(sample, 1.0, 'gauss', True) # Positional cumulative argument 2412 2413 # Test name and docstring of the generated function 2414 2415 h = 1.5 2416 kernel = 'cosine' 2417 f_hat = kde(sample, h, kernel) 2418 self.assertEqual(f_hat.__name__, 'pdf') 2419 self.assertIn(kernel, f_hat.__doc__) 2420 self.assertIn(repr(h), f_hat.__doc__) 2421 2422 # Test closed interval for the support boundaries. 2423 # In particular, 'uniform' should non-zero at the boundaries. 2424 2425 f_hat = kde([0], 1.0, 'uniform') 2426 self.assertEqual(f_hat(-1.0), 1/2) 2427 self.assertEqual(f_hat(1.0), 1/2) 2428 2429 # Test online updates to data 2430 2431 data = [1, 2] 2432 f_hat = kde(data, 5.0, 'triangular') 2433 self.assertEqual(f_hat(100), 0.0) 2434 data.append(100) 2435 self.assertGreater(f_hat(100), 0.0) 2436 2437 def test_kde_kernel_invcdfs(self): 2438 kernel_invcdfs = statistics._kernel_invcdfs 2439 kde = statistics.kde 2440 2441 # Verify that cdf / invcdf will round trip 2442 xarr = [i/100 for i in range(-100, 101)] 2443 for kernel, invcdf in kernel_invcdfs.items(): 2444 with self.subTest(kernel=kernel): 2445 cdf = kde([0.0], h=1.0, kernel=kernel, cumulative=True) 2446 for x in xarr: 2447 self.assertAlmostEqual(invcdf(cdf(x)), x, places=5) 2448 2449 @support.requires_resource('cpu') 2450 def test_kde_random(self): 2451 kde_random = statistics.kde_random 2452 StatisticsError = statistics.StatisticsError 2453 kernels = ['normal', 'gauss', 'logistic', 'sigmoid', 'rectangular', 2454 'uniform', 'triangular', 'parabolic', 'epanechnikov', 2455 'quartic', 'biweight', 'triweight', 'cosine'] 2456 sample = [-2.1, -1.3, -0.4, 1.9, 5.1, 6.2] 2457 2458 # Smoke test 2459 2460 for kernel in kernels: 2461 with self.subTest(kernel=kernel): 2462 rand = kde_random(sample, h=1.5, kernel=kernel) 2463 selections = [rand() for i in range(10)] 2464 2465 # Check error cases 2466 2467 with self.assertRaises(StatisticsError): 2468 kde_random([], h=1.0) # Empty dataset 2469 with self.assertRaises(TypeError): 2470 kde_random(['abc', 'def'], 1.5) # Non-numeric data 2471 with self.assertRaises(TypeError): 2472 kde_random(iter(sample), 1.5) # Data is not a sequence 2473 with self.assertRaises(StatisticsError): 2474 kde_random(sample, h=-1.0) # Zero bandwidth 2475 with self.assertRaises(StatisticsError): 2476 kde_random(sample, h=0.0) # Negative bandwidth 2477 with self.assertRaises(TypeError): 2478 kde_random(sample, h='str') # Wrong bandwidth type 2479 with self.assertRaises(StatisticsError): 2480 kde_random(sample, h=1.0, kernel='bogus') # Invalid kernel 2481 2482 # Test name and docstring of the generated function 2483 2484 h = 1.5 2485 kernel = 'cosine' 2486 rand = kde_random(sample, h, kernel) 2487 self.assertEqual(rand.__name__, 'rand') 2488 self.assertIn(kernel, rand.__doc__) 2489 self.assertIn(repr(h), rand.__doc__) 2490 2491 # Approximate distribution test: Compare a random sample to the expected distribution 2492 2493 data = [-2.1, -1.3, -0.4, 1.9, 5.1, 6.2, 7.8, 14.3, 15.1, 15.3, 15.8, 17.0] 2494 xarr = [x / 10 for x in range(-100, 250)] 2495 n = 1_000_000 2496 h = 1.75 2497 dx = 0.1 2498 2499 def p_observed(x): 2500 # P(x <= X < x+dx) 2501 i = bisect.bisect_left(big_sample, x) 2502 j = bisect.bisect_left(big_sample, x + dx) 2503 return (j - i) / len(big_sample) 2504 2505 def p_expected(x): 2506 # P(x <= X < x+dx) 2507 return F_hat(x + dx) - F_hat(x) 2508 2509 for kernel in kernels: 2510 with self.subTest(kernel=kernel): 2511 2512 rand = kde_random(data, h, kernel, seed=8675309**2) 2513 big_sample = sorted([rand() for i in range(n)]) 2514 F_hat = statistics.kde(data, h, kernel, cumulative=True) 2515 2516 for x in xarr: 2517 self.assertTrue(math.isclose(p_observed(x), p_expected(x), abs_tol=0.0005)) 2518 2519 # Test online updates to data 2520 2521 data = [1, 2] 2522 rand = kde_random(data, 5, 'triangular') 2523 self.assertLess(max([rand() for i in range(5000)]), 10) 2524 data.append(100) 2525 self.assertGreater(max(rand() for i in range(5000)), 10) 2526 2527 2528class TestQuantiles(unittest.TestCase): 2529 2530 def test_specific_cases(self): 2531 # Match results computed by hand and cross-checked 2532 # against the PERCENTILE.EXC function in MS Excel. 2533 quantiles = statistics.quantiles 2534 data = [120, 200, 250, 320, 350] 2535 random.shuffle(data) 2536 for n, expected in [ 2537 (1, []), 2538 (2, [250.0]), 2539 (3, [200.0, 320.0]), 2540 (4, [160.0, 250.0, 335.0]), 2541 (5, [136.0, 220.0, 292.0, 344.0]), 2542 (6, [120.0, 200.0, 250.0, 320.0, 350.0]), 2543 (8, [100.0, 160.0, 212.5, 250.0, 302.5, 335.0, 357.5]), 2544 (10, [88.0, 136.0, 184.0, 220.0, 250.0, 292.0, 326.0, 344.0, 362.0]), 2545 (12, [80.0, 120.0, 160.0, 200.0, 225.0, 250.0, 285.0, 320.0, 335.0, 2546 350.0, 365.0]), 2547 (15, [72.0, 104.0, 136.0, 168.0, 200.0, 220.0, 240.0, 264.0, 292.0, 2548 320.0, 332.0, 344.0, 356.0, 368.0]), 2549 ]: 2550 self.assertEqual(expected, quantiles(data, n=n)) 2551 self.assertEqual(len(quantiles(data, n=n)), n - 1) 2552 # Preserve datatype when possible 2553 for datatype in (float, Decimal, Fraction): 2554 result = quantiles(map(datatype, data), n=n) 2555 self.assertTrue(all(type(x) == datatype) for x in result) 2556 self.assertEqual(result, list(map(datatype, expected))) 2557 # Quantiles should be idempotent 2558 if len(expected) >= 2: 2559 self.assertEqual(quantiles(expected, n=n), expected) 2560 # Cross-check against method='inclusive' which should give 2561 # the same result after adding in minimum and maximum values 2562 # extrapolated from the two lowest and two highest points. 2563 sdata = sorted(data) 2564 lo = 2 * sdata[0] - sdata[1] 2565 hi = 2 * sdata[-1] - sdata[-2] 2566 padded_data = data + [lo, hi] 2567 self.assertEqual( 2568 quantiles(data, n=n), 2569 quantiles(padded_data, n=n, method='inclusive'), 2570 (n, data), 2571 ) 2572 # Invariant under translation and scaling 2573 def f(x): 2574 return 3.5 * x - 1234.675 2575 exp = list(map(f, expected)) 2576 act = quantiles(map(f, data), n=n) 2577 self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act))) 2578 # Q2 agrees with median() 2579 for k in range(2, 60): 2580 data = random.choices(range(100), k=k) 2581 q1, q2, q3 = quantiles(data) 2582 self.assertEqual(q2, statistics.median(data)) 2583 2584 def test_specific_cases_inclusive(self): 2585 # Match results computed by hand and cross-checked 2586 # against the PERCENTILE.INC function in MS Excel 2587 # and against the quantile() function in SciPy. 2588 quantiles = statistics.quantiles 2589 data = [100, 200, 400, 800] 2590 random.shuffle(data) 2591 for n, expected in [ 2592 (1, []), 2593 (2, [300.0]), 2594 (3, [200.0, 400.0]), 2595 (4, [175.0, 300.0, 500.0]), 2596 (5, [160.0, 240.0, 360.0, 560.0]), 2597 (6, [150.0, 200.0, 300.0, 400.0, 600.0]), 2598 (8, [137.5, 175, 225.0, 300.0, 375.0, 500.0,650.0]), 2599 (10, [130.0, 160.0, 190.0, 240.0, 300.0, 360.0, 440.0, 560.0, 680.0]), 2600 (12, [125.0, 150.0, 175.0, 200.0, 250.0, 300.0, 350.0, 400.0, 2601 500.0, 600.0, 700.0]), 2602 (15, [120.0, 140.0, 160.0, 180.0, 200.0, 240.0, 280.0, 320.0, 360.0, 2603 400.0, 480.0, 560.0, 640.0, 720.0]), 2604 ]: 2605 self.assertEqual(expected, quantiles(data, n=n, method="inclusive")) 2606 self.assertEqual(len(quantiles(data, n=n, method="inclusive")), n - 1) 2607 # Preserve datatype when possible 2608 for datatype in (float, Decimal, Fraction): 2609 result = quantiles(map(datatype, data), n=n, method="inclusive") 2610 self.assertTrue(all(type(x) == datatype) for x in result) 2611 self.assertEqual(result, list(map(datatype, expected))) 2612 # Invariant under translation and scaling 2613 def f(x): 2614 return 3.5 * x - 1234.675 2615 exp = list(map(f, expected)) 2616 act = quantiles(map(f, data), n=n, method="inclusive") 2617 self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act))) 2618 # Natural deciles 2619 self.assertEqual(quantiles([0, 100], n=10, method='inclusive'), 2620 [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]) 2621 self.assertEqual(quantiles(range(0, 101), n=10, method='inclusive'), 2622 [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]) 2623 # Whenever n is smaller than the number of data points, running 2624 # method='inclusive' should give the same result as method='exclusive' 2625 # after the two included extreme points are removed. 2626 data = [random.randrange(10_000) for i in range(501)] 2627 actual = quantiles(data, n=32, method='inclusive') 2628 data.remove(min(data)) 2629 data.remove(max(data)) 2630 expected = quantiles(data, n=32) 2631 self.assertEqual(expected, actual) 2632 # Q2 agrees with median() 2633 for k in range(2, 60): 2634 data = random.choices(range(100), k=k) 2635 q1, q2, q3 = quantiles(data, method='inclusive') 2636 self.assertEqual(q2, statistics.median(data)) 2637 # Base case with a single data point: When estimating quantiles from 2638 # a sample, we want to be able to add one sample point at a time, 2639 # getting increasingly better estimates. 2640 self.assertEqual(quantiles([10], n=4), [10.0, 10.0, 10.0]) 2641 self.assertEqual(quantiles([10], n=4, method='exclusive'), [10.0, 10.0, 10.0]) 2642 2643 def test_equal_inputs(self): 2644 quantiles = statistics.quantiles 2645 for n in range(2, 10): 2646 data = [10.0] * n 2647 self.assertEqual(quantiles(data), [10.0, 10.0, 10.0]) 2648 self.assertEqual(quantiles(data, method='inclusive'), 2649 [10.0, 10.0, 10.0]) 2650 2651 def test_equal_sized_groups(self): 2652 quantiles = statistics.quantiles 2653 total = 10_000 2654 data = [random.expovariate(0.2) for i in range(total)] 2655 while len(set(data)) != total: 2656 data.append(random.expovariate(0.2)) 2657 data.sort() 2658 2659 # Cases where the group size exactly divides the total 2660 for n in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000): 2661 group_size = total // n 2662 self.assertEqual( 2663 [bisect.bisect(data, q) for q in quantiles(data, n=n)], 2664 list(range(group_size, total, group_size))) 2665 2666 # When the group sizes can't be exactly equal, they should 2667 # differ by no more than one 2668 for n in (13, 19, 59, 109, 211, 571, 1019, 1907, 5261, 9769): 2669 group_sizes = {total // n, total // n + 1} 2670 pos = [bisect.bisect(data, q) for q in quantiles(data, n=n)] 2671 sizes = {q - p for p, q in zip(pos, pos[1:])} 2672 self.assertTrue(sizes <= group_sizes) 2673 2674 def test_error_cases(self): 2675 quantiles = statistics.quantiles 2676 StatisticsError = statistics.StatisticsError 2677 with self.assertRaises(TypeError): 2678 quantiles() # Missing arguments 2679 with self.assertRaises(TypeError): 2680 quantiles([10, 20, 30], 13, n=4) # Too many arguments 2681 with self.assertRaises(TypeError): 2682 quantiles([10, 20, 30], 4) # n is a positional argument 2683 with self.assertRaises(StatisticsError): 2684 quantiles([10, 20, 30], n=0) # n is zero 2685 with self.assertRaises(StatisticsError): 2686 quantiles([10, 20, 30], n=-1) # n is negative 2687 with self.assertRaises(TypeError): 2688 quantiles([10, 20, 30], n=1.5) # n is not an integer 2689 with self.assertRaises(ValueError): 2690 quantiles([10, 20, 30], method='X') # method is unknown 2691 with self.assertRaises(StatisticsError): 2692 quantiles([], n=4) # not enough data points 2693 with self.assertRaises(TypeError): 2694 quantiles([10, None, 30], n=4) # data is non-numeric 2695 2696 2697class TestBivariateStatistics(unittest.TestCase): 2698 2699 def test_unequal_size_error(self): 2700 for x, y in [ 2701 ([1, 2, 3], [1, 2]), 2702 ([1, 2], [1, 2, 3]), 2703 ]: 2704 with self.assertRaises(statistics.StatisticsError): 2705 statistics.covariance(x, y) 2706 with self.assertRaises(statistics.StatisticsError): 2707 statistics.correlation(x, y) 2708 with self.assertRaises(statistics.StatisticsError): 2709 statistics.linear_regression(x, y) 2710 2711 def test_small_sample_error(self): 2712 for x, y in [ 2713 ([], []), 2714 ([], [1, 2,]), 2715 ([1, 2,], []), 2716 ([1,], [1,]), 2717 ([1,], [1, 2,]), 2718 ([1, 2,], [1,]), 2719 ]: 2720 with self.assertRaises(statistics.StatisticsError): 2721 statistics.covariance(x, y) 2722 with self.assertRaises(statistics.StatisticsError): 2723 statistics.correlation(x, y) 2724 with self.assertRaises(statistics.StatisticsError): 2725 statistics.linear_regression(x, y) 2726 2727 2728class TestCorrelationAndCovariance(unittest.TestCase): 2729 2730 def test_results(self): 2731 for x, y, result in [ 2732 ([1, 2, 3], [1, 2, 3], 1), 2733 ([1, 2, 3], [-1, -2, -3], -1), 2734 ([1, 2, 3], [3, 2, 1], -1), 2735 ([1, 2, 3], [1, 2, 1], 0), 2736 ([1, 2, 3], [1, 3, 2], 0.5), 2737 ]: 2738 self.assertAlmostEqual(statistics.correlation(x, y), result) 2739 self.assertAlmostEqual(statistics.covariance(x, y), result) 2740 2741 def test_different_scales(self): 2742 x = [1, 2, 3] 2743 y = [10, 30, 20] 2744 self.assertAlmostEqual(statistics.correlation(x, y), 0.5) 2745 self.assertAlmostEqual(statistics.covariance(x, y), 5) 2746 2747 y = [.1, .2, .3] 2748 self.assertAlmostEqual(statistics.correlation(x, y), 1) 2749 self.assertAlmostEqual(statistics.covariance(x, y), 0.1) 2750 2751 def test_sqrtprod_helper_function_fundamentals(self): 2752 # Verify that results are close to sqrt(x * y) 2753 for i in range(100): 2754 x = random.expovariate() 2755 y = random.expovariate() 2756 expected = math.sqrt(x * y) 2757 actual = statistics._sqrtprod(x, y) 2758 with self.subTest(x=x, y=y, expected=expected, actual=actual): 2759 self.assertAlmostEqual(expected, actual) 2760 2761 x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661 2762 self.assertEqual(statistics._sqrtprod(x, y), target) 2763 self.assertNotEqual(math.sqrt(x * y), target) 2764 2765 # Test that range extremes avoid underflow and overflow 2766 smallest = sys.float_info.min * sys.float_info.epsilon 2767 self.assertEqual(statistics._sqrtprod(smallest, smallest), smallest) 2768 biggest = sys.float_info.max 2769 self.assertEqual(statistics._sqrtprod(biggest, biggest), biggest) 2770 2771 # Check special values and the sign of the result 2772 special_values = [0.0, -0.0, 1.0, -1.0, 4.0, -4.0, 2773 math.nan, -math.nan, math.inf, -math.inf] 2774 for x, y in itertools.product(special_values, repeat=2): 2775 try: 2776 expected = math.sqrt(x * y) 2777 except ValueError: 2778 expected = 'ValueError' 2779 try: 2780 actual = statistics._sqrtprod(x, y) 2781 except ValueError: 2782 actual = 'ValueError' 2783 with self.subTest(x=x, y=y, expected=expected, actual=actual): 2784 if isinstance(expected, str) and expected == 'ValueError': 2785 self.assertEqual(actual, 'ValueError') 2786 continue 2787 self.assertIsInstance(actual, float) 2788 if math.isnan(expected): 2789 self.assertTrue(math.isnan(actual)) 2790 continue 2791 self.assertEqual(actual, expected) 2792 self.assertEqual(sign(actual), sign(expected)) 2793 2794 @requires_IEEE_754 2795 @unittest.skipIf(HAVE_DOUBLE_ROUNDING, 2796 "accuracy not guaranteed on machines with double rounding") 2797 @support.cpython_only # Allow for a weaker sumprod() implmentation 2798 def test_sqrtprod_helper_function_improved_accuracy(self): 2799 # Test a known example where accuracy is improved 2800 x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661 2801 self.assertEqual(statistics._sqrtprod(x, y), target) 2802 self.assertNotEqual(math.sqrt(x * y), target) 2803 2804 def reference_value(x: float, y: float) -> float: 2805 x = decimal.Decimal(x) 2806 y = decimal.Decimal(y) 2807 with decimal.localcontext() as ctx: 2808 ctx.prec = 200 2809 return float((x * y).sqrt()) 2810 2811 # Verify that the new function with improved accuracy 2812 # agrees with a reference value more often than old version. 2813 new_agreements = 0 2814 old_agreements = 0 2815 for i in range(10_000): 2816 x = random.expovariate() 2817 y = random.expovariate() 2818 new = statistics._sqrtprod(x, y) 2819 old = math.sqrt(x * y) 2820 ref = reference_value(x, y) 2821 new_agreements += (new == ref) 2822 old_agreements += (old == ref) 2823 self.assertGreater(new_agreements, old_agreements) 2824 2825 def test_correlation_spearman(self): 2826 # https://statistics.laerd.com/statistical-guides/spearmans-rank-order-correlation-statistical-guide-2.php 2827 # Compare with: 2828 # >>> import scipy.stats.mstats 2829 # >>> scipy.stats.mstats.spearmanr(reading, mathematics) 2830 # SpearmanrResult(correlation=0.6686960980480712, pvalue=0.03450954165178532) 2831 # And Wolfram Alpha gives: 0.668696 2832 # https://www.wolframalpha.com/input?i=SpearmanRho%5B%7B56%2C+75%2C+45%2C+71%2C+61%2C+64%2C+58%2C+80%2C+76%2C+61%7D%2C+%7B66%2C+70%2C+40%2C+60%2C+65%2C+56%2C+59%2C+77%2C+67%2C+63%7D%5D 2833 reading = [56, 75, 45, 71, 61, 64, 58, 80, 76, 61] 2834 mathematics = [66, 70, 40, 60, 65, 56, 59, 77, 67, 63] 2835 self.assertAlmostEqual(statistics.correlation(reading, mathematics, method='ranked'), 2836 0.6686960980480712) 2837 2838 with self.assertRaises(ValueError): 2839 statistics.correlation(reading, mathematics, method='bad_method') 2840 2841class TestLinearRegression(unittest.TestCase): 2842 2843 def test_constant_input_error(self): 2844 x = [1, 1, 1,] 2845 y = [1, 2, 3,] 2846 with self.assertRaises(statistics.StatisticsError): 2847 statistics.linear_regression(x, y) 2848 2849 def test_results(self): 2850 for x, y, true_intercept, true_slope in [ 2851 ([1, 2, 3], [0, 0, 0], 0, 0), 2852 ([1, 2, 3], [1, 2, 3], 0, 1), 2853 ([1, 2, 3], [100, 100, 100], 100, 0), 2854 ([1, 2, 3], [12, 14, 16], 10, 2), 2855 ([1, 2, 3], [-1, -2, -3], 0, -1), 2856 ([1, 2, 3], [21, 22, 23], 20, 1), 2857 ([1, 2, 3], [5.1, 5.2, 5.3], 5, 0.1), 2858 ]: 2859 slope, intercept = statistics.linear_regression(x, y) 2860 self.assertAlmostEqual(intercept, true_intercept) 2861 self.assertAlmostEqual(slope, true_slope) 2862 2863 def test_proportional(self): 2864 x = [10, 20, 30, 40] 2865 y = [180, 398, 610, 799] 2866 slope, intercept = statistics.linear_regression(x, y, proportional=True) 2867 self.assertAlmostEqual(slope, 20 + 1/150) 2868 self.assertEqual(intercept, 0.0) 2869 2870 def test_float_output(self): 2871 x = [Fraction(2, 3), Fraction(3, 4)] 2872 y = [Fraction(4, 5), Fraction(5, 6)] 2873 slope, intercept = statistics.linear_regression(x, y) 2874 self.assertTrue(isinstance(slope, float)) 2875 self.assertTrue(isinstance(intercept, float)) 2876 slope, intercept = statistics.linear_regression(x, y, proportional=True) 2877 self.assertTrue(isinstance(slope, float)) 2878 self.assertTrue(isinstance(intercept, float)) 2879 2880class TestNormalDist: 2881 2882 # General note on precision: The pdf(), cdf(), and overlap() methods 2883 # depend on functions in the math libraries that do not make 2884 # explicit accuracy guarantees. Accordingly, some of the accuracy 2885 # tests below may fail if the underlying math functions are 2886 # inaccurate. There isn't much we can do about this short of 2887 # implementing our own implementations from scratch. 2888 2889 def test_slots(self): 2890 nd = self.module.NormalDist(300, 23) 2891 with self.assertRaises(TypeError): 2892 vars(nd) 2893 self.assertEqual(tuple(nd.__slots__), ('_mu', '_sigma')) 2894 2895 def test_instantiation_and_attributes(self): 2896 nd = self.module.NormalDist(500, 17) 2897 self.assertEqual(nd.mean, 500) 2898 self.assertEqual(nd.stdev, 17) 2899 self.assertEqual(nd.variance, 17**2) 2900 2901 # default arguments 2902 nd = self.module.NormalDist() 2903 self.assertEqual(nd.mean, 0) 2904 self.assertEqual(nd.stdev, 1) 2905 self.assertEqual(nd.variance, 1**2) 2906 2907 # error case: negative sigma 2908 with self.assertRaises(self.module.StatisticsError): 2909 self.module.NormalDist(500, -10) 2910 2911 # verify that subclass type is honored 2912 class NewNormalDist(self.module.NormalDist): 2913 pass 2914 nnd = NewNormalDist(200, 5) 2915 self.assertEqual(type(nnd), NewNormalDist) 2916 2917 def test_alternative_constructor(self): 2918 NormalDist = self.module.NormalDist 2919 data = [96, 107, 90, 92, 110] 2920 # list input 2921 self.assertEqual(NormalDist.from_samples(data), NormalDist(99, 9)) 2922 # tuple input 2923 self.assertEqual(NormalDist.from_samples(tuple(data)), NormalDist(99, 9)) 2924 # iterator input 2925 self.assertEqual(NormalDist.from_samples(iter(data)), NormalDist(99, 9)) 2926 # error cases 2927 with self.assertRaises(self.module.StatisticsError): 2928 NormalDist.from_samples([]) # empty input 2929 with self.assertRaises(self.module.StatisticsError): 2930 NormalDist.from_samples([10]) # only one input 2931 2932 # verify that subclass type is honored 2933 class NewNormalDist(NormalDist): 2934 pass 2935 nnd = NewNormalDist.from_samples(data) 2936 self.assertEqual(type(nnd), NewNormalDist) 2937 2938 def test_sample_generation(self): 2939 NormalDist = self.module.NormalDist 2940 mu, sigma = 10_000, 3.0 2941 X = NormalDist(mu, sigma) 2942 n = 1_000 2943 data = X.samples(n) 2944 self.assertEqual(len(data), n) 2945 self.assertEqual(set(map(type, data)), {float}) 2946 # mean(data) expected to fall within 8 standard deviations 2947 xbar = self.module.mean(data) 2948 self.assertTrue(mu - sigma*8 <= xbar <= mu + sigma*8) 2949 2950 # verify that seeding makes reproducible sequences 2951 n = 100 2952 data1 = X.samples(n, seed='happiness and joy') 2953 data2 = X.samples(n, seed='trouble and despair') 2954 data3 = X.samples(n, seed='happiness and joy') 2955 data4 = X.samples(n, seed='trouble and despair') 2956 self.assertEqual(data1, data3) 2957 self.assertEqual(data2, data4) 2958 self.assertNotEqual(data1, data2) 2959 2960 def test_pdf(self): 2961 NormalDist = self.module.NormalDist 2962 X = NormalDist(100, 15) 2963 # Verify peak around center 2964 self.assertLess(X.pdf(99), X.pdf(100)) 2965 self.assertLess(X.pdf(101), X.pdf(100)) 2966 # Test symmetry 2967 for i in range(50): 2968 self.assertAlmostEqual(X.pdf(100 - i), X.pdf(100 + i)) 2969 # Test vs CDF 2970 dx = 2.0 ** -10 2971 for x in range(90, 111): 2972 est_pdf = (X.cdf(x + dx) - X.cdf(x)) / dx 2973 self.assertAlmostEqual(X.pdf(x), est_pdf, places=4) 2974 # Test vs table of known values -- CRC 26th Edition 2975 Z = NormalDist() 2976 for x, px in enumerate([ 2977 0.3989, 0.3989, 0.3989, 0.3988, 0.3986, 2978 0.3984, 0.3982, 0.3980, 0.3977, 0.3973, 2979 0.3970, 0.3965, 0.3961, 0.3956, 0.3951, 2980 0.3945, 0.3939, 0.3932, 0.3925, 0.3918, 2981 0.3910, 0.3902, 0.3894, 0.3885, 0.3876, 2982 0.3867, 0.3857, 0.3847, 0.3836, 0.3825, 2983 0.3814, 0.3802, 0.3790, 0.3778, 0.3765, 2984 0.3752, 0.3739, 0.3725, 0.3712, 0.3697, 2985 0.3683, 0.3668, 0.3653, 0.3637, 0.3621, 2986 0.3605, 0.3589, 0.3572, 0.3555, 0.3538, 2987 ]): 2988 self.assertAlmostEqual(Z.pdf(x / 100.0), px, places=4) 2989 self.assertAlmostEqual(Z.pdf(-x / 100.0), px, places=4) 2990 # Error case: variance is zero 2991 Y = NormalDist(100, 0) 2992 with self.assertRaises(self.module.StatisticsError): 2993 Y.pdf(90) 2994 # Special values 2995 self.assertEqual(X.pdf(float('-Inf')), 0.0) 2996 self.assertEqual(X.pdf(float('Inf')), 0.0) 2997 self.assertTrue(math.isnan(X.pdf(float('NaN')))) 2998 2999 def test_cdf(self): 3000 NormalDist = self.module.NormalDist 3001 X = NormalDist(100, 15) 3002 cdfs = [X.cdf(x) for x in range(1, 200)] 3003 self.assertEqual(set(map(type, cdfs)), {float}) 3004 # Verify montonic 3005 self.assertEqual(cdfs, sorted(cdfs)) 3006 # Verify center (should be exact) 3007 self.assertEqual(X.cdf(100), 0.50) 3008 # Check against a table of known values 3009 # https://en.wikipedia.org/wiki/Standard_normal_table#Cumulative 3010 Z = NormalDist() 3011 for z, cum_prob in [ 3012 (0.00, 0.50000), (0.01, 0.50399), (0.02, 0.50798), 3013 (0.14, 0.55567), (0.29, 0.61409), (0.33, 0.62930), 3014 (0.54, 0.70540), (0.60, 0.72575), (1.17, 0.87900), 3015 (1.60, 0.94520), (2.05, 0.97982), (2.89, 0.99807), 3016 (3.52, 0.99978), (3.98, 0.99997), (4.07, 0.99998), 3017 ]: 3018 self.assertAlmostEqual(Z.cdf(z), cum_prob, places=5) 3019 self.assertAlmostEqual(Z.cdf(-z), 1.0 - cum_prob, places=5) 3020 # Error case: variance is zero 3021 Y = NormalDist(100, 0) 3022 with self.assertRaises(self.module.StatisticsError): 3023 Y.cdf(90) 3024 # Special values 3025 self.assertEqual(X.cdf(float('-Inf')), 0.0) 3026 self.assertEqual(X.cdf(float('Inf')), 1.0) 3027 self.assertTrue(math.isnan(X.cdf(float('NaN')))) 3028 3029 @support.skip_if_pgo_task 3030 @support.requires_resource('cpu') 3031 def test_inv_cdf(self): 3032 NormalDist = self.module.NormalDist 3033 3034 # Center case should be exact. 3035 iq = NormalDist(100, 15) 3036 self.assertEqual(iq.inv_cdf(0.50), iq.mean) 3037 3038 # Test versus a published table of known percentage points. 3039 # See the second table at the bottom of the page here: 3040 # http://people.bath.ac.uk/masss/tables/normaltable.pdf 3041 Z = NormalDist() 3042 pp = {5.0: (0.000, 1.645, 2.576, 3.291, 3.891, 3043 4.417, 4.892, 5.327, 5.731, 6.109), 3044 2.5: (0.674, 1.960, 2.807, 3.481, 4.056, 3045 4.565, 5.026, 5.451, 5.847, 6.219), 3046 1.0: (1.282, 2.326, 3.090, 3.719, 4.265, 3047 4.753, 5.199, 5.612, 5.998, 6.361)} 3048 for base, row in pp.items(): 3049 for exp, x in enumerate(row, start=1): 3050 p = base * 10.0 ** (-exp) 3051 self.assertAlmostEqual(-Z.inv_cdf(p), x, places=3) 3052 p = 1.0 - p 3053 self.assertAlmostEqual(Z.inv_cdf(p), x, places=3) 3054 3055 # Match published example for MS Excel 3056 # https://support.office.com/en-us/article/norm-inv-function-54b30935-fee7-493c-bedb-2278a9db7e13 3057 self.assertAlmostEqual(NormalDist(40, 1.5).inv_cdf(0.908789), 42.000002) 3058 3059 # One million equally spaced probabilities 3060 n = 2**20 3061 for p in range(1, n): 3062 p /= n 3063 self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p) 3064 3065 # One hundred ever smaller probabilities to test tails out to 3066 # extreme probabilities: 1 / 2**50 and (2**50-1) / 2 ** 50 3067 for e in range(1, 51): 3068 p = 2.0 ** (-e) 3069 self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p) 3070 p = 1.0 - p 3071 self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p) 3072 3073 # Now apply cdf() first. Near the tails, the round-trip loses 3074 # precision and is ill-conditioned (small changes in the inputs 3075 # give large changes in the output), so only check to 5 places. 3076 for x in range(200): 3077 self.assertAlmostEqual(iq.inv_cdf(iq.cdf(x)), x, places=5) 3078 3079 # Error cases: 3080 with self.assertRaises(self.module.StatisticsError): 3081 iq.inv_cdf(0.0) # p is zero 3082 with self.assertRaises(self.module.StatisticsError): 3083 iq.inv_cdf(-0.1) # p under zero 3084 with self.assertRaises(self.module.StatisticsError): 3085 iq.inv_cdf(1.0) # p is one 3086 with self.assertRaises(self.module.StatisticsError): 3087 iq.inv_cdf(1.1) # p over one 3088 3089 # Supported case: 3090 iq = NormalDist(100, 0) # sigma is zero 3091 self.assertEqual(iq.inv_cdf(0.5), 100) 3092 3093 # Special values 3094 self.assertTrue(math.isnan(Z.inv_cdf(float('NaN')))) 3095 3096 def test_quantiles(self): 3097 # Quartiles of a standard normal distribution 3098 Z = self.module.NormalDist() 3099 for n, expected in [ 3100 (1, []), 3101 (2, [0.0]), 3102 (3, [-0.4307, 0.4307]), 3103 (4 ,[-0.6745, 0.0, 0.6745]), 3104 ]: 3105 actual = Z.quantiles(n=n) 3106 self.assertTrue(all(math.isclose(e, a, abs_tol=0.0001) 3107 for e, a in zip(expected, actual))) 3108 3109 def test_overlap(self): 3110 NormalDist = self.module.NormalDist 3111 3112 # Match examples from Imman and Bradley 3113 for X1, X2, published_result in [ 3114 (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0), 0.80258), 3115 (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0), 0.60993), 3116 ]: 3117 self.assertAlmostEqual(X1.overlap(X2), published_result, places=4) 3118 self.assertAlmostEqual(X2.overlap(X1), published_result, places=4) 3119 3120 # Check against integration of the PDF 3121 def overlap_numeric(X, Y, *, steps=8_192, z=5): 3122 'Numerical integration cross-check for overlap() ' 3123 fsum = math.fsum 3124 center = (X.mean + Y.mean) / 2.0 3125 width = z * max(X.stdev, Y.stdev) 3126 start = center - width 3127 dx = 2.0 * width / steps 3128 x_arr = [start + i*dx for i in range(steps)] 3129 xp = list(map(X.pdf, x_arr)) 3130 yp = list(map(Y.pdf, x_arr)) 3131 total = max(fsum(xp), fsum(yp)) 3132 return fsum(map(min, xp, yp)) / total 3133 3134 for X1, X2 in [ 3135 # Examples from Imman and Bradley 3136 (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0)), 3137 (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)), 3138 # Example from https://www.rasch.org/rmt/rmt101r.htm 3139 (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)), 3140 # Gender heights from http://www.usablestats.com/lessons/normal 3141 (NormalDist(70, 4), NormalDist(65, 3.5)), 3142 # Misc cases with equal standard deviations 3143 (NormalDist(100, 15), NormalDist(110, 15)), 3144 (NormalDist(-100, 15), NormalDist(110, 15)), 3145 (NormalDist(-100, 15), NormalDist(-110, 15)), 3146 # Misc cases with unequal standard deviations 3147 (NormalDist(100, 12), NormalDist(100, 15)), 3148 (NormalDist(100, 12), NormalDist(110, 15)), 3149 (NormalDist(100, 12), NormalDist(150, 15)), 3150 (NormalDist(100, 12), NormalDist(150, 35)), 3151 # Misc cases with small values 3152 (NormalDist(1.000, 0.002), NormalDist(1.001, 0.003)), 3153 (NormalDist(1.000, 0.002), NormalDist(1.006, 0.0003)), 3154 (NormalDist(1.000, 0.002), NormalDist(1.001, 0.099)), 3155 ]: 3156 self.assertAlmostEqual(X1.overlap(X2), overlap_numeric(X1, X2), places=5) 3157 self.assertAlmostEqual(X2.overlap(X1), overlap_numeric(X1, X2), places=5) 3158 3159 # Error cases 3160 X = NormalDist() 3161 with self.assertRaises(TypeError): 3162 X.overlap() # too few arguments 3163 with self.assertRaises(TypeError): 3164 X.overlap(X, X) # too may arguments 3165 with self.assertRaises(TypeError): 3166 X.overlap(None) # right operand not a NormalDist 3167 with self.assertRaises(self.module.StatisticsError): 3168 X.overlap(NormalDist(1, 0)) # right operand sigma is zero 3169 with self.assertRaises(self.module.StatisticsError): 3170 NormalDist(1, 0).overlap(X) # left operand sigma is zero 3171 3172 def test_zscore(self): 3173 NormalDist = self.module.NormalDist 3174 X = NormalDist(100, 15) 3175 self.assertEqual(X.zscore(142), 2.8) 3176 self.assertEqual(X.zscore(58), -2.8) 3177 self.assertEqual(X.zscore(100), 0.0) 3178 with self.assertRaises(TypeError): 3179 X.zscore() # too few arguments 3180 with self.assertRaises(TypeError): 3181 X.zscore(1, 1) # too may arguments 3182 with self.assertRaises(TypeError): 3183 X.zscore(None) # non-numeric type 3184 with self.assertRaises(self.module.StatisticsError): 3185 NormalDist(1, 0).zscore(100) # sigma is zero 3186 3187 def test_properties(self): 3188 X = self.module.NormalDist(100, 15) 3189 self.assertEqual(X.mean, 100) 3190 self.assertEqual(X.median, 100) 3191 self.assertEqual(X.mode, 100) 3192 self.assertEqual(X.stdev, 15) 3193 self.assertEqual(X.variance, 225) 3194 3195 def test_same_type_addition_and_subtraction(self): 3196 NormalDist = self.module.NormalDist 3197 X = NormalDist(100, 12) 3198 Y = NormalDist(40, 5) 3199 self.assertEqual(X + Y, NormalDist(140, 13)) # __add__ 3200 self.assertEqual(X - Y, NormalDist(60, 13)) # __sub__ 3201 3202 def test_translation_and_scaling(self): 3203 NormalDist = self.module.NormalDist 3204 X = NormalDist(100, 15) 3205 y = 10 3206 self.assertEqual(+X, NormalDist(100, 15)) # __pos__ 3207 self.assertEqual(-X, NormalDist(-100, 15)) # __neg__ 3208 self.assertEqual(X + y, NormalDist(110, 15)) # __add__ 3209 self.assertEqual(y + X, NormalDist(110, 15)) # __radd__ 3210 self.assertEqual(X - y, NormalDist(90, 15)) # __sub__ 3211 self.assertEqual(y - X, NormalDist(-90, 15)) # __rsub__ 3212 self.assertEqual(X * y, NormalDist(1000, 150)) # __mul__ 3213 self.assertEqual(y * X, NormalDist(1000, 150)) # __rmul__ 3214 self.assertEqual(X / y, NormalDist(10, 1.5)) # __truediv__ 3215 with self.assertRaises(TypeError): # __rtruediv__ 3216 y / X 3217 3218 def test_unary_operations(self): 3219 NormalDist = self.module.NormalDist 3220 X = NormalDist(100, 12) 3221 Y = +X 3222 self.assertIsNot(X, Y) 3223 self.assertEqual(X.mean, Y.mean) 3224 self.assertEqual(X.stdev, Y.stdev) 3225 Y = -X 3226 self.assertIsNot(X, Y) 3227 self.assertEqual(X.mean, -Y.mean) 3228 self.assertEqual(X.stdev, Y.stdev) 3229 3230 def test_equality(self): 3231 NormalDist = self.module.NormalDist 3232 nd1 = NormalDist() 3233 nd2 = NormalDist(2, 4) 3234 nd3 = NormalDist() 3235 nd4 = NormalDist(2, 4) 3236 nd5 = NormalDist(2, 8) 3237 nd6 = NormalDist(8, 4) 3238 self.assertNotEqual(nd1, nd2) 3239 self.assertEqual(nd1, nd3) 3240 self.assertEqual(nd2, nd4) 3241 self.assertNotEqual(nd2, nd5) 3242 self.assertNotEqual(nd2, nd6) 3243 3244 # Test NotImplemented when types are different 3245 class A: 3246 def __eq__(self, other): 3247 return 10 3248 a = A() 3249 self.assertEqual(nd1.__eq__(a), NotImplemented) 3250 self.assertEqual(nd1 == a, 10) 3251 self.assertEqual(a == nd1, 10) 3252 3253 # All subclasses to compare equal giving the same behavior 3254 # as list, tuple, int, float, complex, str, dict, set, etc. 3255 class SizedNormalDist(NormalDist): 3256 def __init__(self, mu, sigma, n): 3257 super().__init__(mu, sigma) 3258 self.n = n 3259 s = SizedNormalDist(100, 15, 57) 3260 nd4 = NormalDist(100, 15) 3261 self.assertEqual(s, nd4) 3262 3263 # Don't allow duck type equality because we wouldn't 3264 # want a lognormal distribution to compare equal 3265 # to a normal distribution with the same parameters 3266 class LognormalDist: 3267 def __init__(self, mu, sigma): 3268 self.mu = mu 3269 self.sigma = sigma 3270 lnd = LognormalDist(100, 15) 3271 nd = NormalDist(100, 15) 3272 self.assertNotEqual(nd, lnd) 3273 3274 def test_copy(self): 3275 nd = self.module.NormalDist(37.5, 5.625) 3276 nd1 = copy.copy(nd) 3277 self.assertEqual(nd, nd1) 3278 nd2 = copy.deepcopy(nd) 3279 self.assertEqual(nd, nd2) 3280 3281 def test_pickle(self): 3282 nd = self.module.NormalDist(37.5, 5.625) 3283 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 3284 with self.subTest(proto=proto): 3285 pickled = pickle.loads(pickle.dumps(nd, protocol=proto)) 3286 self.assertEqual(nd, pickled) 3287 3288 def test_hashability(self): 3289 ND = self.module.NormalDist 3290 s = {ND(100, 15), ND(100.0, 15.0), ND(100, 10), ND(95, 15), ND(100, 15)} 3291 self.assertEqual(len(s), 3) 3292 3293 def test_repr(self): 3294 nd = self.module.NormalDist(37.5, 5.625) 3295 self.assertEqual(repr(nd), 'NormalDist(mu=37.5, sigma=5.625)') 3296 3297# Swapping the sys.modules['statistics'] is to solving the 3298# _pickle.PicklingError: 3299# Can't pickle <class 'statistics.NormalDist'>: 3300# it's not the same object as statistics.NormalDist 3301class TestNormalDistPython(unittest.TestCase, TestNormalDist): 3302 module = py_statistics 3303 def setUp(self): 3304 sys.modules['statistics'] = self.module 3305 3306 def tearDown(self): 3307 sys.modules['statistics'] = statistics 3308 3309 3310@unittest.skipUnless(c_statistics, 'requires _statistics') 3311class TestNormalDistC(unittest.TestCase, TestNormalDist): 3312 module = c_statistics 3313 def setUp(self): 3314 sys.modules['statistics'] = self.module 3315 3316 def tearDown(self): 3317 sys.modules['statistics'] = statistics 3318 3319 3320# === Run tests === 3321 3322def load_tests(loader, tests, ignore): 3323 """Used for doctest/unittest integration.""" 3324 tests.addTests(doctest.DocTestSuite()) 3325 tests.addTests(doctest.DocTestSuite(statistics)) 3326 return tests 3327 3328 3329if __name__ == "__main__": 3330 unittest.main() 3331