• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Unit tests for the simulation module.
10"""
11
12import logging
13import os
14import shutil
15import tempfile
16import unittest
17
18import mock
19import pydub
20
21from . import audioproc_wrapper
22from . import eval_scores_factory
23from . import evaluation
24from . import external_vad
25from . import signal_processing
26from . import simulation
27from . import test_data_generation_factory
28
29
30class TestApmModuleSimulator(unittest.TestCase):
31  """Unit tests for the ApmModuleSimulator class.
32  """
33
34  def setUp(self):
35    """Create temporary folders and fake audio track."""
36    self._output_path = tempfile.mkdtemp()
37    self._tmp_path = tempfile.mkdtemp()
38
39    silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
40    fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
41        silence)
42    self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav')
43    signal_processing.SignalProcessingUtils.SaveWav(
44        self._fake_audio_track_path, fake_signal)
45
46  def tearDown(self):
47    """Recursively delete temporary folders."""
48    shutil.rmtree(self._output_path)
49    shutil.rmtree(self._tmp_path)
50
51  def testSimulation(self):
52    # Instance dependencies to mock and inject.
53    ap_wrapper = audioproc_wrapper.AudioProcWrapper(
54        audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
55    evaluator = evaluation.ApmModuleEvaluator()
56    ap_wrapper.Run = mock.MagicMock(name='Run')
57    evaluator.Run = mock.MagicMock(name='Run')
58
59    # Instance non-mocked dependencies.
60    test_data_generator_factory = (
61        test_data_generation_factory.TestDataGeneratorFactory(
62            aechen_ir_database_path='',
63            noise_tracks_path='',
64            copy_with_identity=False))
65    evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
66        polqa_tool_bin_path=os.path.join(
67            os.path.dirname(__file__), 'fake_polqa'),
68        echo_metric_tool_bin_path=None
69    )
70
71    # Instance simulator.
72    simulator = simulation.ApmModuleSimulator(
73        test_data_generator_factory=test_data_generator_factory,
74        evaluation_score_factory=evaluation_score_factory,
75        ap_wrapper=ap_wrapper,
76        evaluator=evaluator,
77        external_vads={'fake': external_vad.ExternalVad(os.path.join(
78            os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')}
79    )
80
81    # What to simulate.
82    config_files = ['apm_configs/default.json']
83    input_files = [self._fake_audio_track_path]
84    test_data_generators = ['identity', 'white_noise']
85    eval_scores = ['audio_level_mean', 'polqa']
86
87    # Run all simulations.
88    simulator.Run(
89        config_filepaths=config_files,
90        capture_input_filepaths=input_files,
91        test_data_generator_names=test_data_generators,
92        eval_score_names=eval_scores,
93        output_dir=self._output_path)
94
95    # Check.
96    # TODO(alessiob): Once the TestDataGenerator classes can be configured by
97    # the client code (e.g., number of SNR pairs for the white noise test data
98    # generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
99    # is known; use that with assertEqual.
100    min_number_of_simulations = len(config_files) * len(input_files) * len(
101        test_data_generators)
102    self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
103                            min_number_of_simulations)
104    self.assertGreaterEqual(len(evaluator.Run.call_args_list),
105                            min_number_of_simulations)
106
107  def testInputSignalCreation(self):
108    # Instance simulator.
109    simulator = simulation.ApmModuleSimulator(
110        test_data_generator_factory=(
111            test_data_generation_factory.TestDataGeneratorFactory(
112                aechen_ir_database_path='',
113                noise_tracks_path='',
114                copy_with_identity=False)),
115        evaluation_score_factory=(
116            eval_scores_factory.EvaluationScoreWorkerFactory(
117                polqa_tool_bin_path=os.path.join(
118                    os.path.dirname(__file__), 'fake_polqa'),
119                echo_metric_tool_bin_path=None
120            )),
121        ap_wrapper=audioproc_wrapper.AudioProcWrapper(
122            audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
123        evaluator=evaluation.ApmModuleEvaluator())
124
125    # Inexistent input files to be silently created.
126    input_files = [
127        os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
128        os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
129    ]
130    self.assertFalse(any([os.path.exists(input_file) for input_file in (
131        input_files)]))
132
133    # The input files are created during the simulation.
134    simulator.Run(
135        config_filepaths=['apm_configs/default.json'],
136        capture_input_filepaths=input_files,
137        test_data_generator_names=['identity'],
138        eval_score_names=['audio_level_peak'],
139        output_dir=self._output_path)
140    self.assertTrue(all([os.path.exists(input_file) for input_file in (
141        input_files)]))
142
143  def testPureToneGenerationWithTotalHarmonicDistorsion(self):
144    logging.warning = mock.MagicMock(name='warning')
145
146    # Instance simulator.
147    simulator = simulation.ApmModuleSimulator(
148        test_data_generator_factory=(
149            test_data_generation_factory.TestDataGeneratorFactory(
150                aechen_ir_database_path='',
151                noise_tracks_path='',
152                copy_with_identity=False)),
153        evaluation_score_factory=(
154            eval_scores_factory.EvaluationScoreWorkerFactory(
155                polqa_tool_bin_path=os.path.join(
156                    os.path.dirname(__file__), 'fake_polqa'),
157                echo_metric_tool_bin_path=None
158            )),
159        ap_wrapper=audioproc_wrapper.AudioProcWrapper(
160            audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
161        evaluator=evaluation.ApmModuleEvaluator())
162
163    # What to simulate.
164    config_files = ['apm_configs/default.json']
165    input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
166    eval_scores = ['thd']
167
168    # Should work.
169    simulator.Run(
170        config_filepaths=config_files,
171        capture_input_filepaths=input_files,
172        test_data_generator_names=['identity'],
173        eval_score_names=eval_scores,
174        output_dir=self._output_path)
175    self.assertFalse(logging.warning.called)
176
177    # Warning expected.
178    simulator.Run(
179        config_filepaths=config_files,
180        capture_input_filepaths=input_files,
181        test_data_generator_names=['white_noise'],  # Not allowed with THD.
182        eval_score_names=eval_scores,
183        output_dir=self._output_path)
184    logging.warning.assert_called_with('the evaluation failed: %s', (
185        'The THD score cannot be used with any test data generator other than '
186        '"identity"'))
187
188  #   # Init.
189  #   generator = test_data_generation.IdentityTestDataGenerator('tmp')
190  #   input_signal_filepath = os.path.join(
191  #       self._test_data_cache_path, 'pure_tone-440_1000.wav')
192
193  #   # Check that the input signal is generated.
194  #   self.assertFalse(os.path.exists(input_signal_filepath))
195  #   generator.Generate(
196  #       input_signal_filepath=input_signal_filepath,
197  #       test_data_cache_path=self._test_data_cache_path,
198  #       base_output_path=self._base_output_path)
199  #   self.assertTrue(os.path.exists(input_signal_filepath))
200
201  #   # Check input signal properties.
202  #   input_signal = signal_processing.SignalProcessingUtils.LoadWav(
203  #       input_signal_filepath)
204  #   self.assertEqual(1000, len(input_signal))
205