1# Copyright (c) 2016 The Chromium Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5"""Server side audio utilities functions for Brillo.""" 6 7import contextlib 8import logging 9import numpy 10import os 11import struct 12import subprocess 13import tempfile 14import wave 15 16from autotest_lib.client.common_lib import error 17 18 19_BITS_PER_BYTE=8 20 21# Thresholds used when comparing files. 22# 23# The frequency threshold used when comparing files. The frequency of the 24# recorded audio has to be within _FREQUENCY_THRESHOLD percent of the frequency 25# of the original audio. 26_FREQUENCY_THRESHOLD = 0.01 27# Noise threshold controls how much noise is allowed as a fraction of the 28# magnitude of the peak frequency after taking an FFT. The power of all the 29# other frequencies in the signal should be within _FFT_NOISE_THRESHOLD percent 30# of the power of the main frequency. 31_FFT_NOISE_THRESHOLD = 0.05 32 33# Command used to encode audio. If you want to test with something different, 34# this should be changed. 35_ENCODING_CMD = 'sox' 36 37 38def extract_wav_frames(wave_file): 39 """Extract all frames from a WAV file. 40 41 @param wave_file: A Wave_read object representing a WAV file opened for 42 reading. 43 44 @return: A list containing the frames in the WAV file. 45 """ 46 num_frames = wave_file.getnframes() 47 sample_width = wave_file.getsampwidth() 48 if sample_width == 1: 49 fmt = '%iB' # Read 1 byte. 50 elif sample_width == 2: 51 fmt = '%ih' # Read 2 bytes. 52 elif sample_width == 4: 53 fmt = '%ii' # Read 4 bytes. 54 else: 55 raise ValueError('Unsupported sample width') 56 frames = list(struct.unpack(fmt % num_frames * wave_file.getnchannels(), 57 wave_file.readframes(num_frames))) 58 59 # Since 8-bit PCM is unsigned with an offset of 128, we subtract the offset 60 # to make it signed since the rest of the code assumes signed numbers. 61 if sample_width == 1: 62 frames = [val - 128 for val in frames] 63 64 return frames 65 66 67def check_wav_file(filename, num_channels=None, sample_rate=None, 68 sample_width=None): 69 """Checks a WAV file and returns its peak PCM values. 70 71 @param filename: Input WAV file to analyze. 72 @param num_channels: Number of channels to expect (None to not check). 73 @param sample_rate: Sample rate to expect (None to not check). 74 @param sample_width: Sample width to expect (None to not check). 75 76 @return A list of the absolute maximum PCM values for each channel in the 77 WAV file. 78 79 @raise ValueError: Failed to process the WAV file or validate an attribute. 80 """ 81 chk_file = None 82 try: 83 chk_file = wave.open(filename, 'r') 84 if num_channels is not None and chk_file.getnchannels() != num_channels: 85 raise ValueError('Expected %d channels but got %d instead.', 86 num_channels, chk_file.getnchannels()) 87 if sample_rate is not None and chk_file.getframerate() != sample_rate: 88 raise ValueError('Expected sample rate %d but got %d instead.', 89 sample_rate, chk_file.getframerate()) 90 if sample_width is not None and chk_file.getsampwidth() != sample_width: 91 raise ValueError('Expected sample width %d but got %d instead.', 92 sample_width, chk_file.getsampwidth()) 93 frames = extract_wav_frames(chk_file) 94 except wave.Error as e: 95 raise ValueError('Error processing WAV file: %s' % e) 96 finally: 97 if chk_file is not None: 98 chk_file.close() 99 100 peaks = [] 101 for i in range(chk_file.getnchannels()): 102 peaks.append(max(map(abs, frames[i::chk_file.getnchannels()]))) 103 return peaks; 104 105 106def generate_sine_file(host, num_channels, sample_rate, sample_width, 107 duration_secs, sine_frequency, temp_dir, 108 file_format='wav'): 109 """Generate a sine file and push it to the DUT. 110 111 @param host: An object representing the DUT. 112 @param num_channels: Number of channels to use. 113 @param sample_rate: Sample rate to use for sine wave generation. 114 @param sample_width: Sample width to use for sine wave generation. 115 @param duration_secs: Duration in seconds to generate sine wave for. 116 @param sine_frequency: Frequency to generate sine wave at. 117 @param temp_dir: A temporary directory on the host. 118 @param file_format: A string representing the encoding for the audio file. 119 120 @return A tuple of the filename on the server and the DUT. 121 """; 122 _, local_filename = tempfile.mkstemp( 123 prefix='sine-', suffix='.' + file_format, dir=temp_dir) 124 if sample_width == 1: 125 byte_format = '-e unsigned' 126 else: 127 byte_format = '-e signed' 128 gen_file_cmd = ('sox -n -t wav -c %d %s -b %d -r %d %s synth %d sine %d ' 129 'vol 0.9' % (num_channels, byte_format, 130 sample_width * _BITS_PER_BYTE, sample_rate, 131 local_filename, duration_secs, sine_frequency)) 132 logging.info('Command to generate sine wave: %s', gen_file_cmd) 133 subprocess.call(gen_file_cmd, shell=True) 134 if file_format != 'wav': 135 # Convert the file to the appropriate format. 136 logging.info('Converting file to %s', file_format) 137 _, local_encoded_filename = tempfile.mkstemp( 138 prefix='sine-', suffix='.' + file_format, dir=temp_dir) 139 cvt_file_cmd = '%s %s %s' % (_ENCODING_CMD, local_filename, 140 local_encoded_filename) 141 logging.info('Command to convert file: %s', cvt_file_cmd) 142 subprocess.call(cvt_file_cmd, shell=True) 143 else: 144 local_encoded_filename = local_filename 145 dut_tmp_dir = '/data' 146 remote_filename = os.path.join(dut_tmp_dir, 'sine.' + file_format) 147 logging.info('Send file to DUT.') 148 # TODO(ralphnathan): Find a better place to put this file once the SELinux 149 # issues are resolved. 150 logging.info('remote_filename %s', remote_filename) 151 host.send_file(local_encoded_filename, remote_filename) 152 return local_filename, remote_filename 153 154 155def _is_outside_frequency_threshold(freq_reference, freq_rec): 156 """Compares the frequency of the recorded audio with the reference audio. 157 158 This function checks to see if the frequencies corresponding to the peak 159 FFT values are similiar meaning that the dominant frequency in the audio 160 signal is the same for the recorded audio as that in the audio played. 161 162 @param req_reference: The dominant frequency in the reference audio file. 163 @param freq_rec: The dominant frequency in the recorded audio file. 164 165 @return: True is freq_rec is with _FREQUENCY_THRESHOLD percent of 166 freq_reference. 167 """ 168 ratio = float(freq_rec) / freq_reference 169 if ratio > 1 + _FREQUENCY_THRESHOLD or ratio < 1 - _FREQUENCY_THRESHOLD: 170 return True 171 return False 172 173 174def _compare_frames(reference_file_frames, rec_file_frames, num_channels, 175 sample_rate): 176 """Compares audio frames from the reference file and the recorded file. 177 178 This method checks for two things: 179 1. That the main frequency is the same in both the files. This is done 180 using the FFT and observing the frequency corresponding to the 181 peak. 182 2. That there is no other dominant frequency in the recorded file. 183 This is done by sweeping the frequency domain and checking that the 184 frequency is always less than _FFT_NOISE_THRESHOLD percentage of 185 the peak. 186 187 The key assumption here is that the reference audio file contains only 188 one frequency. 189 190 @param reference_file_frames: Audio frames from the reference file. 191 @param rec_file_frames: Audio frames from the recorded file. 192 @param num_channels: Number of channels in the files. 193 @param sample_rate: Sample rate of the files. 194 195 @raise error.TestFail: The frequency of the recorded signal doesn't 196 match that of the reference signal. 197 @raise error.TestFail: There is too much noise in the recorded signal. 198 """ 199 for channel in range(num_channels): 200 reference_data = reference_file_frames[channel::num_channels] 201 rec_data = rec_file_frames[channel::num_channels] 202 203 # Get fft and frequencies corresponding to the fft values. 204 fft_reference = numpy.fft.rfft(reference_data) 205 fft_rec = numpy.fft.rfft(rec_data) 206 fft_freqs_reference = numpy.fft.rfftfreq(len(reference_data), 207 1.0 / sample_rate) 208 fft_freqs_rec = numpy.fft.rfftfreq(len(rec_data), 1.0 / sample_rate) 209 210 # Get frequency at highest peak. 211 freq_reference = fft_freqs_reference[ 212 numpy.argmax(numpy.abs(fft_reference))] 213 abs_fft_rec = numpy.abs(fft_rec) 214 freq_rec = fft_freqs_rec[numpy.argmax(abs_fft_rec)] 215 216 # Compare the two frequencies. 217 logging.info('Golden frequency of channel %i is %f', channel, 218 freq_reference) 219 logging.info('Recorded frequency of channel %i is %f', channel, 220 freq_rec) 221 if _is_outside_frequency_threshold(freq_reference, freq_rec): 222 raise error.TestFail('The recorded audio frequency does not match ' 223 'that of the audio played.') 224 225 # Check for noise in the frequency domain. 226 fft_rec_peak_val = numpy.max(abs_fft_rec) 227 noise_detected = False 228 for fft_index, fft_val in enumerate(abs_fft_rec): 229 if _is_outside_frequency_threshold(freq_reference, freq_rec): 230 # If the frequency exceeds _FFT_NOISE_THRESHOLD, then fail. 231 if fft_val > _FFT_NOISE_THRESHOLD * fft_rec_peak_val: 232 logging.warning('Unexpected frequency peak detected at %f ' 233 'Hz.', fft_freqs_rec[fft_index]) 234 noise_detected = True 235 236 if noise_detected: 237 raise error.TestFail('Signal is noiser than expected.') 238 239 240def compare_file(reference_audio_filename, test_audio_filename): 241 """Compares the recorded audio file to the reference audio file. 242 243 @param reference_audio_filename : Reference audio file containing the 244 reference signal. 245 @param test_audio_filename: Audio file containing audio captured from 246 the test. 247 """ 248 with contextlib.closing(wave.open(reference_audio_filename, 249 'rb')) as reference_file: 250 with contextlib.closing(wave.open(test_audio_filename, 251 'rb')) as rec_file: 252 # Extract data from files. 253 reference_file_frames = extract_wav_frames(reference_file) 254 rec_file_frames = extract_wav_frames(rec_file) 255 256 num_channels = reference_file.getnchannels() 257 _compare_frames(reference_file_frames, rec_file_frames, 258 reference_file.getnchannels(), 259 reference_file.getframerate()) 260