• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from test.support import findfile
2from test.support.os_helper import TESTFN, unlink
3import array
4import io
5import pickle
6
7
8class UnseekableIO(io.FileIO):
9    def tell(self):
10        raise io.UnsupportedOperation
11
12    def seek(self, *args, **kwargs):
13        raise io.UnsupportedOperation
14
15
16class AudioTests:
17    close_fd = False
18
19    def setUp(self):
20        self.f = self.fout = None
21
22    def tearDown(self):
23        if self.f is not None:
24            self.f.close()
25        if self.fout is not None:
26            self.fout.close()
27        unlink(TESTFN)
28
29    def check_params(self, f, nchannels, sampwidth, framerate, nframes,
30                     comptype, compname):
31        self.assertEqual(f.getnchannels(), nchannels)
32        self.assertEqual(f.getsampwidth(), sampwidth)
33        self.assertEqual(f.getframerate(), framerate)
34        self.assertEqual(f.getnframes(), nframes)
35        self.assertEqual(f.getcomptype(), comptype)
36        self.assertEqual(f.getcompname(), compname)
37
38        params = f.getparams()
39        self.assertEqual(params,
40                (nchannels, sampwidth, framerate, nframes, comptype, compname))
41        self.assertEqual(params.nchannels, nchannels)
42        self.assertEqual(params.sampwidth, sampwidth)
43        self.assertEqual(params.framerate, framerate)
44        self.assertEqual(params.nframes, nframes)
45        self.assertEqual(params.comptype, comptype)
46        self.assertEqual(params.compname, compname)
47
48        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
49            dump = pickle.dumps(params, proto)
50            self.assertEqual(pickle.loads(dump), params)
51
52
53class AudioWriteTests(AudioTests):
54
55    def create_file(self, testfile):
56        f = self.fout = self.module.open(testfile, 'wb')
57        f.setnchannels(self.nchannels)
58        f.setsampwidth(self.sampwidth)
59        f.setframerate(self.framerate)
60        f.setcomptype(self.comptype, self.compname)
61        return f
62
63    def check_file(self, testfile, nframes, frames):
64        with self.module.open(testfile, 'rb') as f:
65            self.assertEqual(f.getnchannels(), self.nchannels)
66            self.assertEqual(f.getsampwidth(), self.sampwidth)
67            self.assertEqual(f.getframerate(), self.framerate)
68            self.assertEqual(f.getnframes(), nframes)
69            self.assertEqual(f.readframes(nframes), frames)
70
71    def test_write_params(self):
72        f = self.create_file(TESTFN)
73        f.setnframes(self.nframes)
74        f.writeframes(self.frames)
75        self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
76                          self.nframes, self.comptype, self.compname)
77        f.close()
78
79    def test_write_context_manager_calls_close(self):
80        # Close checks for a minimum header and will raise an error
81        # if it is not set, so this proves that close is called.
82        with self.assertRaises(self.module.Error):
83            with self.module.open(TESTFN, 'wb'):
84                pass
85        with self.assertRaises(self.module.Error):
86            with open(TESTFN, 'wb') as testfile:
87                with self.module.open(testfile):
88                    pass
89
90    def test_context_manager_with_open_file(self):
91        with open(TESTFN, 'wb') as testfile:
92            with self.module.open(testfile) as f:
93                f.setnchannels(self.nchannels)
94                f.setsampwidth(self.sampwidth)
95                f.setframerate(self.framerate)
96                f.setcomptype(self.comptype, self.compname)
97            self.assertEqual(testfile.closed, self.close_fd)
98        with open(TESTFN, 'rb') as testfile:
99            with self.module.open(testfile) as f:
100                self.assertFalse(f.getfp().closed)
101                params = f.getparams()
102                self.assertEqual(params.nchannels, self.nchannels)
103                self.assertEqual(params.sampwidth, self.sampwidth)
104                self.assertEqual(params.framerate, self.framerate)
105            if not self.close_fd:
106                self.assertIsNone(f.getfp())
107            self.assertEqual(testfile.closed, self.close_fd)
108
109    def test_context_manager_with_filename(self):
110        # If the file doesn't get closed, this test won't fail, but it will
111        # produce a resource leak warning.
112        with self.module.open(TESTFN, 'wb') as f:
113            f.setnchannels(self.nchannels)
114            f.setsampwidth(self.sampwidth)
115            f.setframerate(self.framerate)
116            f.setcomptype(self.comptype, self.compname)
117        with self.module.open(TESTFN) as f:
118            self.assertFalse(f.getfp().closed)
119            params = f.getparams()
120            self.assertEqual(params.nchannels, self.nchannels)
121            self.assertEqual(params.sampwidth, self.sampwidth)
122            self.assertEqual(params.framerate, self.framerate)
123        if not self.close_fd:
124            self.assertIsNone(f.getfp())
125
126    def test_write(self):
127        f = self.create_file(TESTFN)
128        f.setnframes(self.nframes)
129        f.writeframes(self.frames)
130        f.close()
131
132        self.check_file(TESTFN, self.nframes, self.frames)
133
134    def test_write_bytearray(self):
135        f = self.create_file(TESTFN)
136        f.setnframes(self.nframes)
137        f.writeframes(bytearray(self.frames))
138        f.close()
139
140        self.check_file(TESTFN, self.nframes, self.frames)
141
142    def test_write_array(self):
143        f = self.create_file(TESTFN)
144        f.setnframes(self.nframes)
145        f.writeframes(array.array('h', self.frames))
146        f.close()
147
148        self.check_file(TESTFN, self.nframes, self.frames)
149
150    def test_write_memoryview(self):
151        f = self.create_file(TESTFN)
152        f.setnframes(self.nframes)
153        f.writeframes(memoryview(self.frames))
154        f.close()
155
156        self.check_file(TESTFN, self.nframes, self.frames)
157
158    def test_incompleted_write(self):
159        with open(TESTFN, 'wb') as testfile:
160            testfile.write(b'ababagalamaga')
161            f = self.create_file(testfile)
162            f.setnframes(self.nframes + 1)
163            f.writeframes(self.frames)
164            f.close()
165
166        with open(TESTFN, 'rb') as testfile:
167            self.assertEqual(testfile.read(13), b'ababagalamaga')
168            self.check_file(testfile, self.nframes, self.frames)
169
170    def test_multiple_writes(self):
171        with open(TESTFN, 'wb') as testfile:
172            testfile.write(b'ababagalamaga')
173            f = self.create_file(testfile)
174            f.setnframes(self.nframes)
175            framesize = self.nchannels * self.sampwidth
176            f.writeframes(self.frames[:-framesize])
177            f.writeframes(self.frames[-framesize:])
178            f.close()
179
180        with open(TESTFN, 'rb') as testfile:
181            self.assertEqual(testfile.read(13), b'ababagalamaga')
182            self.check_file(testfile, self.nframes, self.frames)
183
184    def test_overflowed_write(self):
185        with open(TESTFN, 'wb') as testfile:
186            testfile.write(b'ababagalamaga')
187            f = self.create_file(testfile)
188            f.setnframes(self.nframes - 1)
189            f.writeframes(self.frames)
190            f.close()
191
192        with open(TESTFN, 'rb') as testfile:
193            self.assertEqual(testfile.read(13), b'ababagalamaga')
194            self.check_file(testfile, self.nframes, self.frames)
195
196    def test_unseekable_read(self):
197        with self.create_file(TESTFN) as f:
198            f.setnframes(self.nframes)
199            f.writeframes(self.frames)
200
201        with UnseekableIO(TESTFN, 'rb') as testfile:
202            self.check_file(testfile, self.nframes, self.frames)
203
204    def test_unseekable_write(self):
205        with UnseekableIO(TESTFN, 'wb') as testfile:
206            with self.create_file(testfile) as f:
207                f.setnframes(self.nframes)
208                f.writeframes(self.frames)
209
210        self.check_file(TESTFN, self.nframes, self.frames)
211
212    def test_unseekable_incompleted_write(self):
213        with UnseekableIO(TESTFN, 'wb') as testfile:
214            testfile.write(b'ababagalamaga')
215            f = self.create_file(testfile)
216            f.setnframes(self.nframes + 1)
217            try:
218                f.writeframes(self.frames)
219            except OSError:
220                pass
221            try:
222                f.close()
223            except OSError:
224                pass
225
226        with open(TESTFN, 'rb') as testfile:
227            self.assertEqual(testfile.read(13), b'ababagalamaga')
228            self.check_file(testfile, self.nframes + 1, self.frames)
229
230    def test_unseekable_overflowed_write(self):
231        with UnseekableIO(TESTFN, 'wb') as testfile:
232            testfile.write(b'ababagalamaga')
233            f = self.create_file(testfile)
234            f.setnframes(self.nframes - 1)
235            try:
236                f.writeframes(self.frames)
237            except OSError:
238                pass
239            try:
240                f.close()
241            except OSError:
242                pass
243
244        with open(TESTFN, 'rb') as testfile:
245            self.assertEqual(testfile.read(13), b'ababagalamaga')
246            framesize = self.nchannels * self.sampwidth
247            self.check_file(testfile, self.nframes - 1, self.frames[:-framesize])
248
249
250class AudioTestsWithSourceFile(AudioTests):
251
252    @classmethod
253    def setUpClass(cls):
254        cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
255
256    def test_read_params(self):
257        f = self.f = self.module.open(self.sndfilepath)
258        #self.assertEqual(f.getfp().name, self.sndfilepath)
259        self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
260                          self.sndfilenframes, self.comptype, self.compname)
261
262    def test_close(self):
263        with open(self.sndfilepath, 'rb') as testfile:
264            f = self.f = self.module.open(testfile)
265            self.assertFalse(testfile.closed)
266            f.close()
267            self.assertEqual(testfile.closed, self.close_fd)
268        with open(TESTFN, 'wb') as testfile:
269            fout = self.fout = self.module.open(testfile, 'wb')
270            self.assertFalse(testfile.closed)
271            with self.assertRaises(self.module.Error):
272                fout.close()
273            self.assertEqual(testfile.closed, self.close_fd)
274            fout.close() # do nothing
275
276    def test_read(self):
277        framesize = self.nchannels * self.sampwidth
278        chunk1 = self.frames[:2 * framesize]
279        chunk2 = self.frames[2 * framesize: 4 * framesize]
280        f = self.f = self.module.open(self.sndfilepath)
281        self.assertEqual(f.readframes(0), b'')
282        self.assertEqual(f.tell(), 0)
283        self.assertEqual(f.readframes(2), chunk1)
284        f.rewind()
285        pos0 = f.tell()
286        self.assertEqual(pos0, 0)
287        self.assertEqual(f.readframes(2), chunk1)
288        pos2 = f.tell()
289        self.assertEqual(pos2, 2)
290        self.assertEqual(f.readframes(2), chunk2)
291        f.setpos(pos2)
292        self.assertEqual(f.readframes(2), chunk2)
293        f.setpos(pos0)
294        self.assertEqual(f.readframes(2), chunk1)
295        with self.assertRaises(self.module.Error):
296            f.setpos(-1)
297        with self.assertRaises(self.module.Error):
298            f.setpos(f.getnframes() + 1)
299
300    def test_copy(self):
301        f = self.f = self.module.open(self.sndfilepath)
302        fout = self.fout = self.module.open(TESTFN, 'wb')
303        fout.setparams(f.getparams())
304        i = 0
305        n = f.getnframes()
306        while n > 0:
307            i += 1
308            fout.writeframes(f.readframes(i))
309            n -= i
310        fout.close()
311        fout = self.fout = self.module.open(TESTFN, 'rb')
312        f.rewind()
313        self.assertEqual(f.getparams(), fout.getparams())
314        self.assertEqual(f.readframes(f.getnframes()),
315                         fout.readframes(fout.getnframes()))
316
317    def test_read_not_from_start(self):
318        with open(TESTFN, 'wb') as testfile:
319            testfile.write(b'ababagalamaga')
320            with open(self.sndfilepath, 'rb') as f:
321                testfile.write(f.read())
322
323        with open(TESTFN, 'rb') as testfile:
324            self.assertEqual(testfile.read(13), b'ababagalamaga')
325            with self.module.open(testfile, 'rb') as f:
326                self.assertEqual(f.getnchannels(), self.nchannels)
327                self.assertEqual(f.getsampwidth(), self.sampwidth)
328                self.assertEqual(f.getframerate(), self.framerate)
329                self.assertEqual(f.getnframes(), self.sndfilenframes)
330                self.assertEqual(f.readframes(self.nframes), self.frames)
331