1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2# 3# Use of this source code is governed by a BSD-style license 4# that can be found in the LICENSE file in the root of the source 5# tree. An additional intellectual property rights grant can be found 6# in the file PATENTS. All contributing project authors may 7# be found in the AUTHORS file in the root of the source tree. 8 9"""Unit tests for the input mixer module. 10""" 11 12import logging 13import os 14import shutil 15import tempfile 16import unittest 17 18import mock 19 20from . import exceptions 21from . import input_mixer 22from . import signal_processing 23 24 25class TestApmInputMixer(unittest.TestCase): 26 """Unit tests for the ApmInputMixer class. 27 """ 28 29 # Audio track file names created in setUp(). 30 _FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer'] 31 32 # Target peak power level (dBFS) of each audio track file created in setUp(). 33 # These values are hand-crafted in order to make saturation happen when 34 # capture and echo_2 are mixed and the contrary for capture and echo_1. 35 # None means that the power is not changed. 36 _MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None] 37 38 # Audio track file durations in milliseconds. 39 _DURATIONS = [1000, 1000, 1000, 800, 1200] 40 41 _SAMPLE_RATE = 48000 42 43 def setUp(self): 44 """Creates temporary data.""" 45 self._tmp_path = tempfile.mkdtemp() 46 47 # Create audio track files. 48 self._audio_tracks = {} 49 for filename, peak_power, duration in zip( 50 self._FILENAMES, self._MAX_PEAK_POWER_LEVELS, self._DURATIONS): 51 audio_track_filepath = os.path.join(self._tmp_path, '{}.wav'.format( 52 filename)) 53 54 # Create a pure tone with the target peak power level. 55 template = signal_processing.SignalProcessingUtils.GenerateSilence( 56 duration=duration, sample_rate=self._SAMPLE_RATE) 57 signal = signal_processing.SignalProcessingUtils.GeneratePureTone( 58 template) 59 if peak_power is not None: 60 signal = signal.apply_gain(-signal.max_dBFS + peak_power) 61 62 signal_processing.SignalProcessingUtils.SaveWav( 63 audio_track_filepath, signal) 64 self._audio_tracks[filename] = { 65 'filepath': audio_track_filepath, 66 'num_samples': signal_processing.SignalProcessingUtils.CountSamples( 67 signal) 68 } 69 70 def tearDown(self): 71 """Recursively deletes temporary folders.""" 72 shutil.rmtree(self._tmp_path) 73 74 def testCheckMixSameDuration(self): 75 """Checks the duration when mixing capture and echo with same duration.""" 76 mix_filepath = input_mixer.ApmInputMixer.Mix( 77 self._tmp_path, 78 self._audio_tracks['capture']['filepath'], 79 self._audio_tracks['echo_1']['filepath']) 80 self.assertTrue(os.path.exists(mix_filepath)) 81 82 mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath) 83 self.assertEqual(self._audio_tracks['capture']['num_samples'], 84 signal_processing.SignalProcessingUtils.CountSamples(mix)) 85 86 def testRejectShorterEcho(self): 87 """Rejects echo signals that are shorter than the capture signal.""" 88 try: 89 _ = input_mixer.ApmInputMixer.Mix( 90 self._tmp_path, 91 self._audio_tracks['capture']['filepath'], 92 self._audio_tracks['shorter']['filepath']) 93 self.fail('no exception raised') 94 except exceptions.InputMixerException: 95 pass 96 97 def testCheckMixDurationWithLongerEcho(self): 98 """Checks the duration when mixing an echo longer than the capture.""" 99 mix_filepath = input_mixer.ApmInputMixer.Mix( 100 self._tmp_path, 101 self._audio_tracks['capture']['filepath'], 102 self._audio_tracks['longer']['filepath']) 103 self.assertTrue(os.path.exists(mix_filepath)) 104 105 mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath) 106 self.assertEqual(self._audio_tracks['capture']['num_samples'], 107 signal_processing.SignalProcessingUtils.CountSamples(mix)) 108 109 def testCheckOutputFileNamesConflict(self): 110 """Checks that different echo files lead to different output file names.""" 111 mix1_filepath = input_mixer.ApmInputMixer.Mix( 112 self._tmp_path, 113 self._audio_tracks['capture']['filepath'], 114 self._audio_tracks['echo_1']['filepath']) 115 self.assertTrue(os.path.exists(mix1_filepath)) 116 117 mix2_filepath = input_mixer.ApmInputMixer.Mix( 118 self._tmp_path, 119 self._audio_tracks['capture']['filepath'], 120 self._audio_tracks['echo_2']['filepath']) 121 self.assertTrue(os.path.exists(mix2_filepath)) 122 123 self.assertNotEqual(mix1_filepath, mix2_filepath) 124 125 def testHardClippingLogExpected(self): 126 """Checks that hard clipping warning is raised when occurring.""" 127 logging.warning = mock.MagicMock(name='warning') 128 _ = input_mixer.ApmInputMixer.Mix( 129 self._tmp_path, 130 self._audio_tracks['capture']['filepath'], 131 self._audio_tracks['echo_2']['filepath']) 132 logging.warning.assert_called_once_with( 133 input_mixer.ApmInputMixer.HardClippingLogMessage()) 134 135 def testHardClippingLogNotExpected(self): 136 """Checks that hard clipping warning is not raised when not occurring.""" 137 logging.warning = mock.MagicMock(name='warning') 138 _ = input_mixer.ApmInputMixer.Mix( 139 self._tmp_path, 140 self._audio_tracks['capture']['filepath'], 141 self._audio_tracks['echo_1']['filepath']) 142 self.assertNotIn( 143 mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()), 144 logging.warning.call_args_list) 145