1from test.support import findfile 2from test.support.os_helper import TESTFN, unlink 3import array 4import io 5import pickle 6 7 8class UnseekableIO(io.FileIO): 9 def tell(self): 10 raise io.UnsupportedOperation 11 12 def seek(self, *args, **kwargs): 13 raise io.UnsupportedOperation 14 15 16class AudioTests: 17 close_fd = False 18 19 def setUp(self): 20 self.f = self.fout = None 21 22 def tearDown(self): 23 if self.f is not None: 24 self.f.close() 25 if self.fout is not None: 26 self.fout.close() 27 unlink(TESTFN) 28 29 def check_params(self, f, nchannels, sampwidth, framerate, nframes, 30 comptype, compname): 31 self.assertEqual(f.getnchannels(), nchannels) 32 self.assertEqual(f.getsampwidth(), sampwidth) 33 self.assertEqual(f.getframerate(), framerate) 34 self.assertEqual(f.getnframes(), nframes) 35 self.assertEqual(f.getcomptype(), comptype) 36 self.assertEqual(f.getcompname(), compname) 37 38 params = f.getparams() 39 self.assertEqual(params, 40 (nchannels, sampwidth, framerate, nframes, comptype, compname)) 41 self.assertEqual(params.nchannels, nchannels) 42 self.assertEqual(params.sampwidth, sampwidth) 43 self.assertEqual(params.framerate, framerate) 44 self.assertEqual(params.nframes, nframes) 45 self.assertEqual(params.comptype, comptype) 46 self.assertEqual(params.compname, compname) 47 48 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 49 dump = pickle.dumps(params, proto) 50 self.assertEqual(pickle.loads(dump), params) 51 52 53class AudioWriteTests(AudioTests): 54 55 def create_file(self, testfile): 56 f = self.fout = self.module.open(testfile, 'wb') 57 f.setnchannels(self.nchannels) 58 f.setsampwidth(self.sampwidth) 59 f.setframerate(self.framerate) 60 f.setcomptype(self.comptype, self.compname) 61 return f 62 63 def check_file(self, testfile, nframes, frames): 64 with self.module.open(testfile, 'rb') as f: 65 self.assertEqual(f.getnchannels(), self.nchannels) 66 self.assertEqual(f.getsampwidth(), self.sampwidth) 67 self.assertEqual(f.getframerate(), self.framerate) 68 self.assertEqual(f.getnframes(), nframes) 69 self.assertEqual(f.readframes(nframes), frames) 70 71 def test_write_params(self): 72 f = self.create_file(TESTFN) 73 f.setnframes(self.nframes) 74 f.writeframes(self.frames) 75 self.check_params(f, self.nchannels, self.sampwidth, self.framerate, 76 self.nframes, self.comptype, self.compname) 77 f.close() 78 79 def test_write_context_manager_calls_close(self): 80 # Close checks for a minimum header and will raise an error 81 # if it is not set, so this proves that close is called. 82 with self.assertRaises(self.module.Error): 83 with self.module.open(TESTFN, 'wb'): 84 pass 85 with self.assertRaises(self.module.Error): 86 with open(TESTFN, 'wb') as testfile: 87 with self.module.open(testfile): 88 pass 89 90 def test_context_manager_with_open_file(self): 91 with open(TESTFN, 'wb') as testfile: 92 with self.module.open(testfile) as f: 93 f.setnchannels(self.nchannels) 94 f.setsampwidth(self.sampwidth) 95 f.setframerate(self.framerate) 96 f.setcomptype(self.comptype, self.compname) 97 self.assertEqual(testfile.closed, self.close_fd) 98 with open(TESTFN, 'rb') as testfile: 99 with self.module.open(testfile) as f: 100 self.assertFalse(f.getfp().closed) 101 params = f.getparams() 102 self.assertEqual(params.nchannels, self.nchannels) 103 self.assertEqual(params.sampwidth, self.sampwidth) 104 self.assertEqual(params.framerate, self.framerate) 105 if not self.close_fd: 106 self.assertIsNone(f.getfp()) 107 self.assertEqual(testfile.closed, self.close_fd) 108 109 def test_context_manager_with_filename(self): 110 # If the file doesn't get closed, this test won't fail, but it will 111 # produce a resource leak warning. 112 with self.module.open(TESTFN, 'wb') as f: 113 f.setnchannels(self.nchannels) 114 f.setsampwidth(self.sampwidth) 115 f.setframerate(self.framerate) 116 f.setcomptype(self.comptype, self.compname) 117 with self.module.open(TESTFN) as f: 118 self.assertFalse(f.getfp().closed) 119 params = f.getparams() 120 self.assertEqual(params.nchannels, self.nchannels) 121 self.assertEqual(params.sampwidth, self.sampwidth) 122 self.assertEqual(params.framerate, self.framerate) 123 if not self.close_fd: 124 self.assertIsNone(f.getfp()) 125 126 def test_write(self): 127 f = self.create_file(TESTFN) 128 f.setnframes(self.nframes) 129 f.writeframes(self.frames) 130 f.close() 131 132 self.check_file(TESTFN, self.nframes, self.frames) 133 134 def test_write_bytearray(self): 135 f = self.create_file(TESTFN) 136 f.setnframes(self.nframes) 137 f.writeframes(bytearray(self.frames)) 138 f.close() 139 140 self.check_file(TESTFN, self.nframes, self.frames) 141 142 def test_write_array(self): 143 f = self.create_file(TESTFN) 144 f.setnframes(self.nframes) 145 f.writeframes(array.array('h', self.frames)) 146 f.close() 147 148 self.check_file(TESTFN, self.nframes, self.frames) 149 150 def test_write_memoryview(self): 151 f = self.create_file(TESTFN) 152 f.setnframes(self.nframes) 153 f.writeframes(memoryview(self.frames)) 154 f.close() 155 156 self.check_file(TESTFN, self.nframes, self.frames) 157 158 def test_incompleted_write(self): 159 with open(TESTFN, 'wb') as testfile: 160 testfile.write(b'ababagalamaga') 161 f = self.create_file(testfile) 162 f.setnframes(self.nframes + 1) 163 f.writeframes(self.frames) 164 f.close() 165 166 with open(TESTFN, 'rb') as testfile: 167 self.assertEqual(testfile.read(13), b'ababagalamaga') 168 self.check_file(testfile, self.nframes, self.frames) 169 170 def test_multiple_writes(self): 171 with open(TESTFN, 'wb') as testfile: 172 testfile.write(b'ababagalamaga') 173 f = self.create_file(testfile) 174 f.setnframes(self.nframes) 175 framesize = self.nchannels * self.sampwidth 176 f.writeframes(self.frames[:-framesize]) 177 f.writeframes(self.frames[-framesize:]) 178 f.close() 179 180 with open(TESTFN, 'rb') as testfile: 181 self.assertEqual(testfile.read(13), b'ababagalamaga') 182 self.check_file(testfile, self.nframes, self.frames) 183 184 def test_overflowed_write(self): 185 with open(TESTFN, 'wb') as testfile: 186 testfile.write(b'ababagalamaga') 187 f = self.create_file(testfile) 188 f.setnframes(self.nframes - 1) 189 f.writeframes(self.frames) 190 f.close() 191 192 with open(TESTFN, 'rb') as testfile: 193 self.assertEqual(testfile.read(13), b'ababagalamaga') 194 self.check_file(testfile, self.nframes, self.frames) 195 196 def test_unseekable_read(self): 197 with self.create_file(TESTFN) as f: 198 f.setnframes(self.nframes) 199 f.writeframes(self.frames) 200 201 with UnseekableIO(TESTFN, 'rb') as testfile: 202 self.check_file(testfile, self.nframes, self.frames) 203 204 def test_unseekable_write(self): 205 with UnseekableIO(TESTFN, 'wb') as testfile: 206 with self.create_file(testfile) as f: 207 f.setnframes(self.nframes) 208 f.writeframes(self.frames) 209 210 self.check_file(TESTFN, self.nframes, self.frames) 211 212 def test_unseekable_incompleted_write(self): 213 with UnseekableIO(TESTFN, 'wb') as testfile: 214 testfile.write(b'ababagalamaga') 215 f = self.create_file(testfile) 216 f.setnframes(self.nframes + 1) 217 try: 218 f.writeframes(self.frames) 219 except OSError: 220 pass 221 try: 222 f.close() 223 except OSError: 224 pass 225 226 with open(TESTFN, 'rb') as testfile: 227 self.assertEqual(testfile.read(13), b'ababagalamaga') 228 self.check_file(testfile, self.nframes + 1, self.frames) 229 230 def test_unseekable_overflowed_write(self): 231 with UnseekableIO(TESTFN, 'wb') as testfile: 232 testfile.write(b'ababagalamaga') 233 f = self.create_file(testfile) 234 f.setnframes(self.nframes - 1) 235 try: 236 f.writeframes(self.frames) 237 except OSError: 238 pass 239 try: 240 f.close() 241 except OSError: 242 pass 243 244 with open(TESTFN, 'rb') as testfile: 245 self.assertEqual(testfile.read(13), b'ababagalamaga') 246 framesize = self.nchannels * self.sampwidth 247 self.check_file(testfile, self.nframes - 1, self.frames[:-framesize]) 248 249 250class AudioTestsWithSourceFile(AudioTests): 251 252 @classmethod 253 def setUpClass(cls): 254 cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata') 255 256 def test_read_params(self): 257 f = self.f = self.module.open(self.sndfilepath) 258 #self.assertEqual(f.getfp().name, self.sndfilepath) 259 self.check_params(f, self.nchannels, self.sampwidth, self.framerate, 260 self.sndfilenframes, self.comptype, self.compname) 261 262 def test_close(self): 263 with open(self.sndfilepath, 'rb') as testfile: 264 f = self.f = self.module.open(testfile) 265 self.assertFalse(testfile.closed) 266 f.close() 267 self.assertEqual(testfile.closed, self.close_fd) 268 with open(TESTFN, 'wb') as testfile: 269 fout = self.fout = self.module.open(testfile, 'wb') 270 self.assertFalse(testfile.closed) 271 with self.assertRaises(self.module.Error): 272 fout.close() 273 self.assertEqual(testfile.closed, self.close_fd) 274 fout.close() # do nothing 275 276 def test_read(self): 277 framesize = self.nchannels * self.sampwidth 278 chunk1 = self.frames[:2 * framesize] 279 chunk2 = self.frames[2 * framesize: 4 * framesize] 280 f = self.f = self.module.open(self.sndfilepath) 281 self.assertEqual(f.readframes(0), b'') 282 self.assertEqual(f.tell(), 0) 283 self.assertEqual(f.readframes(2), chunk1) 284 f.rewind() 285 pos0 = f.tell() 286 self.assertEqual(pos0, 0) 287 self.assertEqual(f.readframes(2), chunk1) 288 pos2 = f.tell() 289 self.assertEqual(pos2, 2) 290 self.assertEqual(f.readframes(2), chunk2) 291 f.setpos(pos2) 292 self.assertEqual(f.readframes(2), chunk2) 293 f.setpos(pos0) 294 self.assertEqual(f.readframes(2), chunk1) 295 with self.assertRaises(self.module.Error): 296 f.setpos(-1) 297 with self.assertRaises(self.module.Error): 298 f.setpos(f.getnframes() + 1) 299 300 def test_copy(self): 301 f = self.f = self.module.open(self.sndfilepath) 302 fout = self.fout = self.module.open(TESTFN, 'wb') 303 fout.setparams(f.getparams()) 304 i = 0 305 n = f.getnframes() 306 while n > 0: 307 i += 1 308 fout.writeframes(f.readframes(i)) 309 n -= i 310 fout.close() 311 fout = self.fout = self.module.open(TESTFN, 'rb') 312 f.rewind() 313 self.assertEqual(f.getparams(), fout.getparams()) 314 self.assertEqual(f.readframes(f.getnframes()), 315 fout.readframes(fout.getnframes())) 316 317 def test_read_not_from_start(self): 318 with open(TESTFN, 'wb') as testfile: 319 testfile.write(b'ababagalamaga') 320 with open(self.sndfilepath, 'rb') as f: 321 testfile.write(f.read()) 322 323 with open(TESTFN, 'rb') as testfile: 324 self.assertEqual(testfile.read(13), b'ababagalamaga') 325 with self.module.open(testfile, 'rb') as f: 326 self.assertEqual(f.getnchannels(), self.nchannels) 327 self.assertEqual(f.getsampwidth(), self.sampwidth) 328 self.assertEqual(f.getframerate(), self.framerate) 329 self.assertEqual(f.getnframes(), self.sndfilenframes) 330 self.assertEqual(f.readframes(self.nframes), self.frames) 331