1 /*
2 * Copyright (c) 2018 Gregor Richards
3 * Copyright (c) 2017 Mozilla
4 * Copyright (c) 2005-2009 Xiph.Org Foundation
5 * Copyright (c) 2007-2008 CSIRO
6 * Copyright (c) 2008-2011 Octasic Inc.
7 * Copyright (c) Jean-Marc Valin
8 * Copyright (c) 2019 Paul B Mahol
9 *
10 * Redistribution and use in source and binary forms, with or without
11 * modification, are permitted provided that the following conditions
12 * are met:
13 *
14 * - Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
16 *
17 * - Redistributions in binary form must reproduce the above copyright
18 * notice, this list of conditions and the following disclaimer in the
19 * documentation and/or other materials provided with the distribution.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR
25 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
29 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
30 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
31 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 */
33
34 #include <float.h>
35
36 #include "libavutil/avassert.h"
37 #include "libavutil/avstring.h"
38 #include "libavutil/float_dsp.h"
39 #include "libavutil/opt.h"
40 #include "libavutil/tx.h"
41 #include "avfilter.h"
42 #include "audio.h"
43 #include "filters.h"
44 #include "formats.h"
45
46 #define FRAME_SIZE_SHIFT 2
47 #define FRAME_SIZE (120<<FRAME_SIZE_SHIFT)
48 #define WINDOW_SIZE (2*FRAME_SIZE)
49 #define FREQ_SIZE (FRAME_SIZE + 1)
50
51 #define PITCH_MIN_PERIOD 60
52 #define PITCH_MAX_PERIOD 768
53 #define PITCH_FRAME_SIZE 960
54 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
55
56 #define SQUARE(x) ((x)*(x))
57
58 #define NB_BANDS 22
59
60 #define CEPS_MEM 8
61 #define NB_DELTA_CEPS 6
62
63 #define NB_FEATURES (NB_BANDS+3*NB_DELTA_CEPS+2)
64
65 #define WEIGHTS_SCALE (1.f/256)
66
67 #define MAX_NEURONS 128
68
69 #define ACTIVATION_TANH 0
70 #define ACTIVATION_SIGMOID 1
71 #define ACTIVATION_RELU 2
72
73 #define Q15ONE 1.0f
74
75 typedef struct DenseLayer {
76 const float *bias;
77 const float *input_weights;
78 int nb_inputs;
79 int nb_neurons;
80 int activation;
81 } DenseLayer;
82
83 typedef struct GRULayer {
84 const float *bias;
85 const float *input_weights;
86 const float *recurrent_weights;
87 int nb_inputs;
88 int nb_neurons;
89 int activation;
90 } GRULayer;
91
92 typedef struct RNNModel {
93 int input_dense_size;
94 const DenseLayer *input_dense;
95
96 int vad_gru_size;
97 const GRULayer *vad_gru;
98
99 int noise_gru_size;
100 const GRULayer *noise_gru;
101
102 int denoise_gru_size;
103 const GRULayer *denoise_gru;
104
105 int denoise_output_size;
106 const DenseLayer *denoise_output;
107
108 int vad_output_size;
109 const DenseLayer *vad_output;
110 } RNNModel;
111
112 typedef struct RNNState {
113 float *vad_gru_state;
114 float *noise_gru_state;
115 float *denoise_gru_state;
116 RNNModel *model;
117 } RNNState;
118
119 typedef struct DenoiseState {
120 float analysis_mem[FRAME_SIZE];
121 float cepstral_mem[CEPS_MEM][NB_BANDS];
122 int memid;
123 DECLARE_ALIGNED(32, float, synthesis_mem)[FRAME_SIZE];
124 float pitch_buf[PITCH_BUF_SIZE];
125 float pitch_enh_buf[PITCH_BUF_SIZE];
126 float last_gain;
127 int last_period;
128 float mem_hp_x[2];
129 float lastg[NB_BANDS];
130 RNNState rnn;
131 AVTXContext *tx, *txi;
132 av_tx_fn tx_fn, txi_fn;
133 } DenoiseState;
134
135 typedef struct AudioRNNContext {
136 const AVClass *class;
137
138 char *model_name;
139
140 int channels;
141 DenoiseState *st;
142
143 DECLARE_ALIGNED(32, float, window)[WINDOW_SIZE];
144 float dct_table[NB_BANDS*NB_BANDS];
145
146 RNNModel *model;
147
148 AVFloatDSPContext *fdsp;
149 } AudioRNNContext;
150
151 #define F_ACTIVATION_TANH 0
152 #define F_ACTIVATION_SIGMOID 1
153 #define F_ACTIVATION_RELU 2
154
rnnoise_model_free(RNNModel * model)155 static void rnnoise_model_free(RNNModel *model)
156 {
157 #define FREE_MAYBE(ptr) do { if (ptr) free(ptr); } while (0)
158 #define FREE_DENSE(name) do { \
159 if (model->name) { \
160 av_free((void *) model->name->input_weights); \
161 av_free((void *) model->name->bias); \
162 av_free((void *) model->name); \
163 } \
164 } while (0)
165 #define FREE_GRU(name) do { \
166 if (model->name) { \
167 av_free((void *) model->name->input_weights); \
168 av_free((void *) model->name->recurrent_weights); \
169 av_free((void *) model->name->bias); \
170 av_free((void *) model->name); \
171 } \
172 } while (0)
173
174 if (!model)
175 return;
176 FREE_DENSE(input_dense);
177 FREE_GRU(vad_gru);
178 FREE_GRU(noise_gru);
179 FREE_GRU(denoise_gru);
180 FREE_DENSE(denoise_output);
181 FREE_DENSE(vad_output);
182 av_free(model);
183 }
184
rnnoise_model_from_file(FILE * f)185 static RNNModel *rnnoise_model_from_file(FILE *f)
186 {
187 RNNModel *ret;
188 DenseLayer *input_dense;
189 GRULayer *vad_gru;
190 GRULayer *noise_gru;
191 GRULayer *denoise_gru;
192 DenseLayer *denoise_output;
193 DenseLayer *vad_output;
194 int in;
195
196 if (fscanf(f, "rnnoise-nu model file version %d\n", &in) != 1 || in != 1)
197 return NULL;
198
199 ret = av_calloc(1, sizeof(RNNModel));
200 if (!ret)
201 return NULL;
202
203 #define ALLOC_LAYER(type, name) \
204 name = av_calloc(1, sizeof(type)); \
205 if (!name) { \
206 rnnoise_model_free(ret); \
207 return NULL; \
208 } \
209 ret->name = name
210
211 ALLOC_LAYER(DenseLayer, input_dense);
212 ALLOC_LAYER(GRULayer, vad_gru);
213 ALLOC_LAYER(GRULayer, noise_gru);
214 ALLOC_LAYER(GRULayer, denoise_gru);
215 ALLOC_LAYER(DenseLayer, denoise_output);
216 ALLOC_LAYER(DenseLayer, vad_output);
217
218 #define INPUT_VAL(name) do { \
219 if (fscanf(f, "%d", &in) != 1 || in < 0 || in > 128) { \
220 rnnoise_model_free(ret); \
221 return NULL; \
222 } \
223 name = in; \
224 } while (0)
225
226 #define INPUT_ACTIVATION(name) do { \
227 int activation; \
228 INPUT_VAL(activation); \
229 switch (activation) { \
230 case F_ACTIVATION_SIGMOID: \
231 name = ACTIVATION_SIGMOID; \
232 break; \
233 case F_ACTIVATION_RELU: \
234 name = ACTIVATION_RELU; \
235 break; \
236 default: \
237 name = ACTIVATION_TANH; \
238 } \
239 } while (0)
240
241 #define INPUT_ARRAY(name, len) do { \
242 float *values = av_calloc((len), sizeof(float)); \
243 if (!values) { \
244 rnnoise_model_free(ret); \
245 return NULL; \
246 } \
247 name = values; \
248 for (int i = 0; i < (len); i++) { \
249 if (fscanf(f, "%d", &in) != 1) { \
250 rnnoise_model_free(ret); \
251 return NULL; \
252 } \
253 values[i] = in; \
254 } \
255 } while (0)
256
257 #define INPUT_ARRAY3(name, len0, len1, len2) do { \
258 float *values = av_calloc(FFALIGN((len0), 4) * FFALIGN((len1), 4) * (len2), sizeof(float)); \
259 if (!values) { \
260 rnnoise_model_free(ret); \
261 return NULL; \
262 } \
263 name = values; \
264 for (int k = 0; k < (len0); k++) { \
265 for (int i = 0; i < (len2); i++) { \
266 for (int j = 0; j < (len1); j++) { \
267 if (fscanf(f, "%d", &in) != 1) { \
268 rnnoise_model_free(ret); \
269 return NULL; \
270 } \
271 values[j * (len2) * FFALIGN((len0), 4) + i * FFALIGN((len0), 4) + k] = in; \
272 } \
273 } \
274 } \
275 } while (0)
276
277 #define INPUT_DENSE(name) do { \
278 INPUT_VAL(name->nb_inputs); \
279 INPUT_VAL(name->nb_neurons); \
280 ret->name ## _size = name->nb_neurons; \
281 INPUT_ACTIVATION(name->activation); \
282 INPUT_ARRAY(name->input_weights, name->nb_inputs * name->nb_neurons); \
283 INPUT_ARRAY(name->bias, name->nb_neurons); \
284 } while (0)
285
286 #define INPUT_GRU(name) do { \
287 INPUT_VAL(name->nb_inputs); \
288 INPUT_VAL(name->nb_neurons); \
289 ret->name ## _size = name->nb_neurons; \
290 INPUT_ACTIVATION(name->activation); \
291 INPUT_ARRAY3(name->input_weights, name->nb_inputs, name->nb_neurons, 3); \
292 INPUT_ARRAY3(name->recurrent_weights, name->nb_neurons, name->nb_neurons, 3); \
293 INPUT_ARRAY(name->bias, name->nb_neurons * 3); \
294 } while (0)
295
296 INPUT_DENSE(input_dense);
297 INPUT_GRU(vad_gru);
298 INPUT_GRU(noise_gru);
299 INPUT_GRU(denoise_gru);
300 INPUT_DENSE(denoise_output);
301 INPUT_DENSE(vad_output);
302
303 if (vad_output->nb_neurons != 1) {
304 rnnoise_model_free(ret);
305 return NULL;
306 }
307
308 return ret;
309 }
310
query_formats(AVFilterContext * ctx)311 static int query_formats(AVFilterContext *ctx)
312 {
313 AVFilterFormats *formats = NULL;
314 AVFilterChannelLayouts *layouts = NULL;
315 static const enum AVSampleFormat sample_fmts[] = {
316 AV_SAMPLE_FMT_FLTP,
317 AV_SAMPLE_FMT_NONE
318 };
319 int ret, sample_rates[] = { 48000, -1 };
320
321 formats = ff_make_format_list(sample_fmts);
322 if (!formats)
323 return AVERROR(ENOMEM);
324 ret = ff_set_common_formats(ctx, formats);
325 if (ret < 0)
326 return ret;
327
328 layouts = ff_all_channel_counts();
329 if (!layouts)
330 return AVERROR(ENOMEM);
331
332 ret = ff_set_common_channel_layouts(ctx, layouts);
333 if (ret < 0)
334 return ret;
335
336 formats = ff_make_format_list(sample_rates);
337 if (!formats)
338 return AVERROR(ENOMEM);
339 return ff_set_common_samplerates(ctx, formats);
340 }
341
config_input(AVFilterLink * inlink)342 static int config_input(AVFilterLink *inlink)
343 {
344 AVFilterContext *ctx = inlink->dst;
345 AudioRNNContext *s = ctx->priv;
346 int ret;
347
348 s->channels = inlink->channels;
349
350 s->st = av_calloc(s->channels, sizeof(DenoiseState));
351 if (!s->st)
352 return AVERROR(ENOMEM);
353
354 for (int i = 0; i < s->channels; i++) {
355 DenoiseState *st = &s->st[i];
356
357 st->rnn.model = s->model;
358 st->rnn.vad_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->vad_gru_size, 16));
359 st->rnn.noise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->noise_gru_size, 16));
360 st->rnn.denoise_gru_state = av_calloc(sizeof(float), FFALIGN(s->model->denoise_gru_size, 16));
361 if (!st->rnn.vad_gru_state ||
362 !st->rnn.noise_gru_state ||
363 !st->rnn.denoise_gru_state)
364 return AVERROR(ENOMEM);
365
366 ret = av_tx_init(&st->tx, &st->tx_fn, AV_TX_FLOAT_FFT, 0, WINDOW_SIZE, NULL, 0);
367 if (ret < 0)
368 return ret;
369
370 ret = av_tx_init(&st->txi, &st->txi_fn, AV_TX_FLOAT_FFT, 1, WINDOW_SIZE, NULL, 0);
371 if (ret < 0)
372 return ret;
373 }
374
375 return 0;
376 }
377
biquad(float * y,float mem[2],const float * x,const float * b,const float * a,int N)378 static void biquad(float *y, float mem[2], const float *x,
379 const float *b, const float *a, int N)
380 {
381 for (int i = 0; i < N; i++) {
382 float xi, yi;
383
384 xi = x[i];
385 yi = x[i] + mem[0];
386 mem[0] = mem[1] + (b[0]*xi - a[0]*yi);
387 mem[1] = (b[1]*xi - a[1]*yi);
388 y[i] = yi;
389 }
390 }
391
392 #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
393 #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst))))
394 #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) ))
395
forward_transform(DenoiseState * st,AVComplexFloat * out,const float * in)396 static void forward_transform(DenoiseState *st, AVComplexFloat *out, const float *in)
397 {
398 AVComplexFloat x[WINDOW_SIZE];
399 AVComplexFloat y[WINDOW_SIZE];
400
401 for (int i = 0; i < WINDOW_SIZE; i++) {
402 x[i].re = in[i];
403 x[i].im = 0;
404 }
405
406 st->tx_fn(st->tx, y, x, sizeof(float));
407
408 RNN_COPY(out, y, FREQ_SIZE);
409 }
410
inverse_transform(DenoiseState * st,float * out,const AVComplexFloat * in)411 static void inverse_transform(DenoiseState *st, float *out, const AVComplexFloat *in)
412 {
413 AVComplexFloat x[WINDOW_SIZE];
414 AVComplexFloat y[WINDOW_SIZE];
415
416 for (int i = 0; i < FREQ_SIZE; i++)
417 x[i] = in[i];
418
419 for (int i = FREQ_SIZE; i < WINDOW_SIZE; i++) {
420 x[i].re = x[WINDOW_SIZE - i].re;
421 x[i].im = -x[WINDOW_SIZE - i].im;
422 }
423
424 st->txi_fn(st->txi, y, x, sizeof(float));
425
426 for (int i = 0; i < WINDOW_SIZE; i++)
427 out[i] = y[i].re / WINDOW_SIZE;
428 }
429
430 static const uint8_t eband5ms[] = {
431 /*0 200 400 600 800 1k 1.2 1.4 1.6 2k 2.4 2.8 3.2 4k 4.8 5.6 6.8 8k 9.6 12k 15.6 20k*/
432 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100
433 };
434
compute_band_energy(float * bandE,const AVComplexFloat * X)435 static void compute_band_energy(float *bandE, const AVComplexFloat *X)
436 {
437 float sum[NB_BANDS] = {0};
438
439 for (int i = 0; i < NB_BANDS - 1; i++) {
440 int band_size;
441
442 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
443 for (int j = 0; j < band_size; j++) {
444 float tmp, frac = (float)j / band_size;
445
446 tmp = SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].re);
447 tmp += SQUARE(X[(eband5ms[i] << FRAME_SIZE_SHIFT) + j].im);
448 sum[i] += (1.f - frac) * tmp;
449 sum[i + 1] += frac * tmp;
450 }
451 }
452
453 sum[0] *= 2;
454 sum[NB_BANDS - 1] *= 2;
455
456 for (int i = 0; i < NB_BANDS; i++)
457 bandE[i] = sum[i];
458 }
459
compute_band_corr(float * bandE,const AVComplexFloat * X,const AVComplexFloat * P)460 static void compute_band_corr(float *bandE, const AVComplexFloat *X, const AVComplexFloat *P)
461 {
462 float sum[NB_BANDS] = { 0 };
463
464 for (int i = 0; i < NB_BANDS - 1; i++) {
465 int band_size;
466
467 band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
468 for (int j = 0; j < band_size; j++) {
469 float tmp, frac = (float)j / band_size;
470
471 tmp = X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].re;
472 tmp += X[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im * P[(eband5ms[i]<<FRAME_SIZE_SHIFT) + j].im;
473 sum[i] += (1 - frac) * tmp;
474 sum[i + 1] += frac * tmp;
475 }
476 }
477
478 sum[0] *= 2;
479 sum[NB_BANDS-1] *= 2;
480
481 for (int i = 0; i < NB_BANDS; i++)
482 bandE[i] = sum[i];
483 }
484
frame_analysis(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,float * Ex,const float * in)485 static void frame_analysis(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, float *Ex, const float *in)
486 {
487 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
488
489 RNN_COPY(x, st->analysis_mem, FRAME_SIZE);
490 RNN_COPY(x + FRAME_SIZE, in, FRAME_SIZE);
491 RNN_COPY(st->analysis_mem, in, FRAME_SIZE);
492 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
493 forward_transform(st, X, x);
494 compute_band_energy(Ex, X);
495 }
496
frame_synthesis(AudioRNNContext * s,DenoiseState * st,float * out,const AVComplexFloat * y)497 static void frame_synthesis(AudioRNNContext *s, DenoiseState *st, float *out, const AVComplexFloat *y)
498 {
499 LOCAL_ALIGNED_32(float, x, [WINDOW_SIZE]);
500
501 inverse_transform(st, x, y);
502 s->fdsp->vector_fmul(x, x, s->window, WINDOW_SIZE);
503 s->fdsp->vector_fmac_scalar(x, st->synthesis_mem, 1.f, FRAME_SIZE);
504 RNN_COPY(out, x, FRAME_SIZE);
505 RNN_COPY(st->synthesis_mem, &x[FRAME_SIZE], FRAME_SIZE);
506 }
507
xcorr_kernel(const float * x,const float * y,float sum[4],int len)508 static inline void xcorr_kernel(const float *x, const float *y, float sum[4], int len)
509 {
510 float y_0, y_1, y_2, y_3 = 0;
511 int j;
512
513 y_0 = *y++;
514 y_1 = *y++;
515 y_2 = *y++;
516
517 for (j = 0; j < len - 3; j += 4) {
518 float tmp;
519
520 tmp = *x++;
521 y_3 = *y++;
522 sum[0] += tmp * y_0;
523 sum[1] += tmp * y_1;
524 sum[2] += tmp * y_2;
525 sum[3] += tmp * y_3;
526 tmp = *x++;
527 y_0 = *y++;
528 sum[0] += tmp * y_1;
529 sum[1] += tmp * y_2;
530 sum[2] += tmp * y_3;
531 sum[3] += tmp * y_0;
532 tmp = *x++;
533 y_1 = *y++;
534 sum[0] += tmp * y_2;
535 sum[1] += tmp * y_3;
536 sum[2] += tmp * y_0;
537 sum[3] += tmp * y_1;
538 tmp = *x++;
539 y_2 = *y++;
540 sum[0] += tmp * y_3;
541 sum[1] += tmp * y_0;
542 sum[2] += tmp * y_1;
543 sum[3] += tmp * y_2;
544 }
545
546 if (j++ < len) {
547 float tmp = *x++;
548
549 y_3 = *y++;
550 sum[0] += tmp * y_0;
551 sum[1] += tmp * y_1;
552 sum[2] += tmp * y_2;
553 sum[3] += tmp * y_3;
554 }
555
556 if (j++ < len) {
557 float tmp=*x++;
558
559 y_0 = *y++;
560 sum[0] += tmp * y_1;
561 sum[1] += tmp * y_2;
562 sum[2] += tmp * y_3;
563 sum[3] += tmp * y_0;
564 }
565
566 if (j < len) {
567 float tmp=*x++;
568
569 y_1 = *y++;
570 sum[0] += tmp * y_2;
571 sum[1] += tmp * y_3;
572 sum[2] += tmp * y_0;
573 sum[3] += tmp * y_1;
574 }
575 }
576
celt_inner_prod(const float * x,const float * y,int N)577 static inline float celt_inner_prod(const float *x,
578 const float *y, int N)
579 {
580 float xy = 0.f;
581
582 for (int i = 0; i < N; i++)
583 xy += x[i] * y[i];
584
585 return xy;
586 }
587
celt_pitch_xcorr(const float * x,const float * y,float * xcorr,int len,int max_pitch)588 static void celt_pitch_xcorr(const float *x, const float *y,
589 float *xcorr, int len, int max_pitch)
590 {
591 int i;
592
593 for (i = 0; i < max_pitch - 3; i += 4) {
594 float sum[4] = { 0, 0, 0, 0};
595
596 xcorr_kernel(x, y + i, sum, len);
597
598 xcorr[i] = sum[0];
599 xcorr[i + 1] = sum[1];
600 xcorr[i + 2] = sum[2];
601 xcorr[i + 3] = sum[3];
602 }
603 /* In case max_pitch isn't a multiple of 4, do non-unrolled version. */
604 for (; i < max_pitch; i++) {
605 xcorr[i] = celt_inner_prod(x, y + i, len);
606 }
607 }
608
celt_autocorr(const float * x,float * ac,const float * window,int overlap,int lag,int n)609 static int celt_autocorr(const float *x, /* in: [0...n-1] samples x */
610 float *ac, /* out: [0...lag-1] ac values */
611 const float *window,
612 int overlap,
613 int lag,
614 int n)
615 {
616 int fastN = n - lag;
617 int shift;
618 const float *xptr;
619 float xx[PITCH_BUF_SIZE>>1];
620
621 if (overlap == 0) {
622 xptr = x;
623 } else {
624 for (int i = 0; i < n; i++)
625 xx[i] = x[i];
626 for (int i = 0; i < overlap; i++) {
627 xx[i] = x[i] * window[i];
628 xx[n-i-1] = x[n-i-1] * window[i];
629 }
630 xptr = xx;
631 }
632
633 shift = 0;
634 celt_pitch_xcorr(xptr, xptr, ac, fastN, lag+1);
635
636 for (int k = 0; k <= lag; k++) {
637 float d = 0.f;
638
639 for (int i = k + fastN; i < n; i++)
640 d += xptr[i] * xptr[i-k];
641 ac[k] += d;
642 }
643
644 return shift;
645 }
646
celt_lpc(float * lpc,const float * ac,int p)647 static void celt_lpc(float *lpc, /* out: [0...p-1] LPC coefficients */
648 const float *ac, /* in: [0...p] autocorrelation values */
649 int p)
650 {
651 float r, error = ac[0];
652
653 RNN_CLEAR(lpc, p);
654 if (ac[0] != 0) {
655 for (int i = 0; i < p; i++) {
656 /* Sum up this iteration's reflection coefficient */
657 float rr = 0;
658 for (int j = 0; j < i; j++)
659 rr += (lpc[j] * ac[i - j]);
660 rr += ac[i + 1];
661 r = -rr/error;
662 /* Update LPC coefficients and total error */
663 lpc[i] = r;
664 for (int j = 0; j < (i + 1) >> 1; j++) {
665 float tmp1, tmp2;
666 tmp1 = lpc[j];
667 tmp2 = lpc[i-1-j];
668 lpc[j] = tmp1 + (r*tmp2);
669 lpc[i-1-j] = tmp2 + (r*tmp1);
670 }
671
672 error = error - (r * r *error);
673 /* Bail out once we get 30 dB gain */
674 if (error < .001f * ac[0])
675 break;
676 }
677 }
678 }
679
celt_fir5(const float * x,const float * num,float * y,int N,float * mem)680 static void celt_fir5(const float *x,
681 const float *num,
682 float *y,
683 int N,
684 float *mem)
685 {
686 float num0, num1, num2, num3, num4;
687 float mem0, mem1, mem2, mem3, mem4;
688
689 num0 = num[0];
690 num1 = num[1];
691 num2 = num[2];
692 num3 = num[3];
693 num4 = num[4];
694 mem0 = mem[0];
695 mem1 = mem[1];
696 mem2 = mem[2];
697 mem3 = mem[3];
698 mem4 = mem[4];
699
700 for (int i = 0; i < N; i++) {
701 float sum = x[i];
702
703 sum += (num0*mem0);
704 sum += (num1*mem1);
705 sum += (num2*mem2);
706 sum += (num3*mem3);
707 sum += (num4*mem4);
708 mem4 = mem3;
709 mem3 = mem2;
710 mem2 = mem1;
711 mem1 = mem0;
712 mem0 = x[i];
713 y[i] = sum;
714 }
715
716 mem[0] = mem0;
717 mem[1] = mem1;
718 mem[2] = mem2;
719 mem[3] = mem3;
720 mem[4] = mem4;
721 }
722
pitch_downsample(float * x[],float * x_lp,int len,int C)723 static void pitch_downsample(float *x[], float *x_lp,
724 int len, int C)
725 {
726 float ac[5];
727 float tmp=Q15ONE;
728 float lpc[4], mem[5]={0,0,0,0,0};
729 float lpc2[5];
730 float c1 = .8f;
731
732 for (int i = 1; i < len >> 1; i++)
733 x_lp[i] = .5f * (.5f * (x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]);
734 x_lp[0] = .5f * (.5f * (x[0][1])+x[0][0]);
735 if (C==2) {
736 for (int i = 1; i < len >> 1; i++)
737 x_lp[i] += (.5f * (.5f * (x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]));
738 x_lp[0] += .5f * (.5f * (x[1][1])+x[1][0]);
739 }
740
741 celt_autocorr(x_lp, ac, NULL, 0, 4, len>>1);
742
743 /* Noise floor -40 dB */
744 ac[0] *= 1.0001f;
745 /* Lag windowing */
746 for (int i = 1; i <= 4; i++) {
747 /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/
748 ac[i] -= ac[i]*(.008f*i)*(.008f*i);
749 }
750
751 celt_lpc(lpc, ac, 4);
752 for (int i = 0; i < 4; i++) {
753 tmp = .9f * tmp;
754 lpc[i] = (lpc[i] * tmp);
755 }
756 /* Add a zero */
757 lpc2[0] = lpc[0] + .8f;
758 lpc2[1] = lpc[1] + (c1 * lpc[0]);
759 lpc2[2] = lpc[2] + (c1 * lpc[1]);
760 lpc2[3] = lpc[3] + (c1 * lpc[2]);
761 lpc2[4] = (c1 * lpc[3]);
762 celt_fir5(x_lp, lpc2, x_lp, len>>1, mem);
763 }
764
dual_inner_prod(const float * x,const float * y01,const float * y02,int N,float * xy1,float * xy2)765 static inline void dual_inner_prod(const float *x, const float *y01, const float *y02,
766 int N, float *xy1, float *xy2)
767 {
768 float xy01 = 0, xy02 = 0;
769
770 for (int i = 0; i < N; i++) {
771 xy01 += (x[i] * y01[i]);
772 xy02 += (x[i] * y02[i]);
773 }
774
775 *xy1 = xy01;
776 *xy2 = xy02;
777 }
778
compute_pitch_gain(float xy,float xx,float yy)779 static float compute_pitch_gain(float xy, float xx, float yy)
780 {
781 return xy / sqrtf(1.f + xx * yy);
782 }
783
784 static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2};
remove_doubling(float * x,int maxperiod,int minperiod,int N,int * T0_,int prev_period,float prev_gain)785 static float remove_doubling(float *x, int maxperiod, int minperiod, int N,
786 int *T0_, int prev_period, float prev_gain)
787 {
788 int k, i, T, T0;
789 float g, g0;
790 float pg;
791 float xy,xx,yy,xy2;
792 float xcorr[3];
793 float best_xy, best_yy;
794 int offset;
795 int minperiod0;
796 float yy_lookup[PITCH_MAX_PERIOD+1];
797
798 minperiod0 = minperiod;
799 maxperiod /= 2;
800 minperiod /= 2;
801 *T0_ /= 2;
802 prev_period /= 2;
803 N /= 2;
804 x += maxperiod;
805 if (*T0_>=maxperiod)
806 *T0_=maxperiod-1;
807
808 T = T0 = *T0_;
809 dual_inner_prod(x, x, x-T0, N, &xx, &xy);
810 yy_lookup[0] = xx;
811 yy=xx;
812 for (i = 1; i <= maxperiod; i++) {
813 yy = yy+(x[-i] * x[-i])-(x[N-i] * x[N-i]);
814 yy_lookup[i] = FFMAX(0, yy);
815 }
816 yy = yy_lookup[T0];
817 best_xy = xy;
818 best_yy = yy;
819 g = g0 = compute_pitch_gain(xy, xx, yy);
820 /* Look for any pitch at T/k */
821 for (k = 2; k <= 15; k++) {
822 int T1, T1b;
823 float g1;
824 float cont=0;
825 float thresh;
826 T1 = (2*T0+k)/(2*k);
827 if (T1 < minperiod)
828 break;
829 /* Look for another strong correlation at T1b */
830 if (k==2)
831 {
832 if (T1+T0>maxperiod)
833 T1b = T0;
834 else
835 T1b = T0+T1;
836 } else
837 {
838 T1b = (2*second_check[k]*T0+k)/(2*k);
839 }
840 dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2);
841 xy = .5f * (xy + xy2);
842 yy = .5f * (yy_lookup[T1] + yy_lookup[T1b]);
843 g1 = compute_pitch_gain(xy, xx, yy);
844 if (FFABS(T1-prev_period)<=1)
845 cont = prev_gain;
846 else if (FFABS(T1-prev_period)<=2 && 5 * k * k < T0)
847 cont = prev_gain * .5f;
848 else
849 cont = 0;
850 thresh = FFMAX(.3f, (.7f * g0) - cont);
851 /* Bias against very high pitch (very short period) to avoid false-positives
852 due to short-term correlation */
853 if (T1<3*minperiod)
854 thresh = FFMAX(.4f, (.85f * g0) - cont);
855 else if (T1<2*minperiod)
856 thresh = FFMAX(.5f, (.9f * g0) - cont);
857 if (g1 > thresh)
858 {
859 best_xy = xy;
860 best_yy = yy;
861 T = T1;
862 g = g1;
863 }
864 }
865 best_xy = FFMAX(0, best_xy);
866 if (best_yy <= best_xy)
867 pg = Q15ONE;
868 else
869 pg = best_xy/(best_yy + 1);
870
871 for (k = 0; k < 3; k++)
872 xcorr[k] = celt_inner_prod(x, x-(T+k-1), N);
873 if ((xcorr[2]-xcorr[0]) > .7f * (xcorr[1]-xcorr[0]))
874 offset = 1;
875 else if ((xcorr[0]-xcorr[2]) > (.7f * (xcorr[1] - xcorr[2])))
876 offset = -1;
877 else
878 offset = 0;
879 if (pg > g)
880 pg = g;
881 *T0_ = 2*T+offset;
882
883 if (*T0_<minperiod0)
884 *T0_=minperiod0;
885 return pg;
886 }
887
find_best_pitch(float * xcorr,float * y,int len,int max_pitch,int * best_pitch)888 static void find_best_pitch(float *xcorr, float *y, int len,
889 int max_pitch, int *best_pitch)
890 {
891 float best_num[2];
892 float best_den[2];
893 float Syy = 1.f;
894
895 best_num[0] = -1;
896 best_num[1] = -1;
897 best_den[0] = 0;
898 best_den[1] = 0;
899 best_pitch[0] = 0;
900 best_pitch[1] = 1;
901
902 for (int j = 0; j < len; j++)
903 Syy += y[j] * y[j];
904
905 for (int i = 0; i < max_pitch; i++) {
906 if (xcorr[i]>0) {
907 float num;
908 float xcorr16;
909
910 xcorr16 = xcorr[i];
911 /* Considering the range of xcorr16, this should avoid both underflows
912 and overflows (inf) when squaring xcorr16 */
913 xcorr16 *= 1e-12f;
914 num = xcorr16 * xcorr16;
915 if ((num * best_den[1]) > (best_num[1] * Syy)) {
916 if ((num * best_den[0]) > (best_num[0] * Syy)) {
917 best_num[1] = best_num[0];
918 best_den[1] = best_den[0];
919 best_pitch[1] = best_pitch[0];
920 best_num[0] = num;
921 best_den[0] = Syy;
922 best_pitch[0] = i;
923 } else {
924 best_num[1] = num;
925 best_den[1] = Syy;
926 best_pitch[1] = i;
927 }
928 }
929 }
930 Syy += y[i+len]*y[i+len] - y[i] * y[i];
931 Syy = FFMAX(1, Syy);
932 }
933 }
934
pitch_search(const float * x_lp,float * y,int len,int max_pitch,int * pitch)935 static void pitch_search(const float *x_lp, float *y,
936 int len, int max_pitch, int *pitch)
937 {
938 int lag;
939 int best_pitch[2]={0,0};
940 int offset;
941
942 float x_lp4[WINDOW_SIZE];
943 float y_lp4[WINDOW_SIZE];
944 float xcorr[WINDOW_SIZE];
945
946 lag = len+max_pitch;
947
948 /* Downsample by 2 again */
949 for (int j = 0; j < len >> 2; j++)
950 x_lp4[j] = x_lp[2*j];
951 for (int j = 0; j < lag >> 2; j++)
952 y_lp4[j] = y[2*j];
953
954 /* Coarse search with 4x decimation */
955
956 celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2);
957
958 find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch);
959
960 /* Finer search with 2x decimation */
961 for (int i = 0; i < max_pitch >> 1; i++) {
962 float sum;
963 xcorr[i] = 0;
964 if (FFABS(i-2*best_pitch[0])>2 && FFABS(i-2*best_pitch[1])>2)
965 continue;
966 sum = celt_inner_prod(x_lp, y+i, len>>1);
967 xcorr[i] = FFMAX(-1, sum);
968 }
969
970 find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch);
971
972 /* Refine by pseudo-interpolation */
973 if (best_pitch[0] > 0 && best_pitch[0] < (max_pitch >> 1) - 1) {
974 float a, b, c;
975
976 a = xcorr[best_pitch[0] - 1];
977 b = xcorr[best_pitch[0]];
978 c = xcorr[best_pitch[0] + 1];
979 if (c - a > .7f * (b - a))
980 offset = 1;
981 else if (a - c > .7f * (b-c))
982 offset = -1;
983 else
984 offset = 0;
985 } else {
986 offset = 0;
987 }
988
989 *pitch = 2 * best_pitch[0] - offset;
990 }
991
dct(AudioRNNContext * s,float * out,const float * in)992 static void dct(AudioRNNContext *s, float *out, const float *in)
993 {
994 for (int i = 0; i < NB_BANDS; i++) {
995 float sum = 0.f;
996
997 for (int j = 0; j < NB_BANDS; j++) {
998 sum += in[j] * s->dct_table[j * NB_BANDS + i];
999 }
1000 out[i] = sum * sqrtf(2.f / 22);
1001 }
1002 }
1003
compute_frame_features(AudioRNNContext * s,DenoiseState * st,AVComplexFloat * X,AVComplexFloat * P,float * Ex,float * Ep,float * Exp,float * features,const float * in)1004 static int compute_frame_features(AudioRNNContext *s, DenoiseState *st, AVComplexFloat *X, AVComplexFloat *P,
1005 float *Ex, float *Ep, float *Exp, float *features, const float *in)
1006 {
1007 float E = 0;
1008 float *ceps_0, *ceps_1, *ceps_2;
1009 float spec_variability = 0;
1010 float Ly[NB_BANDS];
1011 LOCAL_ALIGNED_32(float, p, [WINDOW_SIZE]);
1012 float pitch_buf[PITCH_BUF_SIZE>>1];
1013 int pitch_index;
1014 float gain;
1015 float *(pre[1]);
1016 float tmp[NB_BANDS];
1017 float follow, logMax;
1018
1019 frame_analysis(s, st, X, Ex, in);
1020 RNN_MOVE(st->pitch_buf, &st->pitch_buf[FRAME_SIZE], PITCH_BUF_SIZE-FRAME_SIZE);
1021 RNN_COPY(&st->pitch_buf[PITCH_BUF_SIZE-FRAME_SIZE], in, FRAME_SIZE);
1022 pre[0] = &st->pitch_buf[0];
1023 pitch_downsample(pre, pitch_buf, PITCH_BUF_SIZE, 1);
1024 pitch_search(pitch_buf+(PITCH_MAX_PERIOD>>1), pitch_buf, PITCH_FRAME_SIZE,
1025 PITCH_MAX_PERIOD-3*PITCH_MIN_PERIOD, &pitch_index);
1026 pitch_index = PITCH_MAX_PERIOD-pitch_index;
1027
1028 gain = remove_doubling(pitch_buf, PITCH_MAX_PERIOD, PITCH_MIN_PERIOD,
1029 PITCH_FRAME_SIZE, &pitch_index, st->last_period, st->last_gain);
1030 st->last_period = pitch_index;
1031 st->last_gain = gain;
1032
1033 for (int i = 0; i < WINDOW_SIZE; i++)
1034 p[i] = st->pitch_buf[PITCH_BUF_SIZE-WINDOW_SIZE-pitch_index+i];
1035
1036 s->fdsp->vector_fmul(p, p, s->window, WINDOW_SIZE);
1037 forward_transform(st, P, p);
1038 compute_band_energy(Ep, P);
1039 compute_band_corr(Exp, X, P);
1040
1041 for (int i = 0; i < NB_BANDS; i++)
1042 Exp[i] = Exp[i] / sqrtf(.001f+Ex[i]*Ep[i]);
1043
1044 dct(s, tmp, Exp);
1045
1046 for (int i = 0; i < NB_DELTA_CEPS; i++)
1047 features[NB_BANDS+2*NB_DELTA_CEPS+i] = tmp[i];
1048
1049 features[NB_BANDS+2*NB_DELTA_CEPS] -= 1.3;
1050 features[NB_BANDS+2*NB_DELTA_CEPS+1] -= 0.9;
1051 features[NB_BANDS+3*NB_DELTA_CEPS] = .01*(pitch_index-300);
1052 logMax = -2;
1053 follow = -2;
1054
1055 for (int i = 0; i < NB_BANDS; i++) {
1056 Ly[i] = log10f(1e-2f + Ex[i]);
1057 Ly[i] = FFMAX(logMax-7, FFMAX(follow-1.5, Ly[i]));
1058 logMax = FFMAX(logMax, Ly[i]);
1059 follow = FFMAX(follow-1.5, Ly[i]);
1060 E += Ex[i];
1061 }
1062
1063 if (E < 0.04f) {
1064 /* If there's no audio, avoid messing up the state. */
1065 RNN_CLEAR(features, NB_FEATURES);
1066 return 1;
1067 }
1068
1069 dct(s, features, Ly);
1070 features[0] -= 12;
1071 features[1] -= 4;
1072 ceps_0 = st->cepstral_mem[st->memid];
1073 ceps_1 = (st->memid < 1) ? st->cepstral_mem[CEPS_MEM+st->memid-1] : st->cepstral_mem[st->memid-1];
1074 ceps_2 = (st->memid < 2) ? st->cepstral_mem[CEPS_MEM+st->memid-2] : st->cepstral_mem[st->memid-2];
1075
1076 for (int i = 0; i < NB_BANDS; i++)
1077 ceps_0[i] = features[i];
1078
1079 st->memid++;
1080 for (int i = 0; i < NB_DELTA_CEPS; i++) {
1081 features[i] = ceps_0[i] + ceps_1[i] + ceps_2[i];
1082 features[NB_BANDS+i] = ceps_0[i] - ceps_2[i];
1083 features[NB_BANDS+NB_DELTA_CEPS+i] = ceps_0[i] - 2*ceps_1[i] + ceps_2[i];
1084 }
1085 /* Spectral variability features. */
1086 if (st->memid == CEPS_MEM)
1087 st->memid = 0;
1088
1089 for (int i = 0; i < CEPS_MEM; i++) {
1090 float mindist = 1e15f;
1091 for (int j = 0; j < CEPS_MEM; j++) {
1092 float dist = 0.f;
1093 for (int k = 0; k < NB_BANDS; k++) {
1094 float tmp;
1095
1096 tmp = st->cepstral_mem[i][k] - st->cepstral_mem[j][k];
1097 dist += tmp*tmp;
1098 }
1099
1100 if (j != i)
1101 mindist = FFMIN(mindist, dist);
1102 }
1103
1104 spec_variability += mindist;
1105 }
1106
1107 features[NB_BANDS+3*NB_DELTA_CEPS+1] = spec_variability/CEPS_MEM-2.1;
1108
1109 return 0;
1110 }
1111
interp_band_gain(float * g,const float * bandE)1112 static void interp_band_gain(float *g, const float *bandE)
1113 {
1114 memset(g, 0, sizeof(*g) * FREQ_SIZE);
1115
1116 for (int i = 0; i < NB_BANDS - 1; i++) {
1117 const int band_size = (eband5ms[i + 1] - eband5ms[i]) << FRAME_SIZE_SHIFT;
1118
1119 for (int j = 0; j < band_size; j++) {
1120 float frac = (float)j / band_size;
1121
1122 g[(eband5ms[i] << FRAME_SIZE_SHIFT) + j] = (1.f - frac) * bandE[i] + frac * bandE[i + 1];
1123 }
1124 }
1125 }
1126
pitch_filter(AVComplexFloat * X,const AVComplexFloat * P,const float * Ex,const float * Ep,const float * Exp,const float * g)1127 static void pitch_filter(AVComplexFloat *X, const AVComplexFloat *P, const float *Ex, const float *Ep,
1128 const float *Exp, const float *g)
1129 {
1130 float newE[NB_BANDS];
1131 float r[NB_BANDS];
1132 float norm[NB_BANDS];
1133 float rf[FREQ_SIZE] = {0};
1134 float normf[FREQ_SIZE]={0};
1135
1136 for (int i = 0; i < NB_BANDS; i++) {
1137 if (Exp[i]>g[i]) r[i] = 1;
1138 else r[i] = SQUARE(Exp[i])*(1-SQUARE(g[i]))/(.001 + SQUARE(g[i])*(1-SQUARE(Exp[i])));
1139 r[i] = sqrtf(av_clipf(r[i], 0, 1));
1140 r[i] *= sqrtf(Ex[i]/(1e-8+Ep[i]));
1141 }
1142 interp_band_gain(rf, r);
1143 for (int i = 0; i < FREQ_SIZE; i++) {
1144 X[i].re += rf[i]*P[i].re;
1145 X[i].im += rf[i]*P[i].im;
1146 }
1147 compute_band_energy(newE, X);
1148 for (int i = 0; i < NB_BANDS; i++) {
1149 norm[i] = sqrtf(Ex[i] / (1e-8+newE[i]));
1150 }
1151 interp_band_gain(normf, norm);
1152 for (int i = 0; i < FREQ_SIZE; i++) {
1153 X[i].re *= normf[i];
1154 X[i].im *= normf[i];
1155 }
1156 }
1157
1158 static const float tansig_table[201] = {
1159 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f,
1160 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f,
1161 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f,
1162 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f,
1163 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f,
1164 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f,
1165 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f,
1166 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f,
1167 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f,
1168 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f,
1169 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f,
1170 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f,
1171 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f,
1172 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f,
1173 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f,
1174 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f,
1175 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f,
1176 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f,
1177 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f,
1178 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f,
1179 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f,
1180 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f,
1181 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f,
1182 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f,
1183 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f,
1184 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f,
1185 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f,
1186 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f,
1187 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f,
1188 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f,
1189 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f,
1190 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f,
1191 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f,
1192 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f,
1193 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f,
1194 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f,
1195 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1196 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f,
1197 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1198 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f,
1199 1.000000f,
1200 };
1201
tansig_approx(float x)1202 static inline float tansig_approx(float x)
1203 {
1204 float y, dy;
1205 float sign=1;
1206 int i;
1207
1208 /* Tests are reversed to catch NaNs */
1209 if (!(x<8))
1210 return 1;
1211 if (!(x>-8))
1212 return -1;
1213 /* Another check in case of -ffast-math */
1214
1215 if (isnan(x))
1216 return 0;
1217
1218 if (x < 0) {
1219 x=-x;
1220 sign=-1;
1221 }
1222 i = (int)floor(.5f+25*x);
1223 x -= .04f*i;
1224 y = tansig_table[i];
1225 dy = 1-y*y;
1226 y = y + x*dy*(1 - y*x);
1227 return sign*y;
1228 }
1229
sigmoid_approx(float x)1230 static inline float sigmoid_approx(float x)
1231 {
1232 return .5f + .5f*tansig_approx(.5f*x);
1233 }
1234
compute_dense(const DenseLayer * layer,float * output,const float * input)1235 static void compute_dense(const DenseLayer *layer, float *output, const float *input)
1236 {
1237 const int N = layer->nb_neurons, M = layer->nb_inputs, stride = N;
1238
1239 for (int i = 0; i < N; i++) {
1240 /* Compute update gate. */
1241 float sum = layer->bias[i];
1242
1243 for (int j = 0; j < M; j++)
1244 sum += layer->input_weights[j * stride + i] * input[j];
1245
1246 output[i] = WEIGHTS_SCALE * sum;
1247 }
1248
1249 if (layer->activation == ACTIVATION_SIGMOID) {
1250 for (int i = 0; i < N; i++)
1251 output[i] = sigmoid_approx(output[i]);
1252 } else if (layer->activation == ACTIVATION_TANH) {
1253 for (int i = 0; i < N; i++)
1254 output[i] = tansig_approx(output[i]);
1255 } else if (layer->activation == ACTIVATION_RELU) {
1256 for (int i = 0; i < N; i++)
1257 output[i] = FFMAX(0, output[i]);
1258 } else {
1259 av_assert0(0);
1260 }
1261 }
1262
compute_gru(AudioRNNContext * s,const GRULayer * gru,float * state,const float * input)1263 static void compute_gru(AudioRNNContext *s, const GRULayer *gru, float *state, const float *input)
1264 {
1265 LOCAL_ALIGNED_32(float, z, [MAX_NEURONS]);
1266 LOCAL_ALIGNED_32(float, r, [MAX_NEURONS]);
1267 LOCAL_ALIGNED_32(float, h, [MAX_NEURONS]);
1268 const int M = gru->nb_inputs;
1269 const int N = gru->nb_neurons;
1270 const int AN = FFALIGN(N, 4);
1271 const int AM = FFALIGN(M, 4);
1272 const int stride = 3 * AN, istride = 3 * AM;
1273
1274 for (int i = 0; i < N; i++) {
1275 /* Compute update gate. */
1276 float sum = gru->bias[i];
1277
1278 sum += s->fdsp->scalarproduct_float(gru->input_weights + i * istride, input, AM);
1279 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + i * stride, state, AN);
1280 z[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1281 }
1282
1283 for (int i = 0; i < N; i++) {
1284 /* Compute reset gate. */
1285 float sum = gru->bias[N + i];
1286
1287 sum += s->fdsp->scalarproduct_float(gru->input_weights + AM + i * istride, input, AM);
1288 sum += s->fdsp->scalarproduct_float(gru->recurrent_weights + AN + i * stride, state, AN);
1289 r[i] = sigmoid_approx(WEIGHTS_SCALE * sum);
1290 }
1291
1292 for (int i = 0; i < N; i++) {
1293 /* Compute output. */
1294 float sum = gru->bias[2 * N + i];
1295
1296 sum += s->fdsp->scalarproduct_float(gru->input_weights + 2 * AM + i * istride, input, AM);
1297 for (int j = 0; j < N; j++)
1298 sum += gru->recurrent_weights[2 * AN + i * stride + j] * state[j] * r[j];
1299
1300 if (gru->activation == ACTIVATION_SIGMOID)
1301 sum = sigmoid_approx(WEIGHTS_SCALE * sum);
1302 else if (gru->activation == ACTIVATION_TANH)
1303 sum = tansig_approx(WEIGHTS_SCALE * sum);
1304 else if (gru->activation == ACTIVATION_RELU)
1305 sum = FFMAX(0, WEIGHTS_SCALE * sum);
1306 else
1307 av_assert0(0);
1308 h[i] = z[i] * state[i] + (1.f - z[i]) * sum;
1309 }
1310
1311 RNN_COPY(state, h, N);
1312 }
1313
1314 #define INPUT_SIZE 42
1315
compute_rnn(AudioRNNContext * s,RNNState * rnn,float * gains,float * vad,const float * input)1316 static void compute_rnn(AudioRNNContext *s, RNNState *rnn, float *gains, float *vad, const float *input)
1317 {
1318 LOCAL_ALIGNED_32(float, dense_out, [MAX_NEURONS]);
1319 LOCAL_ALIGNED_32(float, noise_input, [MAX_NEURONS * 3]);
1320 LOCAL_ALIGNED_32(float, denoise_input, [MAX_NEURONS * 3]);
1321
1322 compute_dense(rnn->model->input_dense, dense_out, input);
1323 compute_gru(s, rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
1324 compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
1325
1326 for (int i = 0; i < rnn->model->input_dense_size; i++)
1327 noise_input[i] = dense_out[i];
1328 for (int i = 0; i < rnn->model->vad_gru_size; i++)
1329 noise_input[i + rnn->model->input_dense_size] = rnn->vad_gru_state[i];
1330 for (int i = 0; i < INPUT_SIZE; i++)
1331 noise_input[i + rnn->model->input_dense_size + rnn->model->vad_gru_size] = input[i];
1332
1333 compute_gru(s, rnn->model->noise_gru, rnn->noise_gru_state, noise_input);
1334
1335 for (int i = 0; i < rnn->model->vad_gru_size; i++)
1336 denoise_input[i] = rnn->vad_gru_state[i];
1337 for (int i = 0; i < rnn->model->noise_gru_size; i++)
1338 denoise_input[i + rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
1339 for (int i = 0; i < INPUT_SIZE; i++)
1340 denoise_input[i + rnn->model->vad_gru_size + rnn->model->noise_gru_size] = input[i];
1341
1342 compute_gru(s, rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
1343 compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
1344 }
1345
rnnoise_channel(AudioRNNContext * s,DenoiseState * st,float * out,const float * in)1346 static float rnnoise_channel(AudioRNNContext *s, DenoiseState *st, float *out, const float *in)
1347 {
1348 AVComplexFloat X[FREQ_SIZE];
1349 AVComplexFloat P[WINDOW_SIZE];
1350 float x[FRAME_SIZE];
1351 float Ex[NB_BANDS], Ep[NB_BANDS];
1352 float Exp[NB_BANDS];
1353 float features[NB_FEATURES];
1354 float g[NB_BANDS];
1355 float gf[FREQ_SIZE];
1356 float vad_prob = 0;
1357 static const float a_hp[2] = {-1.99599, 0.99600};
1358 static const float b_hp[2] = {-2, 1};
1359 int silence;
1360
1361 biquad(x, st->mem_hp_x, in, b_hp, a_hp, FRAME_SIZE);
1362 silence = compute_frame_features(s, st, X, P, Ex, Ep, Exp, features, x);
1363
1364 if (!silence) {
1365 compute_rnn(s, &st->rnn, g, &vad_prob, features);
1366 pitch_filter(X, P, Ex, Ep, Exp, g);
1367 for (int i = 0; i < NB_BANDS; i++) {
1368 float alpha = .6f;
1369
1370 g[i] = FFMAX(g[i], alpha * st->lastg[i]);
1371 st->lastg[i] = g[i];
1372 }
1373
1374 interp_band_gain(gf, g);
1375
1376 for (int i = 0; i < FREQ_SIZE; i++) {
1377 X[i].re *= gf[i];
1378 X[i].im *= gf[i];
1379 }
1380 }
1381
1382 frame_synthesis(s, st, out, X);
1383
1384 return vad_prob;
1385 }
1386
1387 typedef struct ThreadData {
1388 AVFrame *in, *out;
1389 } ThreadData;
1390
rnnoise_channels(AVFilterContext * ctx,void * arg,int jobnr,int nb_jobs)1391 static int rnnoise_channels(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
1392 {
1393 AudioRNNContext *s = ctx->priv;
1394 ThreadData *td = arg;
1395 AVFrame *in = td->in;
1396 AVFrame *out = td->out;
1397 const int start = (out->channels * jobnr) / nb_jobs;
1398 const int end = (out->channels * (jobnr+1)) / nb_jobs;
1399
1400 for (int ch = start; ch < end; ch++) {
1401 rnnoise_channel(s, &s->st[ch],
1402 (float *)out->extended_data[ch],
1403 (const float *)in->extended_data[ch]);
1404 }
1405
1406 return 0;
1407 }
1408
filter_frame(AVFilterLink * inlink,AVFrame * in)1409 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
1410 {
1411 AVFilterContext *ctx = inlink->dst;
1412 AVFilterLink *outlink = ctx->outputs[0];
1413 AVFrame *out = NULL;
1414 ThreadData td;
1415
1416 out = ff_get_audio_buffer(outlink, FRAME_SIZE);
1417 if (!out) {
1418 av_frame_free(&in);
1419 return AVERROR(ENOMEM);
1420 }
1421 out->pts = in->pts;
1422
1423 td.in = in; td.out = out;
1424 ctx->internal->execute(ctx, rnnoise_channels, &td, NULL, FFMIN(outlink->channels,
1425 ff_filter_get_nb_threads(ctx)));
1426
1427 av_frame_free(&in);
1428 return ff_filter_frame(outlink, out);
1429 }
1430
activate(AVFilterContext * ctx)1431 static int activate(AVFilterContext *ctx)
1432 {
1433 AVFilterLink *inlink = ctx->inputs[0];
1434 AVFilterLink *outlink = ctx->outputs[0];
1435 AVFrame *in = NULL;
1436 int ret;
1437
1438 FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
1439
1440 ret = ff_inlink_consume_samples(inlink, FRAME_SIZE, FRAME_SIZE, &in);
1441 if (ret < 0)
1442 return ret;
1443
1444 if (ret > 0)
1445 return filter_frame(inlink, in);
1446
1447 FF_FILTER_FORWARD_STATUS(inlink, outlink);
1448 FF_FILTER_FORWARD_WANTED(outlink, inlink);
1449
1450 return FFERROR_NOT_READY;
1451 }
1452
init(AVFilterContext * ctx)1453 static av_cold int init(AVFilterContext *ctx)
1454 {
1455 AudioRNNContext *s = ctx->priv;
1456 FILE *f;
1457
1458 s->fdsp = avpriv_float_dsp_alloc(0);
1459 if (!s->fdsp)
1460 return AVERROR(ENOMEM);
1461
1462 if (!s->model_name)
1463 return AVERROR(EINVAL);
1464 f = av_fopen_utf8(s->model_name, "r");
1465 if (!f)
1466 return AVERROR(EINVAL);
1467
1468 s->model = rnnoise_model_from_file(f);
1469 fclose(f);
1470 if (!s->model)
1471 return AVERROR(EINVAL);
1472
1473 for (int i = 0; i < FRAME_SIZE; i++) {
1474 s->window[i] = sin(.5*M_PI*sin(.5*M_PI*(i+.5)/FRAME_SIZE) * sin(.5*M_PI*(i+.5)/FRAME_SIZE));
1475 s->window[WINDOW_SIZE - 1 - i] = s->window[i];
1476 }
1477
1478 for (int i = 0; i < NB_BANDS; i++) {
1479 for (int j = 0; j < NB_BANDS; j++) {
1480 s->dct_table[i*NB_BANDS + j] = cosf((i + .5f) * j * M_PI / NB_BANDS);
1481 if (j == 0)
1482 s->dct_table[i*NB_BANDS + j] *= sqrtf(.5);
1483 }
1484 }
1485
1486 return 0;
1487 }
1488
uninit(AVFilterContext * ctx)1489 static av_cold void uninit(AVFilterContext *ctx)
1490 {
1491 AudioRNNContext *s = ctx->priv;
1492
1493 av_freep(&s->fdsp);
1494 rnnoise_model_free(s->model);
1495 s->model = NULL;
1496
1497 if (s->st) {
1498 for (int ch = 0; ch < s->channels; ch++) {
1499 av_freep(&s->st[ch].rnn.vad_gru_state);
1500 av_freep(&s->st[ch].rnn.noise_gru_state);
1501 av_freep(&s->st[ch].rnn.denoise_gru_state);
1502 av_tx_uninit(&s->st[ch].tx);
1503 av_tx_uninit(&s->st[ch].txi);
1504 }
1505 }
1506 av_freep(&s->st);
1507 }
1508
1509 static const AVFilterPad inputs[] = {
1510 {
1511 .name = "default",
1512 .type = AVMEDIA_TYPE_AUDIO,
1513 .config_props = config_input,
1514 },
1515 { NULL }
1516 };
1517
1518 static const AVFilterPad outputs[] = {
1519 {
1520 .name = "default",
1521 .type = AVMEDIA_TYPE_AUDIO,
1522 },
1523 { NULL }
1524 };
1525
1526 #define OFFSET(x) offsetof(AudioRNNContext, x)
1527 #define AF AV_OPT_FLAG_AUDIO_PARAM|AV_OPT_FLAG_FILTERING_PARAM
1528
1529 static const AVOption arnndn_options[] = {
1530 { "model", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1531 { "m", "set model name", OFFSET(model_name), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, AF },
1532 { NULL }
1533 };
1534
1535 AVFILTER_DEFINE_CLASS(arnndn);
1536
1537 AVFilter ff_af_arnndn = {
1538 .name = "arnndn",
1539 .description = NULL_IF_CONFIG_SMALL("Reduce noise from speech using Recurrent Neural Networks."),
1540 .query_formats = query_formats,
1541 .priv_size = sizeof(AudioRNNContext),
1542 .priv_class = &arnndn_class,
1543 .activate = activate,
1544 .init = init,
1545 .uninit = uninit,
1546 .inputs = inputs,
1547 .outputs = outputs,
1548 .flags = AVFILTER_FLAG_SLICE_THREADS,
1549 };
1550