• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018 Gregor Richards
3  * Copyright (c) 2017 Mozilla
4  * Copyright (c) 2005-2009 Xiph.Org Foundation
5  * Copyright (c) 2007-2008 CSIRO
6  * Copyright (c) 2008-2011 Octasic Inc.
7  * Copyright (c) Jean-Marc Valin
8  * Copyright (c) 2019 Paul B Mahol
9  *
10  * Redistribution and use in source and binary forms, with or without
11  * modification, are permitted provided that the following conditions
12  * are met:
13  *
14  * - Redistributions of source code must retain the above copyright
15  *   notice, this list of conditions and the following disclaimer.
16  *
17  * - Redistributions in binary form must reproduce the above copyright
18  *   notice, this list of conditions and the following disclaimer in the
19  *   documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  * A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
25  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33 
34 #include <float.h>
35 
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/mem_internal.h"
40 #include "libavutil/opt.h"
41 #include "libavutil/tx.h"
42 #include "avfilter.h"
43 #include "audio.h"
44 #include "filters.h"
45 #include "formats.h"
46 
47 #define FRAME_SIZE_SHIFT 2
48 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49 #define WINDOW_SIZE (2*FRAME_SIZE)
50 #define FREQ_SIZE (FRAME_SIZE + 1)
51 
52 #define PITCH_MIN_PERIOD 60
53 #define PITCH_MAX_PERIOD 768
54 #define PITCH_FRAME_SIZE 960
55 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56 
57 #define SQUARE(x) ((x)*(x))
58 
59 #define NB_BANDS 22
60 
61 #define CEPS_MEM 8
62 #define NB_DELTA_CEPS 6
63 
64 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65 
66 #define WEIGHTS_SCALE (1.f/256)
67 
68 #define MAX_NEURONS 128
69 
70 #define ACTIVATION_TANH    0
71 #define ACTIVATION_SIGMOID 1
72 #define ACTIVATION_RELU    2
73 
74 #define Q15ONE 1.0f
75 
76 typedef struct DenseLayer {
77     const float *bias;
78     const float *input_weights;
79     int nb_inputs;
80     int nb_neurons;
81     int activation;
82 } DenseLayer;
83 
84 typedef struct GRULayer {
85     const float *bias;
86     const float *input_weights;
87     const float *recurrent_weights;
88     int nb_inputs;
89     int nb_neurons;
90     int activation;
91 } GRULayer;
92 
93 typedef struct RNNModel {
94     int input_dense_size;
95     const DenseLayer *input_dense;
96 
97     int vad_gru_size;
98     const GRULayer *vad_gru;
99 
100     int noise_gru_size;
101     const GRULayer *noise_gru;
102 
103     int denoise_gru_size;
104     const GRULayer *denoise_gru;
105 
106     int denoise_output_size;
107     const DenseLayer *denoise_output;
108 
109     int vad_output_size;
110     const DenseLayer *vad_output;
111 } RNNModel;
112 
113 typedef struct RNNState {
114     float *vad_gru_state;
115     float *noise_gru_state;
116     float *denoise_gru_state;
117     RNNModel *model;
118 } RNNState;
119 
120 typedef struct DenoiseState {
121     float analysis_mem[FRAME_SIZE];
122     float cepstral_mem[CEPS_MEM][NB_BANDS];
123     int memid;
124     DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125     float pitch_buf[PITCH_BUF_SIZE];
126     float pitch_enh_buf[PITCH_BUF_SIZE];
127     float last_gain;
128     int last_period;
129     float mem_hp_x[2];
130     float lastg[NB_BANDS];
131     float history[FRAME_SIZE];
132     RNNState rnn[2];
133     AVTXContext *tx, *txi;
134     av_tx_fn tx_fn, txi_fn;
135 } DenoiseState;
136 
137 typedef struct AudioRNNContext {
138     const AVClass *class;
139 
140     char *model_name;
141     float mix;
142 
143     int channels;
144     DenoiseState *st;
145 
146     DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147     DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148 
149     RNNModel *model[2];
150 
151     AVFloatDSPContext *fdsp;
152 } AudioRNNContext;
153 
154 #define F_ACTIVATION_TANH       0
155 #define F_ACTIVATION_SIGMOID    1
156 #define F_ACTIVATION_RELU       2
157 
rnnoise_model_free(RNNModel * model)158 static void rnnoise_model_free(RNNModel *model)
159 {
160 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161 #define FREE_DENSE(name) do { \
162     if (model->name) { \
163         av_free((void *) model->name->input_weights); \
164         av_free((void *) model->name->bias); \
165         av_free((void *) model->name); \
166     } \
167     } while (0)
168 #define FREE_GRU(name) do { \
169     if (model->name) { \
170         av_free((void *) model->name->input_weights); \
171         av_free((void *) model->name->recurrent_weights); \
172         av_free((void *) model->name->bias); \
173         av_free((void *) model->name); \
174     } \
175     } while (0)
176 
177     if (!model)
178         return;
179     FREE_DENSE(input_dense);
180     FREE_GRU(vad_gru);
181     FREE_GRU(noise_gru);
182     FREE_GRU(denoise_gru);
183     FREE_DENSE(denoise_output);
184     FREE_DENSE(vad_output);
185     av_free(model);
186 }
187 
rnnoise_model_from_file(FILE * f,RNNModel ** rnn)188 static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
189 {
190     RNNModel *ret = NULL;
191     DenseLayer *input_dense;
192     GRULayer *vad_gru;
193     GRULayer *noise_gru;
194     GRULayer *denoise_gru;
195     DenseLayer *denoise_output;
196     DenseLayer *vad_output;
197     int in;
198 
199     if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200         return AVERROR_INVALIDDATA;
201 
202     ret = av_calloc(1, sizeof(RNNModel));
203     if (!ret)
204         return AVERROR(ENOMEM);
205 
206 #define ALLOC_LAYER(type, name) \
207     name = av_calloc(1, sizeof(type)); \
208     if (!name) { \
209         rnnoise_model_free(ret); \
210         return AVERROR(ENOMEM); \
211     } \
212     ret->name = name
213 
214     ALLOC_LAYER(DenseLayer, input_dense);
215     ALLOC_LAYER(GRULayer, vad_gru);
216     ALLOC_LAYER(GRULayer, noise_gru);
217     ALLOC_LAYER(GRULayer, denoise_gru);
218     ALLOC_LAYER(DenseLayer, denoise_output);
219     ALLOC_LAYER(DenseLayer, vad_output);
220 
221 #define INPUT_VAL(name) do { \
222     if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223         rnnoise_model_free(ret); \
224         return AVERROR(EINVAL); \
225     } \
226     name = in; \
227     } while (0)
228 
229 #define INPUT_ACTIVATION(name) do { \
230     int activation; \
231     INPUT_VAL(activation); \
232     switch (activation) { \
233     case F_ACTIVATION_SIGMOID: \
234         name = ACTIVATION_SIGMOID; \
235         break; \
236     case F_ACTIVATION_RELU: \
237         name = ACTIVATION_RELU; \
238         break; \
239     default: \
240         name = ACTIVATION_TANH; \
241     } \
242     } while (0)
243 
244 #define INPUT_ARRAY(name, len) do { \
245     float *values = av_calloc((len), sizeof(float)); \
246     if (!values) { \
247         rnnoise_model_free(ret); \
248         return AVERROR(ENOMEM); \
249     } \
250     name = values; \
251     for (int i = 0; i < (len); i++) { \
252         if (fscanf(f, "%d", &in) != 1) { \
253             rnnoise_model_free(ret); \
254             return AVERROR(EINVAL); \
255         } \
256         values[i] = in; \
257     } \
258     } while (0)
259 
260 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
261     float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262     if (!values) { \
263         rnnoise_model_free(ret); \
264         return AVERROR(ENOMEM); \
265     } \
266     name = values; \
267     for (int k = 0; k < (len0); k++) { \
268         for (int i = 0; i < (len2); i++) { \
269             for (int j = 0; j < (len1); j++) { \
270                 if (fscanf(f, "%d", &in) != 1) { \
271                     rnnoise_model_free(ret); \
272                     return AVERROR(EINVAL); \
273                 } \
274                 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275             } \
276         } \
277     } \
278     } while (0)
279 
280 #define NEW_LINE() do { \
281     int c; \
282     while ((c = fgetc(f)) != EOF) { \
283         if (c == '\n') \
284         break; \
285     } \
286     } while (0)
287 
288 #define INPUT_DENSE(name) do { \
289     INPUT_VAL(name->nb_inputs); \
290     INPUT_VAL(name->nb_neurons); \
291     ret->name ## _size = name->nb_neurons; \
292     INPUT_ACTIVATION(name->activation); \
293     NEW_LINE(); \
294     INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
295     NEW_LINE(); \
296     INPUT_ARRAY(name->bias, name->nb_neurons); \
297     NEW_LINE(); \
298     } while (0)
299 
300 #define INPUT_GRU(name) do { \
301     INPUT_VAL(name->nb_inputs); \
302     INPUT_VAL(name->nb_neurons); \
303     ret->name ## _size = name->nb_neurons; \
304     INPUT_ACTIVATION(name->activation); \
305     NEW_LINE(); \
306     INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
307     NEW_LINE(); \
308     INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
309     NEW_LINE(); \
310     INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
311     NEW_LINE(); \
312     } while (0)
313 
314     INPUT_DENSE(input_dense);
315     INPUT_GRU(vad_gru);
316     INPUT_GRU(noise_gru);
317     INPUT_GRU(denoise_gru);
318     INPUT_DENSE(denoise_output);
319     INPUT_DENSE(vad_output);
320 
321     if (vad_output->nb_neurons != 1) {
322         rnnoise_model_free(ret);
323         return AVERROR(EINVAL);
324     }
325 
326     *rnn = ret;
327 
328     return 0;
329 }
330 
query_formats(AVFilterContext * ctx)331 static int query_formats(AVFilterContext *ctx)
332 {
333     AVFilterFormats *formats = NULL;
334     AVFilterChannelLayouts *layouts = NULL;
335     static const enum AVSampleFormat sample_fmts[] = {
336         AV_SAMPLE_FMT_FLTP,
337         AV_SAMPLE_FMT_NONE
338     };
339     int ret, sample_rates[] = { 48000, -1 };
340 
341     formats = ff_make_format_list(sample_fmts);
342     if (!formats)
343         return AVERROR(ENOMEM);
344     ret = ff_set_common_formats(ctx, formats);
345     if (ret < 0)
346         return ret;
347 
348     layouts = ff_all_channel_counts();
349     if (!layouts)
350         return AVERROR(ENOMEM);
351 
352     ret = ff_set_common_channel_layouts(ctx, layouts);
353     if (ret < 0)
354         return ret;
355 
356     formats = ff_make_format_list(sample_rates);
357     if (!formats)
358         return AVERROR(ENOMEM);
359     return ff_set_common_samplerates(ctx, formats);
360 }
361 
config_input(AVFilterLink * inlink)362 static int config_input(AVFilterLink *inlink)
363 {
364     AVFilterContext *ctx = inlink->dst;
365     AudioRNNContext *s = ctx->priv;
366     int ret;
367 
368     s->channels = inlink->channels;
369 
370     if (!s->st)
371         s->st = av_calloc(s->channels, sizeof(DenoiseState));
372     if (!s->st)
373         return AVERROR(ENOMEM);
374 
375     for (int i = 0; i < s->channels; i++) {
376         DenoiseState *st = &s->st[i];
377 
378         st->rnn[0].model = s->model[0];
379         st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
380         st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
381         st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
382         if (!st->rnn[0].vad_gru_state ||
383             !st->rnn[0].noise_gru_state ||
384             !st->rnn[0].denoise_gru_state)
385             return AVERROR(ENOMEM);
386     }
387 
388     for (int i = 0; i < s->channels; i++) {
389         DenoiseState *st = &s->st[i];
390 
391         if (!st->tx)
392             ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
393         if (ret < 0)
394             return ret;
395 
396         if (!st->txi)
397             ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
398         if (ret < 0)
399             return ret;
400     }
401 
402     return 0;
403 }
404 
biquad(float * y,float mem[2],const float * x,const float * b,const float * a,int N)405 static void biquad(float *y, float mem[2], const float *x,
406                    const float *b, const float *a, int N)
407 {
408     for (int i = 0; i < N; i++) {
409         float xi, yi;
410 
411         xi = x[i];
412         yi = x[i] + mem[0];
413         mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
414         mem[1] = (b[1]*xi - a[1]*yi);
415         y[i] = yi;
416     }
417 }
418 
419 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
420 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
421 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
422 
forward_transform(DenoiseState * st,AVComplexFloat * out,const float * in)423 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
424 {
425     AVComplexFloat x[WINDOW_SIZE];
426     AVComplexFloat y[WINDOW_SIZE];
427 
428     for (int i = 0; i < WINDOW_SIZE; i++) {
429         x[i].re = in[i];
430         x[i].im = 0;
431     }
432 
433     st->tx_fn(st->tx, y, x, sizeof(float));
434 
435     RNN_COPY(out, y, FREQ_SIZE);
436 }
437 
inverse_transform(DenoiseState * st,float * out,const AVComplexFloat * in)438 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
439 {
440     AVComplexFloat x[WINDOW_SIZE];
441     AVComplexFloat y[WINDOW_SIZE];
442 
443     RNN_COPY(x, in, FREQ_SIZE);
444 
445     for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
446         x[i].re =  x[WINDOW_SIZE - i].re;
447         x[i].im = -x[WINDOW_SIZE - i].im;
448     }
449 
450     st->txi_fn(st->txi, y, x, sizeof(float));
451 
452     for (int i = 0; i < WINDOW_SIZE; i++)
453         out[i] = y[i].re / WINDOW_SIZE;
454 }
455 
456 static const uint8_t eband5ms[] = {
457 /*0  200 400 600 800  1k 1.2 1.4 1.6  2k 2.4 2.8 3.2  4k 4.8 5.6 6.8  8k 9.6 12k 15.6 20k*/
458   0,  1,  2,  3,  4,   5, 6,  7,  8,  10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
459 };
460 
compute_band_energy(float * bandE,const AVComplexFloat * X)461 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
462 {
463     float sum[NB_BANDS] = {0};
464 
465     for (int i = 0; i < NB_BANDS - 1; i++) {
466         int band_size;
467 
468         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
469         for (int j = 0; j < band_size; j++) {
470             float tmp, frac = (float)j / band_size;
471 
472             tmp         = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
473             tmp        += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
474             sum[i]     += (1.f - frac) * tmp;
475             sum[i + 1] +=        frac  * tmp;
476         }
477     }
478 
479     sum[0] *= 2;
480     sum[NB_BANDS - 1] *= 2;
481 
482     for (int i = 0; i < NB_BANDS; i++)
483         bandE[i] = sum[i];
484 }
485 
compute_band_corr(float * bandE,const AVComplexFloat * X,const AVComplexFloat * P)486 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
487 {
488     float sum[NB_BANDS] = { 0 };
489 
490     for (int i = 0; i < NB_BANDS - 1; i++) {
491         int band_size;
492 
493         band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
494         for (int j = 0; j < band_size; j++) {
495             float tmp, frac = (float)j / band_size;
496 
497             tmp  = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
498             tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
499             sum[i]     += (1 - frac) * tmp;
500             sum[i + 1] +=      frac  * tmp;
501         }
502     }
503 
504     sum[0] *= 2;
505     sum[NB_BANDS-1] *= 2;
506 
507     for (int i = 0; i < NB_BANDS; i++)
508         bandE[i] = sum[i];
509 }
510 
frame_analysis(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,float * Ex,const float * in)511 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
512 {
513     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
514 
515     RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
516     RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
517     RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
518     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
519     forward_transform(st, X, x);
520     compute_band_energy(Ex, X);
521 }
522 
frame_synthesis(AudioRNNContext * s,DenoiseState * st,float * out,const AVComplexFloat * y)523 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
524 {
525     LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
526     const float *src = st->history;
527     const float mix = s->mix;
528     const float imix = 1.f - FFMAX(mix, 0.f);
529 
530     inverse_transform(st, x, y);
531     s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
532     s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
533     RNN_COPY(out, x, FRAME_SIZE);
534     RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
535 
536     for (int n = 0; n < FRAME_SIZE; n++)
537         out[n] = out[n] * mix + src[n] * imix;
538 }
539 
xcorr_kernel(const float * x,const float * y,float sum[4],int len)540 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
541 {
542     float y_0, y_1, y_2, y_3 = 0;
543     int j;
544 
545     y_0 = *y++;
546     y_1 = *y++;
547     y_2 = *y++;
548 
549     for (j = 0; j < len - 3; j += 4) {
550         float tmp;
551 
552         tmp = *x++;
553         y_3 = *y++;
554         sum[0] += tmp * y_0;
555         sum[1] += tmp * y_1;
556         sum[2] += tmp * y_2;
557         sum[3] += tmp * y_3;
558         tmp = *x++;
559         y_0 = *y++;
560         sum[0] += tmp * y_1;
561         sum[1] += tmp * y_2;
562         sum[2] += tmp * y_3;
563         sum[3] += tmp * y_0;
564         tmp = *x++;
565         y_1 = *y++;
566         sum[0] += tmp * y_2;
567         sum[1] += tmp * y_3;
568         sum[2] += tmp * y_0;
569         sum[3] += tmp * y_1;
570         tmp = *x++;
571         y_2 = *y++;
572         sum[0] += tmp * y_3;
573         sum[1] += tmp * y_0;
574         sum[2] += tmp * y_1;
575         sum[3] += tmp * y_2;
576     }
577 
578     if (j++ < len) {
579         float tmp = *x++;
580 
581         y_3 = *y++;
582         sum[0] += tmp * y_0;
583         sum[1] += tmp * y_1;
584         sum[2] += tmp * y_2;
585         sum[3] += tmp * y_3;
586     }
587 
588     if (j++ < len) {
589         float tmp=*x++;
590 
591         y_0 = *y++;
592         sum[0] += tmp * y_1;
593         sum[1] += tmp * y_2;
594         sum[2] += tmp * y_3;
595         sum[3] += tmp * y_0;
596     }
597 
598     if (j < len) {
599         float tmp=*x++;
600 
601         y_1 = *y++;
602         sum[0] += tmp * y_2;
603         sum[1] += tmp * y_3;
604         sum[2] += tmp * y_0;
605         sum[3] += tmp * y_1;
606     }
607 }
608 
celt_inner_prod(const float * x,const float * y,int N)609 static inline float celt_inner_prod(const float *x,
610                                     const float *y, int N)
611 {
612     float xy = 0.f;
613 
614     for (int i = 0; i < N; i++)
615         xy += x[i] * y[i];
616 
617     return xy;
618 }
619 
celt_pitch_xcorr(const float * x,const float * y,float * xcorr,int len,int max_pitch)620 static void celt_pitch_xcorr(const float *x, const float *y,
621                              float *xcorr, int len, int max_pitch)
622 {
623     int i;
624 
625     for (i = 0; i < max_pitch - 3; i += 4) {
626         float sum[4] = { 0, 0, 0, 0};
627 
628         xcorr_kernel(x, y + i, sum, len);
629 
630         xcorr[i]     = sum[0];
631         xcorr[i + 1] = sum[1];
632         xcorr[i + 2] = sum[2];
633         xcorr[i + 3] = sum[3];
634     }
635     /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
636     for (; i < max_pitch; i++) {
637         xcorr[i] = celt_inner_prod(x, y + i, len);
638     }
639 }
640 
celt_autocorr(const float * x,float * ac,const float * window,int overlap,int lag,int n)641 static int celt_autocorr(const float *x,   /*  in: [0...n-1] samples x   */
642                          float       *ac,  /* out: [0...lag-1] ac values */
643                          const float *window,
644                          int          overlap,
645                          int          lag,
646                          int          n)
647 {
648     int fastN = n - lag;
649     int shift;
650     const float *xptr;
651     float xx[PITCH_BUF_SIZE>>1];
652 
653     if (overlap == 0) {
654         xptr = x;
655     } else {
656         for (int i = 0; i < n; i++)
657             xx[i] = x[i];
658         for (int i = 0; i < overlap; i++) {
659             xx[i] = x[i] * window[i];
660             xx[n-i-1] = x[n-i-1] * window[i];
661         }
662         xptr = xx;
663     }
664 
665     shift = 0;
666     celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
667 
668     for (int k = 0; k <= lag; k++) {
669         float d = 0.f;
670 
671         for (int i = k + fastN; i < n; i++)
672             d += xptr[i] * xptr[i-k];
673         ac[k] += d;
674     }
675 
676     return shift;
677 }
678 
celt_lpc(float * lpc,const float * ac,int p)679 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients      */
680                 const float *ac,   /* in:  [0...p] autocorrelation values  */
681                           int p)
682 {
683     float r, error = ac[0];
684 
685     RNN_CLEAR(lpc, p);
686     if (ac[0] != 0) {
687         for (int i = 0; i < p; i++) {
688             /* Sum up this iteration's reflection coefficient */
689             float rr = 0;
690             for (int j = 0; j < i; j++)
691                 rr += (lpc[j] * ac[i - j]);
692             rr += ac[i + 1];
693             r = -rr/error;
694             /*  Update LPC coefficients and total error */
695             lpc[i] = r;
696             for (int j = 0; j < (i + 1) >> 1; j++) {
697                 float tmp1, tmp2;
698                 tmp1 = lpc[j];
699                 tmp2 = lpc[i-1-j];
700                 lpc[j]     = tmp1 + (r*tmp2);
701                 lpc[i-1-j] = tmp2 + (r*tmp1);
702             }
703 
704             error = error - (r * r *error);
705             /* Bail out once we get 30 dB gain */
706             if (error < .001f * ac[0])
707                 break;
708         }
709     }
710 }
711 
celt_fir5(const float * x,const float * num,float * y,int N,float * mem)712 static void celt_fir5(const float *x,
713                       const float *num,
714                       float *y,
715                       int N,
716                       float *mem)
717 {
718     float num0, num1, num2, num3, num4;
719     float mem0, mem1, mem2, mem3, mem4;
720 
721     num0 = num[0];
722     num1 = num[1];
723     num2 = num[2];
724     num3 = num[3];
725     num4 = num[4];
726     mem0 = mem[0];
727     mem1 = mem[1];
728     mem2 = mem[2];
729     mem3 = mem[3];
730     mem4 = mem[4];
731 
732     for (int i = 0; i < N; i++) {
733         float sum = x[i];
734 
735         sum += (num0*mem0);
736         sum += (num1*mem1);
737         sum += (num2*mem2);
738         sum += (num3*mem3);
739         sum += (num4*mem4);
740         mem4 = mem3;
741         mem3 = mem2;
742         mem2 = mem1;
743         mem1 = mem0;
744         mem0 = x[i];
745         y[i] = sum;
746     }
747 
748     mem[0] = mem0;
749     mem[1] = mem1;
750     mem[2] = mem2;
751     mem[3] = mem3;
752     mem[4] = mem4;
753 }
754 
pitch_downsample(float * x[],float * x_lp,int len,int C)755 static void pitch_downsample(float *x[], float *x_lp,
756                              int len, int C)
757 {
758     float ac[5];
759     float tmp=Q15ONE;
760     float lpc[4], mem[5]={0,0,0,0,0};
761     float lpc2[5];
762     float c1 = .8f;
763 
764     for (int i = 1; i < len >> 1; i++)
765         x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
766     x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
767     if (C==2) {
768         for (int i = 1; i < len >> 1; i++)
769             x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
770         x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
771     }
772 
773     celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
774 
775     /* Noise floor -40 dB */
776     ac[0] *= 1.0001f;
777     /* Lag windowing */
778     for (int i = 1; i <= 4; i++) {
779         /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
780         ac[i] -= ac[i]*(.008f*i)*(.008f*i);
781     }
782 
783     celt_lpc(lpc, ac, 4);
784     for (int i = 0; i < 4; i++) {
785         tmp = .9f * tmp;
786         lpc[i] = (lpc[i] * tmp);
787     }
788     /* Add a zero */
789     lpc2[0] = lpc[0] + .8f;
790     lpc2[1] = lpc[1] + (c1 * lpc[0]);
791     lpc2[2] = lpc[2] + (c1 * lpc[1]);
792     lpc2[3] = lpc[3] + (c1 * lpc[2]);
793     lpc2[4] = (c1 * lpc[3]);
794     celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
795 }
796 
dual_inner_prod(const float * x,const float * y01,const float * y02,int N,float * xy1,float * xy2)797 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
798                                    int N, float *xy1, float *xy2)
799 {
800     float xy01 = 0, xy02 = 0;
801 
802     for (int i = 0; i < N; i++) {
803         xy01 += (x[i] * y01[i]);
804         xy02 += (x[i] * y02[i]);
805     }
806 
807     *xy1 = xy01;
808     *xy2 = xy02;
809 }
810 
compute_pitch_gain(float xy,float xx,float yy)811 static float compute_pitch_gain(float xy, float xx, float yy)
812 {
813     return xy / sqrtf(1.f + xx * yy);
814 }
815 
816 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
remove_doubling(float * x,int maxperiod,int minperiod,int N,int * T0_,int prev_period,float prev_gain)817 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
818                              int *T0_, int prev_period, float prev_gain)
819 {
820     int k, i, T, T0;
821     float g, g0;
822     float pg;
823     float xy,xx,yy,xy2;
824     float xcorr[3];
825     float best_xy, best_yy;
826     int offset;
827     int minperiod0;
828     float yy_lookup[PITCH_MAX_PERIOD+1];
829 
830     minperiod0 = minperiod;
831     maxperiod /= 2;
832     minperiod /= 2;
833     *T0_ /= 2;
834     prev_period /= 2;
835     N /= 2;
836     x += maxperiod;
837     if (*T0_>=maxperiod)
838         *T0_=maxperiod-1;
839 
840     T = T0 = *T0_;
841     dual_inner_prod(x, x, x-T0, N, &xx, &xy);
842     yy_lookup[0] = xx;
843     yy=xx;
844     for (i = 1; i <= maxperiod; i++) {
845         yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
846         yy_lookup[i] = FFMAX(0, yy);
847     }
848     yy = yy_lookup[T0];
849     best_xy = xy;
850     best_yy = yy;
851     g = g0 = compute_pitch_gain(xy, xx, yy);
852     /* Look for any pitch at T/k */
853     for (k = 2; k <= 15; k++) {
854         int T1, T1b;
855         float g1;
856         float cont=0;
857         float thresh;
858         T1 = (2*T0+k)/(2*k);
859         if (T1 < minperiod)
860             break;
861         /* Look for another strong correlation at T1b */
862         if (k==2)
863         {
864             if (T1+T0>maxperiod)
865                 T1b = T0;
866             else
867                 T1b = T0+T1;
868         } else
869         {
870             T1b = (2*second_check[k]*T0+k)/(2*k);
871         }
872         dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
873         xy = .5f * (xy + xy2);
874         yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
875         g1 = compute_pitch_gain(xy, xx, yy);
876         if (FFABS(T1-prev_period)<=1)
877             cont = prev_gain;
878         else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
879             cont = prev_gain * .5f;
880         else
881             cont = 0;
882         thresh = FFMAX(.3f, (.7f * g0) - cont);
883         /* Bias against very high pitch (very short period) to avoid false-positives
884            due to short-term correlation */
885         if (T1<3*minperiod)
886             thresh = FFMAX(.4f, (.85f * g0) - cont);
887         else if (T1<2*minperiod)
888             thresh = FFMAX(.5f, (.9f * g0) - cont);
889         if (g1 > thresh)
890         {
891             best_xy = xy;
892             best_yy = yy;
893             T = T1;
894             g = g1;
895         }
896     }
897     best_xy = FFMAX(0, best_xy);
898     if (best_yy <= best_xy)
899         pg = Q15ONE;
900     else
901         pg = best_xy/(best_yy + 1);
902 
903     for (k = 0; k < 3; k++)
904         xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
905     if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
906         offset = 1;
907     else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
908         offset = -1;
909     else
910         offset = 0;
911     if (pg > g)
912         pg = g;
913     *T0_ = 2*T+offset;
914 
915     if (*T0_<minperiod0)
916         *T0_=minperiod0;
917     return pg;
918 }
919 
find_best_pitch(float * xcorr,float * y,int len,int max_pitch,int * best_pitch)920 static void find_best_pitch(float *xcorr, float *y, int len,
921                             int max_pitch, int *best_pitch)
922 {
923     float best_num[2];
924     float best_den[2];
925     float Syy = 1.f;
926 
927     best_num[0] = -1;
928     best_num[1] = -1;
929     best_den[0] = 0;
930     best_den[1] = 0;
931     best_pitch[0] = 0;
932     best_pitch[1] = 1;
933 
934     for (int j = 0; j < len; j++)
935         Syy += y[j] * y[j];
936 
937     for (int i = 0; i < max_pitch; i++) {
938         if (xcorr[i]>0) {
939             float num;
940             float xcorr16;
941 
942             xcorr16 = xcorr[i];
943             /* Considering the range of xcorr16, this should avoid both underflows
944                and overflows (inf) when squaring xcorr16 */
945             xcorr16 *= 1e-12f;
946             num = xcorr16 * xcorr16;
947             if ((num * best_den[1]) > (best_num[1] * Syy)) {
948                 if ((num * best_den[0]) > (best_num[0] * Syy)) {
949                     best_num[1] = best_num[0];
950                     best_den[1] = best_den[0];
951                     best_pitch[1] = best_pitch[0];
952                     best_num[0] = num;
953                     best_den[0] = Syy;
954                     best_pitch[0] = i;
955                 } else {
956                     best_num[1] = num;
957                     best_den[1] = Syy;
958                     best_pitch[1] = i;
959                 }
960             }
961         }
962         Syy += y[i+len]*y[i+len] - y[i] * y[i];
963         Syy = FFMAX(1, Syy);
964     }
965 }
966 
pitch_search(const float * x_lp,float * y,int len,int max_pitch,int * pitch)967 static void pitch_search(const float *x_lp, float *y,
968                          int len, int max_pitch, int *pitch)
969 {
970     int lag;
971     int best_pitch[2]={0,0};
972     int offset;
973 
974     float x_lp4[WINDOW_SIZE];
975     float y_lp4[WINDOW_SIZE];
976     float xcorr[WINDOW_SIZE];
977 
978     lag = len+max_pitch;
979 
980     /* Downsample by 2 again */
981     for (int j = 0; j < len >> 2; j++)
982         x_lp4[j] = x_lp[2*j];
983     for (int j = 0; j < lag >> 2; j++)
984         y_lp4[j] = y[2*j];
985 
986     /* Coarse search with 4x decimation */
987 
988     celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
989 
990     find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
991 
992     /* Finer search with 2x decimation */
993     for (int i = 0; i < max_pitch >> 1; i++) {
994         float sum;
995         xcorr[i] = 0;
996         if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
997             continue;
998         sum = celt_inner_prod(x_lp, y+i, len>>1);
999         xcorr[i] = FFMAX(-1, sum);
1000     }
1001 
1002     find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
1003 
1004     /* Refine by pseudo-interpolation */
1005     if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
1006         float a, b, c;
1007 
1008         a = xcorr[best_pitch[0] - 1];
1009         b = xcorr[best_pitch[0]];
1010         c = xcorr[best_pitch[0] + 1];
1011         if (c - a > .7f * (b - a))
1012             offset = 1;
1013         else if (a - c > .7f * (b-c))
1014             offset = -1;
1015         else
1016             offset = 0;
1017     } else {
1018         offset = 0;
1019     }
1020 
1021     *pitch = 2 * best_pitch[0] - offset;
1022 }
1023 
dct(AudioRNNContext * s,float * out,const float * in)1024 static void dct(AudioRNNContext *s, float *out, const float *in)
1025 {
1026     for (int i = 0; i < NB_BANDS; i++) {
1027         float sum;
1028 
1029         sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1030         out[i] = sum * sqrtf(2.f / 22);
1031     }
1032 }
1033 
compute_frame_features(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,AVComplexFloat * P,float * Ex,float * Ep,float * Exp,float * features,const float * in)1034 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1035                                   float *Ex, float *Ep, float *Exp, float *features, const float *in)
1036 {
1037     float E = 0;
1038     float *ceps_0, *ceps_1, *ceps_2;
1039     float spec_variability = 0;
1040     LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1041     LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1042     float pitch_buf[PITCH_BUF_SIZE>>1];
1043     int pitch_index;
1044     float gain;
1045     float *(pre[1]);
1046     float tmp[NB_BANDS];
1047     float follow, logMax;
1048 
1049     frame_analysis(s, st, X, Ex, in);
1050     RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1051     RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1052     pre[0] = &st->pitch_buf[0];
1053     pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1054     pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1055             PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1056     pitch_index = PITCH_MAX_PERIOD-pitch_index;
1057 
1058     gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1059             PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1060     st->last_period = pitch_index;
1061     st->last_gain = gain;
1062 
1063     for (int i = 0; i < WINDOW_SIZE; i++)
1064         p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1065 
1066     s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1067     forward_transform(st, P, p);
1068     compute_band_energy(Ep, P);
1069     compute_band_corr(Exp, X, P);
1070 
1071     for (int i = 0; i < NB_BANDS; i++)
1072         Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1073 
1074     dct(s, tmp, Exp);
1075 
1076     for (int i = 0; i < NB_DELTA_CEPS; i++)
1077         features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1078 
1079     features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1080     features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1081     features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1082     logMax = -2;
1083     follow = -2;
1084 
1085     for (int i = 0; i < NB_BANDS; i++) {
1086         Ly[i] = log10f(1e-2f + Ex[i]);
1087         Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1088         logMax = FFMAX(logMax, Ly[i]);
1089         follow = FFMAX(follow-1.5, Ly[i]);
1090         E += Ex[i];
1091     }
1092 
1093     if (E < 0.04f) {
1094         /* If there's no audio, avoid messing up the state. */
1095         RNN_CLEAR(features, NB_FEATURES);
1096         return 1;
1097     }
1098 
1099     dct(s, features, Ly);
1100     features[0] -= 12;
1101     features[1] -= 4;
1102     ceps_0 = st->cepstral_mem[st->memid];
1103     ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1104     ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1105 
1106     for (int i = 0; i < NB_BANDS; i++)
1107         ceps_0[i] = features[i];
1108 
1109     st->memid++;
1110     for (int i = 0; i < NB_DELTA_CEPS; i++) {
1111         features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1112         features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1113         features[NB_BANDS+NB_DELTA_CEPS+i] =  ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1114     }
1115     /* Spectral variability features. */
1116     if (st->memid == CEPS_MEM)
1117         st->memid = 0;
1118 
1119     for (int i = 0; i < CEPS_MEM; i++) {
1120         float mindist = 1e15f;
1121         for (int j = 0; j < CEPS_MEM; j++) {
1122             float dist = 0.f;
1123             for (int k = 0; k < NB_BANDS; k++) {
1124                 float tmp;
1125 
1126                 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1127                 dist += tmp*tmp;
1128             }
1129 
1130             if (j != i)
1131                 mindist = FFMIN(mindist, dist);
1132         }
1133 
1134         spec_variability += mindist;
1135     }
1136 
1137     features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1138 
1139     return 0;
1140 }
1141 
interp_band_gain(float * g,const float * bandE)1142 static void interp_band_gain(float *g, const float *bandE)
1143 {
1144     memset(g, 0, sizeof(*g) * FREQ_SIZE);
1145 
1146     for (int i = 0; i < NB_BANDS - 1; i++) {
1147         const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1148 
1149         for (int j = 0; j < band_size; j++) {
1150             float frac = (float)j / band_size;
1151 
1152             g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1153         }
1154     }
1155 }
1156 
pitch_filter(AVComplexFloat * X,const AVComplexFloat * P,const float * Ex,const float * Ep,const float * Exp,const float * g)1157 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1158                          const float *Exp, const float *g)
1159 {
1160     float newE[NB_BANDS];
1161     float r[NB_BANDS];
1162     float norm[NB_BANDS];
1163     float rf[FREQ_SIZE] = {0};
1164     float normf[FREQ_SIZE]={0};
1165 
1166     for (int i = 0; i < NB_BANDS; i++) {
1167         if (Exp[i]>g[i]) r[i] = 1;
1168         else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1169         r[i]  = sqrtf(av_clipf(r[i], 0, 1));
1170         r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1171     }
1172     interp_band_gain(rf, r);
1173     for (int i = 0; i < FREQ_SIZE; i++) {
1174         X[i].re += rf[i]*P[i].re;
1175         X[i].im += rf[i]*P[i].im;
1176     }
1177     compute_band_energy(newE, X);
1178     for (int i = 0; i < NB_BANDS; i++) {
1179         norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1180     }
1181     interp_band_gain(normf, norm);
1182     for (int i = 0; i < FREQ_SIZE; i++) {
1183         X[i].re *= normf[i];
1184         X[i].im *= normf[i];
1185     }
1186 }
1187 
1188 static const float tansig_table[201] = {
1189     0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1190     0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1191     0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1192     0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1193     0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1194     0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1195     0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1196     0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1197     0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1198     0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1199     0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1200     0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1201     0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1202     0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1203     0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1204     0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1205     0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1206     0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1207     0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1208     0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1209     0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1210     0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1211     0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1212     0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1213     0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1214     0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1215     0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1216     0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1217     0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1218     0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1219     0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1220     0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1221     0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1222     0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1223     0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1224     0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1225     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1226     0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1227     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1228     1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1229     1.000000f,
1230 };
1231 
tansig_approx(float x)1232 static inline float tansig_approx(float x)
1233 {
1234     float y, dy;
1235     float sign=1;
1236     int i;
1237 
1238     /* Tests are reversed to catch NaNs */
1239     if (!(x<8))
1240         return 1;
1241     if (!(x>-8))
1242         return -1;
1243     /* Another check in case of -ffast-math */
1244 
1245     if (isnan(x))
1246        return 0;
1247 
1248     if (x < 0) {
1249        x=-x;
1250        sign=-1;
1251     }
1252     i = (int)floor(.5f+25*x);
1253     x -= .04f*i;
1254     y = tansig_table[i];
1255     dy = 1-y*y;
1256     y = y + x*dy*(1 - y*x);
1257     return sign*y;
1258 }
1259 
sigmoid_approx(float x)1260 static inline float sigmoid_approx(float x)
1261 {
1262     return .5f + .5f*tansig_approx(.5f*x);
1263 }
1264 
compute_dense(const DenseLayer * layer,float * output,const float * input)1265 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1266 {
1267     const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1268 
1269     for (int i = 0; i < N; i++) {
1270         /* Compute update gate. */
1271         float sum = layer->bias[i];
1272 
1273         for (int j = 0; j < M; j++)
1274             sum += layer->input_weights[j * stride + i] * input[j];
1275 
1276         output[i] = WEIGHTS_SCALE * sum;
1277     }
1278 
1279     if (layer->activation == ACTIVATION_SIGMOID) {
1280         for (int i = 0; i < N; i++)
1281             output[i] = sigmoid_approx(output[i]);
1282     } else if (layer->activation == ACTIVATION_TANH) {
1283         for (int i = 0; i < N; i++)
1284             output[i] = tansig_approx(output[i]);
1285     } else if (layer->activation == ACTIVATION_RELU) {
1286         for (int i = 0; i < N; i++)
1287             output[i] = FFMAX(0, output[i]);
1288     } else {
1289         av_assert0(0);
1290     }
1291 }
1292 
compute_gru(AudioRNNContext * s,const GRULayer * gru,float * state,const float * input)1293 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1294 {
1295     LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1296     LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1297     LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1298     const int M = gru->nb_inputs;
1299     const int N = gru->nb_neurons;
1300     const int AN = FFALIGN(N, 4);
1301     const int AM = FFALIGN(M, 4);
1302     const int stride = 3 * AN, istride = 3 * AM;
1303 
1304     for (int i = 0; i < N; i++) {
1305         /* Compute update gate. */
1306         float sum = gru->bias[i];
1307 
1308         sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1309         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1310         z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1311     }
1312 
1313     for (int i = 0; i < N; i++) {
1314         /* Compute reset gate. */
1315         float sum = gru->bias[N + i];
1316 
1317         sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1318         sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1319         r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1320     }
1321 
1322     for (int i = 0; i < N; i++) {
1323         /* Compute output. */
1324         float sum = gru->bias[2 * N + i];
1325 
1326         sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1327         for (int j = 0; j < N; j++)
1328             sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1329 
1330         if (gru->activation == ACTIVATION_SIGMOID)
1331             sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1332         else if (gru->activation == ACTIVATION_TANH)
1333             sum = tansig_approx(WEIGHTS_SCALE * sum);
1334         else if (gru->activation == ACTIVATION_RELU)
1335             sum = FFMAX(0, WEIGHTS_SCALE * sum);
1336         else
1337             av_assert0(0);
1338         h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1339     }
1340 
1341     RNN_COPY(state, h, N);
1342 }
1343 
1344 #define INPUT_SIZE 42
1345 
compute_rnn(AudioRNNContext * s,RNNState * rnn,float * gains,float * vad,const float * input)1346 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1347 {
1348     LOCAL_ALIGNED_32(float, dense_out,     [MAX_NEURONS]);
1349     LOCAL_ALIGNED_32(float, noise_input,   [MAX_NEURONS * 3]);
1350     LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1351 
1352     compute_dense(rnn->model->input_dense, dense_out, input);
1353     compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1354     compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1355 
1356     memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1357     memcpy(noise_input + rnn->model->input_dense_size,
1358            rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1359     memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1360            input, INPUT_SIZE * sizeof(float));
1361 
1362     compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1363 
1364     memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1365     memcpy(denoise_input + rnn->model->vad_gru_size,
1366            rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1367     memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1368            input, INPUT_SIZE * sizeof(float));
1369 
1370     compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1371     compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1372 }
1373 
rnnoise_channel(AudioRNNContext * s,DenoiseState * st,float * out,const float * in,int disabled)1374 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1375                              int disabled)
1376 {
1377     AVComplexFloat X[FREQ_SIZE];
1378     AVComplexFloat P[WINDOW_SIZE];
1379     float x[FRAME_SIZE];
1380     float Ex[NB_BANDS], Ep[NB_BANDS];
1381     LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1382     float features[NB_FEATURES];
1383     float g[NB_BANDS];
1384     float gf[FREQ_SIZE];
1385     float vad_prob = 0;
1386     float *history = st->history;
1387     static const float a_hp[2] = {-1.99599, 0.99600};
1388     static const float b_hp[2] = {-2, 1};
1389     int silence;
1390 
1391     biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1392     silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1393 
1394     if (!silence && !disabled) {
1395         compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1396         pitch_filter(X, P, Ex, Ep, Exp, g);
1397         for (int i = 0; i < NB_BANDS; i++) {
1398             float alpha = .6f;
1399 
1400             g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1401             st->lastg[i] = g[i];
1402         }
1403 
1404         interp_band_gain(gf, g);
1405 
1406         for (int i = 0; i < FREQ_SIZE; i++) {
1407             X[i].re *= gf[i];
1408             X[i].im *= gf[i];
1409         }
1410     }
1411 
1412     frame_synthesis(s, st, out, X);
1413     memcpy(history, in, FRAME_SIZE * sizeof(*history));
1414 
1415     return vad_prob;
1416 }
1417 
1418 typedef struct ThreadData {
1419     AVFrame *in, *out;
1420 } ThreadData;
1421 
rnnoise_channels(AVFilterContext * ctx,void * arg,int jobnr,int nb_jobs)1422 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1423 {
1424     AudioRNNContext *s = ctx->priv;
1425     ThreadData *td = arg;
1426     AVFrame *in = td->in;
1427     AVFrame *out = td->out;
1428     const int start = (out->channels * jobnr) / nb_jobs;
1429     const int end = (out->channels * (jobnr+1)) / nb_jobs;
1430 
1431     for (int ch = start; ch < end; ch++) {
1432         rnnoise_channel(s, &s->st[ch],
1433                         (float *)out->extended_data[ch],
1434                         (const float *)in->extended_data[ch],
1435                         ctx->is_disabled);
1436     }
1437 
1438     return 0;
1439 }
1440 
filter_frame(AVFilterLink * inlink,AVFrame * in)1441 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1442 {
1443     AVFilterContext *ctx = inlink->dst;
1444     AVFilterLink *outlink = ctx->outputs[0];
1445     AVFrame *out = NULL;
1446     ThreadData td;
1447 
1448     out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1449     if (!out) {
1450         av_frame_free(&in);
1451         return AVERROR(ENOMEM);
1452     }
1453     out->pts = in->pts;
1454 
1455     td.in = in; td.out = out;
1456     ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1457                                                                    ff_filter_get_nb_threads(ctx)));
1458 
1459     av_frame_free(&in);
1460     return ff_filter_frame(outlink, out);
1461 }
1462 
activate(AVFilterContext * ctx)1463 static int activate(AVFilterContext *ctx)
1464 {
1465     AVFilterLink *inlink = ctx->inputs[0];
1466     AVFilterLink *outlink = ctx->outputs[0];
1467     AVFrame *in = NULL;
1468     int ret;
1469 
1470     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1471 
1472     ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1473     if (ret < 0)
1474         return ret;
1475 
1476     if (ret > 0)
1477         return filter_frame(inlink, in);
1478 
1479     FF_FILTER_FORWARD_STATUS(inlink, outlink);
1480     FF_FILTER_FORWARD_WANTED(outlink, inlink);
1481 
1482     return FFERROR_NOT_READY;
1483 }
1484 
open_model(AVFilterContext * ctx,RNNModel ** model)1485 static int open_model(AVFilterContext *ctx, RNNModel **model)
1486 {
1487     AudioRNNContext *s = ctx->priv;
1488     int ret;
1489     FILE *f;
1490 
1491     if (!s->model_name)
1492         return AVERROR(EINVAL);
1493     f = av_fopen_utf8(s->model_name, "r");
1494     if (!f) {
1495         av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
1496         return AVERROR(EINVAL);
1497     }
1498 
1499     ret = rnnoise_model_from_file(f, model);
1500     fclose(f);
1501     if (!*model || ret < 0)
1502         return ret;
1503 
1504     return 0;
1505 }
1506 
init(AVFilterContext * ctx)1507 static av_cold int init(AVFilterContext *ctx)
1508 {
1509     AudioRNNContext *s = ctx->priv;
1510     int ret;
1511 
1512     s->fdsp = avpriv_float_dsp_alloc(0);
1513     if (!s->fdsp)
1514         return AVERROR(ENOMEM);
1515 
1516     ret = open_model(ctx, &s->model[0]);
1517     if (ret < 0)
1518         return ret;
1519 
1520     for (int i = 0; i < FRAME_SIZE; i++) {
1521         s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1522         s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1523     }
1524 
1525     for (int i = 0; i < NB_BANDS; i++) {
1526         for (int j = 0; j < NB_BANDS; j++) {
1527             s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1528             if (j == 0)
1529                 s->dct_table[j][i] *= sqrtf(.5);
1530         }
1531     }
1532 
1533     return 0;
1534 }
1535 
free_model(AVFilterContext * ctx,int n)1536 static void free_model(AVFilterContext *ctx, int n)
1537 {
1538     AudioRNNContext *s = ctx->priv;
1539 
1540     rnnoise_model_free(s->model[n]);
1541     s->model[n] = NULL;
1542 
1543     for (int ch = 0; ch < s->channels && s->st; ch++) {
1544         av_freep(&s->st[ch].rnn[n].vad_gru_state);
1545         av_freep(&s->st[ch].rnn[n].noise_gru_state);
1546         av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1547     }
1548 }
1549 
process_command(AVFilterContext * ctx,const char * cmd,const char * args,char * res,int res_len,int flags)1550 static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1551                            char *res, int res_len, int flags)
1552 {
1553     AudioRNNContext *s = ctx->priv;
1554     int ret;
1555 
1556     ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1557     if (ret < 0)
1558         return ret;
1559 
1560     ret = open_model(ctx, &s->model[1]);
1561     if (ret < 0)
1562         return ret;
1563 
1564     FFSWAP(RNNModel *, s->model[0], s->model[1]);
1565     for (int ch = 0; ch < s->channels; ch++)
1566         FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1567 
1568     ret = config_input(ctx->inputs[0]);
1569     if (ret < 0) {
1570         for (int ch = 0; ch < s->channels; ch++)
1571             FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1572         FFSWAP(RNNModel *, s->model[0], s->model[1]);
1573         return ret;
1574     }
1575 
1576     free_model(ctx, 1);
1577     return 0;
1578 }
1579 
uninit(AVFilterContext * ctx)1580 static av_cold void uninit(AVFilterContext *ctx)
1581 {
1582     AudioRNNContext *s = ctx->priv;
1583 
1584     av_freep(&s->fdsp);
1585     free_model(ctx, 0);
1586     for (int ch = 0; ch < s->channels && s->st; ch++) {
1587         av_tx_uninit(&s->st[ch].tx);
1588         av_tx_uninit(&s->st[ch].txi);
1589     }
1590     av_freep(&s->st);
1591 }
1592 
1593 static const AVFilterPad inputs[] = {
1594     {
1595         .name         = "default",
1596         .type         = AVMEDIA_TYPE_AUDIO,
1597         .config_props = config_input,
1598     },
1599     { NULL }
1600 };
1601 
1602 static const AVFilterPad outputs[] = {
1603     {
1604         .name          = "default",
1605         .type          = AVMEDIA_TYPE_AUDIO,
1606     },
1607     { NULL }
1608 };
1609 
1610 #define OFFSET(x) offsetof(AudioRNNContext, x)
1611 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1612 
1613 static const AVOption arnndn_options[] = {
1614     { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1615     { "m",     "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1616     { "mix",   "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1617     { NULL }
1618 };
1619 
1620 AVFILTER_DEFINE_CLASS(arnndn);
1621 
1622 AVFilter ff_af_arnndn = {
1623     .name          = "arnndn",
1624     .description   = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1625     .query_formats = query_formats,
1626     .priv_size     = sizeof(AudioRNNContext),
1627     .priv_class    = &arnndn_class,
1628     .activate      = activate,
1629     .init          = init,
1630     .uninit        = uninit,
1631     .inputs        = inputs,
1632     .outputs       = outputs,
1633     .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1634                      AVFILTER_FLAG_SLICE_THREADS,
1635     .process_command = process_command,
1636 };
1637