• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from warnings import catch_warnings
3
4from unittest.test.testmock.support import is_instance
5from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call
6
7
8
9something  = sentinel.Something
10something_else  = sentinel.SomethingElse
11
12
13
14class WithTest(unittest.TestCase):
15
16    def test_with_statement(self):
17        with patch('%s.something' % __name__, sentinel.Something2):
18            self.assertEqual(something, sentinel.Something2, "unpatched")
19        self.assertEqual(something, sentinel.Something)
20
21
22    def test_with_statement_exception(self):
23        try:
24            with patch('%s.something' % __name__, sentinel.Something2):
25                self.assertEqual(something, sentinel.Something2, "unpatched")
26                raise Exception('pow')
27        except Exception:
28            pass
29        else:
30            self.fail("patch swallowed exception")
31        self.assertEqual(something, sentinel.Something)
32
33
34    def test_with_statement_as(self):
35        with patch('%s.something' % __name__) as mock_something:
36            self.assertEqual(something, mock_something, "unpatched")
37            self.assertTrue(is_instance(mock_something, MagicMock),
38                            "patching wrong type")
39        self.assertEqual(something, sentinel.Something)
40
41
42    def test_patch_object_with_statement(self):
43        class Foo(object):
44            something = 'foo'
45        original = Foo.something
46        with patch.object(Foo, 'something'):
47            self.assertNotEqual(Foo.something, original, "unpatched")
48        self.assertEqual(Foo.something, original)
49
50
51    def test_with_statement_nested(self):
52        with catch_warnings(record=True):
53            with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else:
54                self.assertEqual(something, mock_something, "unpatched")
55                self.assertEqual(something_else, mock_something_else,
56                                 "unpatched")
57
58        self.assertEqual(something, sentinel.Something)
59        self.assertEqual(something_else, sentinel.SomethingElse)
60
61
62    def test_with_statement_specified(self):
63        with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
64            self.assertEqual(something, mock_something, "unpatched")
65            self.assertEqual(mock_something, sentinel.Patched, "wrong patch")
66        self.assertEqual(something, sentinel.Something)
67
68
69    def testContextManagerMocking(self):
70        mock = Mock()
71        mock.__enter__ = Mock()
72        mock.__exit__ = Mock()
73        mock.__exit__.return_value = False
74
75        with mock as m:
76            self.assertEqual(m, mock.__enter__.return_value)
77        mock.__enter__.assert_called_with()
78        mock.__exit__.assert_called_with(None, None, None)
79
80
81    def test_context_manager_with_magic_mock(self):
82        mock = MagicMock()
83
84        with self.assertRaises(TypeError):
85            with mock:
86                'foo' + 3
87        mock.__enter__.assert_called_with()
88        self.assertTrue(mock.__exit__.called)
89
90
91    def test_with_statement_same_attribute(self):
92        with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
93            self.assertEqual(something, mock_something, "unpatched")
94
95            with patch('%s.something' % __name__) as mock_again:
96                self.assertEqual(something, mock_again, "unpatched")
97
98            self.assertEqual(something, mock_something,
99                             "restored with wrong instance")
100
101        self.assertEqual(something, sentinel.Something, "not restored")
102
103
104    def test_with_statement_imbricated(self):
105        with patch('%s.something' % __name__) as mock_something:
106            self.assertEqual(something, mock_something, "unpatched")
107
108            with patch('%s.something_else' % __name__) as mock_something_else:
109                self.assertEqual(something_else, mock_something_else,
110                                 "unpatched")
111
112        self.assertEqual(something, sentinel.Something)
113        self.assertEqual(something_else, sentinel.SomethingElse)
114
115
116    def test_dict_context_manager(self):
117        foo = {}
118        with patch.dict(foo, {'a': 'b'}):
119            self.assertEqual(foo, {'a': 'b'})
120        self.assertEqual(foo, {})
121
122        with self.assertRaises(NameError):
123            with patch.dict(foo, {'a': 'b'}):
124                self.assertEqual(foo, {'a': 'b'})
125                raise NameError('Konrad')
126
127        self.assertEqual(foo, {})
128
129
130
131class TestMockOpen(unittest.TestCase):
132
133    def test_mock_open(self):
134        mock = mock_open()
135        with patch('%s.open' % __name__, mock, create=True) as patched:
136            self.assertIs(patched, mock)
137            open('foo')
138
139        mock.assert_called_once_with('foo')
140
141
142    def test_mock_open_context_manager(self):
143        mock = mock_open()
144        handle = mock.return_value
145        with patch('%s.open' % __name__, mock, create=True):
146            with open('foo') as f:
147                f.read()
148
149        expected_calls = [call('foo'), call().__enter__(), call().read(),
150                          call().__exit__(None, None, None)]
151        self.assertEqual(mock.mock_calls, expected_calls)
152        self.assertIs(f, handle)
153
154    def test_mock_open_context_manager_multiple_times(self):
155        mock = mock_open()
156        with patch('%s.open' % __name__, mock, create=True):
157            with open('foo') as f:
158                f.read()
159            with open('bar') as f:
160                f.read()
161
162        expected_calls = [
163            call('foo'), call().__enter__(), call().read(),
164            call().__exit__(None, None, None),
165            call('bar'), call().__enter__(), call().read(),
166            call().__exit__(None, None, None)]
167        self.assertEqual(mock.mock_calls, expected_calls)
168
169    def test_explicit_mock(self):
170        mock = MagicMock()
171        mock_open(mock)
172
173        with patch('%s.open' % __name__, mock, create=True) as patched:
174            self.assertIs(patched, mock)
175            open('foo')
176
177        mock.assert_called_once_with('foo')
178
179
180    def test_read_data(self):
181        mock = mock_open(read_data='foo')
182        with patch('%s.open' % __name__, mock, create=True):
183            h = open('bar')
184            result = h.read()
185
186        self.assertEqual(result, 'foo')
187
188
189    def test_readline_data(self):
190        # Check that readline will return all the lines from the fake file
191        mock = mock_open(read_data='foo\nbar\nbaz\n')
192        with patch('%s.open' % __name__, mock, create=True):
193            h = open('bar')
194            line1 = h.readline()
195            line2 = h.readline()
196            line3 = h.readline()
197        self.assertEqual(line1, 'foo\n')
198        self.assertEqual(line2, 'bar\n')
199        self.assertEqual(line3, 'baz\n')
200
201        # Check that we properly emulate a file that doesn't end in a newline
202        mock = mock_open(read_data='foo')
203        with patch('%s.open' % __name__, mock, create=True):
204            h = open('bar')
205            result = h.readline()
206        self.assertEqual(result, 'foo')
207
208
209    def test_readlines_data(self):
210        # Test that emulating a file that ends in a newline character works
211        mock = mock_open(read_data='foo\nbar\nbaz\n')
212        with patch('%s.open' % __name__, mock, create=True):
213            h = open('bar')
214            result = h.readlines()
215        self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n'])
216
217        # Test that files without a final newline will also be correctly
218        # emulated
219        mock = mock_open(read_data='foo\nbar\nbaz')
220        with patch('%s.open' % __name__, mock, create=True):
221            h = open('bar')
222            result = h.readlines()
223
224        self.assertEqual(result, ['foo\n', 'bar\n', 'baz'])
225
226
227    def test_read_bytes(self):
228        mock = mock_open(read_data=b'\xc6')
229        with patch('%s.open' % __name__, mock, create=True):
230            with open('abc', 'rb') as f:
231                result = f.read()
232        self.assertEqual(result, b'\xc6')
233
234
235    def test_readline_bytes(self):
236        m = mock_open(read_data=b'abc\ndef\nghi\n')
237        with patch('%s.open' % __name__, m, create=True):
238            with open('abc', 'rb') as f:
239                line1 = f.readline()
240                line2 = f.readline()
241                line3 = f.readline()
242        self.assertEqual(line1, b'abc\n')
243        self.assertEqual(line2, b'def\n')
244        self.assertEqual(line3, b'ghi\n')
245
246
247    def test_readlines_bytes(self):
248        m = mock_open(read_data=b'abc\ndef\nghi\n')
249        with patch('%s.open' % __name__, m, create=True):
250            with open('abc', 'rb') as f:
251                result = f.readlines()
252        self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n'])
253
254
255    def test_mock_open_read_with_argument(self):
256        # At one point calling read with an argument was broken
257        # for mocks returned by mock_open
258        some_data = 'foo\nbar\nbaz'
259        mock = mock_open(read_data=some_data)
260        self.assertEqual(mock().read(10), some_data)
261
262
263    def test_interleaved_reads(self):
264        # Test that calling read, readline, and readlines pulls data
265        # sequentially from the data we preload with
266        mock = mock_open(read_data='foo\nbar\nbaz\n')
267        with patch('%s.open' % __name__, mock, create=True):
268            h = open('bar')
269            line1 = h.readline()
270            rest = h.readlines()
271        self.assertEqual(line1, 'foo\n')
272        self.assertEqual(rest, ['bar\n', 'baz\n'])
273
274        mock = mock_open(read_data='foo\nbar\nbaz\n')
275        with patch('%s.open' % __name__, mock, create=True):
276            h = open('bar')
277            line1 = h.readline()
278            rest = h.read()
279        self.assertEqual(line1, 'foo\n')
280        self.assertEqual(rest, 'bar\nbaz\n')
281
282
283    def test_overriding_return_values(self):
284        mock = mock_open(read_data='foo')
285        handle = mock()
286
287        handle.read.return_value = 'bar'
288        handle.readline.return_value = 'bar'
289        handle.readlines.return_value = ['bar']
290
291        self.assertEqual(handle.read(), 'bar')
292        self.assertEqual(handle.readline(), 'bar')
293        self.assertEqual(handle.readlines(), ['bar'])
294
295        # call repeatedly to check that a StopIteration is not propagated
296        self.assertEqual(handle.readline(), 'bar')
297        self.assertEqual(handle.readline(), 'bar')
298
299
300if __name__ == '__main__':
301    unittest.main()
302