• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from warnings import catch_warnings
3
4from unittest.test.testmock.support import is_instance
5from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call
6
7
8
9something  = sentinel.Something
10something_else  = sentinel.SomethingElse
11
12
13
14class WithTest(unittest.TestCase):
15
16    def test_with_statement(self):
17        with patch('%s.something' % __name__, sentinel.Something2):
18            self.assertEqual(something, sentinel.Something2, "unpatched")
19        self.assertEqual(something, sentinel.Something)
20
21
22    def test_with_statement_exception(self):
23        try:
24            with patch('%s.something' % __name__, sentinel.Something2):
25                self.assertEqual(something, sentinel.Something2, "unpatched")
26                raise Exception('pow')
27        except Exception:
28            pass
29        else:
30            self.fail("patch swallowed exception")
31        self.assertEqual(something, sentinel.Something)
32
33
34    def test_with_statement_as(self):
35        with patch('%s.something' % __name__) as mock_something:
36            self.assertEqual(something, mock_something, "unpatched")
37            self.assertTrue(is_instance(mock_something, MagicMock),
38                            "patching wrong type")
39        self.assertEqual(something, sentinel.Something)
40
41
42    def test_patch_object_with_statement(self):
43        class Foo(object):
44            something = 'foo'
45        original = Foo.something
46        with patch.object(Foo, 'something'):
47            self.assertNotEqual(Foo.something, original, "unpatched")
48        self.assertEqual(Foo.something, original)
49
50
51    def test_with_statement_nested(self):
52        with catch_warnings(record=True):
53            with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else:
54                self.assertEqual(something, mock_something, "unpatched")
55                self.assertEqual(something_else, mock_something_else,
56                                 "unpatched")
57
58        self.assertEqual(something, sentinel.Something)
59        self.assertEqual(something_else, sentinel.SomethingElse)
60
61
62    def test_with_statement_specified(self):
63        with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
64            self.assertEqual(something, mock_something, "unpatched")
65            self.assertEqual(mock_something, sentinel.Patched, "wrong patch")
66        self.assertEqual(something, sentinel.Something)
67
68
69    def testContextManagerMocking(self):
70        mock = Mock()
71        mock.__enter__ = Mock()
72        mock.__exit__ = Mock()
73        mock.__exit__.return_value = False
74
75        with mock as m:
76            self.assertEqual(m, mock.__enter__.return_value)
77        mock.__enter__.assert_called_with()
78        mock.__exit__.assert_called_with(None, None, None)
79
80
81    def test_context_manager_with_magic_mock(self):
82        mock = MagicMock()
83
84        with self.assertRaises(TypeError):
85            with mock:
86                'foo' + 3
87        mock.__enter__.assert_called_with()
88        self.assertTrue(mock.__exit__.called)
89
90
91    def test_with_statement_same_attribute(self):
92        with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
93            self.assertEqual(something, mock_something, "unpatched")
94
95            with patch('%s.something' % __name__) as mock_again:
96                self.assertEqual(something, mock_again, "unpatched")
97
98            self.assertEqual(something, mock_something,
99                             "restored with wrong instance")
100
101        self.assertEqual(something, sentinel.Something, "not restored")
102
103
104    def test_with_statement_imbricated(self):
105        with patch('%s.something' % __name__) as mock_something:
106            self.assertEqual(something, mock_something, "unpatched")
107
108            with patch('%s.something_else' % __name__) as mock_something_else:
109                self.assertEqual(something_else, mock_something_else,
110                                 "unpatched")
111
112        self.assertEqual(something, sentinel.Something)
113        self.assertEqual(something_else, sentinel.SomethingElse)
114
115
116    def test_dict_context_manager(self):
117        foo = {}
118        with patch.dict(foo, {'a': 'b'}):
119            self.assertEqual(foo, {'a': 'b'})
120        self.assertEqual(foo, {})
121
122        with self.assertRaises(NameError):
123            with patch.dict(foo, {'a': 'b'}):
124                self.assertEqual(foo, {'a': 'b'})
125                raise NameError('Konrad')
126
127        self.assertEqual(foo, {})
128
129    def test_double_patch_instance_method(self):
130        class C:
131            def f(self):
132                pass
133
134        c = C()
135
136        with patch.object(c, 'f', autospec=True) as patch1:
137            with patch.object(c, 'f', autospec=True) as patch2:
138                c.f()
139            self.assertEqual(patch2.call_count, 1)
140            self.assertEqual(patch1.call_count, 0)
141            c.f()
142        self.assertEqual(patch1.call_count, 1)
143
144
145class TestMockOpen(unittest.TestCase):
146
147    def test_mock_open(self):
148        mock = mock_open()
149        with patch('%s.open' % __name__, mock, create=True) as patched:
150            self.assertIs(patched, mock)
151            open('foo')
152
153        mock.assert_called_once_with('foo')
154
155
156    def test_mock_open_context_manager(self):
157        mock = mock_open()
158        handle = mock.return_value
159        with patch('%s.open' % __name__, mock, create=True):
160            with open('foo') as f:
161                f.read()
162
163        expected_calls = [call('foo'), call().__enter__(), call().read(),
164                          call().__exit__(None, None, None)]
165        self.assertEqual(mock.mock_calls, expected_calls)
166        self.assertIs(f, handle)
167
168    def test_mock_open_context_manager_multiple_times(self):
169        mock = mock_open()
170        with patch('%s.open' % __name__, mock, create=True):
171            with open('foo') as f:
172                f.read()
173            with open('bar') as f:
174                f.read()
175
176        expected_calls = [
177            call('foo'), call().__enter__(), call().read(),
178            call().__exit__(None, None, None),
179            call('bar'), call().__enter__(), call().read(),
180            call().__exit__(None, None, None)]
181        self.assertEqual(mock.mock_calls, expected_calls)
182
183    def test_explicit_mock(self):
184        mock = MagicMock()
185        mock_open(mock)
186
187        with patch('%s.open' % __name__, mock, create=True) as patched:
188            self.assertIs(patched, mock)
189            open('foo')
190
191        mock.assert_called_once_with('foo')
192
193
194    def test_read_data(self):
195        mock = mock_open(read_data='foo')
196        with patch('%s.open' % __name__, mock, create=True):
197            h = open('bar')
198            result = h.read()
199
200        self.assertEqual(result, 'foo')
201
202
203    def test_readline_data(self):
204        # Check that readline will return all the lines from the fake file
205        # And that once fully consumed, readline will return an empty string.
206        mock = mock_open(read_data='foo\nbar\nbaz\n')
207        with patch('%s.open' % __name__, mock, create=True):
208            h = open('bar')
209            line1 = h.readline()
210            line2 = h.readline()
211            line3 = h.readline()
212        self.assertEqual(line1, 'foo\n')
213        self.assertEqual(line2, 'bar\n')
214        self.assertEqual(line3, 'baz\n')
215        self.assertEqual(h.readline(), '')
216
217        # Check that we properly emulate a file that doesn't end in a newline
218        mock = mock_open(read_data='foo')
219        with patch('%s.open' % __name__, mock, create=True):
220            h = open('bar')
221            result = h.readline()
222        self.assertEqual(result, 'foo')
223        self.assertEqual(h.readline(), '')
224
225
226    def test_dunder_iter_data(self):
227        # Check that dunder_iter will return all the lines from the fake file.
228        mock = mock_open(read_data='foo\nbar\nbaz\n')
229        with patch('%s.open' % __name__, mock, create=True):
230            h = open('bar')
231            lines = [l for l in h]
232        self.assertEqual(lines[0], 'foo\n')
233        self.assertEqual(lines[1], 'bar\n')
234        self.assertEqual(lines[2], 'baz\n')
235        self.assertEqual(h.readline(), '')
236
237
238    def test_readlines_data(self):
239        # Test that emulating a file that ends in a newline character works
240        mock = mock_open(read_data='foo\nbar\nbaz\n')
241        with patch('%s.open' % __name__, mock, create=True):
242            h = open('bar')
243            result = h.readlines()
244        self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n'])
245
246        # Test that files without a final newline will also be correctly
247        # emulated
248        mock = mock_open(read_data='foo\nbar\nbaz')
249        with patch('%s.open' % __name__, mock, create=True):
250            h = open('bar')
251            result = h.readlines()
252
253        self.assertEqual(result, ['foo\n', 'bar\n', 'baz'])
254
255
256    def test_read_bytes(self):
257        mock = mock_open(read_data=b'\xc6')
258        with patch('%s.open' % __name__, mock, create=True):
259            with open('abc', 'rb') as f:
260                result = f.read()
261        self.assertEqual(result, b'\xc6')
262
263
264    def test_readline_bytes(self):
265        m = mock_open(read_data=b'abc\ndef\nghi\n')
266        with patch('%s.open' % __name__, m, create=True):
267            with open('abc', 'rb') as f:
268                line1 = f.readline()
269                line2 = f.readline()
270                line3 = f.readline()
271        self.assertEqual(line1, b'abc\n')
272        self.assertEqual(line2, b'def\n')
273        self.assertEqual(line3, b'ghi\n')
274
275
276    def test_readlines_bytes(self):
277        m = mock_open(read_data=b'abc\ndef\nghi\n')
278        with patch('%s.open' % __name__, m, create=True):
279            with open('abc', 'rb') as f:
280                result = f.readlines()
281        self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n'])
282
283
284    def test_mock_open_read_with_argument(self):
285        # At one point calling read with an argument was broken
286        # for mocks returned by mock_open
287        some_data = 'foo\nbar\nbaz'
288        mock = mock_open(read_data=some_data)
289        self.assertEqual(mock().read(10), some_data)
290
291
292    def test_interleaved_reads(self):
293        # Test that calling read, readline, and readlines pulls data
294        # sequentially from the data we preload with
295        mock = mock_open(read_data='foo\nbar\nbaz\n')
296        with patch('%s.open' % __name__, mock, create=True):
297            h = open('bar')
298            line1 = h.readline()
299            rest = h.readlines()
300        self.assertEqual(line1, 'foo\n')
301        self.assertEqual(rest, ['bar\n', 'baz\n'])
302
303        mock = mock_open(read_data='foo\nbar\nbaz\n')
304        with patch('%s.open' % __name__, mock, create=True):
305            h = open('bar')
306            line1 = h.readline()
307            rest = h.read()
308        self.assertEqual(line1, 'foo\n')
309        self.assertEqual(rest, 'bar\nbaz\n')
310
311
312    def test_overriding_return_values(self):
313        mock = mock_open(read_data='foo')
314        handle = mock()
315
316        handle.read.return_value = 'bar'
317        handle.readline.return_value = 'bar'
318        handle.readlines.return_value = ['bar']
319
320        self.assertEqual(handle.read(), 'bar')
321        self.assertEqual(handle.readline(), 'bar')
322        self.assertEqual(handle.readlines(), ['bar'])
323
324        # call repeatedly to check that a StopIteration is not propagated
325        self.assertEqual(handle.readline(), 'bar')
326        self.assertEqual(handle.readline(), 'bar')
327
328
329if __name__ == '__main__':
330    unittest.main()
331