1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8 9"""Evaluation score abstract class and implementations. 10""" 11 12from __future__ import division 13import logging 14import os 15import re 16import subprocess 17import sys 18 19try: 20 import numpy as np 21except ImportError: 22 logging.critical('Cannot import the third-party Python package numpy') 23 sys.exit(1) 24 25from . import data_access 26from . import exceptions 27from . import signal_processing 28 29 30class EvaluationScore(object): 31 32 NAME = None 33 REGISTERED_CLASSES = {} 34 35 def __init__(self, score_filename_prefix): 36 self._score_filename_prefix = score_filename_prefix 37 self._input_signal_metadata = None 38 self._reference_signal = None 39 self._reference_signal_filepath = None 40 self._tested_signal = None 41 self._tested_signal_filepath = None 42 self._output_filepath = None 43 self._score = None 44 self._render_signal_filepath = None 45 46 @classmethod 47 def RegisterClass(cls, class_to_register): 48 """Registers an EvaluationScore implementation. 49 50 Decorator to automatically register the classes that extend EvaluationScore. 51 Example usage: 52 53 @EvaluationScore.RegisterClass 54 class AudioLevelScore(EvaluationScore): 55 pass 56 """ 57 cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register 58 return class_to_register 59 60 @property 61 def output_filepath(self): 62 return self._output_filepath 63 64 @property 65 def score(self): 66 return self._score 67 68 def SetInputSignalMetadata(self, metadata): 69 """Sets input signal metadata. 70 71 Args: 72 metadata: dict instance. 73 """ 74 self._input_signal_metadata = metadata 75 76 def SetReferenceSignalFilepath(self, filepath): 77 """Sets the path to the audio track used as reference signal. 78 79 Args: 80 filepath: path to the reference audio track. 81 """ 82 self._reference_signal_filepath = filepath 83 84 def SetTestedSignalFilepath(self, filepath): 85 """Sets the path to the audio track used as test signal. 86 87 Args: 88 filepath: path to the test audio track. 89 """ 90 self._tested_signal_filepath = filepath 91 92 def SetRenderSignalFilepath(self, filepath): 93 """Sets the path to the audio track used as render signal. 94 95 Args: 96 filepath: path to the test audio track. 97 """ 98 self._render_signal_filepath = filepath 99 100 def Run(self, output_path): 101 """Extracts the score for the set test data pair. 102 103 Args: 104 output_path: path to the directory where the output is written. 105 """ 106 self._output_filepath = os.path.join( 107 output_path, self._score_filename_prefix + self.NAME + '.txt') 108 try: 109 # If the score has already been computed, load. 110 self._LoadScore() 111 logging.debug('score found and loaded') 112 except IOError: 113 # Compute the score. 114 logging.debug('score not found, compute') 115 self._Run(output_path) 116 117 def _Run(self, output_path): 118 # Abstract method. 119 raise NotImplementedError() 120 121 def _LoadReferenceSignal(self): 122 assert self._reference_signal_filepath is not None 123 self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav( 124 self._reference_signal_filepath) 125 126 def _LoadTestedSignal(self): 127 assert self._tested_signal_filepath is not None 128 self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav( 129 self._tested_signal_filepath) 130 131 132 def _LoadScore(self): 133 return data_access.ScoreFile.Load(self._output_filepath) 134 135 def _SaveScore(self): 136 return data_access.ScoreFile.Save(self._output_filepath, self._score) 137 138 139@EvaluationScore.RegisterClass 140class AudioLevelPeakScore(EvaluationScore): 141 """Peak audio level score. 142 143 Defined as the difference between the peak audio level of the tested and 144 the reference signals. 145 146 Unit: dB 147 Ideal: 0 dB 148 Worst case: +/-inf dB 149 """ 150 151 NAME = 'audio_level_peak' 152 153 def __init__(self, score_filename_prefix): 154 EvaluationScore.__init__(self, score_filename_prefix) 155 156 def _Run(self, output_path): 157 self._LoadReferenceSignal() 158 self._LoadTestedSignal() 159 self._score = self._tested_signal.dBFS - self._reference_signal.dBFS 160 self._SaveScore() 161 162 163@EvaluationScore.RegisterClass 164class MeanAudioLevelScore(EvaluationScore): 165 """Mean audio level score. 166 167 Defined as the difference between the mean audio level of the tested and 168 the reference signals. 169 170 Unit: dB 171 Ideal: 0 dB 172 Worst case: +/-inf dB 173 """ 174 175 NAME = 'audio_level_mean' 176 177 def __init__(self, score_filename_prefix): 178 EvaluationScore.__init__(self, score_filename_prefix) 179 180 def _Run(self, output_path): 181 self._LoadReferenceSignal() 182 self._LoadTestedSignal() 183 184 dbfs_diffs_sum = 0.0 185 seconds = min(len(self._tested_signal), len(self._reference_signal)) // 1000 186 for t in range(seconds): 187 t0 = t * seconds 188 t1 = t0 + seconds 189 dbfs_diffs_sum += ( 190 self._tested_signal[t0:t1].dBFS - self._reference_signal[t0:t1].dBFS) 191 self._score = dbfs_diffs_sum / float(seconds) 192 self._SaveScore() 193 194 195@EvaluationScore.RegisterClass 196class EchoMetric(EvaluationScore): 197 """Echo score. 198 199 Proportion of detected echo. 200 201 Unit: ratio 202 Ideal: 0 203 Worst case: 1 204 """ 205 206 NAME = 'echo_metric' 207 208 def __init__(self, score_filename_prefix, echo_detector_bin_filepath): 209 EvaluationScore.__init__(self, score_filename_prefix) 210 211 # POLQA binary file path. 212 self._echo_detector_bin_filepath = echo_detector_bin_filepath 213 if not os.path.exists(self._echo_detector_bin_filepath): 214 logging.error('cannot find EchoMetric tool binary file') 215 raise exceptions.FileNotFoundError() 216 217 self._echo_detector_bin_path, _ = os.path.split( 218 self._echo_detector_bin_filepath) 219 220 def _Run(self, output_path): 221 echo_detector_out_filepath = os.path.join(output_path, 'echo_detector.out') 222 if os.path.exists(echo_detector_out_filepath): 223 os.unlink(echo_detector_out_filepath) 224 225 logging.debug("Render signal filepath: %s", self._render_signal_filepath) 226 if not os.path.exists(self._render_signal_filepath): 227 logging.error("Render input required for evaluating the echo metric.") 228 229 args = [ 230 self._echo_detector_bin_filepath, 231 '--output_file', echo_detector_out_filepath, 232 '--', 233 '-i', self._tested_signal_filepath, 234 '-ri', self._render_signal_filepath 235 ] 236 logging.debug(' '.join(args)) 237 subprocess.call(args, cwd=self._echo_detector_bin_path) 238 239 # Parse Echo detector tool output and extract the score. 240 self._score = self._ParseOutputFile(echo_detector_out_filepath) 241 self._SaveScore() 242 243 @classmethod 244 def _ParseOutputFile(cls, echo_metric_file_path): 245 """ 246 Parses the POLQA tool output formatted as a table ('-t' option). 247 248 Args: 249 polqa_out_filepath: path to the POLQA tool output file. 250 251 Returns: 252 The score as a number in [0, 1]. 253 """ 254 with open(echo_metric_file_path) as f: 255 return float(f.read()) 256 257@EvaluationScore.RegisterClass 258class PolqaScore(EvaluationScore): 259 """POLQA score. 260 261 See http://www.polqa.info/. 262 263 Unit: MOS 264 Ideal: 4.5 265 Worst case: 1.0 266 """ 267 268 NAME = 'polqa' 269 270 def __init__(self, score_filename_prefix, polqa_bin_filepath): 271 EvaluationScore.__init__(self, score_filename_prefix) 272 273 # POLQA binary file path. 274 self._polqa_bin_filepath = polqa_bin_filepath 275 if not os.path.exists(self._polqa_bin_filepath): 276 logging.error('cannot find POLQA tool binary file') 277 raise exceptions.FileNotFoundError() 278 279 # Path to the POLQA directory with binary and license files. 280 self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath) 281 282 def _Run(self, output_path): 283 polqa_out_filepath = os.path.join(output_path, 'polqa.out') 284 if os.path.exists(polqa_out_filepath): 285 os.unlink(polqa_out_filepath) 286 287 args = [ 288 self._polqa_bin_filepath, '-t', '-q', '-Overwrite', 289 '-Ref', self._reference_signal_filepath, 290 '-Test', self._tested_signal_filepath, 291 '-LC', 'NB', 292 '-Out', polqa_out_filepath, 293 ] 294 logging.debug(' '.join(args)) 295 subprocess.call(args, cwd=self._polqa_tool_path) 296 297 # Parse POLQA tool output and extract the score. 298 polqa_output = self._ParseOutputFile(polqa_out_filepath) 299 self._score = float(polqa_output['PolqaScore']) 300 301 self._SaveScore() 302 303 @classmethod 304 def _ParseOutputFile(cls, polqa_out_filepath): 305 """ 306 Parses the POLQA tool output formatted as a table ('-t' option). 307 308 Args: 309 polqa_out_filepath: path to the POLQA tool output file. 310 311 Returns: 312 A dict. 313 """ 314 data = [] 315 with open(polqa_out_filepath) as f: 316 for line in f: 317 line = line.strip() 318 if len(line) == 0 or line.startswith('*'): 319 # Ignore comments. 320 continue 321 # Read fields. 322 data.append(re.split(r'\t+', line)) 323 324 # Two rows expected (header and values). 325 assert len(data) == 2, 'Cannot parse POLQA output' 326 number_of_fields = len(data[0]) 327 assert number_of_fields == len(data[1]) 328 329 # Build and return a dictionary with field names (header) as keys and the 330 # corresponding field values as values. 331 return {data[0][index]: data[1][index] for index in range(number_of_fields)} 332 333 334@EvaluationScore.RegisterClass 335class TotalHarmonicDistorsionScore(EvaluationScore): 336 """Total harmonic distorsion plus noise score. 337 338 Total harmonic distorsion plus noise score. 339 See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN". 340 341 Unit: -. 342 Ideal: 0. 343 Worst case: +inf 344 """ 345 346 NAME = 'thd' 347 348 def __init__(self, score_filename_prefix): 349 EvaluationScore.__init__(self, score_filename_prefix) 350 self._input_frequency = None 351 352 def _Run(self, output_path): 353 self._CheckInputSignal() 354 355 self._LoadTestedSignal() 356 if self._tested_signal.channels != 1: 357 raise exceptions.EvaluationScoreException( 358 'unsupported number of channels') 359 samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData( 360 self._tested_signal) 361 362 # Init. 363 num_samples = len(samples) 364 duration = len(self._tested_signal) / 1000.0 365 scaling = 2.0 / num_samples 366 max_freq = self._tested_signal.frame_rate / 2 367 f0_freq = float(self._input_frequency) 368 t = np.linspace(0, duration, num_samples) 369 370 # Analyze harmonics. 371 b_terms = [] 372 n = 1 373 while f0_freq * n < max_freq: 374 x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling 375 y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling 376 b_terms.append(np.sqrt(x_n**2 + y_n**2)) 377 n += 1 378 379 output_without_fundamental = samples - b_terms[0] * np.sin( 380 2.0 * np.pi * f0_freq * t) 381 distortion_and_noise = np.sqrt(np.sum( 382 output_without_fundamental**2) * np.pi * scaling) 383 384 # TODO(alessiob): Fix or remove if not needed. 385 # thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0] 386 387 # TODO(alessiob): Check the range of |thd_plus_noise| and update the class 388 # docstring above if accordingly. 389 thd_plus_noise = distortion_and_noise / b_terms[0] 390 391 self._score = thd_plus_noise 392 self._SaveScore() 393 394 def _CheckInputSignal(self): 395 # Check input signal and get properties. 396 try: 397 if self._input_signal_metadata['signal'] != 'pure_tone': 398 raise exceptions.EvaluationScoreException( 399 'The THD score requires a pure tone as input signal') 400 self._input_frequency = self._input_signal_metadata['frequency'] 401 if self._input_signal_metadata['test_data_gen_name'] != 'identity' or ( 402 self._input_signal_metadata['test_data_gen_config'] != 'default'): 403 raise exceptions.EvaluationScoreException( 404 'The THD score cannot be used with any test data generator other ' 405 'than "identity"') 406 except TypeError: 407 raise exceptions.EvaluationScoreException( 408 'The THD score requires an input signal with associated metadata') 409 except KeyError: 410 raise exceptions.EvaluationScoreException( 411 'Invalid input signal metadata to compute the THD score') 412