1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8"""Unit tests for the eval_scores module. 9""" 10 11import os 12import shutil 13import tempfile 14import unittest 15 16import pydub 17 18from . import data_access 19from . import eval_scores 20from . import eval_scores_factory 21from . import signal_processing 22 23 24class TestEvalScores(unittest.TestCase): 25 """Unit tests for the eval_scores module. 26 """ 27 28 def setUp(self): 29 """Create temporary output folder and two audio track files.""" 30 self._output_path = tempfile.mkdtemp() 31 32 # Create fake reference and tested (i.e., APM output) audio track files. 33 silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) 34 fake_reference_signal = (signal_processing.SignalProcessingUtils. 35 GenerateWhiteNoise(silence)) 36 fake_tested_signal = (signal_processing.SignalProcessingUtils. 37 GenerateWhiteNoise(silence)) 38 39 # Save fake audio tracks. 40 self._fake_reference_signal_filepath = os.path.join( 41 self._output_path, 'fake_ref.wav') 42 signal_processing.SignalProcessingUtils.SaveWav( 43 self._fake_reference_signal_filepath, fake_reference_signal) 44 self._fake_tested_signal_filepath = os.path.join( 45 self._output_path, 'fake_test.wav') 46 signal_processing.SignalProcessingUtils.SaveWav( 47 self._fake_tested_signal_filepath, fake_tested_signal) 48 49 def tearDown(self): 50 """Recursively delete temporary folder.""" 51 shutil.rmtree(self._output_path) 52 53 def testRegisteredClasses(self): 54 # Evaluation score names to exclude (tested separately). 55 exceptions = ['thd', 'echo_metric'] 56 57 # Preliminary check. 58 self.assertTrue(os.path.exists(self._output_path)) 59 60 # Check that there is at least one registered evaluation score worker. 61 registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES 62 self.assertIsInstance(registered_classes, dict) 63 self.assertGreater(len(registered_classes), 0) 64 65 # Instance evaluation score workers factory with fake dependencies. 66 eval_score_workers_factory = ( 67 eval_scores_factory.EvaluationScoreWorkerFactory( 68 polqa_tool_bin_path=os.path.join( 69 os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'), 70 echo_metric_tool_bin_path=None)) 71 eval_score_workers_factory.SetScoreFilenamePrefix('scores-') 72 73 # Try each registered evaluation score worker. 74 for eval_score_name in registered_classes: 75 if eval_score_name in exceptions: 76 continue 77 78 # Instance evaluation score worker. 79 eval_score_worker = eval_score_workers_factory.GetInstance( 80 registered_classes[eval_score_name]) 81 82 # Set fake input metadata and reference and test file paths, then run. 83 eval_score_worker.SetReferenceSignalFilepath( 84 self._fake_reference_signal_filepath) 85 eval_score_worker.SetTestedSignalFilepath( 86 self._fake_tested_signal_filepath) 87 eval_score_worker.Run(self._output_path) 88 89 # Check output. 90 score = data_access.ScoreFile.Load( 91 eval_score_worker.output_filepath) 92 self.assertTrue(isinstance(score, float)) 93 94 def testTotalHarmonicDistorsionScore(self): 95 # Init. 96 pure_tone_freq = 5000.0 97 eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-') 98 eval_score_worker.SetInputSignalMetadata({ 99 'signal': 100 'pure_tone', 101 'frequency': 102 pure_tone_freq, 103 'test_data_gen_name': 104 'identity', 105 'test_data_gen_config': 106 'default', 107 }) 108 template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) 109 110 # Create 3 test signals: pure tone, pure tone + white noise, white noise 111 # only. 112 pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone( 113 template, pure_tone_freq) 114 white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( 115 template) 116 noisy_tone = signal_processing.SignalProcessingUtils.MixSignals( 117 pure_tone, white_noise) 118 119 # Compute scores for increasingly distorted pure tone signals. 120 scores = [None, None, None] 121 for index, tested_signal in enumerate( 122 [pure_tone, noisy_tone, white_noise]): 123 # Save signal. 124 tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav') 125 signal_processing.SignalProcessingUtils.SaveWav( 126 tmp_filepath, tested_signal) 127 128 # Compute score. 129 eval_score_worker.SetTestedSignalFilepath(tmp_filepath) 130 eval_score_worker.Run(self._output_path) 131 scores[index] = eval_score_worker.score 132 133 # Remove output file to avoid caching. 134 os.remove(eval_score_worker.output_filepath) 135 136 # Validate scores (lowest score with a pure tone). 137 self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)])) 138