• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from warnings import catch_warnings
3
4from unittest.test.testmock.support import is_instance
5from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call
6
7
8
9something  = sentinel.Something
10something_else  = sentinel.SomethingElse
11
12
13class SampleException(Exception): pass
14
15
16class WithTest(unittest.TestCase):
17
18    def test_with_statement(self):
19        with patch('%s.something' % __name__, sentinel.Something2):
20            self.assertEqual(something, sentinel.Something2, "unpatched")
21        self.assertEqual(something, sentinel.Something)
22
23
24    def test_with_statement_exception(self):
25        with self.assertRaises(SampleException):
26            with patch('%s.something' % __name__, sentinel.Something2):
27                self.assertEqual(something, sentinel.Something2, "unpatched")
28                raise SampleException()
29        self.assertEqual(something, sentinel.Something)
30
31
32    def test_with_statement_as(self):
33        with patch('%s.something' % __name__) as mock_something:
34            self.assertEqual(something, mock_something, "unpatched")
35            self.assertTrue(is_instance(mock_something, MagicMock),
36                            "patching wrong type")
37        self.assertEqual(something, sentinel.Something)
38
39
40    def test_patch_object_with_statement(self):
41        class Foo(object):
42            something = 'foo'
43        original = Foo.something
44        with patch.object(Foo, 'something'):
45            self.assertNotEqual(Foo.something, original, "unpatched")
46        self.assertEqual(Foo.something, original)
47
48
49    def test_with_statement_nested(self):
50        with catch_warnings(record=True):
51            with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else:
52                self.assertEqual(something, mock_something, "unpatched")
53                self.assertEqual(something_else, mock_something_else,
54                                 "unpatched")
55
56        self.assertEqual(something, sentinel.Something)
57        self.assertEqual(something_else, sentinel.SomethingElse)
58
59
60    def test_with_statement_specified(self):
61        with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
62            self.assertEqual(something, mock_something, "unpatched")
63            self.assertEqual(mock_something, sentinel.Patched, "wrong patch")
64        self.assertEqual(something, sentinel.Something)
65
66
67    def testContextManagerMocking(self):
68        mock = Mock()
69        mock.__enter__ = Mock()
70        mock.__exit__ = Mock()
71        mock.__exit__.return_value = False
72
73        with mock as m:
74            self.assertEqual(m, mock.__enter__.return_value)
75        mock.__enter__.assert_called_with()
76        mock.__exit__.assert_called_with(None, None, None)
77
78
79    def test_context_manager_with_magic_mock(self):
80        mock = MagicMock()
81
82        with self.assertRaises(TypeError):
83            with mock:
84                'foo' + 3
85        mock.__enter__.assert_called_with()
86        self.assertTrue(mock.__exit__.called)
87
88
89    def test_with_statement_same_attribute(self):
90        with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
91            self.assertEqual(something, mock_something, "unpatched")
92
93            with patch('%s.something' % __name__) as mock_again:
94                self.assertEqual(something, mock_again, "unpatched")
95
96            self.assertEqual(something, mock_something,
97                             "restored with wrong instance")
98
99        self.assertEqual(something, sentinel.Something, "not restored")
100
101
102    def test_with_statement_imbricated(self):
103        with patch('%s.something' % __name__) as mock_something:
104            self.assertEqual(something, mock_something, "unpatched")
105
106            with patch('%s.something_else' % __name__) as mock_something_else:
107                self.assertEqual(something_else, mock_something_else,
108                                 "unpatched")
109
110        self.assertEqual(something, sentinel.Something)
111        self.assertEqual(something_else, sentinel.SomethingElse)
112
113
114    def test_dict_context_manager(self):
115        foo = {}
116        with patch.dict(foo, {'a': 'b'}):
117            self.assertEqual(foo, {'a': 'b'})
118        self.assertEqual(foo, {})
119
120        with self.assertRaises(NameError):
121            with patch.dict(foo, {'a': 'b'}):
122                self.assertEqual(foo, {'a': 'b'})
123                raise NameError('Konrad')
124
125        self.assertEqual(foo, {})
126
127    def test_double_patch_instance_method(self):
128        class C:
129            def f(self): pass
130
131        c = C()
132
133        with patch.object(c, 'f', autospec=True) as patch1:
134            with patch.object(c, 'f', autospec=True) as patch2:
135                c.f()
136            self.assertEqual(patch2.call_count, 1)
137            self.assertEqual(patch1.call_count, 0)
138            c.f()
139        self.assertEqual(patch1.call_count, 1)
140
141
142class TestMockOpen(unittest.TestCase):
143
144    def test_mock_open(self):
145        mock = mock_open()
146        with patch('%s.open' % __name__, mock, create=True) as patched:
147            self.assertIs(patched, mock)
148            open('foo')
149
150        mock.assert_called_once_with('foo')
151
152
153    def test_mock_open_context_manager(self):
154        mock = mock_open()
155        handle = mock.return_value
156        with patch('%s.open' % __name__, mock, create=True):
157            with open('foo') as f:
158                f.read()
159
160        expected_calls = [call('foo'), call().__enter__(), call().read(),
161                          call().__exit__(None, None, None)]
162        self.assertEqual(mock.mock_calls, expected_calls)
163        self.assertIs(f, handle)
164
165    def test_mock_open_context_manager_multiple_times(self):
166        mock = mock_open()
167        with patch('%s.open' % __name__, mock, create=True):
168            with open('foo') as f:
169                f.read()
170            with open('bar') as f:
171                f.read()
172
173        expected_calls = [
174            call('foo'), call().__enter__(), call().read(),
175            call().__exit__(None, None, None),
176            call('bar'), call().__enter__(), call().read(),
177            call().__exit__(None, None, None)]
178        self.assertEqual(mock.mock_calls, expected_calls)
179
180    def test_explicit_mock(self):
181        mock = MagicMock()
182        mock_open(mock)
183
184        with patch('%s.open' % __name__, mock, create=True) as patched:
185            self.assertIs(patched, mock)
186            open('foo')
187
188        mock.assert_called_once_with('foo')
189
190
191    def test_read_data(self):
192        mock = mock_open(read_data='foo')
193        with patch('%s.open' % __name__, mock, create=True):
194            h = open('bar')
195            result = h.read()
196
197        self.assertEqual(result, 'foo')
198
199
200    def test_readline_data(self):
201        # Check that readline will return all the lines from the fake file
202        # And that once fully consumed, readline will return an empty string.
203        mock = mock_open(read_data='foo\nbar\nbaz\n')
204        with patch('%s.open' % __name__, mock, create=True):
205            h = open('bar')
206            line1 = h.readline()
207            line2 = h.readline()
208            line3 = h.readline()
209        self.assertEqual(line1, 'foo\n')
210        self.assertEqual(line2, 'bar\n')
211        self.assertEqual(line3, 'baz\n')
212        self.assertEqual(h.readline(), '')
213
214        # Check that we properly emulate a file that doesn't end in a newline
215        mock = mock_open(read_data='foo')
216        with patch('%s.open' % __name__, mock, create=True):
217            h = open('bar')
218            result = h.readline()
219        self.assertEqual(result, 'foo')
220        self.assertEqual(h.readline(), '')
221
222
223    def test_dunder_iter_data(self):
224        # Check that dunder_iter will return all the lines from the fake file.
225        mock = mock_open(read_data='foo\nbar\nbaz\n')
226        with patch('%s.open' % __name__, mock, create=True):
227            h = open('bar')
228            lines = [l for l in h]
229        self.assertEqual(lines[0], 'foo\n')
230        self.assertEqual(lines[1], 'bar\n')
231        self.assertEqual(lines[2], 'baz\n')
232        self.assertEqual(h.readline(), '')
233        with self.assertRaises(StopIteration):
234            next(h)
235
236    def test_next_data(self):
237        # Check that next will correctly return the next available
238        # line and plays well with the dunder_iter part.
239        mock = mock_open(read_data='foo\nbar\nbaz\n')
240        with patch('%s.open' % __name__, mock, create=True):
241            h = open('bar')
242            line1 = next(h)
243            line2 = next(h)
244            lines = [l for l in h]
245        self.assertEqual(line1, 'foo\n')
246        self.assertEqual(line2, 'bar\n')
247        self.assertEqual(lines[0], 'baz\n')
248        self.assertEqual(h.readline(), '')
249
250    def test_readlines_data(self):
251        # Test that emulating a file that ends in a newline character works
252        mock = mock_open(read_data='foo\nbar\nbaz\n')
253        with patch('%s.open' % __name__, mock, create=True):
254            h = open('bar')
255            result = h.readlines()
256        self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n'])
257
258        # Test that files without a final newline will also be correctly
259        # emulated
260        mock = mock_open(read_data='foo\nbar\nbaz')
261        with patch('%s.open' % __name__, mock, create=True):
262            h = open('bar')
263            result = h.readlines()
264
265        self.assertEqual(result, ['foo\n', 'bar\n', 'baz'])
266
267
268    def test_read_bytes(self):
269        mock = mock_open(read_data=b'\xc6')
270        with patch('%s.open' % __name__, mock, create=True):
271            with open('abc', 'rb') as f:
272                result = f.read()
273        self.assertEqual(result, b'\xc6')
274
275
276    def test_readline_bytes(self):
277        m = mock_open(read_data=b'abc\ndef\nghi\n')
278        with patch('%s.open' % __name__, m, create=True):
279            with open('abc', 'rb') as f:
280                line1 = f.readline()
281                line2 = f.readline()
282                line3 = f.readline()
283        self.assertEqual(line1, b'abc\n')
284        self.assertEqual(line2, b'def\n')
285        self.assertEqual(line3, b'ghi\n')
286
287
288    def test_readlines_bytes(self):
289        m = mock_open(read_data=b'abc\ndef\nghi\n')
290        with patch('%s.open' % __name__, m, create=True):
291            with open('abc', 'rb') as f:
292                result = f.readlines()
293        self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n'])
294
295
296    def test_mock_open_read_with_argument(self):
297        # At one point calling read with an argument was broken
298        # for mocks returned by mock_open
299        some_data = 'foo\nbar\nbaz'
300        mock = mock_open(read_data=some_data)
301        self.assertEqual(mock().read(10), some_data[:10])
302        self.assertEqual(mock().read(10), some_data[:10])
303
304        f = mock()
305        self.assertEqual(f.read(10), some_data[:10])
306        self.assertEqual(f.read(10), some_data[10:])
307
308
309    def test_interleaved_reads(self):
310        # Test that calling read, readline, and readlines pulls data
311        # sequentially from the data we preload with
312        mock = mock_open(read_data='foo\nbar\nbaz\n')
313        with patch('%s.open' % __name__, mock, create=True):
314            h = open('bar')
315            line1 = h.readline()
316            rest = h.readlines()
317        self.assertEqual(line1, 'foo\n')
318        self.assertEqual(rest, ['bar\n', 'baz\n'])
319
320        mock = mock_open(read_data='foo\nbar\nbaz\n')
321        with patch('%s.open' % __name__, mock, create=True):
322            h = open('bar')
323            line1 = h.readline()
324            rest = h.read()
325        self.assertEqual(line1, 'foo\n')
326        self.assertEqual(rest, 'bar\nbaz\n')
327
328
329    def test_overriding_return_values(self):
330        mock = mock_open(read_data='foo')
331        handle = mock()
332
333        handle.read.return_value = 'bar'
334        handle.readline.return_value = 'bar'
335        handle.readlines.return_value = ['bar']
336
337        self.assertEqual(handle.read(), 'bar')
338        self.assertEqual(handle.readline(), 'bar')
339        self.assertEqual(handle.readlines(), ['bar'])
340
341        # call repeatedly to check that a StopIteration is not propagated
342        self.assertEqual(handle.readline(), 'bar')
343        self.assertEqual(handle.readline(), 'bar')
344
345
346if __name__ == '__main__':
347    unittest.main()
348