1from test.support import findfile, TESTFN, unlink 2import array 3import io 4from unittest import mock 5import pickle 6 7 8class UnseekableIO(io.FileIO): 9 def tell(self): 10 raise io.UnsupportedOperation 11 12 def seek(self, *args, **kwargs): 13 raise io.UnsupportedOperation 14 15 16class AudioTests: 17 close_fd = False 18 19 def setUp(self): 20 self.f = self.fout = None 21 22 def tearDown(self): 23 if self.f is not None: 24 self.f.close() 25 if self.fout is not None: 26 self.fout.close() 27 unlink(TESTFN) 28 29 def check_params(self, f, nchannels, sampwidth, framerate, nframes, 30 comptype, compname): 31 self.assertEqual(f.getnchannels(), nchannels) 32 self.assertEqual(f.getsampwidth(), sampwidth) 33 self.assertEqual(f.getframerate(), framerate) 34 self.assertEqual(f.getnframes(), nframes) 35 self.assertEqual(f.getcomptype(), comptype) 36 self.assertEqual(f.getcompname(), compname) 37 38 params = f.getparams() 39 self.assertEqual(params, 40 (nchannels, sampwidth, framerate, nframes, comptype, compname)) 41 self.assertEqual(params.nchannels, nchannels) 42 self.assertEqual(params.sampwidth, sampwidth) 43 self.assertEqual(params.framerate, framerate) 44 self.assertEqual(params.nframes, nframes) 45 self.assertEqual(params.comptype, comptype) 46 self.assertEqual(params.compname, compname) 47 48 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 49 dump = pickle.dumps(params, proto) 50 self.assertEqual(pickle.loads(dump), params) 51 52 53class AudioMiscTests(AudioTests): 54 55 def test_openfp_deprecated(self): 56 arg = "arg" 57 mode = "mode" 58 with mock.patch(f"{self.module.__name__}.open") as mock_open, \ 59 self.assertWarns(DeprecationWarning): 60 self.module.openfp(arg, mode=mode) 61 mock_open.assert_called_with(arg, mode=mode) 62 63 64class AudioWriteTests(AudioTests): 65 66 def create_file(self, testfile): 67 f = self.fout = self.module.open(testfile, 'wb') 68 f.setnchannels(self.nchannels) 69 f.setsampwidth(self.sampwidth) 70 f.setframerate(self.framerate) 71 f.setcomptype(self.comptype, self.compname) 72 return f 73 74 def check_file(self, testfile, nframes, frames): 75 with self.module.open(testfile, 'rb') as f: 76 self.assertEqual(f.getnchannels(), self.nchannels) 77 self.assertEqual(f.getsampwidth(), self.sampwidth) 78 self.assertEqual(f.getframerate(), self.framerate) 79 self.assertEqual(f.getnframes(), nframes) 80 self.assertEqual(f.readframes(nframes), frames) 81 82 def test_write_params(self): 83 f = self.create_file(TESTFN) 84 f.setnframes(self.nframes) 85 f.writeframes(self.frames) 86 self.check_params(f, self.nchannels, self.sampwidth, self.framerate, 87 self.nframes, self.comptype, self.compname) 88 f.close() 89 90 def test_write_context_manager_calls_close(self): 91 # Close checks for a minimum header and will raise an error 92 # if it is not set, so this proves that close is called. 93 with self.assertRaises(self.module.Error): 94 with self.module.open(TESTFN, 'wb'): 95 pass 96 with self.assertRaises(self.module.Error): 97 with open(TESTFN, 'wb') as testfile: 98 with self.module.open(testfile): 99 pass 100 101 def test_context_manager_with_open_file(self): 102 with open(TESTFN, 'wb') as testfile: 103 with self.module.open(testfile) as f: 104 f.setnchannels(self.nchannels) 105 f.setsampwidth(self.sampwidth) 106 f.setframerate(self.framerate) 107 f.setcomptype(self.comptype, self.compname) 108 self.assertEqual(testfile.closed, self.close_fd) 109 with open(TESTFN, 'rb') as testfile: 110 with self.module.open(testfile) as f: 111 self.assertFalse(f.getfp().closed) 112 params = f.getparams() 113 self.assertEqual(params.nchannels, self.nchannels) 114 self.assertEqual(params.sampwidth, self.sampwidth) 115 self.assertEqual(params.framerate, self.framerate) 116 if not self.close_fd: 117 self.assertIsNone(f.getfp()) 118 self.assertEqual(testfile.closed, self.close_fd) 119 120 def test_context_manager_with_filename(self): 121 # If the file doesn't get closed, this test won't fail, but it will 122 # produce a resource leak warning. 123 with self.module.open(TESTFN, 'wb') as f: 124 f.setnchannels(self.nchannels) 125 f.setsampwidth(self.sampwidth) 126 f.setframerate(self.framerate) 127 f.setcomptype(self.comptype, self.compname) 128 with self.module.open(TESTFN) as f: 129 self.assertFalse(f.getfp().closed) 130 params = f.getparams() 131 self.assertEqual(params.nchannels, self.nchannels) 132 self.assertEqual(params.sampwidth, self.sampwidth) 133 self.assertEqual(params.framerate, self.framerate) 134 if not self.close_fd: 135 self.assertIsNone(f.getfp()) 136 137 def test_write(self): 138 f = self.create_file(TESTFN) 139 f.setnframes(self.nframes) 140 f.writeframes(self.frames) 141 f.close() 142 143 self.check_file(TESTFN, self.nframes, self.frames) 144 145 def test_write_bytearray(self): 146 f = self.create_file(TESTFN) 147 f.setnframes(self.nframes) 148 f.writeframes(bytearray(self.frames)) 149 f.close() 150 151 self.check_file(TESTFN, self.nframes, self.frames) 152 153 def test_write_array(self): 154 f = self.create_file(TESTFN) 155 f.setnframes(self.nframes) 156 f.writeframes(array.array('h', self.frames)) 157 f.close() 158 159 self.check_file(TESTFN, self.nframes, self.frames) 160 161 def test_write_memoryview(self): 162 f = self.create_file(TESTFN) 163 f.setnframes(self.nframes) 164 f.writeframes(memoryview(self.frames)) 165 f.close() 166 167 self.check_file(TESTFN, self.nframes, self.frames) 168 169 def test_incompleted_write(self): 170 with open(TESTFN, 'wb') as testfile: 171 testfile.write(b'ababagalamaga') 172 f = self.create_file(testfile) 173 f.setnframes(self.nframes + 1) 174 f.writeframes(self.frames) 175 f.close() 176 177 with open(TESTFN, 'rb') as testfile: 178 self.assertEqual(testfile.read(13), b'ababagalamaga') 179 self.check_file(testfile, self.nframes, self.frames) 180 181 def test_multiple_writes(self): 182 with open(TESTFN, 'wb') as testfile: 183 testfile.write(b'ababagalamaga') 184 f = self.create_file(testfile) 185 f.setnframes(self.nframes) 186 framesize = self.nchannels * self.sampwidth 187 f.writeframes(self.frames[:-framesize]) 188 f.writeframes(self.frames[-framesize:]) 189 f.close() 190 191 with open(TESTFN, 'rb') as testfile: 192 self.assertEqual(testfile.read(13), b'ababagalamaga') 193 self.check_file(testfile, self.nframes, self.frames) 194 195 def test_overflowed_write(self): 196 with open(TESTFN, 'wb') as testfile: 197 testfile.write(b'ababagalamaga') 198 f = self.create_file(testfile) 199 f.setnframes(self.nframes - 1) 200 f.writeframes(self.frames) 201 f.close() 202 203 with open(TESTFN, 'rb') as testfile: 204 self.assertEqual(testfile.read(13), b'ababagalamaga') 205 self.check_file(testfile, self.nframes, self.frames) 206 207 def test_unseekable_read(self): 208 with self.create_file(TESTFN) as f: 209 f.setnframes(self.nframes) 210 f.writeframes(self.frames) 211 212 with UnseekableIO(TESTFN, 'rb') as testfile: 213 self.check_file(testfile, self.nframes, self.frames) 214 215 def test_unseekable_write(self): 216 with UnseekableIO(TESTFN, 'wb') as testfile: 217 with self.create_file(testfile) as f: 218 f.setnframes(self.nframes) 219 f.writeframes(self.frames) 220 221 self.check_file(TESTFN, self.nframes, self.frames) 222 223 def test_unseekable_incompleted_write(self): 224 with UnseekableIO(TESTFN, 'wb') as testfile: 225 testfile.write(b'ababagalamaga') 226 f = self.create_file(testfile) 227 f.setnframes(self.nframes + 1) 228 try: 229 f.writeframes(self.frames) 230 except OSError: 231 pass 232 try: 233 f.close() 234 except OSError: 235 pass 236 237 with open(TESTFN, 'rb') as testfile: 238 self.assertEqual(testfile.read(13), b'ababagalamaga') 239 self.check_file(testfile, self.nframes + 1, self.frames) 240 241 def test_unseekable_overflowed_write(self): 242 with UnseekableIO(TESTFN, 'wb') as testfile: 243 testfile.write(b'ababagalamaga') 244 f = self.create_file(testfile) 245 f.setnframes(self.nframes - 1) 246 try: 247 f.writeframes(self.frames) 248 except OSError: 249 pass 250 try: 251 f.close() 252 except OSError: 253 pass 254 255 with open(TESTFN, 'rb') as testfile: 256 self.assertEqual(testfile.read(13), b'ababagalamaga') 257 framesize = self.nchannels * self.sampwidth 258 self.check_file(testfile, self.nframes - 1, self.frames[:-framesize]) 259 260 261class AudioTestsWithSourceFile(AudioTests): 262 263 @classmethod 264 def setUpClass(cls): 265 cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata') 266 267 def test_read_params(self): 268 f = self.f = self.module.open(self.sndfilepath) 269 #self.assertEqual(f.getfp().name, self.sndfilepath) 270 self.check_params(f, self.nchannels, self.sampwidth, self.framerate, 271 self.sndfilenframes, self.comptype, self.compname) 272 273 def test_close(self): 274 with open(self.sndfilepath, 'rb') as testfile: 275 f = self.f = self.module.open(testfile) 276 self.assertFalse(testfile.closed) 277 f.close() 278 self.assertEqual(testfile.closed, self.close_fd) 279 with open(TESTFN, 'wb') as testfile: 280 fout = self.fout = self.module.open(testfile, 'wb') 281 self.assertFalse(testfile.closed) 282 with self.assertRaises(self.module.Error): 283 fout.close() 284 self.assertEqual(testfile.closed, self.close_fd) 285 fout.close() # do nothing 286 287 def test_read(self): 288 framesize = self.nchannels * self.sampwidth 289 chunk1 = self.frames[:2 * framesize] 290 chunk2 = self.frames[2 * framesize: 4 * framesize] 291 f = self.f = self.module.open(self.sndfilepath) 292 self.assertEqual(f.readframes(0), b'') 293 self.assertEqual(f.tell(), 0) 294 self.assertEqual(f.readframes(2), chunk1) 295 f.rewind() 296 pos0 = f.tell() 297 self.assertEqual(pos0, 0) 298 self.assertEqual(f.readframes(2), chunk1) 299 pos2 = f.tell() 300 self.assertEqual(pos2, 2) 301 self.assertEqual(f.readframes(2), chunk2) 302 f.setpos(pos2) 303 self.assertEqual(f.readframes(2), chunk2) 304 f.setpos(pos0) 305 self.assertEqual(f.readframes(2), chunk1) 306 with self.assertRaises(self.module.Error): 307 f.setpos(-1) 308 with self.assertRaises(self.module.Error): 309 f.setpos(f.getnframes() + 1) 310 311 def test_copy(self): 312 f = self.f = self.module.open(self.sndfilepath) 313 fout = self.fout = self.module.open(TESTFN, 'wb') 314 fout.setparams(f.getparams()) 315 i = 0 316 n = f.getnframes() 317 while n > 0: 318 i += 1 319 fout.writeframes(f.readframes(i)) 320 n -= i 321 fout.close() 322 fout = self.fout = self.module.open(TESTFN, 'rb') 323 f.rewind() 324 self.assertEqual(f.getparams(), fout.getparams()) 325 self.assertEqual(f.readframes(f.getnframes()), 326 fout.readframes(fout.getnframes())) 327 328 def test_read_not_from_start(self): 329 with open(TESTFN, 'wb') as testfile: 330 testfile.write(b'ababagalamaga') 331 with open(self.sndfilepath, 'rb') as f: 332 testfile.write(f.read()) 333 334 with open(TESTFN, 'rb') as testfile: 335 self.assertEqual(testfile.read(13), b'ababagalamaga') 336 with self.module.open(testfile, 'rb') as f: 337 self.assertEqual(f.getnchannels(), self.nchannels) 338 self.assertEqual(f.getsampwidth(), self.sampwidth) 339 self.assertEqual(f.getframerate(), self.framerate) 340 self.assertEqual(f.getnframes(), self.sndfilenframes) 341 self.assertEqual(f.readframes(self.nframes), self.frames) 342