• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import asyncio
2import gc
3import inspect
4import re
5import unittest
6from contextlib import contextmanager
7
8from asyncio import run, iscoroutinefunction
9from unittest import IsolatedAsyncioTestCase
10from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock,
11                           create_autospec, sentinel, _CallList)
12
13
14def tearDownModule():
15    asyncio.set_event_loop_policy(None)
16
17
18class AsyncClass:
19    def __init__(self): pass
20    async def async_method(self): pass
21    def normal_method(self): pass
22
23    @classmethod
24    async def async_class_method(cls): pass
25
26    @staticmethod
27    async def async_static_method(): pass
28
29
30class AwaitableClass:
31    def __await__(self): yield
32
33async def async_func(): pass
34
35async def async_func_args(a, b, *, c): pass
36
37def normal_func(): pass
38
39class NormalClass(object):
40    def a(self): pass
41
42
43async_foo_name = f'{__name__}.AsyncClass'
44normal_foo_name = f'{__name__}.NormalClass'
45
46
47@contextmanager
48def assertNeverAwaited(test):
49    with test.assertWarnsRegex(RuntimeWarning, "was never awaited$"):
50        yield
51        # In non-CPython implementations of Python, this is needed because timely
52        # deallocation is not guaranteed by the garbage collector.
53        gc.collect()
54
55
56class AsyncPatchDecoratorTest(unittest.TestCase):
57    def test_is_coroutine_function_patch(self):
58        @patch.object(AsyncClass, 'async_method')
59        def test_async(mock_method):
60            self.assertTrue(iscoroutinefunction(mock_method))
61        test_async()
62
63    def test_is_async_patch(self):
64        @patch.object(AsyncClass, 'async_method')
65        def test_async(mock_method):
66            m = mock_method()
67            self.assertTrue(inspect.isawaitable(m))
68            run(m)
69
70        @patch(f'{async_foo_name}.async_method')
71        def test_no_parent_attribute(mock_method):
72            m = mock_method()
73            self.assertTrue(inspect.isawaitable(m))
74            run(m)
75
76        test_async()
77        test_no_parent_attribute()
78
79    def test_is_AsyncMock_patch(self):
80        @patch.object(AsyncClass, 'async_method')
81        def test_async(mock_method):
82            self.assertIsInstance(mock_method, AsyncMock)
83
84        test_async()
85
86    def test_is_AsyncMock_patch_staticmethod(self):
87        @patch.object(AsyncClass, 'async_static_method')
88        def test_async(mock_method):
89            self.assertIsInstance(mock_method, AsyncMock)
90
91        test_async()
92
93    def test_is_AsyncMock_patch_classmethod(self):
94        @patch.object(AsyncClass, 'async_class_method')
95        def test_async(mock_method):
96            self.assertIsInstance(mock_method, AsyncMock)
97
98        test_async()
99
100    def test_async_def_patch(self):
101        @patch(f"{__name__}.async_func", return_value=1)
102        @patch(f"{__name__}.async_func_args", return_value=2)
103        async def test_async(func_args_mock, func_mock):
104            self.assertEqual(func_args_mock._mock_name, "async_func_args")
105            self.assertEqual(func_mock._mock_name, "async_func")
106
107            self.assertIsInstance(async_func, AsyncMock)
108            self.assertIsInstance(async_func_args, AsyncMock)
109
110            self.assertEqual(await async_func(), 1)
111            self.assertEqual(await async_func_args(1, 2, c=3), 2)
112
113        run(test_async())
114        self.assertTrue(inspect.iscoroutinefunction(async_func))
115
116
117class AsyncPatchCMTest(unittest.TestCase):
118    def test_is_async_function_cm(self):
119        def test_async():
120            with patch.object(AsyncClass, 'async_method') as mock_method:
121                self.assertTrue(iscoroutinefunction(mock_method))
122
123        test_async()
124
125    def test_is_async_cm(self):
126        def test_async():
127            with patch.object(AsyncClass, 'async_method') as mock_method:
128                m = mock_method()
129                self.assertTrue(inspect.isawaitable(m))
130                run(m)
131
132        test_async()
133
134    def test_is_AsyncMock_cm(self):
135        def test_async():
136            with patch.object(AsyncClass, 'async_method') as mock_method:
137                self.assertIsInstance(mock_method, AsyncMock)
138
139        test_async()
140
141    def test_async_def_cm(self):
142        async def test_async():
143            with patch(f"{__name__}.async_func", AsyncMock()):
144                self.assertIsInstance(async_func, AsyncMock)
145            self.assertTrue(inspect.iscoroutinefunction(async_func))
146
147        run(test_async())
148
149
150class AsyncMockTest(unittest.TestCase):
151    def test_iscoroutinefunction_default(self):
152        mock = AsyncMock()
153        self.assertTrue(iscoroutinefunction(mock))
154
155    def test_iscoroutinefunction_function(self):
156        async def foo(): pass
157        mock = AsyncMock(foo)
158        self.assertTrue(iscoroutinefunction(mock))
159        self.assertTrue(inspect.iscoroutinefunction(mock))
160
161    def test_isawaitable(self):
162        mock = AsyncMock()
163        m = mock()
164        self.assertTrue(inspect.isawaitable(m))
165        run(m)
166        self.assertIn('assert_awaited', dir(mock))
167
168    def test_iscoroutinefunction_normal_function(self):
169        def foo(): pass
170        mock = AsyncMock(foo)
171        self.assertTrue(iscoroutinefunction(mock))
172        self.assertTrue(inspect.iscoroutinefunction(mock))
173
174    def test_future_isfuture(self):
175        loop = asyncio.new_event_loop()
176        asyncio.set_event_loop(loop)
177        fut = asyncio.Future()
178        loop.stop()
179        loop.close()
180        mock = AsyncMock(fut)
181        self.assertIsInstance(mock, asyncio.Future)
182
183
184class AsyncAutospecTest(unittest.TestCase):
185    def test_is_AsyncMock_patch(self):
186        @patch(async_foo_name, autospec=True)
187        def test_async(mock_method):
188            self.assertIsInstance(mock_method.async_method, AsyncMock)
189            self.assertIsInstance(mock_method, MagicMock)
190
191        @patch(async_foo_name, autospec=True)
192        def test_normal_method(mock_method):
193            self.assertIsInstance(mock_method.normal_method, MagicMock)
194
195        test_async()
196        test_normal_method()
197
198    def test_create_autospec_instance(self):
199        with self.assertRaises(RuntimeError):
200            create_autospec(async_func, instance=True)
201
202    @unittest.skip('Broken test from https://bugs.python.org/issue37251')
203    def test_create_autospec_awaitable_class(self):
204        self.assertIsInstance(create_autospec(AwaitableClass), AsyncMock)
205
206    def test_create_autospec(self):
207        spec = create_autospec(async_func_args)
208        awaitable = spec(1, 2, c=3)
209        async def main():
210            await awaitable
211
212        self.assertEqual(spec.await_count, 0)
213        self.assertIsNone(spec.await_args)
214        self.assertEqual(spec.await_args_list, [])
215        spec.assert_not_awaited()
216
217        run(main())
218
219        self.assertTrue(iscoroutinefunction(spec))
220        self.assertTrue(asyncio.iscoroutine(awaitable))
221        self.assertEqual(spec.await_count, 1)
222        self.assertEqual(spec.await_args, call(1, 2, c=3))
223        self.assertEqual(spec.await_args_list, [call(1, 2, c=3)])
224        spec.assert_awaited_once()
225        spec.assert_awaited_once_with(1, 2, c=3)
226        spec.assert_awaited_with(1, 2, c=3)
227        spec.assert_awaited()
228
229        with self.assertRaises(AssertionError):
230            spec.assert_any_await(e=1)
231
232
233    def test_patch_with_autospec(self):
234
235        async def test_async():
236            with patch(f"{__name__}.async_func_args", autospec=True) as mock_method:
237                awaitable = mock_method(1, 2, c=3)
238                self.assertIsInstance(mock_method.mock, AsyncMock)
239
240                self.assertTrue(iscoroutinefunction(mock_method))
241                self.assertTrue(asyncio.iscoroutine(awaitable))
242                self.assertTrue(inspect.isawaitable(awaitable))
243
244                # Verify the default values during mock setup
245                self.assertEqual(mock_method.await_count, 0)
246                self.assertEqual(mock_method.await_args_list, [])
247                self.assertIsNone(mock_method.await_args)
248                mock_method.assert_not_awaited()
249
250                await awaitable
251
252            self.assertEqual(mock_method.await_count, 1)
253            self.assertEqual(mock_method.await_args, call(1, 2, c=3))
254            self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)])
255            mock_method.assert_awaited_once()
256            mock_method.assert_awaited_once_with(1, 2, c=3)
257            mock_method.assert_awaited_with(1, 2, c=3)
258            mock_method.assert_awaited()
259
260            mock_method.reset_mock()
261            self.assertEqual(mock_method.await_count, 0)
262            self.assertIsNone(mock_method.await_args)
263            self.assertEqual(mock_method.await_args_list, [])
264
265        run(test_async())
266
267
268class AsyncSpecTest(unittest.TestCase):
269    def test_spec_normal_methods_on_class(self):
270        def inner_test(mock_type):
271            mock = mock_type(AsyncClass)
272            self.assertIsInstance(mock.async_method, AsyncMock)
273            self.assertIsInstance(mock.normal_method, MagicMock)
274
275        for mock_type in [AsyncMock, MagicMock]:
276            with self.subTest(f"test method types with {mock_type}"):
277                inner_test(mock_type)
278
279    def test_spec_normal_methods_on_class_with_mock(self):
280        mock = Mock(AsyncClass)
281        self.assertIsInstance(mock.async_method, AsyncMock)
282        self.assertIsInstance(mock.normal_method, Mock)
283
284    def test_spec_mock_type_kw(self):
285        def inner_test(mock_type):
286            async_mock = mock_type(spec=async_func)
287            self.assertIsInstance(async_mock, mock_type)
288            with assertNeverAwaited(self):
289                self.assertTrue(inspect.isawaitable(async_mock()))
290
291            sync_mock = mock_type(spec=normal_func)
292            self.assertIsInstance(sync_mock, mock_type)
293
294        for mock_type in [AsyncMock, MagicMock, Mock]:
295            with self.subTest(f"test spec kwarg with {mock_type}"):
296                inner_test(mock_type)
297
298    def test_spec_mock_type_positional(self):
299        def inner_test(mock_type):
300            async_mock = mock_type(async_func)
301            self.assertIsInstance(async_mock, mock_type)
302            with assertNeverAwaited(self):
303                self.assertTrue(inspect.isawaitable(async_mock()))
304
305            sync_mock = mock_type(normal_func)
306            self.assertIsInstance(sync_mock, mock_type)
307
308        for mock_type in [AsyncMock, MagicMock, Mock]:
309            with self.subTest(f"test spec positional with {mock_type}"):
310                inner_test(mock_type)
311
312    def test_spec_as_normal_kw_AsyncMock(self):
313        mock = AsyncMock(spec=normal_func)
314        self.assertIsInstance(mock, AsyncMock)
315        m = mock()
316        self.assertTrue(inspect.isawaitable(m))
317        run(m)
318
319    def test_spec_as_normal_positional_AsyncMock(self):
320        mock = AsyncMock(normal_func)
321        self.assertIsInstance(mock, AsyncMock)
322        m = mock()
323        self.assertTrue(inspect.isawaitable(m))
324        run(m)
325
326    def test_spec_async_mock(self):
327        @patch.object(AsyncClass, 'async_method', spec=True)
328        def test_async(mock_method):
329            self.assertIsInstance(mock_method, AsyncMock)
330
331        test_async()
332
333    def test_spec_parent_not_async_attribute_is(self):
334        @patch(async_foo_name, spec=True)
335        def test_async(mock_method):
336            self.assertIsInstance(mock_method, MagicMock)
337            self.assertIsInstance(mock_method.async_method, AsyncMock)
338
339        test_async()
340
341    def test_target_async_spec_not(self):
342        @patch.object(AsyncClass, 'async_method', spec=NormalClass.a)
343        def test_async_attribute(mock_method):
344            self.assertIsInstance(mock_method, MagicMock)
345            self.assertFalse(inspect.iscoroutine(mock_method))
346            self.assertFalse(inspect.isawaitable(mock_method))
347
348        test_async_attribute()
349
350    def test_target_not_async_spec_is(self):
351        @patch.object(NormalClass, 'a', spec=async_func)
352        def test_attribute_not_async_spec_is(mock_async_func):
353            self.assertIsInstance(mock_async_func, AsyncMock)
354        test_attribute_not_async_spec_is()
355
356    def test_spec_async_attributes(self):
357        @patch(normal_foo_name, spec=AsyncClass)
358        def test_async_attributes_coroutines(MockNormalClass):
359            self.assertIsInstance(MockNormalClass.async_method, AsyncMock)
360            self.assertIsInstance(MockNormalClass, MagicMock)
361
362        test_async_attributes_coroutines()
363
364
365class AsyncSpecSetTest(unittest.TestCase):
366    def test_is_AsyncMock_patch(self):
367        @patch.object(AsyncClass, 'async_method', spec_set=True)
368        def test_async(async_method):
369            self.assertIsInstance(async_method, AsyncMock)
370        test_async()
371
372    def test_is_async_AsyncMock(self):
373        mock = AsyncMock(spec_set=AsyncClass.async_method)
374        self.assertTrue(iscoroutinefunction(mock))
375        self.assertIsInstance(mock, AsyncMock)
376
377    def test_is_child_AsyncMock(self):
378        mock = MagicMock(spec_set=AsyncClass)
379        self.assertTrue(iscoroutinefunction(mock.async_method))
380        self.assertFalse(iscoroutinefunction(mock.normal_method))
381        self.assertIsInstance(mock.async_method, AsyncMock)
382        self.assertIsInstance(mock.normal_method, MagicMock)
383        self.assertIsInstance(mock, MagicMock)
384
385    def test_magicmock_lambda_spec(self):
386        mock_obj = MagicMock()
387        mock_obj.mock_func = MagicMock(spec=lambda x: x)
388
389        with patch.object(mock_obj, "mock_func") as cm:
390            self.assertIsInstance(cm, MagicMock)
391
392
393class AsyncArguments(IsolatedAsyncioTestCase):
394    async def test_add_return_value(self):
395        async def addition(self, var): pass
396
397        mock = AsyncMock(addition, return_value=10)
398        output = await mock(5)
399
400        self.assertEqual(output, 10)
401
402    async def test_add_side_effect_exception(self):
403        async def addition(var): pass
404        mock = AsyncMock(addition, side_effect=Exception('err'))
405        with self.assertRaises(Exception):
406            await mock(5)
407
408    async def test_add_side_effect_coroutine(self):
409        async def addition(var):
410            return var + 1
411        mock = AsyncMock(side_effect=addition)
412        result = await mock(5)
413        self.assertEqual(result, 6)
414
415    async def test_add_side_effect_normal_function(self):
416        def addition(var):
417            return var + 1
418        mock = AsyncMock(side_effect=addition)
419        result = await mock(5)
420        self.assertEqual(result, 6)
421
422    async def test_add_side_effect_iterable(self):
423        vals = [1, 2, 3]
424        mock = AsyncMock(side_effect=vals)
425        for item in vals:
426            self.assertEqual(await mock(), item)
427
428        with self.assertRaises(StopAsyncIteration) as e:
429            await mock()
430
431    async def test_add_side_effect_exception_iterable(self):
432        class SampleException(Exception):
433            pass
434
435        vals = [1, SampleException("foo")]
436        mock = AsyncMock(side_effect=vals)
437        self.assertEqual(await mock(), 1)
438
439        with self.assertRaises(SampleException) as e:
440            await mock()
441
442    async def test_return_value_AsyncMock(self):
443        value = AsyncMock(return_value=10)
444        mock = AsyncMock(return_value=value)
445        result = await mock()
446        self.assertIs(result, value)
447
448    async def test_return_value_awaitable(self):
449        fut = asyncio.Future()
450        fut.set_result(None)
451        mock = AsyncMock(return_value=fut)
452        result = await mock()
453        self.assertIsInstance(result, asyncio.Future)
454
455    async def test_side_effect_awaitable_values(self):
456        fut = asyncio.Future()
457        fut.set_result(None)
458
459        mock = AsyncMock(side_effect=[fut])
460        result = await mock()
461        self.assertIsInstance(result, asyncio.Future)
462
463        with self.assertRaises(StopAsyncIteration):
464            await mock()
465
466    async def test_side_effect_is_AsyncMock(self):
467        effect = AsyncMock(return_value=10)
468        mock = AsyncMock(side_effect=effect)
469
470        result = await mock()
471        self.assertEqual(result, 10)
472
473    async def test_wraps_coroutine(self):
474        value = asyncio.Future()
475
476        ran = False
477        async def inner():
478            nonlocal ran
479            ran = True
480            return value
481
482        mock = AsyncMock(wraps=inner)
483        result = await mock()
484        self.assertEqual(result, value)
485        mock.assert_awaited()
486        self.assertTrue(ran)
487
488    async def test_wraps_normal_function(self):
489        value = 1
490
491        ran = False
492        def inner():
493            nonlocal ran
494            ran = True
495            return value
496
497        mock = AsyncMock(wraps=inner)
498        result = await mock()
499        self.assertEqual(result, value)
500        mock.assert_awaited()
501        self.assertTrue(ran)
502
503    async def test_await_args_list_order(self):
504        async_mock = AsyncMock()
505        mock2 = async_mock(2)
506        mock1 = async_mock(1)
507        await mock1
508        await mock2
509        async_mock.assert_has_awaits([call(1), call(2)])
510        self.assertEqual(async_mock.await_args_list, [call(1), call(2)])
511        self.assertEqual(async_mock.call_args_list, [call(2), call(1)])
512
513
514class AsyncMagicMethods(unittest.TestCase):
515    def test_async_magic_methods_return_async_mocks(self):
516        m_mock = MagicMock()
517        self.assertIsInstance(m_mock.__aenter__, AsyncMock)
518        self.assertIsInstance(m_mock.__aexit__, AsyncMock)
519        self.assertIsInstance(m_mock.__anext__, AsyncMock)
520        # __aiter__ is actually a synchronous object
521        # so should return a MagicMock
522        self.assertIsInstance(m_mock.__aiter__, MagicMock)
523
524    def test_sync_magic_methods_return_magic_mocks(self):
525        a_mock = AsyncMock()
526        self.assertIsInstance(a_mock.__enter__, MagicMock)
527        self.assertIsInstance(a_mock.__exit__, MagicMock)
528        self.assertIsInstance(a_mock.__next__, MagicMock)
529        self.assertIsInstance(a_mock.__len__, MagicMock)
530
531    def test_magicmock_has_async_magic_methods(self):
532        m_mock = MagicMock()
533        self.assertTrue(hasattr(m_mock, "__aenter__"))
534        self.assertTrue(hasattr(m_mock, "__aexit__"))
535        self.assertTrue(hasattr(m_mock, "__anext__"))
536
537    def test_asyncmock_has_sync_magic_methods(self):
538        a_mock = AsyncMock()
539        self.assertTrue(hasattr(a_mock, "__enter__"))
540        self.assertTrue(hasattr(a_mock, "__exit__"))
541        self.assertTrue(hasattr(a_mock, "__next__"))
542        self.assertTrue(hasattr(a_mock, "__len__"))
543
544    def test_magic_methods_are_async_functions(self):
545        m_mock = MagicMock()
546        self.assertIsInstance(m_mock.__aenter__, AsyncMock)
547        self.assertIsInstance(m_mock.__aexit__, AsyncMock)
548        # AsyncMocks are also coroutine functions
549        self.assertTrue(iscoroutinefunction(m_mock.__aenter__))
550        self.assertTrue(iscoroutinefunction(m_mock.__aexit__))
551
552class AsyncContextManagerTest(unittest.TestCase):
553
554    class WithAsyncContextManager:
555        async def __aenter__(self, *args, **kwargs): pass
556
557        async def __aexit__(self, *args, **kwargs): pass
558
559    class WithSyncContextManager:
560        def __enter__(self, *args, **kwargs): pass
561
562        def __exit__(self, *args, **kwargs): pass
563
564    class ProductionCode:
565        # Example real-world(ish) code
566        def __init__(self):
567            self.session = None
568
569        async def main(self):
570            async with self.session.post('https://python.org') as response:
571                val = await response.json()
572                return val
573
574    def test_set_return_value_of_aenter(self):
575        def inner_test(mock_type):
576            pc = self.ProductionCode()
577            pc.session = MagicMock(name='sessionmock')
578            cm = mock_type(name='magic_cm')
579            response = AsyncMock(name='response')
580            response.json = AsyncMock(return_value={'json': 123})
581            cm.__aenter__.return_value = response
582            pc.session.post.return_value = cm
583            result = run(pc.main())
584            self.assertEqual(result, {'json': 123})
585
586        for mock_type in [AsyncMock, MagicMock]:
587            with self.subTest(f"test set return value of aenter with {mock_type}"):
588                inner_test(mock_type)
589
590    def test_mock_supports_async_context_manager(self):
591        def inner_test(mock_type):
592            called = False
593            cm = self.WithAsyncContextManager()
594            cm_mock = mock_type(cm)
595
596            async def use_context_manager():
597                nonlocal called
598                async with cm_mock as result:
599                    called = True
600                return result
601
602            cm_result = run(use_context_manager())
603            self.assertTrue(called)
604            self.assertTrue(cm_mock.__aenter__.called)
605            self.assertTrue(cm_mock.__aexit__.called)
606            cm_mock.__aenter__.assert_awaited()
607            cm_mock.__aexit__.assert_awaited()
608            # We mock __aenter__ so it does not return self
609            self.assertIsNot(cm_mock, cm_result)
610
611        for mock_type in [AsyncMock, MagicMock]:
612            with self.subTest(f"test context manager magics with {mock_type}"):
613                inner_test(mock_type)
614
615
616    def test_mock_customize_async_context_manager(self):
617        instance = self.WithAsyncContextManager()
618        mock_instance = MagicMock(instance)
619
620        expected_result = object()
621        mock_instance.__aenter__.return_value = expected_result
622
623        async def use_context_manager():
624            async with mock_instance as result:
625                return result
626
627        self.assertIs(run(use_context_manager()), expected_result)
628
629    def test_mock_customize_async_context_manager_with_coroutine(self):
630        enter_called = False
631        exit_called = False
632
633        async def enter_coroutine(*args):
634            nonlocal enter_called
635            enter_called = True
636
637        async def exit_coroutine(*args):
638            nonlocal exit_called
639            exit_called = True
640
641        instance = self.WithAsyncContextManager()
642        mock_instance = MagicMock(instance)
643
644        mock_instance.__aenter__ = enter_coroutine
645        mock_instance.__aexit__ = exit_coroutine
646
647        async def use_context_manager():
648            async with mock_instance:
649                pass
650
651        run(use_context_manager())
652        self.assertTrue(enter_called)
653        self.assertTrue(exit_called)
654
655    def test_context_manager_raise_exception_by_default(self):
656        async def raise_in(context_manager):
657            async with context_manager:
658                raise TypeError()
659
660        instance = self.WithAsyncContextManager()
661        mock_instance = MagicMock(instance)
662        with self.assertRaises(TypeError):
663            run(raise_in(mock_instance))
664
665
666class AsyncIteratorTest(unittest.TestCase):
667    class WithAsyncIterator(object):
668        def __init__(self):
669            self.items = ["foo", "NormalFoo", "baz"]
670
671        def __aiter__(self): pass
672
673        async def __anext__(self): pass
674
675    def test_aiter_set_return_value(self):
676        mock_iter = AsyncMock(name="tester")
677        mock_iter.__aiter__.return_value = [1, 2, 3]
678        async def main():
679            return [i async for i in mock_iter]
680        result = run(main())
681        self.assertEqual(result, [1, 2, 3])
682
683    def test_mock_aiter_and_anext_asyncmock(self):
684        def inner_test(mock_type):
685            instance = self.WithAsyncIterator()
686            mock_instance = mock_type(instance)
687            # Check that the mock and the real thing bahave the same
688            # __aiter__ is not actually async, so not a coroutinefunction
689            self.assertFalse(iscoroutinefunction(instance.__aiter__))
690            self.assertFalse(iscoroutinefunction(mock_instance.__aiter__))
691            # __anext__ is async
692            self.assertTrue(iscoroutinefunction(instance.__anext__))
693            self.assertTrue(iscoroutinefunction(mock_instance.__anext__))
694
695        for mock_type in [AsyncMock, MagicMock]:
696            with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
697                inner_test(mock_type)
698
699
700    def test_mock_async_for(self):
701        async def iterate(iterator):
702            accumulator = []
703            async for item in iterator:
704                accumulator.append(item)
705
706            return accumulator
707
708        expected = ["FOO", "BAR", "BAZ"]
709        def test_default(mock_type):
710            mock_instance = mock_type(self.WithAsyncIterator())
711            self.assertEqual(run(iterate(mock_instance)), [])
712
713
714        def test_set_return_value(mock_type):
715            mock_instance = mock_type(self.WithAsyncIterator())
716            mock_instance.__aiter__.return_value = expected[:]
717            self.assertEqual(run(iterate(mock_instance)), expected)
718
719        def test_set_return_value_iter(mock_type):
720            mock_instance = mock_type(self.WithAsyncIterator())
721            mock_instance.__aiter__.return_value = iter(expected[:])
722            self.assertEqual(run(iterate(mock_instance)), expected)
723
724        for mock_type in [AsyncMock, MagicMock]:
725            with self.subTest(f"default value with {mock_type}"):
726                test_default(mock_type)
727
728            with self.subTest(f"set return_value with {mock_type}"):
729                test_set_return_value(mock_type)
730
731            with self.subTest(f"set return_value iterator with {mock_type}"):
732                test_set_return_value_iter(mock_type)
733
734
735class AsyncMockAssert(unittest.TestCase):
736    def setUp(self):
737        self.mock = AsyncMock()
738
739    async def _runnable_test(self, *args, **kwargs):
740        await self.mock(*args, **kwargs)
741
742    async def _await_coroutine(self, coroutine):
743        return await coroutine
744
745    def test_assert_called_but_not_awaited(self):
746        mock = AsyncMock(AsyncClass)
747        with assertNeverAwaited(self):
748            mock.async_method()
749        self.assertTrue(iscoroutinefunction(mock.async_method))
750        mock.async_method.assert_called()
751        mock.async_method.assert_called_once()
752        mock.async_method.assert_called_once_with()
753        with self.assertRaises(AssertionError):
754            mock.assert_awaited()
755        with self.assertRaises(AssertionError):
756            mock.async_method.assert_awaited()
757
758    def test_assert_called_then_awaited(self):
759        mock = AsyncMock(AsyncClass)
760        mock_coroutine = mock.async_method()
761        mock.async_method.assert_called()
762        mock.async_method.assert_called_once()
763        mock.async_method.assert_called_once_with()
764        with self.assertRaises(AssertionError):
765            mock.async_method.assert_awaited()
766
767        run(self._await_coroutine(mock_coroutine))
768        # Assert we haven't re-called the function
769        mock.async_method.assert_called_once()
770        mock.async_method.assert_awaited()
771        mock.async_method.assert_awaited_once()
772        mock.async_method.assert_awaited_once_with()
773
774    def test_assert_called_and_awaited_at_same_time(self):
775        with self.assertRaises(AssertionError):
776            self.mock.assert_awaited()
777
778        with self.assertRaises(AssertionError):
779            self.mock.assert_called()
780
781        run(self._runnable_test())
782        self.mock.assert_called_once()
783        self.mock.assert_awaited_once()
784
785    def test_assert_called_twice_and_awaited_once(self):
786        mock = AsyncMock(AsyncClass)
787        coroutine = mock.async_method()
788        # The first call will be awaited so no warning there
789        # But this call will never get awaited, so it will warn here
790        with assertNeverAwaited(self):
791            mock.async_method()
792        with self.assertRaises(AssertionError):
793            mock.async_method.assert_awaited()
794        mock.async_method.assert_called()
795        run(self._await_coroutine(coroutine))
796        mock.async_method.assert_awaited()
797        mock.async_method.assert_awaited_once()
798
799    def test_assert_called_once_and_awaited_twice(self):
800        mock = AsyncMock(AsyncClass)
801        coroutine = mock.async_method()
802        mock.async_method.assert_called_once()
803        run(self._await_coroutine(coroutine))
804        with self.assertRaises(RuntimeError):
805            # Cannot reuse already awaited coroutine
806            run(self._await_coroutine(coroutine))
807        mock.async_method.assert_awaited()
808
809    def test_assert_awaited_but_not_called(self):
810        with self.assertRaises(AssertionError):
811            self.mock.assert_awaited()
812        with self.assertRaises(AssertionError):
813            self.mock.assert_called()
814        with self.assertRaises(TypeError):
815            # You cannot await an AsyncMock, it must be a coroutine
816            run(self._await_coroutine(self.mock))
817
818        with self.assertRaises(AssertionError):
819            self.mock.assert_awaited()
820        with self.assertRaises(AssertionError):
821            self.mock.assert_called()
822
823    def test_assert_has_calls_not_awaits(self):
824        kalls = [call('foo')]
825        with assertNeverAwaited(self):
826            self.mock('foo')
827        self.mock.assert_has_calls(kalls)
828        with self.assertRaises(AssertionError):
829            self.mock.assert_has_awaits(kalls)
830
831    def test_assert_has_mock_calls_on_async_mock_no_spec(self):
832        with assertNeverAwaited(self):
833            self.mock()
834        kalls_empty = [('', (), {})]
835        self.assertEqual(self.mock.mock_calls, kalls_empty)
836
837        with assertNeverAwaited(self):
838            self.mock('foo')
839        with assertNeverAwaited(self):
840            self.mock('baz')
841        mock_kalls = ([call(), call('foo'), call('baz')])
842        self.assertEqual(self.mock.mock_calls, mock_kalls)
843
844    def test_assert_has_mock_calls_on_async_mock_with_spec(self):
845        a_class_mock = AsyncMock(AsyncClass)
846        with assertNeverAwaited(self):
847            a_class_mock.async_method()
848        kalls_empty = [('', (), {})]
849        self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty)
850        self.assertEqual(a_class_mock.mock_calls, [call.async_method()])
851
852        with assertNeverAwaited(self):
853            a_class_mock.async_method(1, 2, 3, a=4, b=5)
854        method_kalls = [call(), call(1, 2, 3, a=4, b=5)]
855        mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)]
856        self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls)
857        self.assertEqual(a_class_mock.mock_calls, mock_kalls)
858
859    def test_async_method_calls_recorded(self):
860        with assertNeverAwaited(self):
861            self.mock.something(3, fish=None)
862        with assertNeverAwaited(self):
863            self.mock.something_else.something(6, cake=sentinel.Cake)
864
865        self.assertEqual(self.mock.method_calls, [
866            ("something", (3,), {'fish': None}),
867            ("something_else.something", (6,), {'cake': sentinel.Cake})
868        ],
869            "method calls not recorded correctly")
870        self.assertEqual(self.mock.something_else.method_calls,
871                         [("something", (6,), {'cake': sentinel.Cake})],
872                         "method calls not recorded correctly")
873
874    def test_async_arg_lists(self):
875        def assert_attrs(mock):
876            names = ('call_args_list', 'method_calls', 'mock_calls')
877            for name in names:
878                attr = getattr(mock, name)
879                self.assertIsInstance(attr, _CallList)
880                self.assertIsInstance(attr, list)
881                self.assertEqual(attr, [])
882
883        assert_attrs(self.mock)
884        with assertNeverAwaited(self):
885            self.mock()
886        with assertNeverAwaited(self):
887            self.mock(1, 2)
888        with assertNeverAwaited(self):
889            self.mock(a=3)
890
891        self.mock.reset_mock()
892        assert_attrs(self.mock)
893
894        a_mock = AsyncMock(AsyncClass)
895        with assertNeverAwaited(self):
896            a_mock.async_method()
897        with assertNeverAwaited(self):
898            a_mock.async_method(1, a=3)
899
900        a_mock.reset_mock()
901        assert_attrs(a_mock)
902
903    def test_assert_awaited(self):
904        with self.assertRaises(AssertionError):
905            self.mock.assert_awaited()
906
907        run(self._runnable_test())
908        self.mock.assert_awaited()
909
910    def test_assert_awaited_once(self):
911        with self.assertRaises(AssertionError):
912            self.mock.assert_awaited_once()
913
914        run(self._runnable_test())
915        self.mock.assert_awaited_once()
916
917        run(self._runnable_test())
918        with self.assertRaises(AssertionError):
919            self.mock.assert_awaited_once()
920
921    def test_assert_awaited_with(self):
922        msg = 'Not awaited'
923        with self.assertRaisesRegex(AssertionError, msg):
924            self.mock.assert_awaited_with('foo')
925
926        run(self._runnable_test())
927        msg = 'expected await not found'
928        with self.assertRaisesRegex(AssertionError, msg):
929            self.mock.assert_awaited_with('foo')
930
931        run(self._runnable_test('foo'))
932        self.mock.assert_awaited_with('foo')
933
934        run(self._runnable_test('SomethingElse'))
935        with self.assertRaises(AssertionError):
936            self.mock.assert_awaited_with('foo')
937
938    def test_assert_awaited_once_with(self):
939        with self.assertRaises(AssertionError):
940            self.mock.assert_awaited_once_with('foo')
941
942        run(self._runnable_test('foo'))
943        self.mock.assert_awaited_once_with('foo')
944
945        run(self._runnable_test('foo'))
946        with self.assertRaises(AssertionError):
947            self.mock.assert_awaited_once_with('foo')
948
949    def test_assert_any_wait(self):
950        with self.assertRaises(AssertionError):
951            self.mock.assert_any_await('foo')
952
953        run(self._runnable_test('baz'))
954        with self.assertRaises(AssertionError):
955            self.mock.assert_any_await('foo')
956
957        run(self._runnable_test('foo'))
958        self.mock.assert_any_await('foo')
959
960        run(self._runnable_test('SomethingElse'))
961        self.mock.assert_any_await('foo')
962
963    def test_assert_has_awaits_no_order(self):
964        calls = [call('foo'), call('baz')]
965
966        with self.assertRaises(AssertionError) as cm:
967            self.mock.assert_has_awaits(calls)
968        self.assertEqual(len(cm.exception.args), 1)
969
970        run(self._runnable_test('foo'))
971        with self.assertRaises(AssertionError):
972            self.mock.assert_has_awaits(calls)
973
974        run(self._runnable_test('foo'))
975        with self.assertRaises(AssertionError):
976            self.mock.assert_has_awaits(calls)
977
978        run(self._runnable_test('baz'))
979        self.mock.assert_has_awaits(calls)
980
981        run(self._runnable_test('SomethingElse'))
982        self.mock.assert_has_awaits(calls)
983
984    def test_awaits_asserts_with_any(self):
985        class Foo:
986            def __eq__(self, other): pass
987
988        run(self._runnable_test(Foo(), 1))
989
990        self.mock.assert_has_awaits([call(ANY, 1)])
991        self.mock.assert_awaited_with(ANY, 1)
992        self.mock.assert_any_await(ANY, 1)
993
994    def test_awaits_asserts_with_spec_and_any(self):
995        class Foo:
996            def __eq__(self, other): pass
997
998        mock_with_spec = AsyncMock(spec=Foo)
999
1000        async def _custom_mock_runnable_test(*args):
1001            await mock_with_spec(*args)
1002
1003        run(_custom_mock_runnable_test(Foo(), 1))
1004        mock_with_spec.assert_has_awaits([call(ANY, 1)])
1005        mock_with_spec.assert_awaited_with(ANY, 1)
1006        mock_with_spec.assert_any_await(ANY, 1)
1007
1008    def test_assert_has_awaits_ordered(self):
1009        calls = [call('foo'), call('baz')]
1010        with self.assertRaises(AssertionError):
1011            self.mock.assert_has_awaits(calls, any_order=True)
1012
1013        run(self._runnable_test('baz'))
1014        with self.assertRaises(AssertionError):
1015            self.mock.assert_has_awaits(calls, any_order=True)
1016
1017        run(self._runnable_test('bamf'))
1018        with self.assertRaises(AssertionError):
1019            self.mock.assert_has_awaits(calls, any_order=True)
1020
1021        run(self._runnable_test('foo'))
1022        self.mock.assert_has_awaits(calls, any_order=True)
1023
1024        run(self._runnable_test('qux'))
1025        self.mock.assert_has_awaits(calls, any_order=True)
1026
1027    def test_assert_not_awaited(self):
1028        self.mock.assert_not_awaited()
1029
1030        run(self._runnable_test())
1031        with self.assertRaises(AssertionError):
1032            self.mock.assert_not_awaited()
1033
1034    def test_assert_has_awaits_not_matching_spec_error(self):
1035        async def f(x=None): pass
1036
1037        self.mock = AsyncMock(spec=f)
1038        run(self._runnable_test(1))
1039
1040        with self.assertRaisesRegex(
1041                AssertionError,
1042                '^{}$'.format(
1043                    re.escape('Awaits not found.\n'
1044                              'Expected: [call()]\n'
1045                              'Actual: [call(1)]'))) as cm:
1046            self.mock.assert_has_awaits([call()])
1047        self.assertIsNone(cm.exception.__cause__)
1048
1049        with self.assertRaisesRegex(
1050                AssertionError,
1051                '^{}$'.format(
1052                    re.escape(
1053                        'Error processing expected awaits.\n'
1054                        "Errors: [None, TypeError('too many positional "
1055                        "arguments')]\n"
1056                        'Expected: [call(), call(1, 2)]\n'
1057                        'Actual: [call(1)]'))) as cm:
1058            self.mock.assert_has_awaits([call(), call(1, 2)])
1059        self.assertIsInstance(cm.exception.__cause__, TypeError)
1060