• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Unit tests for the input mixer module.
10"""
11
12import logging
13import os
14import shutil
15import tempfile
16import unittest
17
18import mock
19
20from . import exceptions
21from . import input_mixer
22from . import signal_processing
23
24
25class TestApmInputMixer(unittest.TestCase):
26  """Unit tests for the ApmInputMixer class.
27  """
28
29  # Audio track file names created in setUp().
30  _FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
31
32  # Target peak power level (dBFS) of each audio track file created in setUp().
33  # These values are hand-crafted in order to make saturation happen when
34  # capture and echo_2 are mixed and the contrary for capture and echo_1.
35  # None means that the power is not changed.
36  _MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
37
38  # Audio track file durations in milliseconds.
39  _DURATIONS = [1000, 1000, 1000, 800, 1200]
40
41  _SAMPLE_RATE = 48000
42
43  def setUp(self):
44    """Creates temporary data."""
45    self._tmp_path = tempfile.mkdtemp()
46
47    # Create audio track files.
48    self._audio_tracks = {}
49    for filename, peak_power, duration in zip(
50        self._FILENAMES, self._MAX_PEAK_POWER_LEVELS, self._DURATIONS):
51      audio_track_filepath = os.path.join(self._tmp_path, '{}.wav'.format(
52          filename))
53
54      # Create a pure tone with the target peak power level.
55      template = signal_processing.SignalProcessingUtils.GenerateSilence(
56          duration=duration, sample_rate=self._SAMPLE_RATE)
57      signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
58          template)
59      if peak_power is not None:
60        signal = signal.apply_gain(-signal.max_dBFS + peak_power)
61
62      signal_processing.SignalProcessingUtils.SaveWav(
63        audio_track_filepath, signal)
64      self._audio_tracks[filename] = {
65          'filepath': audio_track_filepath,
66          'num_samples': signal_processing.SignalProcessingUtils.CountSamples(
67              signal)
68      }
69
70  def tearDown(self):
71    """Recursively deletes temporary folders."""
72    shutil.rmtree(self._tmp_path)
73
74  def testCheckMixSameDuration(self):
75    """Checks the duration when mixing capture and echo with same duration."""
76    mix_filepath = input_mixer.ApmInputMixer.Mix(
77        self._tmp_path,
78        self._audio_tracks['capture']['filepath'],
79        self._audio_tracks['echo_1']['filepath'])
80    self.assertTrue(os.path.exists(mix_filepath))
81
82    mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
83    self.assertEqual(self._audio_tracks['capture']['num_samples'],
84                     signal_processing.SignalProcessingUtils.CountSamples(mix))
85
86  def testRejectShorterEcho(self):
87    """Rejects echo signals that are shorter than the capture signal."""
88    try:
89      _ = input_mixer.ApmInputMixer.Mix(
90          self._tmp_path,
91          self._audio_tracks['capture']['filepath'],
92          self._audio_tracks['shorter']['filepath'])
93      self.fail('no exception raised')
94    except exceptions.InputMixerException:
95      pass
96
97  def testCheckMixDurationWithLongerEcho(self):
98    """Checks the duration when mixing an echo longer than the capture."""
99    mix_filepath = input_mixer.ApmInputMixer.Mix(
100        self._tmp_path,
101        self._audio_tracks['capture']['filepath'],
102        self._audio_tracks['longer']['filepath'])
103    self.assertTrue(os.path.exists(mix_filepath))
104
105    mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
106    self.assertEqual(self._audio_tracks['capture']['num_samples'],
107                     signal_processing.SignalProcessingUtils.CountSamples(mix))
108
109  def testCheckOutputFileNamesConflict(self):
110    """Checks that different echo files lead to different output file names."""
111    mix1_filepath = input_mixer.ApmInputMixer.Mix(
112        self._tmp_path,
113        self._audio_tracks['capture']['filepath'],
114        self._audio_tracks['echo_1']['filepath'])
115    self.assertTrue(os.path.exists(mix1_filepath))
116
117    mix2_filepath = input_mixer.ApmInputMixer.Mix(
118        self._tmp_path,
119        self._audio_tracks['capture']['filepath'],
120        self._audio_tracks['echo_2']['filepath'])
121    self.assertTrue(os.path.exists(mix2_filepath))
122
123    self.assertNotEqual(mix1_filepath, mix2_filepath)
124
125  def testHardClippingLogExpected(self):
126    """Checks that hard clipping warning is raised when occurring."""
127    logging.warning = mock.MagicMock(name='warning')
128    _ = input_mixer.ApmInputMixer.Mix(
129        self._tmp_path,
130        self._audio_tracks['capture']['filepath'],
131        self._audio_tracks['echo_2']['filepath'])
132    logging.warning.assert_called_once_with(
133        input_mixer.ApmInputMixer.HardClippingLogMessage())
134
135  def testHardClippingLogNotExpected(self):
136    """Checks that hard clipping warning is not raised when not occurring."""
137    logging.warning = mock.MagicMock(name='warning')
138    _ = input_mixer.ApmInputMixer.Mix(
139        self._tmp_path,
140        self._audio_tracks['capture']['filepath'],
141        self._audio_tracks['echo_1']['filepath'])
142    self.assertNotIn(
143        mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
144        logging.warning.call_args_list)
145