1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8"""Unit tests for the input mixer module. 9""" 10 11import logging 12import os 13import shutil 14import tempfile 15import unittest 16 17import mock 18 19from . import exceptions 20from . import input_mixer 21from . import signal_processing 22 23 24class TestApmInputMixer(unittest.TestCase): 25 """Unit tests for the ApmInputMixer class. 26 """ 27 28 # Audio track file names created in setUp(). 29 _FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer'] 30 31 # Target peak power level (dBFS) of each audio track file created in setUp(). 32 # These values are hand-crafted in order to make saturation happen when 33 # capture and echo_2 are mixed and the contrary for capture and echo_1. 34 # None means that the power is not changed. 35 _MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None] 36 37 # Audio track file durations in milliseconds. 38 _DURATIONS = [1000, 1000, 1000, 800, 1200] 39 40 _SAMPLE_RATE = 48000 41 42 def setUp(self): 43 """Creates temporary data.""" 44 self._tmp_path = tempfile.mkdtemp() 45 46 # Create audio track files. 47 self._audio_tracks = {} 48 for filename, peak_power, duration in zip(self._FILENAMES, 49 self._MAX_PEAK_POWER_LEVELS, 50 self._DURATIONS): 51 audio_track_filepath = os.path.join(self._tmp_path, 52 '{}.wav'.format(filename)) 53 54 # Create a pure tone with the target peak power level. 55 template = signal_processing.SignalProcessingUtils.GenerateSilence( 56 duration=duration, sample_rate=self._SAMPLE_RATE) 57 signal = signal_processing.SignalProcessingUtils.GeneratePureTone( 58 template) 59 if peak_power is not None: 60 signal = signal.apply_gain(-signal.max_dBFS + peak_power) 61 62 signal_processing.SignalProcessingUtils.SaveWav( 63 audio_track_filepath, signal) 64 self._audio_tracks[filename] = { 65 'filepath': 66 audio_track_filepath, 67 'num_samples': 68 signal_processing.SignalProcessingUtils.CountSamples(signal) 69 } 70 71 def tearDown(self): 72 """Recursively deletes temporary folders.""" 73 shutil.rmtree(self._tmp_path) 74 75 def testCheckMixSameDuration(self): 76 """Checks the duration when mixing capture and echo with same duration.""" 77 mix_filepath = input_mixer.ApmInputMixer.Mix( 78 self._tmp_path, self._audio_tracks['capture']['filepath'], 79 self._audio_tracks['echo_1']['filepath']) 80 self.assertTrue(os.path.exists(mix_filepath)) 81 82 mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath) 83 self.assertEqual( 84 self._audio_tracks['capture']['num_samples'], 85 signal_processing.SignalProcessingUtils.CountSamples(mix)) 86 87 def testRejectShorterEcho(self): 88 """Rejects echo signals that are shorter than the capture signal.""" 89 try: 90 _ = input_mixer.ApmInputMixer.Mix( 91 self._tmp_path, self._audio_tracks['capture']['filepath'], 92 self._audio_tracks['shorter']['filepath']) 93 self.fail('no exception raised') 94 except exceptions.InputMixerException: 95 pass 96 97 def testCheckMixDurationWithLongerEcho(self): 98 """Checks the duration when mixing an echo longer than the capture.""" 99 mix_filepath = input_mixer.ApmInputMixer.Mix( 100 self._tmp_path, self._audio_tracks['capture']['filepath'], 101 self._audio_tracks['longer']['filepath']) 102 self.assertTrue(os.path.exists(mix_filepath)) 103 104 mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath) 105 self.assertEqual( 106 self._audio_tracks['capture']['num_samples'], 107 signal_processing.SignalProcessingUtils.CountSamples(mix)) 108 109 def testCheckOutputFileNamesConflict(self): 110 """Checks that different echo files lead to different output file names.""" 111 mix1_filepath = input_mixer.ApmInputMixer.Mix( 112 self._tmp_path, self._audio_tracks['capture']['filepath'], 113 self._audio_tracks['echo_1']['filepath']) 114 self.assertTrue(os.path.exists(mix1_filepath)) 115 116 mix2_filepath = input_mixer.ApmInputMixer.Mix( 117 self._tmp_path, self._audio_tracks['capture']['filepath'], 118 self._audio_tracks['echo_2']['filepath']) 119 self.assertTrue(os.path.exists(mix2_filepath)) 120 121 self.assertNotEqual(mix1_filepath, mix2_filepath) 122 123 def testHardClippingLogExpected(self): 124 """Checks that hard clipping warning is raised when occurring.""" 125 logging.warning = mock.MagicMock(name='warning') 126 _ = input_mixer.ApmInputMixer.Mix( 127 self._tmp_path, self._audio_tracks['capture']['filepath'], 128 self._audio_tracks['echo_2']['filepath']) 129 logging.warning.assert_called_once_with( 130 input_mixer.ApmInputMixer.HardClippingLogMessage()) 131 132 def testHardClippingLogNotExpected(self): 133 """Checks that hard clipping warning is not raised when not occurring.""" 134 logging.warning = mock.MagicMock(name='warning') 135 _ = input_mixer.ApmInputMixer.Mix( 136 self._tmp_path, self._audio_tracks['capture']['filepath'], 137 self._audio_tracks['echo_1']['filepath']) 138 self.assertNotIn( 139 mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()), 140 logging.warning.call_args_list) 141