• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8"""Unit tests for the eval_scores module.
9"""
10
11import os
12import shutil
13import tempfile
14import unittest
15
16import pydub
17
18from . import data_access
19from . import eval_scores
20from . import eval_scores_factory
21from . import signal_processing
22
23
24class TestEvalScores(unittest.TestCase):
25    """Unit tests for the eval_scores module.
26  """
27
28    def setUp(self):
29        """Create temporary output folder and two audio track files."""
30        self._output_path = tempfile.mkdtemp()
31
32        # Create fake reference and tested (i.e., APM output) audio track files.
33        silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
34        fake_reference_signal = (signal_processing.SignalProcessingUtils.
35                                 GenerateWhiteNoise(silence))
36        fake_tested_signal = (signal_processing.SignalProcessingUtils.
37                              GenerateWhiteNoise(silence))
38
39        # Save fake audio tracks.
40        self._fake_reference_signal_filepath = os.path.join(
41            self._output_path, 'fake_ref.wav')
42        signal_processing.SignalProcessingUtils.SaveWav(
43            self._fake_reference_signal_filepath, fake_reference_signal)
44        self._fake_tested_signal_filepath = os.path.join(
45            self._output_path, 'fake_test.wav')
46        signal_processing.SignalProcessingUtils.SaveWav(
47            self._fake_tested_signal_filepath, fake_tested_signal)
48
49    def tearDown(self):
50        """Recursively delete temporary folder."""
51        shutil.rmtree(self._output_path)
52
53    def testRegisteredClasses(self):
54        # Evaluation score names to exclude (tested separately).
55        exceptions = ['thd', 'echo_metric']
56
57        # Preliminary check.
58        self.assertTrue(os.path.exists(self._output_path))
59
60        # Check that there is at least one registered evaluation score worker.
61        registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
62        self.assertIsInstance(registered_classes, dict)
63        self.assertGreater(len(registered_classes), 0)
64
65        # Instance evaluation score workers factory with fake dependencies.
66        eval_score_workers_factory = (
67            eval_scores_factory.EvaluationScoreWorkerFactory(
68                polqa_tool_bin_path=os.path.join(
69                    os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
70                echo_metric_tool_bin_path=None))
71        eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
72
73        # Try each registered evaluation score worker.
74        for eval_score_name in registered_classes:
75            if eval_score_name in exceptions:
76                continue
77
78            # Instance evaluation score worker.
79            eval_score_worker = eval_score_workers_factory.GetInstance(
80                registered_classes[eval_score_name])
81
82            # Set fake input metadata and reference and test file paths, then run.
83            eval_score_worker.SetReferenceSignalFilepath(
84                self._fake_reference_signal_filepath)
85            eval_score_worker.SetTestedSignalFilepath(
86                self._fake_tested_signal_filepath)
87            eval_score_worker.Run(self._output_path)
88
89            # Check output.
90            score = data_access.ScoreFile.Load(
91                eval_score_worker.output_filepath)
92            self.assertTrue(isinstance(score, float))
93
94    def testTotalHarmonicDistorsionScore(self):
95        # Init.
96        pure_tone_freq = 5000.0
97        eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
98        eval_score_worker.SetInputSignalMetadata({
99            'signal':
100            'pure_tone',
101            'frequency':
102            pure_tone_freq,
103            'test_data_gen_name':
104            'identity',
105            'test_data_gen_config':
106            'default',
107        })
108        template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
109
110        # Create 3 test signals: pure tone, pure tone + white noise, white noise
111        # only.
112        pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
113            template, pure_tone_freq)
114        white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
115            template)
116        noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
117            pure_tone, white_noise)
118
119        # Compute scores for increasingly distorted pure tone signals.
120        scores = [None, None, None]
121        for index, tested_signal in enumerate(
122            [pure_tone, noisy_tone, white_noise]):
123            # Save signal.
124            tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
125            signal_processing.SignalProcessingUtils.SaveWav(
126                tmp_filepath, tested_signal)
127
128            # Compute score.
129            eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
130            eval_score_worker.Run(self._output_path)
131            scores[index] = eval_score_worker.score
132
133            # Remove output file to avoid caching.
134            os.remove(eval_score_worker.output_filepath)
135
136        # Validate scores (lowest score with a pure tone).
137        self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
138