1 /*
2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 *
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
16 *
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/mem_internal.h"
40 #include "libavutil/opt.h"
41 #include "libavutil/tx.h"
42 #include "avfilter.h"
43 #include "audio.h"
44 #include "filters.h"
45 #include "formats.h"
46
47 #define FRAME_SIZE_SHIFT 2
48 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49 #define WINDOW_SIZE (2*FRAME_SIZE)
50 #define FREQ_SIZE (FRAME_SIZE + 1)
51
52 #define PITCH_MIN_PERIOD 60
53 #define PITCH_MAX_PERIOD 768
54 #define PITCH_FRAME_SIZE 960
55 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56
57 #define SQUARE(x) ((x)*(x))
58
59 #define NB_BANDS 22
60
61 #define CEPS_MEM 8
62 #define NB_DELTA_CEPS 6
63
64 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65
66 #define WEIGHTS_SCALE (1.f/256)
67
68 #define MAX_NEURONS 128
69
70 #define ACTIVATION_TANH 0
71 #define ACTIVATION_SIGMOID 1
72 #define ACTIVATION_RELU 2
73
74 #define Q15ONE 1.0f
75
76 typedef struct DenseLayer {
77 const float *bias;
78 const float *input_weights;
79 int nb_inputs;
80 int nb_neurons;
81 int activation;
82 } DenseLayer;
83
84 typedef struct GRULayer {
85 const float *bias;
86 const float *input_weights;
87 const float *recurrent_weights;
88 int nb_inputs;
89 int nb_neurons;
90 int activation;
91 } GRULayer;
92
93 typedef struct RNNModel {
94 int input_dense_size;
95 const DenseLayer *input_dense;
96
97 int vad_gru_size;
98 const GRULayer *vad_gru;
99
100 int noise_gru_size;
101 const GRULayer *noise_gru;
102
103 int denoise_gru_size;
104 const GRULayer *denoise_gru;
105
106 int denoise_output_size;
107 const DenseLayer *denoise_output;
108
109 int vad_output_size;
110 const DenseLayer *vad_output;
111 } RNNModel;
112
113 typedef struct RNNState {
114 float *vad_gru_state;
115 float *noise_gru_state;
116 float *denoise_gru_state;
117 RNNModel *model;
118 } RNNState;
119
120 typedef struct DenoiseState {
121 float analysis_mem[FRAME_SIZE];
122 float cepstral_mem[CEPS_MEM][NB_BANDS];
123 int memid;
124 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125 float pitch_buf[PITCH_BUF_SIZE];
126 float pitch_enh_buf[PITCH_BUF_SIZE];
127 float last_gain;
128 int last_period;
129 float mem_hp_x[2];
130 float lastg[NB_BANDS];
131 float history[FRAME_SIZE];
132 RNNState rnn[2];
133 AVTXContext *tx, *txi;
134 av_tx_fn tx_fn, txi_fn;
135 } DenoiseState;
136
137 typedef struct AudioRNNContext {
138 const AVClass *class;
139
140 char *model_name;
141 float mix;
142
143 int channels;
144 DenoiseState *st;
145
146 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147 DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148
149 RNNModel *model[2];
150
151 AVFloatDSPContext *fdsp;
152 } AudioRNNContext;
153
154 #define F_ACTIVATION_TANH 0
155 #define F_ACTIVATION_SIGMOID 1
156 #define F_ACTIVATION_RELU 2
157
rnnoise_model_free(RNNModel * model)158 static void rnnoise_model_free(RNNModel *model)
159 {
160 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161 #define FREE_DENSE(name) do { \
162 if (model->name) { \
163 av_free((void *) model->name->input_weights); \
164 av_free((void *) model->name->bias); \
165 av_free((void *) model->name); \
166 } \
167 } while (0)
168 #define FREE_GRU(name) do { \
169 if (model->name) { \
170 av_free((void *) model->name->input_weights); \
171 av_free((void *) model->name->recurrent_weights); \
172 av_free((void *) model->name->bias); \
173 av_free((void *) model->name); \
174 } \
175 } while (0)
176
177 if (!model)
178 return;
179 FREE_DENSE(input_dense);
180 FREE_GRU(vad_gru);
181 FREE_GRU(noise_gru);
182 FREE_GRU(denoise_gru);
183 FREE_DENSE(denoise_output);
184 FREE_DENSE(vad_output);
185 av_free(model);
186 }
187
rnnoise_model_from_file(FILE * f,RNNModel ** rnn)188 static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
189 {
190 RNNModel *ret = NULL;
191 DenseLayer *input_dense;
192 GRULayer *vad_gru;
193 GRULayer *noise_gru;
194 GRULayer *denoise_gru;
195 DenseLayer *denoise_output;
196 DenseLayer *vad_output;
197 int in;
198
199 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200 return AVERROR_INVALIDDATA;
201
202 ret = av_calloc(1, sizeof(RNNModel));
203 if (!ret)
204 return AVERROR(ENOMEM);
205
206 #define ALLOC_LAYER(type, name) \
207 name = av_calloc(1, sizeof(type)); \
208 if (!name) { \
209 rnnoise_model_free(ret); \
210 return AVERROR(ENOMEM); \
211 } \
212 ret->name = name
213
214 ALLOC_LAYER(DenseLayer, input_dense);
215 ALLOC_LAYER(GRULayer, vad_gru);
216 ALLOC_LAYER(GRULayer, noise_gru);
217 ALLOC_LAYER(GRULayer, denoise_gru);
218 ALLOC_LAYER(DenseLayer, denoise_output);
219 ALLOC_LAYER(DenseLayer, vad_output);
220
221 #define INPUT_VAL(name) do { \
222 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223 rnnoise_model_free(ret); \
224 return AVERROR(EINVAL); \
225 } \
226 name = in; \
227 } while (0)
228
229 #define INPUT_ACTIVATION(name) do { \
230 int activation; \
231 INPUT_VAL(activation); \
232 switch (activation) { \
233 case F_ACTIVATION_SIGMOID: \
234 name = ACTIVATION_SIGMOID; \
235 break; \
236 case F_ACTIVATION_RELU: \
237 name = ACTIVATION_RELU; \
238 break; \
239 default: \
240 name = ACTIVATION_TANH; \
241 } \
242 } while (0)
243
244 #define INPUT_ARRAY(name, len) do { \
245 float *values = av_calloc((len), sizeof(float)); \
246 if (!values) { \
247 rnnoise_model_free(ret); \
248 return AVERROR(ENOMEM); \
249 } \
250 name = values; \
251 for (int i = 0; i < (len); i++) { \
252 if (fscanf(f, "%d", &in) != 1) { \
253 rnnoise_model_free(ret); \
254 return AVERROR(EINVAL); \
255 } \
256 values[i] = in; \
257 } \
258 } while (0)
259
260 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
261 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262 if (!values) { \
263 rnnoise_model_free(ret); \
264 return AVERROR(ENOMEM); \
265 } \
266 name = values; \
267 for (int k = 0; k < (len0); k++) { \
268 for (int i = 0; i < (len2); i++) { \
269 for (int j = 0; j < (len1); j++) { \
270 if (fscanf(f, "%d", &in) != 1) { \
271 rnnoise_model_free(ret); \
272 return AVERROR(EINVAL); \
273 } \
274 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275 } \
276 } \
277 } \
278 } while (0)
279
280 #define NEW_LINE() do { \
281 int c; \
282 while ((c = fgetc(f)) != EOF) { \
283 if (c == '\n') \
284 break; \
285 } \
286 } while (0)
287
288 #define INPUT_DENSE(name) do { \
289 INPUT_VAL(name->nb_inputs); \
290 INPUT_VAL(name->nb_neurons); \
291 ret->name ## _size = name->nb_neurons; \
292 INPUT_ACTIVATION(name->activation); \
293 NEW_LINE(); \
294 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
295 NEW_LINE(); \
296 INPUT_ARRAY(name->bias, name->nb_neurons); \
297 NEW_LINE(); \
298 } while (0)
299
300 #define INPUT_GRU(name) do { \
301 INPUT_VAL(name->nb_inputs); \
302 INPUT_VAL(name->nb_neurons); \
303 ret->name ## _size = name->nb_neurons; \
304 INPUT_ACTIVATION(name->activation); \
305 NEW_LINE(); \
306 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
307 NEW_LINE(); \
308 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
309 NEW_LINE(); \
310 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
311 NEW_LINE(); \
312 } while (0)
313
314 INPUT_DENSE(input_dense);
315 INPUT_GRU(vad_gru);
316 INPUT_GRU(noise_gru);
317 INPUT_GRU(denoise_gru);
318 INPUT_DENSE(denoise_output);
319 INPUT_DENSE(vad_output);
320
321 if (vad_output->nb_neurons != 1) {
322 rnnoise_model_free(ret);
323 return AVERROR(EINVAL);
324 }
325
326 *rnn = ret;
327
328 return 0;
329 }
330
query_formats(AVFilterContext * ctx)331 static int query_formats(AVFilterContext *ctx)
332 {
333 static const enum AVSampleFormat sample_fmts[] = {
334 AV_SAMPLE_FMT_FLTP,
335 AV_SAMPLE_FMT_NONE
336 };
337 int ret, sample_rates[] = { 48000, -1 };
338
339 ret = ff_set_common_formats_from_list(ctx, sample_fmts);
340 if (ret < 0)
341 return ret;
342
343 ret = ff_set_common_all_channel_counts(ctx);
344 if (ret < 0)
345 return ret;
346
347 return ff_set_common_samplerates_from_list(ctx, sample_rates);
348 }
349
config_input(AVFilterLink * inlink)350 static int config_input(AVFilterLink *inlink)
351 {
352 AVFilterContext *ctx = inlink->dst;
353 AudioRNNContext *s = ctx->priv;
354 int ret = 0;
355
356 s->channels = inlink->ch_layout.nb_channels;
357
358 if (!s->st)
359 s->st = av_calloc(s->channels, sizeof(DenoiseState));
360 if (!s->st)
361 return AVERROR(ENOMEM);
362
363 for (int i = 0; i < s->channels; i++) {
364 DenoiseState *st = &s->st[i];
365
366 st->rnn[0].model = s->model[0];
367 st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
368 st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
369 st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
370 if (!st->rnn[0].vad_gru_state ||
371 !st->rnn[0].noise_gru_state ||
372 !st->rnn[0].denoise_gru_state)
373 return AVERROR(ENOMEM);
374 }
375
376 for (int i = 0; i < s->channels; i++) {
377 DenoiseState *st = &s->st[i];
378
379 if (!st->tx)
380 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
381 if (ret < 0)
382 return ret;
383
384 if (!st->txi)
385 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
386 if (ret < 0)
387 return ret;
388 }
389
390 return ret;
391 }
392
biquad(float * y,float mem[2],const float * x,const float * b,const float * a,int N)393 static void biquad(float *y, float mem[2], const float *x,
394 const float *b, const float *a, int N)
395 {
396 for (int i = 0; i < N; i++) {
397 float xi, yi;
398
399 xi = x[i];
400 yi = x[i] + mem[0];
401 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
402 mem[1] = (b[1]*xi - a[1]*yi);
403 y[i] = yi;
404 }
405 }
406
407 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
408 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
409 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
410
forward_transform(DenoiseState * st,AVComplexFloat * out,const float * in)411 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
412 {
413 AVComplexFloat x[WINDOW_SIZE];
414 AVComplexFloat y[WINDOW_SIZE];
415
416 for (int i = 0; i < WINDOW_SIZE; i++) {
417 x[i].re = in[i];
418 x[i].im = 0;
419 }
420
421 st->tx_fn(st->tx, y, x, sizeof(float));
422
423 RNN_COPY(out, y, FREQ_SIZE);
424 }
425
inverse_transform(DenoiseState * st,float * out,const AVComplexFloat * in)426 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
427 {
428 AVComplexFloat x[WINDOW_SIZE];
429 AVComplexFloat y[WINDOW_SIZE];
430
431 RNN_COPY(x, in, FREQ_SIZE);
432
433 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
434 x[i].re = x[WINDOW_SIZE - i].re;
435 x[i].im = -x[WINDOW_SIZE - i].im;
436 }
437
438 st->txi_fn(st->txi, y, x, sizeof(float));
439
440 for (int i = 0; i < WINDOW_SIZE; i++)
441 out[i] = y[i].re / WINDOW_SIZE;
442 }
443
444 static const uint8_t eband5ms[] = {
445 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
446 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
447 };
448
compute_band_energy(float * bandE,const AVComplexFloat * X)449 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
450 {
451 float sum[NB_BANDS] = {0};
452
453 for (int i = 0; i < NB_BANDS - 1; i++) {
454 int band_size;
455
456 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
457 for (int j = 0; j < band_size; j++) {
458 float tmp, frac = (float)j / band_size;
459
460 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
461 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
462 sum[i] += (1.f - frac) * tmp;
463 sum[i + 1] += frac * tmp;
464 }
465 }
466
467 sum[0] *= 2;
468 sum[NB_BANDS - 1] *= 2;
469
470 for (int i = 0; i < NB_BANDS; i++)
471 bandE[i] = sum[i];
472 }
473
compute_band_corr(float * bandE,const AVComplexFloat * X,const AVComplexFloat * P)474 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
475 {
476 float sum[NB_BANDS] = { 0 };
477
478 for (int i = 0; i < NB_BANDS - 1; i++) {
479 int band_size;
480
481 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
482 for (int j = 0; j < band_size; j++) {
483 float tmp, frac = (float)j / band_size;
484
485 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
486 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
487 sum[i] += (1 - frac) * tmp;
488 sum[i + 1] += frac * tmp;
489 }
490 }
491
492 sum[0] *= 2;
493 sum[NB_BANDS-1] *= 2;
494
495 for (int i = 0; i < NB_BANDS; i++)
496 bandE[i] = sum[i];
497 }
498
frame_analysis(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,float * Ex,const float * in)499 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
500 {
501 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
502
503 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
504 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
505 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
506 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
507 forward_transform(st, X, x);
508 compute_band_energy(Ex, X);
509 }
510
frame_synthesis(AudioRNNContext * s,DenoiseState * st,float * out,const AVComplexFloat * y)511 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
512 {
513 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
514 const float *src = st->history;
515 const float mix = s->mix;
516 const float imix = 1.f - FFMAX(mix, 0.f);
517
518 inverse_transform(st, x, y);
519 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
520 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
521 RNN_COPY(out, x, FRAME_SIZE);
522 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
523
524 for (int n = 0; n < FRAME_SIZE; n++)
525 out[n] = out[n] * mix + src[n] * imix;
526 }
527
xcorr_kernel(const float * x,const float * y,float sum[4],int len)528 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
529 {
530 float y_0, y_1, y_2, y_3 = 0;
531 int j;
532
533 y_0 = *y++;
534 y_1 = *y++;
535 y_2 = *y++;
536
537 for (j = 0; j < len - 3; j += 4) {
538 float tmp;
539
540 tmp = *x++;
541 y_3 = *y++;
542 sum[0] += tmp * y_0;
543 sum[1] += tmp * y_1;
544 sum[2] += tmp * y_2;
545 sum[3] += tmp * y_3;
546 tmp = *x++;
547 y_0 = *y++;
548 sum[0] += tmp * y_1;
549 sum[1] += tmp * y_2;
550 sum[2] += tmp * y_3;
551 sum[3] += tmp * y_0;
552 tmp = *x++;
553 y_1 = *y++;
554 sum[0] += tmp * y_2;
555 sum[1] += tmp * y_3;
556 sum[2] += tmp * y_0;
557 sum[3] += tmp * y_1;
558 tmp = *x++;
559 y_2 = *y++;
560 sum[0] += tmp * y_3;
561 sum[1] += tmp * y_0;
562 sum[2] += tmp * y_1;
563 sum[3] += tmp * y_2;
564 }
565
566 if (j++ < len) {
567 float tmp = *x++;
568
569 y_3 = *y++;
570 sum[0] += tmp * y_0;
571 sum[1] += tmp * y_1;
572 sum[2] += tmp * y_2;
573 sum[3] += tmp * y_3;
574 }
575
576 if (j++ < len) {
577 float tmp=*x++;
578
579 y_0 = *y++;
580 sum[0] += tmp * y_1;
581 sum[1] += tmp * y_2;
582 sum[2] += tmp * y_3;
583 sum[3] += tmp * y_0;
584 }
585
586 if (j < len) {
587 float tmp=*x++;
588
589 y_1 = *y++;
590 sum[0] += tmp * y_2;
591 sum[1] += tmp * y_3;
592 sum[2] += tmp * y_0;
593 sum[3] += tmp * y_1;
594 }
595 }
596
celt_inner_prod(const float * x,const float * y,int N)597 static inline float celt_inner_prod(const float *x,
598 const float *y, int N)
599 {
600 float xy = 0.f;
601
602 for (int i = 0; i < N; i++)
603 xy += x[i] * y[i];
604
605 return xy;
606 }
607
celt_pitch_xcorr(const float * x,const float * y,float * xcorr,int len,int max_pitch)608 static void celt_pitch_xcorr(const float *x, const float *y,
609 float *xcorr, int len, int max_pitch)
610 {
611 int i;
612
613 for (i = 0; i < max_pitch - 3; i += 4) {
614 float sum[4] = { 0, 0, 0, 0};
615
616 xcorr_kernel(x, y + i, sum, len);
617
618 xcorr[i] = sum[0];
619 xcorr[i + 1] = sum[1];
620 xcorr[i + 2] = sum[2];
621 xcorr[i + 3] = sum[3];
622 }
623 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
624 for (; i < max_pitch; i++) {
625 xcorr[i] = celt_inner_prod(x, y + i, len);
626 }
627 }
628
celt_autocorr(const float * x,float * ac,const float * window,int overlap,int lag,int n)629 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
630 float *ac, /* out: [0...lag-1] ac values */
631 const float *window,
632 int overlap,
633 int lag,
634 int n)
635 {
636 int fastN = n - lag;
637 int shift;
638 const float *xptr;
639 float xx[PITCH_BUF_SIZE>>1];
640
641 if (overlap == 0) {
642 xptr = x;
643 } else {
644 for (int i = 0; i < n; i++)
645 xx[i] = x[i];
646 for (int i = 0; i < overlap; i++) {
647 xx[i] = x[i] * window[i];
648 xx[n-i-1] = x[n-i-1] * window[i];
649 }
650 xptr = xx;
651 }
652
653 shift = 0;
654 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
655
656 for (int k = 0; k <= lag; k++) {
657 float d = 0.f;
658
659 for (int i = k + fastN; i < n; i++)
660 d += xptr[i] * xptr[i-k];
661 ac[k] += d;
662 }
663
664 return shift;
665 }
666
celt_lpc(float * lpc,const float * ac,int p)667 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
668 const float *ac, /* in: [0...p] autocorrelation values */
669 int p)
670 {
671 float r, error = ac[0];
672
673 RNN_CLEAR(lpc, p);
674 if (ac[0] != 0) {
675 for (int i = 0; i < p; i++) {
676 /* Sum up this iteration's reflection coefficient */
677 float rr = 0;
678 for (int j = 0; j < i; j++)
679 rr += (lpc[j] * ac[i - j]);
680 rr += ac[i + 1];
681 r = -rr/error;
682 /* Update LPC coefficients and total error */
683 lpc[i] = r;
684 for (int j = 0; j < (i + 1) >> 1; j++) {
685 float tmp1, tmp2;
686 tmp1 = lpc[j];
687 tmp2 = lpc[i-1-j];
688 lpc[j] = tmp1 + (r*tmp2);
689 lpc[i-1-j] = tmp2 + (r*tmp1);
690 }
691
692 error = error - (r * r *error);
693 /* Bail out once we get 30 dB gain */
694 if (error < .001f * ac[0])
695 break;
696 }
697 }
698 }
699
celt_fir5(const float * x,const float * num,float * y,int N,float * mem)700 static void celt_fir5(const float *x,
701 const float *num,
702 float *y,
703 int N,
704 float *mem)
705 {
706 float num0, num1, num2, num3, num4;
707 float mem0, mem1, mem2, mem3, mem4;
708
709 num0 = num[0];
710 num1 = num[1];
711 num2 = num[2];
712 num3 = num[3];
713 num4 = num[4];
714 mem0 = mem[0];
715 mem1 = mem[1];
716 mem2 = mem[2];
717 mem3 = mem[3];
718 mem4 = mem[4];
719
720 for (int i = 0; i < N; i++) {
721 float sum = x[i];
722
723 sum += (num0*mem0);
724 sum += (num1*mem1);
725 sum += (num2*mem2);
726 sum += (num3*mem3);
727 sum += (num4*mem4);
728 mem4 = mem3;
729 mem3 = mem2;
730 mem2 = mem1;
731 mem1 = mem0;
732 mem0 = x[i];
733 y[i] = sum;
734 }
735
736 mem[0] = mem0;
737 mem[1] = mem1;
738 mem[2] = mem2;
739 mem[3] = mem3;
740 mem[4] = mem4;
741 }
742
pitch_downsample(float * x[],float * x_lp,int len,int C)743 static void pitch_downsample(float *x[], float *x_lp,
744 int len, int C)
745 {
746 float ac[5];
747 float tmp=Q15ONE;
748 float lpc[4], mem[5]={0,0,0,0,0};
749 float lpc2[5];
750 float c1 = .8f;
751
752 for (int i = 1; i < len >> 1; i++)
753 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
754 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
755 if (C==2) {
756 for (int i = 1; i < len >> 1; i++)
757 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
758 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
759 }
760
761 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
762
763 /* Noise floor -40 dB */
764 ac[0] *= 1.0001f;
765 /* Lag windowing */
766 for (int i = 1; i <= 4; i++) {
767 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
768 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
769 }
770
771 celt_lpc(lpc, ac, 4);
772 for (int i = 0; i < 4; i++) {
773 tmp = .9f * tmp;
774 lpc[i] = (lpc[i] * tmp);
775 }
776 /* Add a zero */
777 lpc2[0] = lpc[0] + .8f;
778 lpc2[1] = lpc[1] + (c1 * lpc[0]);
779 lpc2[2] = lpc[2] + (c1 * lpc[1]);
780 lpc2[3] = lpc[3] + (c1 * lpc[2]);
781 lpc2[4] = (c1 * lpc[3]);
782 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
783 }
784
dual_inner_prod(const float * x,const float * y01,const float * y02,int N,float * xy1,float * xy2)785 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
786 int N, float *xy1, float *xy2)
787 {
788 float xy01 = 0, xy02 = 0;
789
790 for (int i = 0; i < N; i++) {
791 xy01 += (x[i] * y01[i]);
792 xy02 += (x[i] * y02[i]);
793 }
794
795 *xy1 = xy01;
796 *xy2 = xy02;
797 }
798
compute_pitch_gain(float xy,float xx,float yy)799 static float compute_pitch_gain(float xy, float xx, float yy)
800 {
801 return xy / sqrtf(1.f + xx * yy);
802 }
803
804 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
remove_doubling(float * x,int maxperiod,int minperiod,int N,int * T0_,int prev_period,float prev_gain)805 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
806 int *T0_, int prev_period, float prev_gain)
807 {
808 int k, i, T, T0;
809 float g, g0;
810 float pg;
811 float xy,xx,yy,xy2;
812 float xcorr[3];
813 float best_xy, best_yy;
814 int offset;
815 int minperiod0;
816 float yy_lookup[PITCH_MAX_PERIOD+1];
817
818 minperiod0 = minperiod;
819 maxperiod /= 2;
820 minperiod /= 2;
821 *T0_ /= 2;
822 prev_period /= 2;
823 N /= 2;
824 x += maxperiod;
825 if (*T0_>=maxperiod)
826 *T0_=maxperiod-1;
827
828 T = T0 = *T0_;
829 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
830 yy_lookup[0] = xx;
831 yy=xx;
832 for (i = 1; i <= maxperiod; i++) {
833 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
834 yy_lookup[i] = FFMAX(0, yy);
835 }
836 yy = yy_lookup[T0];
837 best_xy = xy;
838 best_yy = yy;
839 g = g0 = compute_pitch_gain(xy, xx, yy);
840 /* Look for any pitch at T/k */
841 for (k = 2; k <= 15; k++) {
842 int T1, T1b;
843 float g1;
844 float cont=0;
845 float thresh;
846 T1 = (2*T0+k)/(2*k);
847 if (T1 < minperiod)
848 break;
849 /* Look for another strong correlation at T1b */
850 if (k==2)
851 {
852 if (T1+T0>maxperiod)
853 T1b = T0;
854 else
855 T1b = T0+T1;
856 } else
857 {
858 T1b = (2*second_check[k]*T0+k)/(2*k);
859 }
860 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
861 xy = .5f * (xy + xy2);
862 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
863 g1 = compute_pitch_gain(xy, xx, yy);
864 if (FFABS(T1-prev_period)<=1)
865 cont = prev_gain;
866 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
867 cont = prev_gain * .5f;
868 else
869 cont = 0;
870 thresh = FFMAX(.3f, (.7f * g0) - cont);
871 /* Bias against very high pitch (very short period) to avoid false-positives
872 due to short-term correlation */
873 if (T1<3*minperiod)
874 thresh = FFMAX(.4f, (.85f * g0) - cont);
875 else if (T1<2*minperiod)
876 thresh = FFMAX(.5f, (.9f * g0) - cont);
877 if (g1 > thresh)
878 {
879 best_xy = xy;
880 best_yy = yy;
881 T = T1;
882 g = g1;
883 }
884 }
885 best_xy = FFMAX(0, best_xy);
886 if (best_yy <= best_xy)
887 pg = Q15ONE;
888 else
889 pg = best_xy/(best_yy + 1);
890
891 for (k = 0; k < 3; k++)
892 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
893 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
894 offset = 1;
895 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
896 offset = -1;
897 else
898 offset = 0;
899 if (pg > g)
900 pg = g;
901 *T0_ = 2*T+offset;
902
903 if (*T0_<minperiod0)
904 *T0_=minperiod0;
905 return pg;
906 }
907
find_best_pitch(float * xcorr,float * y,int len,int max_pitch,int * best_pitch)908 static void find_best_pitch(float *xcorr, float *y, int len,
909 int max_pitch, int *best_pitch)
910 {
911 float best_num[2];
912 float best_den[2];
913 float Syy = 1.f;
914
915 best_num[0] = -1;
916 best_num[1] = -1;
917 best_den[0] = 0;
918 best_den[1] = 0;
919 best_pitch[0] = 0;
920 best_pitch[1] = 1;
921
922 for (int j = 0; j < len; j++)
923 Syy += y[j] * y[j];
924
925 for (int i = 0; i < max_pitch; i++) {
926 if (xcorr[i]>0) {
927 float num;
928 float xcorr16;
929
930 xcorr16 = xcorr[i];
931 /* Considering the range of xcorr16, this should avoid both underflows
932 and overflows (inf) when squaring xcorr16 */
933 xcorr16 *= 1e-12f;
934 num = xcorr16 * xcorr16;
935 if ((num * best_den[1]) > (best_num[1] * Syy)) {
936 if ((num * best_den[0]) > (best_num[0] * Syy)) {
937 best_num[1] = best_num[0];
938 best_den[1] = best_den[0];
939 best_pitch[1] = best_pitch[0];
940 best_num[0] = num;
941 best_den[0] = Syy;
942 best_pitch[0] = i;
943 } else {
944 best_num[1] = num;
945 best_den[1] = Syy;
946 best_pitch[1] = i;
947 }
948 }
949 }
950 Syy += y[i+len]*y[i+len] - y[i] * y[i];
951 Syy = FFMAX(1, Syy);
952 }
953 }
954
pitch_search(const float * x_lp,float * y,int len,int max_pitch,int * pitch)955 static void pitch_search(const float *x_lp, float *y,
956 int len, int max_pitch, int *pitch)
957 {
958 int lag;
959 int best_pitch[2]={0,0};
960 int offset;
961
962 float x_lp4[WINDOW_SIZE];
963 float y_lp4[WINDOW_SIZE];
964 float xcorr[WINDOW_SIZE];
965
966 lag = len+max_pitch;
967
968 /* Downsample by 2 again */
969 for (int j = 0; j < len >> 2; j++)
970 x_lp4[j] = x_lp[2*j];
971 for (int j = 0; j < lag >> 2; j++)
972 y_lp4[j] = y[2*j];
973
974 /* Coarse search with 4x decimation */
975
976 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
977
978 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
979
980 /* Finer search with 2x decimation */
981 for (int i = 0; i < max_pitch >> 1; i++) {
982 float sum;
983 xcorr[i] = 0;
984 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
985 continue;
986 sum = celt_inner_prod(x_lp, y+i, len>>1);
987 xcorr[i] = FFMAX(-1, sum);
988 }
989
990 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
991
992 /* Refine by pseudo-interpolation */
993 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
994 float a, b, c;
995
996 a = xcorr[best_pitch[0] - 1];
997 b = xcorr[best_pitch[0]];
998 c = xcorr[best_pitch[0] + 1];
999 if (c - a > .7f * (b - a))
1000 offset = 1;
1001 else if (a - c > .7f * (b-c))
1002 offset = -1;
1003 else
1004 offset = 0;
1005 } else {
1006 offset = 0;
1007 }
1008
1009 *pitch = 2 * best_pitch[0] - offset;
1010 }
1011
dct(AudioRNNContext * s,float * out,const float * in)1012 static void dct(AudioRNNContext *s, float *out, const float *in)
1013 {
1014 for (int i = 0; i < NB_BANDS; i++) {
1015 float sum;
1016
1017 sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1018 out[i] = sum * sqrtf(2.f / 22);
1019 }
1020 }
1021
compute_frame_features(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,AVComplexFloat * P,float * Ex,float * Ep,float * Exp,float * features,const float * in)1022 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1023 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1024 {
1025 float E = 0;
1026 float *ceps_0, *ceps_1, *ceps_2;
1027 float spec_variability = 0;
1028 LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1029 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1030 float pitch_buf[PITCH_BUF_SIZE>>1];
1031 int pitch_index;
1032 float gain;
1033 float *(pre[1]);
1034 float tmp[NB_BANDS];
1035 float follow, logMax;
1036
1037 frame_analysis(s, st, X, Ex, in);
1038 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1039 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1040 pre[0] = &st->pitch_buf[0];
1041 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1042 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1043 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1044 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1045
1046 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1047 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1048 st->last_period = pitch_index;
1049 st->last_gain = gain;
1050
1051 for (int i = 0; i < WINDOW_SIZE; i++)
1052 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1053
1054 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1055 forward_transform(st, P, p);
1056 compute_band_energy(Ep, P);
1057 compute_band_corr(Exp, X, P);
1058
1059 for (int i = 0; i < NB_BANDS; i++)
1060 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1061
1062 dct(s, tmp, Exp);
1063
1064 for (int i = 0; i < NB_DELTA_CEPS; i++)
1065 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1066
1067 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1068 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1069 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1070 logMax = -2;
1071 follow = -2;
1072
1073 for (int i = 0; i < NB_BANDS; i++) {
1074 Ly[i] = log10f(1e-2f + Ex[i]);
1075 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1076 logMax = FFMAX(logMax, Ly[i]);
1077 follow = FFMAX(follow-1.5, Ly[i]);
1078 E += Ex[i];
1079 }
1080
1081 if (E < 0.04f) {
1082 /* If there's no audio, avoid messing up the state. */
1083 RNN_CLEAR(features, NB_FEATURES);
1084 return 1;
1085 }
1086
1087 dct(s, features, Ly);
1088 features[0] -= 12;
1089 features[1] -= 4;
1090 ceps_0 = st->cepstral_mem[st->memid];
1091 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1092 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1093
1094 for (int i = 0; i < NB_BANDS; i++)
1095 ceps_0[i] = features[i];
1096
1097 st->memid++;
1098 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1099 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1100 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1101 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1102 }
1103 /* Spectral variability features. */
1104 if (st->memid == CEPS_MEM)
1105 st->memid = 0;
1106
1107 for (int i = 0; i < CEPS_MEM; i++) {
1108 float mindist = 1e15f;
1109 for (int j = 0; j < CEPS_MEM; j++) {
1110 float dist = 0.f;
1111 for (int k = 0; k < NB_BANDS; k++) {
1112 float tmp;
1113
1114 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1115 dist += tmp*tmp;
1116 }
1117
1118 if (j != i)
1119 mindist = FFMIN(mindist, dist);
1120 }
1121
1122 spec_variability += mindist;
1123 }
1124
1125 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1126
1127 return 0;
1128 }
1129
interp_band_gain(float * g,const float * bandE)1130 static void interp_band_gain(float *g, const float *bandE)
1131 {
1132 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1133
1134 for (int i = 0; i < NB_BANDS - 1; i++) {
1135 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1136
1137 for (int j = 0; j < band_size; j++) {
1138 float frac = (float)j / band_size;
1139
1140 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1141 }
1142 }
1143 }
1144
pitch_filter(AVComplexFloat * X,const AVComplexFloat * P,const float * Ex,const float * Ep,const float * Exp,const float * g)1145 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1146 const float *Exp, const float *g)
1147 {
1148 float newE[NB_BANDS];
1149 float r[NB_BANDS];
1150 float norm[NB_BANDS];
1151 float rf[FREQ_SIZE] = {0};
1152 float normf[FREQ_SIZE]={0};
1153
1154 for (int i = 0; i < NB_BANDS; i++) {
1155 if (Exp[i]>g[i]) r[i] = 1;
1156 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1157 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1158 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1159 }
1160 interp_band_gain(rf, r);
1161 for (int i = 0; i < FREQ_SIZE; i++) {
1162 X[i].re += rf[i]*P[i].re;
1163 X[i].im += rf[i]*P[i].im;
1164 }
1165 compute_band_energy(newE, X);
1166 for (int i = 0; i < NB_BANDS; i++) {
1167 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1168 }
1169 interp_band_gain(normf, norm);
1170 for (int i = 0; i < FREQ_SIZE; i++) {
1171 X[i].re *= normf[i];
1172 X[i].im *= normf[i];
1173 }
1174 }
1175
1176 static const float tansig_table[201] = {
1177 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1178 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1179 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1180 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1181 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1182 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1183 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1184 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1185 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1186 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1187 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1188 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1189 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1190 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1191 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1192 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1193 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1194 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1195 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1196 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1197 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1198 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1199 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1200 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1201 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1202 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1203 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1204 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1205 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1206 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1207 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1208 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1209 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1210 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1211 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1212 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1213 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1214 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1215 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1216 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1217 1.000000f,
1218 };
1219
tansig_approx(float x)1220 static inline float tansig_approx(float x)
1221 {
1222 float y, dy;
1223 float sign=1;
1224 int i;
1225
1226 /* Tests are reversed to catch NaNs */
1227 if (!(x<8))
1228 return 1;
1229 if (!(x>-8))
1230 return -1;
1231 /* Another check in case of -ffast-math */
1232
1233 if (isnan(x))
1234 return 0;
1235
1236 if (x < 0) {
1237 x=-x;
1238 sign=-1;
1239 }
1240 i = (int)floor(.5f+25*x);
1241 x -= .04f*i;
1242 y = tansig_table[i];
1243 dy = 1-y*y;
1244 y = y + x*dy*(1 - y*x);
1245 return sign*y;
1246 }
1247
sigmoid_approx(float x)1248 static inline float sigmoid_approx(float x)
1249 {
1250 return .5f + .5f*tansig_approx(.5f*x);
1251 }
1252
compute_dense(const DenseLayer * layer,float * output,const float * input)1253 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1254 {
1255 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1256
1257 for (int i = 0; i < N; i++) {
1258 /* Compute update gate. */
1259 float sum = layer->bias[i];
1260
1261 for (int j = 0; j < M; j++)
1262 sum += layer->input_weights[j * stride + i] * input[j];
1263
1264 output[i] = WEIGHTS_SCALE * sum;
1265 }
1266
1267 if (layer->activation == ACTIVATION_SIGMOID) {
1268 for (int i = 0; i < N; i++)
1269 output[i] = sigmoid_approx(output[i]);
1270 } else if (layer->activation == ACTIVATION_TANH) {
1271 for (int i = 0; i < N; i++)
1272 output[i] = tansig_approx(output[i]);
1273 } else if (layer->activation == ACTIVATION_RELU) {
1274 for (int i = 0; i < N; i++)
1275 output[i] = FFMAX(0, output[i]);
1276 } else {
1277 av_assert0(0);
1278 }
1279 }
1280
compute_gru(AudioRNNContext * s,const GRULayer * gru,float * state,const float * input)1281 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1282 {
1283 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1284 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1285 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1286 const int M = gru->nb_inputs;
1287 const int N = gru->nb_neurons;
1288 const int AN = FFALIGN(N, 4);
1289 const int AM = FFALIGN(M, 4);
1290 const int stride = 3 * AN, istride = 3 * AM;
1291
1292 for (int i = 0; i < N; i++) {
1293 /* Compute update gate. */
1294 float sum = gru->bias[i];
1295
1296 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1297 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1298 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1299 }
1300
1301 for (int i = 0; i < N; i++) {
1302 /* Compute reset gate. */
1303 float sum = gru->bias[N + i];
1304
1305 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1306 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1307 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1308 }
1309
1310 for (int i = 0; i < N; i++) {
1311 /* Compute output. */
1312 float sum = gru->bias[2 * N + i];
1313
1314 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1315 for (int j = 0; j < N; j++)
1316 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1317
1318 if (gru->activation == ACTIVATION_SIGMOID)
1319 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1320 else if (gru->activation == ACTIVATION_TANH)
1321 sum = tansig_approx(WEIGHTS_SCALE * sum);
1322 else if (gru->activation == ACTIVATION_RELU)
1323 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1324 else
1325 av_assert0(0);
1326 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1327 }
1328
1329 RNN_COPY(state, h, N);
1330 }
1331
1332 #define INPUT_SIZE 42
1333
compute_rnn(AudioRNNContext * s,RNNState * rnn,float * gains,float * vad,const float * input)1334 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1335 {
1336 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1337 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1338 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1339
1340 compute_dense(rnn->model->input_dense, dense_out, input);
1341 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1342 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1343
1344 memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1345 memcpy(noise_input + rnn->model->input_dense_size,
1346 rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1347 memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1348 input, INPUT_SIZE * sizeof(float));
1349
1350 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1351
1352 memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1353 memcpy(denoise_input + rnn->model->vad_gru_size,
1354 rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1355 memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1356 input, INPUT_SIZE * sizeof(float));
1357
1358 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1359 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1360 }
1361
rnnoise_channel(AudioRNNContext * s,DenoiseState * st,float * out,const float * in,int disabled)1362 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1363 int disabled)
1364 {
1365 AVComplexFloat X[FREQ_SIZE];
1366 AVComplexFloat P[WINDOW_SIZE];
1367 float x[FRAME_SIZE];
1368 float Ex[NB_BANDS], Ep[NB_BANDS];
1369 LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1370 float features[NB_FEATURES];
1371 float g[NB_BANDS];
1372 float gf[FREQ_SIZE];
1373 float vad_prob = 0;
1374 float *history = st->history;
1375 static const float a_hp[2] = {-1.99599, 0.99600};
1376 static const float b_hp[2] = {-2, 1};
1377 int silence;
1378
1379 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1380 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1381
1382 if (!silence && !disabled) {
1383 compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1384 pitch_filter(X, P, Ex, Ep, Exp, g);
1385 for (int i = 0; i < NB_BANDS; i++) {
1386 float alpha = .6f;
1387
1388 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1389 st->lastg[i] = g[i];
1390 }
1391
1392 interp_band_gain(gf, g);
1393
1394 for (int i = 0; i < FREQ_SIZE; i++) {
1395 X[i].re *= gf[i];
1396 X[i].im *= gf[i];
1397 }
1398 }
1399
1400 frame_synthesis(s, st, out, X);
1401 memcpy(history, in, FRAME_SIZE * sizeof(*history));
1402
1403 return vad_prob;
1404 }
1405
1406 typedef struct ThreadData {
1407 AVFrame *in, *out;
1408 } ThreadData;
1409
rnnoise_channels(AVFilterContext * ctx,void * arg,int jobnr,int nb_jobs)1410 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1411 {
1412 AudioRNNContext *s = ctx->priv;
1413 ThreadData *td = arg;
1414 AVFrame *in = td->in;
1415 AVFrame *out = td->out;
1416 const int start = (out->ch_layout.nb_channels * jobnr) / nb_jobs;
1417 const int end = (out->ch_layout.nb_channels * (jobnr+1)) / nb_jobs;
1418
1419 for (int ch = start; ch < end; ch++) {
1420 rnnoise_channel(s, &s->st[ch],
1421 (float *)out->extended_data[ch],
1422 (const float *)in->extended_data[ch],
1423 ctx->is_disabled);
1424 }
1425
1426 return 0;
1427 }
1428
filter_frame(AVFilterLink * inlink,AVFrame * in)1429 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1430 {
1431 AVFilterContext *ctx = inlink->dst;
1432 AVFilterLink *outlink = ctx->outputs[0];
1433 AVFrame *out = NULL;
1434 ThreadData td;
1435
1436 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1437 if (!out) {
1438 av_frame_free(&in);
1439 return AVERROR(ENOMEM);
1440 }
1441 out->pts = in->pts;
1442
1443 td.in = in; td.out = out;
1444 ff_filter_execute(ctx, rnnoise_channels, &td, NULL,
1445 FFMIN(outlink->ch_layout.nb_channels, ff_filter_get_nb_threads(ctx)));
1446
1447 av_frame_free(&in);
1448 return ff_filter_frame(outlink, out);
1449 }
1450
activate(AVFilterContext * ctx)1451 static int activate(AVFilterContext *ctx)
1452 {
1453 AVFilterLink *inlink = ctx->inputs[0];
1454 AVFilterLink *outlink = ctx->outputs[0];
1455 AVFrame *in = NULL;
1456 int ret;
1457
1458 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1459
1460 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1461 if (ret < 0)
1462 return ret;
1463
1464 if (ret > 0)
1465 return filter_frame(inlink, in);
1466
1467 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1468 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1469
1470 return FFERROR_NOT_READY;
1471 }
1472
open_model(AVFilterContext * ctx,RNNModel ** model)1473 static int open_model(AVFilterContext *ctx, RNNModel **model)
1474 {
1475 AudioRNNContext *s = ctx->priv;
1476 int ret;
1477 FILE *f;
1478
1479 if (!s->model_name)
1480 return AVERROR(EINVAL);
1481 f = avpriv_fopen_utf8(s->model_name, "r");
1482 if (!f) {
1483 av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
1484 return AVERROR(EINVAL);
1485 }
1486
1487 ret = rnnoise_model_from_file(f, model);
1488 fclose(f);
1489 if (!*model || ret < 0)
1490 return ret;
1491
1492 return 0;
1493 }
1494
init(AVFilterContext * ctx)1495 static av_cold int init(AVFilterContext *ctx)
1496 {
1497 AudioRNNContext *s = ctx->priv;
1498 int ret;
1499
1500 s->fdsp = avpriv_float_dsp_alloc(0);
1501 if (!s->fdsp)
1502 return AVERROR(ENOMEM);
1503
1504 ret = open_model(ctx, &s->model[0]);
1505 if (ret < 0)
1506 return ret;
1507
1508 for (int i = 0; i < FRAME_SIZE; i++) {
1509 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1510 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1511 }
1512
1513 for (int i = 0; i < NB_BANDS; i++) {
1514 for (int j = 0; j < NB_BANDS; j++) {
1515 s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1516 if (j == 0)
1517 s->dct_table[j][i] *= sqrtf(.5);
1518 }
1519 }
1520
1521 return 0;
1522 }
1523
free_model(AVFilterContext * ctx,int n)1524 static void free_model(AVFilterContext *ctx, int n)
1525 {
1526 AudioRNNContext *s = ctx->priv;
1527
1528 rnnoise_model_free(s->model[n]);
1529 s->model[n] = NULL;
1530
1531 for (int ch = 0; ch < s->channels && s->st; ch++) {
1532 av_freep(&s->st[ch].rnn[n].vad_gru_state);
1533 av_freep(&s->st[ch].rnn[n].noise_gru_state);
1534 av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1535 }
1536 }
1537
process_command(AVFilterContext * ctx,const char * cmd,const char * args,char * res,int res_len,int flags)1538 static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1539 char *res, int res_len, int flags)
1540 {
1541 AudioRNNContext *s = ctx->priv;
1542 int ret;
1543
1544 ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1545 if (ret < 0)
1546 return ret;
1547
1548 ret = open_model(ctx, &s->model[1]);
1549 if (ret < 0)
1550 return ret;
1551
1552 FFSWAP(RNNModel *, s->model[0], s->model[1]);
1553 for (int ch = 0; ch < s->channels; ch++)
1554 FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1555
1556 ret = config_input(ctx->inputs[0]);
1557 if (ret < 0) {
1558 for (int ch = 0; ch < s->channels; ch++)
1559 FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1560 FFSWAP(RNNModel *, s->model[0], s->model[1]);
1561 return ret;
1562 }
1563
1564 free_model(ctx, 1);
1565 return 0;
1566 }
1567
uninit(AVFilterContext * ctx)1568 static av_cold void uninit(AVFilterContext *ctx)
1569 {
1570 AudioRNNContext *s = ctx->priv;
1571
1572 av_freep(&s->fdsp);
1573 free_model(ctx, 0);
1574 for (int ch = 0; ch < s->channels && s->st; ch++) {
1575 av_tx_uninit(&s->st[ch].tx);
1576 av_tx_uninit(&s->st[ch].txi);
1577 }
1578 av_freep(&s->st);
1579 }
1580
1581 static const AVFilterPad inputs[] = {
1582 {
1583 .name = "default",
1584 .type = AVMEDIA_TYPE_AUDIO,
1585 .config_props = config_input,
1586 },
1587 };
1588
1589 static const AVFilterPad outputs[] = {
1590 {
1591 .name = "default",
1592 .type = AVMEDIA_TYPE_AUDIO,
1593 },
1594 };
1595
1596 #define OFFSET(x) offsetof(AudioRNNContext, x)
1597 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1598
1599 static const AVOption arnndn_options[] = {
1600 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1601 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1602 { "mix", "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1603 { NULL }
1604 };
1605
1606 AVFILTER_DEFINE_CLASS(arnndn);
1607
1608 const AVFilter ff_af_arnndn = {
1609 .name = "arnndn",
1610 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1611 .priv_size = sizeof(AudioRNNContext),
1612 .priv_class = &arnndn_class,
1613 .activate = activate,
1614 .init = init,
1615 .uninit = uninit,
1616 FILTER_INPUTS(inputs),
1617 FILTER_OUTPUTS(outputs),
1618 FILTER_QUERY_FUNC(query_formats),
1619 .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1620 AVFILTER_FLAG_SLICE_THREADS,
1621 .process_command = process_command,
1622 };
1623