1import unittest 2from warnings import catch_warnings 3 4from unittest.test.testmock.support import is_instance 5from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call 6 7 8 9something = sentinel.Something 10something_else = sentinel.SomethingElse 11 12 13class SampleException(Exception): pass 14 15 16class WithTest(unittest.TestCase): 17 18 def test_with_statement(self): 19 with patch('%s.something' % __name__, sentinel.Something2): 20 self.assertEqual(something, sentinel.Something2, "unpatched") 21 self.assertEqual(something, sentinel.Something) 22 23 24 def test_with_statement_exception(self): 25 with self.assertRaises(SampleException): 26 with patch('%s.something' % __name__, sentinel.Something2): 27 self.assertEqual(something, sentinel.Something2, "unpatched") 28 raise SampleException() 29 self.assertEqual(something, sentinel.Something) 30 31 32 def test_with_statement_as(self): 33 with patch('%s.something' % __name__) as mock_something: 34 self.assertEqual(something, mock_something, "unpatched") 35 self.assertTrue(is_instance(mock_something, MagicMock), 36 "patching wrong type") 37 self.assertEqual(something, sentinel.Something) 38 39 40 def test_patch_object_with_statement(self): 41 class Foo(object): 42 something = 'foo' 43 original = Foo.something 44 with patch.object(Foo, 'something'): 45 self.assertNotEqual(Foo.something, original, "unpatched") 46 self.assertEqual(Foo.something, original) 47 48 49 def test_with_statement_nested(self): 50 with catch_warnings(record=True): 51 with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else: 52 self.assertEqual(something, mock_something, "unpatched") 53 self.assertEqual(something_else, mock_something_else, 54 "unpatched") 55 56 self.assertEqual(something, sentinel.Something) 57 self.assertEqual(something_else, sentinel.SomethingElse) 58 59 60 def test_with_statement_specified(self): 61 with patch('%s.something' % __name__, sentinel.Patched) as mock_something: 62 self.assertEqual(something, mock_something, "unpatched") 63 self.assertEqual(mock_something, sentinel.Patched, "wrong patch") 64 self.assertEqual(something, sentinel.Something) 65 66 67 def testContextManagerMocking(self): 68 mock = Mock() 69 mock.__enter__ = Mock() 70 mock.__exit__ = Mock() 71 mock.__exit__.return_value = False 72 73 with mock as m: 74 self.assertEqual(m, mock.__enter__.return_value) 75 mock.__enter__.assert_called_with() 76 mock.__exit__.assert_called_with(None, None, None) 77 78 79 def test_context_manager_with_magic_mock(self): 80 mock = MagicMock() 81 82 with self.assertRaises(TypeError): 83 with mock: 84 'foo' + 3 85 mock.__enter__.assert_called_with() 86 self.assertTrue(mock.__exit__.called) 87 88 89 def test_with_statement_same_attribute(self): 90 with patch('%s.something' % __name__, sentinel.Patched) as mock_something: 91 self.assertEqual(something, mock_something, "unpatched") 92 93 with patch('%s.something' % __name__) as mock_again: 94 self.assertEqual(something, mock_again, "unpatched") 95 96 self.assertEqual(something, mock_something, 97 "restored with wrong instance") 98 99 self.assertEqual(something, sentinel.Something, "not restored") 100 101 102 def test_with_statement_imbricated(self): 103 with patch('%s.something' % __name__) as mock_something: 104 self.assertEqual(something, mock_something, "unpatched") 105 106 with patch('%s.something_else' % __name__) as mock_something_else: 107 self.assertEqual(something_else, mock_something_else, 108 "unpatched") 109 110 self.assertEqual(something, sentinel.Something) 111 self.assertEqual(something_else, sentinel.SomethingElse) 112 113 114 def test_dict_context_manager(self): 115 foo = {} 116 with patch.dict(foo, {'a': 'b'}): 117 self.assertEqual(foo, {'a': 'b'}) 118 self.assertEqual(foo, {}) 119 120 with self.assertRaises(NameError): 121 with patch.dict(foo, {'a': 'b'}): 122 self.assertEqual(foo, {'a': 'b'}) 123 raise NameError('Konrad') 124 125 self.assertEqual(foo, {}) 126 127 def test_double_patch_instance_method(self): 128 class C: 129 def f(self): pass 130 131 c = C() 132 133 with patch.object(c, 'f', autospec=True) as patch1: 134 with patch.object(c, 'f', autospec=True) as patch2: 135 c.f() 136 self.assertEqual(patch2.call_count, 1) 137 self.assertEqual(patch1.call_count, 0) 138 c.f() 139 self.assertEqual(patch1.call_count, 1) 140 141 142class TestMockOpen(unittest.TestCase): 143 144 def test_mock_open(self): 145 mock = mock_open() 146 with patch('%s.open' % __name__, mock, create=True) as patched: 147 self.assertIs(patched, mock) 148 open('foo') 149 150 mock.assert_called_once_with('foo') 151 152 153 def test_mock_open_context_manager(self): 154 mock = mock_open() 155 handle = mock.return_value 156 with patch('%s.open' % __name__, mock, create=True): 157 with open('foo') as f: 158 f.read() 159 160 expected_calls = [call('foo'), call().__enter__(), call().read(), 161 call().__exit__(None, None, None)] 162 self.assertEqual(mock.mock_calls, expected_calls) 163 self.assertIs(f, handle) 164 165 def test_mock_open_context_manager_multiple_times(self): 166 mock = mock_open() 167 with patch('%s.open' % __name__, mock, create=True): 168 with open('foo') as f: 169 f.read() 170 with open('bar') as f: 171 f.read() 172 173 expected_calls = [ 174 call('foo'), call().__enter__(), call().read(), 175 call().__exit__(None, None, None), 176 call('bar'), call().__enter__(), call().read(), 177 call().__exit__(None, None, None)] 178 self.assertEqual(mock.mock_calls, expected_calls) 179 180 def test_explicit_mock(self): 181 mock = MagicMock() 182 mock_open(mock) 183 184 with patch('%s.open' % __name__, mock, create=True) as patched: 185 self.assertIs(patched, mock) 186 open('foo') 187 188 mock.assert_called_once_with('foo') 189 190 191 def test_read_data(self): 192 mock = mock_open(read_data='foo') 193 with patch('%s.open' % __name__, mock, create=True): 194 h = open('bar') 195 result = h.read() 196 197 self.assertEqual(result, 'foo') 198 199 200 def test_readline_data(self): 201 # Check that readline will return all the lines from the fake file 202 # And that once fully consumed, readline will return an empty string. 203 mock = mock_open(read_data='foo\nbar\nbaz\n') 204 with patch('%s.open' % __name__, mock, create=True): 205 h = open('bar') 206 line1 = h.readline() 207 line2 = h.readline() 208 line3 = h.readline() 209 self.assertEqual(line1, 'foo\n') 210 self.assertEqual(line2, 'bar\n') 211 self.assertEqual(line3, 'baz\n') 212 self.assertEqual(h.readline(), '') 213 214 # Check that we properly emulate a file that doesn't end in a newline 215 mock = mock_open(read_data='foo') 216 with patch('%s.open' % __name__, mock, create=True): 217 h = open('bar') 218 result = h.readline() 219 self.assertEqual(result, 'foo') 220 self.assertEqual(h.readline(), '') 221 222 223 def test_dunder_iter_data(self): 224 # Check that dunder_iter will return all the lines from the fake file. 225 mock = mock_open(read_data='foo\nbar\nbaz\n') 226 with patch('%s.open' % __name__, mock, create=True): 227 h = open('bar') 228 lines = [l for l in h] 229 self.assertEqual(lines[0], 'foo\n') 230 self.assertEqual(lines[1], 'bar\n') 231 self.assertEqual(lines[2], 'baz\n') 232 self.assertEqual(h.readline(), '') 233 with self.assertRaises(StopIteration): 234 next(h) 235 236 def test_next_data(self): 237 # Check that next will correctly return the next available 238 # line and plays well with the dunder_iter part. 239 mock = mock_open(read_data='foo\nbar\nbaz\n') 240 with patch('%s.open' % __name__, mock, create=True): 241 h = open('bar') 242 line1 = next(h) 243 line2 = next(h) 244 lines = [l for l in h] 245 self.assertEqual(line1, 'foo\n') 246 self.assertEqual(line2, 'bar\n') 247 self.assertEqual(lines[0], 'baz\n') 248 self.assertEqual(h.readline(), '') 249 250 def test_readlines_data(self): 251 # Test that emulating a file that ends in a newline character works 252 mock = mock_open(read_data='foo\nbar\nbaz\n') 253 with patch('%s.open' % __name__, mock, create=True): 254 h = open('bar') 255 result = h.readlines() 256 self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n']) 257 258 # Test that files without a final newline will also be correctly 259 # emulated 260 mock = mock_open(read_data='foo\nbar\nbaz') 261 with patch('%s.open' % __name__, mock, create=True): 262 h = open('bar') 263 result = h.readlines() 264 265 self.assertEqual(result, ['foo\n', 'bar\n', 'baz']) 266 267 268 def test_read_bytes(self): 269 mock = mock_open(read_data=b'\xc6') 270 with patch('%s.open' % __name__, mock, create=True): 271 with open('abc', 'rb') as f: 272 result = f.read() 273 self.assertEqual(result, b'\xc6') 274 275 276 def test_readline_bytes(self): 277 m = mock_open(read_data=b'abc\ndef\nghi\n') 278 with patch('%s.open' % __name__, m, create=True): 279 with open('abc', 'rb') as f: 280 line1 = f.readline() 281 line2 = f.readline() 282 line3 = f.readline() 283 self.assertEqual(line1, b'abc\n') 284 self.assertEqual(line2, b'def\n') 285 self.assertEqual(line3, b'ghi\n') 286 287 288 def test_readlines_bytes(self): 289 m = mock_open(read_data=b'abc\ndef\nghi\n') 290 with patch('%s.open' % __name__, m, create=True): 291 with open('abc', 'rb') as f: 292 result = f.readlines() 293 self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n']) 294 295 296 def test_mock_open_read_with_argument(self): 297 # At one point calling read with an argument was broken 298 # for mocks returned by mock_open 299 some_data = 'foo\nbar\nbaz' 300 mock = mock_open(read_data=some_data) 301 self.assertEqual(mock().read(10), some_data[:10]) 302 self.assertEqual(mock().read(10), some_data[:10]) 303 304 f = mock() 305 self.assertEqual(f.read(10), some_data[:10]) 306 self.assertEqual(f.read(10), some_data[10:]) 307 308 309 def test_interleaved_reads(self): 310 # Test that calling read, readline, and readlines pulls data 311 # sequentially from the data we preload with 312 mock = mock_open(read_data='foo\nbar\nbaz\n') 313 with patch('%s.open' % __name__, mock, create=True): 314 h = open('bar') 315 line1 = h.readline() 316 rest = h.readlines() 317 self.assertEqual(line1, 'foo\n') 318 self.assertEqual(rest, ['bar\n', 'baz\n']) 319 320 mock = mock_open(read_data='foo\nbar\nbaz\n') 321 with patch('%s.open' % __name__, mock, create=True): 322 h = open('bar') 323 line1 = h.readline() 324 rest = h.read() 325 self.assertEqual(line1, 'foo\n') 326 self.assertEqual(rest, 'bar\nbaz\n') 327 328 329 def test_overriding_return_values(self): 330 mock = mock_open(read_data='foo') 331 handle = mock() 332 333 handle.read.return_value = 'bar' 334 handle.readline.return_value = 'bar' 335 handle.readlines.return_value = ['bar'] 336 337 self.assertEqual(handle.read(), 'bar') 338 self.assertEqual(handle.readline(), 'bar') 339 self.assertEqual(handle.readlines(), ['bar']) 340 341 # call repeatedly to check that a StopIteration is not propagated 342 self.assertEqual(handle.readline(), 'bar') 343 self.assertEqual(handle.readline(), 'bar') 344 345 346if __name__ == '__main__': 347 unittest.main() 348