• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
12 
13 #include <stdlib.h>
14 
15 #include <algorithm>
16 #include <cmath>
17 #include <cstddef>
18 #include <numeric>
19 
20 #include "modules/audio_processing/agc2/rnn_vad/common.h"
21 #include "rtc_base/checks.h"
22 
23 namespace webrtc {
24 namespace rnn_vad {
25 namespace {
26 
27 // Converts a lag to an inverted lag (only for 24kHz).
GetInvertedLag(size_t lag)28 size_t GetInvertedLag(size_t lag) {
29   RTC_DCHECK_LE(lag, kMaxPitch24kHz);
30   return kMaxPitch24kHz - lag;
31 }
32 
ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,size_t inv_lag,size_t max_pitch_period)33 float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,
34                                   size_t inv_lag,
35                                   size_t max_pitch_period) {
36   RTC_DCHECK_LT(inv_lag, pitch_buf.size());
37   RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
38   RTC_DCHECK_LE(inv_lag, max_pitch_period);
39   // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
40   return std::inner_product(pitch_buf.begin() + max_pitch_period,
41                             pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f);
42 }
43 
44 // Computes a pseudo-interpolation offset for an estimated pitch period |lag| by
45 // looking at the auto-correlation coefficients in the neighborhood of |lag|.
46 // (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output
47 // is a lag in {-1, 0, +1}.
48 // TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it
49 // is relevant only if the spectral analysis works at a sample rate that is
50 // twice as that of the pitch buffer (not so important instead for the estimated
51 // pitch period feature fed into the RNN).
GetPitchPseudoInterpolationOffset(size_t lag,float prev_auto_corr,float lag_auto_corr,float next_auto_corr)52 int GetPitchPseudoInterpolationOffset(size_t lag,
53                                       float prev_auto_corr,
54                                       float lag_auto_corr,
55                                       float next_auto_corr) {
56   const float& a = prev_auto_corr;
57   const float& b = lag_auto_corr;
58   const float& c = next_auto_corr;
59 
60   int offset = 0;
61   if ((c - a) > 0.7f * (b - a)) {
62     offset = 1;  // |c| is the largest auto-correlation coefficient.
63   } else if ((a - c) > 0.7f * (b - c)) {
64     offset = -1;  // |a| is the largest auto-correlation coefficient.
65   }
66   return offset;
67 }
68 
69 // Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The
70 // output sample rate is twice as that of |lag|.
PitchPseudoInterpolationLagPitchBuf(size_t lag,rtc::ArrayView<const float,kBufSize24kHz> pitch_buf)71 size_t PitchPseudoInterpolationLagPitchBuf(
72     size_t lag,
73     rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
74   int offset = 0;
75   // Cannot apply pseudo-interpolation at the boundaries.
76   if (lag > 0 && lag < kMaxPitch24kHz) {
77     offset = GetPitchPseudoInterpolationOffset(
78         lag,
79         ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1),
80                                     kMaxPitch24kHz),
81         ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag),
82                                     kMaxPitch24kHz),
83         ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1),
84                                     kMaxPitch24kHz));
85   }
86   return 2 * lag + offset;
87 }
88 
89 // Refines a pitch period |inv_lag| encoded as inverted lag with
90 // pseudo-interpolation. The output sample rate is twice as that of
91 // |inv_lag|.
PitchPseudoInterpolationInvLagAutoCorr(size_t inv_lag,rtc::ArrayView<const float> auto_corr)92 size_t PitchPseudoInterpolationInvLagAutoCorr(
93     size_t inv_lag,
94     rtc::ArrayView<const float> auto_corr) {
95   int offset = 0;
96   // Cannot apply pseudo-interpolation at the boundaries.
97   if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) {
98     offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1],
99                                                auto_corr[inv_lag],
100                                                auto_corr[inv_lag - 1]);
101   }
102   // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
103   // be subtracted since |inv_lag| is an inverted lag but offset is a lag.
104   return 2 * inv_lag + offset;
105 }
106 
107 // Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when
108 // looking for sub-harmonics.
109 // The values have been chosen to serve the following algorithm. Given the
110 // initial pitch period T, we examine whether one of its harmonics is the true
111 // fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of
112 // these harmonics, in addition to the pitch gain of itself, we choose one
113 // multiple of its pitch period, n*T/k, to validate it (by averaging their pitch
114 // gains). The multiplier n is chosen so that n*T/k is used only one time over
115 // all k. When for example k = 4, we should also expect a peak at 3*T/4. When
116 // k = 8 instead we don't want to look at 2*T/8, since we have already checked
117 // T/4 before. Instead, we look at T*3/8.
118 // The array can be generate in Python as follows:
119 //   from fractions import Fraction
120 //   # Smallest positive integer not in X.
121 //   def mex(X):
122 //     for i in range(1, int(max(X)+2)):
123 //       if i not in X:
124 //         return i
125 //   # Visited multiples of the period.
126 //   S = {1}
127 //   for n in range(2, 16):
128 //     sn = mex({n * i for i in S} | {1})
129 //     S = S | {Fraction(1, n), Fraction(sn, n)}
130 //     print(sn, end=', ')
131 constexpr std::array<int, 14> kSubHarmonicMultipliers = {
132     {3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
133 
134 // Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
135 // a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
136 constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
137     {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
138 
139 }  // namespace
140 
Decimate2x(rtc::ArrayView<const float,kBufSize24kHz> src,rtc::ArrayView<float,kBufSize12kHz> dst)141 void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
142                 rtc::ArrayView<float, kBufSize12kHz> dst) {
143   // TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
144   static_assert(2 * dst.size() == src.size(), "");
145   for (size_t i = 0; i < dst.size(); ++i) {
146     dst[i] = src[2 * i];
147   }
148 }
149 
ComputePitchGainThreshold(int candidate_pitch_period,int pitch_period_ratio,int initial_pitch_period,float initial_pitch_gain,int prev_pitch_period,float prev_pitch_gain)150 float ComputePitchGainThreshold(int candidate_pitch_period,
151                                 int pitch_period_ratio,
152                                 int initial_pitch_period,
153                                 float initial_pitch_gain,
154                                 int prev_pitch_period,
155                                 float prev_pitch_gain) {
156   // Map arguments to more compact aliases.
157   const int& t1 = candidate_pitch_period;
158   const int& k = pitch_period_ratio;
159   const int& t0 = initial_pitch_period;
160   const float& g0 = initial_pitch_gain;
161   const int& t_prev = prev_pitch_period;
162   const float& g_prev = prev_pitch_gain;
163 
164   // Validate input.
165   RTC_DCHECK_GE(t1, 0);
166   RTC_DCHECK_GE(k, 2);
167   RTC_DCHECK_GE(t0, 0);
168   RTC_DCHECK_GE(t_prev, 0);
169 
170   // Compute a term that lowers the threshold when |t1| is close to the last
171   // estimated period |t_prev| - i.e., pitch tracking.
172   float lower_threshold_term = 0;
173   if (abs(t1 - t_prev) <= 1) {
174     // The candidate pitch period is within 1 sample from the previous one.
175     // Make the candidate at |t1| very easy to be accepted.
176     lower_threshold_term = g_prev;
177   } else if (abs(t1 - t_prev) == 2 &&
178              t0 > kInitialPitchPeriodThresholds[k - 2]) {
179     // The candidate pitch period is 2 samples far from the previous one and the
180     // period |t0| (from which |t1| has been derived) is greater than a
181     // threshold. Make |t1| easy to be accepted.
182     lower_threshold_term = 0.5f * g_prev;
183   }
184   // Set the threshold based on the gain of the initial estimate |t0|. Also
185   // reduce the chance of false positives caused by a bias towards high
186   // frequencies (originating from short-term correlations).
187   float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
188   if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
189     // High frequency.
190     threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
191   } else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
192     // Even higher frequency.
193     threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
194   }
195   return threshold;
196 }
197 
ComputeSlidingFrameSquareEnergies(rtc::ArrayView<const float,kBufSize24kHz> pitch_buf,rtc::ArrayView<float,kMaxPitch24kHz+1> yy_values)198 void ComputeSlidingFrameSquareEnergies(
199     rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
200     rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) {
201   float yy =
202       ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz);
203   yy_values[0] = yy;
204   for (size_t i = 1; i < yy_values.size(); ++i) {
205     RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz);
206     RTC_DCHECK_LE(i, kMaxPitch24kHz);
207     const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i];
208     const float new_coeff = pitch_buf[kMaxPitch24kHz - i];
209     yy -= old_coeff * old_coeff;
210     yy += new_coeff * new_coeff;
211     yy = std::max(0.f, yy);
212     yy_values[i] = yy;
213   }
214 }
215 
FindBestPitchPeriods(rtc::ArrayView<const float> auto_corr,rtc::ArrayView<const float> pitch_buf,size_t max_pitch_period)216 std::array<size_t, 2> FindBestPitchPeriods(
217     rtc::ArrayView<const float> auto_corr,
218     rtc::ArrayView<const float> pitch_buf,
219     size_t max_pitch_period) {
220   // Stores a pitch candidate period and strength information.
221   struct PitchCandidate {
222     // Pitch period encoded as inverted lag.
223     size_t period_inverted_lag = 0;
224     // Pitch strength encoded as a ratio.
225     float strength_numerator = -1.f;
226     float strength_denominator = 0.f;
227     // Compare the strength of two pitch candidates.
228     bool HasStrongerPitchThan(const PitchCandidate& b) const {
229       // Comparing the numerator/denominator ratios without using divisions.
230       return strength_numerator * b.strength_denominator >
231              b.strength_numerator * strength_denominator;
232     }
233   };
234 
235   RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
236   RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
237   const size_t frame_size = pitch_buf.size() - max_pitch_period;
238   // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
239   float yy =
240       std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1,
241                          pitch_buf.begin(), 1.f);
242   // Search best and second best pitches by looking at the scaled
243   // auto-correlation.
244   PitchCandidate candidate;
245   PitchCandidate best;
246   PitchCandidate second_best;
247   second_best.period_inverted_lag = 1;
248   for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
249     // A pitch candidate must have positive correlation.
250     if (auto_corr[inv_lag] > 0) {
251       candidate.period_inverted_lag = inv_lag;
252       candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag];
253       candidate.strength_denominator = yy;
254       if (candidate.HasStrongerPitchThan(second_best)) {
255         if (candidate.HasStrongerPitchThan(best)) {
256           second_best = best;
257           best = candidate;
258         } else {
259           second_best = candidate;
260         }
261       }
262     }
263     // Update |squared_energy_y| for the next inverted lag.
264     const float old_coeff = pitch_buf[inv_lag];
265     const float new_coeff = pitch_buf[inv_lag + frame_size];
266     yy -= old_coeff * old_coeff;
267     yy += new_coeff * new_coeff;
268     yy = std::max(0.f, yy);
269   }
270   return {{best.period_inverted_lag, second_best.period_inverted_lag}};
271 }
272 
RefinePitchPeriod48kHz(rtc::ArrayView<const float,kBufSize24kHz> pitch_buf,rtc::ArrayView<const size_t,2> inv_lags)273 size_t RefinePitchPeriod48kHz(
274     rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
275     rtc::ArrayView<const size_t, 2> inv_lags) {
276   // Compute the auto-correlation terms only for neighbors of the given pitch
277   // candidates (similar to what is done in ComputePitchAutoCorrelation(), but
278   // for a few lag values).
279   std::array<float, kNumInvertedLags24kHz> auto_corr;
280   auto_corr.fill(0.f);  // Zeros become ignored lags in FindBestPitchPeriods().
281   auto is_neighbor = [](size_t i, size_t j) {
282     return ((i > j) ? (i - j) : (j - i)) <= 2;
283   };
284   for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
285     if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1]))
286       auto_corr[inv_lag] =
287           ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz);
288   }
289   // Find best pitch at 24 kHz.
290   const auto pitch_candidates_inv_lags = FindBestPitchPeriods(
291       {auto_corr.data(), auto_corr.size()},
292       {pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz);
293   const auto inv_lag = pitch_candidates_inv_lags[0];  // Refine the best.
294   // Pseudo-interpolation.
295   return PitchPseudoInterpolationInvLagAutoCorr(inv_lag, auto_corr);
296 }
297 
CheckLowerPitchPeriodsAndComputePitchGain(rtc::ArrayView<const float,kBufSize24kHz> pitch_buf,int initial_pitch_period_48kHz,PitchInfo prev_pitch_48kHz)298 PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
299     rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
300     int initial_pitch_period_48kHz,
301     PitchInfo prev_pitch_48kHz) {
302   RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
303   RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
304   // Stores information for a refined pitch candidate.
305   struct RefinedPitchCandidate {
306     RefinedPitchCandidate() {}
307     RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
308         : period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
309     int period_24kHz;
310     // Pitch strength information.
311     float gain;
312     // Additional pitch strength information used for the final estimation of
313     // pitch gain.
314     float xy;  // Cross-correlation.
315     float yy;  // Auto-correlation.
316   };
317 
318   // Initialize.
319   std::array<float, kMaxPitch24kHz + 1> yy_values;
320   ComputeSlidingFrameSquareEnergies(pitch_buf,
321                                     {yy_values.data(), yy_values.size()});
322   const float xx = yy_values[0];
323   // Helper lambdas.
324   const auto pitch_gain = [](float xy, float yy, float xx) {
325     RTC_DCHECK_LE(0.f, xx * yy);
326     return xy / std::sqrt(1.f + xx * yy);
327   };
328   // Initial pitch candidate gain.
329   RefinedPitchCandidate best_pitch;
330   best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
331                                      static_cast<int>(kMaxPitch24kHz - 1));
332   best_pitch.xy = ComputeAutoCorrelationCoeff(
333       pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
334   best_pitch.yy = yy_values[best_pitch.period_24kHz];
335   best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx);
336 
337   // Store the initial pitch period information.
338   const size_t initial_pitch_period = best_pitch.period_24kHz;
339   const float initial_pitch_gain = best_pitch.gain;
340 
341   // Given the initial pitch estimation, check lower periods (i.e., harmonics).
342   const auto alternative_period = [](int period, int k, int n) -> int {
343     RTC_DCHECK_GT(k, 0);
344     return (2 * n * period + k) / (2 * k);  // Same as round(n*period/k).
345   };
346   for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
347        ++k) {
348     int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
349     if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
350       break;
351     }
352     // When looking at |candidate_pitch_period|, we also look at one of its
353     // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
354     // |k| == 2 is a special case since |candidate_pitch_secondary_period| might
355     // be greater than the maximum pitch period.
356     int candidate_pitch_secondary_period = alternative_period(
357         initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
358     RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
359     if (k == 2 &&
360         candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
361       candidate_pitch_secondary_period = initial_pitch_period;
362     }
363     RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
364         << "The lower pitch period and the additional sub-harmonic must not "
365            "coincide.";
366     // Compute an auto-correlation score for the primary pitch candidate
367     // |candidate_pitch_period| by also looking at its possible sub-harmonic
368     // |candidate_pitch_secondary_period|.
369     float xy_primary_period = ComputeAutoCorrelationCoeff(
370         pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz);
371     float xy_secondary_period = ComputeAutoCorrelationCoeff(
372         pitch_buf, GetInvertedLag(candidate_pitch_secondary_period),
373         kMaxPitch24kHz);
374     float xy = 0.5f * (xy_primary_period + xy_secondary_period);
375     float yy = 0.5f * (yy_values[candidate_pitch_period] +
376                        yy_values[candidate_pitch_secondary_period]);
377     float candidate_pitch_gain = pitch_gain(xy, yy, xx);
378 
379     // Maybe update best period.
380     float threshold = ComputePitchGainThreshold(
381         candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain,
382         prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain);
383     if (candidate_pitch_gain > threshold) {
384       best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy};
385     }
386   }
387 
388   // Final pitch gain and period.
389   best_pitch.xy = std::max(0.f, best_pitch.xy);
390   RTC_DCHECK_LE(0.f, best_pitch.yy);
391   float final_pitch_gain = (best_pitch.yy <= best_pitch.xy)
392                                ? 1.f
393                                : best_pitch.xy / (best_pitch.yy + 1.f);
394   final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
395   int final_pitch_period_48kHz = std::max(
396       kMinPitch48kHz,
397       PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
398 
399   return {final_pitch_period_48kHz, final_pitch_gain};
400 }
401 
402 }  // namespace rnn_vad
403 }  // namespace webrtc
404