1"""Test suite for statistics module, including helper NumericTestCase and 2approx_equal function. 3 4""" 5 6import bisect 7import collections 8import collections.abc 9import copy 10import decimal 11import doctest 12import math 13import pickle 14import random 15import sys 16import unittest 17from test import support 18from test.support import import_helper 19 20from decimal import Decimal 21from fractions import Fraction 22 23 24# Module to be tested. 25import statistics 26 27 28# === Helper functions and class === 29 30def sign(x): 31 """Return -1.0 for negatives, including -0.0, otherwise +1.0.""" 32 return math.copysign(1, x) 33 34def _nan_equal(a, b): 35 """Return True if a and b are both the same kind of NAN. 36 37 >>> _nan_equal(Decimal('NAN'), Decimal('NAN')) 38 True 39 >>> _nan_equal(Decimal('sNAN'), Decimal('sNAN')) 40 True 41 >>> _nan_equal(Decimal('NAN'), Decimal('sNAN')) 42 False 43 >>> _nan_equal(Decimal(42), Decimal('NAN')) 44 False 45 46 >>> _nan_equal(float('NAN'), float('NAN')) 47 True 48 >>> _nan_equal(float('NAN'), 0.5) 49 False 50 51 >>> _nan_equal(float('NAN'), Decimal('NAN')) 52 False 53 54 NAN payloads are not compared. 55 """ 56 if type(a) is not type(b): 57 return False 58 if isinstance(a, float): 59 return math.isnan(a) and math.isnan(b) 60 aexp = a.as_tuple()[2] 61 bexp = b.as_tuple()[2] 62 return (aexp == bexp) and (aexp in ('n', 'N')) # Both NAN or both sNAN. 63 64 65def _calc_errors(actual, expected): 66 """Return the absolute and relative errors between two numbers. 67 68 >>> _calc_errors(100, 75) 69 (25, 0.25) 70 >>> _calc_errors(100, 100) 71 (0, 0.0) 72 73 Returns the (absolute error, relative error) between the two arguments. 74 """ 75 base = max(abs(actual), abs(expected)) 76 abs_err = abs(actual - expected) 77 rel_err = abs_err/base if base else float('inf') 78 return (abs_err, rel_err) 79 80 81def approx_equal(x, y, tol=1e-12, rel=1e-7): 82 """approx_equal(x, y [, tol [, rel]]) => True|False 83 84 Return True if numbers x and y are approximately equal, to within some 85 margin of error, otherwise return False. Numbers which compare equal 86 will also compare approximately equal. 87 88 x is approximately equal to y if the difference between them is less than 89 an absolute error tol or a relative error rel, whichever is bigger. 90 91 If given, both tol and rel must be finite, non-negative numbers. If not 92 given, default values are tol=1e-12 and rel=1e-7. 93 94 >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0) 95 True 96 >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0) 97 False 98 99 Absolute error is defined as abs(x-y); if that is less than or equal to 100 tol, x and y are considered approximately equal. 101 102 Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is 103 smaller, provided x or y are not zero. If that figure is less than or 104 equal to rel, x and y are considered approximately equal. 105 106 Complex numbers are not directly supported. If you wish to compare to 107 complex numbers, extract their real and imaginary parts and compare them 108 individually. 109 110 NANs always compare unequal, even with themselves. Infinities compare 111 approximately equal if they have the same sign (both positive or both 112 negative). Infinities with different signs compare unequal; so do 113 comparisons of infinities with finite numbers. 114 """ 115 if tol < 0 or rel < 0: 116 raise ValueError('error tolerances must be non-negative') 117 # NANs are never equal to anything, approximately or otherwise. 118 if math.isnan(x) or math.isnan(y): 119 return False 120 # Numbers which compare equal also compare approximately equal. 121 if x == y: 122 # This includes the case of two infinities with the same sign. 123 return True 124 if math.isinf(x) or math.isinf(y): 125 # This includes the case of two infinities of opposite sign, or 126 # one infinity and one finite number. 127 return False 128 # Two finite numbers. 129 actual_error = abs(x - y) 130 allowed_error = max(tol, rel*max(abs(x), abs(y))) 131 return actual_error <= allowed_error 132 133 134# This class exists only as somewhere to stick a docstring containing 135# doctests. The following docstring and tests were originally in a separate 136# module. Now that it has been merged in here, I need somewhere to hang the. 137# docstring. Ultimately, this class will die, and the information below will 138# either become redundant, or be moved into more appropriate places. 139class _DoNothing: 140 """ 141 When doing numeric work, especially with floats, exact equality is often 142 not what you want. Due to round-off error, it is often a bad idea to try 143 to compare floats with equality. Instead the usual procedure is to test 144 them with some (hopefully small!) allowance for error. 145 146 The ``approx_equal`` function allows you to specify either an absolute 147 error tolerance, or a relative error, or both. 148 149 Absolute error tolerances are simple, but you need to know the magnitude 150 of the quantities being compared: 151 152 >>> approx_equal(12.345, 12.346, tol=1e-3) 153 True 154 >>> approx_equal(12.345e6, 12.346e6, tol=1e-3) # tol is too small. 155 False 156 157 Relative errors are more suitable when the values you are comparing can 158 vary in magnitude: 159 160 >>> approx_equal(12.345, 12.346, rel=1e-4) 161 True 162 >>> approx_equal(12.345e6, 12.346e6, rel=1e-4) 163 True 164 165 but a naive implementation of relative error testing can run into trouble 166 around zero. 167 168 If you supply both an absolute tolerance and a relative error, the 169 comparison succeeds if either individual test succeeds: 170 171 >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4) 172 True 173 174 """ 175 pass 176 177 178 179# We prefer this for testing numeric values that may not be exactly equal, 180# and avoid using TestCase.assertAlmostEqual, because it sucks :-) 181 182py_statistics = import_helper.import_fresh_module('statistics', 183 blocked=['_statistics']) 184c_statistics = import_helper.import_fresh_module('statistics', 185 fresh=['_statistics']) 186 187 188class TestModules(unittest.TestCase): 189 func_names = ['_normal_dist_inv_cdf'] 190 191 def test_py_functions(self): 192 for fname in self.func_names: 193 self.assertEqual(getattr(py_statistics, fname).__module__, 'statistics') 194 195 @unittest.skipUnless(c_statistics, 'requires _statistics') 196 def test_c_functions(self): 197 for fname in self.func_names: 198 self.assertEqual(getattr(c_statistics, fname).__module__, '_statistics') 199 200 201class NumericTestCase(unittest.TestCase): 202 """Unit test class for numeric work. 203 204 This subclasses TestCase. In addition to the standard method 205 ``TestCase.assertAlmostEqual``, ``assertApproxEqual`` is provided. 206 """ 207 # By default, we expect exact equality, unless overridden. 208 tol = rel = 0 209 210 def assertApproxEqual( 211 self, first, second, tol=None, rel=None, msg=None 212 ): 213 """Test passes if ``first`` and ``second`` are approximately equal. 214 215 This test passes if ``first`` and ``second`` are equal to 216 within ``tol``, an absolute error, or ``rel``, a relative error. 217 218 If either ``tol`` or ``rel`` are None or not given, they default to 219 test attributes of the same name (by default, 0). 220 221 The objects may be either numbers, or sequences of numbers. Sequences 222 are tested element-by-element. 223 224 >>> class MyTest(NumericTestCase): 225 ... def test_number(self): 226 ... x = 1.0/6 227 ... y = sum([x]*6) 228 ... self.assertApproxEqual(y, 1.0, tol=1e-15) 229 ... def test_sequence(self): 230 ... a = [1.001, 1.001e-10, 1.001e10] 231 ... b = [1.0, 1e-10, 1e10] 232 ... self.assertApproxEqual(a, b, rel=1e-3) 233 ... 234 >>> import unittest 235 >>> from io import StringIO # Suppress test runner output. 236 >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest) 237 >>> unittest.TextTestRunner(stream=StringIO()).run(suite) 238 <unittest.runner.TextTestResult run=2 errors=0 failures=0> 239 240 """ 241 if tol is None: 242 tol = self.tol 243 if rel is None: 244 rel = self.rel 245 if ( 246 isinstance(first, collections.abc.Sequence) and 247 isinstance(second, collections.abc.Sequence) 248 ): 249 check = self._check_approx_seq 250 else: 251 check = self._check_approx_num 252 check(first, second, tol, rel, msg) 253 254 def _check_approx_seq(self, first, second, tol, rel, msg): 255 if len(first) != len(second): 256 standardMsg = ( 257 "sequences differ in length: %d items != %d items" 258 % (len(first), len(second)) 259 ) 260 msg = self._formatMessage(msg, standardMsg) 261 raise self.failureException(msg) 262 for i, (a,e) in enumerate(zip(first, second)): 263 self._check_approx_num(a, e, tol, rel, msg, i) 264 265 def _check_approx_num(self, first, second, tol, rel, msg, idx=None): 266 if approx_equal(first, second, tol, rel): 267 # Test passes. Return early, we are done. 268 return None 269 # Otherwise we failed. 270 standardMsg = self._make_std_err_msg(first, second, tol, rel, idx) 271 msg = self._formatMessage(msg, standardMsg) 272 raise self.failureException(msg) 273 274 @staticmethod 275 def _make_std_err_msg(first, second, tol, rel, idx): 276 # Create the standard error message for approx_equal failures. 277 assert first != second 278 template = ( 279 ' %r != %r\n' 280 ' values differ by more than tol=%r and rel=%r\n' 281 ' -> absolute error = %r\n' 282 ' -> relative error = %r' 283 ) 284 if idx is not None: 285 header = 'numeric sequences first differ at index %d.\n' % idx 286 template = header + template 287 # Calculate actual errors: 288 abs_err, rel_err = _calc_errors(first, second) 289 return template % (first, second, tol, rel, abs_err, rel_err) 290 291 292# ======================== 293# === Test the helpers === 294# ======================== 295 296class TestSign(unittest.TestCase): 297 """Test that the helper function sign() works correctly.""" 298 def testZeroes(self): 299 # Test that signed zeroes report their sign correctly. 300 self.assertEqual(sign(0.0), +1) 301 self.assertEqual(sign(-0.0), -1) 302 303 304# --- Tests for approx_equal --- 305 306class ApproxEqualSymmetryTest(unittest.TestCase): 307 # Test symmetry of approx_equal. 308 309 def test_relative_symmetry(self): 310 # Check that approx_equal treats relative error symmetrically. 311 # (a-b)/a is usually not equal to (a-b)/b. Ensure that this 312 # doesn't matter. 313 # 314 # Note: the reason for this test is that an early version 315 # of approx_equal was not symmetric. A relative error test 316 # would pass, or fail, depending on which value was passed 317 # as the first argument. 318 # 319 args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)] 320 args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)] 321 assert len(args1) == len(args2) 322 for a, b in zip(args1, args2): 323 self.do_relative_symmetry(a, b) 324 325 def do_relative_symmetry(self, a, b): 326 a, b = min(a, b), max(a, b) 327 assert a < b 328 delta = b - a # The absolute difference between the values. 329 rel_err1, rel_err2 = abs(delta/a), abs(delta/b) 330 # Choose an error margin halfway between the two. 331 rel = (rel_err1 + rel_err2)/2 332 # Now see that values a and b compare approx equal regardless of 333 # which is given first. 334 self.assertTrue(approx_equal(a, b, tol=0, rel=rel)) 335 self.assertTrue(approx_equal(b, a, tol=0, rel=rel)) 336 337 def test_symmetry(self): 338 # Test that approx_equal(a, b) == approx_equal(b, a) 339 args = [-23, -2, 5, 107, 93568] 340 delta = 2 341 for a in args: 342 for type_ in (int, float, Decimal, Fraction): 343 x = type_(a)*100 344 y = x + delta 345 r = abs(delta/max(x, y)) 346 # There are five cases to check: 347 # 1) actual error <= tol, <= rel 348 self.do_symmetry_test(x, y, tol=delta, rel=r) 349 self.do_symmetry_test(x, y, tol=delta+1, rel=2*r) 350 # 2) actual error > tol, > rel 351 self.do_symmetry_test(x, y, tol=delta-1, rel=r/2) 352 # 3) actual error <= tol, > rel 353 self.do_symmetry_test(x, y, tol=delta, rel=r/2) 354 # 4) actual error > tol, <= rel 355 self.do_symmetry_test(x, y, tol=delta-1, rel=r) 356 self.do_symmetry_test(x, y, tol=delta-1, rel=2*r) 357 # 5) exact equality test 358 self.do_symmetry_test(x, x, tol=0, rel=0) 359 self.do_symmetry_test(x, y, tol=0, rel=0) 360 361 def do_symmetry_test(self, a, b, tol, rel): 362 template = "approx_equal comparisons don't match for %r" 363 flag1 = approx_equal(a, b, tol, rel) 364 flag2 = approx_equal(b, a, tol, rel) 365 self.assertEqual(flag1, flag2, template.format((a, b, tol, rel))) 366 367 368class ApproxEqualExactTest(unittest.TestCase): 369 # Test the approx_equal function with exactly equal values. 370 # Equal values should compare as approximately equal. 371 # Test cases for exactly equal values, which should compare approx 372 # equal regardless of the error tolerances given. 373 374 def do_exactly_equal_test(self, x, tol, rel): 375 result = approx_equal(x, x, tol=tol, rel=rel) 376 self.assertTrue(result, 'equality failure for x=%r' % x) 377 result = approx_equal(-x, -x, tol=tol, rel=rel) 378 self.assertTrue(result, 'equality failure for x=%r' % -x) 379 380 def test_exactly_equal_ints(self): 381 # Test that equal int values are exactly equal. 382 for n in [42, 19740, 14974, 230, 1795, 700245, 36587]: 383 self.do_exactly_equal_test(n, 0, 0) 384 385 def test_exactly_equal_floats(self): 386 # Test that equal float values are exactly equal. 387 for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]: 388 self.do_exactly_equal_test(x, 0, 0) 389 390 def test_exactly_equal_fractions(self): 391 # Test that equal Fraction values are exactly equal. 392 F = Fraction 393 for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]: 394 self.do_exactly_equal_test(f, 0, 0) 395 396 def test_exactly_equal_decimals(self): 397 # Test that equal Decimal values are exactly equal. 398 D = Decimal 399 for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()): 400 self.do_exactly_equal_test(d, 0, 0) 401 402 def test_exactly_equal_absolute(self): 403 # Test that equal values are exactly equal with an absolute error. 404 for n in [16, 1013, 1372, 1198, 971, 4]: 405 # Test as ints. 406 self.do_exactly_equal_test(n, 0.01, 0) 407 # Test as floats. 408 self.do_exactly_equal_test(n/10, 0.01, 0) 409 # Test as Fractions. 410 f = Fraction(n, 1234) 411 self.do_exactly_equal_test(f, 0.01, 0) 412 413 def test_exactly_equal_absolute_decimals(self): 414 # Test equal Decimal values are exactly equal with an absolute error. 415 self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0) 416 self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0) 417 418 def test_exactly_equal_relative(self): 419 # Test that equal values are exactly equal with a relative error. 420 for x in [8347, 101.3, -7910.28, Fraction(5, 21)]: 421 self.do_exactly_equal_test(x, 0, 0.01) 422 self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01")) 423 424 def test_exactly_equal_both(self): 425 # Test that equal values are equal when both tol and rel are given. 426 for x in [41017, 16.742, -813.02, Fraction(3, 8)]: 427 self.do_exactly_equal_test(x, 0.1, 0.01) 428 D = Decimal 429 self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01")) 430 431 432class ApproxEqualUnequalTest(unittest.TestCase): 433 # Unequal values should compare unequal with zero error tolerances. 434 # Test cases for unequal values, with exact equality test. 435 436 def do_exactly_unequal_test(self, x): 437 for a in (x, -x): 438 result = approx_equal(a, a+1, tol=0, rel=0) 439 self.assertFalse(result, 'inequality failure for x=%r' % a) 440 441 def test_exactly_unequal_ints(self): 442 # Test unequal int values are unequal with zero error tolerance. 443 for n in [951, 572305, 478, 917, 17240]: 444 self.do_exactly_unequal_test(n) 445 446 def test_exactly_unequal_floats(self): 447 # Test unequal float values are unequal with zero error tolerance. 448 for x in [9.51, 5723.05, 47.8, 9.17, 17.24]: 449 self.do_exactly_unequal_test(x) 450 451 def test_exactly_unequal_fractions(self): 452 # Test that unequal Fractions are unequal with zero error tolerance. 453 F = Fraction 454 for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]: 455 self.do_exactly_unequal_test(f) 456 457 def test_exactly_unequal_decimals(self): 458 # Test that unequal Decimals are unequal with zero error tolerance. 459 for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()): 460 self.do_exactly_unequal_test(d) 461 462 463class ApproxEqualInexactTest(unittest.TestCase): 464 # Inexact test cases for approx_error. 465 # Test cases when comparing two values that are not exactly equal. 466 467 # === Absolute error tests === 468 469 def do_approx_equal_abs_test(self, x, delta): 470 template = "Test failure for x={!r}, y={!r}" 471 for y in (x + delta, x - delta): 472 msg = template.format(x, y) 473 self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg) 474 self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg) 475 476 def test_approx_equal_absolute_ints(self): 477 # Test approximate equality of ints with an absolute error. 478 for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]: 479 self.do_approx_equal_abs_test(n, 10) 480 self.do_approx_equal_abs_test(n, 2) 481 482 def test_approx_equal_absolute_floats(self): 483 # Test approximate equality of floats with an absolute error. 484 for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]: 485 self.do_approx_equal_abs_test(x, 1.5) 486 self.do_approx_equal_abs_test(x, 0.01) 487 self.do_approx_equal_abs_test(x, 0.0001) 488 489 def test_approx_equal_absolute_fractions(self): 490 # Test approximate equality of Fractions with an absolute error. 491 delta = Fraction(1, 29) 492 numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71] 493 for f in (Fraction(n, 29) for n in numerators): 494 self.do_approx_equal_abs_test(f, delta) 495 self.do_approx_equal_abs_test(f, float(delta)) 496 497 def test_approx_equal_absolute_decimals(self): 498 # Test approximate equality of Decimals with an absolute error. 499 delta = Decimal("0.01") 500 for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()): 501 self.do_approx_equal_abs_test(d, delta) 502 self.do_approx_equal_abs_test(-d, delta) 503 504 def test_cross_zero(self): 505 # Test for the case of the two values having opposite signs. 506 self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0)) 507 508 # === Relative error tests === 509 510 def do_approx_equal_rel_test(self, x, delta): 511 template = "Test failure for x={!r}, y={!r}" 512 for y in (x*(1+delta), x*(1-delta)): 513 msg = template.format(x, y) 514 self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg) 515 self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg) 516 517 def test_approx_equal_relative_ints(self): 518 # Test approximate equality of ints with a relative error. 519 self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36)) 520 self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37)) 521 # --- 522 self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125)) 523 self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125)) 524 self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125)) 525 526 def test_approx_equal_relative_floats(self): 527 # Test approximate equality of floats with a relative error. 528 for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]: 529 self.do_approx_equal_rel_test(x, 0.02) 530 self.do_approx_equal_rel_test(x, 0.0001) 531 532 def test_approx_equal_relative_fractions(self): 533 # Test approximate equality of Fractions with a relative error. 534 F = Fraction 535 delta = Fraction(3, 8) 536 for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]: 537 for d in (delta, float(delta)): 538 self.do_approx_equal_rel_test(f, d) 539 self.do_approx_equal_rel_test(-f, d) 540 541 def test_approx_equal_relative_decimals(self): 542 # Test approximate equality of Decimals with a relative error. 543 for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()): 544 self.do_approx_equal_rel_test(d, Decimal("0.001")) 545 self.do_approx_equal_rel_test(-d, Decimal("0.05")) 546 547 # === Both absolute and relative error tests === 548 549 # There are four cases to consider: 550 # 1) actual error <= both absolute and relative error 551 # 2) actual error <= absolute error but > relative error 552 # 3) actual error <= relative error but > absolute error 553 # 4) actual error > both absolute and relative error 554 555 def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag): 556 check = self.assertTrue if tol_flag else self.assertFalse 557 check(approx_equal(a, b, tol=tol, rel=0)) 558 check = self.assertTrue if rel_flag else self.assertFalse 559 check(approx_equal(a, b, tol=0, rel=rel)) 560 check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse 561 check(approx_equal(a, b, tol=tol, rel=rel)) 562 563 def test_approx_equal_both1(self): 564 # Test actual error <= both absolute and relative error. 565 self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True) 566 self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True) 567 568 def test_approx_equal_both2(self): 569 # Test actual error <= absolute error but > relative error. 570 self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False) 571 572 def test_approx_equal_both3(self): 573 # Test actual error <= relative error but > absolute error. 574 self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True) 575 576 def test_approx_equal_both4(self): 577 # Test actual error > both absolute and relative error. 578 self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False) 579 self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False) 580 581 582class ApproxEqualSpecialsTest(unittest.TestCase): 583 # Test approx_equal with NANs and INFs and zeroes. 584 585 def test_inf(self): 586 for type_ in (float, Decimal): 587 inf = type_('inf') 588 self.assertTrue(approx_equal(inf, inf)) 589 self.assertTrue(approx_equal(inf, inf, 0, 0)) 590 self.assertTrue(approx_equal(inf, inf, 1, 0.01)) 591 self.assertTrue(approx_equal(-inf, -inf)) 592 self.assertFalse(approx_equal(inf, -inf)) 593 self.assertFalse(approx_equal(inf, 1000)) 594 595 def test_nan(self): 596 for type_ in (float, Decimal): 597 nan = type_('nan') 598 for other in (nan, type_('inf'), 1000): 599 self.assertFalse(approx_equal(nan, other)) 600 601 def test_float_zeroes(self): 602 nzero = math.copysign(0.0, -1) 603 self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1)) 604 605 def test_decimal_zeroes(self): 606 nzero = Decimal("-0.0") 607 self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1)) 608 609 610class TestApproxEqualErrors(unittest.TestCase): 611 # Test error conditions of approx_equal. 612 613 def test_bad_tol(self): 614 # Test negative tol raises. 615 self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1) 616 617 def test_bad_rel(self): 618 # Test negative rel raises. 619 self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1) 620 621 622# --- Tests for NumericTestCase --- 623 624# The formatting routine that generates the error messages is complex enough 625# that it too needs testing. 626 627class TestNumericTestCase(unittest.TestCase): 628 # The exact wording of NumericTestCase error messages is *not* guaranteed, 629 # but we need to give them some sort of test to ensure that they are 630 # generated correctly. As a compromise, we look for specific substrings 631 # that are expected to be found even if the overall error message changes. 632 633 def do_test(self, args): 634 actual_msg = NumericTestCase._make_std_err_msg(*args) 635 expected = self.generate_substrings(*args) 636 for substring in expected: 637 self.assertIn(substring, actual_msg) 638 639 def test_numerictestcase_is_testcase(self): 640 # Ensure that NumericTestCase actually is a TestCase. 641 self.assertTrue(issubclass(NumericTestCase, unittest.TestCase)) 642 643 def test_error_msg_numeric(self): 644 # Test the error message generated for numeric comparisons. 645 args = (2.5, 4.0, 0.5, 0.25, None) 646 self.do_test(args) 647 648 def test_error_msg_sequence(self): 649 # Test the error message generated for sequence comparisons. 650 args = (3.75, 8.25, 1.25, 0.5, 7) 651 self.do_test(args) 652 653 def generate_substrings(self, first, second, tol, rel, idx): 654 """Return substrings we expect to see in error messages.""" 655 abs_err, rel_err = _calc_errors(first, second) 656 substrings = [ 657 'tol=%r' % tol, 658 'rel=%r' % rel, 659 'absolute error = %r' % abs_err, 660 'relative error = %r' % rel_err, 661 ] 662 if idx is not None: 663 substrings.append('differ at index %d' % idx) 664 return substrings 665 666 667# ======================================= 668# === Tests for the statistics module === 669# ======================================= 670 671 672class GlobalsTest(unittest.TestCase): 673 module = statistics 674 expected_metadata = ["__doc__", "__all__"] 675 676 def test_meta(self): 677 # Test for the existence of metadata. 678 for meta in self.expected_metadata: 679 self.assertTrue(hasattr(self.module, meta), 680 "%s not present" % meta) 681 682 def test_check_all(self): 683 # Check everything in __all__ exists and is public. 684 module = self.module 685 for name in module.__all__: 686 # No private names in __all__: 687 self.assertFalse(name.startswith("_"), 688 'private name "%s" in __all__' % name) 689 # And anything in __all__ must exist: 690 self.assertTrue(hasattr(module, name), 691 'missing name "%s" in __all__' % name) 692 693 694class DocTests(unittest.TestCase): 695 @unittest.skipIf(sys.flags.optimize >= 2, 696 "Docstrings are omitted with -OO and above") 697 def test_doc_tests(self): 698 failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS) 699 self.assertGreater(tried, 0) 700 self.assertEqual(failed, 0) 701 702class StatisticsErrorTest(unittest.TestCase): 703 def test_has_exception(self): 704 errmsg = ( 705 "Expected StatisticsError to be a ValueError, but got a" 706 " subclass of %r instead." 707 ) 708 self.assertTrue(hasattr(statistics, 'StatisticsError')) 709 self.assertTrue( 710 issubclass(statistics.StatisticsError, ValueError), 711 errmsg % statistics.StatisticsError.__base__ 712 ) 713 714 715# === Tests for private utility functions === 716 717class ExactRatioTest(unittest.TestCase): 718 # Test _exact_ratio utility. 719 720 def test_int(self): 721 for i in (-20, -3, 0, 5, 99, 10**20): 722 self.assertEqual(statistics._exact_ratio(i), (i, 1)) 723 724 def test_fraction(self): 725 numerators = (-5, 1, 12, 38) 726 for n in numerators: 727 f = Fraction(n, 37) 728 self.assertEqual(statistics._exact_ratio(f), (n, 37)) 729 730 def test_float(self): 731 self.assertEqual(statistics._exact_ratio(0.125), (1, 8)) 732 self.assertEqual(statistics._exact_ratio(1.125), (9, 8)) 733 data = [random.uniform(-100, 100) for _ in range(100)] 734 for x in data: 735 num, den = statistics._exact_ratio(x) 736 self.assertEqual(x, num/den) 737 738 def test_decimal(self): 739 D = Decimal 740 _exact_ratio = statistics._exact_ratio 741 self.assertEqual(_exact_ratio(D("0.125")), (1, 8)) 742 self.assertEqual(_exact_ratio(D("12.345")), (2469, 200)) 743 self.assertEqual(_exact_ratio(D("-1.98")), (-99, 50)) 744 745 def test_inf(self): 746 INF = float("INF") 747 class MyFloat(float): 748 pass 749 class MyDecimal(Decimal): 750 pass 751 for inf in (INF, -INF): 752 for type_ in (float, MyFloat, Decimal, MyDecimal): 753 x = type_(inf) 754 ratio = statistics._exact_ratio(x) 755 self.assertEqual(ratio, (x, None)) 756 self.assertEqual(type(ratio[0]), type_) 757 self.assertTrue(math.isinf(ratio[0])) 758 759 def test_float_nan(self): 760 NAN = float("NAN") 761 class MyFloat(float): 762 pass 763 for nan in (NAN, MyFloat(NAN)): 764 ratio = statistics._exact_ratio(nan) 765 self.assertTrue(math.isnan(ratio[0])) 766 self.assertIs(ratio[1], None) 767 self.assertEqual(type(ratio[0]), type(nan)) 768 769 def test_decimal_nan(self): 770 NAN = Decimal("NAN") 771 sNAN = Decimal("sNAN") 772 class MyDecimal(Decimal): 773 pass 774 for nan in (NAN, MyDecimal(NAN), sNAN, MyDecimal(sNAN)): 775 ratio = statistics._exact_ratio(nan) 776 self.assertTrue(_nan_equal(ratio[0], nan)) 777 self.assertIs(ratio[1], None) 778 self.assertEqual(type(ratio[0]), type(nan)) 779 780 781class DecimalToRatioTest(unittest.TestCase): 782 # Test _exact_ratio private function. 783 784 def test_infinity(self): 785 # Test that INFs are handled correctly. 786 inf = Decimal('INF') 787 self.assertEqual(statistics._exact_ratio(inf), (inf, None)) 788 self.assertEqual(statistics._exact_ratio(-inf), (-inf, None)) 789 790 def test_nan(self): 791 # Test that NANs are handled correctly. 792 for nan in (Decimal('NAN'), Decimal('sNAN')): 793 num, den = statistics._exact_ratio(nan) 794 # Because NANs always compare non-equal, we cannot use assertEqual. 795 # Nor can we use an identity test, as we don't guarantee anything 796 # about the object identity. 797 self.assertTrue(_nan_equal(num, nan)) 798 self.assertIs(den, None) 799 800 def test_sign(self): 801 # Test sign is calculated correctly. 802 numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")] 803 for d in numbers: 804 # First test positive decimals. 805 assert d > 0 806 num, den = statistics._exact_ratio(d) 807 self.assertGreaterEqual(num, 0) 808 self.assertGreater(den, 0) 809 # Then test negative decimals. 810 num, den = statistics._exact_ratio(-d) 811 self.assertLessEqual(num, 0) 812 self.assertGreater(den, 0) 813 814 def test_negative_exponent(self): 815 # Test result when the exponent is negative. 816 t = statistics._exact_ratio(Decimal("0.1234")) 817 self.assertEqual(t, (617, 5000)) 818 819 def test_positive_exponent(self): 820 # Test results when the exponent is positive. 821 t = statistics._exact_ratio(Decimal("1.234e7")) 822 self.assertEqual(t, (12340000, 1)) 823 824 def test_regression_20536(self): 825 # Regression test for issue 20536. 826 # See http://bugs.python.org/issue20536 827 t = statistics._exact_ratio(Decimal("1e2")) 828 self.assertEqual(t, (100, 1)) 829 t = statistics._exact_ratio(Decimal("1.47e5")) 830 self.assertEqual(t, (147000, 1)) 831 832 833class IsFiniteTest(unittest.TestCase): 834 # Test _isfinite private function. 835 836 def test_finite(self): 837 # Test that finite numbers are recognised as finite. 838 for x in (5, Fraction(1, 3), 2.5, Decimal("5.5")): 839 self.assertTrue(statistics._isfinite(x)) 840 841 def test_infinity(self): 842 # Test that INFs are not recognised as finite. 843 for x in (float("inf"), Decimal("inf")): 844 self.assertFalse(statistics._isfinite(x)) 845 846 def test_nan(self): 847 # Test that NANs are not recognised as finite. 848 for x in (float("nan"), Decimal("NAN"), Decimal("sNAN")): 849 self.assertFalse(statistics._isfinite(x)) 850 851 852class CoerceTest(unittest.TestCase): 853 # Test that private function _coerce correctly deals with types. 854 855 # The coercion rules are currently an implementation detail, although at 856 # some point that should change. The tests and comments here define the 857 # correct implementation. 858 859 # Pre-conditions of _coerce: 860 # 861 # - The first time _sum calls _coerce, the 862 # - coerce(T, S) will never be called with bool as the first argument; 863 # this is a pre-condition, guarded with an assertion. 864 865 # 866 # - coerce(T, T) will always return T; we assume T is a valid numeric 867 # type. Violate this assumption at your own risk. 868 # 869 # - Apart from as above, bool is treated as if it were actually int. 870 # 871 # - coerce(int, X) and coerce(X, int) return X. 872 # - 873 def test_bool(self): 874 # bool is somewhat special, due to the pre-condition that it is 875 # never given as the first argument to _coerce, and that it cannot 876 # be subclassed. So we test it specially. 877 for T in (int, float, Fraction, Decimal): 878 self.assertIs(statistics._coerce(T, bool), T) 879 class MyClass(T): pass 880 self.assertIs(statistics._coerce(MyClass, bool), MyClass) 881 882 def assertCoerceTo(self, A, B): 883 """Assert that type A coerces to B.""" 884 self.assertIs(statistics._coerce(A, B), B) 885 self.assertIs(statistics._coerce(B, A), B) 886 887 def check_coerce_to(self, A, B): 888 """Checks that type A coerces to B, including subclasses.""" 889 # Assert that type A is coerced to B. 890 self.assertCoerceTo(A, B) 891 # Subclasses of A are also coerced to B. 892 class SubclassOfA(A): pass 893 self.assertCoerceTo(SubclassOfA, B) 894 # A, and subclasses of A, are coerced to subclasses of B. 895 class SubclassOfB(B): pass 896 self.assertCoerceTo(A, SubclassOfB) 897 self.assertCoerceTo(SubclassOfA, SubclassOfB) 898 899 def assertCoerceRaises(self, A, B): 900 """Assert that coercing A to B, or vice versa, raises TypeError.""" 901 self.assertRaises(TypeError, statistics._coerce, (A, B)) 902 self.assertRaises(TypeError, statistics._coerce, (B, A)) 903 904 def check_type_coercions(self, T): 905 """Check that type T coerces correctly with subclasses of itself.""" 906 assert T is not bool 907 # Coercing a type with itself returns the same type. 908 self.assertIs(statistics._coerce(T, T), T) 909 # Coercing a type with a subclass of itself returns the subclass. 910 class U(T): pass 911 class V(T): pass 912 class W(U): pass 913 for typ in (U, V, W): 914 self.assertCoerceTo(T, typ) 915 self.assertCoerceTo(U, W) 916 # Coercing two subclasses that aren't parent/child is an error. 917 self.assertCoerceRaises(U, V) 918 self.assertCoerceRaises(V, W) 919 920 def test_int(self): 921 # Check that int coerces correctly. 922 self.check_type_coercions(int) 923 for typ in (float, Fraction, Decimal): 924 self.check_coerce_to(int, typ) 925 926 def test_fraction(self): 927 # Check that Fraction coerces correctly. 928 self.check_type_coercions(Fraction) 929 self.check_coerce_to(Fraction, float) 930 931 def test_decimal(self): 932 # Check that Decimal coerces correctly. 933 self.check_type_coercions(Decimal) 934 935 def test_float(self): 936 # Check that float coerces correctly. 937 self.check_type_coercions(float) 938 939 def test_non_numeric_types(self): 940 for bad_type in (str, list, type(None), tuple, dict): 941 for good_type in (int, float, Fraction, Decimal): 942 self.assertCoerceRaises(good_type, bad_type) 943 944 def test_incompatible_types(self): 945 # Test that incompatible types raise. 946 for T in (float, Fraction): 947 class MySubclass(T): pass 948 self.assertCoerceRaises(T, Decimal) 949 self.assertCoerceRaises(MySubclass, Decimal) 950 951 952class ConvertTest(unittest.TestCase): 953 # Test private _convert function. 954 955 def check_exact_equal(self, x, y): 956 """Check that x equals y, and has the same type as well.""" 957 self.assertEqual(x, y) 958 self.assertIs(type(x), type(y)) 959 960 def test_int(self): 961 # Test conversions to int. 962 x = statistics._convert(Fraction(71), int) 963 self.check_exact_equal(x, 71) 964 class MyInt(int): pass 965 x = statistics._convert(Fraction(17), MyInt) 966 self.check_exact_equal(x, MyInt(17)) 967 968 def test_fraction(self): 969 # Test conversions to Fraction. 970 x = statistics._convert(Fraction(95, 99), Fraction) 971 self.check_exact_equal(x, Fraction(95, 99)) 972 class MyFraction(Fraction): 973 def __truediv__(self, other): 974 return self.__class__(super().__truediv__(other)) 975 x = statistics._convert(Fraction(71, 13), MyFraction) 976 self.check_exact_equal(x, MyFraction(71, 13)) 977 978 def test_float(self): 979 # Test conversions to float. 980 x = statistics._convert(Fraction(-1, 2), float) 981 self.check_exact_equal(x, -0.5) 982 class MyFloat(float): 983 def __truediv__(self, other): 984 return self.__class__(super().__truediv__(other)) 985 x = statistics._convert(Fraction(9, 8), MyFloat) 986 self.check_exact_equal(x, MyFloat(1.125)) 987 988 def test_decimal(self): 989 # Test conversions to Decimal. 990 x = statistics._convert(Fraction(1, 40), Decimal) 991 self.check_exact_equal(x, Decimal("0.025")) 992 class MyDecimal(Decimal): 993 def __truediv__(self, other): 994 return self.__class__(super().__truediv__(other)) 995 x = statistics._convert(Fraction(-15, 16), MyDecimal) 996 self.check_exact_equal(x, MyDecimal("-0.9375")) 997 998 def test_inf(self): 999 for INF in (float('inf'), Decimal('inf')): 1000 for inf in (INF, -INF): 1001 x = statistics._convert(inf, type(inf)) 1002 self.check_exact_equal(x, inf) 1003 1004 def test_nan(self): 1005 for nan in (float('nan'), Decimal('NAN'), Decimal('sNAN')): 1006 x = statistics._convert(nan, type(nan)) 1007 self.assertTrue(_nan_equal(x, nan)) 1008 1009 def test_invalid_input_type(self): 1010 with self.assertRaises(TypeError): 1011 statistics._convert(None, float) 1012 1013 1014class FailNegTest(unittest.TestCase): 1015 """Test _fail_neg private function.""" 1016 1017 def test_pass_through(self): 1018 # Test that values are passed through unchanged. 1019 values = [1, 2.0, Fraction(3), Decimal(4)] 1020 new = list(statistics._fail_neg(values)) 1021 self.assertEqual(values, new) 1022 1023 def test_negatives_raise(self): 1024 # Test that negatives raise an exception. 1025 for x in [1, 2.0, Fraction(3), Decimal(4)]: 1026 seq = [-x] 1027 it = statistics._fail_neg(seq) 1028 self.assertRaises(statistics.StatisticsError, next, it) 1029 1030 def test_error_msg(self): 1031 # Test that a given error message is used. 1032 msg = "badness #%d" % random.randint(10000, 99999) 1033 try: 1034 next(statistics._fail_neg([-1], msg)) 1035 except statistics.StatisticsError as e: 1036 errmsg = e.args[0] 1037 else: 1038 self.fail("expected exception, but it didn't happen") 1039 self.assertEqual(errmsg, msg) 1040 1041 1042class FindLteqTest(unittest.TestCase): 1043 # Test _find_lteq private function. 1044 1045 def test_invalid_input_values(self): 1046 for a, x in [ 1047 ([], 1), 1048 ([1, 2], 3), 1049 ([1, 3], 2) 1050 ]: 1051 with self.subTest(a=a, x=x): 1052 with self.assertRaises(ValueError): 1053 statistics._find_lteq(a, x) 1054 1055 def test_locate_successfully(self): 1056 for a, x, expected_i in [ 1057 ([1, 1, 1, 2, 3], 1, 0), 1058 ([0, 1, 1, 1, 2, 3], 1, 1), 1059 ([1, 2, 3, 3, 3], 3, 2) 1060 ]: 1061 with self.subTest(a=a, x=x): 1062 self.assertEqual(expected_i, statistics._find_lteq(a, x)) 1063 1064 1065class FindRteqTest(unittest.TestCase): 1066 # Test _find_rteq private function. 1067 1068 def test_invalid_input_values(self): 1069 for a, l, x in [ 1070 ([1], 2, 1), 1071 ([1, 3], 0, 2) 1072 ]: 1073 with self.assertRaises(ValueError): 1074 statistics._find_rteq(a, l, x) 1075 1076 def test_locate_successfully(self): 1077 for a, l, x, expected_i in [ 1078 ([1, 1, 1, 2, 3], 0, 1, 2), 1079 ([0, 1, 1, 1, 2, 3], 0, 1, 3), 1080 ([1, 2, 3, 3, 3], 0, 3, 4) 1081 ]: 1082 with self.subTest(a=a, l=l, x=x): 1083 self.assertEqual(expected_i, statistics._find_rteq(a, l, x)) 1084 1085 1086# === Tests for public functions === 1087 1088class UnivariateCommonMixin: 1089 # Common tests for most univariate functions that take a data argument. 1090 1091 def test_no_args(self): 1092 # Fail if given no arguments. 1093 self.assertRaises(TypeError, self.func) 1094 1095 def test_empty_data(self): 1096 # Fail when the data argument (first argument) is empty. 1097 for empty in ([], (), iter([])): 1098 self.assertRaises(statistics.StatisticsError, self.func, empty) 1099 1100 def prepare_data(self): 1101 """Return int data for various tests.""" 1102 data = list(range(10)) 1103 while data == sorted(data): 1104 random.shuffle(data) 1105 return data 1106 1107 def test_no_inplace_modifications(self): 1108 # Test that the function does not modify its input data. 1109 data = self.prepare_data() 1110 assert len(data) != 1 # Necessary to avoid infinite loop. 1111 assert data != sorted(data) 1112 saved = data[:] 1113 assert data is not saved 1114 _ = self.func(data) 1115 self.assertListEqual(data, saved, "data has been modified") 1116 1117 def test_order_doesnt_matter(self): 1118 # Test that the order of data points doesn't change the result. 1119 1120 # CAUTION: due to floating point rounding errors, the result actually 1121 # may depend on the order. Consider this test representing an ideal. 1122 # To avoid this test failing, only test with exact values such as ints 1123 # or Fractions. 1124 data = [1, 2, 3, 3, 3, 4, 5, 6]*100 1125 expected = self.func(data) 1126 random.shuffle(data) 1127 actual = self.func(data) 1128 self.assertEqual(expected, actual) 1129 1130 def test_type_of_data_collection(self): 1131 # Test that the type of iterable data doesn't effect the result. 1132 class MyList(list): 1133 pass 1134 class MyTuple(tuple): 1135 pass 1136 def generator(data): 1137 return (obj for obj in data) 1138 data = self.prepare_data() 1139 expected = self.func(data) 1140 for kind in (list, tuple, iter, MyList, MyTuple, generator): 1141 result = self.func(kind(data)) 1142 self.assertEqual(result, expected) 1143 1144 def test_range_data(self): 1145 # Test that functions work with range objects. 1146 data = range(20, 50, 3) 1147 expected = self.func(list(data)) 1148 self.assertEqual(self.func(data), expected) 1149 1150 def test_bad_arg_types(self): 1151 # Test that function raises when given data of the wrong type. 1152 1153 # Don't roll the following into a loop like this: 1154 # for bad in list_of_bad: 1155 # self.check_for_type_error(bad) 1156 # 1157 # Since assertRaises doesn't show the arguments that caused the test 1158 # failure, it is very difficult to debug these test failures when the 1159 # following are in a loop. 1160 self.check_for_type_error(None) 1161 self.check_for_type_error(23) 1162 self.check_for_type_error(42.0) 1163 self.check_for_type_error(object()) 1164 1165 def check_for_type_error(self, *args): 1166 self.assertRaises(TypeError, self.func, *args) 1167 1168 def test_type_of_data_element(self): 1169 # Check the type of data elements doesn't affect the numeric result. 1170 # This is a weaker test than UnivariateTypeMixin.testTypesConserved, 1171 # because it checks the numeric result by equality, but not by type. 1172 class MyFloat(float): 1173 def __truediv__(self, other): 1174 return type(self)(super().__truediv__(other)) 1175 def __add__(self, other): 1176 return type(self)(super().__add__(other)) 1177 __radd__ = __add__ 1178 1179 raw = self.prepare_data() 1180 expected = self.func(raw) 1181 for kind in (float, MyFloat, Decimal, Fraction): 1182 data = [kind(x) for x in raw] 1183 result = type(expected)(self.func(data)) 1184 self.assertEqual(result, expected) 1185 1186 1187class UnivariateTypeMixin: 1188 """Mixin class for type-conserving functions. 1189 1190 This mixin class holds test(s) for functions which conserve the type of 1191 individual data points. E.g. the mean of a list of Fractions should itself 1192 be a Fraction. 1193 1194 Not all tests to do with types need go in this class. Only those that 1195 rely on the function returning the same type as its input data. 1196 """ 1197 def prepare_types_for_conservation_test(self): 1198 """Return the types which are expected to be conserved.""" 1199 class MyFloat(float): 1200 def __truediv__(self, other): 1201 return type(self)(super().__truediv__(other)) 1202 def __rtruediv__(self, other): 1203 return type(self)(super().__rtruediv__(other)) 1204 def __sub__(self, other): 1205 return type(self)(super().__sub__(other)) 1206 def __rsub__(self, other): 1207 return type(self)(super().__rsub__(other)) 1208 def __pow__(self, other): 1209 return type(self)(super().__pow__(other)) 1210 def __add__(self, other): 1211 return type(self)(super().__add__(other)) 1212 __radd__ = __add__ 1213 return (float, Decimal, Fraction, MyFloat) 1214 1215 def test_types_conserved(self): 1216 # Test that functions keeps the same type as their data points. 1217 # (Excludes mixed data types.) This only tests the type of the return 1218 # result, not the value. 1219 data = self.prepare_data() 1220 for kind in self.prepare_types_for_conservation_test(): 1221 d = [kind(x) for x in data] 1222 result = self.func(d) 1223 self.assertIs(type(result), kind) 1224 1225 1226class TestSumCommon(UnivariateCommonMixin, UnivariateTypeMixin): 1227 # Common test cases for statistics._sum() function. 1228 1229 # This test suite looks only at the numeric value returned by _sum, 1230 # after conversion to the appropriate type. 1231 def setUp(self): 1232 def simplified_sum(*args): 1233 T, value, n = statistics._sum(*args) 1234 return statistics._coerce(value, T) 1235 self.func = simplified_sum 1236 1237 1238class TestSum(NumericTestCase): 1239 # Test cases for statistics._sum() function. 1240 1241 # These tests look at the entire three value tuple returned by _sum. 1242 1243 def setUp(self): 1244 self.func = statistics._sum 1245 1246 def test_empty_data(self): 1247 # Override test for empty data. 1248 for data in ([], (), iter([])): 1249 self.assertEqual(self.func(data), (int, Fraction(0), 0)) 1250 1251 def test_ints(self): 1252 self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]), 1253 (int, Fraction(60), 8)) 1254 1255 def test_floats(self): 1256 self.assertEqual(self.func([0.25]*20), 1257 (float, Fraction(5.0), 20)) 1258 1259 def test_fractions(self): 1260 self.assertEqual(self.func([Fraction(1, 1000)]*500), 1261 (Fraction, Fraction(1, 2), 500)) 1262 1263 def test_decimals(self): 1264 D = Decimal 1265 data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"), 1266 D("3.974"), D("2.328"), D("4.617"), D("2.843"), 1267 ] 1268 self.assertEqual(self.func(data), 1269 (Decimal, Decimal("20.686"), 8)) 1270 1271 def test_compare_with_math_fsum(self): 1272 # Compare with the math.fsum function. 1273 # Ideally we ought to get the exact same result, but sometimes 1274 # we differ by a very slight amount :-( 1275 data = [random.uniform(-100, 1000) for _ in range(1000)] 1276 self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16) 1277 1278 def test_strings_fail(self): 1279 # Sum of strings should fail. 1280 self.assertRaises(TypeError, self.func, [1, 2, 3], '999') 1281 self.assertRaises(TypeError, self.func, [1, 2, 3, '999']) 1282 1283 def test_bytes_fail(self): 1284 # Sum of bytes should fail. 1285 self.assertRaises(TypeError, self.func, [1, 2, 3], b'999') 1286 self.assertRaises(TypeError, self.func, [1, 2, 3, b'999']) 1287 1288 def test_mixed_sum(self): 1289 # Mixed input types are not (currently) allowed. 1290 # Check that mixed data types fail. 1291 self.assertRaises(TypeError, self.func, [1, 2.0, Decimal(1)]) 1292 # And so does mixed start argument. 1293 self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1)) 1294 1295 1296class SumTortureTest(NumericTestCase): 1297 def test_torture(self): 1298 # Tim Peters' torture test for sum, and variants of same. 1299 self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000), 1300 (float, Fraction(20000.0), 40000)) 1301 self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000), 1302 (float, Fraction(20000.0), 40000)) 1303 T, num, count = statistics._sum([1e-100, 1, 1e-100, -1]*10000) 1304 self.assertIs(T, float) 1305 self.assertEqual(count, 40000) 1306 self.assertApproxEqual(float(num), 2.0e-96, rel=5e-16) 1307 1308 1309class SumSpecialValues(NumericTestCase): 1310 # Test that sum works correctly with IEEE-754 special values. 1311 1312 def test_nan(self): 1313 for type_ in (float, Decimal): 1314 nan = type_('nan') 1315 result = statistics._sum([1, nan, 2])[1] 1316 self.assertIs(type(result), type_) 1317 self.assertTrue(math.isnan(result)) 1318 1319 def check_infinity(self, x, inf): 1320 """Check x is an infinity of the same type and sign as inf.""" 1321 self.assertTrue(math.isinf(x)) 1322 self.assertIs(type(x), type(inf)) 1323 self.assertEqual(x > 0, inf > 0) 1324 assert x == inf 1325 1326 def do_test_inf(self, inf): 1327 # Adding a single infinity gives infinity. 1328 result = statistics._sum([1, 2, inf, 3])[1] 1329 self.check_infinity(result, inf) 1330 # Adding two infinities of the same sign also gives infinity. 1331 result = statistics._sum([1, 2, inf, 3, inf, 4])[1] 1332 self.check_infinity(result, inf) 1333 1334 def test_float_inf(self): 1335 inf = float('inf') 1336 for sign in (+1, -1): 1337 self.do_test_inf(sign*inf) 1338 1339 def test_decimal_inf(self): 1340 inf = Decimal('inf') 1341 for sign in (+1, -1): 1342 self.do_test_inf(sign*inf) 1343 1344 def test_float_mismatched_infs(self): 1345 # Test that adding two infinities of opposite sign gives a NAN. 1346 inf = float('inf') 1347 result = statistics._sum([1, 2, inf, 3, -inf, 4])[1] 1348 self.assertTrue(math.isnan(result)) 1349 1350 def test_decimal_extendedcontext_mismatched_infs_to_nan(self): 1351 # Test adding Decimal INFs with opposite sign returns NAN. 1352 inf = Decimal('inf') 1353 data = [1, 2, inf, 3, -inf, 4] 1354 with decimal.localcontext(decimal.ExtendedContext): 1355 self.assertTrue(math.isnan(statistics._sum(data)[1])) 1356 1357 def test_decimal_basiccontext_mismatched_infs_to_nan(self): 1358 # Test adding Decimal INFs with opposite sign raises InvalidOperation. 1359 inf = Decimal('inf') 1360 data = [1, 2, inf, 3, -inf, 4] 1361 with decimal.localcontext(decimal.BasicContext): 1362 self.assertRaises(decimal.InvalidOperation, statistics._sum, data) 1363 1364 def test_decimal_snan_raises(self): 1365 # Adding sNAN should raise InvalidOperation. 1366 sNAN = Decimal('sNAN') 1367 data = [1, sNAN, 2] 1368 self.assertRaises(decimal.InvalidOperation, statistics._sum, data) 1369 1370 1371# === Tests for averages === 1372 1373class AverageMixin(UnivariateCommonMixin): 1374 # Mixin class holding common tests for averages. 1375 1376 def test_single_value(self): 1377 # Average of a single value is the value itself. 1378 for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')): 1379 self.assertEqual(self.func([x]), x) 1380 1381 def prepare_values_for_repeated_single_test(self): 1382 return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712')) 1383 1384 def test_repeated_single_value(self): 1385 # The average of a single repeated value is the value itself. 1386 for x in self.prepare_values_for_repeated_single_test(): 1387 for count in (2, 5, 10, 20): 1388 with self.subTest(x=x, count=count): 1389 data = [x]*count 1390 self.assertEqual(self.func(data), x) 1391 1392 1393class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): 1394 def setUp(self): 1395 self.func = statistics.mean 1396 1397 def test_torture_pep(self): 1398 # "Torture Test" from PEP-450. 1399 self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1) 1400 1401 def test_ints(self): 1402 # Test mean with ints. 1403 data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9] 1404 random.shuffle(data) 1405 self.assertEqual(self.func(data), 4.8125) 1406 1407 def test_floats(self): 1408 # Test mean with floats. 1409 data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5] 1410 random.shuffle(data) 1411 self.assertEqual(self.func(data), 22.015625) 1412 1413 def test_decimals(self): 1414 # Test mean with Decimals. 1415 D = Decimal 1416 data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")] 1417 random.shuffle(data) 1418 self.assertEqual(self.func(data), D("3.5896")) 1419 1420 def test_fractions(self): 1421 # Test mean with Fractions. 1422 F = Fraction 1423 data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)] 1424 random.shuffle(data) 1425 self.assertEqual(self.func(data), F(1479, 1960)) 1426 1427 def test_inf(self): 1428 # Test mean with infinities. 1429 raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later. 1430 for kind in (float, Decimal): 1431 for sign in (1, -1): 1432 inf = kind("inf")*sign 1433 data = raw + [inf] 1434 result = self.func(data) 1435 self.assertTrue(math.isinf(result)) 1436 self.assertEqual(result, inf) 1437 1438 def test_mismatched_infs(self): 1439 # Test mean with infinities of opposite sign. 1440 data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')] 1441 result = self.func(data) 1442 self.assertTrue(math.isnan(result)) 1443 1444 def test_nan(self): 1445 # Test mean with NANs. 1446 raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later. 1447 for kind in (float, Decimal): 1448 inf = kind("nan") 1449 data = raw + [inf] 1450 result = self.func(data) 1451 self.assertTrue(math.isnan(result)) 1452 1453 def test_big_data(self): 1454 # Test adding a large constant to every data point. 1455 c = 1e9 1456 data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4] 1457 expected = self.func(data) + c 1458 assert expected != c 1459 result = self.func([x+c for x in data]) 1460 self.assertEqual(result, expected) 1461 1462 def test_doubled_data(self): 1463 # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z]. 1464 data = [random.uniform(-3, 5) for _ in range(1000)] 1465 expected = self.func(data) 1466 actual = self.func(data*2) 1467 self.assertApproxEqual(actual, expected) 1468 1469 def test_regression_20561(self): 1470 # Regression test for issue 20561. 1471 # See http://bugs.python.org/issue20561 1472 d = Decimal('1e4') 1473 self.assertEqual(statistics.mean([d]), d) 1474 1475 def test_regression_25177(self): 1476 # Regression test for issue 25177. 1477 # Ensure very big and very small floats don't overflow. 1478 # See http://bugs.python.org/issue25177. 1479 self.assertEqual(statistics.mean( 1480 [8.988465674311579e+307, 8.98846567431158e+307]), 1481 8.98846567431158e+307) 1482 big = 8.98846567431158e+307 1483 tiny = 5e-324 1484 for n in (2, 3, 5, 200): 1485 self.assertEqual(statistics.mean([big]*n), big) 1486 self.assertEqual(statistics.mean([tiny]*n), tiny) 1487 1488 1489class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): 1490 def setUp(self): 1491 self.func = statistics.harmonic_mean 1492 1493 def prepare_data(self): 1494 # Override mixin method. 1495 values = super().prepare_data() 1496 values.remove(0) 1497 return values 1498 1499 def prepare_values_for_repeated_single_test(self): 1500 # Override mixin method. 1501 return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125')) 1502 1503 def test_zero(self): 1504 # Test that harmonic mean returns zero when given zero. 1505 values = [1, 0, 2] 1506 self.assertEqual(self.func(values), 0) 1507 1508 def test_negative_error(self): 1509 # Test that harmonic mean raises when given a negative value. 1510 exc = statistics.StatisticsError 1511 for values in ([-1], [1, -2, 3]): 1512 with self.subTest(values=values): 1513 self.assertRaises(exc, self.func, values) 1514 1515 def test_invalid_type_error(self): 1516 # Test error is raised when input contains invalid type(s) 1517 for data in [ 1518 ['3.14'], # single string 1519 ['1', '2', '3'], # multiple strings 1520 [1, '2', 3, '4', 5], # mixed strings and valid integers 1521 [2.3, 3.4, 4.5, '5.6'] # only one string and valid floats 1522 ]: 1523 with self.subTest(data=data): 1524 with self.assertRaises(TypeError): 1525 self.func(data) 1526 1527 def test_ints(self): 1528 # Test harmonic mean with ints. 1529 data = [2, 4, 4, 8, 16, 16] 1530 random.shuffle(data) 1531 self.assertEqual(self.func(data), 6*4/5) 1532 1533 def test_floats_exact(self): 1534 # Test harmonic mean with some carefully chosen floats. 1535 data = [1/8, 1/4, 1/4, 1/2, 1/2] 1536 random.shuffle(data) 1537 self.assertEqual(self.func(data), 1/4) 1538 self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5) 1539 1540 def test_singleton_lists(self): 1541 # Test that harmonic mean([x]) returns (approximately) x. 1542 for x in range(1, 101): 1543 self.assertEqual(self.func([x]), x) 1544 1545 def test_decimals_exact(self): 1546 # Test harmonic mean with some carefully chosen Decimals. 1547 D = Decimal 1548 self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30)) 1549 data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")] 1550 random.shuffle(data) 1551 self.assertEqual(self.func(data), D("0.10")) 1552 data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")] 1553 random.shuffle(data) 1554 self.assertEqual(self.func(data), D(66528)/70723) 1555 1556 def test_fractions(self): 1557 # Test harmonic mean with Fractions. 1558 F = Fraction 1559 data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)] 1560 random.shuffle(data) 1561 self.assertEqual(self.func(data), F(7*420, 4029)) 1562 1563 def test_inf(self): 1564 # Test harmonic mean with infinity. 1565 values = [2.0, float('inf'), 1.0] 1566 self.assertEqual(self.func(values), 2.0) 1567 1568 def test_nan(self): 1569 # Test harmonic mean with NANs. 1570 values = [2.0, float('nan'), 1.0] 1571 self.assertTrue(math.isnan(self.func(values))) 1572 1573 def test_multiply_data_points(self): 1574 # Test multiplying every data point by a constant. 1575 c = 111 1576 data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4] 1577 expected = self.func(data)*c 1578 result = self.func([x*c for x in data]) 1579 self.assertEqual(result, expected) 1580 1581 def test_doubled_data(self): 1582 # Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z]. 1583 data = [random.uniform(1, 5) for _ in range(1000)] 1584 expected = self.func(data) 1585 actual = self.func(data*2) 1586 self.assertApproxEqual(actual, expected) 1587 1588 def test_with_weights(self): 1589 self.assertEqual(self.func([40, 60], [5, 30]), 56.0) # common case 1590 self.assertEqual(self.func([40, 60], 1591 weights=[5, 30]), 56.0) # keyword argument 1592 self.assertEqual(self.func(iter([40, 60]), 1593 iter([5, 30])), 56.0) # iterator inputs 1594 self.assertEqual( 1595 self.func([Fraction(10, 3), Fraction(23, 5), Fraction(7, 2)], [5, 2, 10]), 1596 self.func([Fraction(10, 3)] * 5 + 1597 [Fraction(23, 5)] * 2 + 1598 [Fraction(7, 2)] * 10)) 1599 self.assertEqual(self.func([10], [7]), 10) # n=1 fast path 1600 with self.assertRaises(TypeError): 1601 self.func([1, 2, 3], [1, (), 3]) # non-numeric weight 1602 with self.assertRaises(statistics.StatisticsError): 1603 self.func([1, 2, 3], [1, 2]) # wrong number of weights 1604 with self.assertRaises(statistics.StatisticsError): 1605 self.func([10], [0]) # no non-zero weights 1606 with self.assertRaises(statistics.StatisticsError): 1607 self.func([10, 20], [0, 0]) # no non-zero weights 1608 1609 1610class TestMedian(NumericTestCase, AverageMixin): 1611 # Common tests for median and all median.* functions. 1612 def setUp(self): 1613 self.func = statistics.median 1614 1615 def prepare_data(self): 1616 """Overload method from UnivariateCommonMixin.""" 1617 data = super().prepare_data() 1618 if len(data)%2 != 1: 1619 data.append(2) 1620 return data 1621 1622 def test_even_ints(self): 1623 # Test median with an even number of int data points. 1624 data = [1, 2, 3, 4, 5, 6] 1625 assert len(data)%2 == 0 1626 self.assertEqual(self.func(data), 3.5) 1627 1628 def test_odd_ints(self): 1629 # Test median with an odd number of int data points. 1630 data = [1, 2, 3, 4, 5, 6, 9] 1631 assert len(data)%2 == 1 1632 self.assertEqual(self.func(data), 4) 1633 1634 def test_odd_fractions(self): 1635 # Test median works with an odd number of Fractions. 1636 F = Fraction 1637 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)] 1638 assert len(data)%2 == 1 1639 random.shuffle(data) 1640 self.assertEqual(self.func(data), F(3, 7)) 1641 1642 def test_even_fractions(self): 1643 # Test median works with an even number of Fractions. 1644 F = Fraction 1645 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] 1646 assert len(data)%2 == 0 1647 random.shuffle(data) 1648 self.assertEqual(self.func(data), F(1, 2)) 1649 1650 def test_odd_decimals(self): 1651 # Test median works with an odd number of Decimals. 1652 D = Decimal 1653 data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')] 1654 assert len(data)%2 == 1 1655 random.shuffle(data) 1656 self.assertEqual(self.func(data), D('4.2')) 1657 1658 def test_even_decimals(self): 1659 # Test median works with an even number of Decimals. 1660 D = Decimal 1661 data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')] 1662 assert len(data)%2 == 0 1663 random.shuffle(data) 1664 self.assertEqual(self.func(data), D('3.65')) 1665 1666 1667class TestMedianDataType(NumericTestCase, UnivariateTypeMixin): 1668 # Test conservation of data element type for median. 1669 def setUp(self): 1670 self.func = statistics.median 1671 1672 def prepare_data(self): 1673 data = list(range(15)) 1674 assert len(data)%2 == 1 1675 while data == sorted(data): 1676 random.shuffle(data) 1677 return data 1678 1679 1680class TestMedianLow(TestMedian, UnivariateTypeMixin): 1681 def setUp(self): 1682 self.func = statistics.median_low 1683 1684 def test_even_ints(self): 1685 # Test median_low with an even number of ints. 1686 data = [1, 2, 3, 4, 5, 6] 1687 assert len(data)%2 == 0 1688 self.assertEqual(self.func(data), 3) 1689 1690 def test_even_fractions(self): 1691 # Test median_low works with an even number of Fractions. 1692 F = Fraction 1693 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] 1694 assert len(data)%2 == 0 1695 random.shuffle(data) 1696 self.assertEqual(self.func(data), F(3, 7)) 1697 1698 def test_even_decimals(self): 1699 # Test median_low works with an even number of Decimals. 1700 D = Decimal 1701 data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')] 1702 assert len(data)%2 == 0 1703 random.shuffle(data) 1704 self.assertEqual(self.func(data), D('3.3')) 1705 1706 1707class TestMedianHigh(TestMedian, UnivariateTypeMixin): 1708 def setUp(self): 1709 self.func = statistics.median_high 1710 1711 def test_even_ints(self): 1712 # Test median_high with an even number of ints. 1713 data = [1, 2, 3, 4, 5, 6] 1714 assert len(data)%2 == 0 1715 self.assertEqual(self.func(data), 4) 1716 1717 def test_even_fractions(self): 1718 # Test median_high works with an even number of Fractions. 1719 F = Fraction 1720 data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)] 1721 assert len(data)%2 == 0 1722 random.shuffle(data) 1723 self.assertEqual(self.func(data), F(4, 7)) 1724 1725 def test_even_decimals(self): 1726 # Test median_high works with an even number of Decimals. 1727 D = Decimal 1728 data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')] 1729 assert len(data)%2 == 0 1730 random.shuffle(data) 1731 self.assertEqual(self.func(data), D('4.4')) 1732 1733 1734class TestMedianGrouped(TestMedian): 1735 # Test median_grouped. 1736 # Doesn't conserve data element types, so don't use TestMedianType. 1737 def setUp(self): 1738 self.func = statistics.median_grouped 1739 1740 def test_odd_number_repeated(self): 1741 # Test median.grouped with repeated median values. 1742 data = [12, 13, 14, 14, 14, 15, 15] 1743 assert len(data)%2 == 1 1744 self.assertEqual(self.func(data), 14) 1745 #--- 1746 data = [12, 13, 14, 14, 14, 14, 15] 1747 assert len(data)%2 == 1 1748 self.assertEqual(self.func(data), 13.875) 1749 #--- 1750 data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30] 1751 assert len(data)%2 == 1 1752 self.assertEqual(self.func(data, 5), 19.375) 1753 #--- 1754 data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28] 1755 assert len(data)%2 == 1 1756 self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8) 1757 1758 def test_even_number_repeated(self): 1759 # Test median.grouped with repeated median values. 1760 data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30] 1761 assert len(data)%2 == 0 1762 self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8) 1763 #--- 1764 data = [2, 3, 4, 4, 4, 5] 1765 assert len(data)%2 == 0 1766 self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8) 1767 #--- 1768 data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6] 1769 assert len(data)%2 == 0 1770 self.assertEqual(self.func(data), 4.5) 1771 #--- 1772 data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6] 1773 assert len(data)%2 == 0 1774 self.assertEqual(self.func(data), 4.75) 1775 1776 def test_repeated_single_value(self): 1777 # Override method from AverageMixin. 1778 # Yet again, failure of median_grouped to conserve the data type 1779 # causes me headaches :-( 1780 for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')): 1781 for count in (2, 5, 10, 20): 1782 data = [x]*count 1783 self.assertEqual(self.func(data), float(x)) 1784 1785 def test_odd_fractions(self): 1786 # Test median_grouped works with an odd number of Fractions. 1787 F = Fraction 1788 data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)] 1789 assert len(data)%2 == 1 1790 random.shuffle(data) 1791 self.assertEqual(self.func(data), 3.0) 1792 1793 def test_even_fractions(self): 1794 # Test median_grouped works with an even number of Fractions. 1795 F = Fraction 1796 data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)] 1797 assert len(data)%2 == 0 1798 random.shuffle(data) 1799 self.assertEqual(self.func(data), 3.25) 1800 1801 def test_odd_decimals(self): 1802 # Test median_grouped works with an odd number of Decimals. 1803 D = Decimal 1804 data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')] 1805 assert len(data)%2 == 1 1806 random.shuffle(data) 1807 self.assertEqual(self.func(data), 6.75) 1808 1809 def test_even_decimals(self): 1810 # Test median_grouped works with an even number of Decimals. 1811 D = Decimal 1812 data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')] 1813 assert len(data)%2 == 0 1814 random.shuffle(data) 1815 self.assertEqual(self.func(data), 6.5) 1816 #--- 1817 data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')] 1818 assert len(data)%2 == 0 1819 random.shuffle(data) 1820 self.assertEqual(self.func(data), 7.0) 1821 1822 def test_interval(self): 1823 # Test median_grouped with interval argument. 1824 data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75] 1825 self.assertEqual(self.func(data, 0.25), 2.875) 1826 data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75] 1827 self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8) 1828 data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340] 1829 self.assertEqual(self.func(data, 20), 265.0) 1830 1831 def test_data_type_error(self): 1832 # Test median_grouped with str, bytes data types for data and interval 1833 data = ["", "", ""] 1834 self.assertRaises(TypeError, self.func, data) 1835 #--- 1836 data = [b"", b"", b""] 1837 self.assertRaises(TypeError, self.func, data) 1838 #--- 1839 data = [1, 2, 3] 1840 interval = "" 1841 self.assertRaises(TypeError, self.func, data, interval) 1842 #--- 1843 data = [1, 2, 3] 1844 interval = b"" 1845 self.assertRaises(TypeError, self.func, data, interval) 1846 1847 1848class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin): 1849 # Test cases for the discrete version of mode. 1850 def setUp(self): 1851 self.func = statistics.mode 1852 1853 def prepare_data(self): 1854 """Overload method from UnivariateCommonMixin.""" 1855 # Make sure test data has exactly one mode. 1856 return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2] 1857 1858 def test_range_data(self): 1859 # Override test from UnivariateCommonMixin. 1860 data = range(20, 50, 3) 1861 self.assertEqual(self.func(data), 20) 1862 1863 def test_nominal_data(self): 1864 # Test mode with nominal data. 1865 data = 'abcbdb' 1866 self.assertEqual(self.func(data), 'b') 1867 data = 'fe fi fo fum fi fi'.split() 1868 self.assertEqual(self.func(data), 'fi') 1869 1870 def test_discrete_data(self): 1871 # Test mode with discrete numeric data. 1872 data = list(range(10)) 1873 for i in range(10): 1874 d = data + [i] 1875 random.shuffle(d) 1876 self.assertEqual(self.func(d), i) 1877 1878 def test_bimodal_data(self): 1879 # Test mode with bimodal data. 1880 data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9] 1881 assert data.count(2) == data.count(6) == 4 1882 # mode() should return 2, the first encountered mode 1883 self.assertEqual(self.func(data), 2) 1884 1885 def test_unique_data(self): 1886 # Test mode when data points are all unique. 1887 data = list(range(10)) 1888 # mode() should return 0, the first encountered mode 1889 self.assertEqual(self.func(data), 0) 1890 1891 def test_none_data(self): 1892 # Test that mode raises TypeError if given None as data. 1893 1894 # This test is necessary because the implementation of mode uses 1895 # collections.Counter, which accepts None and returns an empty dict. 1896 self.assertRaises(TypeError, self.func, None) 1897 1898 def test_counter_data(self): 1899 # Test that a Counter is treated like any other iterable. 1900 # We're making sure mode() first calls iter() on its input. 1901 # The concern is that a Counter of a Counter returns the original 1902 # unchanged rather than counting its keys. 1903 c = collections.Counter(a=1, b=2) 1904 # If iter() is called, mode(c) loops over the keys, ['a', 'b'], 1905 # all the counts will be 1, and the first encountered mode is 'a'. 1906 self.assertEqual(self.func(c), 'a') 1907 1908 1909class TestMultiMode(unittest.TestCase): 1910 1911 def test_basics(self): 1912 multimode = statistics.multimode 1913 self.assertEqual(multimode('aabbbbbbbbcc'), ['b']) 1914 self.assertEqual(multimode('aabbbbccddddeeffffgg'), ['b', 'd', 'f']) 1915 self.assertEqual(multimode(''), []) 1916 1917 1918class TestFMean(unittest.TestCase): 1919 1920 def test_basics(self): 1921 fmean = statistics.fmean 1922 D = Decimal 1923 F = Fraction 1924 for data, expected_mean, kind in [ 1925 ([3.5, 4.0, 5.25], 4.25, 'floats'), 1926 ([D('3.5'), D('4.0'), D('5.25')], 4.25, 'decimals'), 1927 ([F(7, 2), F(4, 1), F(21, 4)], 4.25, 'fractions'), 1928 ([True, False, True, True, False], 0.60, 'booleans'), 1929 ([3.5, 4, F(21, 4)], 4.25, 'mixed types'), 1930 ((3.5, 4.0, 5.25), 4.25, 'tuple'), 1931 (iter([3.5, 4.0, 5.25]), 4.25, 'iterator'), 1932 ]: 1933 actual_mean = fmean(data) 1934 self.assertIs(type(actual_mean), float, kind) 1935 self.assertEqual(actual_mean, expected_mean, kind) 1936 1937 def test_error_cases(self): 1938 fmean = statistics.fmean 1939 StatisticsError = statistics.StatisticsError 1940 with self.assertRaises(StatisticsError): 1941 fmean([]) # empty input 1942 with self.assertRaises(StatisticsError): 1943 fmean(iter([])) # empty iterator 1944 with self.assertRaises(TypeError): 1945 fmean(None) # non-iterable input 1946 with self.assertRaises(TypeError): 1947 fmean([10, None, 20]) # non-numeric input 1948 with self.assertRaises(TypeError): 1949 fmean() # missing data argument 1950 with self.assertRaises(TypeError): 1951 fmean([10, 20, 60], 70) # too many arguments 1952 1953 def test_special_values(self): 1954 # Rules for special values are inherited from math.fsum() 1955 fmean = statistics.fmean 1956 NaN = float('Nan') 1957 Inf = float('Inf') 1958 self.assertTrue(math.isnan(fmean([10, NaN])), 'nan') 1959 self.assertTrue(math.isnan(fmean([NaN, Inf])), 'nan and infinity') 1960 self.assertTrue(math.isinf(fmean([10, Inf])), 'infinity') 1961 with self.assertRaises(ValueError): 1962 fmean([Inf, -Inf]) 1963 1964 1965# === Tests for variances and standard deviations === 1966 1967class VarianceStdevMixin(UnivariateCommonMixin): 1968 # Mixin class holding common tests for variance and std dev. 1969 1970 # Subclasses should inherit from this before NumericTestClass, in order 1971 # to see the rel attribute below. See testShiftData for an explanation. 1972 1973 rel = 1e-12 1974 1975 def test_single_value(self): 1976 # Deviation of a single value is zero. 1977 for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')): 1978 self.assertEqual(self.func([x]), 0) 1979 1980 def test_repeated_single_value(self): 1981 # The deviation of a single repeated value is zero. 1982 for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')): 1983 for count in (2, 3, 5, 15): 1984 data = [x]*count 1985 self.assertEqual(self.func(data), 0) 1986 1987 def test_domain_error_regression(self): 1988 # Regression test for a domain error exception. 1989 # (Thanks to Geremy Condra.) 1990 data = [0.123456789012345]*10000 1991 # All the items are identical, so variance should be exactly zero. 1992 # We allow some small round-off error, but not much. 1993 result = self.func(data) 1994 self.assertApproxEqual(result, 0.0, tol=5e-17) 1995 self.assertGreaterEqual(result, 0) # A negative result must fail. 1996 1997 def test_shift_data(self): 1998 # Test that shifting the data by a constant amount does not affect 1999 # the variance or stdev. Or at least not much. 2000 2001 # Due to rounding, this test should be considered an ideal. We allow 2002 # some tolerance away from "no change at all" by setting tol and/or rel 2003 # attributes. Subclasses may set tighter or looser error tolerances. 2004 raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78] 2005 expected = self.func(raw) 2006 # Don't set shift too high, the bigger it is, the more rounding error. 2007 shift = 1e5 2008 data = [x + shift for x in raw] 2009 self.assertApproxEqual(self.func(data), expected) 2010 2011 def test_shift_data_exact(self): 2012 # Like test_shift_data, but result is always exact. 2013 raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16] 2014 assert all(x==int(x) for x in raw) 2015 expected = self.func(raw) 2016 shift = 10**9 2017 data = [x + shift for x in raw] 2018 self.assertEqual(self.func(data), expected) 2019 2020 def test_iter_list_same(self): 2021 # Test that iter data and list data give the same result. 2022 2023 # This is an explicit test that iterators and lists are treated the 2024 # same; justification for this test over and above the similar test 2025 # in UnivariateCommonMixin is that an earlier design had variance and 2026 # friends swap between one- and two-pass algorithms, which would 2027 # sometimes give different results. 2028 data = [random.uniform(-3, 8) for _ in range(1000)] 2029 expected = self.func(data) 2030 self.assertEqual(self.func(iter(data)), expected) 2031 2032 2033class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin): 2034 # Tests for population variance. 2035 def setUp(self): 2036 self.func = statistics.pvariance 2037 2038 def test_exact_uniform(self): 2039 # Test the variance against an exact result for uniform data. 2040 data = list(range(10000)) 2041 random.shuffle(data) 2042 expected = (10000**2 - 1)/12 # Exact value. 2043 self.assertEqual(self.func(data), expected) 2044 2045 def test_ints(self): 2046 # Test population variance with int data. 2047 data = [4, 7, 13, 16] 2048 exact = 22.5 2049 self.assertEqual(self.func(data), exact) 2050 2051 def test_fractions(self): 2052 # Test population variance with Fraction data. 2053 F = Fraction 2054 data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)] 2055 exact = F(3, 8) 2056 result = self.func(data) 2057 self.assertEqual(result, exact) 2058 self.assertIsInstance(result, Fraction) 2059 2060 def test_decimals(self): 2061 # Test population variance with Decimal data. 2062 D = Decimal 2063 data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")] 2064 exact = D('0.096875') 2065 result = self.func(data) 2066 self.assertEqual(result, exact) 2067 self.assertIsInstance(result, Decimal) 2068 2069 def test_accuracy_bug_20499(self): 2070 data = [0, 0, 1] 2071 exact = 2 / 9 2072 result = self.func(data) 2073 self.assertEqual(result, exact) 2074 self.assertIsInstance(result, float) 2075 2076 2077class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin): 2078 # Tests for sample variance. 2079 def setUp(self): 2080 self.func = statistics.variance 2081 2082 def test_single_value(self): 2083 # Override method from VarianceStdevMixin. 2084 for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')): 2085 self.assertRaises(statistics.StatisticsError, self.func, [x]) 2086 2087 def test_ints(self): 2088 # Test sample variance with int data. 2089 data = [4, 7, 13, 16] 2090 exact = 30 2091 self.assertEqual(self.func(data), exact) 2092 2093 def test_fractions(self): 2094 # Test sample variance with Fraction data. 2095 F = Fraction 2096 data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)] 2097 exact = F(1, 2) 2098 result = self.func(data) 2099 self.assertEqual(result, exact) 2100 self.assertIsInstance(result, Fraction) 2101 2102 def test_decimals(self): 2103 # Test sample variance with Decimal data. 2104 D = Decimal 2105 data = [D(2), D(2), D(7), D(9)] 2106 exact = 4*D('9.5')/D(3) 2107 result = self.func(data) 2108 self.assertEqual(result, exact) 2109 self.assertIsInstance(result, Decimal) 2110 2111 def test_center_not_at_mean(self): 2112 data = (1.0, 2.0) 2113 self.assertEqual(self.func(data), 0.5) 2114 self.assertEqual(self.func(data, xbar=2.0), 1.0) 2115 2116 def test_accuracy_bug_20499(self): 2117 data = [0, 0, 2] 2118 exact = 4 / 3 2119 result = self.func(data) 2120 self.assertEqual(result, exact) 2121 self.assertIsInstance(result, float) 2122 2123class TestPStdev(VarianceStdevMixin, NumericTestCase): 2124 # Tests for population standard deviation. 2125 def setUp(self): 2126 self.func = statistics.pstdev 2127 2128 def test_compare_to_variance(self): 2129 # Test that stdev is, in fact, the square root of variance. 2130 data = [random.uniform(-17, 24) for _ in range(1000)] 2131 expected = math.sqrt(statistics.pvariance(data)) 2132 self.assertEqual(self.func(data), expected) 2133 2134 def test_center_not_at_mean(self): 2135 # See issue: 40855 2136 data = (3, 6, 7, 10) 2137 self.assertEqual(self.func(data), 2.5) 2138 self.assertEqual(self.func(data, mu=0.5), 6.5) 2139 2140class TestStdev(VarianceStdevMixin, NumericTestCase): 2141 # Tests for sample standard deviation. 2142 def setUp(self): 2143 self.func = statistics.stdev 2144 2145 def test_single_value(self): 2146 # Override method from VarianceStdevMixin. 2147 for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')): 2148 self.assertRaises(statistics.StatisticsError, self.func, [x]) 2149 2150 def test_compare_to_variance(self): 2151 # Test that stdev is, in fact, the square root of variance. 2152 data = [random.uniform(-2, 9) for _ in range(1000)] 2153 expected = math.sqrt(statistics.variance(data)) 2154 self.assertEqual(self.func(data), expected) 2155 2156 def test_center_not_at_mean(self): 2157 data = (1.0, 2.0) 2158 self.assertEqual(self.func(data, xbar=2.0), 1.0) 2159 2160class TestGeometricMean(unittest.TestCase): 2161 2162 def test_basics(self): 2163 geometric_mean = statistics.geometric_mean 2164 self.assertAlmostEqual(geometric_mean([54, 24, 36]), 36.0) 2165 self.assertAlmostEqual(geometric_mean([4.0, 9.0]), 6.0) 2166 self.assertAlmostEqual(geometric_mean([17.625]), 17.625) 2167 2168 random.seed(86753095551212) 2169 for rng in [ 2170 range(1, 100), 2171 range(1, 1_000), 2172 range(1, 10_000), 2173 range(500, 10_000, 3), 2174 range(10_000, 500, -3), 2175 [12, 17, 13, 5, 120, 7], 2176 [random.expovariate(50.0) for i in range(1_000)], 2177 [random.lognormvariate(20.0, 3.0) for i in range(2_000)], 2178 [random.triangular(2000, 3000, 2200) for i in range(3_000)], 2179 ]: 2180 gm_decimal = math.prod(map(Decimal, rng)) ** (Decimal(1) / len(rng)) 2181 gm_float = geometric_mean(rng) 2182 self.assertTrue(math.isclose(gm_float, float(gm_decimal))) 2183 2184 def test_various_input_types(self): 2185 geometric_mean = statistics.geometric_mean 2186 D = Decimal 2187 F = Fraction 2188 # https://www.wolframalpha.com/input/?i=geometric+mean+3.5,+4.0,+5.25 2189 expected_mean = 4.18886 2190 for data, kind in [ 2191 ([3.5, 4.0, 5.25], 'floats'), 2192 ([D('3.5'), D('4.0'), D('5.25')], 'decimals'), 2193 ([F(7, 2), F(4, 1), F(21, 4)], 'fractions'), 2194 ([3.5, 4, F(21, 4)], 'mixed types'), 2195 ((3.5, 4.0, 5.25), 'tuple'), 2196 (iter([3.5, 4.0, 5.25]), 'iterator'), 2197 ]: 2198 actual_mean = geometric_mean(data) 2199 self.assertIs(type(actual_mean), float, kind) 2200 self.assertAlmostEqual(actual_mean, expected_mean, places=5) 2201 2202 def test_big_and_small(self): 2203 geometric_mean = statistics.geometric_mean 2204 2205 # Avoid overflow to infinity 2206 large = 2.0 ** 1000 2207 big_gm = geometric_mean([54.0 * large, 24.0 * large, 36.0 * large]) 2208 self.assertTrue(math.isclose(big_gm, 36.0 * large)) 2209 self.assertFalse(math.isinf(big_gm)) 2210 2211 # Avoid underflow to zero 2212 small = 2.0 ** -1000 2213 small_gm = geometric_mean([54.0 * small, 24.0 * small, 36.0 * small]) 2214 self.assertTrue(math.isclose(small_gm, 36.0 * small)) 2215 self.assertNotEqual(small_gm, 0.0) 2216 2217 def test_error_cases(self): 2218 geometric_mean = statistics.geometric_mean 2219 StatisticsError = statistics.StatisticsError 2220 with self.assertRaises(StatisticsError): 2221 geometric_mean([]) # empty input 2222 with self.assertRaises(StatisticsError): 2223 geometric_mean([3.5, 0.0, 5.25]) # zero input 2224 with self.assertRaises(StatisticsError): 2225 geometric_mean([3.5, -4.0, 5.25]) # negative input 2226 with self.assertRaises(StatisticsError): 2227 geometric_mean(iter([])) # empty iterator 2228 with self.assertRaises(TypeError): 2229 geometric_mean(None) # non-iterable input 2230 with self.assertRaises(TypeError): 2231 geometric_mean([10, None, 20]) # non-numeric input 2232 with self.assertRaises(TypeError): 2233 geometric_mean() # missing data argument 2234 with self.assertRaises(TypeError): 2235 geometric_mean([10, 20, 60], 70) # too many arguments 2236 2237 def test_special_values(self): 2238 # Rules for special values are inherited from math.fsum() 2239 geometric_mean = statistics.geometric_mean 2240 NaN = float('Nan') 2241 Inf = float('Inf') 2242 self.assertTrue(math.isnan(geometric_mean([10, NaN])), 'nan') 2243 self.assertTrue(math.isnan(geometric_mean([NaN, Inf])), 'nan and infinity') 2244 self.assertTrue(math.isinf(geometric_mean([10, Inf])), 'infinity') 2245 with self.assertRaises(ValueError): 2246 geometric_mean([Inf, -Inf]) 2247 2248 2249class TestQuantiles(unittest.TestCase): 2250 2251 def test_specific_cases(self): 2252 # Match results computed by hand and cross-checked 2253 # against the PERCENTILE.EXC function in MS Excel. 2254 quantiles = statistics.quantiles 2255 data = [120, 200, 250, 320, 350] 2256 random.shuffle(data) 2257 for n, expected in [ 2258 (1, []), 2259 (2, [250.0]), 2260 (3, [200.0, 320.0]), 2261 (4, [160.0, 250.0, 335.0]), 2262 (5, [136.0, 220.0, 292.0, 344.0]), 2263 (6, [120.0, 200.0, 250.0, 320.0, 350.0]), 2264 (8, [100.0, 160.0, 212.5, 250.0, 302.5, 335.0, 357.5]), 2265 (10, [88.0, 136.0, 184.0, 220.0, 250.0, 292.0, 326.0, 344.0, 362.0]), 2266 (12, [80.0, 120.0, 160.0, 200.0, 225.0, 250.0, 285.0, 320.0, 335.0, 2267 350.0, 365.0]), 2268 (15, [72.0, 104.0, 136.0, 168.0, 200.0, 220.0, 240.0, 264.0, 292.0, 2269 320.0, 332.0, 344.0, 356.0, 368.0]), 2270 ]: 2271 self.assertEqual(expected, quantiles(data, n=n)) 2272 self.assertEqual(len(quantiles(data, n=n)), n - 1) 2273 # Preserve datatype when possible 2274 for datatype in (float, Decimal, Fraction): 2275 result = quantiles(map(datatype, data), n=n) 2276 self.assertTrue(all(type(x) == datatype) for x in result) 2277 self.assertEqual(result, list(map(datatype, expected))) 2278 # Quantiles should be idempotent 2279 if len(expected) >= 2: 2280 self.assertEqual(quantiles(expected, n=n), expected) 2281 # Cross-check against method='inclusive' which should give 2282 # the same result after adding in minimum and maximum values 2283 # extrapolated from the two lowest and two highest points. 2284 sdata = sorted(data) 2285 lo = 2 * sdata[0] - sdata[1] 2286 hi = 2 * sdata[-1] - sdata[-2] 2287 padded_data = data + [lo, hi] 2288 self.assertEqual( 2289 quantiles(data, n=n), 2290 quantiles(padded_data, n=n, method='inclusive'), 2291 (n, data), 2292 ) 2293 # Invariant under translation and scaling 2294 def f(x): 2295 return 3.5 * x - 1234.675 2296 exp = list(map(f, expected)) 2297 act = quantiles(map(f, data), n=n) 2298 self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act))) 2299 # Q2 agrees with median() 2300 for k in range(2, 60): 2301 data = random.choices(range(100), k=k) 2302 q1, q2, q3 = quantiles(data) 2303 self.assertEqual(q2, statistics.median(data)) 2304 2305 def test_specific_cases_inclusive(self): 2306 # Match results computed by hand and cross-checked 2307 # against the PERCENTILE.INC function in MS Excel 2308 # and against the quantile() function in SciPy. 2309 quantiles = statistics.quantiles 2310 data = [100, 200, 400, 800] 2311 random.shuffle(data) 2312 for n, expected in [ 2313 (1, []), 2314 (2, [300.0]), 2315 (3, [200.0, 400.0]), 2316 (4, [175.0, 300.0, 500.0]), 2317 (5, [160.0, 240.0, 360.0, 560.0]), 2318 (6, [150.0, 200.0, 300.0, 400.0, 600.0]), 2319 (8, [137.5, 175, 225.0, 300.0, 375.0, 500.0,650.0]), 2320 (10, [130.0, 160.0, 190.0, 240.0, 300.0, 360.0, 440.0, 560.0, 680.0]), 2321 (12, [125.0, 150.0, 175.0, 200.0, 250.0, 300.0, 350.0, 400.0, 2322 500.0, 600.0, 700.0]), 2323 (15, [120.0, 140.0, 160.0, 180.0, 200.0, 240.0, 280.0, 320.0, 360.0, 2324 400.0, 480.0, 560.0, 640.0, 720.0]), 2325 ]: 2326 self.assertEqual(expected, quantiles(data, n=n, method="inclusive")) 2327 self.assertEqual(len(quantiles(data, n=n, method="inclusive")), n - 1) 2328 # Preserve datatype when possible 2329 for datatype in (float, Decimal, Fraction): 2330 result = quantiles(map(datatype, data), n=n, method="inclusive") 2331 self.assertTrue(all(type(x) == datatype) for x in result) 2332 self.assertEqual(result, list(map(datatype, expected))) 2333 # Invariant under translation and scaling 2334 def f(x): 2335 return 3.5 * x - 1234.675 2336 exp = list(map(f, expected)) 2337 act = quantiles(map(f, data), n=n, method="inclusive") 2338 self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act))) 2339 # Natural deciles 2340 self.assertEqual(quantiles([0, 100], n=10, method='inclusive'), 2341 [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]) 2342 self.assertEqual(quantiles(range(0, 101), n=10, method='inclusive'), 2343 [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]) 2344 # Whenever n is smaller than the number of data points, running 2345 # method='inclusive' should give the same result as method='exclusive' 2346 # after the two included extreme points are removed. 2347 data = [random.randrange(10_000) for i in range(501)] 2348 actual = quantiles(data, n=32, method='inclusive') 2349 data.remove(min(data)) 2350 data.remove(max(data)) 2351 expected = quantiles(data, n=32) 2352 self.assertEqual(expected, actual) 2353 # Q2 agrees with median() 2354 for k in range(2, 60): 2355 data = random.choices(range(100), k=k) 2356 q1, q2, q3 = quantiles(data, method='inclusive') 2357 self.assertEqual(q2, statistics.median(data)) 2358 2359 def test_equal_inputs(self): 2360 quantiles = statistics.quantiles 2361 for n in range(2, 10): 2362 data = [10.0] * n 2363 self.assertEqual(quantiles(data), [10.0, 10.0, 10.0]) 2364 self.assertEqual(quantiles(data, method='inclusive'), 2365 [10.0, 10.0, 10.0]) 2366 2367 def test_equal_sized_groups(self): 2368 quantiles = statistics.quantiles 2369 total = 10_000 2370 data = [random.expovariate(0.2) for i in range(total)] 2371 while len(set(data)) != total: 2372 data.append(random.expovariate(0.2)) 2373 data.sort() 2374 2375 # Cases where the group size exactly divides the total 2376 for n in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000): 2377 group_size = total // n 2378 self.assertEqual( 2379 [bisect.bisect(data, q) for q in quantiles(data, n=n)], 2380 list(range(group_size, total, group_size))) 2381 2382 # When the group sizes can't be exactly equal, they should 2383 # differ by no more than one 2384 for n in (13, 19, 59, 109, 211, 571, 1019, 1907, 5261, 9769): 2385 group_sizes = {total // n, total // n + 1} 2386 pos = [bisect.bisect(data, q) for q in quantiles(data, n=n)] 2387 sizes = {q - p for p, q in zip(pos, pos[1:])} 2388 self.assertTrue(sizes <= group_sizes) 2389 2390 def test_error_cases(self): 2391 quantiles = statistics.quantiles 2392 StatisticsError = statistics.StatisticsError 2393 with self.assertRaises(TypeError): 2394 quantiles() # Missing arguments 2395 with self.assertRaises(TypeError): 2396 quantiles([10, 20, 30], 13, n=4) # Too many arguments 2397 with self.assertRaises(TypeError): 2398 quantiles([10, 20, 30], 4) # n is a positional argument 2399 with self.assertRaises(StatisticsError): 2400 quantiles([10, 20, 30], n=0) # n is zero 2401 with self.assertRaises(StatisticsError): 2402 quantiles([10, 20, 30], n=-1) # n is negative 2403 with self.assertRaises(TypeError): 2404 quantiles([10, 20, 30], n=1.5) # n is not an integer 2405 with self.assertRaises(ValueError): 2406 quantiles([10, 20, 30], method='X') # method is unknown 2407 with self.assertRaises(StatisticsError): 2408 quantiles([10], n=4) # not enough data points 2409 with self.assertRaises(TypeError): 2410 quantiles([10, None, 30], n=4) # data is non-numeric 2411 2412 2413class TestBivariateStatistics(unittest.TestCase): 2414 2415 def test_unequal_size_error(self): 2416 for x, y in [ 2417 ([1, 2, 3], [1, 2]), 2418 ([1, 2], [1, 2, 3]), 2419 ]: 2420 with self.assertRaises(statistics.StatisticsError): 2421 statistics.covariance(x, y) 2422 with self.assertRaises(statistics.StatisticsError): 2423 statistics.correlation(x, y) 2424 with self.assertRaises(statistics.StatisticsError): 2425 statistics.linear_regression(x, y) 2426 2427 def test_small_sample_error(self): 2428 for x, y in [ 2429 ([], []), 2430 ([], [1, 2,]), 2431 ([1, 2,], []), 2432 ([1,], [1,]), 2433 ([1,], [1, 2,]), 2434 ([1, 2,], [1,]), 2435 ]: 2436 with self.assertRaises(statistics.StatisticsError): 2437 statistics.covariance(x, y) 2438 with self.assertRaises(statistics.StatisticsError): 2439 statistics.correlation(x, y) 2440 with self.assertRaises(statistics.StatisticsError): 2441 statistics.linear_regression(x, y) 2442 2443 2444class TestCorrelationAndCovariance(unittest.TestCase): 2445 2446 def test_results(self): 2447 for x, y, result in [ 2448 ([1, 2, 3], [1, 2, 3], 1), 2449 ([1, 2, 3], [-1, -2, -3], -1), 2450 ([1, 2, 3], [3, 2, 1], -1), 2451 ([1, 2, 3], [1, 2, 1], 0), 2452 ([1, 2, 3], [1, 3, 2], 0.5), 2453 ]: 2454 self.assertAlmostEqual(statistics.correlation(x, y), result) 2455 self.assertAlmostEqual(statistics.covariance(x, y), result) 2456 2457 def test_different_scales(self): 2458 x = [1, 2, 3] 2459 y = [10, 30, 20] 2460 self.assertAlmostEqual(statistics.correlation(x, y), 0.5) 2461 self.assertAlmostEqual(statistics.covariance(x, y), 5) 2462 2463 y = [.1, .2, .3] 2464 self.assertAlmostEqual(statistics.correlation(x, y), 1) 2465 self.assertAlmostEqual(statistics.covariance(x, y), 0.1) 2466 2467 2468class TestLinearRegression(unittest.TestCase): 2469 2470 def test_constant_input_error(self): 2471 x = [1, 1, 1,] 2472 y = [1, 2, 3,] 2473 with self.assertRaises(statistics.StatisticsError): 2474 statistics.linear_regression(x, y) 2475 2476 def test_results(self): 2477 for x, y, true_intercept, true_slope in [ 2478 ([1, 2, 3], [0, 0, 0], 0, 0), 2479 ([1, 2, 3], [1, 2, 3], 0, 1), 2480 ([1, 2, 3], [100, 100, 100], 100, 0), 2481 ([1, 2, 3], [12, 14, 16], 10, 2), 2482 ([1, 2, 3], [-1, -2, -3], 0, -1), 2483 ([1, 2, 3], [21, 22, 23], 20, 1), 2484 ([1, 2, 3], [5.1, 5.2, 5.3], 5, 0.1), 2485 ]: 2486 slope, intercept = statistics.linear_regression(x, y) 2487 self.assertAlmostEqual(intercept, true_intercept) 2488 self.assertAlmostEqual(slope, true_slope) 2489 2490 2491class TestNormalDist: 2492 2493 # General note on precision: The pdf(), cdf(), and overlap() methods 2494 # depend on functions in the math libraries that do not make 2495 # explicit accuracy guarantees. Accordingly, some of the accuracy 2496 # tests below may fail if the underlying math functions are 2497 # inaccurate. There isn't much we can do about this short of 2498 # implementing our own implementations from scratch. 2499 2500 def test_slots(self): 2501 nd = self.module.NormalDist(300, 23) 2502 with self.assertRaises(TypeError): 2503 vars(nd) 2504 self.assertEqual(tuple(nd.__slots__), ('_mu', '_sigma')) 2505 2506 def test_instantiation_and_attributes(self): 2507 nd = self.module.NormalDist(500, 17) 2508 self.assertEqual(nd.mean, 500) 2509 self.assertEqual(nd.stdev, 17) 2510 self.assertEqual(nd.variance, 17**2) 2511 2512 # default arguments 2513 nd = self.module.NormalDist() 2514 self.assertEqual(nd.mean, 0) 2515 self.assertEqual(nd.stdev, 1) 2516 self.assertEqual(nd.variance, 1**2) 2517 2518 # error case: negative sigma 2519 with self.assertRaises(self.module.StatisticsError): 2520 self.module.NormalDist(500, -10) 2521 2522 # verify that subclass type is honored 2523 class NewNormalDist(self.module.NormalDist): 2524 pass 2525 nnd = NewNormalDist(200, 5) 2526 self.assertEqual(type(nnd), NewNormalDist) 2527 2528 def test_alternative_constructor(self): 2529 NormalDist = self.module.NormalDist 2530 data = [96, 107, 90, 92, 110] 2531 # list input 2532 self.assertEqual(NormalDist.from_samples(data), NormalDist(99, 9)) 2533 # tuple input 2534 self.assertEqual(NormalDist.from_samples(tuple(data)), NormalDist(99, 9)) 2535 # iterator input 2536 self.assertEqual(NormalDist.from_samples(iter(data)), NormalDist(99, 9)) 2537 # error cases 2538 with self.assertRaises(self.module.StatisticsError): 2539 NormalDist.from_samples([]) # empty input 2540 with self.assertRaises(self.module.StatisticsError): 2541 NormalDist.from_samples([10]) # only one input 2542 2543 # verify that subclass type is honored 2544 class NewNormalDist(NormalDist): 2545 pass 2546 nnd = NewNormalDist.from_samples(data) 2547 self.assertEqual(type(nnd), NewNormalDist) 2548 2549 def test_sample_generation(self): 2550 NormalDist = self.module.NormalDist 2551 mu, sigma = 10_000, 3.0 2552 X = NormalDist(mu, sigma) 2553 n = 1_000 2554 data = X.samples(n) 2555 self.assertEqual(len(data), n) 2556 self.assertEqual(set(map(type, data)), {float}) 2557 # mean(data) expected to fall within 8 standard deviations 2558 xbar = self.module.mean(data) 2559 self.assertTrue(mu - sigma*8 <= xbar <= mu + sigma*8) 2560 2561 # verify that seeding makes reproducible sequences 2562 n = 100 2563 data1 = X.samples(n, seed='happiness and joy') 2564 data2 = X.samples(n, seed='trouble and despair') 2565 data3 = X.samples(n, seed='happiness and joy') 2566 data4 = X.samples(n, seed='trouble and despair') 2567 self.assertEqual(data1, data3) 2568 self.assertEqual(data2, data4) 2569 self.assertNotEqual(data1, data2) 2570 2571 def test_pdf(self): 2572 NormalDist = self.module.NormalDist 2573 X = NormalDist(100, 15) 2574 # Verify peak around center 2575 self.assertLess(X.pdf(99), X.pdf(100)) 2576 self.assertLess(X.pdf(101), X.pdf(100)) 2577 # Test symmetry 2578 for i in range(50): 2579 self.assertAlmostEqual(X.pdf(100 - i), X.pdf(100 + i)) 2580 # Test vs CDF 2581 dx = 2.0 ** -10 2582 for x in range(90, 111): 2583 est_pdf = (X.cdf(x + dx) - X.cdf(x)) / dx 2584 self.assertAlmostEqual(X.pdf(x), est_pdf, places=4) 2585 # Test vs table of known values -- CRC 26th Edition 2586 Z = NormalDist() 2587 for x, px in enumerate([ 2588 0.3989, 0.3989, 0.3989, 0.3988, 0.3986, 2589 0.3984, 0.3982, 0.3980, 0.3977, 0.3973, 2590 0.3970, 0.3965, 0.3961, 0.3956, 0.3951, 2591 0.3945, 0.3939, 0.3932, 0.3925, 0.3918, 2592 0.3910, 0.3902, 0.3894, 0.3885, 0.3876, 2593 0.3867, 0.3857, 0.3847, 0.3836, 0.3825, 2594 0.3814, 0.3802, 0.3790, 0.3778, 0.3765, 2595 0.3752, 0.3739, 0.3725, 0.3712, 0.3697, 2596 0.3683, 0.3668, 0.3653, 0.3637, 0.3621, 2597 0.3605, 0.3589, 0.3572, 0.3555, 0.3538, 2598 ]): 2599 self.assertAlmostEqual(Z.pdf(x / 100.0), px, places=4) 2600 self.assertAlmostEqual(Z.pdf(-x / 100.0), px, places=4) 2601 # Error case: variance is zero 2602 Y = NormalDist(100, 0) 2603 with self.assertRaises(self.module.StatisticsError): 2604 Y.pdf(90) 2605 # Special values 2606 self.assertEqual(X.pdf(float('-Inf')), 0.0) 2607 self.assertEqual(X.pdf(float('Inf')), 0.0) 2608 self.assertTrue(math.isnan(X.pdf(float('NaN')))) 2609 2610 def test_cdf(self): 2611 NormalDist = self.module.NormalDist 2612 X = NormalDist(100, 15) 2613 cdfs = [X.cdf(x) for x in range(1, 200)] 2614 self.assertEqual(set(map(type, cdfs)), {float}) 2615 # Verify montonic 2616 self.assertEqual(cdfs, sorted(cdfs)) 2617 # Verify center (should be exact) 2618 self.assertEqual(X.cdf(100), 0.50) 2619 # Check against a table of known values 2620 # https://en.wikipedia.org/wiki/Standard_normal_table#Cumulative 2621 Z = NormalDist() 2622 for z, cum_prob in [ 2623 (0.00, 0.50000), (0.01, 0.50399), (0.02, 0.50798), 2624 (0.14, 0.55567), (0.29, 0.61409), (0.33, 0.62930), 2625 (0.54, 0.70540), (0.60, 0.72575), (1.17, 0.87900), 2626 (1.60, 0.94520), (2.05, 0.97982), (2.89, 0.99807), 2627 (3.52, 0.99978), (3.98, 0.99997), (4.07, 0.99998), 2628 ]: 2629 self.assertAlmostEqual(Z.cdf(z), cum_prob, places=5) 2630 self.assertAlmostEqual(Z.cdf(-z), 1.0 - cum_prob, places=5) 2631 # Error case: variance is zero 2632 Y = NormalDist(100, 0) 2633 with self.assertRaises(self.module.StatisticsError): 2634 Y.cdf(90) 2635 # Special values 2636 self.assertEqual(X.cdf(float('-Inf')), 0.0) 2637 self.assertEqual(X.cdf(float('Inf')), 1.0) 2638 self.assertTrue(math.isnan(X.cdf(float('NaN')))) 2639 2640 @support.skip_if_pgo_task 2641 def test_inv_cdf(self): 2642 NormalDist = self.module.NormalDist 2643 2644 # Center case should be exact. 2645 iq = NormalDist(100, 15) 2646 self.assertEqual(iq.inv_cdf(0.50), iq.mean) 2647 2648 # Test versus a published table of known percentage points. 2649 # See the second table at the bottom of the page here: 2650 # http://people.bath.ac.uk/masss/tables/normaltable.pdf 2651 Z = NormalDist() 2652 pp = {5.0: (0.000, 1.645, 2.576, 3.291, 3.891, 2653 4.417, 4.892, 5.327, 5.731, 6.109), 2654 2.5: (0.674, 1.960, 2.807, 3.481, 4.056, 2655 4.565, 5.026, 5.451, 5.847, 6.219), 2656 1.0: (1.282, 2.326, 3.090, 3.719, 4.265, 2657 4.753, 5.199, 5.612, 5.998, 6.361)} 2658 for base, row in pp.items(): 2659 for exp, x in enumerate(row, start=1): 2660 p = base * 10.0 ** (-exp) 2661 self.assertAlmostEqual(-Z.inv_cdf(p), x, places=3) 2662 p = 1.0 - p 2663 self.assertAlmostEqual(Z.inv_cdf(p), x, places=3) 2664 2665 # Match published example for MS Excel 2666 # https://support.office.com/en-us/article/norm-inv-function-54b30935-fee7-493c-bedb-2278a9db7e13 2667 self.assertAlmostEqual(NormalDist(40, 1.5).inv_cdf(0.908789), 42.000002) 2668 2669 # One million equally spaced probabilities 2670 n = 2**20 2671 for p in range(1, n): 2672 p /= n 2673 self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p) 2674 2675 # One hundred ever smaller probabilities to test tails out to 2676 # extreme probabilities: 1 / 2**50 and (2**50-1) / 2 ** 50 2677 for e in range(1, 51): 2678 p = 2.0 ** (-e) 2679 self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p) 2680 p = 1.0 - p 2681 self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p) 2682 2683 # Now apply cdf() first. Near the tails, the round-trip loses 2684 # precision and is ill-conditioned (small changes in the inputs 2685 # give large changes in the output), so only check to 5 places. 2686 for x in range(200): 2687 self.assertAlmostEqual(iq.inv_cdf(iq.cdf(x)), x, places=5) 2688 2689 # Error cases: 2690 with self.assertRaises(self.module.StatisticsError): 2691 iq.inv_cdf(0.0) # p is zero 2692 with self.assertRaises(self.module.StatisticsError): 2693 iq.inv_cdf(-0.1) # p under zero 2694 with self.assertRaises(self.module.StatisticsError): 2695 iq.inv_cdf(1.0) # p is one 2696 with self.assertRaises(self.module.StatisticsError): 2697 iq.inv_cdf(1.1) # p over one 2698 with self.assertRaises(self.module.StatisticsError): 2699 iq = NormalDist(100, 0) # sigma is zero 2700 iq.inv_cdf(0.5) 2701 2702 # Special values 2703 self.assertTrue(math.isnan(Z.inv_cdf(float('NaN')))) 2704 2705 def test_quantiles(self): 2706 # Quartiles of a standard normal distribution 2707 Z = self.module.NormalDist() 2708 for n, expected in [ 2709 (1, []), 2710 (2, [0.0]), 2711 (3, [-0.4307, 0.4307]), 2712 (4 ,[-0.6745, 0.0, 0.6745]), 2713 ]: 2714 actual = Z.quantiles(n=n) 2715 self.assertTrue(all(math.isclose(e, a, abs_tol=0.0001) 2716 for e, a in zip(expected, actual))) 2717 2718 def test_overlap(self): 2719 NormalDist = self.module.NormalDist 2720 2721 # Match examples from Imman and Bradley 2722 for X1, X2, published_result in [ 2723 (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0), 0.80258), 2724 (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0), 0.60993), 2725 ]: 2726 self.assertAlmostEqual(X1.overlap(X2), published_result, places=4) 2727 self.assertAlmostEqual(X2.overlap(X1), published_result, places=4) 2728 2729 # Check against integration of the PDF 2730 def overlap_numeric(X, Y, *, steps=8_192, z=5): 2731 'Numerical integration cross-check for overlap() ' 2732 fsum = math.fsum 2733 center = (X.mean + Y.mean) / 2.0 2734 width = z * max(X.stdev, Y.stdev) 2735 start = center - width 2736 dx = 2.0 * width / steps 2737 x_arr = [start + i*dx for i in range(steps)] 2738 xp = list(map(X.pdf, x_arr)) 2739 yp = list(map(Y.pdf, x_arr)) 2740 total = max(fsum(xp), fsum(yp)) 2741 return fsum(map(min, xp, yp)) / total 2742 2743 for X1, X2 in [ 2744 # Examples from Imman and Bradley 2745 (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0)), 2746 (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)), 2747 # Example from https://www.rasch.org/rmt/rmt101r.htm 2748 (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)), 2749 # Gender heights from http://www.usablestats.com/lessons/normal 2750 (NormalDist(70, 4), NormalDist(65, 3.5)), 2751 # Misc cases with equal standard deviations 2752 (NormalDist(100, 15), NormalDist(110, 15)), 2753 (NormalDist(-100, 15), NormalDist(110, 15)), 2754 (NormalDist(-100, 15), NormalDist(-110, 15)), 2755 # Misc cases with unequal standard deviations 2756 (NormalDist(100, 12), NormalDist(100, 15)), 2757 (NormalDist(100, 12), NormalDist(110, 15)), 2758 (NormalDist(100, 12), NormalDist(150, 15)), 2759 (NormalDist(100, 12), NormalDist(150, 35)), 2760 # Misc cases with small values 2761 (NormalDist(1.000, 0.002), NormalDist(1.001, 0.003)), 2762 (NormalDist(1.000, 0.002), NormalDist(1.006, 0.0003)), 2763 (NormalDist(1.000, 0.002), NormalDist(1.001, 0.099)), 2764 ]: 2765 self.assertAlmostEqual(X1.overlap(X2), overlap_numeric(X1, X2), places=5) 2766 self.assertAlmostEqual(X2.overlap(X1), overlap_numeric(X1, X2), places=5) 2767 2768 # Error cases 2769 X = NormalDist() 2770 with self.assertRaises(TypeError): 2771 X.overlap() # too few arguments 2772 with self.assertRaises(TypeError): 2773 X.overlap(X, X) # too may arguments 2774 with self.assertRaises(TypeError): 2775 X.overlap(None) # right operand not a NormalDist 2776 with self.assertRaises(self.module.StatisticsError): 2777 X.overlap(NormalDist(1, 0)) # right operand sigma is zero 2778 with self.assertRaises(self.module.StatisticsError): 2779 NormalDist(1, 0).overlap(X) # left operand sigma is zero 2780 2781 def test_zscore(self): 2782 NormalDist = self.module.NormalDist 2783 X = NormalDist(100, 15) 2784 self.assertEqual(X.zscore(142), 2.8) 2785 self.assertEqual(X.zscore(58), -2.8) 2786 self.assertEqual(X.zscore(100), 0.0) 2787 with self.assertRaises(TypeError): 2788 X.zscore() # too few arguments 2789 with self.assertRaises(TypeError): 2790 X.zscore(1, 1) # too may arguments 2791 with self.assertRaises(TypeError): 2792 X.zscore(None) # non-numeric type 2793 with self.assertRaises(self.module.StatisticsError): 2794 NormalDist(1, 0).zscore(100) # sigma is zero 2795 2796 def test_properties(self): 2797 X = self.module.NormalDist(100, 15) 2798 self.assertEqual(X.mean, 100) 2799 self.assertEqual(X.median, 100) 2800 self.assertEqual(X.mode, 100) 2801 self.assertEqual(X.stdev, 15) 2802 self.assertEqual(X.variance, 225) 2803 2804 def test_same_type_addition_and_subtraction(self): 2805 NormalDist = self.module.NormalDist 2806 X = NormalDist(100, 12) 2807 Y = NormalDist(40, 5) 2808 self.assertEqual(X + Y, NormalDist(140, 13)) # __add__ 2809 self.assertEqual(X - Y, NormalDist(60, 13)) # __sub__ 2810 2811 def test_translation_and_scaling(self): 2812 NormalDist = self.module.NormalDist 2813 X = NormalDist(100, 15) 2814 y = 10 2815 self.assertEqual(+X, NormalDist(100, 15)) # __pos__ 2816 self.assertEqual(-X, NormalDist(-100, 15)) # __neg__ 2817 self.assertEqual(X + y, NormalDist(110, 15)) # __add__ 2818 self.assertEqual(y + X, NormalDist(110, 15)) # __radd__ 2819 self.assertEqual(X - y, NormalDist(90, 15)) # __sub__ 2820 self.assertEqual(y - X, NormalDist(-90, 15)) # __rsub__ 2821 self.assertEqual(X * y, NormalDist(1000, 150)) # __mul__ 2822 self.assertEqual(y * X, NormalDist(1000, 150)) # __rmul__ 2823 self.assertEqual(X / y, NormalDist(10, 1.5)) # __truediv__ 2824 with self.assertRaises(TypeError): # __rtruediv__ 2825 y / X 2826 2827 def test_unary_operations(self): 2828 NormalDist = self.module.NormalDist 2829 X = NormalDist(100, 12) 2830 Y = +X 2831 self.assertIsNot(X, Y) 2832 self.assertEqual(X.mean, Y.mean) 2833 self.assertEqual(X.stdev, Y.stdev) 2834 Y = -X 2835 self.assertIsNot(X, Y) 2836 self.assertEqual(X.mean, -Y.mean) 2837 self.assertEqual(X.stdev, Y.stdev) 2838 2839 def test_equality(self): 2840 NormalDist = self.module.NormalDist 2841 nd1 = NormalDist() 2842 nd2 = NormalDist(2, 4) 2843 nd3 = NormalDist() 2844 nd4 = NormalDist(2, 4) 2845 nd5 = NormalDist(2, 8) 2846 nd6 = NormalDist(8, 4) 2847 self.assertNotEqual(nd1, nd2) 2848 self.assertEqual(nd1, nd3) 2849 self.assertEqual(nd2, nd4) 2850 self.assertNotEqual(nd2, nd5) 2851 self.assertNotEqual(nd2, nd6) 2852 2853 # Test NotImplemented when types are different 2854 class A: 2855 def __eq__(self, other): 2856 return 10 2857 a = A() 2858 self.assertEqual(nd1.__eq__(a), NotImplemented) 2859 self.assertEqual(nd1 == a, 10) 2860 self.assertEqual(a == nd1, 10) 2861 2862 # All subclasses to compare equal giving the same behavior 2863 # as list, tuple, int, float, complex, str, dict, set, etc. 2864 class SizedNormalDist(NormalDist): 2865 def __init__(self, mu, sigma, n): 2866 super().__init__(mu, sigma) 2867 self.n = n 2868 s = SizedNormalDist(100, 15, 57) 2869 nd4 = NormalDist(100, 15) 2870 self.assertEqual(s, nd4) 2871 2872 # Don't allow duck type equality because we wouldn't 2873 # want a lognormal distribution to compare equal 2874 # to a normal distribution with the same parameters 2875 class LognormalDist: 2876 def __init__(self, mu, sigma): 2877 self.mu = mu 2878 self.sigma = sigma 2879 lnd = LognormalDist(100, 15) 2880 nd = NormalDist(100, 15) 2881 self.assertNotEqual(nd, lnd) 2882 2883 def test_pickle_and_copy(self): 2884 nd = self.module.NormalDist(37.5, 5.625) 2885 nd1 = copy.copy(nd) 2886 self.assertEqual(nd, nd1) 2887 nd2 = copy.deepcopy(nd) 2888 self.assertEqual(nd, nd2) 2889 nd3 = pickle.loads(pickle.dumps(nd)) 2890 self.assertEqual(nd, nd3) 2891 2892 def test_hashability(self): 2893 ND = self.module.NormalDist 2894 s = {ND(100, 15), ND(100.0, 15.0), ND(100, 10), ND(95, 15), ND(100, 15)} 2895 self.assertEqual(len(s), 3) 2896 2897 def test_repr(self): 2898 nd = self.module.NormalDist(37.5, 5.625) 2899 self.assertEqual(repr(nd), 'NormalDist(mu=37.5, sigma=5.625)') 2900 2901# Swapping the sys.modules['statistics'] is to solving the 2902# _pickle.PicklingError: 2903# Can't pickle <class 'statistics.NormalDist'>: 2904# it's not the same object as statistics.NormalDist 2905class TestNormalDistPython(unittest.TestCase, TestNormalDist): 2906 module = py_statistics 2907 def setUp(self): 2908 sys.modules['statistics'] = self.module 2909 2910 def tearDown(self): 2911 sys.modules['statistics'] = statistics 2912 2913 2914@unittest.skipUnless(c_statistics, 'requires _statistics') 2915class TestNormalDistC(unittest.TestCase, TestNormalDist): 2916 module = c_statistics 2917 def setUp(self): 2918 sys.modules['statistics'] = self.module 2919 2920 def tearDown(self): 2921 sys.modules['statistics'] = statistics 2922 2923 2924# === Run tests === 2925 2926def load_tests(loader, tests, ignore): 2927 """Used for doctest/unittest integration.""" 2928 tests.addTests(doctest.DocTestSuite()) 2929 return tests 2930 2931 2932if __name__ == "__main__": 2933 unittest.main() 2934