1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10 #include "modules/audio_processing/aec3/matched_filter.h"
11
12 // Defines WEBRTC_ARCH_X86_FAMILY, used below.
13 #include "rtc_base/system/arch.h"
14
15 #if defined(WEBRTC_HAS_NEON)
16 #include <arm_neon.h>
17 #endif
18 #if defined(WEBRTC_ARCH_X86_FAMILY)
19 #include <emmintrin.h>
20 #endif
21 #include <algorithm>
22 #include <cstddef>
23 #include <initializer_list>
24 #include <iterator>
25 #include <numeric>
26
27 #include "modules/audio_processing/aec3/downsampled_render_buffer.h"
28 #include "modules/audio_processing/logging/apm_data_dumper.h"
29 #include "rtc_base/checks.h"
30 #include "rtc_base/logging.h"
31
32 namespace webrtc {
33 namespace aec3 {
34
35 #if defined(WEBRTC_HAS_NEON)
36
MatchedFilterCore_NEON(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)37 void MatchedFilterCore_NEON(size_t x_start_index,
38 float x2_sum_threshold,
39 float smoothing,
40 rtc::ArrayView<const float> x,
41 rtc::ArrayView<const float> y,
42 rtc::ArrayView<float> h,
43 bool* filters_updated,
44 float* error_sum) {
45 const int h_size = static_cast<int>(h.size());
46 const int x_size = static_cast<int>(x.size());
47 RTC_DCHECK_EQ(0, h_size % 4);
48
49 // Process for all samples in the sub-block.
50 for (size_t i = 0; i < y.size(); ++i) {
51 // Apply the matched filter as filter * x, and compute x * x.
52
53 RTC_DCHECK_GT(x_size, x_start_index);
54 const float* x_p = &x[x_start_index];
55 const float* h_p = &h[0];
56
57 // Initialize values for the accumulation.
58 float32x4_t s_128 = vdupq_n_f32(0);
59 float32x4_t x2_sum_128 = vdupq_n_f32(0);
60 float x2_sum = 0.f;
61 float s = 0;
62
63 // Compute loop chunk sizes until, and after, the wraparound of the circular
64 // buffer for x.
65 const int chunk1 =
66 std::min(h_size, static_cast<int>(x_size - x_start_index));
67
68 // Perform the loop in two chunks.
69 const int chunk2 = h_size - chunk1;
70 for (int limit : {chunk1, chunk2}) {
71 // Perform 128 bit vector operations.
72 const int limit_by_4 = limit >> 2;
73 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
74 // Load the data into 128 bit vectors.
75 const float32x4_t x_k = vld1q_f32(x_p);
76 const float32x4_t h_k = vld1q_f32(h_p);
77 // Compute and accumulate x * x and h * x.
78 x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
79 s_128 = vmlaq_f32(s_128, h_k, x_k);
80 }
81
82 // Perform non-vector operations for any remaining items.
83 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
84 const float x_k = *x_p;
85 x2_sum += x_k * x_k;
86 s += *h_p * x_k;
87 }
88
89 x_p = &x[0];
90 }
91
92 // Combine the accumulated vector and scalar values.
93 float* v = reinterpret_cast<float*>(&x2_sum_128);
94 x2_sum += v[0] + v[1] + v[2] + v[3];
95 v = reinterpret_cast<float*>(&s_128);
96 s += v[0] + v[1] + v[2] + v[3];
97
98 // Compute the matched filter error.
99 float e = y[i] - s;
100 const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
101 (*error_sum) += e * e;
102
103 // Update the matched filter estimate in an NLMS manner.
104 if (x2_sum > x2_sum_threshold && !saturation) {
105 RTC_DCHECK_LT(0.f, x2_sum);
106 const float alpha = smoothing * e / x2_sum;
107 const float32x4_t alpha_128 = vmovq_n_f32(alpha);
108
109 // filter = filter + smoothing * (y - filter * x) * x / x * x.
110 float* h_p = &h[0];
111 x_p = &x[x_start_index];
112
113 // Perform the loop in two chunks.
114 for (int limit : {chunk1, chunk2}) {
115 // Perform 128 bit vector operations.
116 const int limit_by_4 = limit >> 2;
117 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
118 // Load the data into 128 bit vectors.
119 float32x4_t h_k = vld1q_f32(h_p);
120 const float32x4_t x_k = vld1q_f32(x_p);
121 // Compute h = h + alpha * x.
122 h_k = vmlaq_f32(h_k, alpha_128, x_k);
123
124 // Store the result.
125 vst1q_f32(h_p, h_k);
126 }
127
128 // Perform non-vector operations for any remaining items.
129 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
130 *h_p += alpha * *x_p;
131 }
132
133 x_p = &x[0];
134 }
135
136 *filters_updated = true;
137 }
138
139 x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
140 }
141 }
142
143 #endif
144
145 #if defined(WEBRTC_ARCH_X86_FAMILY)
146
MatchedFilterCore_SSE2(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)147 void MatchedFilterCore_SSE2(size_t x_start_index,
148 float x2_sum_threshold,
149 float smoothing,
150 rtc::ArrayView<const float> x,
151 rtc::ArrayView<const float> y,
152 rtc::ArrayView<float> h,
153 bool* filters_updated,
154 float* error_sum) {
155 const int h_size = static_cast<int>(h.size());
156 const int x_size = static_cast<int>(x.size());
157 RTC_DCHECK_EQ(0, h_size % 4);
158
159 // Process for all samples in the sub-block.
160 for (size_t i = 0; i < y.size(); ++i) {
161 // Apply the matched filter as filter * x, and compute x * x.
162
163 RTC_DCHECK_GT(x_size, x_start_index);
164 const float* x_p = &x[x_start_index];
165 const float* h_p = &h[0];
166
167 // Initialize values for the accumulation.
168 __m128 s_128 = _mm_set1_ps(0);
169 __m128 x2_sum_128 = _mm_set1_ps(0);
170 float x2_sum = 0.f;
171 float s = 0;
172
173 // Compute loop chunk sizes until, and after, the wraparound of the circular
174 // buffer for x.
175 const int chunk1 =
176 std::min(h_size, static_cast<int>(x_size - x_start_index));
177
178 // Perform the loop in two chunks.
179 const int chunk2 = h_size - chunk1;
180 for (int limit : {chunk1, chunk2}) {
181 // Perform 128 bit vector operations.
182 const int limit_by_4 = limit >> 2;
183 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
184 // Load the data into 128 bit vectors.
185 const __m128 x_k = _mm_loadu_ps(x_p);
186 const __m128 h_k = _mm_loadu_ps(h_p);
187 const __m128 xx = _mm_mul_ps(x_k, x_k);
188 // Compute and accumulate x * x and h * x.
189 x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
190 const __m128 hx = _mm_mul_ps(h_k, x_k);
191 s_128 = _mm_add_ps(s_128, hx);
192 }
193
194 // Perform non-vector operations for any remaining items.
195 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
196 const float x_k = *x_p;
197 x2_sum += x_k * x_k;
198 s += *h_p * x_k;
199 }
200
201 x_p = &x[0];
202 }
203
204 // Combine the accumulated vector and scalar values.
205 float* v = reinterpret_cast<float*>(&x2_sum_128);
206 x2_sum += v[0] + v[1] + v[2] + v[3];
207 v = reinterpret_cast<float*>(&s_128);
208 s += v[0] + v[1] + v[2] + v[3];
209
210 // Compute the matched filter error.
211 float e = y[i] - s;
212 const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
213 (*error_sum) += e * e;
214
215 // Update the matched filter estimate in an NLMS manner.
216 if (x2_sum > x2_sum_threshold && !saturation) {
217 RTC_DCHECK_LT(0.f, x2_sum);
218 const float alpha = smoothing * e / x2_sum;
219 const __m128 alpha_128 = _mm_set1_ps(alpha);
220
221 // filter = filter + smoothing * (y - filter * x) * x / x * x.
222 float* h_p = &h[0];
223 x_p = &x[x_start_index];
224
225 // Perform the loop in two chunks.
226 for (int limit : {chunk1, chunk2}) {
227 // Perform 128 bit vector operations.
228 const int limit_by_4 = limit >> 2;
229 for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
230 // Load the data into 128 bit vectors.
231 __m128 h_k = _mm_loadu_ps(h_p);
232 const __m128 x_k = _mm_loadu_ps(x_p);
233
234 // Compute h = h + alpha * x.
235 const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
236 h_k = _mm_add_ps(h_k, alpha_x);
237
238 // Store the result.
239 _mm_storeu_ps(h_p, h_k);
240 }
241
242 // Perform non-vector operations for any remaining items.
243 for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
244 *h_p += alpha * *x_p;
245 }
246
247 x_p = &x[0];
248 }
249
250 *filters_updated = true;
251 }
252
253 x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
254 }
255 }
256 #endif
257
MatchedFilterCore(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)258 void MatchedFilterCore(size_t x_start_index,
259 float x2_sum_threshold,
260 float smoothing,
261 rtc::ArrayView<const float> x,
262 rtc::ArrayView<const float> y,
263 rtc::ArrayView<float> h,
264 bool* filters_updated,
265 float* error_sum) {
266 // Process for all samples in the sub-block.
267 for (size_t i = 0; i < y.size(); ++i) {
268 // Apply the matched filter as filter * x, and compute x * x.
269 float x2_sum = 0.f;
270 float s = 0;
271 size_t x_index = x_start_index;
272 for (size_t k = 0; k < h.size(); ++k) {
273 x2_sum += x[x_index] * x[x_index];
274 s += h[k] * x[x_index];
275 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
276 }
277
278 // Compute the matched filter error.
279 float e = y[i] - s;
280 const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
281 (*error_sum) += e * e;
282
283 // Update the matched filter estimate in an NLMS manner.
284 if (x2_sum > x2_sum_threshold && !saturation) {
285 RTC_DCHECK_LT(0.f, x2_sum);
286 const float alpha = smoothing * e / x2_sum;
287
288 // filter = filter + smoothing * (y - filter * x) * x / x * x.
289 size_t x_index = x_start_index;
290 for (size_t k = 0; k < h.size(); ++k) {
291 h[k] += alpha * x[x_index];
292 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
293 }
294 *filters_updated = true;
295 }
296
297 x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
298 }
299 }
300
301 } // namespace aec3
302
MatchedFilter(ApmDataDumper * data_dumper,Aec3Optimization optimization,size_t sub_block_size,size_t window_size_sub_blocks,int num_matched_filters,size_t alignment_shift_sub_blocks,float excitation_limit,float smoothing,float matching_filter_threshold)303 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
304 Aec3Optimization optimization,
305 size_t sub_block_size,
306 size_t window_size_sub_blocks,
307 int num_matched_filters,
308 size_t alignment_shift_sub_blocks,
309 float excitation_limit,
310 float smoothing,
311 float matching_filter_threshold)
312 : data_dumper_(data_dumper),
313 optimization_(optimization),
314 sub_block_size_(sub_block_size),
315 filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_),
316 filters_(
317 num_matched_filters,
318 std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
319 lag_estimates_(num_matched_filters),
320 filters_offsets_(num_matched_filters, 0),
321 excitation_limit_(excitation_limit),
322 smoothing_(smoothing),
323 matching_filter_threshold_(matching_filter_threshold) {
324 RTC_DCHECK(data_dumper);
325 RTC_DCHECK_LT(0, window_size_sub_blocks);
326 RTC_DCHECK((kBlockSize % sub_block_size) == 0);
327 RTC_DCHECK((sub_block_size % 4) == 0);
328 }
329
330 MatchedFilter::~MatchedFilter() = default;
331
Reset()332 void MatchedFilter::Reset() {
333 for (auto& f : filters_) {
334 std::fill(f.begin(), f.end(), 0.f);
335 }
336
337 for (auto& l : lag_estimates_) {
338 l = MatchedFilter::LagEstimate();
339 }
340 }
341
Update(const DownsampledRenderBuffer & render_buffer,rtc::ArrayView<const float> capture)342 void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
343 rtc::ArrayView<const float> capture) {
344 RTC_DCHECK_EQ(sub_block_size_, capture.size());
345 auto& y = capture;
346
347 const float x2_sum_threshold =
348 filters_[0].size() * excitation_limit_ * excitation_limit_;
349
350 // Apply all matched filters.
351 size_t alignment_shift = 0;
352 for (size_t n = 0; n < filters_.size(); ++n) {
353 float error_sum = 0.f;
354 bool filters_updated = false;
355
356 size_t x_start_index =
357 (render_buffer.read + alignment_shift + sub_block_size_ - 1) %
358 render_buffer.buffer.size();
359
360 switch (optimization_) {
361 #if defined(WEBRTC_ARCH_X86_FAMILY)
362 case Aec3Optimization::kSse2:
363 aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold,
364 smoothing_, render_buffer.buffer, y,
365 filters_[n], &filters_updated, &error_sum);
366 break;
367 #endif
368 #if defined(WEBRTC_HAS_NEON)
369 case Aec3Optimization::kNeon:
370 aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold,
371 smoothing_, render_buffer.buffer, y,
372 filters_[n], &filters_updated, &error_sum);
373 break;
374 #endif
375 default:
376 aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing_,
377 render_buffer.buffer, y, filters_[n],
378 &filters_updated, &error_sum);
379 }
380
381 // Compute anchor for the matched filter error.
382 const float error_sum_anchor =
383 std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
384
385 // Estimate the lag in the matched filter as the distance to the portion in
386 // the filter that contributes the most to the matched filter output. This
387 // is detected as the peak of the matched filter.
388 const size_t lag_estimate = std::distance(
389 filters_[n].begin(),
390 std::max_element(
391 filters_[n].begin(), filters_[n].end(),
392 [](float a, float b) -> bool { return a * a < b * b; }));
393
394 // Update the lag estimates for the matched filter.
395 lag_estimates_[n] = LagEstimate(
396 error_sum_anchor - error_sum,
397 (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
398 error_sum < matching_filter_threshold_ * error_sum_anchor),
399 lag_estimate + alignment_shift, filters_updated);
400
401 RTC_DCHECK_GE(10, filters_.size());
402 switch (n) {
403 case 0:
404 data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
405 break;
406 case 1:
407 data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
408 break;
409 case 2:
410 data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
411 break;
412 case 3:
413 data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
414 break;
415 case 4:
416 data_dumper_->DumpRaw("aec3_correlator_4_h", filters_[4]);
417 break;
418 case 5:
419 data_dumper_->DumpRaw("aec3_correlator_5_h", filters_[5]);
420 break;
421 case 6:
422 data_dumper_->DumpRaw("aec3_correlator_6_h", filters_[6]);
423 break;
424 case 7:
425 data_dumper_->DumpRaw("aec3_correlator_7_h", filters_[7]);
426 break;
427 case 8:
428 data_dumper_->DumpRaw("aec3_correlator_8_h", filters_[8]);
429 break;
430 case 9:
431 data_dumper_->DumpRaw("aec3_correlator_9_h", filters_[9]);
432 break;
433 default:
434 RTC_NOTREACHED();
435 }
436
437 alignment_shift += filter_intra_lag_shift_;
438 }
439 }
440
LogFilterProperties(int sample_rate_hz,size_t shift,size_t downsampling_factor) const441 void MatchedFilter::LogFilterProperties(int sample_rate_hz,
442 size_t shift,
443 size_t downsampling_factor) const {
444 size_t alignment_shift = 0;
445 constexpr int kFsBy1000 = 16;
446 for (size_t k = 0; k < filters_.size(); ++k) {
447 int start = static_cast<int>(alignment_shift * downsampling_factor);
448 int end = static_cast<int>((alignment_shift + filters_[k].size()) *
449 downsampling_factor);
450 RTC_LOG(LS_VERBOSE) << "Filter " << k << ": start: "
451 << (start - static_cast<int>(shift)) / kFsBy1000
452 << " ms, end: "
453 << (end - static_cast<int>(shift)) / kFsBy1000
454 << " ms.";
455 alignment_shift += filter_intra_lag_shift_;
456 }
457 }
458
459 } // namespace webrtc
460