1import asyncio 2import gc 3import inspect 4import re 5import unittest 6from contextlib import contextmanager 7 8from asyncio import run, iscoroutinefunction 9from unittest import IsolatedAsyncioTestCase 10from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock, 11 create_autospec, sentinel, _CallList) 12 13 14def tearDownModule(): 15 asyncio.set_event_loop_policy(None) 16 17 18class AsyncClass: 19 def __init__(self): pass 20 async def async_method(self): pass 21 def normal_method(self): pass 22 23 @classmethod 24 async def async_class_method(cls): pass 25 26 @staticmethod 27 async def async_static_method(): pass 28 29 30class AwaitableClass: 31 def __await__(self): yield 32 33async def async_func(): pass 34 35async def async_func_args(a, b, *, c): pass 36 37def normal_func(): pass 38 39class NormalClass(object): 40 def a(self): pass 41 42 43async_foo_name = f'{__name__}.AsyncClass' 44normal_foo_name = f'{__name__}.NormalClass' 45 46 47@contextmanager 48def assertNeverAwaited(test): 49 with test.assertWarnsRegex(RuntimeWarning, "was never awaited$"): 50 yield 51 # In non-CPython implementations of Python, this is needed because timely 52 # deallocation is not guaranteed by the garbage collector. 53 gc.collect() 54 55 56class AsyncPatchDecoratorTest(unittest.TestCase): 57 def test_is_coroutine_function_patch(self): 58 @patch.object(AsyncClass, 'async_method') 59 def test_async(mock_method): 60 self.assertTrue(iscoroutinefunction(mock_method)) 61 test_async() 62 63 def test_is_async_patch(self): 64 @patch.object(AsyncClass, 'async_method') 65 def test_async(mock_method): 66 m = mock_method() 67 self.assertTrue(inspect.isawaitable(m)) 68 run(m) 69 70 @patch(f'{async_foo_name}.async_method') 71 def test_no_parent_attribute(mock_method): 72 m = mock_method() 73 self.assertTrue(inspect.isawaitable(m)) 74 run(m) 75 76 test_async() 77 test_no_parent_attribute() 78 79 def test_is_AsyncMock_patch(self): 80 @patch.object(AsyncClass, 'async_method') 81 def test_async(mock_method): 82 self.assertIsInstance(mock_method, AsyncMock) 83 84 test_async() 85 86 def test_is_AsyncMock_patch_staticmethod(self): 87 @patch.object(AsyncClass, 'async_static_method') 88 def test_async(mock_method): 89 self.assertIsInstance(mock_method, AsyncMock) 90 91 test_async() 92 93 def test_is_AsyncMock_patch_classmethod(self): 94 @patch.object(AsyncClass, 'async_class_method') 95 def test_async(mock_method): 96 self.assertIsInstance(mock_method, AsyncMock) 97 98 test_async() 99 100 def test_async_def_patch(self): 101 @patch(f"{__name__}.async_func", return_value=1) 102 @patch(f"{__name__}.async_func_args", return_value=2) 103 async def test_async(func_args_mock, func_mock): 104 self.assertEqual(func_args_mock._mock_name, "async_func_args") 105 self.assertEqual(func_mock._mock_name, "async_func") 106 107 self.assertIsInstance(async_func, AsyncMock) 108 self.assertIsInstance(async_func_args, AsyncMock) 109 110 self.assertEqual(await async_func(), 1) 111 self.assertEqual(await async_func_args(1, 2, c=3), 2) 112 113 run(test_async()) 114 self.assertTrue(inspect.iscoroutinefunction(async_func)) 115 116 117class AsyncPatchCMTest(unittest.TestCase): 118 def test_is_async_function_cm(self): 119 def test_async(): 120 with patch.object(AsyncClass, 'async_method') as mock_method: 121 self.assertTrue(iscoroutinefunction(mock_method)) 122 123 test_async() 124 125 def test_is_async_cm(self): 126 def test_async(): 127 with patch.object(AsyncClass, 'async_method') as mock_method: 128 m = mock_method() 129 self.assertTrue(inspect.isawaitable(m)) 130 run(m) 131 132 test_async() 133 134 def test_is_AsyncMock_cm(self): 135 def test_async(): 136 with patch.object(AsyncClass, 'async_method') as mock_method: 137 self.assertIsInstance(mock_method, AsyncMock) 138 139 test_async() 140 141 def test_async_def_cm(self): 142 async def test_async(): 143 with patch(f"{__name__}.async_func", AsyncMock()): 144 self.assertIsInstance(async_func, AsyncMock) 145 self.assertTrue(inspect.iscoroutinefunction(async_func)) 146 147 run(test_async()) 148 149 150class AsyncMockTest(unittest.TestCase): 151 def test_iscoroutinefunction_default(self): 152 mock = AsyncMock() 153 self.assertTrue(iscoroutinefunction(mock)) 154 155 def test_iscoroutinefunction_function(self): 156 async def foo(): pass 157 mock = AsyncMock(foo) 158 self.assertTrue(iscoroutinefunction(mock)) 159 self.assertTrue(inspect.iscoroutinefunction(mock)) 160 161 def test_isawaitable(self): 162 mock = AsyncMock() 163 m = mock() 164 self.assertTrue(inspect.isawaitable(m)) 165 run(m) 166 self.assertIn('assert_awaited', dir(mock)) 167 168 def test_iscoroutinefunction_normal_function(self): 169 def foo(): pass 170 mock = AsyncMock(foo) 171 self.assertTrue(iscoroutinefunction(mock)) 172 self.assertTrue(inspect.iscoroutinefunction(mock)) 173 174 def test_future_isfuture(self): 175 loop = asyncio.new_event_loop() 176 asyncio.set_event_loop(loop) 177 fut = asyncio.Future() 178 loop.stop() 179 loop.close() 180 mock = AsyncMock(fut) 181 self.assertIsInstance(mock, asyncio.Future) 182 183 184class AsyncAutospecTest(unittest.TestCase): 185 def test_is_AsyncMock_patch(self): 186 @patch(async_foo_name, autospec=True) 187 def test_async(mock_method): 188 self.assertIsInstance(mock_method.async_method, AsyncMock) 189 self.assertIsInstance(mock_method, MagicMock) 190 191 @patch(async_foo_name, autospec=True) 192 def test_normal_method(mock_method): 193 self.assertIsInstance(mock_method.normal_method, MagicMock) 194 195 test_async() 196 test_normal_method() 197 198 def test_create_autospec_instance(self): 199 with self.assertRaises(RuntimeError): 200 create_autospec(async_func, instance=True) 201 202 @unittest.skip('Broken test from https://bugs.python.org/issue37251') 203 def test_create_autospec_awaitable_class(self): 204 self.assertIsInstance(create_autospec(AwaitableClass), AsyncMock) 205 206 def test_create_autospec(self): 207 spec = create_autospec(async_func_args) 208 awaitable = spec(1, 2, c=3) 209 async def main(): 210 await awaitable 211 212 self.assertEqual(spec.await_count, 0) 213 self.assertIsNone(spec.await_args) 214 self.assertEqual(spec.await_args_list, []) 215 spec.assert_not_awaited() 216 217 run(main()) 218 219 self.assertTrue(iscoroutinefunction(spec)) 220 self.assertTrue(asyncio.iscoroutine(awaitable)) 221 self.assertEqual(spec.await_count, 1) 222 self.assertEqual(spec.await_args, call(1, 2, c=3)) 223 self.assertEqual(spec.await_args_list, [call(1, 2, c=3)]) 224 spec.assert_awaited_once() 225 spec.assert_awaited_once_with(1, 2, c=3) 226 spec.assert_awaited_with(1, 2, c=3) 227 spec.assert_awaited() 228 229 with self.assertRaises(AssertionError): 230 spec.assert_any_await(e=1) 231 232 233 def test_patch_with_autospec(self): 234 235 async def test_async(): 236 with patch(f"{__name__}.async_func_args", autospec=True) as mock_method: 237 awaitable = mock_method(1, 2, c=3) 238 self.assertIsInstance(mock_method.mock, AsyncMock) 239 240 self.assertTrue(iscoroutinefunction(mock_method)) 241 self.assertTrue(asyncio.iscoroutine(awaitable)) 242 self.assertTrue(inspect.isawaitable(awaitable)) 243 244 # Verify the default values during mock setup 245 self.assertEqual(mock_method.await_count, 0) 246 self.assertEqual(mock_method.await_args_list, []) 247 self.assertIsNone(mock_method.await_args) 248 mock_method.assert_not_awaited() 249 250 await awaitable 251 252 self.assertEqual(mock_method.await_count, 1) 253 self.assertEqual(mock_method.await_args, call(1, 2, c=3)) 254 self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)]) 255 mock_method.assert_awaited_once() 256 mock_method.assert_awaited_once_with(1, 2, c=3) 257 mock_method.assert_awaited_with(1, 2, c=3) 258 mock_method.assert_awaited() 259 260 mock_method.reset_mock() 261 self.assertEqual(mock_method.await_count, 0) 262 self.assertIsNone(mock_method.await_args) 263 self.assertEqual(mock_method.await_args_list, []) 264 265 run(test_async()) 266 267 268class AsyncSpecTest(unittest.TestCase): 269 def test_spec_normal_methods_on_class(self): 270 def inner_test(mock_type): 271 mock = mock_type(AsyncClass) 272 self.assertIsInstance(mock.async_method, AsyncMock) 273 self.assertIsInstance(mock.normal_method, MagicMock) 274 275 for mock_type in [AsyncMock, MagicMock]: 276 with self.subTest(f"test method types with {mock_type}"): 277 inner_test(mock_type) 278 279 def test_spec_normal_methods_on_class_with_mock(self): 280 mock = Mock(AsyncClass) 281 self.assertIsInstance(mock.async_method, AsyncMock) 282 self.assertIsInstance(mock.normal_method, Mock) 283 284 def test_spec_mock_type_kw(self): 285 def inner_test(mock_type): 286 async_mock = mock_type(spec=async_func) 287 self.assertIsInstance(async_mock, mock_type) 288 with assertNeverAwaited(self): 289 self.assertTrue(inspect.isawaitable(async_mock())) 290 291 sync_mock = mock_type(spec=normal_func) 292 self.assertIsInstance(sync_mock, mock_type) 293 294 for mock_type in [AsyncMock, MagicMock, Mock]: 295 with self.subTest(f"test spec kwarg with {mock_type}"): 296 inner_test(mock_type) 297 298 def test_spec_mock_type_positional(self): 299 def inner_test(mock_type): 300 async_mock = mock_type(async_func) 301 self.assertIsInstance(async_mock, mock_type) 302 with assertNeverAwaited(self): 303 self.assertTrue(inspect.isawaitable(async_mock())) 304 305 sync_mock = mock_type(normal_func) 306 self.assertIsInstance(sync_mock, mock_type) 307 308 for mock_type in [AsyncMock, MagicMock, Mock]: 309 with self.subTest(f"test spec positional with {mock_type}"): 310 inner_test(mock_type) 311 312 def test_spec_as_normal_kw_AsyncMock(self): 313 mock = AsyncMock(spec=normal_func) 314 self.assertIsInstance(mock, AsyncMock) 315 m = mock() 316 self.assertTrue(inspect.isawaitable(m)) 317 run(m) 318 319 def test_spec_as_normal_positional_AsyncMock(self): 320 mock = AsyncMock(normal_func) 321 self.assertIsInstance(mock, AsyncMock) 322 m = mock() 323 self.assertTrue(inspect.isawaitable(m)) 324 run(m) 325 326 def test_spec_async_mock(self): 327 @patch.object(AsyncClass, 'async_method', spec=True) 328 def test_async(mock_method): 329 self.assertIsInstance(mock_method, AsyncMock) 330 331 test_async() 332 333 def test_spec_parent_not_async_attribute_is(self): 334 @patch(async_foo_name, spec=True) 335 def test_async(mock_method): 336 self.assertIsInstance(mock_method, MagicMock) 337 self.assertIsInstance(mock_method.async_method, AsyncMock) 338 339 test_async() 340 341 def test_target_async_spec_not(self): 342 @patch.object(AsyncClass, 'async_method', spec=NormalClass.a) 343 def test_async_attribute(mock_method): 344 self.assertIsInstance(mock_method, MagicMock) 345 self.assertFalse(inspect.iscoroutine(mock_method)) 346 self.assertFalse(inspect.isawaitable(mock_method)) 347 348 test_async_attribute() 349 350 def test_target_not_async_spec_is(self): 351 @patch.object(NormalClass, 'a', spec=async_func) 352 def test_attribute_not_async_spec_is(mock_async_func): 353 self.assertIsInstance(mock_async_func, AsyncMock) 354 test_attribute_not_async_spec_is() 355 356 def test_spec_async_attributes(self): 357 @patch(normal_foo_name, spec=AsyncClass) 358 def test_async_attributes_coroutines(MockNormalClass): 359 self.assertIsInstance(MockNormalClass.async_method, AsyncMock) 360 self.assertIsInstance(MockNormalClass, MagicMock) 361 362 test_async_attributes_coroutines() 363 364 365class AsyncSpecSetTest(unittest.TestCase): 366 def test_is_AsyncMock_patch(self): 367 @patch.object(AsyncClass, 'async_method', spec_set=True) 368 def test_async(async_method): 369 self.assertIsInstance(async_method, AsyncMock) 370 test_async() 371 372 def test_is_async_AsyncMock(self): 373 mock = AsyncMock(spec_set=AsyncClass.async_method) 374 self.assertTrue(iscoroutinefunction(mock)) 375 self.assertIsInstance(mock, AsyncMock) 376 377 def test_is_child_AsyncMock(self): 378 mock = MagicMock(spec_set=AsyncClass) 379 self.assertTrue(iscoroutinefunction(mock.async_method)) 380 self.assertFalse(iscoroutinefunction(mock.normal_method)) 381 self.assertIsInstance(mock.async_method, AsyncMock) 382 self.assertIsInstance(mock.normal_method, MagicMock) 383 self.assertIsInstance(mock, MagicMock) 384 385 def test_magicmock_lambda_spec(self): 386 mock_obj = MagicMock() 387 mock_obj.mock_func = MagicMock(spec=lambda x: x) 388 389 with patch.object(mock_obj, "mock_func") as cm: 390 self.assertIsInstance(cm, MagicMock) 391 392 393class AsyncArguments(IsolatedAsyncioTestCase): 394 async def test_add_return_value(self): 395 async def addition(self, var): pass 396 397 mock = AsyncMock(addition, return_value=10) 398 output = await mock(5) 399 400 self.assertEqual(output, 10) 401 402 async def test_add_side_effect_exception(self): 403 async def addition(var): pass 404 mock = AsyncMock(addition, side_effect=Exception('err')) 405 with self.assertRaises(Exception): 406 await mock(5) 407 408 async def test_add_side_effect_coroutine(self): 409 async def addition(var): 410 return var + 1 411 mock = AsyncMock(side_effect=addition) 412 result = await mock(5) 413 self.assertEqual(result, 6) 414 415 async def test_add_side_effect_normal_function(self): 416 def addition(var): 417 return var + 1 418 mock = AsyncMock(side_effect=addition) 419 result = await mock(5) 420 self.assertEqual(result, 6) 421 422 async def test_add_side_effect_iterable(self): 423 vals = [1, 2, 3] 424 mock = AsyncMock(side_effect=vals) 425 for item in vals: 426 self.assertEqual(await mock(), item) 427 428 with self.assertRaises(StopAsyncIteration) as e: 429 await mock() 430 431 async def test_add_side_effect_exception_iterable(self): 432 class SampleException(Exception): 433 pass 434 435 vals = [1, SampleException("foo")] 436 mock = AsyncMock(side_effect=vals) 437 self.assertEqual(await mock(), 1) 438 439 with self.assertRaises(SampleException) as e: 440 await mock() 441 442 async def test_return_value_AsyncMock(self): 443 value = AsyncMock(return_value=10) 444 mock = AsyncMock(return_value=value) 445 result = await mock() 446 self.assertIs(result, value) 447 448 async def test_return_value_awaitable(self): 449 fut = asyncio.Future() 450 fut.set_result(None) 451 mock = AsyncMock(return_value=fut) 452 result = await mock() 453 self.assertIsInstance(result, asyncio.Future) 454 455 async def test_side_effect_awaitable_values(self): 456 fut = asyncio.Future() 457 fut.set_result(None) 458 459 mock = AsyncMock(side_effect=[fut]) 460 result = await mock() 461 self.assertIsInstance(result, asyncio.Future) 462 463 with self.assertRaises(StopAsyncIteration): 464 await mock() 465 466 async def test_side_effect_is_AsyncMock(self): 467 effect = AsyncMock(return_value=10) 468 mock = AsyncMock(side_effect=effect) 469 470 result = await mock() 471 self.assertEqual(result, 10) 472 473 async def test_wraps_coroutine(self): 474 value = asyncio.Future() 475 476 ran = False 477 async def inner(): 478 nonlocal ran 479 ran = True 480 return value 481 482 mock = AsyncMock(wraps=inner) 483 result = await mock() 484 self.assertEqual(result, value) 485 mock.assert_awaited() 486 self.assertTrue(ran) 487 488 async def test_wraps_normal_function(self): 489 value = 1 490 491 ran = False 492 def inner(): 493 nonlocal ran 494 ran = True 495 return value 496 497 mock = AsyncMock(wraps=inner) 498 result = await mock() 499 self.assertEqual(result, value) 500 mock.assert_awaited() 501 self.assertTrue(ran) 502 503 async def test_await_args_list_order(self): 504 async_mock = AsyncMock() 505 mock2 = async_mock(2) 506 mock1 = async_mock(1) 507 await mock1 508 await mock2 509 async_mock.assert_has_awaits([call(1), call(2)]) 510 self.assertEqual(async_mock.await_args_list, [call(1), call(2)]) 511 self.assertEqual(async_mock.call_args_list, [call(2), call(1)]) 512 513 514class AsyncMagicMethods(unittest.TestCase): 515 def test_async_magic_methods_return_async_mocks(self): 516 m_mock = MagicMock() 517 self.assertIsInstance(m_mock.__aenter__, AsyncMock) 518 self.assertIsInstance(m_mock.__aexit__, AsyncMock) 519 self.assertIsInstance(m_mock.__anext__, AsyncMock) 520 # __aiter__ is actually a synchronous object 521 # so should return a MagicMock 522 self.assertIsInstance(m_mock.__aiter__, MagicMock) 523 524 def test_sync_magic_methods_return_magic_mocks(self): 525 a_mock = AsyncMock() 526 self.assertIsInstance(a_mock.__enter__, MagicMock) 527 self.assertIsInstance(a_mock.__exit__, MagicMock) 528 self.assertIsInstance(a_mock.__next__, MagicMock) 529 self.assertIsInstance(a_mock.__len__, MagicMock) 530 531 def test_magicmock_has_async_magic_methods(self): 532 m_mock = MagicMock() 533 self.assertTrue(hasattr(m_mock, "__aenter__")) 534 self.assertTrue(hasattr(m_mock, "__aexit__")) 535 self.assertTrue(hasattr(m_mock, "__anext__")) 536 537 def test_asyncmock_has_sync_magic_methods(self): 538 a_mock = AsyncMock() 539 self.assertTrue(hasattr(a_mock, "__enter__")) 540 self.assertTrue(hasattr(a_mock, "__exit__")) 541 self.assertTrue(hasattr(a_mock, "__next__")) 542 self.assertTrue(hasattr(a_mock, "__len__")) 543 544 def test_magic_methods_are_async_functions(self): 545 m_mock = MagicMock() 546 self.assertIsInstance(m_mock.__aenter__, AsyncMock) 547 self.assertIsInstance(m_mock.__aexit__, AsyncMock) 548 # AsyncMocks are also coroutine functions 549 self.assertTrue(iscoroutinefunction(m_mock.__aenter__)) 550 self.assertTrue(iscoroutinefunction(m_mock.__aexit__)) 551 552class AsyncContextManagerTest(unittest.TestCase): 553 554 class WithAsyncContextManager: 555 async def __aenter__(self, *args, **kwargs): pass 556 557 async def __aexit__(self, *args, **kwargs): pass 558 559 class WithSyncContextManager: 560 def __enter__(self, *args, **kwargs): pass 561 562 def __exit__(self, *args, **kwargs): pass 563 564 class ProductionCode: 565 # Example real-world(ish) code 566 def __init__(self): 567 self.session = None 568 569 async def main(self): 570 async with self.session.post('https://python.org') as response: 571 val = await response.json() 572 return val 573 574 def test_set_return_value_of_aenter(self): 575 def inner_test(mock_type): 576 pc = self.ProductionCode() 577 pc.session = MagicMock(name='sessionmock') 578 cm = mock_type(name='magic_cm') 579 response = AsyncMock(name='response') 580 response.json = AsyncMock(return_value={'json': 123}) 581 cm.__aenter__.return_value = response 582 pc.session.post.return_value = cm 583 result = run(pc.main()) 584 self.assertEqual(result, {'json': 123}) 585 586 for mock_type in [AsyncMock, MagicMock]: 587 with self.subTest(f"test set return value of aenter with {mock_type}"): 588 inner_test(mock_type) 589 590 def test_mock_supports_async_context_manager(self): 591 def inner_test(mock_type): 592 called = False 593 cm = self.WithAsyncContextManager() 594 cm_mock = mock_type(cm) 595 596 async def use_context_manager(): 597 nonlocal called 598 async with cm_mock as result: 599 called = True 600 return result 601 602 cm_result = run(use_context_manager()) 603 self.assertTrue(called) 604 self.assertTrue(cm_mock.__aenter__.called) 605 self.assertTrue(cm_mock.__aexit__.called) 606 cm_mock.__aenter__.assert_awaited() 607 cm_mock.__aexit__.assert_awaited() 608 # We mock __aenter__ so it does not return self 609 self.assertIsNot(cm_mock, cm_result) 610 611 for mock_type in [AsyncMock, MagicMock]: 612 with self.subTest(f"test context manager magics with {mock_type}"): 613 inner_test(mock_type) 614 615 616 def test_mock_customize_async_context_manager(self): 617 instance = self.WithAsyncContextManager() 618 mock_instance = MagicMock(instance) 619 620 expected_result = object() 621 mock_instance.__aenter__.return_value = expected_result 622 623 async def use_context_manager(): 624 async with mock_instance as result: 625 return result 626 627 self.assertIs(run(use_context_manager()), expected_result) 628 629 def test_mock_customize_async_context_manager_with_coroutine(self): 630 enter_called = False 631 exit_called = False 632 633 async def enter_coroutine(*args): 634 nonlocal enter_called 635 enter_called = True 636 637 async def exit_coroutine(*args): 638 nonlocal exit_called 639 exit_called = True 640 641 instance = self.WithAsyncContextManager() 642 mock_instance = MagicMock(instance) 643 644 mock_instance.__aenter__ = enter_coroutine 645 mock_instance.__aexit__ = exit_coroutine 646 647 async def use_context_manager(): 648 async with mock_instance: 649 pass 650 651 run(use_context_manager()) 652 self.assertTrue(enter_called) 653 self.assertTrue(exit_called) 654 655 def test_context_manager_raise_exception_by_default(self): 656 async def raise_in(context_manager): 657 async with context_manager: 658 raise TypeError() 659 660 instance = self.WithAsyncContextManager() 661 mock_instance = MagicMock(instance) 662 with self.assertRaises(TypeError): 663 run(raise_in(mock_instance)) 664 665 666class AsyncIteratorTest(unittest.TestCase): 667 class WithAsyncIterator(object): 668 def __init__(self): 669 self.items = ["foo", "NormalFoo", "baz"] 670 671 def __aiter__(self): pass 672 673 async def __anext__(self): pass 674 675 def test_aiter_set_return_value(self): 676 mock_iter = AsyncMock(name="tester") 677 mock_iter.__aiter__.return_value = [1, 2, 3] 678 async def main(): 679 return [i async for i in mock_iter] 680 result = run(main()) 681 self.assertEqual(result, [1, 2, 3]) 682 683 def test_mock_aiter_and_anext_asyncmock(self): 684 def inner_test(mock_type): 685 instance = self.WithAsyncIterator() 686 mock_instance = mock_type(instance) 687 # Check that the mock and the real thing bahave the same 688 # __aiter__ is not actually async, so not a coroutinefunction 689 self.assertFalse(iscoroutinefunction(instance.__aiter__)) 690 self.assertFalse(iscoroutinefunction(mock_instance.__aiter__)) 691 # __anext__ is async 692 self.assertTrue(iscoroutinefunction(instance.__anext__)) 693 self.assertTrue(iscoroutinefunction(mock_instance.__anext__)) 694 695 for mock_type in [AsyncMock, MagicMock]: 696 with self.subTest(f"test aiter and anext corourtine with {mock_type}"): 697 inner_test(mock_type) 698 699 700 def test_mock_async_for(self): 701 async def iterate(iterator): 702 accumulator = [] 703 async for item in iterator: 704 accumulator.append(item) 705 706 return accumulator 707 708 expected = ["FOO", "BAR", "BAZ"] 709 def test_default(mock_type): 710 mock_instance = mock_type(self.WithAsyncIterator()) 711 self.assertEqual(run(iterate(mock_instance)), []) 712 713 714 def test_set_return_value(mock_type): 715 mock_instance = mock_type(self.WithAsyncIterator()) 716 mock_instance.__aiter__.return_value = expected[:] 717 self.assertEqual(run(iterate(mock_instance)), expected) 718 719 def test_set_return_value_iter(mock_type): 720 mock_instance = mock_type(self.WithAsyncIterator()) 721 mock_instance.__aiter__.return_value = iter(expected[:]) 722 self.assertEqual(run(iterate(mock_instance)), expected) 723 724 for mock_type in [AsyncMock, MagicMock]: 725 with self.subTest(f"default value with {mock_type}"): 726 test_default(mock_type) 727 728 with self.subTest(f"set return_value with {mock_type}"): 729 test_set_return_value(mock_type) 730 731 with self.subTest(f"set return_value iterator with {mock_type}"): 732 test_set_return_value_iter(mock_type) 733 734 735class AsyncMockAssert(unittest.TestCase): 736 def setUp(self): 737 self.mock = AsyncMock() 738 739 async def _runnable_test(self, *args, **kwargs): 740 await self.mock(*args, **kwargs) 741 742 async def _await_coroutine(self, coroutine): 743 return await coroutine 744 745 def test_assert_called_but_not_awaited(self): 746 mock = AsyncMock(AsyncClass) 747 with assertNeverAwaited(self): 748 mock.async_method() 749 self.assertTrue(iscoroutinefunction(mock.async_method)) 750 mock.async_method.assert_called() 751 mock.async_method.assert_called_once() 752 mock.async_method.assert_called_once_with() 753 with self.assertRaises(AssertionError): 754 mock.assert_awaited() 755 with self.assertRaises(AssertionError): 756 mock.async_method.assert_awaited() 757 758 def test_assert_called_then_awaited(self): 759 mock = AsyncMock(AsyncClass) 760 mock_coroutine = mock.async_method() 761 mock.async_method.assert_called() 762 mock.async_method.assert_called_once() 763 mock.async_method.assert_called_once_with() 764 with self.assertRaises(AssertionError): 765 mock.async_method.assert_awaited() 766 767 run(self._await_coroutine(mock_coroutine)) 768 # Assert we haven't re-called the function 769 mock.async_method.assert_called_once() 770 mock.async_method.assert_awaited() 771 mock.async_method.assert_awaited_once() 772 mock.async_method.assert_awaited_once_with() 773 774 def test_assert_called_and_awaited_at_same_time(self): 775 with self.assertRaises(AssertionError): 776 self.mock.assert_awaited() 777 778 with self.assertRaises(AssertionError): 779 self.mock.assert_called() 780 781 run(self._runnable_test()) 782 self.mock.assert_called_once() 783 self.mock.assert_awaited_once() 784 785 def test_assert_called_twice_and_awaited_once(self): 786 mock = AsyncMock(AsyncClass) 787 coroutine = mock.async_method() 788 # The first call will be awaited so no warning there 789 # But this call will never get awaited, so it will warn here 790 with assertNeverAwaited(self): 791 mock.async_method() 792 with self.assertRaises(AssertionError): 793 mock.async_method.assert_awaited() 794 mock.async_method.assert_called() 795 run(self._await_coroutine(coroutine)) 796 mock.async_method.assert_awaited() 797 mock.async_method.assert_awaited_once() 798 799 def test_assert_called_once_and_awaited_twice(self): 800 mock = AsyncMock(AsyncClass) 801 coroutine = mock.async_method() 802 mock.async_method.assert_called_once() 803 run(self._await_coroutine(coroutine)) 804 with self.assertRaises(RuntimeError): 805 # Cannot reuse already awaited coroutine 806 run(self._await_coroutine(coroutine)) 807 mock.async_method.assert_awaited() 808 809 def test_assert_awaited_but_not_called(self): 810 with self.assertRaises(AssertionError): 811 self.mock.assert_awaited() 812 with self.assertRaises(AssertionError): 813 self.mock.assert_called() 814 with self.assertRaises(TypeError): 815 # You cannot await an AsyncMock, it must be a coroutine 816 run(self._await_coroutine(self.mock)) 817 818 with self.assertRaises(AssertionError): 819 self.mock.assert_awaited() 820 with self.assertRaises(AssertionError): 821 self.mock.assert_called() 822 823 def test_assert_has_calls_not_awaits(self): 824 kalls = [call('foo')] 825 with assertNeverAwaited(self): 826 self.mock('foo') 827 self.mock.assert_has_calls(kalls) 828 with self.assertRaises(AssertionError): 829 self.mock.assert_has_awaits(kalls) 830 831 def test_assert_has_mock_calls_on_async_mock_no_spec(self): 832 with assertNeverAwaited(self): 833 self.mock() 834 kalls_empty = [('', (), {})] 835 self.assertEqual(self.mock.mock_calls, kalls_empty) 836 837 with assertNeverAwaited(self): 838 self.mock('foo') 839 with assertNeverAwaited(self): 840 self.mock('baz') 841 mock_kalls = ([call(), call('foo'), call('baz')]) 842 self.assertEqual(self.mock.mock_calls, mock_kalls) 843 844 def test_assert_has_mock_calls_on_async_mock_with_spec(self): 845 a_class_mock = AsyncMock(AsyncClass) 846 with assertNeverAwaited(self): 847 a_class_mock.async_method() 848 kalls_empty = [('', (), {})] 849 self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty) 850 self.assertEqual(a_class_mock.mock_calls, [call.async_method()]) 851 852 with assertNeverAwaited(self): 853 a_class_mock.async_method(1, 2, 3, a=4, b=5) 854 method_kalls = [call(), call(1, 2, 3, a=4, b=5)] 855 mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)] 856 self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls) 857 self.assertEqual(a_class_mock.mock_calls, mock_kalls) 858 859 def test_async_method_calls_recorded(self): 860 with assertNeverAwaited(self): 861 self.mock.something(3, fish=None) 862 with assertNeverAwaited(self): 863 self.mock.something_else.something(6, cake=sentinel.Cake) 864 865 self.assertEqual(self.mock.method_calls, [ 866 ("something", (3,), {'fish': None}), 867 ("something_else.something", (6,), {'cake': sentinel.Cake}) 868 ], 869 "method calls not recorded correctly") 870 self.assertEqual(self.mock.something_else.method_calls, 871 [("something", (6,), {'cake': sentinel.Cake})], 872 "method calls not recorded correctly") 873 874 def test_async_arg_lists(self): 875 def assert_attrs(mock): 876 names = ('call_args_list', 'method_calls', 'mock_calls') 877 for name in names: 878 attr = getattr(mock, name) 879 self.assertIsInstance(attr, _CallList) 880 self.assertIsInstance(attr, list) 881 self.assertEqual(attr, []) 882 883 assert_attrs(self.mock) 884 with assertNeverAwaited(self): 885 self.mock() 886 with assertNeverAwaited(self): 887 self.mock(1, 2) 888 with assertNeverAwaited(self): 889 self.mock(a=3) 890 891 self.mock.reset_mock() 892 assert_attrs(self.mock) 893 894 a_mock = AsyncMock(AsyncClass) 895 with assertNeverAwaited(self): 896 a_mock.async_method() 897 with assertNeverAwaited(self): 898 a_mock.async_method(1, a=3) 899 900 a_mock.reset_mock() 901 assert_attrs(a_mock) 902 903 def test_assert_awaited(self): 904 with self.assertRaises(AssertionError): 905 self.mock.assert_awaited() 906 907 run(self._runnable_test()) 908 self.mock.assert_awaited() 909 910 def test_assert_awaited_once(self): 911 with self.assertRaises(AssertionError): 912 self.mock.assert_awaited_once() 913 914 run(self._runnable_test()) 915 self.mock.assert_awaited_once() 916 917 run(self._runnable_test()) 918 with self.assertRaises(AssertionError): 919 self.mock.assert_awaited_once() 920 921 def test_assert_awaited_with(self): 922 msg = 'Not awaited' 923 with self.assertRaisesRegex(AssertionError, msg): 924 self.mock.assert_awaited_with('foo') 925 926 run(self._runnable_test()) 927 msg = 'expected await not found' 928 with self.assertRaisesRegex(AssertionError, msg): 929 self.mock.assert_awaited_with('foo') 930 931 run(self._runnable_test('foo')) 932 self.mock.assert_awaited_with('foo') 933 934 run(self._runnable_test('SomethingElse')) 935 with self.assertRaises(AssertionError): 936 self.mock.assert_awaited_with('foo') 937 938 def test_assert_awaited_once_with(self): 939 with self.assertRaises(AssertionError): 940 self.mock.assert_awaited_once_with('foo') 941 942 run(self._runnable_test('foo')) 943 self.mock.assert_awaited_once_with('foo') 944 945 run(self._runnable_test('foo')) 946 with self.assertRaises(AssertionError): 947 self.mock.assert_awaited_once_with('foo') 948 949 def test_assert_any_wait(self): 950 with self.assertRaises(AssertionError): 951 self.mock.assert_any_await('foo') 952 953 run(self._runnable_test('baz')) 954 with self.assertRaises(AssertionError): 955 self.mock.assert_any_await('foo') 956 957 run(self._runnable_test('foo')) 958 self.mock.assert_any_await('foo') 959 960 run(self._runnable_test('SomethingElse')) 961 self.mock.assert_any_await('foo') 962 963 def test_assert_has_awaits_no_order(self): 964 calls = [call('foo'), call('baz')] 965 966 with self.assertRaises(AssertionError) as cm: 967 self.mock.assert_has_awaits(calls) 968 self.assertEqual(len(cm.exception.args), 1) 969 970 run(self._runnable_test('foo')) 971 with self.assertRaises(AssertionError): 972 self.mock.assert_has_awaits(calls) 973 974 run(self._runnable_test('foo')) 975 with self.assertRaises(AssertionError): 976 self.mock.assert_has_awaits(calls) 977 978 run(self._runnable_test('baz')) 979 self.mock.assert_has_awaits(calls) 980 981 run(self._runnable_test('SomethingElse')) 982 self.mock.assert_has_awaits(calls) 983 984 def test_awaits_asserts_with_any(self): 985 class Foo: 986 def __eq__(self, other): pass 987 988 run(self._runnable_test(Foo(), 1)) 989 990 self.mock.assert_has_awaits([call(ANY, 1)]) 991 self.mock.assert_awaited_with(ANY, 1) 992 self.mock.assert_any_await(ANY, 1) 993 994 def test_awaits_asserts_with_spec_and_any(self): 995 class Foo: 996 def __eq__(self, other): pass 997 998 mock_with_spec = AsyncMock(spec=Foo) 999 1000 async def _custom_mock_runnable_test(*args): 1001 await mock_with_spec(*args) 1002 1003 run(_custom_mock_runnable_test(Foo(), 1)) 1004 mock_with_spec.assert_has_awaits([call(ANY, 1)]) 1005 mock_with_spec.assert_awaited_with(ANY, 1) 1006 mock_with_spec.assert_any_await(ANY, 1) 1007 1008 def test_assert_has_awaits_ordered(self): 1009 calls = [call('foo'), call('baz')] 1010 with self.assertRaises(AssertionError): 1011 self.mock.assert_has_awaits(calls, any_order=True) 1012 1013 run(self._runnable_test('baz')) 1014 with self.assertRaises(AssertionError): 1015 self.mock.assert_has_awaits(calls, any_order=True) 1016 1017 run(self._runnable_test('bamf')) 1018 with self.assertRaises(AssertionError): 1019 self.mock.assert_has_awaits(calls, any_order=True) 1020 1021 run(self._runnable_test('foo')) 1022 self.mock.assert_has_awaits(calls, any_order=True) 1023 1024 run(self._runnable_test('qux')) 1025 self.mock.assert_has_awaits(calls, any_order=True) 1026 1027 def test_assert_not_awaited(self): 1028 self.mock.assert_not_awaited() 1029 1030 run(self._runnable_test()) 1031 with self.assertRaises(AssertionError): 1032 self.mock.assert_not_awaited() 1033 1034 def test_assert_has_awaits_not_matching_spec_error(self): 1035 async def f(x=None): pass 1036 1037 self.mock = AsyncMock(spec=f) 1038 run(self._runnable_test(1)) 1039 1040 with self.assertRaisesRegex( 1041 AssertionError, 1042 '^{}$'.format( 1043 re.escape('Awaits not found.\n' 1044 'Expected: [call()]\n' 1045 'Actual: [call(1)]'))) as cm: 1046 self.mock.assert_has_awaits([call()]) 1047 self.assertIsNone(cm.exception.__cause__) 1048 1049 with self.assertRaisesRegex( 1050 AssertionError, 1051 '^{}$'.format( 1052 re.escape( 1053 'Error processing expected awaits.\n' 1054 "Errors: [None, TypeError('too many positional " 1055 "arguments')]\n" 1056 'Expected: [call(), call(1, 2)]\n' 1057 'Actual: [call(1)]'))) as cm: 1058 self.mock.assert_has_awaits([call(), call(1, 2)]) 1059 self.assertIsInstance(cm.exception.__cause__, TypeError) 1060