1import unittest 2from warnings import catch_warnings 3 4from unittest.test.testmock.support import is_instance 5from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call 6 7 8 9something = sentinel.Something 10something_else = sentinel.SomethingElse 11 12 13 14class WithTest(unittest.TestCase): 15 16 def test_with_statement(self): 17 with patch('%s.something' % __name__, sentinel.Something2): 18 self.assertEqual(something, sentinel.Something2, "unpatched") 19 self.assertEqual(something, sentinel.Something) 20 21 22 def test_with_statement_exception(self): 23 try: 24 with patch('%s.something' % __name__, sentinel.Something2): 25 self.assertEqual(something, sentinel.Something2, "unpatched") 26 raise Exception('pow') 27 except Exception: 28 pass 29 else: 30 self.fail("patch swallowed exception") 31 self.assertEqual(something, sentinel.Something) 32 33 34 def test_with_statement_as(self): 35 with patch('%s.something' % __name__) as mock_something: 36 self.assertEqual(something, mock_something, "unpatched") 37 self.assertTrue(is_instance(mock_something, MagicMock), 38 "patching wrong type") 39 self.assertEqual(something, sentinel.Something) 40 41 42 def test_patch_object_with_statement(self): 43 class Foo(object): 44 something = 'foo' 45 original = Foo.something 46 with patch.object(Foo, 'something'): 47 self.assertNotEqual(Foo.something, original, "unpatched") 48 self.assertEqual(Foo.something, original) 49 50 51 def test_with_statement_nested(self): 52 with catch_warnings(record=True): 53 with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else: 54 self.assertEqual(something, mock_something, "unpatched") 55 self.assertEqual(something_else, mock_something_else, 56 "unpatched") 57 58 self.assertEqual(something, sentinel.Something) 59 self.assertEqual(something_else, sentinel.SomethingElse) 60 61 62 def test_with_statement_specified(self): 63 with patch('%s.something' % __name__, sentinel.Patched) as mock_something: 64 self.assertEqual(something, mock_something, "unpatched") 65 self.assertEqual(mock_something, sentinel.Patched, "wrong patch") 66 self.assertEqual(something, sentinel.Something) 67 68 69 def testContextManagerMocking(self): 70 mock = Mock() 71 mock.__enter__ = Mock() 72 mock.__exit__ = Mock() 73 mock.__exit__.return_value = False 74 75 with mock as m: 76 self.assertEqual(m, mock.__enter__.return_value) 77 mock.__enter__.assert_called_with() 78 mock.__exit__.assert_called_with(None, None, None) 79 80 81 def test_context_manager_with_magic_mock(self): 82 mock = MagicMock() 83 84 with self.assertRaises(TypeError): 85 with mock: 86 'foo' + 3 87 mock.__enter__.assert_called_with() 88 self.assertTrue(mock.__exit__.called) 89 90 91 def test_with_statement_same_attribute(self): 92 with patch('%s.something' % __name__, sentinel.Patched) as mock_something: 93 self.assertEqual(something, mock_something, "unpatched") 94 95 with patch('%s.something' % __name__) as mock_again: 96 self.assertEqual(something, mock_again, "unpatched") 97 98 self.assertEqual(something, mock_something, 99 "restored with wrong instance") 100 101 self.assertEqual(something, sentinel.Something, "not restored") 102 103 104 def test_with_statement_imbricated(self): 105 with patch('%s.something' % __name__) as mock_something: 106 self.assertEqual(something, mock_something, "unpatched") 107 108 with patch('%s.something_else' % __name__) as mock_something_else: 109 self.assertEqual(something_else, mock_something_else, 110 "unpatched") 111 112 self.assertEqual(something, sentinel.Something) 113 self.assertEqual(something_else, sentinel.SomethingElse) 114 115 116 def test_dict_context_manager(self): 117 foo = {} 118 with patch.dict(foo, {'a': 'b'}): 119 self.assertEqual(foo, {'a': 'b'}) 120 self.assertEqual(foo, {}) 121 122 with self.assertRaises(NameError): 123 with patch.dict(foo, {'a': 'b'}): 124 self.assertEqual(foo, {'a': 'b'}) 125 raise NameError('Konrad') 126 127 self.assertEqual(foo, {}) 128 129 def test_double_patch_instance_method(self): 130 class C: 131 def f(self): 132 pass 133 134 c = C() 135 136 with patch.object(c, 'f', autospec=True) as patch1: 137 with patch.object(c, 'f', autospec=True) as patch2: 138 c.f() 139 self.assertEqual(patch2.call_count, 1) 140 self.assertEqual(patch1.call_count, 0) 141 c.f() 142 self.assertEqual(patch1.call_count, 1) 143 144 145class TestMockOpen(unittest.TestCase): 146 147 def test_mock_open(self): 148 mock = mock_open() 149 with patch('%s.open' % __name__, mock, create=True) as patched: 150 self.assertIs(patched, mock) 151 open('foo') 152 153 mock.assert_called_once_with('foo') 154 155 156 def test_mock_open_context_manager(self): 157 mock = mock_open() 158 handle = mock.return_value 159 with patch('%s.open' % __name__, mock, create=True): 160 with open('foo') as f: 161 f.read() 162 163 expected_calls = [call('foo'), call().__enter__(), call().read(), 164 call().__exit__(None, None, None)] 165 self.assertEqual(mock.mock_calls, expected_calls) 166 self.assertIs(f, handle) 167 168 def test_mock_open_context_manager_multiple_times(self): 169 mock = mock_open() 170 with patch('%s.open' % __name__, mock, create=True): 171 with open('foo') as f: 172 f.read() 173 with open('bar') as f: 174 f.read() 175 176 expected_calls = [ 177 call('foo'), call().__enter__(), call().read(), 178 call().__exit__(None, None, None), 179 call('bar'), call().__enter__(), call().read(), 180 call().__exit__(None, None, None)] 181 self.assertEqual(mock.mock_calls, expected_calls) 182 183 def test_explicit_mock(self): 184 mock = MagicMock() 185 mock_open(mock) 186 187 with patch('%s.open' % __name__, mock, create=True) as patched: 188 self.assertIs(patched, mock) 189 open('foo') 190 191 mock.assert_called_once_with('foo') 192 193 194 def test_read_data(self): 195 mock = mock_open(read_data='foo') 196 with patch('%s.open' % __name__, mock, create=True): 197 h = open('bar') 198 result = h.read() 199 200 self.assertEqual(result, 'foo') 201 202 203 def test_readline_data(self): 204 # Check that readline will return all the lines from the fake file 205 # And that once fully consumed, readline will return an empty string. 206 mock = mock_open(read_data='foo\nbar\nbaz\n') 207 with patch('%s.open' % __name__, mock, create=True): 208 h = open('bar') 209 line1 = h.readline() 210 line2 = h.readline() 211 line3 = h.readline() 212 self.assertEqual(line1, 'foo\n') 213 self.assertEqual(line2, 'bar\n') 214 self.assertEqual(line3, 'baz\n') 215 self.assertEqual(h.readline(), '') 216 217 # Check that we properly emulate a file that doesn't end in a newline 218 mock = mock_open(read_data='foo') 219 with patch('%s.open' % __name__, mock, create=True): 220 h = open('bar') 221 result = h.readline() 222 self.assertEqual(result, 'foo') 223 self.assertEqual(h.readline(), '') 224 225 226 def test_dunder_iter_data(self): 227 # Check that dunder_iter will return all the lines from the fake file. 228 mock = mock_open(read_data='foo\nbar\nbaz\n') 229 with patch('%s.open' % __name__, mock, create=True): 230 h = open('bar') 231 lines = [l for l in h] 232 self.assertEqual(lines[0], 'foo\n') 233 self.assertEqual(lines[1], 'bar\n') 234 self.assertEqual(lines[2], 'baz\n') 235 self.assertEqual(h.readline(), '') 236 237 238 def test_readlines_data(self): 239 # Test that emulating a file that ends in a newline character works 240 mock = mock_open(read_data='foo\nbar\nbaz\n') 241 with patch('%s.open' % __name__, mock, create=True): 242 h = open('bar') 243 result = h.readlines() 244 self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n']) 245 246 # Test that files without a final newline will also be correctly 247 # emulated 248 mock = mock_open(read_data='foo\nbar\nbaz') 249 with patch('%s.open' % __name__, mock, create=True): 250 h = open('bar') 251 result = h.readlines() 252 253 self.assertEqual(result, ['foo\n', 'bar\n', 'baz']) 254 255 256 def test_read_bytes(self): 257 mock = mock_open(read_data=b'\xc6') 258 with patch('%s.open' % __name__, mock, create=True): 259 with open('abc', 'rb') as f: 260 result = f.read() 261 self.assertEqual(result, b'\xc6') 262 263 264 def test_readline_bytes(self): 265 m = mock_open(read_data=b'abc\ndef\nghi\n') 266 with patch('%s.open' % __name__, m, create=True): 267 with open('abc', 'rb') as f: 268 line1 = f.readline() 269 line2 = f.readline() 270 line3 = f.readline() 271 self.assertEqual(line1, b'abc\n') 272 self.assertEqual(line2, b'def\n') 273 self.assertEqual(line3, b'ghi\n') 274 275 276 def test_readlines_bytes(self): 277 m = mock_open(read_data=b'abc\ndef\nghi\n') 278 with patch('%s.open' % __name__, m, create=True): 279 with open('abc', 'rb') as f: 280 result = f.readlines() 281 self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n']) 282 283 284 def test_mock_open_read_with_argument(self): 285 # At one point calling read with an argument was broken 286 # for mocks returned by mock_open 287 some_data = 'foo\nbar\nbaz' 288 mock = mock_open(read_data=some_data) 289 self.assertEqual(mock().read(10), some_data) 290 291 292 def test_interleaved_reads(self): 293 # Test that calling read, readline, and readlines pulls data 294 # sequentially from the data we preload with 295 mock = mock_open(read_data='foo\nbar\nbaz\n') 296 with patch('%s.open' % __name__, mock, create=True): 297 h = open('bar') 298 line1 = h.readline() 299 rest = h.readlines() 300 self.assertEqual(line1, 'foo\n') 301 self.assertEqual(rest, ['bar\n', 'baz\n']) 302 303 mock = mock_open(read_data='foo\nbar\nbaz\n') 304 with patch('%s.open' % __name__, mock, create=True): 305 h = open('bar') 306 line1 = h.readline() 307 rest = h.read() 308 self.assertEqual(line1, 'foo\n') 309 self.assertEqual(rest, 'bar\nbaz\n') 310 311 312 def test_overriding_return_values(self): 313 mock = mock_open(read_data='foo') 314 handle = mock() 315 316 handle.read.return_value = 'bar' 317 handle.readline.return_value = 'bar' 318 handle.readlines.return_value = ['bar'] 319 320 self.assertEqual(handle.read(), 'bar') 321 self.assertEqual(handle.readline(), 'bar') 322 self.assertEqual(handle.readlines(), ['bar']) 323 324 # call repeatedly to check that a StopIteration is not propagated 325 self.assertEqual(handle.readline(), 'bar') 326 self.assertEqual(handle.readline(), 'bar') 327 328 329if __name__ == '__main__': 330 unittest.main() 331