1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8"""Evaluation score abstract class and implementations. 9""" 10 11from __future__ import division 12import logging 13import os 14import re 15import subprocess 16import sys 17 18try: 19 import numpy as np 20except ImportError: 21 logging.critical('Cannot import the third-party Python package numpy') 22 sys.exit(1) 23 24from . import data_access 25from . import exceptions 26from . import signal_processing 27 28 29class EvaluationScore(object): 30 31 NAME = None 32 REGISTERED_CLASSES = {} 33 34 def __init__(self, score_filename_prefix): 35 self._score_filename_prefix = score_filename_prefix 36 self._input_signal_metadata = None 37 self._reference_signal = None 38 self._reference_signal_filepath = None 39 self._tested_signal = None 40 self._tested_signal_filepath = None 41 self._output_filepath = None 42 self._score = None 43 self._render_signal_filepath = None 44 45 @classmethod 46 def RegisterClass(cls, class_to_register): 47 """Registers an EvaluationScore implementation. 48 49 Decorator to automatically register the classes that extend EvaluationScore. 50 Example usage: 51 52 @EvaluationScore.RegisterClass 53 class AudioLevelScore(EvaluationScore): 54 pass 55 """ 56 cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register 57 return class_to_register 58 59 @property 60 def output_filepath(self): 61 return self._output_filepath 62 63 @property 64 def score(self): 65 return self._score 66 67 def SetInputSignalMetadata(self, metadata): 68 """Sets input signal metadata. 69 70 Args: 71 metadata: dict instance. 72 """ 73 self._input_signal_metadata = metadata 74 75 def SetReferenceSignalFilepath(self, filepath): 76 """Sets the path to the audio track used as reference signal. 77 78 Args: 79 filepath: path to the reference audio track. 80 """ 81 self._reference_signal_filepath = filepath 82 83 def SetTestedSignalFilepath(self, filepath): 84 """Sets the path to the audio track used as test signal. 85 86 Args: 87 filepath: path to the test audio track. 88 """ 89 self._tested_signal_filepath = filepath 90 91 def SetRenderSignalFilepath(self, filepath): 92 """Sets the path to the audio track used as render signal. 93 94 Args: 95 filepath: path to the test audio track. 96 """ 97 self._render_signal_filepath = filepath 98 99 def Run(self, output_path): 100 """Extracts the score for the set test data pair. 101 102 Args: 103 output_path: path to the directory where the output is written. 104 """ 105 self._output_filepath = os.path.join( 106 output_path, self._score_filename_prefix + self.NAME + '.txt') 107 try: 108 # If the score has already been computed, load. 109 self._LoadScore() 110 logging.debug('score found and loaded') 111 except IOError: 112 # Compute the score. 113 logging.debug('score not found, compute') 114 self._Run(output_path) 115 116 def _Run(self, output_path): 117 # Abstract method. 118 raise NotImplementedError() 119 120 def _LoadReferenceSignal(self): 121 assert self._reference_signal_filepath is not None 122 self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav( 123 self._reference_signal_filepath) 124 125 def _LoadTestedSignal(self): 126 assert self._tested_signal_filepath is not None 127 self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav( 128 self._tested_signal_filepath) 129 130 def _LoadScore(self): 131 return data_access.ScoreFile.Load(self._output_filepath) 132 133 def _SaveScore(self): 134 return data_access.ScoreFile.Save(self._output_filepath, self._score) 135 136 137@EvaluationScore.RegisterClass 138class AudioLevelPeakScore(EvaluationScore): 139 """Peak audio level score. 140 141 Defined as the difference between the peak audio level of the tested and 142 the reference signals. 143 144 Unit: dB 145 Ideal: 0 dB 146 Worst case: +/-inf dB 147 """ 148 149 NAME = 'audio_level_peak' 150 151 def __init__(self, score_filename_prefix): 152 EvaluationScore.__init__(self, score_filename_prefix) 153 154 def _Run(self, output_path): 155 self._LoadReferenceSignal() 156 self._LoadTestedSignal() 157 self._score = self._tested_signal.dBFS - self._reference_signal.dBFS 158 self._SaveScore() 159 160 161@EvaluationScore.RegisterClass 162class MeanAudioLevelScore(EvaluationScore): 163 """Mean audio level score. 164 165 Defined as the difference between the mean audio level of the tested and 166 the reference signals. 167 168 Unit: dB 169 Ideal: 0 dB 170 Worst case: +/-inf dB 171 """ 172 173 NAME = 'audio_level_mean' 174 175 def __init__(self, score_filename_prefix): 176 EvaluationScore.__init__(self, score_filename_prefix) 177 178 def _Run(self, output_path): 179 self._LoadReferenceSignal() 180 self._LoadTestedSignal() 181 182 dbfs_diffs_sum = 0.0 183 seconds = min(len(self._tested_signal), len( 184 self._reference_signal)) // 1000 185 for t in range(seconds): 186 t0 = t * seconds 187 t1 = t0 + seconds 188 dbfs_diffs_sum += (self._tested_signal[t0:t1].dBFS - 189 self._reference_signal[t0:t1].dBFS) 190 self._score = dbfs_diffs_sum / float(seconds) 191 self._SaveScore() 192 193 194@EvaluationScore.RegisterClass 195class EchoMetric(EvaluationScore): 196 """Echo score. 197 198 Proportion of detected echo. 199 200 Unit: ratio 201 Ideal: 0 202 Worst case: 1 203 """ 204 205 NAME = 'echo_metric' 206 207 def __init__(self, score_filename_prefix, echo_detector_bin_filepath): 208 EvaluationScore.__init__(self, score_filename_prefix) 209 210 # POLQA binary file path. 211 self._echo_detector_bin_filepath = echo_detector_bin_filepath 212 if not os.path.exists(self._echo_detector_bin_filepath): 213 logging.error('cannot find EchoMetric tool binary file') 214 raise exceptions.FileNotFoundError() 215 216 self._echo_detector_bin_path, _ = os.path.split( 217 self._echo_detector_bin_filepath) 218 219 def _Run(self, output_path): 220 echo_detector_out_filepath = os.path.join(output_path, 221 'echo_detector.out') 222 if os.path.exists(echo_detector_out_filepath): 223 os.unlink(echo_detector_out_filepath) 224 225 logging.debug("Render signal filepath: %s", 226 self._render_signal_filepath) 227 if not os.path.exists(self._render_signal_filepath): 228 logging.error( 229 "Render input required for evaluating the echo metric.") 230 231 args = [ 232 self._echo_detector_bin_filepath, '--output_file', 233 echo_detector_out_filepath, '--', '-i', 234 self._tested_signal_filepath, '-ri', self._render_signal_filepath 235 ] 236 logging.debug(' '.join(args)) 237 subprocess.call(args, cwd=self._echo_detector_bin_path) 238 239 # Parse Echo detector tool output and extract the score. 240 self._score = self._ParseOutputFile(echo_detector_out_filepath) 241 self._SaveScore() 242 243 @classmethod 244 def _ParseOutputFile(cls, echo_metric_file_path): 245 """ 246 Parses the POLQA tool output formatted as a table ('-t' option). 247 248 Args: 249 polqa_out_filepath: path to the POLQA tool output file. 250 251 Returns: 252 The score as a number in [0, 1]. 253 """ 254 with open(echo_metric_file_path) as f: 255 return float(f.read()) 256 257 258@EvaluationScore.RegisterClass 259class PolqaScore(EvaluationScore): 260 """POLQA score. 261 262 See http://www.polqa.info/. 263 264 Unit: MOS 265 Ideal: 4.5 266 Worst case: 1.0 267 """ 268 269 NAME = 'polqa' 270 271 def __init__(self, score_filename_prefix, polqa_bin_filepath): 272 EvaluationScore.__init__(self, score_filename_prefix) 273 274 # POLQA binary file path. 275 self._polqa_bin_filepath = polqa_bin_filepath 276 if not os.path.exists(self._polqa_bin_filepath): 277 logging.error('cannot find POLQA tool binary file') 278 raise exceptions.FileNotFoundError() 279 280 # Path to the POLQA directory with binary and license files. 281 self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath) 282 283 def _Run(self, output_path): 284 polqa_out_filepath = os.path.join(output_path, 'polqa.out') 285 if os.path.exists(polqa_out_filepath): 286 os.unlink(polqa_out_filepath) 287 288 args = [ 289 self._polqa_bin_filepath, 290 '-t', 291 '-q', 292 '-Overwrite', 293 '-Ref', 294 self._reference_signal_filepath, 295 '-Test', 296 self._tested_signal_filepath, 297 '-LC', 298 'NB', 299 '-Out', 300 polqa_out_filepath, 301 ] 302 logging.debug(' '.join(args)) 303 subprocess.call(args, cwd=self._polqa_tool_path) 304 305 # Parse POLQA tool output and extract the score. 306 polqa_output = self._ParseOutputFile(polqa_out_filepath) 307 self._score = float(polqa_output['PolqaScore']) 308 309 self._SaveScore() 310 311 @classmethod 312 def _ParseOutputFile(cls, polqa_out_filepath): 313 """ 314 Parses the POLQA tool output formatted as a table ('-t' option). 315 316 Args: 317 polqa_out_filepath: path to the POLQA tool output file. 318 319 Returns: 320 A dict. 321 """ 322 data = [] 323 with open(polqa_out_filepath) as f: 324 for line in f: 325 line = line.strip() 326 if len(line) == 0 or line.startswith('*'): 327 # Ignore comments. 328 continue 329 # Read fields. 330 data.append(re.split(r'\t+', line)) 331 332 # Two rows expected (header and values). 333 assert len(data) == 2, 'Cannot parse POLQA output' 334 number_of_fields = len(data[0]) 335 assert number_of_fields == len(data[1]) 336 337 # Build and return a dictionary with field names (header) as keys and the 338 # corresponding field values as values. 339 return { 340 data[0][index]: data[1][index] 341 for index in range(number_of_fields) 342 } 343 344 345@EvaluationScore.RegisterClass 346class TotalHarmonicDistorsionScore(EvaluationScore): 347 """Total harmonic distorsion plus noise score. 348 349 Total harmonic distorsion plus noise score. 350 See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN". 351 352 Unit: -. 353 Ideal: 0. 354 Worst case: +inf 355 """ 356 357 NAME = 'thd' 358 359 def __init__(self, score_filename_prefix): 360 EvaluationScore.__init__(self, score_filename_prefix) 361 self._input_frequency = None 362 363 def _Run(self, output_path): 364 self._CheckInputSignal() 365 366 self._LoadTestedSignal() 367 if self._tested_signal.channels != 1: 368 raise exceptions.EvaluationScoreException( 369 'unsupported number of channels') 370 samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData( 371 self._tested_signal) 372 373 # Init. 374 num_samples = len(samples) 375 duration = len(self._tested_signal) / 1000.0 376 scaling = 2.0 / num_samples 377 max_freq = self._tested_signal.frame_rate / 2 378 f0_freq = float(self._input_frequency) 379 t = np.linspace(0, duration, num_samples) 380 381 # Analyze harmonics. 382 b_terms = [] 383 n = 1 384 while f0_freq * n < max_freq: 385 x_n = np.sum( 386 samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling 387 y_n = np.sum( 388 samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling 389 b_terms.append(np.sqrt(x_n**2 + y_n**2)) 390 n += 1 391 392 output_without_fundamental = samples - b_terms[0] * np.sin( 393 2.0 * np.pi * f0_freq * t) 394 distortion_and_noise = np.sqrt( 395 np.sum(output_without_fundamental**2) * np.pi * scaling) 396 397 # TODO(alessiob): Fix or remove if not needed. 398 # thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0] 399 400 # TODO(alessiob): Check the range of `thd_plus_noise` and update the class 401 # docstring above if accordingly. 402 thd_plus_noise = distortion_and_noise / b_terms[0] 403 404 self._score = thd_plus_noise 405 self._SaveScore() 406 407 def _CheckInputSignal(self): 408 # Check input signal and get properties. 409 try: 410 if self._input_signal_metadata['signal'] != 'pure_tone': 411 raise exceptions.EvaluationScoreException( 412 'The THD score requires a pure tone as input signal') 413 self._input_frequency = self._input_signal_metadata['frequency'] 414 if self._input_signal_metadata[ 415 'test_data_gen_name'] != 'identity' or ( 416 self._input_signal_metadata['test_data_gen_config'] != 417 'default'): 418 raise exceptions.EvaluationScoreException( 419 'The THD score cannot be used with any test data generator other ' 420 'than "identity"') 421 except TypeError: 422 raise exceptions.EvaluationScoreException( 423 'The THD score requires an input signal with associated metadata' 424 ) 425 except KeyError: 426 raise exceptions.EvaluationScoreException( 427 'Invalid input signal metadata to compute the THD score') 428