• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from test.support import findfile, TESTFN, unlink
2import array
3import io
4import pickle
5
6
7class UnseekableIO(io.FileIO):
8    def tell(self):
9        raise io.UnsupportedOperation
10
11    def seek(self, *args, **kwargs):
12        raise io.UnsupportedOperation
13
14
15class AudioTests:
16    close_fd = False
17
18    def setUp(self):
19        self.f = self.fout = None
20
21    def tearDown(self):
22        if self.f is not None:
23            self.f.close()
24        if self.fout is not None:
25            self.fout.close()
26        unlink(TESTFN)
27
28    def check_params(self, f, nchannels, sampwidth, framerate, nframes,
29                     comptype, compname):
30        self.assertEqual(f.getnchannels(), nchannels)
31        self.assertEqual(f.getsampwidth(), sampwidth)
32        self.assertEqual(f.getframerate(), framerate)
33        self.assertEqual(f.getnframes(), nframes)
34        self.assertEqual(f.getcomptype(), comptype)
35        self.assertEqual(f.getcompname(), compname)
36
37        params = f.getparams()
38        self.assertEqual(params,
39                (nchannels, sampwidth, framerate, nframes, comptype, compname))
40        self.assertEqual(params.nchannels, nchannels)
41        self.assertEqual(params.sampwidth, sampwidth)
42        self.assertEqual(params.framerate, framerate)
43        self.assertEqual(params.nframes, nframes)
44        self.assertEqual(params.comptype, comptype)
45        self.assertEqual(params.compname, compname)
46
47        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
48            dump = pickle.dumps(params, proto)
49            self.assertEqual(pickle.loads(dump), params)
50
51
52class AudioWriteTests(AudioTests):
53
54    def create_file(self, testfile):
55        f = self.fout = self.module.open(testfile, 'wb')
56        f.setnchannels(self.nchannels)
57        f.setsampwidth(self.sampwidth)
58        f.setframerate(self.framerate)
59        f.setcomptype(self.comptype, self.compname)
60        return f
61
62    def check_file(self, testfile, nframes, frames):
63        with self.module.open(testfile, 'rb') as f:
64            self.assertEqual(f.getnchannels(), self.nchannels)
65            self.assertEqual(f.getsampwidth(), self.sampwidth)
66            self.assertEqual(f.getframerate(), self.framerate)
67            self.assertEqual(f.getnframes(), nframes)
68            self.assertEqual(f.readframes(nframes), frames)
69
70    def test_write_params(self):
71        f = self.create_file(TESTFN)
72        f.setnframes(self.nframes)
73        f.writeframes(self.frames)
74        self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
75                          self.nframes, self.comptype, self.compname)
76        f.close()
77
78    def test_write_context_manager_calls_close(self):
79        # Close checks for a minimum header and will raise an error
80        # if it is not set, so this proves that close is called.
81        with self.assertRaises(self.module.Error):
82            with self.module.open(TESTFN, 'wb'):
83                pass
84        with self.assertRaises(self.module.Error):
85            with open(TESTFN, 'wb') as testfile:
86                with self.module.open(testfile):
87                    pass
88
89    def test_context_manager_with_open_file(self):
90        with open(TESTFN, 'wb') as testfile:
91            with self.module.open(testfile) as f:
92                f.setnchannels(self.nchannels)
93                f.setsampwidth(self.sampwidth)
94                f.setframerate(self.framerate)
95                f.setcomptype(self.comptype, self.compname)
96            self.assertEqual(testfile.closed, self.close_fd)
97        with open(TESTFN, 'rb') as testfile:
98            with self.module.open(testfile) as f:
99                self.assertFalse(f.getfp().closed)
100                params = f.getparams()
101                self.assertEqual(params.nchannels, self.nchannels)
102                self.assertEqual(params.sampwidth, self.sampwidth)
103                self.assertEqual(params.framerate, self.framerate)
104            if not self.close_fd:
105                self.assertIsNone(f.getfp())
106            self.assertEqual(testfile.closed, self.close_fd)
107
108    def test_context_manager_with_filename(self):
109        # If the file doesn't get closed, this test won't fail, but it will
110        # produce a resource leak warning.
111        with self.module.open(TESTFN, 'wb') as f:
112            f.setnchannels(self.nchannels)
113            f.setsampwidth(self.sampwidth)
114            f.setframerate(self.framerate)
115            f.setcomptype(self.comptype, self.compname)
116        with self.module.open(TESTFN) as f:
117            self.assertFalse(f.getfp().closed)
118            params = f.getparams()
119            self.assertEqual(params.nchannels, self.nchannels)
120            self.assertEqual(params.sampwidth, self.sampwidth)
121            self.assertEqual(params.framerate, self.framerate)
122        if not self.close_fd:
123            self.assertIsNone(f.getfp())
124
125    def test_write(self):
126        f = self.create_file(TESTFN)
127        f.setnframes(self.nframes)
128        f.writeframes(self.frames)
129        f.close()
130
131        self.check_file(TESTFN, self.nframes, self.frames)
132
133    def test_write_bytearray(self):
134        f = self.create_file(TESTFN)
135        f.setnframes(self.nframes)
136        f.writeframes(bytearray(self.frames))
137        f.close()
138
139        self.check_file(TESTFN, self.nframes, self.frames)
140
141    def test_write_array(self):
142        f = self.create_file(TESTFN)
143        f.setnframes(self.nframes)
144        f.writeframes(array.array('h', self.frames))
145        f.close()
146
147        self.check_file(TESTFN, self.nframes, self.frames)
148
149    def test_write_memoryview(self):
150        f = self.create_file(TESTFN)
151        f.setnframes(self.nframes)
152        f.writeframes(memoryview(self.frames))
153        f.close()
154
155        self.check_file(TESTFN, self.nframes, self.frames)
156
157    def test_incompleted_write(self):
158        with open(TESTFN, 'wb') as testfile:
159            testfile.write(b'ababagalamaga')
160            f = self.create_file(testfile)
161            f.setnframes(self.nframes + 1)
162            f.writeframes(self.frames)
163            f.close()
164
165        with open(TESTFN, 'rb') as testfile:
166            self.assertEqual(testfile.read(13), b'ababagalamaga')
167            self.check_file(testfile, self.nframes, self.frames)
168
169    def test_multiple_writes(self):
170        with open(TESTFN, 'wb') as testfile:
171            testfile.write(b'ababagalamaga')
172            f = self.create_file(testfile)
173            f.setnframes(self.nframes)
174            framesize = self.nchannels * self.sampwidth
175            f.writeframes(self.frames[:-framesize])
176            f.writeframes(self.frames[-framesize:])
177            f.close()
178
179        with open(TESTFN, 'rb') as testfile:
180            self.assertEqual(testfile.read(13), b'ababagalamaga')
181            self.check_file(testfile, self.nframes, self.frames)
182
183    def test_overflowed_write(self):
184        with open(TESTFN, 'wb') as testfile:
185            testfile.write(b'ababagalamaga')
186            f = self.create_file(testfile)
187            f.setnframes(self.nframes - 1)
188            f.writeframes(self.frames)
189            f.close()
190
191        with open(TESTFN, 'rb') as testfile:
192            self.assertEqual(testfile.read(13), b'ababagalamaga')
193            self.check_file(testfile, self.nframes, self.frames)
194
195    def test_unseekable_read(self):
196        with self.create_file(TESTFN) as f:
197            f.setnframes(self.nframes)
198            f.writeframes(self.frames)
199
200        with UnseekableIO(TESTFN, 'rb') as testfile:
201            self.check_file(testfile, self.nframes, self.frames)
202
203    def test_unseekable_write(self):
204        with UnseekableIO(TESTFN, 'wb') as testfile:
205            with self.create_file(testfile) as f:
206                f.setnframes(self.nframes)
207                f.writeframes(self.frames)
208
209        self.check_file(TESTFN, self.nframes, self.frames)
210
211    def test_unseekable_incompleted_write(self):
212        with UnseekableIO(TESTFN, 'wb') as testfile:
213            testfile.write(b'ababagalamaga')
214            f = self.create_file(testfile)
215            f.setnframes(self.nframes + 1)
216            try:
217                f.writeframes(self.frames)
218            except OSError:
219                pass
220            try:
221                f.close()
222            except OSError:
223                pass
224
225        with open(TESTFN, 'rb') as testfile:
226            self.assertEqual(testfile.read(13), b'ababagalamaga')
227            self.check_file(testfile, self.nframes + 1, self.frames)
228
229    def test_unseekable_overflowed_write(self):
230        with UnseekableIO(TESTFN, 'wb') as testfile:
231            testfile.write(b'ababagalamaga')
232            f = self.create_file(testfile)
233            f.setnframes(self.nframes - 1)
234            try:
235                f.writeframes(self.frames)
236            except OSError:
237                pass
238            try:
239                f.close()
240            except OSError:
241                pass
242
243        with open(TESTFN, 'rb') as testfile:
244            self.assertEqual(testfile.read(13), b'ababagalamaga')
245            framesize = self.nchannels * self.sampwidth
246            self.check_file(testfile, self.nframes - 1, self.frames[:-framesize])
247
248
249class AudioTestsWithSourceFile(AudioTests):
250
251    @classmethod
252    def setUpClass(cls):
253        cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
254
255    def test_read_params(self):
256        f = self.f = self.module.open(self.sndfilepath)
257        #self.assertEqual(f.getfp().name, self.sndfilepath)
258        self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
259                          self.sndfilenframes, self.comptype, self.compname)
260
261    def test_close(self):
262        with open(self.sndfilepath, 'rb') as testfile:
263            f = self.f = self.module.open(testfile)
264            self.assertFalse(testfile.closed)
265            f.close()
266            self.assertEqual(testfile.closed, self.close_fd)
267        with open(TESTFN, 'wb') as testfile:
268            fout = self.fout = self.module.open(testfile, 'wb')
269            self.assertFalse(testfile.closed)
270            with self.assertRaises(self.module.Error):
271                fout.close()
272            self.assertEqual(testfile.closed, self.close_fd)
273            fout.close() # do nothing
274
275    def test_read(self):
276        framesize = self.nchannels * self.sampwidth
277        chunk1 = self.frames[:2 * framesize]
278        chunk2 = self.frames[2 * framesize: 4 * framesize]
279        f = self.f = self.module.open(self.sndfilepath)
280        self.assertEqual(f.readframes(0), b'')
281        self.assertEqual(f.tell(), 0)
282        self.assertEqual(f.readframes(2), chunk1)
283        f.rewind()
284        pos0 = f.tell()
285        self.assertEqual(pos0, 0)
286        self.assertEqual(f.readframes(2), chunk1)
287        pos2 = f.tell()
288        self.assertEqual(pos2, 2)
289        self.assertEqual(f.readframes(2), chunk2)
290        f.setpos(pos2)
291        self.assertEqual(f.readframes(2), chunk2)
292        f.setpos(pos0)
293        self.assertEqual(f.readframes(2), chunk1)
294        with self.assertRaises(self.module.Error):
295            f.setpos(-1)
296        with self.assertRaises(self.module.Error):
297            f.setpos(f.getnframes() + 1)
298
299    def test_copy(self):
300        f = self.f = self.module.open(self.sndfilepath)
301        fout = self.fout = self.module.open(TESTFN, 'wb')
302        fout.setparams(f.getparams())
303        i = 0
304        n = f.getnframes()
305        while n > 0:
306            i += 1
307            fout.writeframes(f.readframes(i))
308            n -= i
309        fout.close()
310        fout = self.fout = self.module.open(TESTFN, 'rb')
311        f.rewind()
312        self.assertEqual(f.getparams(), fout.getparams())
313        self.assertEqual(f.readframes(f.getnframes()),
314                         fout.readframes(fout.getnframes()))
315
316    def test_read_not_from_start(self):
317        with open(TESTFN, 'wb') as testfile:
318            testfile.write(b'ababagalamaga')
319            with open(self.sndfilepath, 'rb') as f:
320                testfile.write(f.read())
321
322        with open(TESTFN, 'rb') as testfile:
323            self.assertEqual(testfile.read(13), b'ababagalamaga')
324            with self.module.open(testfile, 'rb') as f:
325                self.assertEqual(f.getnchannels(), self.nchannels)
326                self.assertEqual(f.getsampwidth(), self.sampwidth)
327                self.assertEqual(f.getframerate(), self.framerate)
328                self.assertEqual(f.getnframes(), self.sndfilenframes)
329                self.assertEqual(f.readframes(self.nframes), self.frames)
330