• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Stuff to parse WAVE files.
2
3Usage.
4
5Reading WAVE files:
6      f = wave.open(file, 'r')
7where file is either the name of a file or an open file pointer.
8The open file pointer must have methods read(), seek(), and close().
9When the setpos() and rewind() methods are not used, the seek()
10method is not  necessary.
11
12This returns an instance of a class with the following public methods:
13      getnchannels()  -- returns number of audio channels (1 for
14                         mono, 2 for stereo)
15      getsampwidth()  -- returns sample width in bytes
16      getframerate()  -- returns sampling frequency
17      getnframes()    -- returns number of audio frames
18      getcomptype()   -- returns compression type ('NONE' for linear samples)
19      getcompname()   -- returns human-readable version of
20                         compression type ('not compressed' linear samples)
21      getparams()     -- returns a namedtuple consisting of all of the
22                         above in the above order
23      getmarkers()    -- returns None (for compatibility with the
24                         old aifc module)
25      getmark(id)     -- raises an error since the mark does not
26                         exist (for compatibility with the old aifc module)
27      readframes(n)   -- returns at most n frames of audio
28      rewind()        -- rewind to the beginning of the audio stream
29      setpos(pos)     -- seek to the specified position
30      tell()          -- return the current position
31      close()         -- close the instance (make it unusable)
32The position returned by tell() and the position given to setpos()
33are compatible and have nothing to do with the actual position in the
34file.
35The close() method is called automatically when the class instance
36is destroyed.
37
38Writing WAVE files:
39      f = wave.open(file, 'w')
40where file is either the name of a file or an open file pointer.
41The open file pointer must have methods write(), tell(), seek(), and
42close().
43
44This returns an instance of a class with the following public methods:
45      setnchannels(n) -- set the number of channels
46      setsampwidth(n) -- set the sample width
47      setframerate(n) -- set the frame rate
48      setnframes(n)   -- set the number of frames
49      setcomptype(type, name)
50                      -- set the compression type and the
51                         human-readable compression type
52      setparams(tuple)
53                      -- set all parameters at once
54      tell()          -- return current position in output file
55      writeframesraw(data)
56                      -- write audio frames without patching up the
57                         file header
58      writeframes(data)
59                      -- write audio frames and patch up the file header
60      close()         -- patch up the file header and close the
61                         output file
62You should set the parameters before the first writeframesraw or
63writeframes.  The total number of frames does not need to be set,
64but when it is set to the correct value, the header does not have to
65be patched up.
66It is best to first set all parameters, perhaps possibly the
67compression type, and then write audio frames using writeframesraw.
68When all frames have been written, either call writeframes(b'') or
69close() to patch up the sizes in the header.
70The close() method is called automatically when the class instance
71is destroyed.
72"""
73
74from collections import namedtuple
75import builtins
76import struct
77import sys
78
79
80__all__ = ["open", "Error", "Wave_read", "Wave_write"]
81
82class Error(Exception):
83    pass
84
85WAVE_FORMAT_PCM = 0x0001
86WAVE_FORMAT_EXTENSIBLE = 0xFFFE
87# Derived from uuid.UUID("00000001-0000-0010-8000-00aa00389b71").bytes_le
88KSDATAFORMAT_SUBTYPE_PCM = b'\x01\x00\x00\x00\x00\x00\x10\x00\x80\x00\x00\xaa\x008\x9bq'
89
90_array_fmts = None, 'b', 'h', None, 'i'
91
92_wave_params = namedtuple('_wave_params',
93                     'nchannels sampwidth framerate nframes comptype compname')
94
95
96def _byteswap(data, width):
97    swapped_data = bytearray(len(data))
98
99    for i in range(0, len(data), width):
100        for j in range(width):
101            swapped_data[i + width - 1 - j] = data[i + j]
102
103    return bytes(swapped_data)
104
105
106class _Chunk:
107    def __init__(self, file, align=True, bigendian=True, inclheader=False):
108        self.closed = False
109        self.align = align      # whether to align to word (2-byte) boundaries
110        if bigendian:
111            strflag = '>'
112        else:
113            strflag = '<'
114        self.file = file
115        self.chunkname = file.read(4)
116        if len(self.chunkname) < 4:
117            raise EOFError
118        try:
119            self.chunksize = struct.unpack_from(strflag+'L', file.read(4))[0]
120        except struct.error:
121            raise EOFError from None
122        if inclheader:
123            self.chunksize = self.chunksize - 8 # subtract header
124        self.size_read = 0
125        try:
126            self.offset = self.file.tell()
127        except (AttributeError, OSError):
128            self.seekable = False
129        else:
130            self.seekable = True
131
132    def getname(self):
133        """Return the name (ID) of the current chunk."""
134        return self.chunkname
135
136    def close(self):
137        if not self.closed:
138            try:
139                self.skip()
140            finally:
141                self.closed = True
142
143    def seek(self, pos, whence=0):
144        """Seek to specified position into the chunk.
145        Default position is 0 (start of chunk).
146        If the file is not seekable, this will result in an error.
147        """
148
149        if self.closed:
150            raise ValueError("I/O operation on closed file")
151        if not self.seekable:
152            raise OSError("cannot seek")
153        if whence == 1:
154            pos = pos + self.size_read
155        elif whence == 2:
156            pos = pos + self.chunksize
157        if pos < 0 or pos > self.chunksize:
158            raise RuntimeError
159        self.file.seek(self.offset + pos, 0)
160        self.size_read = pos
161
162    def tell(self):
163        if self.closed:
164            raise ValueError("I/O operation on closed file")
165        return self.size_read
166
167    def read(self, size=-1):
168        """Read at most size bytes from the chunk.
169        If size is omitted or negative, read until the end
170        of the chunk.
171        """
172
173        if self.closed:
174            raise ValueError("I/O operation on closed file")
175        if self.size_read >= self.chunksize:
176            return b''
177        if size < 0:
178            size = self.chunksize - self.size_read
179        if size > self.chunksize - self.size_read:
180            size = self.chunksize - self.size_read
181        data = self.file.read(size)
182        self.size_read = self.size_read + len(data)
183        if self.size_read == self.chunksize and \
184           self.align and \
185           (self.chunksize & 1):
186            dummy = self.file.read(1)
187            self.size_read = self.size_read + len(dummy)
188        return data
189
190    def skip(self):
191        """Skip the rest of the chunk.
192        If you are not interested in the contents of the chunk,
193        this method should be called so that the file points to
194        the start of the next chunk.
195        """
196
197        if self.closed:
198            raise ValueError("I/O operation on closed file")
199        if self.seekable:
200            try:
201                n = self.chunksize - self.size_read
202                # maybe fix alignment
203                if self.align and (self.chunksize & 1):
204                    n = n + 1
205                self.file.seek(n, 1)
206                self.size_read = self.size_read + n
207                return
208            except OSError:
209                pass
210        while self.size_read < self.chunksize:
211            n = min(8192, self.chunksize - self.size_read)
212            dummy = self.read(n)
213            if not dummy:
214                raise EOFError
215
216
217class Wave_read:
218    """Variables used in this class:
219
220    These variables are available to the user though appropriate
221    methods of this class:
222    _file -- the open file with methods read(), close(), and seek()
223              set through the __init__() method
224    _nchannels -- the number of audio channels
225              available through the getnchannels() method
226    _nframes -- the number of audio frames
227              available through the getnframes() method
228    _sampwidth -- the number of bytes per audio sample
229              available through the getsampwidth() method
230    _framerate -- the sampling frequency
231              available through the getframerate() method
232    _comptype -- the AIFF-C compression type ('NONE' if AIFF)
233              available through the getcomptype() method
234    _compname -- the human-readable AIFF-C compression type
235              available through the getcomptype() method
236    _soundpos -- the position in the audio stream
237              available through the tell() method, set through the
238              setpos() method
239
240    These variables are used internally only:
241    _fmt_chunk_read -- 1 iff the FMT chunk has been read
242    _data_seek_needed -- 1 iff positioned correctly in audio
243              file for readframes()
244    _data_chunk -- instantiation of a chunk class for the DATA chunk
245    _framesize -- size of one frame in the file
246    """
247
248    def initfp(self, file):
249        self._convert = None
250        self._soundpos = 0
251        self._file = _Chunk(file, bigendian = 0)
252        if self._file.getname() != b'RIFF':
253            raise Error('file does not start with RIFF id')
254        if self._file.read(4) != b'WAVE':
255            raise Error('not a WAVE file')
256        self._fmt_chunk_read = 0
257        self._data_chunk = None
258        while 1:
259            self._data_seek_needed = 1
260            try:
261                chunk = _Chunk(self._file, bigendian = 0)
262            except EOFError:
263                break
264            chunkname = chunk.getname()
265            if chunkname == b'fmt ':
266                self._read_fmt_chunk(chunk)
267                self._fmt_chunk_read = 1
268            elif chunkname == b'data':
269                if not self._fmt_chunk_read:
270                    raise Error('data chunk before fmt chunk')
271                self._data_chunk = chunk
272                self._nframes = chunk.chunksize // self._framesize
273                self._data_seek_needed = 0
274                break
275            chunk.skip()
276        if not self._fmt_chunk_read or not self._data_chunk:
277            raise Error('fmt chunk and/or data chunk missing')
278
279    def __init__(self, f):
280        self._i_opened_the_file = None
281        if isinstance(f, str):
282            f = builtins.open(f, 'rb')
283            self._i_opened_the_file = f
284        # else, assume it is an open file object already
285        try:
286            self.initfp(f)
287        except:
288            if self._i_opened_the_file:
289                f.close()
290            raise
291
292    def __del__(self):
293        self.close()
294
295    def __enter__(self):
296        return self
297
298    def __exit__(self, *args):
299        self.close()
300
301    #
302    # User visible methods.
303    #
304    def getfp(self):
305        return self._file
306
307    def rewind(self):
308        self._data_seek_needed = 1
309        self._soundpos = 0
310
311    def close(self):
312        self._file = None
313        file = self._i_opened_the_file
314        if file:
315            self._i_opened_the_file = None
316            file.close()
317
318    def tell(self):
319        return self._soundpos
320
321    def getnchannels(self):
322        return self._nchannels
323
324    def getnframes(self):
325        return self._nframes
326
327    def getsampwidth(self):
328        return self._sampwidth
329
330    def getframerate(self):
331        return self._framerate
332
333    def getcomptype(self):
334        return self._comptype
335
336    def getcompname(self):
337        return self._compname
338
339    def getparams(self):
340        return _wave_params(self.getnchannels(), self.getsampwidth(),
341                       self.getframerate(), self.getnframes(),
342                       self.getcomptype(), self.getcompname())
343
344    def getmarkers(self):
345        import warnings
346        warnings._deprecated("Wave_read.getmarkers", remove=(3, 15))
347        return None
348
349    def getmark(self, id):
350        import warnings
351        warnings._deprecated("Wave_read.getmark", remove=(3, 15))
352        raise Error('no marks')
353
354    def setpos(self, pos):
355        if pos < 0 or pos > self._nframes:
356            raise Error('position not in range')
357        self._soundpos = pos
358        self._data_seek_needed = 1
359
360    def readframes(self, nframes):
361        if self._data_seek_needed:
362            self._data_chunk.seek(0, 0)
363            pos = self._soundpos * self._framesize
364            if pos:
365                self._data_chunk.seek(pos, 0)
366            self._data_seek_needed = 0
367        if nframes == 0:
368            return b''
369        data = self._data_chunk.read(nframes * self._framesize)
370        if self._sampwidth != 1 and sys.byteorder == 'big':
371            data = _byteswap(data, self._sampwidth)
372        if self._convert and data:
373            data = self._convert(data)
374        self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth)
375        return data
376
377    #
378    # Internal methods.
379    #
380
381    def _read_fmt_chunk(self, chunk):
382        try:
383            wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
384        except struct.error:
385            raise EOFError from None
386        if wFormatTag != WAVE_FORMAT_PCM and wFormatTag != WAVE_FORMAT_EXTENSIBLE:
387            raise Error('unknown format: %r' % (wFormatTag,))
388        try:
389            sampwidth = struct.unpack_from('<H', chunk.read(2))[0]
390        except struct.error:
391            raise EOFError from None
392        if wFormatTag == WAVE_FORMAT_EXTENSIBLE:
393            try:
394                cbSize, wValidBitsPerSample, dwChannelMask = struct.unpack_from('<HHL', chunk.read(8))
395                # Read the entire UUID from the chunk
396                SubFormat = chunk.read(16)
397                if len(SubFormat) < 16:
398                    raise EOFError
399            except struct.error:
400                raise EOFError from None
401            if SubFormat != KSDATAFORMAT_SUBTYPE_PCM:
402                try:
403                    import uuid
404                    subformat_msg = f'unknown extended format: {uuid.UUID(bytes_le=SubFormat)}'
405                except Exception:
406                    subformat_msg = 'unknown extended format'
407                raise Error(subformat_msg)
408        self._sampwidth = (sampwidth + 7) // 8
409        if not self._sampwidth:
410            raise Error('bad sample width')
411        if not self._nchannels:
412            raise Error('bad # of channels')
413        self._framesize = self._nchannels * self._sampwidth
414        self._comptype = 'NONE'
415        self._compname = 'not compressed'
416
417
418class Wave_write:
419    """Variables used in this class:
420
421    These variables are user settable through appropriate methods
422    of this class:
423    _file -- the open file with methods write(), close(), tell(), seek()
424              set through the __init__() method
425    _comptype -- the AIFF-C compression type ('NONE' in AIFF)
426              set through the setcomptype() or setparams() method
427    _compname -- the human-readable AIFF-C compression type
428              set through the setcomptype() or setparams() method
429    _nchannels -- the number of audio channels
430              set through the setnchannels() or setparams() method
431    _sampwidth -- the number of bytes per audio sample
432              set through the setsampwidth() or setparams() method
433    _framerate -- the sampling frequency
434              set through the setframerate() or setparams() method
435    _nframes -- the number of audio frames written to the header
436              set through the setnframes() or setparams() method
437
438    These variables are used internally only:
439    _datalength -- the size of the audio samples written to the header
440    _nframeswritten -- the number of frames actually written
441    _datawritten -- the size of the audio samples actually written
442    """
443
444    def __init__(self, f):
445        self._i_opened_the_file = None
446        if isinstance(f, str):
447            f = builtins.open(f, 'wb')
448            self._i_opened_the_file = f
449        try:
450            self.initfp(f)
451        except:
452            if self._i_opened_the_file:
453                f.close()
454            raise
455
456    def initfp(self, file):
457        self._file = file
458        self._convert = None
459        self._nchannels = 0
460        self._sampwidth = 0
461        self._framerate = 0
462        self._nframes = 0
463        self._nframeswritten = 0
464        self._datawritten = 0
465        self._datalength = 0
466        self._headerwritten = False
467
468    def __del__(self):
469        self.close()
470
471    def __enter__(self):
472        return self
473
474    def __exit__(self, *args):
475        self.close()
476
477    #
478    # User visible methods.
479    #
480    def setnchannels(self, nchannels):
481        if self._datawritten:
482            raise Error('cannot change parameters after starting to write')
483        if nchannels < 1:
484            raise Error('bad # of channels')
485        self._nchannels = nchannels
486
487    def getnchannels(self):
488        if not self._nchannels:
489            raise Error('number of channels not set')
490        return self._nchannels
491
492    def setsampwidth(self, sampwidth):
493        if self._datawritten:
494            raise Error('cannot change parameters after starting to write')
495        if sampwidth < 1 or sampwidth > 4:
496            raise Error('bad sample width')
497        self._sampwidth = sampwidth
498
499    def getsampwidth(self):
500        if not self._sampwidth:
501            raise Error('sample width not set')
502        return self._sampwidth
503
504    def setframerate(self, framerate):
505        if self._datawritten:
506            raise Error('cannot change parameters after starting to write')
507        if framerate <= 0:
508            raise Error('bad frame rate')
509        self._framerate = int(round(framerate))
510
511    def getframerate(self):
512        if not self._framerate:
513            raise Error('frame rate not set')
514        return self._framerate
515
516    def setnframes(self, nframes):
517        if self._datawritten:
518            raise Error('cannot change parameters after starting to write')
519        self._nframes = nframes
520
521    def getnframes(self):
522        return self._nframeswritten
523
524    def setcomptype(self, comptype, compname):
525        if self._datawritten:
526            raise Error('cannot change parameters after starting to write')
527        if comptype not in ('NONE',):
528            raise Error('unsupported compression type')
529        self._comptype = comptype
530        self._compname = compname
531
532    def getcomptype(self):
533        return self._comptype
534
535    def getcompname(self):
536        return self._compname
537
538    def setparams(self, params):
539        nchannels, sampwidth, framerate, nframes, comptype, compname = params
540        if self._datawritten:
541            raise Error('cannot change parameters after starting to write')
542        self.setnchannels(nchannels)
543        self.setsampwidth(sampwidth)
544        self.setframerate(framerate)
545        self.setnframes(nframes)
546        self.setcomptype(comptype, compname)
547
548    def getparams(self):
549        if not self._nchannels or not self._sampwidth or not self._framerate:
550            raise Error('not all parameters set')
551        return _wave_params(self._nchannels, self._sampwidth, self._framerate,
552              self._nframes, self._comptype, self._compname)
553
554    def setmark(self, id, pos, name):
555        import warnings
556        warnings._deprecated("Wave_write.setmark", remove=(3, 15))
557        raise Error('setmark() not supported')
558
559    def getmark(self, id):
560        import warnings
561        warnings._deprecated("Wave_write.getmark", remove=(3, 15))
562        raise Error('no marks')
563
564    def getmarkers(self):
565        import warnings
566        warnings._deprecated("Wave_write.getmarkers", remove=(3, 15))
567        return None
568
569    def tell(self):
570        return self._nframeswritten
571
572    def writeframesraw(self, data):
573        if not isinstance(data, (bytes, bytearray)):
574            data = memoryview(data).cast('B')
575        self._ensure_header_written(len(data))
576        nframes = len(data) // (self._sampwidth * self._nchannels)
577        if self._convert:
578            data = self._convert(data)
579        if self._sampwidth != 1 and sys.byteorder == 'big':
580            data = _byteswap(data, self._sampwidth)
581        self._file.write(data)
582        self._datawritten += len(data)
583        self._nframeswritten = self._nframeswritten + nframes
584
585    def writeframes(self, data):
586        self.writeframesraw(data)
587        if self._datalength != self._datawritten:
588            self._patchheader()
589
590    def close(self):
591        try:
592            if self._file:
593                self._ensure_header_written(0)
594                if self._datalength != self._datawritten:
595                    self._patchheader()
596                self._file.flush()
597        finally:
598            self._file = None
599            file = self._i_opened_the_file
600            if file:
601                self._i_opened_the_file = None
602                file.close()
603
604    #
605    # Internal methods.
606    #
607
608    def _ensure_header_written(self, datasize):
609        if not self._headerwritten:
610            if not self._nchannels:
611                raise Error('# channels not specified')
612            if not self._sampwidth:
613                raise Error('sample width not specified')
614            if not self._framerate:
615                raise Error('sampling rate not specified')
616            self._write_header(datasize)
617
618    def _write_header(self, initlength):
619        assert not self._headerwritten
620        self._file.write(b'RIFF')
621        if not self._nframes:
622            self._nframes = initlength // (self._nchannels * self._sampwidth)
623        self._datalength = self._nframes * self._nchannels * self._sampwidth
624        try:
625            self._form_length_pos = self._file.tell()
626        except (AttributeError, OSError):
627            self._form_length_pos = None
628        self._file.write(struct.pack('<L4s4sLHHLLHH4s',
629            36 + self._datalength, b'WAVE', b'fmt ', 16,
630            WAVE_FORMAT_PCM, self._nchannels, self._framerate,
631            self._nchannels * self._framerate * self._sampwidth,
632            self._nchannels * self._sampwidth,
633            self._sampwidth * 8, b'data'))
634        if self._form_length_pos is not None:
635            self._data_length_pos = self._file.tell()
636        self._file.write(struct.pack('<L', self._datalength))
637        self._headerwritten = True
638
639    def _patchheader(self):
640        assert self._headerwritten
641        if self._datawritten == self._datalength:
642            return
643        curpos = self._file.tell()
644        self._file.seek(self._form_length_pos, 0)
645        self._file.write(struct.pack('<L', 36 + self._datawritten))
646        self._file.seek(self._data_length_pos, 0)
647        self._file.write(struct.pack('<L', self._datawritten))
648        self._file.seek(curpos, 0)
649        self._datalength = self._datawritten
650
651
652def open(f, mode=None):
653    if mode is None:
654        if hasattr(f, 'mode'):
655            mode = f.mode
656        else:
657            mode = 'rb'
658    if mode in ('r', 'rb'):
659        return Wave_read(f)
660    elif mode in ('w', 'wb'):
661        return Wave_write(f)
662    else:
663        raise Error("mode must be 'r', 'rb', 'w', or 'wb'")
664