• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from test.support import findfile, TESTFN, unlink
2import array
3import io
4from unittest import mock
5import pickle
6
7
8class UnseekableIO(io.FileIO):
9    def tell(self):
10        raise io.UnsupportedOperation
11
12    def seek(self, *args, **kwargs):
13        raise io.UnsupportedOperation
14
15
16class AudioTests:
17    close_fd = False
18
19    def setUp(self):
20        self.f = self.fout = None
21
22    def tearDown(self):
23        if self.f is not None:
24            self.f.close()
25        if self.fout is not None:
26            self.fout.close()
27        unlink(TESTFN)
28
29    def check_params(self, f, nchannels, sampwidth, framerate, nframes,
30                     comptype, compname):
31        self.assertEqual(f.getnchannels(), nchannels)
32        self.assertEqual(f.getsampwidth(), sampwidth)
33        self.assertEqual(f.getframerate(), framerate)
34        self.assertEqual(f.getnframes(), nframes)
35        self.assertEqual(f.getcomptype(), comptype)
36        self.assertEqual(f.getcompname(), compname)
37
38        params = f.getparams()
39        self.assertEqual(params,
40                (nchannels, sampwidth, framerate, nframes, comptype, compname))
41        self.assertEqual(params.nchannels, nchannels)
42        self.assertEqual(params.sampwidth, sampwidth)
43        self.assertEqual(params.framerate, framerate)
44        self.assertEqual(params.nframes, nframes)
45        self.assertEqual(params.comptype, comptype)
46        self.assertEqual(params.compname, compname)
47
48        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
49            dump = pickle.dumps(params, proto)
50            self.assertEqual(pickle.loads(dump), params)
51
52
53class AudioMiscTests(AudioTests):
54
55    def test_openfp_deprecated(self):
56        arg = "arg"
57        mode = "mode"
58        with mock.patch(f"{self.module.__name__}.open") as mock_open, \
59             self.assertWarns(DeprecationWarning):
60            self.module.openfp(arg, mode=mode)
61            mock_open.assert_called_with(arg, mode=mode)
62
63
64class AudioWriteTests(AudioTests):
65
66    def create_file(self, testfile):
67        f = self.fout = self.module.open(testfile, 'wb')
68        f.setnchannels(self.nchannels)
69        f.setsampwidth(self.sampwidth)
70        f.setframerate(self.framerate)
71        f.setcomptype(self.comptype, self.compname)
72        return f
73
74    def check_file(self, testfile, nframes, frames):
75        with self.module.open(testfile, 'rb') as f:
76            self.assertEqual(f.getnchannels(), self.nchannels)
77            self.assertEqual(f.getsampwidth(), self.sampwidth)
78            self.assertEqual(f.getframerate(), self.framerate)
79            self.assertEqual(f.getnframes(), nframes)
80            self.assertEqual(f.readframes(nframes), frames)
81
82    def test_write_params(self):
83        f = self.create_file(TESTFN)
84        f.setnframes(self.nframes)
85        f.writeframes(self.frames)
86        self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
87                          self.nframes, self.comptype, self.compname)
88        f.close()
89
90    def test_write_context_manager_calls_close(self):
91        # Close checks for a minimum header and will raise an error
92        # if it is not set, so this proves that close is called.
93        with self.assertRaises(self.module.Error):
94            with self.module.open(TESTFN, 'wb'):
95                pass
96        with self.assertRaises(self.module.Error):
97            with open(TESTFN, 'wb') as testfile:
98                with self.module.open(testfile):
99                    pass
100
101    def test_context_manager_with_open_file(self):
102        with open(TESTFN, 'wb') as testfile:
103            with self.module.open(testfile) as f:
104                f.setnchannels(self.nchannels)
105                f.setsampwidth(self.sampwidth)
106                f.setframerate(self.framerate)
107                f.setcomptype(self.comptype, self.compname)
108            self.assertEqual(testfile.closed, self.close_fd)
109        with open(TESTFN, 'rb') as testfile:
110            with self.module.open(testfile) as f:
111                self.assertFalse(f.getfp().closed)
112                params = f.getparams()
113                self.assertEqual(params.nchannels, self.nchannels)
114                self.assertEqual(params.sampwidth, self.sampwidth)
115                self.assertEqual(params.framerate, self.framerate)
116            if not self.close_fd:
117                self.assertIsNone(f.getfp())
118            self.assertEqual(testfile.closed, self.close_fd)
119
120    def test_context_manager_with_filename(self):
121        # If the file doesn't get closed, this test won't fail, but it will
122        # produce a resource leak warning.
123        with self.module.open(TESTFN, 'wb') as f:
124            f.setnchannels(self.nchannels)
125            f.setsampwidth(self.sampwidth)
126            f.setframerate(self.framerate)
127            f.setcomptype(self.comptype, self.compname)
128        with self.module.open(TESTFN) as f:
129            self.assertFalse(f.getfp().closed)
130            params = f.getparams()
131            self.assertEqual(params.nchannels, self.nchannels)
132            self.assertEqual(params.sampwidth, self.sampwidth)
133            self.assertEqual(params.framerate, self.framerate)
134        if not self.close_fd:
135            self.assertIsNone(f.getfp())
136
137    def test_write(self):
138        f = self.create_file(TESTFN)
139        f.setnframes(self.nframes)
140        f.writeframes(self.frames)
141        f.close()
142
143        self.check_file(TESTFN, self.nframes, self.frames)
144
145    def test_write_bytearray(self):
146        f = self.create_file(TESTFN)
147        f.setnframes(self.nframes)
148        f.writeframes(bytearray(self.frames))
149        f.close()
150
151        self.check_file(TESTFN, self.nframes, self.frames)
152
153    def test_write_array(self):
154        f = self.create_file(TESTFN)
155        f.setnframes(self.nframes)
156        f.writeframes(array.array('h', self.frames))
157        f.close()
158
159        self.check_file(TESTFN, self.nframes, self.frames)
160
161    def test_write_memoryview(self):
162        f = self.create_file(TESTFN)
163        f.setnframes(self.nframes)
164        f.writeframes(memoryview(self.frames))
165        f.close()
166
167        self.check_file(TESTFN, self.nframes, self.frames)
168
169    def test_incompleted_write(self):
170        with open(TESTFN, 'wb') as testfile:
171            testfile.write(b'ababagalamaga')
172            f = self.create_file(testfile)
173            f.setnframes(self.nframes + 1)
174            f.writeframes(self.frames)
175            f.close()
176
177        with open(TESTFN, 'rb') as testfile:
178            self.assertEqual(testfile.read(13), b'ababagalamaga')
179            self.check_file(testfile, self.nframes, self.frames)
180
181    def test_multiple_writes(self):
182        with open(TESTFN, 'wb') as testfile:
183            testfile.write(b'ababagalamaga')
184            f = self.create_file(testfile)
185            f.setnframes(self.nframes)
186            framesize = self.nchannels * self.sampwidth
187            f.writeframes(self.frames[:-framesize])
188            f.writeframes(self.frames[-framesize:])
189            f.close()
190
191        with open(TESTFN, 'rb') as testfile:
192            self.assertEqual(testfile.read(13), b'ababagalamaga')
193            self.check_file(testfile, self.nframes, self.frames)
194
195    def test_overflowed_write(self):
196        with open(TESTFN, 'wb') as testfile:
197            testfile.write(b'ababagalamaga')
198            f = self.create_file(testfile)
199            f.setnframes(self.nframes - 1)
200            f.writeframes(self.frames)
201            f.close()
202
203        with open(TESTFN, 'rb') as testfile:
204            self.assertEqual(testfile.read(13), b'ababagalamaga')
205            self.check_file(testfile, self.nframes, self.frames)
206
207    def test_unseekable_read(self):
208        with self.create_file(TESTFN) as f:
209            f.setnframes(self.nframes)
210            f.writeframes(self.frames)
211
212        with UnseekableIO(TESTFN, 'rb') as testfile:
213            self.check_file(testfile, self.nframes, self.frames)
214
215    def test_unseekable_write(self):
216        with UnseekableIO(TESTFN, 'wb') as testfile:
217            with self.create_file(testfile) as f:
218                f.setnframes(self.nframes)
219                f.writeframes(self.frames)
220
221        self.check_file(TESTFN, self.nframes, self.frames)
222
223    def test_unseekable_incompleted_write(self):
224        with UnseekableIO(TESTFN, 'wb') as testfile:
225            testfile.write(b'ababagalamaga')
226            f = self.create_file(testfile)
227            f.setnframes(self.nframes + 1)
228            try:
229                f.writeframes(self.frames)
230            except OSError:
231                pass
232            try:
233                f.close()
234            except OSError:
235                pass
236
237        with open(TESTFN, 'rb') as testfile:
238            self.assertEqual(testfile.read(13), b'ababagalamaga')
239            self.check_file(testfile, self.nframes + 1, self.frames)
240
241    def test_unseekable_overflowed_write(self):
242        with UnseekableIO(TESTFN, 'wb') as testfile:
243            testfile.write(b'ababagalamaga')
244            f = self.create_file(testfile)
245            f.setnframes(self.nframes - 1)
246            try:
247                f.writeframes(self.frames)
248            except OSError:
249                pass
250            try:
251                f.close()
252            except OSError:
253                pass
254
255        with open(TESTFN, 'rb') as testfile:
256            self.assertEqual(testfile.read(13), b'ababagalamaga')
257            framesize = self.nchannels * self.sampwidth
258            self.check_file(testfile, self.nframes - 1, self.frames[:-framesize])
259
260
261class AudioTestsWithSourceFile(AudioTests):
262
263    @classmethod
264    def setUpClass(cls):
265        cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
266
267    def test_read_params(self):
268        f = self.f = self.module.open(self.sndfilepath)
269        #self.assertEqual(f.getfp().name, self.sndfilepath)
270        self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
271                          self.sndfilenframes, self.comptype, self.compname)
272
273    def test_close(self):
274        with open(self.sndfilepath, 'rb') as testfile:
275            f = self.f = self.module.open(testfile)
276            self.assertFalse(testfile.closed)
277            f.close()
278            self.assertEqual(testfile.closed, self.close_fd)
279        with open(TESTFN, 'wb') as testfile:
280            fout = self.fout = self.module.open(testfile, 'wb')
281            self.assertFalse(testfile.closed)
282            with self.assertRaises(self.module.Error):
283                fout.close()
284            self.assertEqual(testfile.closed, self.close_fd)
285            fout.close() # do nothing
286
287    def test_read(self):
288        framesize = self.nchannels * self.sampwidth
289        chunk1 = self.frames[:2 * framesize]
290        chunk2 = self.frames[2 * framesize: 4 * framesize]
291        f = self.f = self.module.open(self.sndfilepath)
292        self.assertEqual(f.readframes(0), b'')
293        self.assertEqual(f.tell(), 0)
294        self.assertEqual(f.readframes(2), chunk1)
295        f.rewind()
296        pos0 = f.tell()
297        self.assertEqual(pos0, 0)
298        self.assertEqual(f.readframes(2), chunk1)
299        pos2 = f.tell()
300        self.assertEqual(pos2, 2)
301        self.assertEqual(f.readframes(2), chunk2)
302        f.setpos(pos2)
303        self.assertEqual(f.readframes(2), chunk2)
304        f.setpos(pos0)
305        self.assertEqual(f.readframes(2), chunk1)
306        with self.assertRaises(self.module.Error):
307            f.setpos(-1)
308        with self.assertRaises(self.module.Error):
309            f.setpos(f.getnframes() + 1)
310
311    def test_copy(self):
312        f = self.f = self.module.open(self.sndfilepath)
313        fout = self.fout = self.module.open(TESTFN, 'wb')
314        fout.setparams(f.getparams())
315        i = 0
316        n = f.getnframes()
317        while n > 0:
318            i += 1
319            fout.writeframes(f.readframes(i))
320            n -= i
321        fout.close()
322        fout = self.fout = self.module.open(TESTFN, 'rb')
323        f.rewind()
324        self.assertEqual(f.getparams(), fout.getparams())
325        self.assertEqual(f.readframes(f.getnframes()),
326                         fout.readframes(fout.getnframes()))
327
328    def test_read_not_from_start(self):
329        with open(TESTFN, 'wb') as testfile:
330            testfile.write(b'ababagalamaga')
331            with open(self.sndfilepath, 'rb') as f:
332                testfile.write(f.read())
333
334        with open(TESTFN, 'rb') as testfile:
335            self.assertEqual(testfile.read(13), b'ababagalamaga')
336            with self.module.open(testfile, 'rb') as f:
337                self.assertEqual(f.getnchannels(), self.nchannels)
338                self.assertEqual(f.getsampwidth(), self.sampwidth)
339                self.assertEqual(f.getframerate(), self.framerate)
340                self.assertEqual(f.getnframes(), self.sndfilenframes)
341                self.assertEqual(f.readframes(self.nframes), self.frames)
342