• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
2#
3# Use of this source code is governed by a BSD-style license
4# that can be found in the LICENSE file in the root of the source
5# tree. An additional intellectual property rights grant can be found
6# in the file PATENTS.  All contributing project authors may
7# be found in the AUTHORS file in the root of the source tree.
8
9"""Unit tests for the eval_scores module.
10"""
11
12import os
13import shutil
14import tempfile
15import unittest
16
17import pydub
18
19from . import data_access
20from . import eval_scores
21from . import eval_scores_factory
22from . import signal_processing
23
24
25class TestEvalScores(unittest.TestCase):
26  """Unit tests for the eval_scores module.
27  """
28
29  def setUp(self):
30    """Create temporary output folder and two audio track files."""
31    self._output_path = tempfile.mkdtemp()
32
33    # Create fake reference and tested (i.e., APM output) audio track files.
34    silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
35    fake_reference_signal = (
36        signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
37    fake_tested_signal = (
38        signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
39
40    # Save fake audio tracks.
41    self._fake_reference_signal_filepath = os.path.join(
42        self._output_path, 'fake_ref.wav')
43    signal_processing.SignalProcessingUtils.SaveWav(
44        self._fake_reference_signal_filepath, fake_reference_signal)
45    self._fake_tested_signal_filepath = os.path.join(
46        self._output_path, 'fake_test.wav')
47    signal_processing.SignalProcessingUtils.SaveWav(
48        self._fake_tested_signal_filepath, fake_tested_signal)
49
50  def tearDown(self):
51    """Recursively delete temporary folder."""
52    shutil.rmtree(self._output_path)
53
54  def testRegisteredClasses(self):
55    # Evaluation score names to exclude (tested separately).
56    exceptions = ['thd', 'echo_metric']
57
58    # Preliminary check.
59    self.assertTrue(os.path.exists(self._output_path))
60
61    # Check that there is at least one registered evaluation score worker.
62    registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
63    self.assertIsInstance(registered_classes, dict)
64    self.assertGreater(len(registered_classes), 0)
65
66    # Instance evaluation score workers factory with fake dependencies.
67    eval_score_workers_factory = (
68        eval_scores_factory.EvaluationScoreWorkerFactory(
69            polqa_tool_bin_path=os.path.join(
70                os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
71            echo_metric_tool_bin_path=None
72        ))
73    eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
74
75    # Try each registered evaluation score worker.
76    for eval_score_name in registered_classes:
77      if eval_score_name in exceptions:
78        continue
79
80      # Instance evaluation score worker.
81      eval_score_worker = eval_score_workers_factory.GetInstance(
82          registered_classes[eval_score_name])
83
84      # Set fake input metadata and reference and test file paths, then run.
85      eval_score_worker.SetReferenceSignalFilepath(
86          self._fake_reference_signal_filepath)
87      eval_score_worker.SetTestedSignalFilepath(
88          self._fake_tested_signal_filepath)
89      eval_score_worker.Run(self._output_path)
90
91      # Check output.
92      score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
93      self.assertTrue(isinstance(score, float))
94
95  def testTotalHarmonicDistorsionScore(self):
96    # Init.
97    pure_tone_freq = 5000.0
98    eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
99    eval_score_worker.SetInputSignalMetadata({
100        'signal': 'pure_tone',
101        'frequency': pure_tone_freq,
102        'test_data_gen_name': 'identity',
103        'test_data_gen_config': 'default',
104    })
105    template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
106
107    # Create 3 test signals: pure tone, pure tone + white noise, white noise
108    # only.
109    pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
110        template, pure_tone_freq)
111    white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
112        template)
113    noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
114        pure_tone, white_noise)
115
116    # Compute scores for increasingly distorted pure tone signals.
117    scores = [None, None, None]
118    for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]):
119      # Save signal.
120      tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
121      signal_processing.SignalProcessingUtils.SaveWav(
122          tmp_filepath, tested_signal)
123
124      # Compute score.
125      eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
126      eval_score_worker.Run(self._output_path)
127      scores[index] = eval_score_worker.score
128
129      # Remove output file to avoid caching.
130      os.remove(eval_score_worker.output_filepath)
131
132    # Validate scores (lowest score with a pure tone).
133    self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
134