• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "modules/audio_processing/aec3/matched_filter.h"
11 
12 // Defines WEBRTC_ARCH_X86_FAMILY, used below.
13 #include "rtc_base/system/arch.h"
14 
15 #if defined(WEBRTC_HAS_NEON)
16 #include <arm_neon.h>
17 #endif
18 #if defined(WEBRTC_ARCH_X86_FAMILY)
19 #include <emmintrin.h>
20 #endif
21 #include <algorithm>
22 #include <cstddef>
23 #include <initializer_list>
24 #include <iterator>
25 #include <numeric>
26 
27 #include "modules/audio_processing/aec3/downsampled_render_buffer.h"
28 #include "modules/audio_processing/logging/apm_data_dumper.h"
29 #include "rtc_base/checks.h"
30 #include "rtc_base/logging.h"
31 
32 namespace webrtc {
33 namespace aec3 {
34 
35 #if defined(WEBRTC_HAS_NEON)
36 
MatchedFilterCore_NEON(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)37 void MatchedFilterCore_NEON(size_t x_start_index,
38                             float x2_sum_threshold,
39                             float smoothing,
40                             rtc::ArrayView<const float> x,
41                             rtc::ArrayView<const float> y,
42                             rtc::ArrayView<float> h,
43                             bool* filters_updated,
44                             float* error_sum) {
45   const int h_size = static_cast<int>(h.size());
46   const int x_size = static_cast<int>(x.size());
47   RTC_DCHECK_EQ(0, h_size % 4);
48 
49   // Process for all samples in the sub-block.
50   for (size_t i = 0; i < y.size(); ++i) {
51     // Apply the matched filter as filter * x, and compute x * x.
52 
53     RTC_DCHECK_GT(x_size, x_start_index);
54     const float* x_p = &x[x_start_index];
55     const float* h_p = &h[0];
56 
57     // Initialize values for the accumulation.
58     float32x4_t s_128 = vdupq_n_f32(0);
59     float32x4_t x2_sum_128 = vdupq_n_f32(0);
60     float x2_sum = 0.f;
61     float s = 0;
62 
63     // Compute loop chunk sizes until, and after, the wraparound of the circular
64     // buffer for x.
65     const int chunk1 =
66         std::min(h_size, static_cast<int>(x_size - x_start_index));
67 
68     // Perform the loop in two chunks.
69     const int chunk2 = h_size - chunk1;
70     for (int limit : {chunk1, chunk2}) {
71       // Perform 128 bit vector operations.
72       const int limit_by_4 = limit >> 2;
73       for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
74         // Load the data into 128 bit vectors.
75         const float32x4_t x_k = vld1q_f32(x_p);
76         const float32x4_t h_k = vld1q_f32(h_p);
77         // Compute and accumulate x * x and h * x.
78         x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
79         s_128 = vmlaq_f32(s_128, h_k, x_k);
80       }
81 
82       // Perform non-vector operations for any remaining items.
83       for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
84         const float x_k = *x_p;
85         x2_sum += x_k * x_k;
86         s += *h_p * x_k;
87       }
88 
89       x_p = &x[0];
90     }
91 
92     // Combine the accumulated vector and scalar values.
93     float* v = reinterpret_cast<float*>(&x2_sum_128);
94     x2_sum += v[0] + v[1] + v[2] + v[3];
95     v = reinterpret_cast<float*>(&s_128);
96     s += v[0] + v[1] + v[2] + v[3];
97 
98     // Compute the matched filter error.
99     float e = y[i] - s;
100     const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
101     (*error_sum) += e * e;
102 
103     // Update the matched filter estimate in an NLMS manner.
104     if (x2_sum > x2_sum_threshold && !saturation) {
105       RTC_DCHECK_LT(0.f, x2_sum);
106       const float alpha = smoothing * e / x2_sum;
107       const float32x4_t alpha_128 = vmovq_n_f32(alpha);
108 
109       // filter = filter + smoothing * (y - filter * x) * x / x * x.
110       float* h_p = &h[0];
111       x_p = &x[x_start_index];
112 
113       // Perform the loop in two chunks.
114       for (int limit : {chunk1, chunk2}) {
115         // Perform 128 bit vector operations.
116         const int limit_by_4 = limit >> 2;
117         for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
118           // Load the data into 128 bit vectors.
119           float32x4_t h_k = vld1q_f32(h_p);
120           const float32x4_t x_k = vld1q_f32(x_p);
121           // Compute h = h + alpha * x.
122           h_k = vmlaq_f32(h_k, alpha_128, x_k);
123 
124           // Store the result.
125           vst1q_f32(h_p, h_k);
126         }
127 
128         // Perform non-vector operations for any remaining items.
129         for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
130           *h_p += alpha * *x_p;
131         }
132 
133         x_p = &x[0];
134       }
135 
136       *filters_updated = true;
137     }
138 
139     x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
140   }
141 }
142 
143 #endif
144 
145 #if defined(WEBRTC_ARCH_X86_FAMILY)
146 
MatchedFilterCore_SSE2(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)147 void MatchedFilterCore_SSE2(size_t x_start_index,
148                             float x2_sum_threshold,
149                             float smoothing,
150                             rtc::ArrayView<const float> x,
151                             rtc::ArrayView<const float> y,
152                             rtc::ArrayView<float> h,
153                             bool* filters_updated,
154                             float* error_sum) {
155   const int h_size = static_cast<int>(h.size());
156   const int x_size = static_cast<int>(x.size());
157   RTC_DCHECK_EQ(0, h_size % 4);
158 
159   // Process for all samples in the sub-block.
160   for (size_t i = 0; i < y.size(); ++i) {
161     // Apply the matched filter as filter * x, and compute x * x.
162 
163     RTC_DCHECK_GT(x_size, x_start_index);
164     const float* x_p = &x[x_start_index];
165     const float* h_p = &h[0];
166 
167     // Initialize values for the accumulation.
168     __m128 s_128 = _mm_set1_ps(0);
169     __m128 x2_sum_128 = _mm_set1_ps(0);
170     float x2_sum = 0.f;
171     float s = 0;
172 
173     // Compute loop chunk sizes until, and after, the wraparound of the circular
174     // buffer for x.
175     const int chunk1 =
176         std::min(h_size, static_cast<int>(x_size - x_start_index));
177 
178     // Perform the loop in two chunks.
179     const int chunk2 = h_size - chunk1;
180     for (int limit : {chunk1, chunk2}) {
181       // Perform 128 bit vector operations.
182       const int limit_by_4 = limit >> 2;
183       for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
184         // Load the data into 128 bit vectors.
185         const __m128 x_k = _mm_loadu_ps(x_p);
186         const __m128 h_k = _mm_loadu_ps(h_p);
187         const __m128 xx = _mm_mul_ps(x_k, x_k);
188         // Compute and accumulate x * x and h * x.
189         x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
190         const __m128 hx = _mm_mul_ps(h_k, x_k);
191         s_128 = _mm_add_ps(s_128, hx);
192       }
193 
194       // Perform non-vector operations for any remaining items.
195       for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
196         const float x_k = *x_p;
197         x2_sum += x_k * x_k;
198         s += *h_p * x_k;
199       }
200 
201       x_p = &x[0];
202     }
203 
204     // Combine the accumulated vector and scalar values.
205     float* v = reinterpret_cast<float*>(&x2_sum_128);
206     x2_sum += v[0] + v[1] + v[2] + v[3];
207     v = reinterpret_cast<float*>(&s_128);
208     s += v[0] + v[1] + v[2] + v[3];
209 
210     // Compute the matched filter error.
211     float e = y[i] - s;
212     const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
213     (*error_sum) += e * e;
214 
215     // Update the matched filter estimate in an NLMS manner.
216     if (x2_sum > x2_sum_threshold && !saturation) {
217       RTC_DCHECK_LT(0.f, x2_sum);
218       const float alpha = smoothing * e / x2_sum;
219       const __m128 alpha_128 = _mm_set1_ps(alpha);
220 
221       // filter = filter + smoothing * (y - filter * x) * x / x * x.
222       float* h_p = &h[0];
223       x_p = &x[x_start_index];
224 
225       // Perform the loop in two chunks.
226       for (int limit : {chunk1, chunk2}) {
227         // Perform 128 bit vector operations.
228         const int limit_by_4 = limit >> 2;
229         for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
230           // Load the data into 128 bit vectors.
231           __m128 h_k = _mm_loadu_ps(h_p);
232           const __m128 x_k = _mm_loadu_ps(x_p);
233 
234           // Compute h = h + alpha * x.
235           const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
236           h_k = _mm_add_ps(h_k, alpha_x);
237 
238           // Store the result.
239           _mm_storeu_ps(h_p, h_k);
240         }
241 
242         // Perform non-vector operations for any remaining items.
243         for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
244           *h_p += alpha * *x_p;
245         }
246 
247         x_p = &x[0];
248       }
249 
250       *filters_updated = true;
251     }
252 
253     x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
254   }
255 }
256 #endif
257 
MatchedFilterCore(size_t x_start_index,float x2_sum_threshold,float smoothing,rtc::ArrayView<const float> x,rtc::ArrayView<const float> y,rtc::ArrayView<float> h,bool * filters_updated,float * error_sum)258 void MatchedFilterCore(size_t x_start_index,
259                        float x2_sum_threshold,
260                        float smoothing,
261                        rtc::ArrayView<const float> x,
262                        rtc::ArrayView<const float> y,
263                        rtc::ArrayView<float> h,
264                        bool* filters_updated,
265                        float* error_sum) {
266   // Process for all samples in the sub-block.
267   for (size_t i = 0; i < y.size(); ++i) {
268     // Apply the matched filter as filter * x, and compute x * x.
269     float x2_sum = 0.f;
270     float s = 0;
271     size_t x_index = x_start_index;
272     for (size_t k = 0; k < h.size(); ++k) {
273       x2_sum += x[x_index] * x[x_index];
274       s += h[k] * x[x_index];
275       x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
276     }
277 
278     // Compute the matched filter error.
279     float e = y[i] - s;
280     const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
281     (*error_sum) += e * e;
282 
283     // Update the matched filter estimate in an NLMS manner.
284     if (x2_sum > x2_sum_threshold && !saturation) {
285       RTC_DCHECK_LT(0.f, x2_sum);
286       const float alpha = smoothing * e / x2_sum;
287 
288       // filter = filter + smoothing * (y - filter * x) * x / x * x.
289       size_t x_index = x_start_index;
290       for (size_t k = 0; k < h.size(); ++k) {
291         h[k] += alpha * x[x_index];
292         x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
293       }
294       *filters_updated = true;
295     }
296 
297     x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
298   }
299 }
300 
301 }  // namespace aec3
302 
MatchedFilter(ApmDataDumper * data_dumper,Aec3Optimization optimization,size_t sub_block_size,size_t window_size_sub_blocks,int num_matched_filters,size_t alignment_shift_sub_blocks,float excitation_limit,float smoothing,float matching_filter_threshold)303 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
304                              Aec3Optimization optimization,
305                              size_t sub_block_size,
306                              size_t window_size_sub_blocks,
307                              int num_matched_filters,
308                              size_t alignment_shift_sub_blocks,
309                              float excitation_limit,
310                              float smoothing,
311                              float matching_filter_threshold)
312     : data_dumper_(data_dumper),
313       optimization_(optimization),
314       sub_block_size_(sub_block_size),
315       filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_),
316       filters_(
317           num_matched_filters,
318           std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
319       lag_estimates_(num_matched_filters),
320       filters_offsets_(num_matched_filters, 0),
321       excitation_limit_(excitation_limit),
322       smoothing_(smoothing),
323       matching_filter_threshold_(matching_filter_threshold) {
324   RTC_DCHECK(data_dumper);
325   RTC_DCHECK_LT(0, window_size_sub_blocks);
326   RTC_DCHECK((kBlockSize % sub_block_size) == 0);
327   RTC_DCHECK((sub_block_size % 4) == 0);
328 }
329 
330 MatchedFilter::~MatchedFilter() = default;
331 
Reset()332 void MatchedFilter::Reset() {
333   for (auto& f : filters_) {
334     std::fill(f.begin(), f.end(), 0.f);
335   }
336 
337   for (auto& l : lag_estimates_) {
338     l = MatchedFilter::LagEstimate();
339   }
340 }
341 
Update(const DownsampledRenderBuffer & render_buffer,rtc::ArrayView<const float> capture)342 void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
343                            rtc::ArrayView<const float> capture) {
344   RTC_DCHECK_EQ(sub_block_size_, capture.size());
345   auto& y = capture;
346 
347   const float x2_sum_threshold =
348       filters_[0].size() * excitation_limit_ * excitation_limit_;
349 
350   // Apply all matched filters.
351   size_t alignment_shift = 0;
352   for (size_t n = 0; n < filters_.size(); ++n) {
353     float error_sum = 0.f;
354     bool filters_updated = false;
355 
356     size_t x_start_index =
357         (render_buffer.read + alignment_shift + sub_block_size_ - 1) %
358         render_buffer.buffer.size();
359 
360     switch (optimization_) {
361 #if defined(WEBRTC_ARCH_X86_FAMILY)
362       case Aec3Optimization::kSse2:
363         aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold,
364                                      smoothing_, render_buffer.buffer, y,
365                                      filters_[n], &filters_updated, &error_sum);
366         break;
367 #endif
368 #if defined(WEBRTC_HAS_NEON)
369       case Aec3Optimization::kNeon:
370         aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold,
371                                      smoothing_, render_buffer.buffer, y,
372                                      filters_[n], &filters_updated, &error_sum);
373         break;
374 #endif
375       default:
376         aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing_,
377                                 render_buffer.buffer, y, filters_[n],
378                                 &filters_updated, &error_sum);
379     }
380 
381     // Compute anchor for the matched filter error.
382     const float error_sum_anchor =
383         std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
384 
385     // Estimate the lag in the matched filter as the distance to the portion in
386     // the filter that contributes the most to the matched filter output. This
387     // is detected as the peak of the matched filter.
388     const size_t lag_estimate = std::distance(
389         filters_[n].begin(),
390         std::max_element(
391             filters_[n].begin(), filters_[n].end(),
392             [](float a, float b) -> bool { return a * a < b * b; }));
393 
394     // Update the lag estimates for the matched filter.
395     lag_estimates_[n] = LagEstimate(
396         error_sum_anchor - error_sum,
397         (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
398          error_sum < matching_filter_threshold_ * error_sum_anchor),
399         lag_estimate + alignment_shift, filters_updated);
400 
401     RTC_DCHECK_GE(10, filters_.size());
402     switch (n) {
403       case 0:
404         data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
405         break;
406       case 1:
407         data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
408         break;
409       case 2:
410         data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
411         break;
412       case 3:
413         data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
414         break;
415       case 4:
416         data_dumper_->DumpRaw("aec3_correlator_4_h", filters_[4]);
417         break;
418       case 5:
419         data_dumper_->DumpRaw("aec3_correlator_5_h", filters_[5]);
420         break;
421       case 6:
422         data_dumper_->DumpRaw("aec3_correlator_6_h", filters_[6]);
423         break;
424       case 7:
425         data_dumper_->DumpRaw("aec3_correlator_7_h", filters_[7]);
426         break;
427       case 8:
428         data_dumper_->DumpRaw("aec3_correlator_8_h", filters_[8]);
429         break;
430       case 9:
431         data_dumper_->DumpRaw("aec3_correlator_9_h", filters_[9]);
432         break;
433       default:
434         RTC_NOTREACHED();
435     }
436 
437     alignment_shift += filter_intra_lag_shift_;
438   }
439 }
440 
LogFilterProperties(int sample_rate_hz,size_t shift,size_t downsampling_factor) const441 void MatchedFilter::LogFilterProperties(int sample_rate_hz,
442                                         size_t shift,
443                                         size_t downsampling_factor) const {
444   size_t alignment_shift = 0;
445   constexpr int kFsBy1000 = 16;
446   for (size_t k = 0; k < filters_.size(); ++k) {
447     int start = static_cast<int>(alignment_shift * downsampling_factor);
448     int end = static_cast<int>((alignment_shift + filters_[k].size()) *
449                                downsampling_factor);
450     RTC_LOG(LS_VERBOSE) << "Filter " << k << ": start: "
451                         << (start - static_cast<int>(shift)) / kFsBy1000
452                         << " ms, end: "
453                         << (end - static_cast<int>(shift)) / kFsBy1000
454                         << " ms.";
455     alignment_shift += filter_intra_lag_shift_;
456   }
457 }
458 
459 }  // namespace webrtc
460