1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8 9"""Unit tests for the eval_scores module. 10""" 11 12import os 13import shutil 14import tempfile 15import unittest 16 17import pydub 18 19from . import data_access 20from . import eval_scores 21from . import eval_scores_factory 22from . import signal_processing 23 24 25class TestEvalScores(unittest.TestCase): 26 """Unit tests for the eval_scores module. 27 """ 28 29 def setUp(self): 30 """Create temporary output folder and two audio track files.""" 31 self._output_path = tempfile.mkdtemp() 32 33 # Create fake reference and tested (i.e., APM output) audio track files. 34 silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) 35 fake_reference_signal = ( 36 signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence)) 37 fake_tested_signal = ( 38 signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence)) 39 40 # Save fake audio tracks. 41 self._fake_reference_signal_filepath = os.path.join( 42 self._output_path, 'fake_ref.wav') 43 signal_processing.SignalProcessingUtils.SaveWav( 44 self._fake_reference_signal_filepath, fake_reference_signal) 45 self._fake_tested_signal_filepath = os.path.join( 46 self._output_path, 'fake_test.wav') 47 signal_processing.SignalProcessingUtils.SaveWav( 48 self._fake_tested_signal_filepath, fake_tested_signal) 49 50 def tearDown(self): 51 """Recursively delete temporary folder.""" 52 shutil.rmtree(self._output_path) 53 54 def testRegisteredClasses(self): 55 # Evaluation score names to exclude (tested separately). 56 exceptions = ['thd', 'echo_metric'] 57 58 # Preliminary check. 59 self.assertTrue(os.path.exists(self._output_path)) 60 61 # Check that there is at least one registered evaluation score worker. 62 registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES 63 self.assertIsInstance(registered_classes, dict) 64 self.assertGreater(len(registered_classes), 0) 65 66 # Instance evaluation score workers factory with fake dependencies. 67 eval_score_workers_factory = ( 68 eval_scores_factory.EvaluationScoreWorkerFactory( 69 polqa_tool_bin_path=os.path.join( 70 os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'), 71 echo_metric_tool_bin_path=None 72 )) 73 eval_score_workers_factory.SetScoreFilenamePrefix('scores-') 74 75 # Try each registered evaluation score worker. 76 for eval_score_name in registered_classes: 77 if eval_score_name in exceptions: 78 continue 79 80 # Instance evaluation score worker. 81 eval_score_worker = eval_score_workers_factory.GetInstance( 82 registered_classes[eval_score_name]) 83 84 # Set fake input metadata and reference and test file paths, then run. 85 eval_score_worker.SetReferenceSignalFilepath( 86 self._fake_reference_signal_filepath) 87 eval_score_worker.SetTestedSignalFilepath( 88 self._fake_tested_signal_filepath) 89 eval_score_worker.Run(self._output_path) 90 91 # Check output. 92 score = data_access.ScoreFile.Load(eval_score_worker.output_filepath) 93 self.assertTrue(isinstance(score, float)) 94 95 def testTotalHarmonicDistorsionScore(self): 96 # Init. 97 pure_tone_freq = 5000.0 98 eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-') 99 eval_score_worker.SetInputSignalMetadata({ 100 'signal': 'pure_tone', 101 'frequency': pure_tone_freq, 102 'test_data_gen_name': 'identity', 103 'test_data_gen_config': 'default', 104 }) 105 template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) 106 107 # Create 3 test signals: pure tone, pure tone + white noise, white noise 108 # only. 109 pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone( 110 template, pure_tone_freq) 111 white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( 112 template) 113 noisy_tone = signal_processing.SignalProcessingUtils.MixSignals( 114 pure_tone, white_noise) 115 116 # Compute scores for increasingly distorted pure tone signals. 117 scores = [None, None, None] 118 for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]): 119 # Save signal. 120 tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav') 121 signal_processing.SignalProcessingUtils.SaveWav( 122 tmp_filepath, tested_signal) 123 124 # Compute score. 125 eval_score_worker.SetTestedSignalFilepath(tmp_filepath) 126 eval_score_worker.Run(self._output_path) 127 scores[index] = eval_score_worker.score 128 129 # Remove output file to avoid caching. 130 os.remove(eval_score_worker.output_filepath) 131 132 # Validate scores (lowest score with a pure tone). 133 self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)])) 134