1 /*
2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 *
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
16 *
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/mem_internal.h"
40 #include "libavutil/opt.h"
41 #include "libavutil/tx.h"
42 #include "avfilter.h"
43 #include "audio.h"
44 #include "filters.h"
45 #include "formats.h"
46
47 #define FRAME_SIZE_SHIFT 2
48 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
49 #define WINDOW_SIZE (2*FRAME_SIZE)
50 #define FREQ_SIZE (FRAME_SIZE + 1)
51
52 #define PITCH_MIN_PERIOD 60
53 #define PITCH_MAX_PERIOD 768
54 #define PITCH_FRAME_SIZE 960
55 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
56
57 #define SQUARE(x) ((x)*(x))
58
59 #define NB_BANDS 22
60
61 #define CEPS_MEM 8
62 #define NB_DELTA_CEPS 6
63
64 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
65
66 #define WEIGHTS_SCALE (1.f/256)
67
68 #define MAX_NEURONS 128
69
70 #define ACTIVATION_TANH 0
71 #define ACTIVATION_SIGMOID 1
72 #define ACTIVATION_RELU 2
73
74 #define Q15ONE 1.0f
75
76 typedef struct DenseLayer {
77 const float *bias;
78 const float *input_weights;
79 int nb_inputs;
80 int nb_neurons;
81 int activation;
82 } DenseLayer;
83
84 typedef struct GRULayer {
85 const float *bias;
86 const float *input_weights;
87 const float *recurrent_weights;
88 int nb_inputs;
89 int nb_neurons;
90 int activation;
91 } GRULayer;
92
93 typedef struct RNNModel {
94 int input_dense_size;
95 const DenseLayer *input_dense;
96
97 int vad_gru_size;
98 const GRULayer *vad_gru;
99
100 int noise_gru_size;
101 const GRULayer *noise_gru;
102
103 int denoise_gru_size;
104 const GRULayer *denoise_gru;
105
106 int denoise_output_size;
107 const DenseLayer *denoise_output;
108
109 int vad_output_size;
110 const DenseLayer *vad_output;
111 } RNNModel;
112
113 typedef struct RNNState {
114 float *vad_gru_state;
115 float *noise_gru_state;
116 float *denoise_gru_state;
117 RNNModel *model;
118 } RNNState;
119
120 typedef struct DenoiseState {
121 float analysis_mem[FRAME_SIZE];
122 float cepstral_mem[CEPS_MEM][NB_BANDS];
123 int memid;
124 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
125 float pitch_buf[PITCH_BUF_SIZE];
126 float pitch_enh_buf[PITCH_BUF_SIZE];
127 float last_gain;
128 int last_period;
129 float mem_hp_x[2];
130 float lastg[NB_BANDS];
131 float history[FRAME_SIZE];
132 RNNState rnn[2];
133 AVTXContext *tx, *txi;
134 av_tx_fn tx_fn, txi_fn;
135 } DenoiseState;
136
137 typedef struct AudioRNNContext {
138 const AVClass *class;
139
140 char *model_name;
141 float mix;
142
143 int channels;
144 DenoiseState *st;
145
146 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
147 DECLARE_ALIGNED(32, float, dct_table)[FFALIGN(NB_BANDS, 4)][FFALIGN(NB_BANDS, 4)];
148
149 RNNModel *model[2];
150
151 AVFloatDSPContext *fdsp;
152 } AudioRNNContext;
153
154 #define F_ACTIVATION_TANH 0
155 #define F_ACTIVATION_SIGMOID 1
156 #define F_ACTIVATION_RELU 2
157
rnnoise_model_free(RNNModel * model)158 static void rnnoise_model_free(RNNModel *model)
159 {
160 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
161 #define FREE_DENSE(name) do { \
162 if (model->name) { \
163 av_free((void *) model->name->input_weights); \
164 av_free((void *) model->name->bias); \
165 av_free((void *) model->name); \
166 } \
167 } while (0)
168 #define FREE_GRU(name) do { \
169 if (model->name) { \
170 av_free((void *) model->name->input_weights); \
171 av_free((void *) model->name->recurrent_weights); \
172 av_free((void *) model->name->bias); \
173 av_free((void *) model->name); \
174 } \
175 } while (0)
176
177 if (!model)
178 return;
179 FREE_DENSE(input_dense);
180 FREE_GRU(vad_gru);
181 FREE_GRU(noise_gru);
182 FREE_GRU(denoise_gru);
183 FREE_DENSE(denoise_output);
184 FREE_DENSE(vad_output);
185 av_free(model);
186 }
187
rnnoise_model_from_file(FILE * f,RNNModel ** rnn)188 static int rnnoise_model_from_file(FILE *f, RNNModel **rnn)
189 {
190 RNNModel *ret = NULL;
191 DenseLayer *input_dense;
192 GRULayer *vad_gru;
193 GRULayer *noise_gru;
194 GRULayer *denoise_gru;
195 DenseLayer *denoise_output;
196 DenseLayer *vad_output;
197 int in;
198
199 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
200 return AVERROR_INVALIDDATA;
201
202 ret = av_calloc(1, sizeof(RNNModel));
203 if (!ret)
204 return AVERROR(ENOMEM);
205
206 #define ALLOC_LAYER(type, name) \
207 name = av_calloc(1, sizeof(type)); \
208 if (!name) { \
209 rnnoise_model_free(ret); \
210 return AVERROR(ENOMEM); \
211 } \
212 ret->name = name
213
214 ALLOC_LAYER(DenseLayer, input_dense);
215 ALLOC_LAYER(GRULayer, vad_gru);
216 ALLOC_LAYER(GRULayer, noise_gru);
217 ALLOC_LAYER(GRULayer, denoise_gru);
218 ALLOC_LAYER(DenseLayer, denoise_output);
219 ALLOC_LAYER(DenseLayer, vad_output);
220
221 #define INPUT_VAL(name) do { \
222 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
223 rnnoise_model_free(ret); \
224 return AVERROR(EINVAL); \
225 } \
226 name = in; \
227 } while (0)
228
229 #define INPUT_ACTIVATION(name) do { \
230 int activation; \
231 INPUT_VAL(activation); \
232 switch (activation) { \
233 case F_ACTIVATION_SIGMOID: \
234 name = ACTIVATION_SIGMOID; \
235 break; \
236 case F_ACTIVATION_RELU: \
237 name = ACTIVATION_RELU; \
238 break; \
239 default: \
240 name = ACTIVATION_TANH; \
241 } \
242 } while (0)
243
244 #define INPUT_ARRAY(name, len) do { \
245 float *values = av_calloc((len), sizeof(float)); \
246 if (!values) { \
247 rnnoise_model_free(ret); \
248 return AVERROR(ENOMEM); \
249 } \
250 name = values; \
251 for (int i = 0; i < (len); i++) { \
252 if (fscanf(f, "%d", &in) != 1) { \
253 rnnoise_model_free(ret); \
254 return AVERROR(EINVAL); \
255 } \
256 values[i] = in; \
257 } \
258 } while (0)
259
260 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
261 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
262 if (!values) { \
263 rnnoise_model_free(ret); \
264 return AVERROR(ENOMEM); \
265 } \
266 name = values; \
267 for (int k = 0; k < (len0); k++) { \
268 for (int i = 0; i < (len2); i++) { \
269 for (int j = 0; j < (len1); j++) { \
270 if (fscanf(f, "%d", &in) != 1) { \
271 rnnoise_model_free(ret); \
272 return AVERROR(EINVAL); \
273 } \
274 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
275 } \
276 } \
277 } \
278 } while (0)
279
280 #define NEW_LINE() do { \
281 int c; \
282 while ((c = fgetc(f)) != EOF) { \
283 if (c == '\n') \
284 break; \
285 } \
286 } while (0)
287
288 #define INPUT_DENSE(name) do { \
289 INPUT_VAL(name->nb_inputs); \
290 INPUT_VAL(name->nb_neurons); \
291 ret->name ## _size = name->nb_neurons; \
292 INPUT_ACTIVATION(name->activation); \
293 NEW_LINE(); \
294 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
295 NEW_LINE(); \
296 INPUT_ARRAY(name->bias, name->nb_neurons); \
297 NEW_LINE(); \
298 } while (0)
299
300 #define INPUT_GRU(name) do { \
301 INPUT_VAL(name->nb_inputs); \
302 INPUT_VAL(name->nb_neurons); \
303 ret->name ## _size = name->nb_neurons; \
304 INPUT_ACTIVATION(name->activation); \
305 NEW_LINE(); \
306 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
307 NEW_LINE(); \
308 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
309 NEW_LINE(); \
310 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
311 NEW_LINE(); \
312 } while (0)
313
314 INPUT_DENSE(input_dense);
315 INPUT_GRU(vad_gru);
316 INPUT_GRU(noise_gru);
317 INPUT_GRU(denoise_gru);
318 INPUT_DENSE(denoise_output);
319 INPUT_DENSE(vad_output);
320
321 if (vad_output->nb_neurons != 1) {
322 rnnoise_model_free(ret);
323 return AVERROR(EINVAL);
324 }
325
326 *rnn = ret;
327
328 return 0;
329 }
330
query_formats(AVFilterContext * ctx)331 static int query_formats(AVFilterContext *ctx)
332 {
333 AVFilterFormats *formats = NULL;
334 AVFilterChannelLayouts *layouts = NULL;
335 static const enum AVSampleFormat sample_fmts[] = {
336 AV_SAMPLE_FMT_FLTP,
337 AV_SAMPLE_FMT_NONE
338 };
339 int ret, sample_rates[] = { 48000, -1 };
340
341 formats = ff_make_format_list(sample_fmts);
342 if (!formats)
343 return AVERROR(ENOMEM);
344 ret = ff_set_common_formats(ctx, formats);
345 if (ret < 0)
346 return ret;
347
348 layouts = ff_all_channel_counts();
349 if (!layouts)
350 return AVERROR(ENOMEM);
351
352 ret = ff_set_common_channel_layouts(ctx, layouts);
353 if (ret < 0)
354 return ret;
355
356 formats = ff_make_format_list(sample_rates);
357 if (!formats)
358 return AVERROR(ENOMEM);
359 return ff_set_common_samplerates(ctx, formats);
360 }
361
config_input(AVFilterLink * inlink)362 static int config_input(AVFilterLink *inlink)
363 {
364 AVFilterContext *ctx = inlink->dst;
365 AudioRNNContext *s = ctx->priv;
366 int ret;
367
368 s->channels = inlink->channels;
369
370 if (!s->st)
371 s->st = av_calloc(s->channels, sizeof(DenoiseState));
372 if (!s->st)
373 return AVERROR(ENOMEM);
374
375 for (int i = 0; i < s->channels; i++) {
376 DenoiseState *st = &s->st[i];
377
378 st->rnn[0].model = s->model[0];
379 st->rnn[0].vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->vad_gru_size, 16));
380 st->rnn[0].noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->noise_gru_size, 16));
381 st->rnn[0].denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model[0]->denoise_gru_size, 16));
382 if (!st->rnn[0].vad_gru_state ||
383 !st->rnn[0].noise_gru_state ||
384 !st->rnn[0].denoise_gru_state)
385 return AVERROR(ENOMEM);
386 }
387
388 for (int i = 0; i < s->channels; i++) {
389 DenoiseState *st = &s->st[i];
390
391 if (!st->tx)
392 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
393 if (ret < 0)
394 return ret;
395
396 if (!st->txi)
397 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
398 if (ret < 0)
399 return ret;
400 }
401
402 return 0;
403 }
404
biquad(float * y,float mem[2],const float * x,const float * b,const float * a,int N)405 static void biquad(float *y, float mem[2], const float *x,
406 const float *b, const float *a, int N)
407 {
408 for (int i = 0; i < N; i++) {
409 float xi, yi;
410
411 xi = x[i];
412 yi = x[i] + mem[0];
413 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
414 mem[1] = (b[1]*xi - a[1]*yi);
415 y[i] = yi;
416 }
417 }
418
419 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
420 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
421 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
422
forward_transform(DenoiseState * st,AVComplexFloat * out,const float * in)423 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
424 {
425 AVComplexFloat x[WINDOW_SIZE];
426 AVComplexFloat y[WINDOW_SIZE];
427
428 for (int i = 0; i < WINDOW_SIZE; i++) {
429 x[i].re = in[i];
430 x[i].im = 0;
431 }
432
433 st->tx_fn(st->tx, y, x, sizeof(float));
434
435 RNN_COPY(out, y, FREQ_SIZE);
436 }
437
inverse_transform(DenoiseState * st,float * out,const AVComplexFloat * in)438 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
439 {
440 AVComplexFloat x[WINDOW_SIZE];
441 AVComplexFloat y[WINDOW_SIZE];
442
443 RNN_COPY(x, in, FREQ_SIZE);
444
445 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
446 x[i].re = x[WINDOW_SIZE - i].re;
447 x[i].im = -x[WINDOW_SIZE - i].im;
448 }
449
450 st->txi_fn(st->txi, y, x, sizeof(float));
451
452 for (int i = 0; i < WINDOW_SIZE; i++)
453 out[i] = y[i].re / WINDOW_SIZE;
454 }
455
456 static const uint8_t eband5ms[] = {
457 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
458 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
459 };
460
compute_band_energy(float * bandE,const AVComplexFloat * X)461 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
462 {
463 float sum[NB_BANDS] = {0};
464
465 for (int i = 0; i < NB_BANDS - 1; i++) {
466 int band_size;
467
468 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
469 for (int j = 0; j < band_size; j++) {
470 float tmp, frac = (float)j / band_size;
471
472 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
473 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
474 sum[i] += (1.f - frac) * tmp;
475 sum[i + 1] += frac * tmp;
476 }
477 }
478
479 sum[0] *= 2;
480 sum[NB_BANDS - 1] *= 2;
481
482 for (int i = 0; i < NB_BANDS; i++)
483 bandE[i] = sum[i];
484 }
485
compute_band_corr(float * bandE,const AVComplexFloat * X,const AVComplexFloat * P)486 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
487 {
488 float sum[NB_BANDS] = { 0 };
489
490 for (int i = 0; i < NB_BANDS - 1; i++) {
491 int band_size;
492
493 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
494 for (int j = 0; j < band_size; j++) {
495 float tmp, frac = (float)j / band_size;
496
497 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
498 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
499 sum[i] += (1 - frac) * tmp;
500 sum[i + 1] += frac * tmp;
501 }
502 }
503
504 sum[0] *= 2;
505 sum[NB_BANDS-1] *= 2;
506
507 for (int i = 0; i < NB_BANDS; i++)
508 bandE[i] = sum[i];
509 }
510
frame_analysis(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,float * Ex,const float * in)511 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
512 {
513 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
514
515 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
516 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
517 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
518 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
519 forward_transform(st, X, x);
520 compute_band_energy(Ex, X);
521 }
522
frame_synthesis(AudioRNNContext * s,DenoiseState * st,float * out,const AVComplexFloat * y)523 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
524 {
525 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
526 const float *src = st->history;
527 const float mix = s->mix;
528 const float imix = 1.f - FFMAX(mix, 0.f);
529
530 inverse_transform(st, x, y);
531 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
532 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
533 RNN_COPY(out, x, FRAME_SIZE);
534 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
535
536 for (int n = 0; n < FRAME_SIZE; n++)
537 out[n] = out[n] * mix + src[n] * imix;
538 }
539
xcorr_kernel(const float * x,const float * y,float sum[4],int len)540 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
541 {
542 float y_0, y_1, y_2, y_3 = 0;
543 int j;
544
545 y_0 = *y++;
546 y_1 = *y++;
547 y_2 = *y++;
548
549 for (j = 0; j < len - 3; j += 4) {
550 float tmp;
551
552 tmp = *x++;
553 y_3 = *y++;
554 sum[0] += tmp * y_0;
555 sum[1] += tmp * y_1;
556 sum[2] += tmp * y_2;
557 sum[3] += tmp * y_3;
558 tmp = *x++;
559 y_0 = *y++;
560 sum[0] += tmp * y_1;
561 sum[1] += tmp * y_2;
562 sum[2] += tmp * y_3;
563 sum[3] += tmp * y_0;
564 tmp = *x++;
565 y_1 = *y++;
566 sum[0] += tmp * y_2;
567 sum[1] += tmp * y_3;
568 sum[2] += tmp * y_0;
569 sum[3] += tmp * y_1;
570 tmp = *x++;
571 y_2 = *y++;
572 sum[0] += tmp * y_3;
573 sum[1] += tmp * y_0;
574 sum[2] += tmp * y_1;
575 sum[3] += tmp * y_2;
576 }
577
578 if (j++ < len) {
579 float tmp = *x++;
580
581 y_3 = *y++;
582 sum[0] += tmp * y_0;
583 sum[1] += tmp * y_1;
584 sum[2] += tmp * y_2;
585 sum[3] += tmp * y_3;
586 }
587
588 if (j++ < len) {
589 float tmp=*x++;
590
591 y_0 = *y++;
592 sum[0] += tmp * y_1;
593 sum[1] += tmp * y_2;
594 sum[2] += tmp * y_3;
595 sum[3] += tmp * y_0;
596 }
597
598 if (j < len) {
599 float tmp=*x++;
600
601 y_1 = *y++;
602 sum[0] += tmp * y_2;
603 sum[1] += tmp * y_3;
604 sum[2] += tmp * y_0;
605 sum[3] += tmp * y_1;
606 }
607 }
608
celt_inner_prod(const float * x,const float * y,int N)609 static inline float celt_inner_prod(const float *x,
610 const float *y, int N)
611 {
612 float xy = 0.f;
613
614 for (int i = 0; i < N; i++)
615 xy += x[i] * y[i];
616
617 return xy;
618 }
619
celt_pitch_xcorr(const float * x,const float * y,float * xcorr,int len,int max_pitch)620 static void celt_pitch_xcorr(const float *x, const float *y,
621 float *xcorr, int len, int max_pitch)
622 {
623 int i;
624
625 for (i = 0; i < max_pitch - 3; i += 4) {
626 float sum[4] = { 0, 0, 0, 0};
627
628 xcorr_kernel(x, y + i, sum, len);
629
630 xcorr[i] = sum[0];
631 xcorr[i + 1] = sum[1];
632 xcorr[i + 2] = sum[2];
633 xcorr[i + 3] = sum[3];
634 }
635 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
636 for (; i < max_pitch; i++) {
637 xcorr[i] = celt_inner_prod(x, y + i, len);
638 }
639 }
640
celt_autocorr(const float * x,float * ac,const float * window,int overlap,int lag,int n)641 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
642 float *ac, /* out: [0...lag-1] ac values */
643 const float *window,
644 int overlap,
645 int lag,
646 int n)
647 {
648 int fastN = n - lag;
649 int shift;
650 const float *xptr;
651 float xx[PITCH_BUF_SIZE>>1];
652
653 if (overlap == 0) {
654 xptr = x;
655 } else {
656 for (int i = 0; i < n; i++)
657 xx[i] = x[i];
658 for (int i = 0; i < overlap; i++) {
659 xx[i] = x[i] * window[i];
660 xx[n-i-1] = x[n-i-1] * window[i];
661 }
662 xptr = xx;
663 }
664
665 shift = 0;
666 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
667
668 for (int k = 0; k <= lag; k++) {
669 float d = 0.f;
670
671 for (int i = k + fastN; i < n; i++)
672 d += xptr[i] * xptr[i-k];
673 ac[k] += d;
674 }
675
676 return shift;
677 }
678
celt_lpc(float * lpc,const float * ac,int p)679 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
680 const float *ac, /* in: [0...p] autocorrelation values */
681 int p)
682 {
683 float r, error = ac[0];
684
685 RNN_CLEAR(lpc, p);
686 if (ac[0] != 0) {
687 for (int i = 0; i < p; i++) {
688 /* Sum up this iteration's reflection coefficient */
689 float rr = 0;
690 for (int j = 0; j < i; j++)
691 rr += (lpc[j] * ac[i - j]);
692 rr += ac[i + 1];
693 r = -rr/error;
694 /* Update LPC coefficients and total error */
695 lpc[i] = r;
696 for (int j = 0; j < (i + 1) >> 1; j++) {
697 float tmp1, tmp2;
698 tmp1 = lpc[j];
699 tmp2 = lpc[i-1-j];
700 lpc[j] = tmp1 + (r*tmp2);
701 lpc[i-1-j] = tmp2 + (r*tmp1);
702 }
703
704 error = error - (r * r *error);
705 /* Bail out once we get 30 dB gain */
706 if (error < .001f * ac[0])
707 break;
708 }
709 }
710 }
711
celt_fir5(const float * x,const float * num,float * y,int N,float * mem)712 static void celt_fir5(const float *x,
713 const float *num,
714 float *y,
715 int N,
716 float *mem)
717 {
718 float num0, num1, num2, num3, num4;
719 float mem0, mem1, mem2, mem3, mem4;
720
721 num0 = num[0];
722 num1 = num[1];
723 num2 = num[2];
724 num3 = num[3];
725 num4 = num[4];
726 mem0 = mem[0];
727 mem1 = mem[1];
728 mem2 = mem[2];
729 mem3 = mem[3];
730 mem4 = mem[4];
731
732 for (int i = 0; i < N; i++) {
733 float sum = x[i];
734
735 sum += (num0*mem0);
736 sum += (num1*mem1);
737 sum += (num2*mem2);
738 sum += (num3*mem3);
739 sum += (num4*mem4);
740 mem4 = mem3;
741 mem3 = mem2;
742 mem2 = mem1;
743 mem1 = mem0;
744 mem0 = x[i];
745 y[i] = sum;
746 }
747
748 mem[0] = mem0;
749 mem[1] = mem1;
750 mem[2] = mem2;
751 mem[3] = mem3;
752 mem[4] = mem4;
753 }
754
pitch_downsample(float * x[],float * x_lp,int len,int C)755 static void pitch_downsample(float *x[], float *x_lp,
756 int len, int C)
757 {
758 float ac[5];
759 float tmp=Q15ONE;
760 float lpc[4], mem[5]={0,0,0,0,0};
761 float lpc2[5];
762 float c1 = .8f;
763
764 for (int i = 1; i < len >> 1; i++)
765 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
766 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
767 if (C==2) {
768 for (int i = 1; i < len >> 1; i++)
769 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
770 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
771 }
772
773 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
774
775 /* Noise floor -40 dB */
776 ac[0] *= 1.0001f;
777 /* Lag windowing */
778 for (int i = 1; i <= 4; i++) {
779 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
780 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
781 }
782
783 celt_lpc(lpc, ac, 4);
784 for (int i = 0; i < 4; i++) {
785 tmp = .9f * tmp;
786 lpc[i] = (lpc[i] * tmp);
787 }
788 /* Add a zero */
789 lpc2[0] = lpc[0] + .8f;
790 lpc2[1] = lpc[1] + (c1 * lpc[0]);
791 lpc2[2] = lpc[2] + (c1 * lpc[1]);
792 lpc2[3] = lpc[3] + (c1 * lpc[2]);
793 lpc2[4] = (c1 * lpc[3]);
794 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
795 }
796
dual_inner_prod(const float * x,const float * y01,const float * y02,int N,float * xy1,float * xy2)797 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
798 int N, float *xy1, float *xy2)
799 {
800 float xy01 = 0, xy02 = 0;
801
802 for (int i = 0; i < N; i++) {
803 xy01 += (x[i] * y01[i]);
804 xy02 += (x[i] * y02[i]);
805 }
806
807 *xy1 = xy01;
808 *xy2 = xy02;
809 }
810
compute_pitch_gain(float xy,float xx,float yy)811 static float compute_pitch_gain(float xy, float xx, float yy)
812 {
813 return xy / sqrtf(1.f + xx * yy);
814 }
815
816 static const uint8_t second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
remove_doubling(float * x,int maxperiod,int minperiod,int N,int * T0_,int prev_period,float prev_gain)817 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
818 int *T0_, int prev_period, float prev_gain)
819 {
820 int k, i, T, T0;
821 float g, g0;
822 float pg;
823 float xy,xx,yy,xy2;
824 float xcorr[3];
825 float best_xy, best_yy;
826 int offset;
827 int minperiod0;
828 float yy_lookup[PITCH_MAX_PERIOD+1];
829
830 minperiod0 = minperiod;
831 maxperiod /= 2;
832 minperiod /= 2;
833 *T0_ /= 2;
834 prev_period /= 2;
835 N /= 2;
836 x += maxperiod;
837 if (*T0_>=maxperiod)
838 *T0_=maxperiod-1;
839
840 T = T0 = *T0_;
841 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
842 yy_lookup[0] = xx;
843 yy=xx;
844 for (i = 1; i <= maxperiod; i++) {
845 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
846 yy_lookup[i] = FFMAX(0, yy);
847 }
848 yy = yy_lookup[T0];
849 best_xy = xy;
850 best_yy = yy;
851 g = g0 = compute_pitch_gain(xy, xx, yy);
852 /* Look for any pitch at T/k */
853 for (k = 2; k <= 15; k++) {
854 int T1, T1b;
855 float g1;
856 float cont=0;
857 float thresh;
858 T1 = (2*T0+k)/(2*k);
859 if (T1 < minperiod)
860 break;
861 /* Look for another strong correlation at T1b */
862 if (k==2)
863 {
864 if (T1+T0>maxperiod)
865 T1b = T0;
866 else
867 T1b = T0+T1;
868 } else
869 {
870 T1b = (2*second_check[k]*T0+k)/(2*k);
871 }
872 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
873 xy = .5f * (xy + xy2);
874 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
875 g1 = compute_pitch_gain(xy, xx, yy);
876 if (FFABS(T1-prev_period)<=1)
877 cont = prev_gain;
878 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
879 cont = prev_gain * .5f;
880 else
881 cont = 0;
882 thresh = FFMAX(.3f, (.7f * g0) - cont);
883 /* Bias against very high pitch (very short period) to avoid false-positives
884 due to short-term correlation */
885 if (T1<3*minperiod)
886 thresh = FFMAX(.4f, (.85f * g0) - cont);
887 else if (T1<2*minperiod)
888 thresh = FFMAX(.5f, (.9f * g0) - cont);
889 if (g1 > thresh)
890 {
891 best_xy = xy;
892 best_yy = yy;
893 T = T1;
894 g = g1;
895 }
896 }
897 best_xy = FFMAX(0, best_xy);
898 if (best_yy <= best_xy)
899 pg = Q15ONE;
900 else
901 pg = best_xy/(best_yy + 1);
902
903 for (k = 0; k < 3; k++)
904 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
905 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
906 offset = 1;
907 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
908 offset = -1;
909 else
910 offset = 0;
911 if (pg > g)
912 pg = g;
913 *T0_ = 2*T+offset;
914
915 if (*T0_<minperiod0)
916 *T0_=minperiod0;
917 return pg;
918 }
919
find_best_pitch(float * xcorr,float * y,int len,int max_pitch,int * best_pitch)920 static void find_best_pitch(float *xcorr, float *y, int len,
921 int max_pitch, int *best_pitch)
922 {
923 float best_num[2];
924 float best_den[2];
925 float Syy = 1.f;
926
927 best_num[0] = -1;
928 best_num[1] = -1;
929 best_den[0] = 0;
930 best_den[1] = 0;
931 best_pitch[0] = 0;
932 best_pitch[1] = 1;
933
934 for (int j = 0; j < len; j++)
935 Syy += y[j] * y[j];
936
937 for (int i = 0; i < max_pitch; i++) {
938 if (xcorr[i]>0) {
939 float num;
940 float xcorr16;
941
942 xcorr16 = xcorr[i];
943 /* Considering the range of xcorr16, this should avoid both underflows
944 and overflows (inf) when squaring xcorr16 */
945 xcorr16 *= 1e-12f;
946 num = xcorr16 * xcorr16;
947 if ((num * best_den[1]) > (best_num[1] * Syy)) {
948 if ((num * best_den[0]) > (best_num[0] * Syy)) {
949 best_num[1] = best_num[0];
950 best_den[1] = best_den[0];
951 best_pitch[1] = best_pitch[0];
952 best_num[0] = num;
953 best_den[0] = Syy;
954 best_pitch[0] = i;
955 } else {
956 best_num[1] = num;
957 best_den[1] = Syy;
958 best_pitch[1] = i;
959 }
960 }
961 }
962 Syy += y[i+len]*y[i+len] - y[i] * y[i];
963 Syy = FFMAX(1, Syy);
964 }
965 }
966
pitch_search(const float * x_lp,float * y,int len,int max_pitch,int * pitch)967 static void pitch_search(const float *x_lp, float *y,
968 int len, int max_pitch, int *pitch)
969 {
970 int lag;
971 int best_pitch[2]={0,0};
972 int offset;
973
974 float x_lp4[WINDOW_SIZE];
975 float y_lp4[WINDOW_SIZE];
976 float xcorr[WINDOW_SIZE];
977
978 lag = len+max_pitch;
979
980 /* Downsample by 2 again */
981 for (int j = 0; j < len >> 2; j++)
982 x_lp4[j] = x_lp[2*j];
983 for (int j = 0; j < lag >> 2; j++)
984 y_lp4[j] = y[2*j];
985
986 /* Coarse search with 4x decimation */
987
988 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
989
990 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
991
992 /* Finer search with 2x decimation */
993 for (int i = 0; i < max_pitch >> 1; i++) {
994 float sum;
995 xcorr[i] = 0;
996 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
997 continue;
998 sum = celt_inner_prod(x_lp, y+i, len>>1);
999 xcorr[i] = FFMAX(-1, sum);
1000 }
1001
1002 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
1003
1004 /* Refine by pseudo-interpolation */
1005 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
1006 float a, b, c;
1007
1008 a = xcorr[best_pitch[0] - 1];
1009 b = xcorr[best_pitch[0]];
1010 c = xcorr[best_pitch[0] + 1];
1011 if (c - a > .7f * (b - a))
1012 offset = 1;
1013 else if (a - c > .7f * (b-c))
1014 offset = -1;
1015 else
1016 offset = 0;
1017 } else {
1018 offset = 0;
1019 }
1020
1021 *pitch = 2 * best_pitch[0] - offset;
1022 }
1023
dct(AudioRNNContext * s,float * out,const float * in)1024 static void dct(AudioRNNContext *s, float *out, const float *in)
1025 {
1026 for (int i = 0; i < NB_BANDS; i++) {
1027 float sum;
1028
1029 sum = s->fdsp->scalarproduct_float(in, s->dct_table[i], FFALIGN(NB_BANDS, 4));
1030 out[i] = sum * sqrtf(2.f / 22);
1031 }
1032 }
1033
compute_frame_features(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,AVComplexFloat * P,float * Ex,float * Ep,float * Exp,float * features,const float * in)1034 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1035 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1036 {
1037 float E = 0;
1038 float *ceps_0, *ceps_1, *ceps_2;
1039 float spec_variability = 0;
1040 LOCAL_ALIGNED_32(float, Ly, [NB_BANDS]);
1041 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1042 float pitch_buf[PITCH_BUF_SIZE>>1];
1043 int pitch_index;
1044 float gain;
1045 float *(pre[1]);
1046 float tmp[NB_BANDS];
1047 float follow, logMax;
1048
1049 frame_analysis(s, st, X, Ex, in);
1050 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1051 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1052 pre[0] = &st->pitch_buf[0];
1053 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1054 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1055 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1056 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1057
1058 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1059 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1060 st->last_period = pitch_index;
1061 st->last_gain = gain;
1062
1063 for (int i = 0; i < WINDOW_SIZE; i++)
1064 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1065
1066 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1067 forward_transform(st, P, p);
1068 compute_band_energy(Ep, P);
1069 compute_band_corr(Exp, X, P);
1070
1071 for (int i = 0; i < NB_BANDS; i++)
1072 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1073
1074 dct(s, tmp, Exp);
1075
1076 for (int i = 0; i < NB_DELTA_CEPS; i++)
1077 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1078
1079 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1080 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1081 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1082 logMax = -2;
1083 follow = -2;
1084
1085 for (int i = 0; i < NB_BANDS; i++) {
1086 Ly[i] = log10f(1e-2f + Ex[i]);
1087 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1088 logMax = FFMAX(logMax, Ly[i]);
1089 follow = FFMAX(follow-1.5, Ly[i]);
1090 E += Ex[i];
1091 }
1092
1093 if (E < 0.04f) {
1094 /* If there's no audio, avoid messing up the state. */
1095 RNN_CLEAR(features, NB_FEATURES);
1096 return 1;
1097 }
1098
1099 dct(s, features, Ly);
1100 features[0] -= 12;
1101 features[1] -= 4;
1102 ceps_0 = st->cepstral_mem[st->memid];
1103 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1104 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1105
1106 for (int i = 0; i < NB_BANDS; i++)
1107 ceps_0[i] = features[i];
1108
1109 st->memid++;
1110 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1111 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1112 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1113 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1114 }
1115 /* Spectral variability features. */
1116 if (st->memid == CEPS_MEM)
1117 st->memid = 0;
1118
1119 for (int i = 0; i < CEPS_MEM; i++) {
1120 float mindist = 1e15f;
1121 for (int j = 0; j < CEPS_MEM; j++) {
1122 float dist = 0.f;
1123 for (int k = 0; k < NB_BANDS; k++) {
1124 float tmp;
1125
1126 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1127 dist += tmp*tmp;
1128 }
1129
1130 if (j != i)
1131 mindist = FFMIN(mindist, dist);
1132 }
1133
1134 spec_variability += mindist;
1135 }
1136
1137 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1138
1139 return 0;
1140 }
1141
interp_band_gain(float * g,const float * bandE)1142 static void interp_band_gain(float *g, const float *bandE)
1143 {
1144 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1145
1146 for (int i = 0; i < NB_BANDS - 1; i++) {
1147 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1148
1149 for (int j = 0; j < band_size; j++) {
1150 float frac = (float)j / band_size;
1151
1152 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1153 }
1154 }
1155 }
1156
pitch_filter(AVComplexFloat * X,const AVComplexFloat * P,const float * Ex,const float * Ep,const float * Exp,const float * g)1157 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1158 const float *Exp, const float *g)
1159 {
1160 float newE[NB_BANDS];
1161 float r[NB_BANDS];
1162 float norm[NB_BANDS];
1163 float rf[FREQ_SIZE] = {0};
1164 float normf[FREQ_SIZE]={0};
1165
1166 for (int i = 0; i < NB_BANDS; i++) {
1167 if (Exp[i]>g[i]) r[i] = 1;
1168 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1169 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1170 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1171 }
1172 interp_band_gain(rf, r);
1173 for (int i = 0; i < FREQ_SIZE; i++) {
1174 X[i].re += rf[i]*P[i].re;
1175 X[i].im += rf[i]*P[i].im;
1176 }
1177 compute_band_energy(newE, X);
1178 for (int i = 0; i < NB_BANDS; i++) {
1179 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1180 }
1181 interp_band_gain(normf, norm);
1182 for (int i = 0; i < FREQ_SIZE; i++) {
1183 X[i].re *= normf[i];
1184 X[i].im *= normf[i];
1185 }
1186 }
1187
1188 static const float tansig_table[201] = {
1189 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1190 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1191 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1192 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1193 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1194 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1195 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1196 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1197 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1198 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1199 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1200 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1201 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1202 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1203 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1204 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1205 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1206 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1207 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1208 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1209 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1210 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1211 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1212 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1213 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1214 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1215 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1216 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1217 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1218 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1219 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1220 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1221 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1222 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1223 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1224 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1225 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1226 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1227 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1228 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1229 1.000000f,
1230 };
1231
tansig_approx(float x)1232 static inline float tansig_approx(float x)
1233 {
1234 float y, dy;
1235 float sign=1;
1236 int i;
1237
1238 /* Tests are reversed to catch NaNs */
1239 if (!(x<8))
1240 return 1;
1241 if (!(x>-8))
1242 return -1;
1243 /* Another check in case of -ffast-math */
1244
1245 if (isnan(x))
1246 return 0;
1247
1248 if (x < 0) {
1249 x=-x;
1250 sign=-1;
1251 }
1252 i = (int)floor(.5f+25*x);
1253 x -= .04f*i;
1254 y = tansig_table[i];
1255 dy = 1-y*y;
1256 y = y + x*dy*(1 - y*x);
1257 return sign*y;
1258 }
1259
sigmoid_approx(float x)1260 static inline float sigmoid_approx(float x)
1261 {
1262 return .5f + .5f*tansig_approx(.5f*x);
1263 }
1264
compute_dense(const DenseLayer * layer,float * output,const float * input)1265 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1266 {
1267 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1268
1269 for (int i = 0; i < N; i++) {
1270 /* Compute update gate. */
1271 float sum = layer->bias[i];
1272
1273 for (int j = 0; j < M; j++)
1274 sum += layer->input_weights[j * stride + i] * input[j];
1275
1276 output[i] = WEIGHTS_SCALE * sum;
1277 }
1278
1279 if (layer->activation == ACTIVATION_SIGMOID) {
1280 for (int i = 0; i < N; i++)
1281 output[i] = sigmoid_approx(output[i]);
1282 } else if (layer->activation == ACTIVATION_TANH) {
1283 for (int i = 0; i < N; i++)
1284 output[i] = tansig_approx(output[i]);
1285 } else if (layer->activation == ACTIVATION_RELU) {
1286 for (int i = 0; i < N; i++)
1287 output[i] = FFMAX(0, output[i]);
1288 } else {
1289 av_assert0(0);
1290 }
1291 }
1292
compute_gru(AudioRNNContext * s,const GRULayer * gru,float * state,const float * input)1293 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1294 {
1295 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1296 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1297 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1298 const int M = gru->nb_inputs;
1299 const int N = gru->nb_neurons;
1300 const int AN = FFALIGN(N, 4);
1301 const int AM = FFALIGN(M, 4);
1302 const int stride = 3 * AN, istride = 3 * AM;
1303
1304 for (int i = 0; i < N; i++) {
1305 /* Compute update gate. */
1306 float sum = gru->bias[i];
1307
1308 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1309 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1310 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1311 }
1312
1313 for (int i = 0; i < N; i++) {
1314 /* Compute reset gate. */
1315 float sum = gru->bias[N + i];
1316
1317 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1318 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1319 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1320 }
1321
1322 for (int i = 0; i < N; i++) {
1323 /* Compute output. */
1324 float sum = gru->bias[2 * N + i];
1325
1326 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1327 for (int j = 0; j < N; j++)
1328 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1329
1330 if (gru->activation == ACTIVATION_SIGMOID)
1331 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1332 else if (gru->activation == ACTIVATION_TANH)
1333 sum = tansig_approx(WEIGHTS_SCALE * sum);
1334 else if (gru->activation == ACTIVATION_RELU)
1335 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1336 else
1337 av_assert0(0);
1338 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1339 }
1340
1341 RNN_COPY(state, h, N);
1342 }
1343
1344 #define INPUT_SIZE 42
1345
compute_rnn(AudioRNNContext * s,RNNState * rnn,float * gains,float * vad,const float * input)1346 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1347 {
1348 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1349 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1350 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1351
1352 compute_dense(rnn->model->input_dense, dense_out, input);
1353 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1354 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1355
1356 memcpy(noise_input, dense_out, rnn->model->input_dense_size * sizeof(float));
1357 memcpy(noise_input + rnn->model->input_dense_size,
1358 rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1359 memcpy(noise_input + rnn->model->input_dense_size + rnn->model->vad_gru_size,
1360 input, INPUT_SIZE * sizeof(float));
1361
1362 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1363
1364 memcpy(denoise_input, rnn->vad_gru_state, rnn->model->vad_gru_size * sizeof(float));
1365 memcpy(denoise_input + rnn->model->vad_gru_size,
1366 rnn->noise_gru_state, rnn->model->noise_gru_size * sizeof(float));
1367 memcpy(denoise_input + rnn->model->vad_gru_size + rnn->model->noise_gru_size,
1368 input, INPUT_SIZE * sizeof(float));
1369
1370 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1371 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1372 }
1373
rnnoise_channel(AudioRNNContext * s,DenoiseState * st,float * out,const float * in,int disabled)1374 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in,
1375 int disabled)
1376 {
1377 AVComplexFloat X[FREQ_SIZE];
1378 AVComplexFloat P[WINDOW_SIZE];
1379 float x[FRAME_SIZE];
1380 float Ex[NB_BANDS], Ep[NB_BANDS];
1381 LOCAL_ALIGNED_32(float, Exp, [NB_BANDS]);
1382 float features[NB_FEATURES];
1383 float g[NB_BANDS];
1384 float gf[FREQ_SIZE];
1385 float vad_prob = 0;
1386 float *history = st->history;
1387 static const float a_hp[2] = {-1.99599, 0.99600};
1388 static const float b_hp[2] = {-2, 1};
1389 int silence;
1390
1391 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1392 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1393
1394 if (!silence && !disabled) {
1395 compute_rnn(s, &st->rnn[0], g, &vad_prob, features);
1396 pitch_filter(X, P, Ex, Ep, Exp, g);
1397 for (int i = 0; i < NB_BANDS; i++) {
1398 float alpha = .6f;
1399
1400 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1401 st->lastg[i] = g[i];
1402 }
1403
1404 interp_band_gain(gf, g);
1405
1406 for (int i = 0; i < FREQ_SIZE; i++) {
1407 X[i].re *= gf[i];
1408 X[i].im *= gf[i];
1409 }
1410 }
1411
1412 frame_synthesis(s, st, out, X);
1413 memcpy(history, in, FRAME_SIZE * sizeof(*history));
1414
1415 return vad_prob;
1416 }
1417
1418 typedef struct ThreadData {
1419 AVFrame *in, *out;
1420 } ThreadData;
1421
rnnoise_channels(AVFilterContext * ctx,void * arg,int jobnr,int nb_jobs)1422 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1423 {
1424 AudioRNNContext *s = ctx->priv;
1425 ThreadData *td = arg;
1426 AVFrame *in = td->in;
1427 AVFrame *out = td->out;
1428 const int start = (out->channels * jobnr) / nb_jobs;
1429 const int end = (out->channels * (jobnr+1)) / nb_jobs;
1430
1431 for (int ch = start; ch < end; ch++) {
1432 rnnoise_channel(s, &s->st[ch],
1433 (float *)out->extended_data[ch],
1434 (const float *)in->extended_data[ch],
1435 ctx->is_disabled);
1436 }
1437
1438 return 0;
1439 }
1440
filter_frame(AVFilterLink * inlink,AVFrame * in)1441 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1442 {
1443 AVFilterContext *ctx = inlink->dst;
1444 AVFilterLink *outlink = ctx->outputs[0];
1445 AVFrame *out = NULL;
1446 ThreadData td;
1447
1448 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1449 if (!out) {
1450 av_frame_free(&in);
1451 return AVERROR(ENOMEM);
1452 }
1453 out->pts = in->pts;
1454
1455 td.in = in; td.out = out;
1456 ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1457 ff_filter_get_nb_threads(ctx)));
1458
1459 av_frame_free(&in);
1460 return ff_filter_frame(outlink, out);
1461 }
1462
activate(AVFilterContext * ctx)1463 static int activate(AVFilterContext *ctx)
1464 {
1465 AVFilterLink *inlink = ctx->inputs[0];
1466 AVFilterLink *outlink = ctx->outputs[0];
1467 AVFrame *in = NULL;
1468 int ret;
1469
1470 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1471
1472 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1473 if (ret < 0)
1474 return ret;
1475
1476 if (ret > 0)
1477 return filter_frame(inlink, in);
1478
1479 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1480 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1481
1482 return FFERROR_NOT_READY;
1483 }
1484
open_model(AVFilterContext * ctx,RNNModel ** model)1485 static int open_model(AVFilterContext *ctx, RNNModel **model)
1486 {
1487 AudioRNNContext *s = ctx->priv;
1488 int ret;
1489 FILE *f;
1490
1491 if (!s->model_name)
1492 return AVERROR(EINVAL);
1493 f = av_fopen_utf8(s->model_name, "r");
1494 if (!f) {
1495 av_log(ctx, AV_LOG_ERROR, "Failed to open model file: %s\n", s->model_name);
1496 return AVERROR(EINVAL);
1497 }
1498
1499 ret = rnnoise_model_from_file(f, model);
1500 fclose(f);
1501 if (!*model || ret < 0)
1502 return ret;
1503
1504 return 0;
1505 }
1506
init(AVFilterContext * ctx)1507 static av_cold int init(AVFilterContext *ctx)
1508 {
1509 AudioRNNContext *s = ctx->priv;
1510 int ret;
1511
1512 s->fdsp = avpriv_float_dsp_alloc(0);
1513 if (!s->fdsp)
1514 return AVERROR(ENOMEM);
1515
1516 ret = open_model(ctx, &s->model[0]);
1517 if (ret < 0)
1518 return ret;
1519
1520 for (int i = 0; i < FRAME_SIZE; i++) {
1521 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1522 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1523 }
1524
1525 for (int i = 0; i < NB_BANDS; i++) {
1526 for (int j = 0; j < NB_BANDS; j++) {
1527 s->dct_table[j][i] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1528 if (j == 0)
1529 s->dct_table[j][i] *= sqrtf(.5);
1530 }
1531 }
1532
1533 return 0;
1534 }
1535
free_model(AVFilterContext * ctx,int n)1536 static void free_model(AVFilterContext *ctx, int n)
1537 {
1538 AudioRNNContext *s = ctx->priv;
1539
1540 rnnoise_model_free(s->model[n]);
1541 s->model[n] = NULL;
1542
1543 for (int ch = 0; ch < s->channels && s->st; ch++) {
1544 av_freep(&s->st[ch].rnn[n].vad_gru_state);
1545 av_freep(&s->st[ch].rnn[n].noise_gru_state);
1546 av_freep(&s->st[ch].rnn[n].denoise_gru_state);
1547 }
1548 }
1549
process_command(AVFilterContext * ctx,const char * cmd,const char * args,char * res,int res_len,int flags)1550 static int process_command(AVFilterContext *ctx, const char *cmd, const char *args,
1551 char *res, int res_len, int flags)
1552 {
1553 AudioRNNContext *s = ctx->priv;
1554 int ret;
1555
1556 ret = ff_filter_process_command(ctx, cmd, args, res, res_len, flags);
1557 if (ret < 0)
1558 return ret;
1559
1560 ret = open_model(ctx, &s->model[1]);
1561 if (ret < 0)
1562 return ret;
1563
1564 FFSWAP(RNNModel *, s->model[0], s->model[1]);
1565 for (int ch = 0; ch < s->channels; ch++)
1566 FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1567
1568 ret = config_input(ctx->inputs[0]);
1569 if (ret < 0) {
1570 for (int ch = 0; ch < s->channels; ch++)
1571 FFSWAP(RNNState, s->st[ch].rnn[0], s->st[ch].rnn[1]);
1572 FFSWAP(RNNModel *, s->model[0], s->model[1]);
1573 return ret;
1574 }
1575
1576 free_model(ctx, 1);
1577 return 0;
1578 }
1579
uninit(AVFilterContext * ctx)1580 static av_cold void uninit(AVFilterContext *ctx)
1581 {
1582 AudioRNNContext *s = ctx->priv;
1583
1584 av_freep(&s->fdsp);
1585 free_model(ctx, 0);
1586 for (int ch = 0; ch < s->channels && s->st; ch++) {
1587 av_tx_uninit(&s->st[ch].tx);
1588 av_tx_uninit(&s->st[ch].txi);
1589 }
1590 av_freep(&s->st);
1591 }
1592
1593 static const AVFilterPad inputs[] = {
1594 {
1595 .name = "default",
1596 .type = AVMEDIA_TYPE_AUDIO,
1597 .config_props = config_input,
1598 },
1599 { NULL }
1600 };
1601
1602 static const AVFilterPad outputs[] = {
1603 {
1604 .name = "default",
1605 .type = AVMEDIA_TYPE_AUDIO,
1606 },
1607 { NULL }
1608 };
1609
1610 #define OFFSET(x) offsetof(AudioRNNContext, x)
1611 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_RUNTIME_PARAM
1612
1613 static const AVOption arnndn_options[] = {
1614 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1615 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1616 { "mix", "set output vs input mix", OFFSET(mix), AV_OPT_TYPE_FLOAT, {.dbl=1.0},-1, 1, AF },
1617 { NULL }
1618 };
1619
1620 AVFILTER_DEFINE_CLASS(arnndn);
1621
1622 AVFilter ff_af_arnndn = {
1623 .name = "arnndn",
1624 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1625 .query_formats = query_formats,
1626 .priv_size = sizeof(AudioRNNContext),
1627 .priv_class = &arnndn_class,
1628 .activate = activate,
1629 .init = init,
1630 .uninit = uninit,
1631 .inputs = inputs,
1632 .outputs = outputs,
1633 .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_INTERNAL |
1634 AVFILTER_FLAG_SLICE_THREADS,
1635 .process_command = process_command,
1636 };
1637