• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33 
34 #include <float.h>
35 
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/mem_internal.h"
40 #include "libavutil/opt.h"
41 #include "libavutil/tx.h"
42 #include "avfilter.h"
43 #include "audio.h"
44 #include "filters.h"
45 #include "formats.h"
46 
47 #define FRAME_SIZE_SHIFT 2
48 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49 #define WINDOW_SIZE (2*FRAME_SIZE)
50 #define FREQ_SIZE (FRAME_SIZE + 1)
51 
52 #define PITCH_MIN_PERIOD 60
53 #define PITCH_MAX_PERIOD 768
54 #define PITCH_FRAME_SIZE 960
55 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56 
57 #define SQUARE(x) ((x)*(x))
58 
59 #define NB_BANDS 22
60 
61 #define CEPS_MEM 8
62 #define NB_DELTA_CEPS 6
63 
64 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65 
66 #define WEIGHTS_SCALE (1.f/256)
67 
68 #define MAX_NEURONS 128
69 
70 #define ACTIVATION_TANH    0
71 #define ACTIVATION_SIGMOID 1
72 #define ACTIVATION_RELU    2
73 
74 #define Q15ONE 1.0f
75 
76 typedef struct DenseLayer {
77     const float *bias;
78     const float *input_weights;
79     int nb_inputs;
80     int nb_neurons;
81     int activation;
82 } DenseLayer;
83 
84 typedef struct GRULayer {
85     const float *bias;
86     const float *input_weights;
87     const float *recurrent_weights;
88     int nb_inputs;
89     int nb_neurons;
90     int activation;
91 } GRULayer;
92 
93 typedef struct RNNModel {
94     int input_dense_size;
95     const DenseLayer *input_dense;
96 
97     int vad_gru_size;
98     const GRULayer *vad_gru;
99 
100     int noise_gru_size;
101     const GRULayer *noise_gru;
102 
103     int denoise_gru_size;
104     const GRULayer *denoise_gru;
105 
106     int denoise_output_size;
107     const DenseLayer *denoise_output;
108 
109     int vad_output_size;
110     const DenseLayer *vad_output;
111 } RNNModel;
112 
113 typedef struct RNNState {
114     float *vad_gru_state;
115     float *noise_gru_state;
116     float *denoise_gru_state;
117     RNNModel *model;
118 } RNNState;
119 
120 typedef struct DenoiseState {
121     float analysis_mem[FRAME_SIZE];
122     float cepstral_mem[CEPS_MEM][NB_BANDS];
123     int memid;
124     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125     float pitch_buf[PITCH_BUF_SIZE];
126     float pitch_enh_buf[PITCH_BUF_SIZE];
127     float last_gain;
128     int last_period;
129     float mem_hp_x[2];
130     float lastg[NB_BANDS];
131     float history[FRAME_SIZE];
132     RNNState rnn[2];
133     AVTXContext *tx, *txi;
134     av_tx_fn tx_fn, txi_fn;
135 } DenoiseState;
136 
137 typedef struct AudioRNNContext {
138     const AVClass *class;
139 
140     char *model_name;
141     float mix;
142 
143     int channels;
144     DenoiseState *st;
145 
146     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147     DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148 
149     RNNModel *model[2];
150 
151     AVFloatDSPContext *fdsp;
152 } AudioRNNContext;
153 
154 #define F_ACTIVATION_TANH       0
155 #define F_ACTIVATION_SIGMOID    1
156 #define F_ACTIVATION_RELU       2
157 
rnnoise_model_free(RNNModel * model)158 static void rnnoise_model_free(RNNModel *model)
159 {
160 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161 #define FREE_DENSE(name) do { \
162     if (model->name) { \
163         av_free((void *) model->name->input_weights); \
164         av_free((void *) model->name->bias); \
165         av_free((void *) model->name); \
166     } \
167     } while (0)
168 #define FREE_GRU(name) do { \
169     if (model->name) { \
170         av_free((void *) model->name->input_weights); \
171         av_free((void *) model->name->recurrent_weights); \
172         av_free((void *) model->name->bias); \
173         av_free((void *) model->name); \
174     } \
175     } while (0)
176 
177     if (!model)
178         return;
179     FREE_DENSE(input_dense);
180     FREE_GRU(vad_gru);
181     FREE_GRU(noise_gru);
182     FREE_GRU(denoise_gru);
183     FREE_DENSE(denoise_output);
184     FREE_DENSE(vad_output);
185     av_free(model);
186 }
187 
rnnoise_model_from_file(FILE * f,RNNModel ** rnn)188 static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
189 {
190     RNNModel *ret = NULL;
191     DenseLayer *input_dense;
192     GRULayer *vad_gru;
193     GRULayer *noise_gru;
194     GRULayer *denoise_gru;
195     DenseLayer *denoise_output;
196     DenseLayer *vad_output;
197     int in;
198 
199     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200         return AVERROR_INVALIDDATA;
201 
202     ret = av_calloc(1, sizeof(RNNModel));
203     if (!ret)
204         return AVERROR(ENOMEM);
205 
206 #define ALLOC_LAYER(type, name) \
207     name = av_calloc(1, sizeof(type)); \
208     if (!name) { \
209         rnnoise_model_free(ret); \
210         return AVERROR(ENOMEM); \
211     } \
212     ret->name = name
213 
214     ALLOC_LAYER(DenseLayer, input_dense);
215     ALLOC_LAYER(GRULayer, vad_gru);
216     ALLOC_LAYER(GRULayer, noise_gru);
217     ALLOC_LAYER(GRULayer, denoise_gru);
218     ALLOC_LAYER(DenseLayer, denoise_output);
219     ALLOC_LAYER(DenseLayer, vad_output);
220 
221 #define INPUT_VAL(name) do { \
222     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223         rnnoise_model_free(ret); \
224         return AVERROR(EINVAL); \
225     } \
226     name = in; \
227     } while (0)
228 
229 #define INPUT_ACTIVATION(name) do { \
230     int activation; \
231     INPUT_VAL(activation); \
232     switch (activation) { \
233     case F_ACTIVATION_SIGMOID: \
234         name = ACTIVATION_SIGMOID; \
235         break; \
236     case F_ACTIVATION_RELU: \
237         name = ACTIVATION_RELU; \
238         break; \
239     default: \
240         name = ACTIVATION_TANH; \
241     } \
242     } while (0)
243 
244 #define INPUT_ARRAY(name, len) do { \
245     float *values = av_calloc((len), sizeof(float)); \
246     if (!values) { \
247         rnnoise_model_free(ret); \
248         return AVERROR(ENOMEM); \
249     } \
250     name = values; \
251     for (int i = 0; i < (len); i++) { \
252         if (fscanf(f, "%d", &in) != 1) { \
253             rnnoise_model_free(ret); \
254             return AVERROR(EINVAL); \
255         } \
256         values[i] = in; \
257     } \
258     } while (0)
259 
260 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
261     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262     if (!values) { \
263         rnnoise_model_free(ret); \
264         return AVERROR(ENOMEM); \
265     } \
266     name = values; \
267     for (int k = 0; k < (len0); k++) { \
268         for (int i = 0; i < (len2); i++) { \
269             for (int j = 0; j < (len1); j++) { \
270                 if (fscanf(f, "%d", &in) != 1) { \
271                     rnnoise_model_free(ret); \
272                     return AVERROR(EINVAL); \
273                 } \
274                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275             } \
276         } \
277     } \
278     } while (0)
279 
280 #define NEW_LINE() do { \
281     int c; \
282     while ((c = fgetc(f)) != EOF) { \
283         if (c == '\n') \
284         break; \
285     } \
286     } while (0)
287 
288 #define INPUT_DENSE(name) do { \
289     INPUT_VAL(name->nb_inputs); \
290     INPUT_VAL(name->nb_neurons); \
291     ret->name ## _size = name->nb_neurons; \
292     INPUT_ACTIVATION(name->activation); \
293     NEW_LINE(); \
294     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
295     NEW_LINE(); \
296     INPUT_ARRAY(name->bias, name->nb_neurons); \
297     NEW_LINE(); \
298     } while (0)
299 
300 #define INPUT_GRU(name) do { \
301     INPUT_VAL(name->nb_inputs); \
302     INPUT_VAL(name->nb_neurons); \
303     ret->name ## _size = name->nb_neurons; \
304     INPUT_ACTIVATION(name->activation); \
305     NEW_LINE(); \
306     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
307     NEW_LINE(); \
308     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
309     NEW_LINE(); \
310     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
311     NEW_LINE(); \
312     } while (0)
313 
314     INPUT_DENSE(input_dense);
315     INPUT_GRU(vad_gru);
316     INPUT_GRU(noise_gru);
317     INPUT_GRU(denoise_gru);
318     INPUT_DENSE(denoise_output);
319     INPUT_DENSE(vad_output);
320 
321     if (vad_output->nb_neurons != 1) {
322         rnnoise_model_free(ret);
323         return AVERROR(EINVAL);
324     }
325 
326     *rnn = ret;
327 
328     return 0;
329 }
330 
query_formats(AVFilterContext * ctx)331 static int query_formats(AVFilterContext *ctx)
332 {
333     static const enum AVSampleFormat sample_fmts[] = {
334         AV_SAMPLE_FMT_FLTP,
335         AV_SAMPLE_FMT_NONE
336     };
337     int ret, sample_rates[] = { 48000, -1 };
338 
339     ret = ff_set_common_formats_from_list(ctx, sample_fmts);
340     if (ret < 0)
341         return ret;
342 
343     ret = ff_set_common_all_channel_counts(ctx);
344     if (ret < 0)
345         return ret;
346 
347     return ff_set_common_samplerates_from_list(ctx, sample_rates);
348 }
349 
config_input(AVFilterLink * inlink)350 static int config_input(AVFilterLink *inlink)
351 {
352     AVFilterContext *ctx = inlink->dst;
353     AudioRNNContext *s = ctx->priv;
354     int ret = 0;
355 
356     s->channels = inlink->ch_layout.nb_channels;
357 
358     if (!s->st)
359         s->st = av_calloc(s->channels, sizeof(DenoiseState));
360     if (!s->st)
361         return AVERROR(ENOMEM);
362 
363     for (int i = 0; i < s->channels; i++) {
364         DenoiseState *st = &s->st[i];
365 
366         st->rnn[0].model = s->model[0];
367         st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
368         st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
369         st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
370         if (!st->rnn[0].vad_gru_state ||
371             !st->rnn[0].noise_gru_state ||
372             !st->rnn[0].denoise_gru_state)
373             return AVERROR(ENOMEM);
374     }
375 
376     for (int i = 0; i < s->channels; i++) {
377         DenoiseState *st = &s->st[i];
378 
379         if (!st->tx)
380             ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
381         if (ret < 0)
382             return ret;
383 
384         if (!st->txi)
385             ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
386         if (ret < 0)
387             return ret;
388     }
389 
390     return ret;
391 }
392 
biquad(float * y,float mem[2],const float * x,const float * b,const float * a,int N)393 static void biquad(float *y, float mem[2], const float *x,
394                    const float *b, const float *a, int N)
395 {
396     for (int i = 0; i < N; i++) {
397         float xi, yi;
398 
399         xi = x[i];
400         yi = x[i] + mem[0];
401         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
402         mem[1] = (b[1]*xi - a[1]*yi);
403         y[i] = yi;
404     }
405 }
406 
407 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
408 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
409 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
410 
forward_transform(DenoiseState * st,AVComplexFloat * out,const float * in)411 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
412 {
413     AVComplexFloat x[WINDOW_SIZE];
414     AVComplexFloat y[WINDOW_SIZE];
415 
416     for (int i = 0; i < WINDOW_SIZE; i++) {
417         x[i].re = in[i];
418         x[i].im = 0;
419     }
420 
421     st->tx_fn(st->tx, y, x, sizeof(float));
422 
423     RNN_COPY(out, y, FREQ_SIZE);
424 }
425 
inverse_transform(DenoiseState * st,float * out,const AVComplexFloat * in)426 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
427 {
428     AVComplexFloat x[WINDOW_SIZE];
429     AVComplexFloat y[WINDOW_SIZE];
430 
431     RNN_COPY(x, in, FREQ_SIZE);
432 
433     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
434         x[i].re =  x[WINDOW_SIZE - i].re;
435         x[i].im = -x[WINDOW_SIZE - i].im;
436     }
437 
438     st->txi_fn(st->txi, y, x, sizeof(float));
439 
440     for (int i = 0; i < WINDOW_SIZE; i++)
441         out[i] = y[i].re / WINDOW_SIZE;
442 }
443 
444 static const uint8_t eband5ms[] = {
445 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
446   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
447 };
448 
compute_band_energy(float * bandE,const AVComplexFloat * X)449 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
450 {
451     float sum[NB_BANDS] = {0};
452 
453     for (int i = 0; i < NB_BANDS - 1; i++) {
454         int band_size;
455 
456         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
457         for (int j = 0; j < band_size; j++) {
458             float tmp, frac = (float)j / band_size;
459 
460             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
461             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
462             sum[i]     += (1.f - frac) * tmp;
463             sum[i + 1] +=        frac  * tmp;
464         }
465     }
466 
467     sum[0] *= 2;
468     sum[NB_BANDS - 1] *= 2;
469 
470     for (int i = 0; i < NB_BANDS; i++)
471         bandE[i] = sum[i];
472 }
473 
compute_band_corr(float * bandE,const AVComplexFloat * X,const AVComplexFloat * P)474 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
475 {
476     float sum[NB_BANDS] = { 0 };
477 
478     for (int i = 0; i < NB_BANDS - 1; i++) {
479         int band_size;
480 
481         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
482         for (int j = 0; j < band_size; j++) {
483             float tmp, frac = (float)j / band_size;
484 
485             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
486             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
487             sum[i]     += (1 - frac) * tmp;
488             sum[i + 1] +=      frac  * tmp;
489         }
490     }
491 
492     sum[0] *= 2;
493     sum[NB_BANDS-1] *= 2;
494 
495     for (int i = 0; i < NB_BANDS; i++)
496         bandE[i] = sum[i];
497 }
498 
frame_analysis(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,float * Ex,const float * in)499 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
500 {
501     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
502 
503     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
504     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
505     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
506     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
507     forward_transform(st, X, x);
508     compute_band_energy(Ex, X);
509 }
510 
frame_synthesis(AudioRNNContext * s,DenoiseState * st,float * out,const AVComplexFloat * y)511 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
512 {
513     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
514     const float *src = st->history;
515     const float mix = s->mix;
516     const float imix = 1.f - FFMAX(mix, 0.f);
517 
518     inverse_transform(st, x, y);
519     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
520     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
521     RNN_COPY(out, x, FRAME_SIZE);
522     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
523 
524     for (int n = 0; n < FRAME_SIZE; n++)
525         out[n] = out[n] * mix + src[n] * imix;
526 }
527 
xcorr_kernel(const float * x,const float * y,float sum[4],int len)528 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
529 {
530     float y_0, y_1, y_2, y_3 = 0;
531     int j;
532 
533     y_0 = *y++;
534     y_1 = *y++;
535     y_2 = *y++;
536 
537     for (j = 0; j < len - 3; j += 4) {
538         float tmp;
539 
540         tmp = *x++;
541         y_3 = *y++;
542         sum[0] += tmp * y_0;
543         sum[1] += tmp * y_1;
544         sum[2] += tmp * y_2;
545         sum[3] += tmp * y_3;
546         tmp = *x++;
547         y_0 = *y++;
548         sum[0] += tmp * y_1;
549         sum[1] += tmp * y_2;
550         sum[2] += tmp * y_3;
551         sum[3] += tmp * y_0;
552         tmp = *x++;
553         y_1 = *y++;
554         sum[0] += tmp * y_2;
555         sum[1] += tmp * y_3;
556         sum[2] += tmp * y_0;
557         sum[3] += tmp * y_1;
558         tmp = *x++;
559         y_2 = *y++;
560         sum[0] += tmp * y_3;
561         sum[1] += tmp * y_0;
562         sum[2] += tmp * y_1;
563         sum[3] += tmp * y_2;
564     }
565 
566     if (j++ < len) {
567         float tmp = *x++;
568 
569         y_3 = *y++;
570         sum[0] += tmp * y_0;
571         sum[1] += tmp * y_1;
572         sum[2] += tmp * y_2;
573         sum[3] += tmp * y_3;
574     }
575 
576     if (j++ < len) {
577         float tmp=*x++;
578 
579         y_0 = *y++;
580         sum[0] += tmp * y_1;
581         sum[1] += tmp * y_2;
582         sum[2] += tmp * y_3;
583         sum[3] += tmp * y_0;
584     }
585 
586     if (j < len) {
587         float tmp=*x++;
588 
589         y_1 = *y++;
590         sum[0] += tmp * y_2;
591         sum[1] += tmp * y_3;
592         sum[2] += tmp * y_0;
593         sum[3] += tmp * y_1;
594     }
595 }
596 
celt_inner_prod(const float * x,const float * y,int N)597 static inline float celt_inner_prod(const float *x,
598                                     const float *y, int N)
599 {
600     float xy = 0.f;
601 
602     for (int i = 0; i < N; i++)
603         xy += x[i] * y[i];
604 
605     return xy;
606 }
607 
celt_pitch_xcorr(const float * x,const float * y,float * xcorr,int len,int max_pitch)608 static void celt_pitch_xcorr(const float *x, const float *y,
609                              float *xcorr, int len, int max_pitch)
610 {
611     int i;
612 
613     for (i = 0; i < max_pitch - 3; i += 4) {
614         float sum[4] = { 0, 0, 0, 0};
615 
616         xcorr_kernel(x, y + i, sum, len);
617 
618         xcorr[i]     = sum[0];
619         xcorr[i + 1] = sum[1];
620         xcorr[i + 2] = sum[2];
621         xcorr[i + 3] = sum[3];
622     }
623     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
624     for (; i < max_pitch; i++) {
625         xcorr[i] = celt_inner_prod(x, y + i, len);
626     }
627 }
628 
celt_autocorr(const float * x,float * ac,const float * window,int overlap,int lag,int n)629 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
630                          float       *ac,  /* out: [0...lag-1] ac values */
631                          const float *window,
632                          int          overlap,
633                          int          lag,
634                          int          n)
635 {
636     int fastN = n - lag;
637     int shift;
638     const float *xptr;
639     float xx[PITCH_BUF_SIZE>>1];
640 
641     if (overlap == 0) {
642         xptr = x;
643     } else {
644         for (int i = 0; i < n; i++)
645             xx[i] = x[i];
646         for (int i = 0; i < overlap; i++) {
647             xx[i] = x[i] * window[i];
648             xx[n-i-1] = x[n-i-1] * window[i];
649         }
650         xptr = xx;
651     }
652 
653     shift = 0;
654     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
655 
656     for (int k = 0; k <= lag; k++) {
657         float d = 0.f;
658 
659         for (int i = k + fastN; i < n; i++)
660             d += xptr[i] * xptr[i-k];
661         ac[k] += d;
662     }
663 
664     return shift;
665 }
666 
celt_lpc(float * lpc,const float * ac,int p)667 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
668                 const float *ac,   /* in:  [0...p] autocorrelation values  */
669                           int p)
670 {
671     float r, error = ac[0];
672 
673     RNN_CLEAR(lpc, p);
674     if (ac[0] != 0) {
675         for (int i = 0; i < p; i++) {
676             /* Sum up this iteration's reflection coefficient */
677             float rr = 0;
678             for (int j = 0; j < i; j++)
679                 rr += (lpc[j] * ac[i - j]);
680             rr += ac[i + 1];
681             r = -rr/error;
682             /*  Update LPC coefficients and total error */
683             lpc[i] = r;
684             for (int j = 0; j < (i + 1) >> 1; j++) {
685                 float tmp1, tmp2;
686                 tmp1 = lpc[j];
687                 tmp2 = lpc[i-1-j];
688                 lpc[j]     = tmp1 + (r*tmp2);
689                 lpc[i-1-j] = tmp2 + (r*tmp1);
690             }
691 
692             error = error - (r * r *error);
693             /* Bail out once we get 30 dB gain */
694             if (error < .001f * ac[0])
695                 break;
696         }
697     }
698 }
699 
celt_fir5(const float * x,const float * num,float * y,int N,float * mem)700 static void celt_fir5(const float *x,
701                       const float *num,
702                       float *y,
703                       int N,
704                       float *mem)
705 {
706     float num0, num1, num2, num3, num4;
707     float mem0, mem1, mem2, mem3, mem4;
708 
709     num0 = num[0];
710     num1 = num[1];
711     num2 = num[2];
712     num3 = num[3];
713     num4 = num[4];
714     mem0 = mem[0];
715     mem1 = mem[1];
716     mem2 = mem[2];
717     mem3 = mem[3];
718     mem4 = mem[4];
719 
720     for (int i = 0; i < N; i++) {
721         float sum = x[i];
722 
723         sum += (num0*mem0);
724         sum += (num1*mem1);
725         sum += (num2*mem2);
726         sum += (num3*mem3);
727         sum += (num4*mem4);
728         mem4 = mem3;
729         mem3 = mem2;
730         mem2 = mem1;
731         mem1 = mem0;
732         mem0 = x[i];
733         y[i] = sum;
734     }
735 
736     mem[0] = mem0;
737     mem[1] = mem1;
738     mem[2] = mem2;
739     mem[3] = mem3;
740     mem[4] = mem4;
741 }
742 
pitch_downsample(float * x[],float * x_lp,int len,int C)743 static void pitch_downsample(float *x[], float *x_lp,
744                              int len, int C)
745 {
746     float ac[5];
747     float tmp=Q15ONE;
748     float lpc[4], mem[5]={0,0,0,0,0};
749     float lpc2[5];
750     float c1 = .8f;
751 
752     for (int i = 1; i < len >> 1; i++)
753         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
754     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
755     if (C==2) {
756         for (int i = 1; i < len >> 1; i++)
757             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
758         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
759     }
760 
761     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
762 
763     /* Noise floor -40 dB */
764     ac[0] *= 1.0001f;
765     /* Lag windowing */
766     for (int i = 1; i <= 4; i++) {
767         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
768         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
769     }
770 
771     celt_lpc(lpc, ac, 4);
772     for (int i = 0; i < 4; i++) {
773         tmp = .9f * tmp;
774         lpc[i] = (lpc[i] * tmp);
775     }
776     /* Add a zero */
777     lpc2[0] = lpc[0] + .8f;
778     lpc2[1] = lpc[1] + (c1 * lpc[0]);
779     lpc2[2] = lpc[2] + (c1 * lpc[1]);
780     lpc2[3] = lpc[3] + (c1 * lpc[2]);
781     lpc2[4] = (c1 * lpc[3]);
782     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
783 }
784 
dual_inner_prod(const float * x,const float * y01,const float * y02,int N,float * xy1,float * xy2)785 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
786                                    int N, float *xy1, float *xy2)
787 {
788     float xy01 = 0, xy02 = 0;
789 
790     for (int i = 0; i < N; i++) {
791         xy01 += (x[i] * y01[i]);
792         xy02 += (x[i] * y02[i]);
793     }
794 
795     *xy1 = xy01;
796     *xy2 = xy02;
797 }
798 
compute_pitch_gain(float xy,float xx,float yy)799 static float compute_pitch_gain(float xy, float xx, float yy)
800 {
801     return xy / sqrtf(1.f + xx * yy);
802 }
803 
804 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
remove_doubling(float * x,int maxperiod,int minperiod,int N,int * T0_,int prev_period,float prev_gain)805 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
806                              int *T0_, int prev_period, float prev_gain)
807 {
808     int k, i, T, T0;
809     float g, g0;
810     float pg;
811     float xy,xx,yy,xy2;
812     float xcorr[3];
813     float best_xy, best_yy;
814     int offset;
815     int minperiod0;
816     float yy_lookup[PITCH_MAX_PERIOD+1];
817 
818     minperiod0 = minperiod;
819     maxperiod /= 2;
820     minperiod /= 2;
821     *T0_ /= 2;
822     prev_period /= 2;
823     N /= 2;
824     x += maxperiod;
825     if (*T0_>=maxperiod)
826         *T0_=maxperiod-1;
827 
828     T = T0 = *T0_;
829     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
830     yy_lookup[0] = xx;
831     yy=xx;
832     for (i = 1; i <= maxperiod; i++) {
833         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
834         yy_lookup[i] = FFMAX(0, yy);
835     }
836     yy = yy_lookup[T0];
837     best_xy = xy;
838     best_yy = yy;
839     g = g0 = compute_pitch_gain(xy, xx, yy);
840     /* Look for any pitch at T/k */
841     for (k = 2; k <= 15; k++) {
842         int T1, T1b;
843         float g1;
844         float cont=0;
845         float thresh;
846         T1 = (2*T0+k)/(2*k);
847         if (T1 < minperiod)
848             break;
849         /* Look for another strong correlation at T1b */
850         if (k==2)
851         {
852             if (T1+T0>maxperiod)
853                 T1b = T0;
854             else
855                 T1b = T0+T1;
856         } else
857         {
858             T1b = (2*second_check[k]*T0+k)/(2*k);
859         }
860         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
861         xy = .5f * (xy + xy2);
862         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
863         g1 = compute_pitch_gain(xy, xx, yy);
864         if (FFABS(T1-prev_period)<=1)
865             cont = prev_gain;
866         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
867             cont = prev_gain * .5f;
868         else
869             cont = 0;
870         thresh = FFMAX(.3f, (.7f * g0) - cont);
871         /* Bias against very high pitch (very short period) to avoid false-positives
872            due to short-term correlation */
873         if (T1<3*minperiod)
874             thresh = FFMAX(.4f, (.85f * g0) - cont);
875         else if (T1<2*minperiod)
876             thresh = FFMAX(.5f, (.9f * g0) - cont);
877         if (g1 > thresh)
878         {
879             best_xy = xy;
880             best_yy = yy;
881             T = T1;
882             g = g1;
883         }
884     }
885     best_xy = FFMAX(0, best_xy);
886     if (best_yy <= best_xy)
887         pg = Q15ONE;
888     else
889         pg = best_xy/(best_yy + 1);
890 
891     for (k = 0; k < 3; k++)
892         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
893     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
894         offset = 1;
895     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
896         offset = -1;
897     else
898         offset = 0;
899     if (pg > g)
900         pg = g;
901     *T0_ = 2*T+offset;
902 
903     if (*T0_<minperiod0)
904         *T0_=minperiod0;
905     return pg;
906 }
907 
find_best_pitch(float * xcorr,float * y,int len,int max_pitch,int * best_pitch)908 static void find_best_pitch(float *xcorr, float *y, int len,
909                             int max_pitch, int *best_pitch)
910 {
911     float best_num[2];
912     float best_den[2];
913     float Syy = 1.f;
914 
915     best_num[0] = -1;
916     best_num[1] = -1;
917     best_den[0] = 0;
918     best_den[1] = 0;
919     best_pitch[0] = 0;
920     best_pitch[1] = 1;
921 
922     for (int j = 0; j < len; j++)
923         Syy += y[j] * y[j];
924 
925     for (int i = 0; i < max_pitch; i++) {
926         if (xcorr[i]>0) {
927             float num;
928             float xcorr16;
929 
930             xcorr16 = xcorr[i];
931             /* Considering the range of xcorr16, this should avoid both underflows
932                and overflows (inf) when squaring xcorr16 */
933             xcorr16 *= 1e-12f;
934             num = xcorr16 * xcorr16;
935             if ((num * best_den[1]) > (best_num[1] * Syy)) {
936                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
937                     best_num[1] = best_num[0];
938                     best_den[1] = best_den[0];
939                     best_pitch[1] = best_pitch[0];
940                     best_num[0] = num;
941                     best_den[0] = Syy;
942                     best_pitch[0] = i;
943                 } else {
944                     best_num[1] = num;
945                     best_den[1] = Syy;
946                     best_pitch[1] = i;
947                 }
948             }
949         }
950         Syy += y[i+len]*y[i+len] - y[i] * y[i];
951         Syy = FFMAX(1, Syy);
952     }
953 }
954 
pitch_search(const float * x_lp,float * y,int len,int max_pitch,int * pitch)955 static void pitch_search(const float *x_lp, float *y,
956                          int len, int max_pitch, int *pitch)
957 {
958     int lag;
959     int best_pitch[2]={0,0};
960     int offset;
961 
962     float x_lp4[WINDOW_SIZE];
963     float y_lp4[WINDOW_SIZE];
964     float xcorr[WINDOW_SIZE];
965 
966     lag = len+max_pitch;
967 
968     /* Downsample by 2 again */
969     for (int j = 0; j < len >> 2; j++)
970         x_lp4[j] = x_lp[2*j];
971     for (int j = 0; j < lag >> 2; j++)
972         y_lp4[j] = y[2*j];
973 
974     /* Coarse search with 4x decimation */
975 
976     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
977 
978     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
979 
980     /* Finer search with 2x decimation */
981     for (int i = 0; i < max_pitch >> 1; i++) {
982         float sum;
983         xcorr[i] = 0;
984         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
985             continue;
986         sum = celt_inner_prod(x_lp, y+i, len>>1);
987         xcorr[i] = FFMAX(-1, sum);
988     }
989 
990     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
991 
992     /* Refine by pseudo-interpolation */
993     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
994         float a, b, c;
995 
996         a = xcorr[best_pitch[0] - 1];
997         b = xcorr[best_pitch[0]];
998         c = xcorr[best_pitch[0] + 1];
999         if (c - a > .7f * (b - a))
1000             offset = 1;
1001         else if (a - c > .7f * (b-c))
1002             offset = -1;
1003         else
1004             offset = 0;
1005     } else {
1006         offset = 0;
1007     }
1008 
1009     *pitch = 2 * best_pitch[0] - offset;
1010 }
1011 
dct(AudioRNNContext * s,float * out,const float * in)1012 static void dct(AudioRNNContext *s, float *out, const float *in)
1013 {
1014     for (int i = 0; i < NB_BANDS; i++) {
1015         float sum;
1016 
1017         sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1018         out[i] = sum * sqrtf(2.f / 22);
1019     }
1020 }
1021 
compute_frame_features(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,AVComplexFloat * P,float * Ex,float * Ep,float * Exp,float * features,const float * in)1022 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1023                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1024 {
1025     float E = 0;
1026     float *ceps_0, *ceps_1, *ceps_2;
1027     float spec_variability = 0;
1028     LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1029     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1030     float pitch_buf[PITCH_BUF_SIZE>>1];
1031     int pitch_index;
1032     float gain;
1033     float *(pre[1]);
1034     float tmp[NB_BANDS];
1035     float follow, logMax;
1036 
1037     frame_analysis(s, st, X, Ex, in);
1038     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1039     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1040     pre[0] = &st->pitch_buf[0];
1041     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1042     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1043             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1044     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1045 
1046     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1047             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1048     st->last_period = pitch_index;
1049     st->last_gain = gain;
1050 
1051     for (int i = 0; i < WINDOW_SIZE; i++)
1052         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1053 
1054     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1055     forward_transform(st, P, p);
1056     compute_band_energy(Ep, P);
1057     compute_band_corr(Exp, X, P);
1058 
1059     for (int i = 0; i < NB_BANDS; i++)
1060         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1061 
1062     dct(s, tmp, Exp);
1063 
1064     for (int i = 0; i < NB_DELTA_CEPS; i++)
1065         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1066 
1067     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1068     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1069     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1070     logMax = -2;
1071     follow = -2;
1072 
1073     for (int i = 0; i < NB_BANDS; i++) {
1074         Ly[i] = log10f(1e-2f + Ex[i]);
1075         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1076         logMax = FFMAX(logMax, Ly[i]);
1077         follow = FFMAX(follow-1.5, Ly[i]);
1078         E += Ex[i];
1079     }
1080 
1081     if (E < 0.04f) {
1082         /* If there's no audio, avoid messing up the state. */
1083         RNN_CLEAR(features, NB_FEATURES);
1084         return 1;
1085     }
1086 
1087     dct(s, features, Ly);
1088     features[0] -= 12;
1089     features[1] -= 4;
1090     ceps_0 = st->cepstral_mem[st->memid];
1091     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1092     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1093 
1094     for (int i = 0; i < NB_BANDS; i++)
1095         ceps_0[i] = features[i];
1096 
1097     st->memid++;
1098     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1099         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1100         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1101         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1102     }
1103     /* Spectral variability features. */
1104     if (st->memid == CEPS_MEM)
1105         st->memid = 0;
1106 
1107     for (int i = 0; i < CEPS_MEM; i++) {
1108         float mindist = 1e15f;
1109         for (int j = 0; j < CEPS_MEM; j++) {
1110             float dist = 0.f;
1111             for (int k = 0; k < NB_BANDS; k++) {
1112                 float tmp;
1113 
1114                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1115                 dist += tmp*tmp;
1116             }
1117 
1118             if (j != i)
1119                 mindist = FFMIN(mindist, dist);
1120         }
1121 
1122         spec_variability += mindist;
1123     }
1124 
1125     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1126 
1127     return 0;
1128 }
1129 
interp_band_gain(float * g,const float * bandE)1130 static void interp_band_gain(float *g, const float *bandE)
1131 {
1132     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1133 
1134     for (int i = 0; i < NB_BANDS - 1; i++) {
1135         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1136 
1137         for (int j = 0; j < band_size; j++) {
1138             float frac = (float)j / band_size;
1139 
1140             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1141         }
1142     }
1143 }
1144 
pitch_filter(AVComplexFloat * X,const AVComplexFloat * P,const float * Ex,const float * Ep,const float * Exp,const float * g)1145 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1146                          const float *Exp, const float *g)
1147 {
1148     float newE[NB_BANDS];
1149     float r[NB_BANDS];
1150     float norm[NB_BANDS];
1151     float rf[FREQ_SIZE] = {0};
1152     float normf[FREQ_SIZE]={0};
1153 
1154     for (int i = 0; i < NB_BANDS; i++) {
1155         if (Exp[i]>g[i]) r[i] = 1;
1156         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1157         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1158         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1159     }
1160     interp_band_gain(rf, r);
1161     for (int i = 0; i < FREQ_SIZE; i++) {
1162         X[i].re += rf[i]*P[i].re;
1163         X[i].im += rf[i]*P[i].im;
1164     }
1165     compute_band_energy(newE, X);
1166     for (int i = 0; i < NB_BANDS; i++) {
1167         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1168     }
1169     interp_band_gain(normf, norm);
1170     for (int i = 0; i < FREQ_SIZE; i++) {
1171         X[i].re *= normf[i];
1172         X[i].im *= normf[i];
1173     }
1174 }
1175 
1176 static const float tansig_table[201] = {
1177     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1178     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1179     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1180     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1181     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1182     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1183     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1184     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1185     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1186     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1187     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1188     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1189     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1190     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1191     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1192     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1193     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1194     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1195     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1196     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1197     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1198     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1199     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1200     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1201     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1202     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1203     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1204     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1205     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1206     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1207     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1208     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1209     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1210     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1211     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1212     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1213     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1214     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1215     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1216     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1217     1.000000f,
1218 };
1219 
tansig_approx(float x)1220 static inline float tansig_approx(float x)
1221 {
1222     float y, dy;
1223     float sign=1;
1224     int i;
1225 
1226     /* Tests are reversed to catch NaNs */
1227     if (!(x<8))
1228         return 1;
1229     if (!(x>-8))
1230         return -1;
1231     /* Another check in case of -ffast-math */
1232 
1233     if (isnan(x))
1234        return 0;
1235 
1236     if (x < 0) {
1237        x=-x;
1238        sign=-1;
1239     }
1240     i = (int)floor(.5f+25*x);
1241     x -= .04f*i;
1242     y = tansig_table[i];
1243     dy = 1-y*y;
1244     y = y + x*dy*(1 - y*x);
1245     return sign*y;
1246 }
1247 
sigmoid_approx(float x)1248 static inline float sigmoid_approx(float x)
1249 {
1250     return .5f + .5f*tansig_approx(.5f*x);
1251 }
1252 
compute_dense(const DenseLayer * layer,float * output,const float * input)1253 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1254 {
1255     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1256 
1257     for (int i = 0; i < N; i++) {
1258         /* Compute update gate. */
1259         float sum = layer->bias[i];
1260 
1261         for (int j = 0; j < M; j++)
1262             sum += layer->input_weights[j * stride + i] * input[j];
1263 
1264         output[i] = WEIGHTS_SCALE * sum;
1265     }
1266 
1267     if (layer->activation == ACTIVATION_SIGMOID) {
1268         for (int i = 0; i < N; i++)
1269             output[i] = sigmoid_approx(output[i]);
1270     } else if (layer->activation == ACTIVATION_TANH) {
1271         for (int i = 0; i < N; i++)
1272             output[i] = tansig_approx(output[i]);
1273     } else if (layer->activation == ACTIVATION_RELU) {
1274         for (int i = 0; i < N; i++)
1275             output[i] = FFMAX(0, output[i]);
1276     } else {
1277         av_assert0(0);
1278     }
1279 }
1280 
compute_gru(AudioRNNContext * s,const GRULayer * gru,float * state,const float * input)1281 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1282 {
1283     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1284     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1285     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1286     const int M = gru->nb_inputs;
1287     const int N = gru->nb_neurons;
1288     const int AN = FFALIGN(N, 4);
1289     const int AM = FFALIGN(M, 4);
1290     const int stride = 3 * AN, istride = 3 * AM;
1291 
1292     for (int i = 0; i < N; i++) {
1293         /* Compute update gate. */
1294         float sum = gru->bias[i];
1295 
1296         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1297         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1298         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1299     }
1300 
1301     for (int i = 0; i < N; i++) {
1302         /* Compute reset gate. */
1303         float sum = gru->bias[N + i];
1304 
1305         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1306         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1307         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1308     }
1309 
1310     for (int i = 0; i < N; i++) {
1311         /* Compute output. */
1312         float sum = gru->bias[2 * N + i];
1313 
1314         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1315         for (int j = 0; j < N; j++)
1316             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1317 
1318         if (gru->activation == ACTIVATION_SIGMOID)
1319             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1320         else if (gru->activation == ACTIVATION_TANH)
1321             sum = tansig_approx(WEIGHTS_SCALE * sum);
1322         else if (gru->activation == ACTIVATION_RELU)
1323             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1324         else
1325             av_assert0(0);
1326         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1327     }
1328 
1329     RNN_COPY(state, h, N);
1330 }
1331 
1332 #define INPUT_SIZE 42
1333 
compute_rnn(AudioRNNContext * s,RNNState * rnn,float * gains,float * vad,const float * input)1334 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1335 {
1336     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1337     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1338     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1339 
1340     compute_dense(rnn->model->input_dense, dense_out, input);
1341     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1342     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1343 
1344     memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1345     memcpy(noise_input + rnn->model->input_dense_size,
1346            rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1347     memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1348            input, INPUT_SIZE * sizeof(float));
1349 
1350     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1351 
1352     memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1353     memcpy(denoise_input + rnn->model->vad_gru_size,
1354            rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1355     memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1356            input, INPUT_SIZE * sizeof(float));
1357 
1358     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1359     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1360 }
1361 
rnnoise_channel(AudioRNNContext * s,DenoiseState * st,float * out,const float * in,int disabled)1362 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1363                              int disabled)
1364 {
1365     AVComplexFloat X[FREQ_SIZE];
1366     AVComplexFloat P[WINDOW_SIZE];
1367     float x[FRAME_SIZE];
1368     float Ex[NB_BANDS], Ep[NB_BANDS];
1369     LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1370     float features[NB_FEATURES];
1371     float g[NB_BANDS];
1372     float gf[FREQ_SIZE];
1373     float vad_prob = 0;
1374     float *history = st->history;
1375     static const float a_hp[2] = {-1.99599, 0.99600};
1376     static const float b_hp[2] = {-2, 1};
1377     int silence;
1378 
1379     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1380     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1381 
1382     if (!silence && !disabled) {
1383         compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1384         pitch_filter(X, P, Ex, Ep, Exp, g);
1385         for (int i = 0; i < NB_BANDS; i++) {
1386             float alpha = .6f;
1387 
1388             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1389             st->lastg[i] = g[i];
1390         }
1391 
1392         interp_band_gain(gf, g);
1393 
1394         for (int i = 0; i < FREQ_SIZE; i++) {
1395             X[i].re *= gf[i];
1396             X[i].im *= gf[i];
1397         }
1398     }
1399 
1400     frame_synthesis(s, st, out, X);
1401     memcpy(history, in, FRAME_SIZE * sizeof(*history));
1402 
1403     return vad_prob;
1404 }
1405 
1406 typedef struct ThreadData {
1407     AVFrame *in, *out;
1408 } ThreadData;
1409 
rnnoise_channels(AVFilterContext * ctx,void * arg,int jobnr,int nb_jobs)1410 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1411 {
1412     AudioRNNContext *s = ctx->priv;
1413     ThreadData *td = arg;
1414     AVFrame *in = td->in;
1415     AVFrame *out = td->out;
1416     const int start = (out->ch_layout.nb_channels * jobnr) / nb_jobs;
1417     const int end = (out->ch_layout.nb_channels * (jobnr+1)) / nb_jobs;
1418 
1419     for (int ch = start; ch < end; ch++) {
1420         rnnoise_channel(s, &s->st[ch],
1421                         (float *)out->extended_data[ch],
1422                         (const float *)in->extended_data[ch],
1423                         ctx->is_disabled);
1424     }
1425 
1426     return 0;
1427 }
1428 
filter_frame(AVFilterLink * inlink,AVFrame * in)1429 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1430 {
1431     AVFilterContext *ctx = inlink->dst;
1432     AVFilterLink *outlink = ctx->outputs[0];
1433     AVFrame *out = NULL;
1434     ThreadData td;
1435 
1436     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1437     if (!out) {
1438         av_frame_free(&in);
1439         return AVERROR(ENOMEM);
1440     }
1441     out->pts = in->pts;
1442 
1443     td.in = in; td.out = out;
1444     ff_filter_execute(ctx, rnnoise_channels, &td, NULL,
1445                       FFMIN(outlink->ch_layout.nb_channels, ff_filter_get_nb_threads(ctx)));
1446 
1447     av_frame_free(&in);
1448     return ff_filter_frame(outlink, out);
1449 }
1450 
activate(AVFilterContext * ctx)1451 static int activate(AVFilterContext *ctx)
1452 {
1453     AVFilterLink *inlink = ctx->inputs[0];
1454     AVFilterLink *outlink = ctx->outputs[0];
1455     AVFrame *in = NULL;
1456     int ret;
1457 
1458     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1459 
1460     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1461     if (ret < 0)
1462         return ret;
1463 
1464     if (ret > 0)
1465         return filter_frame(inlink, in);
1466 
1467     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1468     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1469 
1470     return FFERROR_NOT_READY;
1471 }
1472 
open_model(AVFilterContext * ctx,RNNModel ** model)1473 static int open_model(AVFilterContext *ctx, RNNModel **model)
1474 {
1475     AudioRNNContext *s = ctx->priv;
1476     int ret;
1477     FILE *f;
1478 
1479     if (!s->model_name)
1480         return AVERROR(EINVAL);
1481     f = avpriv_fopen_utf8(s->model_name, "r");
1482     if (!f) {
1483         av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
1484         return AVERROR(EINVAL);
1485     }
1486 
1487     ret = rnnoise_model_from_file(f, model);
1488     fclose(f);
1489     if (!*model || ret < 0)
1490         return ret;
1491 
1492     return 0;
1493 }
1494 
init(AVFilterContext * ctx)1495 static av_cold int init(AVFilterContext *ctx)
1496 {
1497     AudioRNNContext *s = ctx->priv;
1498     int ret;
1499 
1500     s->fdsp = avpriv_float_dsp_alloc(0);
1501     if (!s->fdsp)
1502         return AVERROR(ENOMEM);
1503 
1504     ret = open_model(ctx, &s->model[0]);
1505     if (ret < 0)
1506         return ret;
1507 
1508     for (int i = 0; i < FRAME_SIZE; i++) {
1509         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1510         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1511     }
1512 
1513     for (int i = 0; i < NB_BANDS; i++) {
1514         for (int j = 0; j < NB_BANDS; j++) {
1515             s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1516             if (j == 0)
1517                 s->dct_table[j][i] *= sqrtf(.5);
1518         }
1519     }
1520 
1521     return 0;
1522 }
1523 
free_model(AVFilterContext * ctx,int n)1524 static void free_model(AVFilterContext *ctx, int n)
1525 {
1526     AudioRNNContext *s = ctx->priv;
1527 
1528     rnnoise_model_free(s->model[n]);
1529     s->model[n] = NULL;
1530 
1531     for (int ch = 0; ch < s->channels && s->st; ch++) {
1532         av_freep(&s->st[ch].rnn[n].vad_gru_state);
1533         av_freep(&s->st[ch].rnn[n].noise_gru_state);
1534         av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1535     }
1536 }
1537 
process_command(AVFilterContext * ctx,const char * cmd,const char * args,char * res,int res_len,int flags)1538 static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1539                            char *res, int res_len, int flags)
1540 {
1541     AudioRNNContext *s = ctx->priv;
1542     int ret;
1543 
1544     ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1545     if (ret < 0)
1546         return ret;
1547 
1548     ret = open_model(ctx, &s->model[1]);
1549     if (ret < 0)
1550         return ret;
1551 
1552     FFSWAP(RNNModel *, s->model[0], s->model[1]);
1553     for (int ch = 0; ch < s->channels; ch++)
1554         FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1555 
1556     ret = config_input(ctx->inputs[0]);
1557     if (ret < 0) {
1558         for (int ch = 0; ch < s->channels; ch++)
1559             FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1560         FFSWAP(RNNModel *, s->model[0], s->model[1]);
1561         return ret;
1562     }
1563 
1564     free_model(ctx, 1);
1565     return 0;
1566 }
1567 
uninit(AVFilterContext * ctx)1568 static av_cold void uninit(AVFilterContext *ctx)
1569 {
1570     AudioRNNContext *s = ctx->priv;
1571 
1572     av_freep(&s->fdsp);
1573     free_model(ctx, 0);
1574     for (int ch = 0; ch < s->channels && s->st; ch++) {
1575         av_tx_uninit(&s->st[ch].tx);
1576         av_tx_uninit(&s->st[ch].txi);
1577     }
1578     av_freep(&s->st);
1579 }
1580 
1581 static const AVFilterPad inputs[] = {
1582     {
1583         .name         = "default",
1584         .type         = AVMEDIA_TYPE_AUDIO,
1585         .config_props = config_input,
1586     },
1587 };
1588 
1589 static const AVFilterPad outputs[] = {
1590     {
1591         .name          = "default",
1592         .type          = AVMEDIA_TYPE_AUDIO,
1593     },
1594 };
1595 
1596 #define OFFSET(x) offsetof(AudioRNNContext, x)
1597 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1598 
1599 static const AVOption arnndn_options[] = {
1600     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1601     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1602     { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1603     { NULL }
1604 };
1605 
1606 AVFILTER_DEFINE_CLASS(arnndn);
1607 
1608 const AVFilter ff_af_arnndn = {
1609     .name          = "arnndn",
1610     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1611     .priv_size     = sizeof(AudioRNNContext),
1612     .priv_class    = &arnndn_class,
1613     .activate      = activate,
1614     .init          = init,
1615     .uninit        = uninit,
1616     FILTER_INPUTS(inputs),
1617     FILTER_OUTPUTS(outputs),
1618     FILTER_QUERY_FUNC(query_formats),
1619     .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1620                      AVFILTER_FLAG_SLICE_THREADS,
1621     .process_command = process_command,
1622 };
1623