1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "modules/audio_processing/aec3/adaptive_fir_filter.h"
12
13 // Defines WEBRTC_ARCH_X86_FAMILY, used below.
14 #include "rtc_base/system/arch.h"
15
16 #if defined(WEBRTC_HAS_NEON)
17 #include <arm_neon.h>
18 #endif
19 #if defined(WEBRTC_ARCH_X86_FAMILY)
20 #include <emmintrin.h>
21 #endif
22 #include <math.h>
23
24 #include <algorithm>
25 #include <functional>
26
27 #include "modules/audio_processing/aec3/fft_data.h"
28 #include "rtc_base/checks.h"
29
30 namespace webrtc {
31
32 namespace aec3 {
33
34 // Computes and stores the frequency response of the filter.
ComputeFrequencyResponse(size_t num_partitions,const std::vector<std::vector<FftData>> & H,std::vector<std::array<float,kFftLengthBy2Plus1>> * H2)35 void ComputeFrequencyResponse(
36 size_t num_partitions,
37 const std::vector<std::vector<FftData>>& H,
38 std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
39 for (auto& H2_ch : *H2) {
40 H2_ch.fill(0.f);
41 }
42
43 const size_t num_render_channels = H[0].size();
44 RTC_DCHECK_EQ(H.size(), H2->capacity());
45 for (size_t p = 0; p < num_partitions; ++p) {
46 RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
47 for (size_t ch = 0; ch < num_render_channels; ++ch) {
48 for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) {
49 float tmp =
50 H[p][ch].re[j] * H[p][ch].re[j] + H[p][ch].im[j] * H[p][ch].im[j];
51 (*H2)[p][j] = std::max((*H2)[p][j], tmp);
52 }
53 }
54 }
55 }
56
57 #if defined(WEBRTC_HAS_NEON)
58 // Computes and stores the frequency response of the filter.
ComputeFrequencyResponse_Neon(size_t num_partitions,const std::vector<std::vector<FftData>> & H,std::vector<std::array<float,kFftLengthBy2Plus1>> * H2)59 void ComputeFrequencyResponse_Neon(
60 size_t num_partitions,
61 const std::vector<std::vector<FftData>>& H,
62 std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
63 for (auto& H2_ch : *H2) {
64 H2_ch.fill(0.f);
65 }
66
67 const size_t num_render_channels = H[0].size();
68 RTC_DCHECK_EQ(H.size(), H2->capacity());
69 for (size_t p = 0; p < num_partitions; ++p) {
70 RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
71 for (size_t ch = 0; ch < num_render_channels; ++ch) {
72 for (size_t j = 0; j < kFftLengthBy2; j += 4) {
73 const float32x4_t re = vld1q_f32(&H[p][ch].re[j]);
74 const float32x4_t im = vld1q_f32(&H[p][ch].im[j]);
75 float32x4_t H2_new = vmulq_f32(re, re);
76 H2_new = vmlaq_f32(H2_new, im, im);
77 float32x4_t H2_p_j = vld1q_f32(&(*H2)[p][j]);
78 H2_p_j = vmaxq_f32(H2_p_j, H2_new);
79 vst1q_f32(&(*H2)[p][j], H2_p_j);
80 }
81 float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] +
82 H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2];
83 (*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new);
84 }
85 }
86 }
87 #endif
88
89 #if defined(WEBRTC_ARCH_X86_FAMILY)
90 // Computes and stores the frequency response of the filter.
ComputeFrequencyResponse_Sse2(size_t num_partitions,const std::vector<std::vector<FftData>> & H,std::vector<std::array<float,kFftLengthBy2Plus1>> * H2)91 void ComputeFrequencyResponse_Sse2(
92 size_t num_partitions,
93 const std::vector<std::vector<FftData>>& H,
94 std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
95 for (auto& H2_ch : *H2) {
96 H2_ch.fill(0.f);
97 }
98
99 const size_t num_render_channels = H[0].size();
100 RTC_DCHECK_EQ(H.size(), H2->capacity());
101 // constexpr __mmmask8 kMaxMask = static_cast<__mmmask8>(256u);
102 for (size_t p = 0; p < num_partitions; ++p) {
103 RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
104 for (size_t ch = 0; ch < num_render_channels; ++ch) {
105 for (size_t j = 0; j < kFftLengthBy2; j += 4) {
106 const __m128 re = _mm_loadu_ps(&H[p][ch].re[j]);
107 const __m128 re2 = _mm_mul_ps(re, re);
108 const __m128 im = _mm_loadu_ps(&H[p][ch].im[j]);
109 const __m128 im2 = _mm_mul_ps(im, im);
110 const __m128 H2_new = _mm_add_ps(re2, im2);
111 __m128 H2_k_j = _mm_loadu_ps(&(*H2)[p][j]);
112 H2_k_j = _mm_max_ps(H2_k_j, H2_new);
113 _mm_storeu_ps(&(*H2)[p][j], H2_k_j);
114 }
115 float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] +
116 H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2];
117 (*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new);
118 }
119 }
120 }
121 #endif
122
123 // Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)).
AdaptPartitions(const RenderBuffer & render_buffer,const FftData & G,size_t num_partitions,std::vector<std::vector<FftData>> * H)124 void AdaptPartitions(const RenderBuffer& render_buffer,
125 const FftData& G,
126 size_t num_partitions,
127 std::vector<std::vector<FftData>>* H) {
128 rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
129 render_buffer.GetFftBuffer();
130 size_t index = render_buffer.Position();
131 const size_t num_render_channels = render_buffer_data[index].size();
132 for (size_t p = 0; p < num_partitions; ++p) {
133 for (size_t ch = 0; ch < num_render_channels; ++ch) {
134 const FftData& X_p_ch = render_buffer_data[index][ch];
135 FftData& H_p_ch = (*H)[p][ch];
136 for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
137 H_p_ch.re[k] += X_p_ch.re[k] * G.re[k] + X_p_ch.im[k] * G.im[k];
138 H_p_ch.im[k] += X_p_ch.re[k] * G.im[k] - X_p_ch.im[k] * G.re[k];
139 }
140 }
141 index = index < (render_buffer_data.size() - 1) ? index + 1 : 0;
142 }
143 }
144
145 #if defined(WEBRTC_HAS_NEON)
146 // Adapts the filter partitions. (Neon variant)
AdaptPartitions_Neon(const RenderBuffer & render_buffer,const FftData & G,size_t num_partitions,std::vector<std::vector<FftData>> * H)147 void AdaptPartitions_Neon(const RenderBuffer& render_buffer,
148 const FftData& G,
149 size_t num_partitions,
150 std::vector<std::vector<FftData>>* H) {
151 rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
152 render_buffer.GetFftBuffer();
153 const size_t num_render_channels = render_buffer_data[0].size();
154 const size_t lim1 = std::min(
155 render_buffer_data.size() - render_buffer.Position(), num_partitions);
156 const size_t lim2 = num_partitions;
157 constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
158
159 size_t X_partition = render_buffer.Position();
160 size_t limit = lim1;
161 size_t p = 0;
162 do {
163 for (; p < limit; ++p, ++X_partition) {
164 for (size_t ch = 0; ch < num_render_channels; ++ch) {
165 FftData& H_p_ch = (*H)[p][ch];
166 const FftData& X = render_buffer_data[X_partition][ch];
167 for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
168 const float32x4_t G_re = vld1q_f32(&G.re[k]);
169 const float32x4_t G_im = vld1q_f32(&G.im[k]);
170 const float32x4_t X_re = vld1q_f32(&X.re[k]);
171 const float32x4_t X_im = vld1q_f32(&X.im[k]);
172 const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]);
173 const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]);
174 const float32x4_t a = vmulq_f32(X_re, G_re);
175 const float32x4_t e = vmlaq_f32(a, X_im, G_im);
176 const float32x4_t c = vmulq_f32(X_re, G_im);
177 const float32x4_t f = vmlsq_f32(c, X_im, G_re);
178 const float32x4_t g = vaddq_f32(H_re, e);
179 const float32x4_t h = vaddq_f32(H_im, f);
180 vst1q_f32(&H_p_ch.re[k], g);
181 vst1q_f32(&H_p_ch.im[k], h);
182 }
183 }
184 }
185
186 X_partition = 0;
187 limit = lim2;
188 } while (p < lim2);
189
190 X_partition = render_buffer.Position();
191 limit = lim1;
192 p = 0;
193 do {
194 for (; p < limit; ++p, ++X_partition) {
195 for (size_t ch = 0; ch < num_render_channels; ++ch) {
196 FftData& H_p_ch = (*H)[p][ch];
197 const FftData& X = render_buffer_data[X_partition][ch];
198
199 H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
200 X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
201 H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
202 X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
203 }
204 }
205 X_partition = 0;
206 limit = lim2;
207 } while (p < lim2);
208 }
209 #endif
210
211 #if defined(WEBRTC_ARCH_X86_FAMILY)
212 // Adapts the filter partitions. (SSE2 variant)
AdaptPartitions_Sse2(const RenderBuffer & render_buffer,const FftData & G,size_t num_partitions,std::vector<std::vector<FftData>> * H)213 void AdaptPartitions_Sse2(const RenderBuffer& render_buffer,
214 const FftData& G,
215 size_t num_partitions,
216 std::vector<std::vector<FftData>>* H) {
217 rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
218 render_buffer.GetFftBuffer();
219 const size_t num_render_channels = render_buffer_data[0].size();
220 const size_t lim1 = std::min(
221 render_buffer_data.size() - render_buffer.Position(), num_partitions);
222 const size_t lim2 = num_partitions;
223 constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
224
225 size_t X_partition = render_buffer.Position();
226 size_t limit = lim1;
227 size_t p = 0;
228 do {
229 for (; p < limit; ++p, ++X_partition) {
230 for (size_t ch = 0; ch < num_render_channels; ++ch) {
231 FftData& H_p_ch = (*H)[p][ch];
232 const FftData& X = render_buffer_data[X_partition][ch];
233
234 for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
235 const __m128 G_re = _mm_loadu_ps(&G.re[k]);
236 const __m128 G_im = _mm_loadu_ps(&G.im[k]);
237 const __m128 X_re = _mm_loadu_ps(&X.re[k]);
238 const __m128 X_im = _mm_loadu_ps(&X.im[k]);
239 const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]);
240 const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]);
241 const __m128 a = _mm_mul_ps(X_re, G_re);
242 const __m128 b = _mm_mul_ps(X_im, G_im);
243 const __m128 c = _mm_mul_ps(X_re, G_im);
244 const __m128 d = _mm_mul_ps(X_im, G_re);
245 const __m128 e = _mm_add_ps(a, b);
246 const __m128 f = _mm_sub_ps(c, d);
247 const __m128 g = _mm_add_ps(H_re, e);
248 const __m128 h = _mm_add_ps(H_im, f);
249 _mm_storeu_ps(&H_p_ch.re[k], g);
250 _mm_storeu_ps(&H_p_ch.im[k], h);
251 }
252 }
253 }
254 X_partition = 0;
255 limit = lim2;
256 } while (p < lim2);
257
258 X_partition = render_buffer.Position();
259 limit = lim1;
260 p = 0;
261 do {
262 for (; p < limit; ++p, ++X_partition) {
263 for (size_t ch = 0; ch < num_render_channels; ++ch) {
264 FftData& H_p_ch = (*H)[p][ch];
265 const FftData& X = render_buffer_data[X_partition][ch];
266
267 H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
268 X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
269 H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
270 X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
271 }
272 }
273
274 X_partition = 0;
275 limit = lim2;
276 } while (p < lim2);
277 }
278 #endif
279
280 // Produces the filter output.
ApplyFilter(const RenderBuffer & render_buffer,size_t num_partitions,const std::vector<std::vector<FftData>> & H,FftData * S)281 void ApplyFilter(const RenderBuffer& render_buffer,
282 size_t num_partitions,
283 const std::vector<std::vector<FftData>>& H,
284 FftData* S) {
285 S->re.fill(0.f);
286 S->im.fill(0.f);
287
288 rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
289 render_buffer.GetFftBuffer();
290 size_t index = render_buffer.Position();
291 const size_t num_render_channels = render_buffer_data[index].size();
292 for (size_t p = 0; p < num_partitions; ++p) {
293 RTC_DCHECK_EQ(num_render_channels, H[p].size());
294 for (size_t ch = 0; ch < num_render_channels; ++ch) {
295 const FftData& X_p_ch = render_buffer_data[index][ch];
296 const FftData& H_p_ch = H[p][ch];
297 for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
298 S->re[k] += X_p_ch.re[k] * H_p_ch.re[k] - X_p_ch.im[k] * H_p_ch.im[k];
299 S->im[k] += X_p_ch.re[k] * H_p_ch.im[k] + X_p_ch.im[k] * H_p_ch.re[k];
300 }
301 }
302 index = index < (render_buffer_data.size() - 1) ? index + 1 : 0;
303 }
304 }
305
306 #if defined(WEBRTC_HAS_NEON)
307 // Produces the filter output (Neon variant).
ApplyFilter_Neon(const RenderBuffer & render_buffer,size_t num_partitions,const std::vector<std::vector<FftData>> & H,FftData * S)308 void ApplyFilter_Neon(const RenderBuffer& render_buffer,
309 size_t num_partitions,
310 const std::vector<std::vector<FftData>>& H,
311 FftData* S) {
312 // const RenderBuffer& render_buffer,
313 // rtc::ArrayView<const FftData> H,
314 // FftData* S) {
315 RTC_DCHECK_GE(H.size(), H.size() - 1);
316 S->Clear();
317
318 rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
319 render_buffer.GetFftBuffer();
320 const size_t num_render_channels = render_buffer_data[0].size();
321 const size_t lim1 = std::min(
322 render_buffer_data.size() - render_buffer.Position(), num_partitions);
323 const size_t lim2 = num_partitions;
324 constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
325
326 size_t X_partition = render_buffer.Position();
327 size_t p = 0;
328 size_t limit = lim1;
329 do {
330 for (; p < limit; ++p, ++X_partition) {
331 for (size_t ch = 0; ch < num_render_channels; ++ch) {
332 const FftData& H_p_ch = H[p][ch];
333 const FftData& X = render_buffer_data[X_partition][ch];
334 for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
335 const float32x4_t X_re = vld1q_f32(&X.re[k]);
336 const float32x4_t X_im = vld1q_f32(&X.im[k]);
337 const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]);
338 const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]);
339 const float32x4_t S_re = vld1q_f32(&S->re[k]);
340 const float32x4_t S_im = vld1q_f32(&S->im[k]);
341 const float32x4_t a = vmulq_f32(X_re, H_re);
342 const float32x4_t e = vmlsq_f32(a, X_im, H_im);
343 const float32x4_t c = vmulq_f32(X_re, H_im);
344 const float32x4_t f = vmlaq_f32(c, X_im, H_re);
345 const float32x4_t g = vaddq_f32(S_re, e);
346 const float32x4_t h = vaddq_f32(S_im, f);
347 vst1q_f32(&S->re[k], g);
348 vst1q_f32(&S->im[k], h);
349 }
350 }
351 }
352 limit = lim2;
353 X_partition = 0;
354 } while (p < lim2);
355
356 X_partition = render_buffer.Position();
357 p = 0;
358 limit = lim1;
359 do {
360 for (; p < limit; ++p, ++X_partition) {
361 for (size_t ch = 0; ch < num_render_channels; ++ch) {
362 const FftData& H_p_ch = H[p][ch];
363 const FftData& X = render_buffer_data[X_partition][ch];
364 S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] -
365 X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2];
366 S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] +
367 X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2];
368 }
369 }
370 limit = lim2;
371 X_partition = 0;
372 } while (p < lim2);
373 }
374 #endif
375
376 #if defined(WEBRTC_ARCH_X86_FAMILY)
377 // Produces the filter output (SSE2 variant).
ApplyFilter_Sse2(const RenderBuffer & render_buffer,size_t num_partitions,const std::vector<std::vector<FftData>> & H,FftData * S)378 void ApplyFilter_Sse2(const RenderBuffer& render_buffer,
379 size_t num_partitions,
380 const std::vector<std::vector<FftData>>& H,
381 FftData* S) {
382 // const RenderBuffer& render_buffer,
383 // rtc::ArrayView<const FftData> H,
384 // FftData* S) {
385 RTC_DCHECK_GE(H.size(), H.size() - 1);
386 S->re.fill(0.f);
387 S->im.fill(0.f);
388
389 rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
390 render_buffer.GetFftBuffer();
391 const size_t num_render_channels = render_buffer_data[0].size();
392 const size_t lim1 = std::min(
393 render_buffer_data.size() - render_buffer.Position(), num_partitions);
394 const size_t lim2 = num_partitions;
395 constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
396
397 size_t X_partition = render_buffer.Position();
398 size_t p = 0;
399 size_t limit = lim1;
400 do {
401 for (; p < limit; ++p, ++X_partition) {
402 for (size_t ch = 0; ch < num_render_channels; ++ch) {
403 const FftData& H_p_ch = H[p][ch];
404 const FftData& X = render_buffer_data[X_partition][ch];
405 for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
406 const __m128 X_re = _mm_loadu_ps(&X.re[k]);
407 const __m128 X_im = _mm_loadu_ps(&X.im[k]);
408 const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]);
409 const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]);
410 const __m128 S_re = _mm_loadu_ps(&S->re[k]);
411 const __m128 S_im = _mm_loadu_ps(&S->im[k]);
412 const __m128 a = _mm_mul_ps(X_re, H_re);
413 const __m128 b = _mm_mul_ps(X_im, H_im);
414 const __m128 c = _mm_mul_ps(X_re, H_im);
415 const __m128 d = _mm_mul_ps(X_im, H_re);
416 const __m128 e = _mm_sub_ps(a, b);
417 const __m128 f = _mm_add_ps(c, d);
418 const __m128 g = _mm_add_ps(S_re, e);
419 const __m128 h = _mm_add_ps(S_im, f);
420 _mm_storeu_ps(&S->re[k], g);
421 _mm_storeu_ps(&S->im[k], h);
422 }
423 }
424 }
425 limit = lim2;
426 X_partition = 0;
427 } while (p < lim2);
428
429 X_partition = render_buffer.Position();
430 p = 0;
431 limit = lim1;
432 do {
433 for (; p < limit; ++p, ++X_partition) {
434 for (size_t ch = 0; ch < num_render_channels; ++ch) {
435 const FftData& H_p_ch = H[p][ch];
436 const FftData& X = render_buffer_data[X_partition][ch];
437 S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] -
438 X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2];
439 S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] +
440 X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2];
441 }
442 }
443 limit = lim2;
444 X_partition = 0;
445 } while (p < lim2);
446 }
447 #endif
448
449 } // namespace aec3
450
451 namespace {
452
453 // Ensures that the newly added filter partitions after a size increase are set
454 // to zero.
ZeroFilter(size_t old_size,size_t new_size,std::vector<std::vector<FftData>> * H)455 void ZeroFilter(size_t old_size,
456 size_t new_size,
457 std::vector<std::vector<FftData>>* H) {
458 RTC_DCHECK_GE(H->size(), old_size);
459 RTC_DCHECK_GE(H->size(), new_size);
460
461 for (size_t p = old_size; p < new_size; ++p) {
462 RTC_DCHECK_EQ((*H)[p].size(), (*H)[0].size());
463 for (size_t ch = 0; ch < (*H)[0].size(); ++ch) {
464 (*H)[p][ch].Clear();
465 }
466 }
467 }
468
469 } // namespace
470
AdaptiveFirFilter(size_t max_size_partitions,size_t initial_size_partitions,size_t size_change_duration_blocks,size_t num_render_channels,Aec3Optimization optimization,ApmDataDumper * data_dumper)471 AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions,
472 size_t initial_size_partitions,
473 size_t size_change_duration_blocks,
474 size_t num_render_channels,
475 Aec3Optimization optimization,
476 ApmDataDumper* data_dumper)
477 : data_dumper_(data_dumper),
478 fft_(),
479 optimization_(optimization),
480 num_render_channels_(num_render_channels),
481 max_size_partitions_(max_size_partitions),
482 size_change_duration_blocks_(
483 static_cast<int>(size_change_duration_blocks)),
484 current_size_partitions_(initial_size_partitions),
485 target_size_partitions_(initial_size_partitions),
486 old_target_size_partitions_(initial_size_partitions),
487 H_(max_size_partitions_, std::vector<FftData>(num_render_channels_)) {
488 RTC_DCHECK(data_dumper_);
489 RTC_DCHECK_GE(max_size_partitions, initial_size_partitions);
490
491 RTC_DCHECK_LT(0, size_change_duration_blocks_);
492 one_by_size_change_duration_blocks_ = 1.f / size_change_duration_blocks_;
493
494 ZeroFilter(0, max_size_partitions_, &H_);
495
496 SetSizePartitions(current_size_partitions_, true);
497 }
498
499 AdaptiveFirFilter::~AdaptiveFirFilter() = default;
500
HandleEchoPathChange()501 void AdaptiveFirFilter::HandleEchoPathChange() {
502 // TODO(peah): Check the value and purpose of the code below.
503 ZeroFilter(current_size_partitions_, max_size_partitions_, &H_);
504 }
505
SetSizePartitions(size_t size,bool immediate_effect)506 void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) {
507 RTC_DCHECK_EQ(max_size_partitions_, H_.capacity());
508 RTC_DCHECK_LE(size, max_size_partitions_);
509
510 target_size_partitions_ = std::min(max_size_partitions_, size);
511 if (immediate_effect) {
512 size_t old_size_partitions_ = current_size_partitions_;
513 current_size_partitions_ = old_target_size_partitions_ =
514 target_size_partitions_;
515 ZeroFilter(old_size_partitions_, current_size_partitions_, &H_);
516
517 partition_to_constrain_ =
518 std::min(partition_to_constrain_, current_size_partitions_ - 1);
519 size_change_counter_ = 0;
520 } else {
521 size_change_counter_ = size_change_duration_blocks_;
522 }
523 }
524
UpdateSize()525 void AdaptiveFirFilter::UpdateSize() {
526 RTC_DCHECK_GE(size_change_duration_blocks_, size_change_counter_);
527 size_t old_size_partitions_ = current_size_partitions_;
528 if (size_change_counter_ > 0) {
529 --size_change_counter_;
530
531 auto average = [](float from, float to, float from_weight) {
532 return from * from_weight + to * (1.f - from_weight);
533 };
534
535 float change_factor =
536 size_change_counter_ * one_by_size_change_duration_blocks_;
537
538 current_size_partitions_ = average(old_target_size_partitions_,
539 target_size_partitions_, change_factor);
540
541 partition_to_constrain_ =
542 std::min(partition_to_constrain_, current_size_partitions_ - 1);
543 } else {
544 current_size_partitions_ = old_target_size_partitions_ =
545 target_size_partitions_;
546 }
547 ZeroFilter(old_size_partitions_, current_size_partitions_, &H_);
548 RTC_DCHECK_LE(0, size_change_counter_);
549 }
550
Filter(const RenderBuffer & render_buffer,FftData * S) const551 void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer,
552 FftData* S) const {
553 RTC_DCHECK(S);
554 switch (optimization_) {
555 #if defined(WEBRTC_ARCH_X86_FAMILY)
556 case Aec3Optimization::kSse2:
557 aec3::ApplyFilter_Sse2(render_buffer, current_size_partitions_, H_, S);
558 break;
559 #endif
560 #if defined(WEBRTC_HAS_NEON)
561 case Aec3Optimization::kNeon:
562 aec3::ApplyFilter_Neon(render_buffer, current_size_partitions_, H_, S);
563 break;
564 #endif
565 default:
566 aec3::ApplyFilter(render_buffer, current_size_partitions_, H_, S);
567 }
568 }
569
Adapt(const RenderBuffer & render_buffer,const FftData & G)570 void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer,
571 const FftData& G) {
572 // Adapt the filter and update the filter size.
573 AdaptAndUpdateSize(render_buffer, G);
574
575 // Constrain the filter partitions in a cyclic manner.
576 Constrain();
577 }
578
Adapt(const RenderBuffer & render_buffer,const FftData & G,std::vector<float> * impulse_response)579 void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer,
580 const FftData& G,
581 std::vector<float>* impulse_response) {
582 // Adapt the filter and update the filter size.
583 AdaptAndUpdateSize(render_buffer, G);
584
585 // Constrain the filter partitions in a cyclic manner.
586 ConstrainAndUpdateImpulseResponse(impulse_response);
587 }
588
ComputeFrequencyResponse(std::vector<std::array<float,kFftLengthBy2Plus1>> * H2) const589 void AdaptiveFirFilter::ComputeFrequencyResponse(
590 std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) const {
591 RTC_DCHECK_GE(max_size_partitions_, H2->capacity());
592
593 H2->resize(current_size_partitions_);
594
595 switch (optimization_) {
596 #if defined(WEBRTC_ARCH_X86_FAMILY)
597 case Aec3Optimization::kSse2:
598 aec3::ComputeFrequencyResponse_Sse2(current_size_partitions_, H_, H2);
599 break;
600 #endif
601 #if defined(WEBRTC_HAS_NEON)
602 case Aec3Optimization::kNeon:
603 aec3::ComputeFrequencyResponse_Neon(current_size_partitions_, H_, H2);
604 break;
605 #endif
606 default:
607 aec3::ComputeFrequencyResponse(current_size_partitions_, H_, H2);
608 }
609 }
610
AdaptAndUpdateSize(const RenderBuffer & render_buffer,const FftData & G)611 void AdaptiveFirFilter::AdaptAndUpdateSize(const RenderBuffer& render_buffer,
612 const FftData& G) {
613 // Update the filter size if needed.
614 UpdateSize();
615
616 // Adapt the filter.
617 switch (optimization_) {
618 #if defined(WEBRTC_ARCH_X86_FAMILY)
619 case Aec3Optimization::kSse2:
620 aec3::AdaptPartitions_Sse2(render_buffer, G, current_size_partitions_,
621 &H_);
622 break;
623 #endif
624 #if defined(WEBRTC_HAS_NEON)
625 case Aec3Optimization::kNeon:
626 aec3::AdaptPartitions_Neon(render_buffer, G, current_size_partitions_,
627 &H_);
628 break;
629 #endif
630 default:
631 aec3::AdaptPartitions(render_buffer, G, current_size_partitions_, &H_);
632 }
633 }
634
635 // Constrains the partition of the frequency domain filter to be limited in
636 // time via setting the relevant time-domain coefficients to zero and updates
637 // the corresponding values in an externally stored impulse response estimate.
ConstrainAndUpdateImpulseResponse(std::vector<float> * impulse_response)638 void AdaptiveFirFilter::ConstrainAndUpdateImpulseResponse(
639 std::vector<float>* impulse_response) {
640 RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_),
641 impulse_response->capacity());
642 impulse_response->resize(GetTimeDomainLength(current_size_partitions_));
643 std::array<float, kFftLength> h;
644 impulse_response->resize(GetTimeDomainLength(current_size_partitions_));
645 std::fill(
646 impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2,
647 impulse_response->begin() + (partition_to_constrain_ + 1) * kFftLengthBy2,
648 0.f);
649
650 for (size_t ch = 0; ch < num_render_channels_; ++ch) {
651 fft_.Ifft(H_[partition_to_constrain_][ch], &h);
652
653 static constexpr float kScale = 1.0f / kFftLengthBy2;
654 std::for_each(h.begin(), h.begin() + kFftLengthBy2,
655 [](float& a) { a *= kScale; });
656 std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
657
658 if (ch == 0) {
659 std::copy(
660 h.begin(), h.begin() + kFftLengthBy2,
661 impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2);
662 } else {
663 for (size_t k = 0, j = partition_to_constrain_ * kFftLengthBy2;
664 k < kFftLengthBy2; ++k, ++j) {
665 if (fabsf((*impulse_response)[j]) < fabsf(h[k])) {
666 (*impulse_response)[j] = h[k];
667 }
668 }
669 }
670
671 fft_.Fft(&h, &H_[partition_to_constrain_][ch]);
672 }
673
674 partition_to_constrain_ =
675 partition_to_constrain_ < (current_size_partitions_ - 1)
676 ? partition_to_constrain_ + 1
677 : 0;
678 }
679
680 // Constrains the a partiton of the frequency domain filter to be limited in
681 // time via setting the relevant time-domain coefficients to zero.
Constrain()682 void AdaptiveFirFilter::Constrain() {
683 std::array<float, kFftLength> h;
684 for (size_t ch = 0; ch < num_render_channels_; ++ch) {
685 fft_.Ifft(H_[partition_to_constrain_][ch], &h);
686
687 static constexpr float kScale = 1.0f / kFftLengthBy2;
688 std::for_each(h.begin(), h.begin() + kFftLengthBy2,
689 [](float& a) { a *= kScale; });
690 std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
691
692 fft_.Fft(&h, &H_[partition_to_constrain_][ch]);
693 }
694
695 partition_to_constrain_ =
696 partition_to_constrain_ < (current_size_partitions_ - 1)
697 ? partition_to_constrain_ + 1
698 : 0;
699 }
700
ScaleFilter(float factor)701 void AdaptiveFirFilter::ScaleFilter(float factor) {
702 for (auto& H_p : H_) {
703 for (auto& H_p_ch : H_p) {
704 for (auto& re : H_p_ch.re) {
705 re *= factor;
706 }
707 for (auto& im : H_p_ch.im) {
708 im *= factor;
709 }
710 }
711 }
712 }
713
714 // Set the filter coefficients.
SetFilter(size_t num_partitions,const std::vector<std::vector<FftData>> & H)715 void AdaptiveFirFilter::SetFilter(size_t num_partitions,
716 const std::vector<std::vector<FftData>>& H) {
717 const size_t min_num_partitions =
718 std::min(current_size_partitions_, num_partitions);
719 for (size_t p = 0; p < min_num_partitions; ++p) {
720 RTC_DCHECK_EQ(H_[p].size(), H[p].size());
721 RTC_DCHECK_EQ(num_render_channels_, H_[p].size());
722
723 for (size_t ch = 0; ch < num_render_channels_; ++ch) {
724 std::copy(H[p][ch].re.begin(), H[p][ch].re.end(), H_[p][ch].re.begin());
725 std::copy(H[p][ch].im.begin(), H[p][ch].im.end(), H_[p][ch].im.begin());
726 }
727 }
728 }
729
730 } // namespace webrtc
731