1 /*
2 * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/agc2/signal_classifier.h"
12
13 #include <algorithm>
14 #include <numeric>
15 #include <vector>
16
17 #include "api/array_view.h"
18 #include "modules/audio_processing/agc2/down_sampler.h"
19 #include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
20 #include "modules/audio_processing/logging/apm_data_dumper.h"
21 #include "rtc_base/checks.h"
22 #include "system_wrappers/include/cpu_features_wrapper.h"
23
24 namespace webrtc {
25 namespace {
26
IsSse2Available()27 bool IsSse2Available() {
28 #if defined(WEBRTC_ARCH_X86_FAMILY)
29 return WebRtc_GetCPUInfo(kSSE2) != 0;
30 #else
31 return false;
32 #endif
33 }
34
RemoveDcLevel(rtc::ArrayView<float> x)35 void RemoveDcLevel(rtc::ArrayView<float> x) {
36 RTC_DCHECK_LT(0, x.size());
37 float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f);
38 mean /= x.size();
39
40 for (float& v : x) {
41 v -= mean;
42 }
43 }
44
PowerSpectrum(const OouraFft * ooura_fft,rtc::ArrayView<const float> x,rtc::ArrayView<float> spectrum)45 void PowerSpectrum(const OouraFft* ooura_fft,
46 rtc::ArrayView<const float> x,
47 rtc::ArrayView<float> spectrum) {
48 RTC_DCHECK_EQ(65, spectrum.size());
49 RTC_DCHECK_EQ(128, x.size());
50 float X[128];
51 std::copy(x.data(), x.data() + x.size(), X);
52 ooura_fft->Fft(X);
53
54 float* X_p = X;
55 RTC_DCHECK_EQ(X_p, &X[0]);
56 spectrum[0] = (*X_p) * (*X_p);
57 ++X_p;
58 RTC_DCHECK_EQ(X_p, &X[1]);
59 spectrum[64] = (*X_p) * (*X_p);
60 for (int k = 1; k < 64; ++k) {
61 ++X_p;
62 RTC_DCHECK_EQ(X_p, &X[2 * k]);
63 spectrum[k] = (*X_p) * (*X_p);
64 ++X_p;
65 RTC_DCHECK_EQ(X_p, &X[2 * k + 1]);
66 spectrum[k] += (*X_p) * (*X_p);
67 }
68 }
69
ClassifySignal(rtc::ArrayView<const float> signal_spectrum,rtc::ArrayView<const float> noise_spectrum,ApmDataDumper * data_dumper)70 webrtc::SignalClassifier::SignalType ClassifySignal(
71 rtc::ArrayView<const float> signal_spectrum,
72 rtc::ArrayView<const float> noise_spectrum,
73 ApmDataDumper* data_dumper) {
74 int num_stationary_bands = 0;
75 int num_highly_nonstationary_bands = 0;
76
77 // Detect stationary and highly nonstationary bands.
78 for (size_t k = 1; k < 40; k++) {
79 if (signal_spectrum[k] < 3 * noise_spectrum[k] &&
80 signal_spectrum[k] * 3 > noise_spectrum[k]) {
81 ++num_stationary_bands;
82 } else if (signal_spectrum[k] > 9 * noise_spectrum[k]) {
83 ++num_highly_nonstationary_bands;
84 }
85 }
86
87 data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands);
88 data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1,
89 &num_highly_nonstationary_bands);
90
91 // Use the detected number of bands to classify the overall signal
92 // stationarity.
93 if (num_stationary_bands > 15) {
94 return SignalClassifier::SignalType::kStationary;
95 } else {
96 return SignalClassifier::SignalType::kNonStationary;
97 }
98 }
99
100 } // namespace
101
FrameExtender(size_t frame_size,size_t extended_frame_size)102 SignalClassifier::FrameExtender::FrameExtender(size_t frame_size,
103 size_t extended_frame_size)
104 : x_old_(extended_frame_size - frame_size, 0.f) {}
105
106 SignalClassifier::FrameExtender::~FrameExtender() = default;
107
ExtendFrame(rtc::ArrayView<const float> x,rtc::ArrayView<float> x_extended)108 void SignalClassifier::FrameExtender::ExtendFrame(
109 rtc::ArrayView<const float> x,
110 rtc::ArrayView<float> x_extended) {
111 RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size());
112 std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data());
113 std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size());
114 std::copy(x_extended.data() + x_extended.size() - x_old_.size(),
115 x_extended.data() + x_extended.size(), x_old_.data());
116 }
117
SignalClassifier(ApmDataDumper * data_dumper)118 SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper)
119 : data_dumper_(data_dumper),
120 down_sampler_(data_dumper_),
121 noise_spectrum_estimator_(data_dumper_),
122 ooura_fft_(IsSse2Available()) {
123 Initialize(48000);
124 }
~SignalClassifier()125 SignalClassifier::~SignalClassifier() {}
126
Initialize(int sample_rate_hz)127 void SignalClassifier::Initialize(int sample_rate_hz) {
128 down_sampler_.Initialize(sample_rate_hz);
129 noise_spectrum_estimator_.Initialize();
130 frame_extender_.reset(new FrameExtender(80, 128));
131 sample_rate_hz_ = sample_rate_hz;
132 initialization_frames_left_ = 2;
133 consistent_classification_counter_ = 3;
134 last_signal_type_ = SignalClassifier::SignalType::kNonStationary;
135 }
136
Analyze(rtc::ArrayView<const float> signal)137 SignalClassifier::SignalType SignalClassifier::Analyze(
138 rtc::ArrayView<const float> signal) {
139 RTC_DCHECK_EQ(signal.size(), sample_rate_hz_ / 100);
140
141 // Compute the signal power spectrum.
142 float downsampled_frame[80];
143 down_sampler_.DownSample(signal, downsampled_frame);
144 float extended_frame[128];
145 frame_extender_->ExtendFrame(downsampled_frame, extended_frame);
146 RemoveDcLevel(extended_frame);
147 float signal_spectrum[65];
148 PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum);
149
150 // Classify the signal based on the estimate of the noise spectrum and the
151 // signal spectrum estimate.
152 const SignalType signal_type = ClassifySignal(
153 signal_spectrum, noise_spectrum_estimator_.GetNoiseSpectrum(),
154 data_dumper_);
155
156 // Update the noise spectrum based on the signal spectrum.
157 noise_spectrum_estimator_.Update(signal_spectrum,
158 initialization_frames_left_ > 0);
159
160 // Update the number of frames until a reliable signal spectrum is achieved.
161 initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1);
162
163 if (last_signal_type_ == signal_type) {
164 consistent_classification_counter_ =
165 std::max(0, consistent_classification_counter_ - 1);
166 } else {
167 last_signal_type_ = signal_type;
168 consistent_classification_counter_ = 3;
169 }
170
171 if (consistent_classification_counter_ > 0) {
172 return SignalClassifier::SignalType::kNonStationary;
173 }
174 return signal_type;
175 }
176
177 } // namespace webrtc
178