1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8 9"""Extraction of annotations from audio files. 10""" 11 12from __future__ import division 13import logging 14import os 15import shutil 16import struct 17import subprocess 18import sys 19import tempfile 20 21try: 22 import numpy as np 23except ImportError: 24 logging.critical('Cannot import the third-party Python package numpy') 25 sys.exit(1) 26 27from . import external_vad 28from . import exceptions 29from . import signal_processing 30 31 32class AudioAnnotationsExtractor(object): 33 """Extracts annotations from audio files. 34 """ 35 36 class VadType(object): 37 ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard. 38 WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h 39 WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h 40 41 def __init__(self, value): 42 if (not isinstance(value, int)) or not 0 <= value <= 7: 43 raise exceptions.InitializationException( 44 'Invalid vad type: ' + value) 45 self._value = value 46 47 def Contains(self, vad_type): 48 return self._value | vad_type == self._value 49 50 def __str__(self): 51 vads = [] 52 if self.Contains(self.ENERGY_THRESHOLD): 53 vads.append("energy") 54 if self.Contains(self.WEBRTC_COMMON_AUDIO): 55 vads.append("common_audio") 56 if self.Contains(self.WEBRTC_APM): 57 vads.append("apm") 58 return "VadType({})".format(", ".join(vads)) 59 60 _OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz' 61 62 # Level estimation params. 63 _ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0) 64 _LEVEL_FRAME_SIZE_MS = 1.0 65 # The time constants in ms indicate the time it takes for the level estimate 66 # to go down/up by 1 db if the signal is zero. 67 _LEVEL_ATTACK_MS = 5.0 68 _LEVEL_DECAY_MS = 20.0 69 70 # VAD params. 71 _VAD_THRESHOLD = 1 72 _VAD_WEBRTC_PATH = os.path.join(os.path.dirname( 73 os.path.abspath(__file__)), os.pardir, os.pardir) 74 _VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad') 75 76 _VAD_WEBRTC_APM_PATH = os.path.join( 77 _VAD_WEBRTC_PATH, 'apm_vad') 78 79 def __init__(self, vad_type, external_vads=None): 80 self._signal = None 81 self._level = None 82 self._level_frame_size = None 83 self._common_audio_vad = None 84 self._energy_vad = None 85 self._apm_vad_probs = None 86 self._apm_vad_rms = None 87 self._vad_frame_size = None 88 self._vad_frame_size_ms = None 89 self._c_attack = None 90 self._c_decay = None 91 92 self._vad_type = self.VadType(vad_type) 93 logging.info('VADs used for annotations: ' + str(self._vad_type)) 94 95 if external_vads is None: 96 external_vads = {} 97 self._external_vads = external_vads 98 99 assert len(self._external_vads) == len(external_vads), ( 100 'The external VAD names must be unique.') 101 for vad in external_vads.values(): 102 if not isinstance(vad, external_vad.ExternalVad): 103 raise exceptions.InitializationException( 104 'Invalid vad type: ' + str(type(vad))) 105 logging.info('External VAD used for annotation: ' + 106 str(vad.name)) 107 108 assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \ 109 self._VAD_WEBRTC_COMMON_AUDIO_PATH 110 assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \ 111 self._VAD_WEBRTC_APM_PATH 112 113 @classmethod 114 def GetOutputFileNameTemplate(cls): 115 return cls._OUTPUT_FILENAME_TEMPLATE 116 117 def GetLevel(self): 118 return self._level 119 120 def GetLevelFrameSize(self): 121 return self._level_frame_size 122 123 @classmethod 124 def GetLevelFrameSizeMs(cls): 125 return cls._LEVEL_FRAME_SIZE_MS 126 127 def GetVadOutput(self, vad_type): 128 if vad_type == self.VadType.ENERGY_THRESHOLD: 129 return self._energy_vad 130 elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO: 131 return self._common_audio_vad 132 elif vad_type == self.VadType.WEBRTC_APM: 133 return (self._apm_vad_probs, self._apm_vad_rms) 134 else: 135 raise exceptions.InitializationException( 136 'Invalid vad type: ' + vad_type) 137 138 def GetVadFrameSize(self): 139 return self._vad_frame_size 140 141 def GetVadFrameSizeMs(self): 142 return self._vad_frame_size_ms 143 144 def Extract(self, filepath): 145 # Load signal. 146 self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath) 147 if self._signal.channels != 1: 148 raise NotImplementedError('Multiple-channel annotations not implemented') 149 150 # Level estimation params. 151 self._level_frame_size = int(self._signal.frame_rate / 1000 * ( 152 self._LEVEL_FRAME_SIZE_MS)) 153 self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else ( 154 self._ONE_DB_REDUCTION ** ( 155 self._LEVEL_FRAME_SIZE_MS / self._LEVEL_ATTACK_MS)) 156 self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else ( 157 self._ONE_DB_REDUCTION ** ( 158 self._LEVEL_FRAME_SIZE_MS / self._LEVEL_DECAY_MS)) 159 160 # Compute level. 161 self._LevelEstimation() 162 163 # Ideal VAD output, it requires clean speech with high SNR as input. 164 if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD): 165 # Naive VAD based on level thresholding. 166 vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD) 167 self._energy_vad = np.uint8(self._level > vad_threshold) 168 self._vad_frame_size = self._level_frame_size 169 self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS 170 if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO): 171 # WebRTC common_audio/ VAD. 172 self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate) 173 if self._vad_type.Contains(self.VadType.WEBRTC_APM): 174 # WebRTC modules/audio_processing/ VAD. 175 self._RunWebRtcApmVad(filepath) 176 for extvad_name in self._external_vads: 177 self._external_vads[extvad_name].Run(filepath) 178 179 def Save(self, output_path, annotation_name=""): 180 ext_kwargs = {'extvad_conf-' + ext_vad: 181 self._external_vads[ext_vad].GetVadOutput() 182 for ext_vad in self._external_vads} 183 np.savez_compressed( 184 file=os.path.join( 185 output_path, 186 self.GetOutputFileNameTemplate().format(annotation_name)), 187 level=self._level, 188 level_frame_size=self._level_frame_size, 189 level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS, 190 vad_output=self._common_audio_vad, 191 vad_energy_output=self._energy_vad, 192 vad_frame_size=self._vad_frame_size, 193 vad_frame_size_ms=self._vad_frame_size_ms, 194 vad_probs=self._apm_vad_probs, 195 vad_rms=self._apm_vad_rms, 196 **ext_kwargs 197 ) 198 199 def _LevelEstimation(self): 200 # Read samples. 201 samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData( 202 self._signal).astype(np.float32) / 32768.0 203 num_frames = len(samples) // self._level_frame_size 204 num_samples = num_frames * self._level_frame_size 205 206 # Envelope. 207 self._level = np.max(np.reshape(np.abs(samples[:num_samples]), ( 208 num_frames, self._level_frame_size)), axis=1) 209 assert len(self._level) == num_frames 210 211 # Envelope smoothing. 212 smooth = lambda curr, prev, k: (1 - k) * curr + k * prev 213 self._level[0] = smooth(self._level[0], 0.0, self._c_attack) 214 for i in range(1, num_frames): 215 self._level[i] = smooth( 216 self._level[i], self._level[i - 1], self._c_attack if ( 217 self._level[i] > self._level[i - 1]) else self._c_decay) 218 219 def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate): 220 self._common_audio_vad = None 221 self._vad_frame_size = None 222 223 # Create temporary output path. 224 tmp_path = tempfile.mkdtemp() 225 output_file_path = os.path.join( 226 tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp') 227 228 # Call WebRTC VAD. 229 try: 230 subprocess.call([ 231 self._VAD_WEBRTC_COMMON_AUDIO_PATH, 232 '-i', wav_file_path, 233 '-o', output_file_path 234 ], cwd=self._VAD_WEBRTC_PATH) 235 236 # Read bytes. 237 with open(output_file_path, 'rb') as f: 238 raw_data = f.read() 239 240 # Parse side information. 241 self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0] 242 self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000 243 assert self._vad_frame_size_ms in [10, 20, 30] 244 extra_bits = struct.unpack('B', raw_data[-1])[0] 245 assert 0 <= extra_bits <= 8 246 247 # Init VAD vector. 248 num_bytes = len(raw_data) 249 num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte. 250 self._common_audio_vad = np.zeros(num_frames, np.uint8) 251 252 # Read VAD decisions. 253 for i, byte in enumerate(raw_data[1:-1]): 254 byte = struct.unpack('B', byte)[0] 255 for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)): 256 self._common_audio_vad[i * 8 + j] = int(byte & 1) 257 byte = byte >> 1 258 except Exception as e: 259 logging.error('Error while running the WebRTC VAD (' + e.message + ')') 260 finally: 261 if os.path.exists(tmp_path): 262 shutil.rmtree(tmp_path) 263 264 def _RunWebRtcApmVad(self, wav_file_path): 265 # Create temporary output path. 266 tmp_path = tempfile.mkdtemp() 267 output_file_path_probs = os.path.join( 268 tmp_path, os.path.split(wav_file_path)[1] + '_vad_probs.tmp') 269 output_file_path_rms = os.path.join( 270 tmp_path, os.path.split(wav_file_path)[1] + '_vad_rms.tmp') 271 272 # Call WebRTC VAD. 273 try: 274 subprocess.call([ 275 self._VAD_WEBRTC_APM_PATH, 276 '-i', wav_file_path, 277 '-o_probs', output_file_path_probs, 278 '-o_rms', output_file_path_rms 279 ], cwd=self._VAD_WEBRTC_PATH) 280 281 # Parse annotations. 282 self._apm_vad_probs = np.fromfile(output_file_path_probs, np.double) 283 self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double) 284 assert len(self._apm_vad_rms) == len(self._apm_vad_probs) 285 286 except Exception as e: 287 logging.error('Error while running the WebRTC APM VAD (' + 288 e.message + ')') 289 finally: 290 if os.path.exists(tmp_path): 291 shutil.rmtree(tmp_path) 292