• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8"""Unit tests for the input mixer module.
9"""
10
11import logging
12import os
13import shutil
14import tempfile
15import unittest
16
17import mock
18
19from . import exceptions
20from . import input_mixer
21from . import signal_processing
22
23
24class TestApmInputMixer(unittest.TestCase):
25    """Unit tests for the ApmInputMixer class.
26  """
27
28    # Audio track file names created in setUp().
29    _FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
30
31    # Target peak power level (dBFS) of each audio track file created in setUp().
32    # These values are hand-crafted in order to make saturation happen when
33    # capture and echo_2 are mixed and the contrary for capture and echo_1.
34    # None means that the power is not changed.
35    _MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
36
37    # Audio track file durations in milliseconds.
38    _DURATIONS = [1000, 1000, 1000, 800, 1200]
39
40    _SAMPLE_RATE = 48000
41
42    def setUp(self):
43        """Creates temporary data."""
44        self._tmp_path = tempfile.mkdtemp()
45
46        # Create audio track files.
47        self._audio_tracks = {}
48        for filename, peak_power, duration in zip(self._FILENAMES,
49                                                  self._MAX_PEAK_POWER_LEVELS,
50                                                  self._DURATIONS):
51            audio_track_filepath = os.path.join(self._tmp_path,
52                                                '{}.wav'.format(filename))
53
54            # Create a pure tone with the target peak power level.
55            template = signal_processing.SignalProcessingUtils.GenerateSilence(
56                duration=duration, sample_rate=self._SAMPLE_RATE)
57            signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
58                template)
59            if peak_power is not None:
60                signal = signal.apply_gain(-signal.max_dBFS + peak_power)
61
62            signal_processing.SignalProcessingUtils.SaveWav(
63                audio_track_filepath, signal)
64            self._audio_tracks[filename] = {
65                'filepath':
66                audio_track_filepath,
67                'num_samples':
68                signal_processing.SignalProcessingUtils.CountSamples(signal)
69            }
70
71    def tearDown(self):
72        """Recursively deletes temporary folders."""
73        shutil.rmtree(self._tmp_path)
74
75    def testCheckMixSameDuration(self):
76        """Checks the duration when mixing capture and echo with same duration."""
77        mix_filepath = input_mixer.ApmInputMixer.Mix(
78            self._tmp_path, self._audio_tracks['capture']['filepath'],
79            self._audio_tracks['echo_1']['filepath'])
80        self.assertTrue(os.path.exists(mix_filepath))
81
82        mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
83        self.assertEqual(
84            self._audio_tracks['capture']['num_samples'],
85            signal_processing.SignalProcessingUtils.CountSamples(mix))
86
87    def testRejectShorterEcho(self):
88        """Rejects echo signals that are shorter than the capture signal."""
89        try:
90            _ = input_mixer.ApmInputMixer.Mix(
91                self._tmp_path, self._audio_tracks['capture']['filepath'],
92                self._audio_tracks['shorter']['filepath'])
93            self.fail('no exception raised')
94        except exceptions.InputMixerException:
95            pass
96
97    def testCheckMixDurationWithLongerEcho(self):
98        """Checks the duration when mixing an echo longer than the capture."""
99        mix_filepath = input_mixer.ApmInputMixer.Mix(
100            self._tmp_path, self._audio_tracks['capture']['filepath'],
101            self._audio_tracks['longer']['filepath'])
102        self.assertTrue(os.path.exists(mix_filepath))
103
104        mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
105        self.assertEqual(
106            self._audio_tracks['capture']['num_samples'],
107            signal_processing.SignalProcessingUtils.CountSamples(mix))
108
109    def testCheckOutputFileNamesConflict(self):
110        """Checks that different echo files lead to different output file names."""
111        mix1_filepath = input_mixer.ApmInputMixer.Mix(
112            self._tmp_path, self._audio_tracks['capture']['filepath'],
113            self._audio_tracks['echo_1']['filepath'])
114        self.assertTrue(os.path.exists(mix1_filepath))
115
116        mix2_filepath = input_mixer.ApmInputMixer.Mix(
117            self._tmp_path, self._audio_tracks['capture']['filepath'],
118            self._audio_tracks['echo_2']['filepath'])
119        self.assertTrue(os.path.exists(mix2_filepath))
120
121        self.assertNotEqual(mix1_filepath, mix2_filepath)
122
123    def testHardClippingLogExpected(self):
124        """Checks that hard clipping warning is raised when occurring."""
125        logging.warning = mock.MagicMock(name='warning')
126        _ = input_mixer.ApmInputMixer.Mix(
127            self._tmp_path, self._audio_tracks['capture']['filepath'],
128            self._audio_tracks['echo_2']['filepath'])
129        logging.warning.assert_called_once_with(
130            input_mixer.ApmInputMixer.HardClippingLogMessage())
131
132    def testHardClippingLogNotExpected(self):
133        """Checks that hard clipping warning is not raised when not occurring."""
134        logging.warning = mock.MagicMock(name='warning')
135        _ = input_mixer.ApmInputMixer.Mix(
136            self._tmp_path, self._audio_tracks['capture']['filepath'],
137            self._audio_tracks['echo_1']['filepath'])
138        self.assertNotIn(
139            mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
140            logging.warning.call_args_list)
141