• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <float.h>
14 #include <limits.h>
15 #include <math.h>
16 
17 #include "config/aom_scale_rtcd.h"
18 #include "config/av1_rtcd.h"
19 
20 #include "aom_dsp/aom_dsp_common.h"
21 #include "aom_dsp/binary_codes_writer.h"
22 #include "aom_dsp/mathutils.h"
23 #include "aom_dsp/psnr.h"
24 #include "aom_mem/aom_mem.h"
25 #include "aom_ports/mem.h"
26 #include "av1/common/av1_common_int.h"
27 #include "av1/common/quant_common.h"
28 #include "av1/common/restoration.h"
29 
30 #include "av1/encoder/av1_quantize.h"
31 #include "av1/encoder/encoder.h"
32 #include "av1/encoder/picklpf.h"
33 #include "av1/encoder/pickrst.h"
34 
35 // Number of Wiener iterations
36 #define NUM_WIENER_ITERS 5
37 
38 // Penalty factor for use of dual sgr
39 #define DUAL_SGR_PENALTY_MULT 0.01
40 
41 // Working precision for Wiener filter coefficients
42 #define WIENER_TAP_SCALE_FACTOR ((int64_t)1 << 16)
43 
44 #define SGRPROJ_EP_GRP1_START_IDX 0
45 #define SGRPROJ_EP_GRP1_END_IDX 9
46 #define SGRPROJ_EP_GRP1_SEARCH_COUNT 4
47 #define SGRPROJ_EP_GRP2_3_SEARCH_COUNT 2
48 static const int sgproj_ep_grp1_seed[SGRPROJ_EP_GRP1_SEARCH_COUNT] = { 0, 3, 6,
49                                                                        9 };
50 static const int sgproj_ep_grp2_3[SGRPROJ_EP_GRP2_3_SEARCH_COUNT][14] = {
51   { 10, 10, 11, 11, 12, 12, 13, 13, 13, 13, -1, -1, -1, -1 },
52   { 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15 }
53 };
54 
55 #if DEBUG_LR_COSTING
56 RestorationUnitInfo lr_ref_params[RESTORE_TYPES][MAX_MB_PLANE]
57                                  [MAX_LR_UNITS_W * MAX_LR_UNITS_H];
58 #endif  // DEBUG_LR_COSTING
59 
60 typedef int64_t (*sse_extractor_type)(const YV12_BUFFER_CONFIG *a,
61                                       const YV12_BUFFER_CONFIG *b);
62 typedef int64_t (*sse_part_extractor_type)(const YV12_BUFFER_CONFIG *a,
63                                            const YV12_BUFFER_CONFIG *b,
64                                            int hstart, int width, int vstart,
65                                            int height);
66 typedef uint64_t (*var_part_extractor_type)(const YV12_BUFFER_CONFIG *a,
67                                             int hstart, int width, int vstart,
68                                             int height);
69 
70 #if CONFIG_AV1_HIGHBITDEPTH
71 #define NUM_EXTRACTORS (3 * (1 + 1))
72 #else
73 #define NUM_EXTRACTORS 3
74 #endif
75 static const sse_part_extractor_type sse_part_extractors[NUM_EXTRACTORS] = {
76   aom_get_y_sse_part,        aom_get_u_sse_part,
77   aom_get_v_sse_part,
78 #if CONFIG_AV1_HIGHBITDEPTH
79   aom_highbd_get_y_sse_part, aom_highbd_get_u_sse_part,
80   aom_highbd_get_v_sse_part,
81 #endif
82 };
83 static const var_part_extractor_type var_part_extractors[NUM_EXTRACTORS] = {
84   aom_get_y_var,        aom_get_u_var,        aom_get_v_var,
85 #if CONFIG_AV1_HIGHBITDEPTH
86   aom_highbd_get_y_var, aom_highbd_get_u_var, aom_highbd_get_v_var,
87 #endif
88 };
89 
sse_restoration_unit(const RestorationTileLimits * limits,const YV12_BUFFER_CONFIG * src,const YV12_BUFFER_CONFIG * dst,int plane,int highbd)90 static int64_t sse_restoration_unit(const RestorationTileLimits *limits,
91                                     const YV12_BUFFER_CONFIG *src,
92                                     const YV12_BUFFER_CONFIG *dst, int plane,
93                                     int highbd) {
94   return sse_part_extractors[3 * highbd + plane](
95       src, dst, limits->h_start, limits->h_end - limits->h_start,
96       limits->v_start, limits->v_end - limits->v_start);
97 }
98 
var_restoration_unit(const RestorationTileLimits * limits,const YV12_BUFFER_CONFIG * src,int plane,int highbd)99 static uint64_t var_restoration_unit(const RestorationTileLimits *limits,
100                                      const YV12_BUFFER_CONFIG *src, int plane,
101                                      int highbd) {
102   return var_part_extractors[3 * highbd + plane](
103       src, limits->h_start, limits->h_end - limits->h_start, limits->v_start,
104       limits->v_end - limits->v_start);
105 }
106 
107 typedef struct {
108   const YV12_BUFFER_CONFIG *src;
109   YV12_BUFFER_CONFIG *dst;
110 
111   const AV1_COMMON *cm;
112   const MACROBLOCK *x;
113   int plane;
114   int plane_w;
115   int plane_h;
116   RestUnitSearchInfo *rusi;
117 
118   // Speed features
119   const LOOP_FILTER_SPEED_FEATURES *lpf_sf;
120 
121   uint8_t *dgd_buffer;
122   int dgd_stride;
123   const uint8_t *src_buffer;
124   int src_stride;
125 
126   // SSE values for each restoration mode for the current RU
127   // These are saved by each search function for use in search_switchable()
128   int64_t sse[RESTORE_SWITCHABLE_TYPES];
129 
130   // This flag will be set based on the speed feature
131   // 'prune_sgr_based_on_wiener'. 0 implies no pruning and 1 implies pruning.
132   uint8_t skip_sgr_eval;
133 
134   // Total rate and distortion so far for each restoration type
135   // These are initialised by reset_rsc in search_rest_type
136   int64_t total_sse[RESTORE_TYPES];
137   int64_t total_bits[RESTORE_TYPES];
138 
139   // Reference parameters for delta-coding
140   //
141   // For each restoration type, we need to store the latest parameter set which
142   // has been used, so that we can properly cost up the next parameter set.
143   // Note that we have two sets of these - one for the single-restoration-mode
144   // search (ie, frame_restoration_type = RESTORE_WIENER or RESTORE_SGRPROJ)
145   // and one for the switchable mode. This is because these two cases can lead
146   // to different sets of parameters being signaled, but we don't know which
147   // we will pick for sure until the end of the search process.
148   WienerInfo ref_wiener;
149   SgrprojInfo ref_sgrproj;
150   WienerInfo switchable_ref_wiener;
151   SgrprojInfo switchable_ref_sgrproj;
152 
153   // Buffers used to hold dgd-avg and src-avg data respectively during SIMD
154   // call of Wiener filter.
155   int16_t *dgd_avg;
156   int16_t *src_avg;
157 } RestSearchCtxt;
158 
rsc_on_tile(void * priv)159 static AOM_INLINE void rsc_on_tile(void *priv) {
160   RestSearchCtxt *rsc = (RestSearchCtxt *)priv;
161   set_default_wiener(&rsc->ref_wiener);
162   set_default_sgrproj(&rsc->ref_sgrproj);
163   set_default_wiener(&rsc->switchable_ref_wiener);
164   set_default_sgrproj(&rsc->switchable_ref_sgrproj);
165 }
166 
reset_rsc(RestSearchCtxt * rsc)167 static AOM_INLINE void reset_rsc(RestSearchCtxt *rsc) {
168   memset(rsc->total_sse, 0, sizeof(rsc->total_sse));
169   memset(rsc->total_bits, 0, sizeof(rsc->total_bits));
170 }
171 
init_rsc(const YV12_BUFFER_CONFIG * src,const AV1_COMMON * cm,const MACROBLOCK * x,const LOOP_FILTER_SPEED_FEATURES * lpf_sf,int plane,RestUnitSearchInfo * rusi,YV12_BUFFER_CONFIG * dst,RestSearchCtxt * rsc)172 static AOM_INLINE void init_rsc(const YV12_BUFFER_CONFIG *src,
173                                 const AV1_COMMON *cm, const MACROBLOCK *x,
174                                 const LOOP_FILTER_SPEED_FEATURES *lpf_sf,
175                                 int plane, RestUnitSearchInfo *rusi,
176                                 YV12_BUFFER_CONFIG *dst, RestSearchCtxt *rsc) {
177   rsc->src = src;
178   rsc->dst = dst;
179   rsc->cm = cm;
180   rsc->x = x;
181   rsc->plane = plane;
182   rsc->rusi = rusi;
183   rsc->lpf_sf = lpf_sf;
184 
185   const YV12_BUFFER_CONFIG *dgd = &cm->cur_frame->buf;
186   const int is_uv = plane != AOM_PLANE_Y;
187   int plane_w, plane_h;
188   av1_get_upsampled_plane_size(cm, is_uv, &plane_w, &plane_h);
189   assert(plane_w == src->crop_widths[is_uv]);
190   assert(plane_h == src->crop_heights[is_uv]);
191   assert(src->crop_widths[is_uv] == dgd->crop_widths[is_uv]);
192   assert(src->crop_heights[is_uv] == dgd->crop_heights[is_uv]);
193 
194   rsc->plane_w = plane_w;
195   rsc->plane_h = plane_h;
196   rsc->src_buffer = src->buffers[plane];
197   rsc->src_stride = src->strides[is_uv];
198   rsc->dgd_buffer = dgd->buffers[plane];
199   rsc->dgd_stride = dgd->strides[is_uv];
200 }
201 
try_restoration_unit(const RestSearchCtxt * rsc,const RestorationTileLimits * limits,const RestorationUnitInfo * rui)202 static int64_t try_restoration_unit(const RestSearchCtxt *rsc,
203                                     const RestorationTileLimits *limits,
204                                     const RestorationUnitInfo *rui) {
205   const AV1_COMMON *const cm = rsc->cm;
206   const int plane = rsc->plane;
207   const int is_uv = plane > 0;
208   const RestorationInfo *rsi = &cm->rst_info[plane];
209   RestorationLineBuffers rlbs;
210   const int bit_depth = cm->seq_params->bit_depth;
211   const int highbd = cm->seq_params->use_highbitdepth;
212 
213   const YV12_BUFFER_CONFIG *fts = &cm->cur_frame->buf;
214   // TODO(yunqing): For now, only use optimized LR filter in decoder. Can be
215   // also used in encoder.
216   const int optimized_lr = 0;
217 
218   av1_loop_restoration_filter_unit(
219       limits, rui, &rsi->boundaries, &rlbs, rsc->plane_w, rsc->plane_h,
220       is_uv && cm->seq_params->subsampling_x,
221       is_uv && cm->seq_params->subsampling_y, highbd, bit_depth,
222       fts->buffers[plane], fts->strides[is_uv], rsc->dst->buffers[plane],
223       rsc->dst->strides[is_uv], cm->rst_tmpbuf, optimized_lr, cm->error);
224 
225   return sse_restoration_unit(limits, rsc->src, rsc->dst, plane, highbd);
226 }
227 
av1_lowbd_pixel_proj_error_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)228 int64_t av1_lowbd_pixel_proj_error_c(const uint8_t *src8, int width, int height,
229                                      int src_stride, const uint8_t *dat8,
230                                      int dat_stride, int32_t *flt0,
231                                      int flt0_stride, int32_t *flt1,
232                                      int flt1_stride, int xq[2],
233                                      const sgr_params_type *params) {
234   int i, j;
235   const uint8_t *src = src8;
236   const uint8_t *dat = dat8;
237   int64_t err = 0;
238   if (params->r[0] > 0 && params->r[1] > 0) {
239     for (i = 0; i < height; ++i) {
240       for (j = 0; j < width; ++j) {
241         assert(flt1[j] < (1 << 15) && flt1[j] > -(1 << 15));
242         assert(flt0[j] < (1 << 15) && flt0[j] > -(1 << 15));
243         const int32_t u = (int32_t)(dat[j] << SGRPROJ_RST_BITS);
244         int32_t v = u << SGRPROJ_PRJ_BITS;
245         v += xq[0] * (flt0[j] - u) + xq[1] * (flt1[j] - u);
246         const int32_t e =
247             ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - src[j];
248         err += ((int64_t)e * e);
249       }
250       dat += dat_stride;
251       src += src_stride;
252       flt0 += flt0_stride;
253       flt1 += flt1_stride;
254     }
255   } else if (params->r[0] > 0) {
256     for (i = 0; i < height; ++i) {
257       for (j = 0; j < width; ++j) {
258         assert(flt0[j] < (1 << 15) && flt0[j] > -(1 << 15));
259         const int32_t u = (int32_t)(dat[j] << SGRPROJ_RST_BITS);
260         int32_t v = u << SGRPROJ_PRJ_BITS;
261         v += xq[0] * (flt0[j] - u);
262         const int32_t e =
263             ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - src[j];
264         err += ((int64_t)e * e);
265       }
266       dat += dat_stride;
267       src += src_stride;
268       flt0 += flt0_stride;
269     }
270   } else if (params->r[1] > 0) {
271     for (i = 0; i < height; ++i) {
272       for (j = 0; j < width; ++j) {
273         assert(flt1[j] < (1 << 15) && flt1[j] > -(1 << 15));
274         const int32_t u = (int32_t)(dat[j] << SGRPROJ_RST_BITS);
275         int32_t v = u << SGRPROJ_PRJ_BITS;
276         v += xq[1] * (flt1[j] - u);
277         const int32_t e =
278             ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) - src[j];
279         err += ((int64_t)e * e);
280       }
281       dat += dat_stride;
282       src += src_stride;
283       flt1 += flt1_stride;
284     }
285   } else {
286     for (i = 0; i < height; ++i) {
287       for (j = 0; j < width; ++j) {
288         const int32_t e = (int32_t)(dat[j]) - src[j];
289         err += ((int64_t)e * e);
290       }
291       dat += dat_stride;
292       src += src_stride;
293     }
294   }
295 
296   return err;
297 }
298 
299 #if CONFIG_AV1_HIGHBITDEPTH
av1_highbd_pixel_proj_error_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)300 int64_t av1_highbd_pixel_proj_error_c(const uint8_t *src8, int width,
301                                       int height, int src_stride,
302                                       const uint8_t *dat8, int dat_stride,
303                                       int32_t *flt0, int flt0_stride,
304                                       int32_t *flt1, int flt1_stride, int xq[2],
305                                       const sgr_params_type *params) {
306   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
307   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
308   int i, j;
309   int64_t err = 0;
310   const int32_t half = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1);
311   if (params->r[0] > 0 && params->r[1] > 0) {
312     int xq0 = xq[0];
313     int xq1 = xq[1];
314     for (i = 0; i < height; ++i) {
315       for (j = 0; j < width; ++j) {
316         const int32_t d = dat[j];
317         const int32_t s = src[j];
318         const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
319         int32_t v0 = flt0[j] - u;
320         int32_t v1 = flt1[j] - u;
321         int32_t v = half;
322         v += xq0 * v0;
323         v += xq1 * v1;
324         const int32_t e = (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
325         err += ((int64_t)e * e);
326       }
327       dat += dat_stride;
328       flt0 += flt0_stride;
329       flt1 += flt1_stride;
330       src += src_stride;
331     }
332   } else if (params->r[0] > 0 || params->r[1] > 0) {
333     int exq;
334     int32_t *flt;
335     int flt_stride;
336     if (params->r[0] > 0) {
337       exq = xq[0];
338       flt = flt0;
339       flt_stride = flt0_stride;
340     } else {
341       exq = xq[1];
342       flt = flt1;
343       flt_stride = flt1_stride;
344     }
345     for (i = 0; i < height; ++i) {
346       for (j = 0; j < width; ++j) {
347         const int32_t d = dat[j];
348         const int32_t s = src[j];
349         const int32_t u = (int32_t)(d << SGRPROJ_RST_BITS);
350         int32_t v = half;
351         v += exq * (flt[j] - u);
352         const int32_t e = (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + d - s;
353         err += ((int64_t)e * e);
354       }
355       dat += dat_stride;
356       flt += flt_stride;
357       src += src_stride;
358     }
359   } else {
360     for (i = 0; i < height; ++i) {
361       for (j = 0; j < width; ++j) {
362         const int32_t d = dat[j];
363         const int32_t s = src[j];
364         const int32_t e = d - s;
365         err += ((int64_t)e * e);
366       }
367       dat += dat_stride;
368       src += src_stride;
369     }
370   }
371   return err;
372 }
373 #endif  // CONFIG_AV1_HIGHBITDEPTH
374 
get_pixel_proj_error(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int use_highbitdepth,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int * xqd,const sgr_params_type * params)375 static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
376                                     int src_stride, const uint8_t *dat8,
377                                     int dat_stride, int use_highbitdepth,
378                                     int32_t *flt0, int flt0_stride,
379                                     int32_t *flt1, int flt1_stride, int *xqd,
380                                     const sgr_params_type *params) {
381   int xq[2];
382   av1_decode_xq(xqd, xq, params);
383 
384 #if CONFIG_AV1_HIGHBITDEPTH
385   if (use_highbitdepth) {
386     return av1_highbd_pixel_proj_error(src8, width, height, src_stride, dat8,
387                                        dat_stride, flt0, flt0_stride, flt1,
388                                        flt1_stride, xq, params);
389 
390   } else {
391     return av1_lowbd_pixel_proj_error(src8, width, height, src_stride, dat8,
392                                       dat_stride, flt0, flt0_stride, flt1,
393                                       flt1_stride, xq, params);
394   }
395 #else
396   (void)use_highbitdepth;
397   return av1_lowbd_pixel_proj_error(src8, width, height, src_stride, dat8,
398                                     dat_stride, flt0, flt0_stride, flt1,
399                                     flt1_stride, xq, params);
400 #endif
401 }
402 
403 #define USE_SGRPROJ_REFINEMENT_SEARCH 1
finer_search_pixel_proj_error(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int use_highbitdepth,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int start_step,int * xqd,const sgr_params_type * params)404 static int64_t finer_search_pixel_proj_error(
405     const uint8_t *src8, int width, int height, int src_stride,
406     const uint8_t *dat8, int dat_stride, int use_highbitdepth, int32_t *flt0,
407     int flt0_stride, int32_t *flt1, int flt1_stride, int start_step, int *xqd,
408     const sgr_params_type *params) {
409   int64_t err = get_pixel_proj_error(
410       src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth, flt0,
411       flt0_stride, flt1, flt1_stride, xqd, params);
412   (void)start_step;
413 #if USE_SGRPROJ_REFINEMENT_SEARCH
414   int64_t err2;
415   int tap_min[] = { SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MIN1 };
416   int tap_max[] = { SGRPROJ_PRJ_MAX0, SGRPROJ_PRJ_MAX1 };
417   for (int s = start_step; s >= 1; s >>= 1) {
418     for (int p = 0; p < 2; ++p) {
419       if ((params->r[0] == 0 && p == 0) || (params->r[1] == 0 && p == 1)) {
420         continue;
421       }
422       int skip = 0;
423       do {
424         if (xqd[p] - s >= tap_min[p]) {
425           xqd[p] -= s;
426           err2 =
427               get_pixel_proj_error(src8, width, height, src_stride, dat8,
428                                    dat_stride, use_highbitdepth, flt0,
429                                    flt0_stride, flt1, flt1_stride, xqd, params);
430           if (err2 > err) {
431             xqd[p] += s;
432           } else {
433             err = err2;
434             skip = 1;
435             // At the highest step size continue moving in the same direction
436             if (s == start_step) continue;
437           }
438         }
439         break;
440       } while (1);
441       if (skip) break;
442       do {
443         if (xqd[p] + s <= tap_max[p]) {
444           xqd[p] += s;
445           err2 =
446               get_pixel_proj_error(src8, width, height, src_stride, dat8,
447                                    dat_stride, use_highbitdepth, flt0,
448                                    flt0_stride, flt1, flt1_stride, xqd, params);
449           if (err2 > err) {
450             xqd[p] -= s;
451           } else {
452             err = err2;
453             // At the highest step size continue moving in the same direction
454             if (s == start_step) continue;
455           }
456         }
457         break;
458       } while (1);
459     }
460   }
461 #endif  // USE_SGRPROJ_REFINEMENT_SEARCH
462   return err;
463 }
464 
signed_rounded_divide(int64_t dividend,int64_t divisor)465 static int64_t signed_rounded_divide(int64_t dividend, int64_t divisor) {
466   if (dividend < 0)
467     return (dividend - divisor / 2) / divisor;
468   else
469     return (dividend + divisor / 2) / divisor;
470 }
471 
calc_proj_params_r0_r1_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])472 static AOM_INLINE void calc_proj_params_r0_r1_c(
473     const uint8_t *src8, int width, int height, int src_stride,
474     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
475     int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
476   const int size = width * height;
477   const uint8_t *src = src8;
478   const uint8_t *dat = dat8;
479   for (int i = 0; i < height; ++i) {
480     for (int j = 0; j < width; ++j) {
481       const int32_t u = (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
482       const int32_t s =
483           (int32_t)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
484       const int32_t f1 = (int32_t)flt0[i * flt0_stride + j] - u;
485       const int32_t f2 = (int32_t)flt1[i * flt1_stride + j] - u;
486       H[0][0] += (int64_t)f1 * f1;
487       H[1][1] += (int64_t)f2 * f2;
488       H[0][1] += (int64_t)f1 * f2;
489       C[0] += (int64_t)f1 * s;
490       C[1] += (int64_t)f2 * s;
491     }
492   }
493   H[0][0] /= size;
494   H[0][1] /= size;
495   H[1][1] /= size;
496   H[1][0] = H[0][1];
497   C[0] /= size;
498   C[1] /= size;
499 }
500 
calc_proj_params_r0_r1_high_bd_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])501 static AOM_INLINE void calc_proj_params_r0_r1_high_bd_c(
502     const uint8_t *src8, int width, int height, int src_stride,
503     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
504     int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
505   const int size = width * height;
506   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
507   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
508   for (int i = 0; i < height; ++i) {
509     for (int j = 0; j < width; ++j) {
510       const int32_t u = (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
511       const int32_t s =
512           (int32_t)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
513       const int32_t f1 = (int32_t)flt0[i * flt0_stride + j] - u;
514       const int32_t f2 = (int32_t)flt1[i * flt1_stride + j] - u;
515       H[0][0] += (int64_t)f1 * f1;
516       H[1][1] += (int64_t)f2 * f2;
517       H[0][1] += (int64_t)f1 * f2;
518       C[0] += (int64_t)f1 * s;
519       C[1] += (int64_t)f2 * s;
520     }
521   }
522   H[0][0] /= size;
523   H[0][1] /= size;
524   H[1][1] /= size;
525   H[1][0] = H[0][1];
526   C[0] /= size;
527   C[1] /= size;
528 }
529 
calc_proj_params_r0_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])530 static AOM_INLINE void calc_proj_params_r0_c(const uint8_t *src8, int width,
531                                              int height, int src_stride,
532                                              const uint8_t *dat8,
533                                              int dat_stride, int32_t *flt0,
534                                              int flt0_stride, int64_t H[2][2],
535                                              int64_t C[2]) {
536   const int size = width * height;
537   const uint8_t *src = src8;
538   const uint8_t *dat = dat8;
539   for (int i = 0; i < height; ++i) {
540     for (int j = 0; j < width; ++j) {
541       const int32_t u = (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
542       const int32_t s =
543           (int32_t)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
544       const int32_t f1 = (int32_t)flt0[i * flt0_stride + j] - u;
545       H[0][0] += (int64_t)f1 * f1;
546       C[0] += (int64_t)f1 * s;
547     }
548   }
549   H[0][0] /= size;
550   C[0] /= size;
551 }
552 
calc_proj_params_r0_high_bd_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])553 static AOM_INLINE void calc_proj_params_r0_high_bd_c(
554     const uint8_t *src8, int width, int height, int src_stride,
555     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
556     int64_t H[2][2], int64_t C[2]) {
557   const int size = width * height;
558   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
559   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
560   for (int i = 0; i < height; ++i) {
561     for (int j = 0; j < width; ++j) {
562       const int32_t u = (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
563       const int32_t s =
564           (int32_t)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
565       const int32_t f1 = (int32_t)flt0[i * flt0_stride + j] - u;
566       H[0][0] += (int64_t)f1 * f1;
567       C[0] += (int64_t)f1 * s;
568     }
569   }
570   H[0][0] /= size;
571   C[0] /= size;
572 }
573 
calc_proj_params_r1_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])574 static AOM_INLINE void calc_proj_params_r1_c(const uint8_t *src8, int width,
575                                              int height, int src_stride,
576                                              const uint8_t *dat8,
577                                              int dat_stride, int32_t *flt1,
578                                              int flt1_stride, int64_t H[2][2],
579                                              int64_t C[2]) {
580   const int size = width * height;
581   const uint8_t *src = src8;
582   const uint8_t *dat = dat8;
583   for (int i = 0; i < height; ++i) {
584     for (int j = 0; j < width; ++j) {
585       const int32_t u = (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
586       const int32_t s =
587           (int32_t)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
588       const int32_t f2 = (int32_t)flt1[i * flt1_stride + j] - u;
589       H[1][1] += (int64_t)f2 * f2;
590       C[1] += (int64_t)f2 * s;
591     }
592   }
593   H[1][1] /= size;
594   C[1] /= size;
595 }
596 
calc_proj_params_r1_high_bd_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])597 static AOM_INLINE void calc_proj_params_r1_high_bd_c(
598     const uint8_t *src8, int width, int height, int src_stride,
599     const uint8_t *dat8, int dat_stride, int32_t *flt1, int flt1_stride,
600     int64_t H[2][2], int64_t C[2]) {
601   const int size = width * height;
602   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
603   const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
604   for (int i = 0; i < height; ++i) {
605     for (int j = 0; j < width; ++j) {
606       const int32_t u = (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
607       const int32_t s =
608           (int32_t)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
609       const int32_t f2 = (int32_t)flt1[i * flt1_stride + j] - u;
610       H[1][1] += (int64_t)f2 * f2;
611       C[1] += (int64_t)f2 * s;
612     }
613   }
614   H[1][1] /= size;
615   C[1] /= size;
616 }
617 
618 // The function calls 3 subfunctions for the following cases :
619 // 1) When params->r[0] > 0 and params->r[1] > 0. In this case all elements
620 // of C and H need to be computed.
621 // 2) When only params->r[0] > 0. In this case only H[0][0] and C[0] are
622 // non-zero and need to be computed.
623 // 3) When only params->r[1] > 0. In this case only H[1][1] and C[1] are
624 // non-zero and need to be computed.
av1_calc_proj_params_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)625 void av1_calc_proj_params_c(const uint8_t *src8, int width, int height,
626                             int src_stride, const uint8_t *dat8, int dat_stride,
627                             int32_t *flt0, int flt0_stride, int32_t *flt1,
628                             int flt1_stride, int64_t H[2][2], int64_t C[2],
629                             const sgr_params_type *params) {
630   if ((params->r[0] > 0) && (params->r[1] > 0)) {
631     calc_proj_params_r0_r1_c(src8, width, height, src_stride, dat8, dat_stride,
632                              flt0, flt0_stride, flt1, flt1_stride, H, C);
633   } else if (params->r[0] > 0) {
634     calc_proj_params_r0_c(src8, width, height, src_stride, dat8, dat_stride,
635                           flt0, flt0_stride, H, C);
636   } else if (params->r[1] > 0) {
637     calc_proj_params_r1_c(src8, width, height, src_stride, dat8, dat_stride,
638                           flt1, flt1_stride, H, C);
639   }
640 }
641 
av1_calc_proj_params_high_bd_c(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)642 void av1_calc_proj_params_high_bd_c(const uint8_t *src8, int width, int height,
643                                     int src_stride, const uint8_t *dat8,
644                                     int dat_stride, int32_t *flt0,
645                                     int flt0_stride, int32_t *flt1,
646                                     int flt1_stride, int64_t H[2][2],
647                                     int64_t C[2],
648                                     const sgr_params_type *params) {
649   if ((params->r[0] > 0) && (params->r[1] > 0)) {
650     calc_proj_params_r0_r1_high_bd_c(src8, width, height, src_stride, dat8,
651                                      dat_stride, flt0, flt0_stride, flt1,
652                                      flt1_stride, H, C);
653   } else if (params->r[0] > 0) {
654     calc_proj_params_r0_high_bd_c(src8, width, height, src_stride, dat8,
655                                   dat_stride, flt0, flt0_stride, H, C);
656   } else if (params->r[1] > 0) {
657     calc_proj_params_r1_high_bd_c(src8, width, height, src_stride, dat8,
658                                   dat_stride, flt1, flt1_stride, H, C);
659   }
660 }
661 
get_proj_subspace(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int use_highbitdepth,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int * xq,const sgr_params_type * params)662 static AOM_INLINE void get_proj_subspace(const uint8_t *src8, int width,
663                                          int height, int src_stride,
664                                          const uint8_t *dat8, int dat_stride,
665                                          int use_highbitdepth, int32_t *flt0,
666                                          int flt0_stride, int32_t *flt1,
667                                          int flt1_stride, int *xq,
668                                          const sgr_params_type *params) {
669   int64_t H[2][2] = { { 0, 0 }, { 0, 0 } };
670   int64_t C[2] = { 0, 0 };
671 
672   // Default values to be returned if the problem becomes ill-posed
673   xq[0] = 0;
674   xq[1] = 0;
675 
676   if (!use_highbitdepth) {
677     if ((width & 0x7) == 0) {
678       av1_calc_proj_params(src8, width, height, src_stride, dat8, dat_stride,
679                            flt0, flt0_stride, flt1, flt1_stride, H, C, params);
680     } else {
681       av1_calc_proj_params_c(src8, width, height, src_stride, dat8, dat_stride,
682                              flt0, flt0_stride, flt1, flt1_stride, H, C,
683                              params);
684     }
685   }
686 #if CONFIG_AV1_HIGHBITDEPTH
687   else {  // NOLINT
688     if ((width & 0x7) == 0) {
689       av1_calc_proj_params_high_bd(src8, width, height, src_stride, dat8,
690                                    dat_stride, flt0, flt0_stride, flt1,
691                                    flt1_stride, H, C, params);
692     } else {
693       av1_calc_proj_params_high_bd_c(src8, width, height, src_stride, dat8,
694                                      dat_stride, flt0, flt0_stride, flt1,
695                                      flt1_stride, H, C, params);
696     }
697   }
698 #endif
699 
700   if (params->r[0] == 0) {
701     // H matrix is now only the scalar H[1][1]
702     // C vector is now only the scalar C[1]
703     const int64_t Det = H[1][1];
704     if (Det == 0) return;  // ill-posed, return default values
705     xq[0] = 0;
706     xq[1] = (int)signed_rounded_divide(C[1] * (1 << SGRPROJ_PRJ_BITS), Det);
707   } else if (params->r[1] == 0) {
708     // H matrix is now only the scalar H[0][0]
709     // C vector is now only the scalar C[0]
710     const int64_t Det = H[0][0];
711     if (Det == 0) return;  // ill-posed, return default values
712     xq[0] = (int)signed_rounded_divide(C[0] * (1 << SGRPROJ_PRJ_BITS), Det);
713     xq[1] = 0;
714   } else {
715     const int64_t Det = H[0][0] * H[1][1] - H[0][1] * H[1][0];
716     if (Det == 0) return;  // ill-posed, return default values
717 
718     // If scaling up dividend would overflow, instead scale down the divisor
719     const int64_t div1 = H[1][1] * C[0] - H[0][1] * C[1];
720     if ((div1 > 0 && INT64_MAX / (1 << SGRPROJ_PRJ_BITS) < div1) ||
721         (div1 < 0 && INT64_MIN / (1 << SGRPROJ_PRJ_BITS) > div1))
722       xq[0] = (int)signed_rounded_divide(div1, Det / (1 << SGRPROJ_PRJ_BITS));
723     else
724       xq[0] = (int)signed_rounded_divide(div1 * (1 << SGRPROJ_PRJ_BITS), Det);
725 
726     const int64_t div2 = H[0][0] * C[1] - H[1][0] * C[0];
727     if ((div2 > 0 && INT64_MAX / (1 << SGRPROJ_PRJ_BITS) < div2) ||
728         (div2 < 0 && INT64_MIN / (1 << SGRPROJ_PRJ_BITS) > div2))
729       xq[1] = (int)signed_rounded_divide(div2, Det / (1 << SGRPROJ_PRJ_BITS));
730     else
731       xq[1] = (int)signed_rounded_divide(div2 * (1 << SGRPROJ_PRJ_BITS), Det);
732   }
733 }
734 
encode_xq(int * xq,int * xqd,const sgr_params_type * params)735 static AOM_INLINE void encode_xq(int *xq, int *xqd,
736                                  const sgr_params_type *params) {
737   if (params->r[0] == 0) {
738     xqd[0] = 0;
739     xqd[1] = clamp((1 << SGRPROJ_PRJ_BITS) - xq[1], SGRPROJ_PRJ_MIN1,
740                    SGRPROJ_PRJ_MAX1);
741   } else if (params->r[1] == 0) {
742     xqd[0] = clamp(xq[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
743     xqd[1] = clamp((1 << SGRPROJ_PRJ_BITS) - xqd[0], SGRPROJ_PRJ_MIN1,
744                    SGRPROJ_PRJ_MAX1);
745   } else {
746     xqd[0] = clamp(xq[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
747     xqd[1] = clamp((1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1], SGRPROJ_PRJ_MIN1,
748                    SGRPROJ_PRJ_MAX1);
749   }
750 }
751 
752 // Apply the self-guided filter across an entire restoration unit.
apply_sgr(int sgr_params_idx,const uint8_t * dat8,int width,int height,int dat_stride,int use_highbd,int bit_depth,int pu_width,int pu_height,int32_t * flt0,int32_t * flt1,int flt_stride,struct aom_internal_error_info * error_info)753 static AOM_INLINE void apply_sgr(int sgr_params_idx, const uint8_t *dat8,
754                                  int width, int height, int dat_stride,
755                                  int use_highbd, int bit_depth, int pu_width,
756                                  int pu_height, int32_t *flt0, int32_t *flt1,
757                                  int flt_stride,
758                                  struct aom_internal_error_info *error_info) {
759   for (int i = 0; i < height; i += pu_height) {
760     const int h = AOMMIN(pu_height, height - i);
761     int32_t *flt0_row = flt0 + i * flt_stride;
762     int32_t *flt1_row = flt1 + i * flt_stride;
763     const uint8_t *dat8_row = dat8 + i * dat_stride;
764 
765     // Iterate over the stripe in blocks of width pu_width
766     for (int j = 0; j < width; j += pu_width) {
767       const int w = AOMMIN(pu_width, width - j);
768       if (av1_selfguided_restoration(
769               dat8_row + j, w, h, dat_stride, flt0_row + j, flt1_row + j,
770               flt_stride, sgr_params_idx, bit_depth, use_highbd) != 0) {
771         aom_internal_error(
772             error_info, AOM_CODEC_MEM_ERROR,
773             "Error allocating buffer in av1_selfguided_restoration");
774       }
775     }
776   }
777 }
778 
compute_sgrproj_err(const uint8_t * dat8,const int width,const int height,const int dat_stride,const uint8_t * src8,const int src_stride,const int use_highbitdepth,const int bit_depth,const int pu_width,const int pu_height,const int ep,int32_t * flt0,int32_t * flt1,const int flt_stride,int * exqd,int64_t * err,struct aom_internal_error_info * error_info)779 static AOM_INLINE void compute_sgrproj_err(
780     const uint8_t *dat8, const int width, const int height,
781     const int dat_stride, const uint8_t *src8, const int src_stride,
782     const int use_highbitdepth, const int bit_depth, const int pu_width,
783     const int pu_height, const int ep, int32_t *flt0, int32_t *flt1,
784     const int flt_stride, int *exqd, int64_t *err,
785     struct aom_internal_error_info *error_info) {
786   int exq[2];
787   apply_sgr(ep, dat8, width, height, dat_stride, use_highbitdepth, bit_depth,
788             pu_width, pu_height, flt0, flt1, flt_stride, error_info);
789   const sgr_params_type *const params = &av1_sgr_params[ep];
790   get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
791                     use_highbitdepth, flt0, flt_stride, flt1, flt_stride, exq,
792                     params);
793   encode_xq(exq, exqd, params);
794   *err = finer_search_pixel_proj_error(
795       src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth, flt0,
796       flt_stride, flt1, flt_stride, 2, exqd, params);
797 }
798 
get_best_error(int64_t * besterr,const int64_t err,const int * exqd,int * bestxqd,int * bestep,const int ep)799 static AOM_INLINE void get_best_error(int64_t *besterr, const int64_t err,
800                                       const int *exqd, int *bestxqd,
801                                       int *bestep, const int ep) {
802   if (*besterr == -1 || err < *besterr) {
803     *bestep = ep;
804     *besterr = err;
805     bestxqd[0] = exqd[0];
806     bestxqd[1] = exqd[1];
807   }
808 }
809 
search_selfguided_restoration(const uint8_t * dat8,int width,int height,int dat_stride,const uint8_t * src8,int src_stride,int use_highbitdepth,int bit_depth,int pu_width,int pu_height,int32_t * rstbuf,int enable_sgr_ep_pruning,struct aom_internal_error_info * error_info)810 static SgrprojInfo search_selfguided_restoration(
811     const uint8_t *dat8, int width, int height, int dat_stride,
812     const uint8_t *src8, int src_stride, int use_highbitdepth, int bit_depth,
813     int pu_width, int pu_height, int32_t *rstbuf, int enable_sgr_ep_pruning,
814     struct aom_internal_error_info *error_info) {
815   int32_t *flt0 = rstbuf;
816   int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX;
817   int ep, idx, bestep = 0;
818   int64_t besterr = -1;
819   int exqd[2], bestxqd[2] = { 0, 0 };
820   int flt_stride = ((width + 7) & ~7) + 8;
821   assert(pu_width == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
822          pu_width == RESTORATION_PROC_UNIT_SIZE);
823   assert(pu_height == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
824          pu_height == RESTORATION_PROC_UNIT_SIZE);
825   if (!enable_sgr_ep_pruning) {
826     for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
827       int64_t err;
828       compute_sgrproj_err(dat8, width, height, dat_stride, src8, src_stride,
829                           use_highbitdepth, bit_depth, pu_width, pu_height, ep,
830                           flt0, flt1, flt_stride, exqd, &err, error_info);
831       get_best_error(&besterr, err, exqd, bestxqd, &bestep, ep);
832     }
833   } else {
834     // evaluate first four seed ep in first group
835     for (idx = 0; idx < SGRPROJ_EP_GRP1_SEARCH_COUNT; idx++) {
836       ep = sgproj_ep_grp1_seed[idx];
837       int64_t err;
838       compute_sgrproj_err(dat8, width, height, dat_stride, src8, src_stride,
839                           use_highbitdepth, bit_depth, pu_width, pu_height, ep,
840                           flt0, flt1, flt_stride, exqd, &err, error_info);
841       get_best_error(&besterr, err, exqd, bestxqd, &bestep, ep);
842     }
843     // evaluate left and right ep of winner in seed ep
844     int bestep_ref = bestep;
845     for (ep = bestep_ref - 1; ep < bestep_ref + 2; ep += 2) {
846       if (ep < SGRPROJ_EP_GRP1_START_IDX || ep > SGRPROJ_EP_GRP1_END_IDX)
847         continue;
848       int64_t err;
849       compute_sgrproj_err(dat8, width, height, dat_stride, src8, src_stride,
850                           use_highbitdepth, bit_depth, pu_width, pu_height, ep,
851                           flt0, flt1, flt_stride, exqd, &err, error_info);
852       get_best_error(&besterr, err, exqd, bestxqd, &bestep, ep);
853     }
854     // evaluate last two group
855     for (idx = 0; idx < SGRPROJ_EP_GRP2_3_SEARCH_COUNT; idx++) {
856       ep = sgproj_ep_grp2_3[idx][bestep];
857       int64_t err;
858       compute_sgrproj_err(dat8, width, height, dat_stride, src8, src_stride,
859                           use_highbitdepth, bit_depth, pu_width, pu_height, ep,
860                           flt0, flt1, flt_stride, exqd, &err, error_info);
861       get_best_error(&besterr, err, exqd, bestxqd, &bestep, ep);
862     }
863   }
864 
865   SgrprojInfo ret;
866   ret.ep = bestep;
867   ret.xqd[0] = bestxqd[0];
868   ret.xqd[1] = bestxqd[1];
869   return ret;
870 }
871 
count_sgrproj_bits(SgrprojInfo * sgrproj_info,SgrprojInfo * ref_sgrproj_info)872 static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
873                               SgrprojInfo *ref_sgrproj_info) {
874   int bits = SGRPROJ_PARAMS_BITS;
875   const sgr_params_type *params = &av1_sgr_params[sgrproj_info->ep];
876   if (params->r[0] > 0)
877     bits += aom_count_primitive_refsubexpfin(
878         SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
879         ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
880         sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0);
881   if (params->r[1] > 0)
882     bits += aom_count_primitive_refsubexpfin(
883         SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
884         ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
885         sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
886   return bits;
887 }
888 
search_sgrproj(const RestorationTileLimits * limits,int rest_unit_idx,void * priv,int32_t * tmpbuf,RestorationLineBuffers * rlbs,struct aom_internal_error_info * error_info)889 static AOM_INLINE void search_sgrproj(
890     const RestorationTileLimits *limits, int rest_unit_idx, void *priv,
891     int32_t *tmpbuf, RestorationLineBuffers *rlbs,
892     struct aom_internal_error_info *error_info) {
893   (void)rlbs;
894   RestSearchCtxt *rsc = (RestSearchCtxt *)priv;
895   RestUnitSearchInfo *rusi = &rsc->rusi[rest_unit_idx];
896 
897   const MACROBLOCK *const x = rsc->x;
898   const AV1_COMMON *const cm = rsc->cm;
899   const int highbd = cm->seq_params->use_highbitdepth;
900   const int bit_depth = cm->seq_params->bit_depth;
901 
902   const int64_t bits_none = x->mode_costs.sgrproj_restore_cost[0];
903   // Prune evaluation of RESTORE_SGRPROJ if 'skip_sgr_eval' is set
904   if (rsc->skip_sgr_eval) {
905     rsc->total_bits[RESTORE_SGRPROJ] += bits_none;
906     rsc->total_sse[RESTORE_SGRPROJ] += rsc->sse[RESTORE_NONE];
907     rusi->best_rtype[RESTORE_SGRPROJ - 1] = RESTORE_NONE;
908     rsc->sse[RESTORE_SGRPROJ] = INT64_MAX;
909     return;
910   }
911 
912   uint8_t *dgd_start =
913       rsc->dgd_buffer + limits->v_start * rsc->dgd_stride + limits->h_start;
914   const uint8_t *src_start =
915       rsc->src_buffer + limits->v_start * rsc->src_stride + limits->h_start;
916 
917   const int is_uv = rsc->plane > 0;
918   const int ss_x = is_uv && cm->seq_params->subsampling_x;
919   const int ss_y = is_uv && cm->seq_params->subsampling_y;
920   const int procunit_width = RESTORATION_PROC_UNIT_SIZE >> ss_x;
921   const int procunit_height = RESTORATION_PROC_UNIT_SIZE >> ss_y;
922 
923   rusi->sgrproj = search_selfguided_restoration(
924       dgd_start, limits->h_end - limits->h_start,
925       limits->v_end - limits->v_start, rsc->dgd_stride, src_start,
926       rsc->src_stride, highbd, bit_depth, procunit_width, procunit_height,
927       tmpbuf, rsc->lpf_sf->enable_sgr_ep_pruning, error_info);
928 
929   RestorationUnitInfo rui;
930   rui.restoration_type = RESTORE_SGRPROJ;
931   rui.sgrproj_info = rusi->sgrproj;
932 
933   rsc->sse[RESTORE_SGRPROJ] = try_restoration_unit(rsc, limits, &rui);
934 
935   const int64_t bits_sgr =
936       x->mode_costs.sgrproj_restore_cost[1] +
937       (count_sgrproj_bits(&rusi->sgrproj, &rsc->ref_sgrproj)
938        << AV1_PROB_COST_SHIFT);
939   double cost_none = RDCOST_DBL_WITH_NATIVE_BD_DIST(
940       x->rdmult, bits_none >> 4, rsc->sse[RESTORE_NONE], bit_depth);
941   double cost_sgr = RDCOST_DBL_WITH_NATIVE_BD_DIST(
942       x->rdmult, bits_sgr >> 4, rsc->sse[RESTORE_SGRPROJ], bit_depth);
943   if (rusi->sgrproj.ep < 10)
944     cost_sgr *=
945         (1 + DUAL_SGR_PENALTY_MULT * rsc->lpf_sf->dual_sgr_penalty_level);
946 
947   RestorationType rtype =
948       (cost_sgr < cost_none) ? RESTORE_SGRPROJ : RESTORE_NONE;
949   rusi->best_rtype[RESTORE_SGRPROJ - 1] = rtype;
950 
951 #if DEBUG_LR_COSTING
952   // Store ref params for later checking
953   lr_ref_params[RESTORE_SGRPROJ][rsc->plane][rest_unit_idx].sgrproj_info =
954       rsc->ref_sgrproj;
955 #endif  // DEBUG_LR_COSTING
956 
957   rsc->total_sse[RESTORE_SGRPROJ] += rsc->sse[rtype];
958   rsc->total_bits[RESTORE_SGRPROJ] +=
959       (cost_sgr < cost_none) ? bits_sgr : bits_none;
960   if (cost_sgr < cost_none) rsc->ref_sgrproj = rusi->sgrproj;
961 }
962 
acc_stat_one_line(const uint8_t * dgd,const uint8_t * src,int dgd_stride,int h_start,int h_end,uint8_t avg,const int wiener_halfwin,const int wiener_win2,int32_t * M_int32,int32_t * H_int32,int count)963 static void acc_stat_one_line(const uint8_t *dgd, const uint8_t *src,
964                               int dgd_stride, int h_start, int h_end,
965                               uint8_t avg, const int wiener_halfwin,
966                               const int wiener_win2, int32_t *M_int32,
967                               int32_t *H_int32, int count) {
968   int j, k, l;
969   int16_t Y[WIENER_WIN2];
970 
971   for (j = h_start; j < h_end; j++) {
972     const int16_t X = (int16_t)src[j] - (int16_t)avg;
973     int idx = 0;
974     for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
975       for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
976         Y[idx] =
977             (int16_t)dgd[(count + l) * dgd_stride + (j + k)] - (int16_t)avg;
978         idx++;
979       }
980     }
981     assert(idx == wiener_win2);
982     for (k = 0; k < wiener_win2; ++k) {
983       M_int32[k] += (int32_t)Y[k] * X;
984       for (l = k; l < wiener_win2; ++l) {
985         // H is a symmetric matrix, so we only need to fill out the upper
986         // triangle here. We can copy it down to the lower triangle outside
987         // the (i, j) loops.
988         H_int32[k * wiener_win2 + l] += (int32_t)Y[k] * Y[l];
989       }
990     }
991   }
992 }
993 
av1_compute_stats_c(int wiener_win,const uint8_t * dgd,const uint8_t * src,int16_t * dgd_avg,int16_t * src_avg,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,int use_downsampled_wiener_stats)994 void av1_compute_stats_c(int wiener_win, const uint8_t *dgd, const uint8_t *src,
995                          int16_t *dgd_avg, int16_t *src_avg, int h_start,
996                          int h_end, int v_start, int v_end, int dgd_stride,
997                          int src_stride, int64_t *M, int64_t *H,
998                          int use_downsampled_wiener_stats) {
999   (void)dgd_avg;
1000   (void)src_avg;
1001   int i, k, l;
1002   const int wiener_win2 = wiener_win * wiener_win;
1003   const int wiener_halfwin = (wiener_win >> 1);
1004   uint8_t avg = find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
1005   int32_t M_row[WIENER_WIN2] = { 0 };
1006   int32_t H_row[WIENER_WIN2 * WIENER_WIN2] = { 0 };
1007   int downsample_factor =
1008       use_downsampled_wiener_stats ? WIENER_STATS_DOWNSAMPLE_FACTOR : 1;
1009 
1010   memset(M, 0, sizeof(*M) * wiener_win2);
1011   memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
1012 
1013   for (i = v_start; i < v_end; i = i + downsample_factor) {
1014     if (use_downsampled_wiener_stats &&
1015         (v_end - i < WIENER_STATS_DOWNSAMPLE_FACTOR)) {
1016       downsample_factor = v_end - i;
1017     }
1018 
1019     memset(M_row, 0, sizeof(int32_t) * WIENER_WIN2);
1020     memset(H_row, 0, sizeof(int32_t) * WIENER_WIN2 * WIENER_WIN2);
1021     acc_stat_one_line(dgd, src + i * src_stride, dgd_stride, h_start, h_end,
1022                       avg, wiener_halfwin, wiener_win2, M_row, H_row, i);
1023 
1024     for (k = 0; k < wiener_win2; ++k) {
1025       // Scale M matrix based on the downsampling factor
1026       M[k] += ((int64_t)M_row[k] * downsample_factor);
1027       for (l = k; l < wiener_win2; ++l) {
1028         // H is a symmetric matrix, so we only need to fill out the upper
1029         // triangle here. We can copy it down to the lower triangle outside
1030         // the (i, j) loops.
1031         // Scale H Matrix based on the downsampling factor
1032         H[k * wiener_win2 + l] +=
1033             ((int64_t)H_row[k * wiener_win2 + l] * downsample_factor);
1034       }
1035     }
1036   }
1037 
1038   for (k = 0; k < wiener_win2; ++k) {
1039     for (l = k + 1; l < wiener_win2; ++l) {
1040       H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
1041     }
1042   }
1043 }
1044 
1045 #if CONFIG_AV1_HIGHBITDEPTH
av1_compute_stats_highbd_c(int wiener_win,const uint8_t * dgd8,const uint8_t * src8,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,aom_bit_depth_t bit_depth)1046 void av1_compute_stats_highbd_c(int wiener_win, const uint8_t *dgd8,
1047                                 const uint8_t *src8, int h_start, int h_end,
1048                                 int v_start, int v_end, int dgd_stride,
1049                                 int src_stride, int64_t *M, int64_t *H,
1050                                 aom_bit_depth_t bit_depth) {
1051   int i, j, k, l;
1052   int32_t Y[WIENER_WIN2];
1053   const int wiener_win2 = wiener_win * wiener_win;
1054   const int wiener_halfwin = (wiener_win >> 1);
1055   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
1056   const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
1057   uint16_t avg =
1058       find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
1059 
1060   uint8_t bit_depth_divider = 1;
1061   if (bit_depth == AOM_BITS_12)
1062     bit_depth_divider = 16;
1063   else if (bit_depth == AOM_BITS_10)
1064     bit_depth_divider = 4;
1065 
1066   memset(M, 0, sizeof(*M) * wiener_win2);
1067   memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
1068   for (i = v_start; i < v_end; i++) {
1069     for (j = h_start; j < h_end; j++) {
1070       const int32_t X = (int32_t)src[i * src_stride + j] - (int32_t)avg;
1071       int idx = 0;
1072       for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
1073         for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
1074           Y[idx] = (int32_t)dgd[(i + l) * dgd_stride + (j + k)] - (int32_t)avg;
1075           idx++;
1076         }
1077       }
1078       assert(idx == wiener_win2);
1079       for (k = 0; k < wiener_win2; ++k) {
1080         M[k] += (int64_t)Y[k] * X;
1081         for (l = k; l < wiener_win2; ++l) {
1082           // H is a symmetric matrix, so we only need to fill out the upper
1083           // triangle here. We can copy it down to the lower triangle outside
1084           // the (i, j) loops.
1085           H[k * wiener_win2 + l] += (int64_t)Y[k] * Y[l];
1086         }
1087       }
1088     }
1089   }
1090   for (k = 0; k < wiener_win2; ++k) {
1091     M[k] /= bit_depth_divider;
1092     H[k * wiener_win2 + k] /= bit_depth_divider;
1093     for (l = k + 1; l < wiener_win2; ++l) {
1094       H[k * wiener_win2 + l] /= bit_depth_divider;
1095       H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
1096     }
1097   }
1098 }
1099 #endif  // CONFIG_AV1_HIGHBITDEPTH
1100 
wrap_index(int i,int wiener_win)1101 static INLINE int wrap_index(int i, int wiener_win) {
1102   const int wiener_halfwin1 = (wiener_win >> 1) + 1;
1103   return (i >= wiener_halfwin1 ? wiener_win - 1 - i : i);
1104 }
1105 
1106 // Splits each w[i] into smaller components w1[i] and w2[i] such that
1107 // w[i] = w1[i] * WIENER_TAP_SCALE_FACTOR + w2[i].
split_wiener_filter_coefficients(int wiener_win,const int32_t * w,int32_t * w1,int32_t * w2)1108 static INLINE void split_wiener_filter_coefficients(int wiener_win,
1109                                                     const int32_t *w,
1110                                                     int32_t *w1, int32_t *w2) {
1111   for (int i = 0; i < wiener_win; i++) {
1112     w1[i] = w[i] / WIENER_TAP_SCALE_FACTOR;
1113     w2[i] = w[i] - w1[i] * WIENER_TAP_SCALE_FACTOR;
1114     assert(w[i] == w1[i] * WIENER_TAP_SCALE_FACTOR + w2[i]);
1115   }
1116 }
1117 
1118 // Calculates x * w / WIENER_TAP_SCALE_FACTOR, where
1119 // w = w1 * WIENER_TAP_SCALE_FACTOR + w2.
1120 //
1121 // The multiplication x * w may overflow, so we multiply x by the components of
1122 // w (w1 and w2) and combine the multiplication with the division.
multiply_and_scale(int64_t x,int32_t w1,int32_t w2)1123 static INLINE int64_t multiply_and_scale(int64_t x, int32_t w1, int32_t w2) {
1124   // Let y = x * w / WIENER_TAP_SCALE_FACTOR
1125   //       = x * (w1 * WIENER_TAP_SCALE_FACTOR + w2) / WIENER_TAP_SCALE_FACTOR
1126   const int64_t y = x * w1 + x * w2 / WIENER_TAP_SCALE_FACTOR;
1127   // Double-check the calculation using __int128.
1128   // TODO(wtc): Remove after 2024-04-30.
1129 #if !defined(NDEBUG) && defined(__GNUC__) && defined(__LP64__)
1130   const int32_t w = w1 * WIENER_TAP_SCALE_FACTOR + w2;
1131   const __int128 z = (__int128)x * w / WIENER_TAP_SCALE_FACTOR;
1132   assert(z >= INT64_MIN);
1133   assert(z <= INT64_MAX);
1134   assert(y == (int64_t)z);
1135 #endif
1136   return y;
1137 }
1138 
1139 // Solve linear equations to find Wiener filter tap values
1140 // Taps are output scaled by WIENER_FILT_STEP
linsolve_wiener(int n,int64_t * A,int stride,int64_t * b,int64_t * x)1141 static int linsolve_wiener(int n, int64_t *A, int stride, int64_t *b,
1142                            int64_t *x) {
1143   for (int k = 0; k < n - 1; k++) {
1144     // Partial pivoting: bring the row with the largest pivot to the top
1145     for (int i = n - 1; i > k; i--) {
1146       // If row i has a better (bigger) pivot than row (i-1), swap them
1147       if (llabs(A[(i - 1) * stride + k]) < llabs(A[i * stride + k])) {
1148         for (int j = 0; j < n; j++) {
1149           const int64_t c = A[i * stride + j];
1150           A[i * stride + j] = A[(i - 1) * stride + j];
1151           A[(i - 1) * stride + j] = c;
1152         }
1153         const int64_t c = b[i];
1154         b[i] = b[i - 1];
1155         b[i - 1] = c;
1156       }
1157     }
1158 
1159     // b/278065963: The multiplies
1160     //   c / 256 * A[k * stride + j] / cd * 256
1161     // and
1162     //   c / 256 * b[k] / cd * 256
1163     // within Gaussian elimination can cause a signed integer overflow. Rework
1164     // the multiplies so that larger scaling is used without significantly
1165     // impacting the overall precision.
1166     //
1167     // Precision guidance:
1168     //   scale_threshold: Pick as high as possible.
1169     // For max_abs_akj >= scale_threshold scenario:
1170     //   scaler_A: Pick as low as possible. Needed for A[(i + 1) * stride + j].
1171     //   scaler_c: Pick as low as possible while maintaining scaler_c >=
1172     //     (1 << 7). Needed for A[(i + 1) * stride + j] and b[i + 1].
1173     int64_t max_abs_akj = 0;
1174     for (int j = 0; j < n; j++) {
1175       const int64_t abs_akj = llabs(A[k * stride + j]);
1176       if (abs_akj > max_abs_akj) max_abs_akj = abs_akj;
1177     }
1178     const int scale_threshold = 1 << 22;
1179     const int scaler_A = max_abs_akj < scale_threshold ? 1 : (1 << 5);
1180     const int scaler_c = max_abs_akj < scale_threshold ? 1 : (1 << 7);
1181     const int scaler = scaler_c * scaler_A;
1182 
1183     // Forward elimination (convert A to row-echelon form)
1184     for (int i = k; i < n - 1; i++) {
1185       if (A[k * stride + k] == 0) return 0;
1186       const int64_t c = A[(i + 1) * stride + k] / scaler_c;
1187       const int64_t cd = A[k * stride + k];
1188       for (int j = 0; j < n; j++) {
1189         A[(i + 1) * stride + j] -=
1190             A[k * stride + j] / scaler_A * c / cd * scaler;
1191       }
1192       b[i + 1] -= c * b[k] / cd * scaler_c;
1193     }
1194   }
1195   // Back-substitution
1196   for (int i = n - 1; i >= 0; i--) {
1197     if (A[i * stride + i] == 0) return 0;
1198     int64_t c = 0;
1199     for (int j = i + 1; j <= n - 1; j++) {
1200       c += A[i * stride + j] * x[j] / WIENER_TAP_SCALE_FACTOR;
1201     }
1202     // Store filter taps x in scaled form.
1203     x[i] = WIENER_TAP_SCALE_FACTOR * (b[i] - c) / A[i * stride + i];
1204   }
1205 
1206   return 1;
1207 }
1208 
1209 // Fix vector b, update vector a
update_a_sep_sym(int wiener_win,int64_t ** Mc,int64_t ** Hc,int32_t * a,const int32_t * b)1210 static AOM_INLINE void update_a_sep_sym(int wiener_win, int64_t **Mc,
1211                                         int64_t **Hc, int32_t *a,
1212                                         const int32_t *b) {
1213   int i, j;
1214   int64_t S[WIENER_WIN];
1215   int64_t A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
1216   int32_t b1[WIENER_WIN], b2[WIENER_WIN];
1217   const int wiener_win2 = wiener_win * wiener_win;
1218   const int wiener_halfwin1 = (wiener_win >> 1) + 1;
1219   memset(A, 0, sizeof(A));
1220   memset(B, 0, sizeof(B));
1221   for (i = 0; i < wiener_win; i++) {
1222     for (j = 0; j < wiener_win; ++j) {
1223       const int jj = wrap_index(j, wiener_win);
1224       A[jj] += Mc[i][j] * b[i] / WIENER_TAP_SCALE_FACTOR;
1225     }
1226   }
1227   split_wiener_filter_coefficients(wiener_win, b, b1, b2);
1228 
1229   for (i = 0; i < wiener_win; i++) {
1230     for (j = 0; j < wiener_win; j++) {
1231       int k, l;
1232       for (k = 0; k < wiener_win; ++k) {
1233         const int kk = wrap_index(k, wiener_win);
1234         for (l = 0; l < wiener_win; ++l) {
1235           const int ll = wrap_index(l, wiener_win);
1236           // Calculate
1237           // B[ll * wiener_halfwin1 + kk] +=
1238           //    Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] /
1239           //    WIENER_TAP_SCALE_FACTOR * b[j] / WIENER_TAP_SCALE_FACTOR;
1240           //
1241           // The last multiplication may overflow, so we combine the last
1242           // multiplication with the last division.
1243           const int64_t x = Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] /
1244                             WIENER_TAP_SCALE_FACTOR;
1245           // b[j] = b1[j] * WIENER_TAP_SCALE_FACTOR + b2[j]
1246           B[ll * wiener_halfwin1 + kk] += multiply_and_scale(x, b1[j], b2[j]);
1247         }
1248       }
1249     }
1250   }
1251   // Normalization enforcement in the system of equations itself
1252   for (i = 0; i < wiener_halfwin1 - 1; ++i) {
1253     A[i] -=
1254         A[wiener_halfwin1 - 1] * 2 +
1255         B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
1256         2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
1257   }
1258   for (i = 0; i < wiener_halfwin1 - 1; ++i) {
1259     for (j = 0; j < wiener_halfwin1 - 1; ++j) {
1260       B[i * wiener_halfwin1 + j] -=
1261           2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
1262                B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
1263                2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
1264                      (wiener_halfwin1 - 1)]);
1265     }
1266   }
1267   if (linsolve_wiener(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
1268     S[wiener_halfwin1 - 1] = WIENER_TAP_SCALE_FACTOR;
1269     for (i = wiener_halfwin1; i < wiener_win; ++i) {
1270       S[i] = S[wiener_win - 1 - i];
1271       S[wiener_halfwin1 - 1] -= 2 * S[i];
1272     }
1273     for (i = 0; i < wiener_win; ++i) {
1274       a[i] = (int32_t)CLIP(S[i], -(1 << (WIENER_FILT_BITS - 1)),
1275                            (1 << (WIENER_FILT_BITS - 1)) - 1);
1276     }
1277   }
1278 }
1279 
1280 // Fix vector a, update vector b
update_b_sep_sym(int wiener_win,int64_t ** Mc,int64_t ** Hc,const int32_t * a,int32_t * b)1281 static AOM_INLINE void update_b_sep_sym(int wiener_win, int64_t **Mc,
1282                                         int64_t **Hc, const int32_t *a,
1283                                         int32_t *b) {
1284   int i, j;
1285   int64_t S[WIENER_WIN];
1286   int64_t A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
1287   int32_t a1[WIENER_WIN], a2[WIENER_WIN];
1288   const int wiener_win2 = wiener_win * wiener_win;
1289   const int wiener_halfwin1 = (wiener_win >> 1) + 1;
1290   memset(A, 0, sizeof(A));
1291   memset(B, 0, sizeof(B));
1292   for (i = 0; i < wiener_win; i++) {
1293     const int ii = wrap_index(i, wiener_win);
1294     for (j = 0; j < wiener_win; j++) {
1295       A[ii] += Mc[i][j] * a[j] / WIENER_TAP_SCALE_FACTOR;
1296     }
1297   }
1298   split_wiener_filter_coefficients(wiener_win, a, a1, a2);
1299 
1300   for (i = 0; i < wiener_win; i++) {
1301     const int ii = wrap_index(i, wiener_win);
1302     for (j = 0; j < wiener_win; j++) {
1303       const int jj = wrap_index(j, wiener_win);
1304       int k, l;
1305       for (k = 0; k < wiener_win; ++k) {
1306         for (l = 0; l < wiener_win; ++l) {
1307           // Calculate
1308           // B[jj * wiener_halfwin1 + ii] +=
1309           //     Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] /
1310           //     WIENER_TAP_SCALE_FACTOR * a[l] / WIENER_TAP_SCALE_FACTOR;
1311           //
1312           // The last multiplication may overflow, so we combine the last
1313           // multiplication with the last division.
1314           const int64_t x = Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] /
1315                             WIENER_TAP_SCALE_FACTOR;
1316           // a[l] = a1[l] * WIENER_TAP_SCALE_FACTOR + a2[l]
1317           B[jj * wiener_halfwin1 + ii] += multiply_and_scale(x, a1[l], a2[l]);
1318         }
1319       }
1320     }
1321   }
1322   // Normalization enforcement in the system of equations itself
1323   for (i = 0; i < wiener_halfwin1 - 1; ++i) {
1324     A[i] -=
1325         A[wiener_halfwin1 - 1] * 2 +
1326         B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
1327         2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
1328   }
1329   for (i = 0; i < wiener_halfwin1 - 1; ++i) {
1330     for (j = 0; j < wiener_halfwin1 - 1; ++j) {
1331       B[i * wiener_halfwin1 + j] -=
1332           2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
1333                B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
1334                2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
1335                      (wiener_halfwin1 - 1)]);
1336     }
1337   }
1338   if (linsolve_wiener(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
1339     S[wiener_halfwin1 - 1] = WIENER_TAP_SCALE_FACTOR;
1340     for (i = wiener_halfwin1; i < wiener_win; ++i) {
1341       S[i] = S[wiener_win - 1 - i];
1342       S[wiener_halfwin1 - 1] -= 2 * S[i];
1343     }
1344     for (i = 0; i < wiener_win; ++i) {
1345       b[i] = (int32_t)CLIP(S[i], -(1 << (WIENER_FILT_BITS - 1)),
1346                            (1 << (WIENER_FILT_BITS - 1)) - 1);
1347     }
1348   }
1349 }
1350 
wiener_decompose_sep_sym(int wiener_win,int64_t * M,int64_t * H,int32_t * a,int32_t * b)1351 static void wiener_decompose_sep_sym(int wiener_win, int64_t *M, int64_t *H,
1352                                      int32_t *a, int32_t *b) {
1353   static const int32_t init_filt[WIENER_WIN] = {
1354     WIENER_FILT_TAP0_MIDV, WIENER_FILT_TAP1_MIDV, WIENER_FILT_TAP2_MIDV,
1355     WIENER_FILT_TAP3_MIDV, WIENER_FILT_TAP2_MIDV, WIENER_FILT_TAP1_MIDV,
1356     WIENER_FILT_TAP0_MIDV,
1357   };
1358   int64_t *Hc[WIENER_WIN2];
1359   int64_t *Mc[WIENER_WIN];
1360   int i, j, iter;
1361   const int plane_off = (WIENER_WIN - wiener_win) >> 1;
1362   const int wiener_win2 = wiener_win * wiener_win;
1363   for (i = 0; i < wiener_win; i++) {
1364     a[i] = b[i] =
1365         WIENER_TAP_SCALE_FACTOR / WIENER_FILT_STEP * init_filt[i + plane_off];
1366   }
1367   for (i = 0; i < wiener_win; i++) {
1368     Mc[i] = M + i * wiener_win;
1369     for (j = 0; j < wiener_win; j++) {
1370       Hc[i * wiener_win + j] =
1371           H + i * wiener_win * wiener_win2 + j * wiener_win;
1372     }
1373   }
1374 
1375   iter = 1;
1376   while (iter < NUM_WIENER_ITERS) {
1377     update_a_sep_sym(wiener_win, Mc, Hc, a, b);
1378     update_b_sep_sym(wiener_win, Mc, Hc, a, b);
1379     iter++;
1380   }
1381 }
1382 
1383 // Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
1384 // against identity filters; Final score is defined as the difference between
1385 // the function values
compute_score(int wiener_win,int64_t * M,int64_t * H,InterpKernel vfilt,InterpKernel hfilt)1386 static int64_t compute_score(int wiener_win, int64_t *M, int64_t *H,
1387                              InterpKernel vfilt, InterpKernel hfilt) {
1388   int32_t ab[WIENER_WIN * WIENER_WIN];
1389   int16_t a[WIENER_WIN], b[WIENER_WIN];
1390   int64_t P = 0, Q = 0;
1391   int64_t iP = 0, iQ = 0;
1392   int64_t Score, iScore;
1393   int i, k, l;
1394   const int plane_off = (WIENER_WIN - wiener_win) >> 1;
1395   const int wiener_win2 = wiener_win * wiener_win;
1396 
1397   a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = WIENER_FILT_STEP;
1398   for (i = 0; i < WIENER_HALFWIN; ++i) {
1399     a[i] = a[WIENER_WIN - i - 1] = vfilt[i];
1400     b[i] = b[WIENER_WIN - i - 1] = hfilt[i];
1401     a[WIENER_HALFWIN] -= 2 * a[i];
1402     b[WIENER_HALFWIN] -= 2 * b[i];
1403   }
1404   memset(ab, 0, sizeof(ab));
1405   for (k = 0; k < wiener_win; ++k) {
1406     for (l = 0; l < wiener_win; ++l)
1407       ab[k * wiener_win + l] = a[l + plane_off] * b[k + plane_off];
1408   }
1409   for (k = 0; k < wiener_win2; ++k) {
1410     P += ab[k] * M[k] / WIENER_FILT_STEP / WIENER_FILT_STEP;
1411     for (l = 0; l < wiener_win2; ++l) {
1412       Q += ab[k] * H[k * wiener_win2 + l] * ab[l] / WIENER_FILT_STEP /
1413            WIENER_FILT_STEP / WIENER_FILT_STEP / WIENER_FILT_STEP;
1414     }
1415   }
1416   Score = Q - 2 * P;
1417 
1418   iP = M[wiener_win2 >> 1];
1419   iQ = H[(wiener_win2 >> 1) * wiener_win2 + (wiener_win2 >> 1)];
1420   iScore = iQ - 2 * iP;
1421 
1422   return Score - iScore;
1423 }
1424 
finalize_sym_filter(int wiener_win,int32_t * f,InterpKernel fi)1425 static AOM_INLINE void finalize_sym_filter(int wiener_win, int32_t *f,
1426                                            InterpKernel fi) {
1427   int i;
1428   const int wiener_halfwin = (wiener_win >> 1);
1429 
1430   for (i = 0; i < wiener_halfwin; ++i) {
1431     const int64_t dividend = (int64_t)f[i] * WIENER_FILT_STEP;
1432     const int64_t divisor = WIENER_TAP_SCALE_FACTOR;
1433     // Perform this division with proper rounding rather than truncation
1434     if (dividend < 0) {
1435       fi[i] = (int16_t)((dividend - (divisor / 2)) / divisor);
1436     } else {
1437       fi[i] = (int16_t)((dividend + (divisor / 2)) / divisor);
1438     }
1439   }
1440   // Specialize for 7-tap filter
1441   if (wiener_win == WIENER_WIN) {
1442     fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
1443     fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
1444     fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
1445   } else {
1446     fi[2] = CLIP(fi[1], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
1447     fi[1] = CLIP(fi[0], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
1448     fi[0] = 0;
1449   }
1450   // Satisfy filter constraints
1451   fi[WIENER_WIN - 1] = fi[0];
1452   fi[WIENER_WIN - 2] = fi[1];
1453   fi[WIENER_WIN - 3] = fi[2];
1454   // The central element has an implicit +WIENER_FILT_STEP
1455   fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
1456 }
1457 
count_wiener_bits(int wiener_win,WienerInfo * wiener_info,WienerInfo * ref_wiener_info)1458 static int count_wiener_bits(int wiener_win, WienerInfo *wiener_info,
1459                              WienerInfo *ref_wiener_info) {
1460   int bits = 0;
1461   if (wiener_win == WIENER_WIN)
1462     bits += aom_count_primitive_refsubexpfin(
1463         WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
1464         WIENER_FILT_TAP0_SUBEXP_K,
1465         ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
1466         wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
1467   bits += aom_count_primitive_refsubexpfin(
1468       WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
1469       WIENER_FILT_TAP1_SUBEXP_K,
1470       ref_wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV,
1471       wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV);
1472   bits += aom_count_primitive_refsubexpfin(
1473       WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
1474       WIENER_FILT_TAP2_SUBEXP_K,
1475       ref_wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV,
1476       wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV);
1477   if (wiener_win == WIENER_WIN)
1478     bits += aom_count_primitive_refsubexpfin(
1479         WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
1480         WIENER_FILT_TAP0_SUBEXP_K,
1481         ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
1482         wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
1483   bits += aom_count_primitive_refsubexpfin(
1484       WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
1485       WIENER_FILT_TAP1_SUBEXP_K,
1486       ref_wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV,
1487       wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV);
1488   bits += aom_count_primitive_refsubexpfin(
1489       WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
1490       WIENER_FILT_TAP2_SUBEXP_K,
1491       ref_wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV,
1492       wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV);
1493   return bits;
1494 }
1495 
finer_search_wiener(const RestSearchCtxt * rsc,const RestorationTileLimits * limits,RestorationUnitInfo * rui,int wiener_win)1496 static int64_t finer_search_wiener(const RestSearchCtxt *rsc,
1497                                    const RestorationTileLimits *limits,
1498                                    RestorationUnitInfo *rui, int wiener_win) {
1499   const int plane_off = (WIENER_WIN - wiener_win) >> 1;
1500   int64_t err = try_restoration_unit(rsc, limits, rui);
1501 
1502   if (rsc->lpf_sf->disable_wiener_coeff_refine_search) return err;
1503 
1504   // Refinement search around the wiener filter coefficients.
1505   int64_t err2;
1506   int tap_min[] = { WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP1_MINV,
1507                     WIENER_FILT_TAP2_MINV };
1508   int tap_max[] = { WIENER_FILT_TAP0_MAXV, WIENER_FILT_TAP1_MAXV,
1509                     WIENER_FILT_TAP2_MAXV };
1510 
1511   WienerInfo *plane_wiener = &rui->wiener_info;
1512 
1513   // printf("err  pre = %"PRId64"\n", err);
1514   const int start_step = 4;
1515   for (int s = start_step; s >= 1; s >>= 1) {
1516     for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
1517       int skip = 0;
1518       do {
1519         if (plane_wiener->hfilter[p] - s >= tap_min[p]) {
1520           plane_wiener->hfilter[p] -= s;
1521           plane_wiener->hfilter[WIENER_WIN - p - 1] -= s;
1522           plane_wiener->hfilter[WIENER_HALFWIN] += 2 * s;
1523           err2 = try_restoration_unit(rsc, limits, rui);
1524           if (err2 > err) {
1525             plane_wiener->hfilter[p] += s;
1526             plane_wiener->hfilter[WIENER_WIN - p - 1] += s;
1527             plane_wiener->hfilter[WIENER_HALFWIN] -= 2 * s;
1528           } else {
1529             err = err2;
1530             skip = 1;
1531             // At the highest step size continue moving in the same direction
1532             if (s == start_step) continue;
1533           }
1534         }
1535         break;
1536       } while (1);
1537       if (skip) break;
1538       do {
1539         if (plane_wiener->hfilter[p] + s <= tap_max[p]) {
1540           plane_wiener->hfilter[p] += s;
1541           plane_wiener->hfilter[WIENER_WIN - p - 1] += s;
1542           plane_wiener->hfilter[WIENER_HALFWIN] -= 2 * s;
1543           err2 = try_restoration_unit(rsc, limits, rui);
1544           if (err2 > err) {
1545             plane_wiener->hfilter[p] -= s;
1546             plane_wiener->hfilter[WIENER_WIN - p - 1] -= s;
1547             plane_wiener->hfilter[WIENER_HALFWIN] += 2 * s;
1548           } else {
1549             err = err2;
1550             // At the highest step size continue moving in the same direction
1551             if (s == start_step) continue;
1552           }
1553         }
1554         break;
1555       } while (1);
1556     }
1557     for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
1558       int skip = 0;
1559       do {
1560         if (plane_wiener->vfilter[p] - s >= tap_min[p]) {
1561           plane_wiener->vfilter[p] -= s;
1562           plane_wiener->vfilter[WIENER_WIN - p - 1] -= s;
1563           plane_wiener->vfilter[WIENER_HALFWIN] += 2 * s;
1564           err2 = try_restoration_unit(rsc, limits, rui);
1565           if (err2 > err) {
1566             plane_wiener->vfilter[p] += s;
1567             plane_wiener->vfilter[WIENER_WIN - p - 1] += s;
1568             plane_wiener->vfilter[WIENER_HALFWIN] -= 2 * s;
1569           } else {
1570             err = err2;
1571             skip = 1;
1572             // At the highest step size continue moving in the same direction
1573             if (s == start_step) continue;
1574           }
1575         }
1576         break;
1577       } while (1);
1578       if (skip) break;
1579       do {
1580         if (plane_wiener->vfilter[p] + s <= tap_max[p]) {
1581           plane_wiener->vfilter[p] += s;
1582           plane_wiener->vfilter[WIENER_WIN - p - 1] += s;
1583           plane_wiener->vfilter[WIENER_HALFWIN] -= 2 * s;
1584           err2 = try_restoration_unit(rsc, limits, rui);
1585           if (err2 > err) {
1586             plane_wiener->vfilter[p] -= s;
1587             plane_wiener->vfilter[WIENER_WIN - p - 1] -= s;
1588             plane_wiener->vfilter[WIENER_HALFWIN] += 2 * s;
1589           } else {
1590             err = err2;
1591             // At the highest step size continue moving in the same direction
1592             if (s == start_step) continue;
1593           }
1594         }
1595         break;
1596       } while (1);
1597     }
1598   }
1599   // printf("err post = %"PRId64"\n", err);
1600   return err;
1601 }
1602 
search_wiener(const RestorationTileLimits * limits,int rest_unit_idx,void * priv,int32_t * tmpbuf,RestorationLineBuffers * rlbs,struct aom_internal_error_info * error_info)1603 static AOM_INLINE void search_wiener(
1604     const RestorationTileLimits *limits, int rest_unit_idx, void *priv,
1605     int32_t *tmpbuf, RestorationLineBuffers *rlbs,
1606     struct aom_internal_error_info *error_info) {
1607   (void)tmpbuf;
1608   (void)rlbs;
1609   (void)error_info;
1610   RestSearchCtxt *rsc = (RestSearchCtxt *)priv;
1611   RestUnitSearchInfo *rusi = &rsc->rusi[rest_unit_idx];
1612 
1613   const MACROBLOCK *const x = rsc->x;
1614   const int64_t bits_none = x->mode_costs.wiener_restore_cost[0];
1615 
1616   // Skip Wiener search for low variance contents
1617   if (rsc->lpf_sf->prune_wiener_based_on_src_var) {
1618     const int scale[3] = { 0, 1, 2 };
1619     // Obtain the normalized Qscale
1620     const int qs = av1_dc_quant_QTX(rsc->cm->quant_params.base_qindex, 0,
1621                                     rsc->cm->seq_params->bit_depth) >>
1622                    3;
1623     // Derive threshold as sqr(normalized Qscale) * scale / 16,
1624     const uint64_t thresh =
1625         (qs * qs * scale[rsc->lpf_sf->prune_wiener_based_on_src_var]) >> 4;
1626     const int highbd = rsc->cm->seq_params->use_highbitdepth;
1627     const uint64_t src_var =
1628         var_restoration_unit(limits, rsc->src, rsc->plane, highbd);
1629     // Do not perform Wiener search if source variance is lower than threshold
1630     // or if the reconstruction error is zero
1631     int prune_wiener = (src_var < thresh) || (rsc->sse[RESTORE_NONE] == 0);
1632     if (prune_wiener) {
1633       rsc->total_bits[RESTORE_WIENER] += bits_none;
1634       rsc->total_sse[RESTORE_WIENER] += rsc->sse[RESTORE_NONE];
1635       rusi->best_rtype[RESTORE_WIENER - 1] = RESTORE_NONE;
1636       rsc->sse[RESTORE_WIENER] = INT64_MAX;
1637       if (rsc->lpf_sf->prune_sgr_based_on_wiener == 2) rsc->skip_sgr_eval = 1;
1638       return;
1639     }
1640   }
1641 
1642   const int wiener_win =
1643       (rsc->plane == AOM_PLANE_Y) ? WIENER_WIN : WIENER_WIN_CHROMA;
1644 
1645   int reduced_wiener_win = wiener_win;
1646   if (rsc->lpf_sf->reduce_wiener_window_size) {
1647     reduced_wiener_win =
1648         (rsc->plane == AOM_PLANE_Y) ? WIENER_WIN_REDUCED : WIENER_WIN_CHROMA;
1649   }
1650 
1651   int64_t M[WIENER_WIN2];
1652   int64_t H[WIENER_WIN2 * WIENER_WIN2];
1653   int32_t vfilter[WIENER_WIN], hfilter[WIENER_WIN];
1654 
1655 #if CONFIG_AV1_HIGHBITDEPTH
1656   const AV1_COMMON *const cm = rsc->cm;
1657   if (cm->seq_params->use_highbitdepth) {
1658     // TODO(any) : Add support for use_downsampled_wiener_stats SF in HBD
1659     // functions. Optimize intrinsics of HBD design similar to LBD (i.e.,
1660     // pre-calculate d and s buffers and avoid most of the C operations).
1661     av1_compute_stats_highbd(reduced_wiener_win, rsc->dgd_buffer,
1662                              rsc->src_buffer, limits->h_start, limits->h_end,
1663                              limits->v_start, limits->v_end, rsc->dgd_stride,
1664                              rsc->src_stride, M, H, cm->seq_params->bit_depth);
1665   } else {
1666     av1_compute_stats(reduced_wiener_win, rsc->dgd_buffer, rsc->src_buffer,
1667                       rsc->dgd_avg, rsc->src_avg, limits->h_start,
1668                       limits->h_end, limits->v_start, limits->v_end,
1669                       rsc->dgd_stride, rsc->src_stride, M, H,
1670                       rsc->lpf_sf->use_downsampled_wiener_stats);
1671   }
1672 #else
1673   av1_compute_stats(reduced_wiener_win, rsc->dgd_buffer, rsc->src_buffer,
1674                     rsc->dgd_avg, rsc->src_avg, limits->h_start, limits->h_end,
1675                     limits->v_start, limits->v_end, rsc->dgd_stride,
1676                     rsc->src_stride, M, H,
1677                     rsc->lpf_sf->use_downsampled_wiener_stats);
1678 #endif
1679 
1680   wiener_decompose_sep_sym(reduced_wiener_win, M, H, vfilter, hfilter);
1681 
1682   RestorationUnitInfo rui;
1683   memset(&rui, 0, sizeof(rui));
1684   rui.restoration_type = RESTORE_WIENER;
1685   finalize_sym_filter(reduced_wiener_win, vfilter, rui.wiener_info.vfilter);
1686   finalize_sym_filter(reduced_wiener_win, hfilter, rui.wiener_info.hfilter);
1687 
1688   // Filter score computes the value of the function x'*A*x - x'*b for the
1689   // learned filter and compares it against identity filer. If there is no
1690   // reduction in the function, the filter is reverted back to identity
1691   if (compute_score(reduced_wiener_win, M, H, rui.wiener_info.vfilter,
1692                     rui.wiener_info.hfilter) > 0) {
1693     rsc->total_bits[RESTORE_WIENER] += bits_none;
1694     rsc->total_sse[RESTORE_WIENER] += rsc->sse[RESTORE_NONE];
1695     rusi->best_rtype[RESTORE_WIENER - 1] = RESTORE_NONE;
1696     rsc->sse[RESTORE_WIENER] = INT64_MAX;
1697     if (rsc->lpf_sf->prune_sgr_based_on_wiener == 2) rsc->skip_sgr_eval = 1;
1698     return;
1699   }
1700 
1701   rsc->sse[RESTORE_WIENER] =
1702       finer_search_wiener(rsc, limits, &rui, reduced_wiener_win);
1703   rusi->wiener = rui.wiener_info;
1704 
1705   if (reduced_wiener_win != WIENER_WIN) {
1706     assert(rui.wiener_info.vfilter[0] == 0 &&
1707            rui.wiener_info.vfilter[WIENER_WIN - 1] == 0);
1708     assert(rui.wiener_info.hfilter[0] == 0 &&
1709            rui.wiener_info.hfilter[WIENER_WIN - 1] == 0);
1710   }
1711 
1712   const int64_t bits_wiener =
1713       x->mode_costs.wiener_restore_cost[1] +
1714       (count_wiener_bits(wiener_win, &rusi->wiener, &rsc->ref_wiener)
1715        << AV1_PROB_COST_SHIFT);
1716 
1717   double cost_none = RDCOST_DBL_WITH_NATIVE_BD_DIST(
1718       x->rdmult, bits_none >> 4, rsc->sse[RESTORE_NONE],
1719       rsc->cm->seq_params->bit_depth);
1720   double cost_wiener = RDCOST_DBL_WITH_NATIVE_BD_DIST(
1721       x->rdmult, bits_wiener >> 4, rsc->sse[RESTORE_WIENER],
1722       rsc->cm->seq_params->bit_depth);
1723 
1724   RestorationType rtype =
1725       (cost_wiener < cost_none) ? RESTORE_WIENER : RESTORE_NONE;
1726   rusi->best_rtype[RESTORE_WIENER - 1] = rtype;
1727 
1728   // Set 'skip_sgr_eval' based on rdcost ratio of RESTORE_WIENER and
1729   // RESTORE_NONE or based on best_rtype
1730   if (rsc->lpf_sf->prune_sgr_based_on_wiener == 1) {
1731     rsc->skip_sgr_eval = cost_wiener > (1.01 * cost_none);
1732   } else if (rsc->lpf_sf->prune_sgr_based_on_wiener == 2) {
1733     rsc->skip_sgr_eval = rusi->best_rtype[RESTORE_WIENER - 1] == RESTORE_NONE;
1734   }
1735 
1736 #if DEBUG_LR_COSTING
1737   // Store ref params for later checking
1738   lr_ref_params[RESTORE_WIENER][rsc->plane][rest_unit_idx].wiener_info =
1739       rsc->ref_wiener;
1740 #endif  // DEBUG_LR_COSTING
1741 
1742   rsc->total_sse[RESTORE_WIENER] += rsc->sse[rtype];
1743   rsc->total_bits[RESTORE_WIENER] +=
1744       (cost_wiener < cost_none) ? bits_wiener : bits_none;
1745   if (cost_wiener < cost_none) rsc->ref_wiener = rusi->wiener;
1746 }
1747 
search_norestore(const RestorationTileLimits * limits,int rest_unit_idx,void * priv,int32_t * tmpbuf,RestorationLineBuffers * rlbs,struct aom_internal_error_info * error_info)1748 static AOM_INLINE void search_norestore(
1749     const RestorationTileLimits *limits, int rest_unit_idx, void *priv,
1750     int32_t *tmpbuf, RestorationLineBuffers *rlbs,
1751     struct aom_internal_error_info *error_info) {
1752   (void)rest_unit_idx;
1753   (void)tmpbuf;
1754   (void)rlbs;
1755   (void)error_info;
1756 
1757   RestSearchCtxt *rsc = (RestSearchCtxt *)priv;
1758 
1759   const int highbd = rsc->cm->seq_params->use_highbitdepth;
1760   rsc->sse[RESTORE_NONE] = sse_restoration_unit(
1761       limits, rsc->src, &rsc->cm->cur_frame->buf, rsc->plane, highbd);
1762 
1763   rsc->total_sse[RESTORE_NONE] += rsc->sse[RESTORE_NONE];
1764 }
1765 
search_switchable(const RestorationTileLimits * limits,int rest_unit_idx,void * priv,int32_t * tmpbuf,RestorationLineBuffers * rlbs,struct aom_internal_error_info * error_info)1766 static AOM_INLINE void search_switchable(
1767     const RestorationTileLimits *limits, int rest_unit_idx, void *priv,
1768     int32_t *tmpbuf, RestorationLineBuffers *rlbs,
1769     struct aom_internal_error_info *error_info) {
1770   (void)limits;
1771   (void)tmpbuf;
1772   (void)rlbs;
1773   (void)error_info;
1774   RestSearchCtxt *rsc = (RestSearchCtxt *)priv;
1775   RestUnitSearchInfo *rusi = &rsc->rusi[rest_unit_idx];
1776 
1777   const MACROBLOCK *const x = rsc->x;
1778 
1779   const int wiener_win =
1780       (rsc->plane == AOM_PLANE_Y) ? WIENER_WIN : WIENER_WIN_CHROMA;
1781 
1782   double best_cost = 0;
1783   int64_t best_bits = 0;
1784   RestorationType best_rtype = RESTORE_NONE;
1785 
1786   for (RestorationType r = 0; r < RESTORE_SWITCHABLE_TYPES; ++r) {
1787     // If this restoration mode was skipped, or could not find a solution
1788     // that was better than RESTORE_NONE, then we can't select it here either.
1789     //
1790     // Note: It is possible for the restoration search functions to find a
1791     // filter which is better than RESTORE_NONE when looking purely at SSE, but
1792     // for it to be rejected overall due to its rate cost. In this case, there
1793     // is a chance that it may be have a lower rate cost when looking at
1794     // RESTORE_SWITCHABLE, and so it might be acceptable here.
1795     //
1796     // Therefore we prune based on SSE, rather than on whether or not the
1797     // previous search function selected this mode.
1798     if (r > RESTORE_NONE) {
1799       if (rsc->sse[r] > rsc->sse[RESTORE_NONE]) continue;
1800     }
1801 
1802     const int64_t sse = rsc->sse[r];
1803     int64_t coeff_pcost = 0;
1804     switch (r) {
1805       case RESTORE_NONE: coeff_pcost = 0; break;
1806       case RESTORE_WIENER:
1807         coeff_pcost = count_wiener_bits(wiener_win, &rusi->wiener,
1808                                         &rsc->switchable_ref_wiener);
1809         break;
1810       case RESTORE_SGRPROJ:
1811         coeff_pcost =
1812             count_sgrproj_bits(&rusi->sgrproj, &rsc->switchable_ref_sgrproj);
1813         break;
1814       default: assert(0); break;
1815     }
1816     const int64_t coeff_bits = coeff_pcost << AV1_PROB_COST_SHIFT;
1817     const int64_t bits = x->mode_costs.switchable_restore_cost[r] + coeff_bits;
1818     double cost = RDCOST_DBL_WITH_NATIVE_BD_DIST(
1819         x->rdmult, bits >> 4, sse, rsc->cm->seq_params->bit_depth);
1820     if (r == RESTORE_SGRPROJ && rusi->sgrproj.ep < 10)
1821       cost *= (1 + DUAL_SGR_PENALTY_MULT * rsc->lpf_sf->dual_sgr_penalty_level);
1822     if (r == 0 || cost < best_cost) {
1823       best_cost = cost;
1824       best_bits = bits;
1825       best_rtype = r;
1826     }
1827   }
1828 
1829   rusi->best_rtype[RESTORE_SWITCHABLE - 1] = best_rtype;
1830 
1831 #if DEBUG_LR_COSTING
1832   // Store ref params for later checking
1833   lr_ref_params[RESTORE_SWITCHABLE][rsc->plane][rest_unit_idx].wiener_info =
1834       rsc->switchable_ref_wiener;
1835   lr_ref_params[RESTORE_SWITCHABLE][rsc->plane][rest_unit_idx].sgrproj_info =
1836       rsc->switchable_ref_sgrproj;
1837 #endif  // DEBUG_LR_COSTING
1838 
1839   rsc->total_sse[RESTORE_SWITCHABLE] += rsc->sse[best_rtype];
1840   rsc->total_bits[RESTORE_SWITCHABLE] += best_bits;
1841   if (best_rtype == RESTORE_WIENER) rsc->switchable_ref_wiener = rusi->wiener;
1842   if (best_rtype == RESTORE_SGRPROJ)
1843     rsc->switchable_ref_sgrproj = rusi->sgrproj;
1844 }
1845 
copy_unit_info(RestorationType frame_rtype,const RestUnitSearchInfo * rusi,RestorationUnitInfo * rui)1846 static AOM_INLINE void copy_unit_info(RestorationType frame_rtype,
1847                                       const RestUnitSearchInfo *rusi,
1848                                       RestorationUnitInfo *rui) {
1849   assert(frame_rtype > 0);
1850   rui->restoration_type = rusi->best_rtype[frame_rtype - 1];
1851   if (rui->restoration_type == RESTORE_WIENER)
1852     rui->wiener_info = rusi->wiener;
1853   else
1854     rui->sgrproj_info = rusi->sgrproj;
1855 }
1856 
restoration_search(AV1_COMMON * cm,int plane,RestSearchCtxt * rsc,bool * disable_lr_filter)1857 static void restoration_search(AV1_COMMON *cm, int plane, RestSearchCtxt *rsc,
1858                                bool *disable_lr_filter) {
1859   const BLOCK_SIZE sb_size = cm->seq_params->sb_size;
1860   const int mib_size_log2 = cm->seq_params->mib_size_log2;
1861   const CommonTileParams *tiles = &cm->tiles;
1862   const int is_uv = plane > 0;
1863   const int ss_y = is_uv && cm->seq_params->subsampling_y;
1864   RestorationInfo *rsi = &cm->rst_info[plane];
1865   const int ru_size = rsi->restoration_unit_size;
1866   const int ext_size = ru_size * 3 / 2;
1867 
1868   int plane_w, plane_h;
1869   av1_get_upsampled_plane_size(cm, is_uv, &plane_w, &plane_h);
1870 
1871   static const rest_unit_visitor_t funs[RESTORE_TYPES] = {
1872     search_norestore, search_wiener, search_sgrproj, search_switchable
1873   };
1874 
1875   const int plane_num_units = rsi->num_rest_units;
1876   const RestorationType num_rtypes =
1877       (plane_num_units > 1) ? RESTORE_TYPES : RESTORE_SWITCHABLE_TYPES;
1878 
1879   reset_rsc(rsc);
1880 
1881   // Iterate over restoration units in encoding order, so that each RU gets
1882   // the correct reference parameters when we cost it up. This is effectively
1883   // a nested iteration over:
1884   // * Each tile, order does not matter
1885   //   * Each superblock within that tile, in raster order
1886   //     * Each LR unit which is coded within that superblock, in raster order
1887   for (int tile_row = 0; tile_row < tiles->rows; tile_row++) {
1888     int sb_row_start = tiles->row_start_sb[tile_row];
1889     int sb_row_end = tiles->row_start_sb[tile_row + 1];
1890     for (int tile_col = 0; tile_col < tiles->cols; tile_col++) {
1891       int sb_col_start = tiles->col_start_sb[tile_col];
1892       int sb_col_end = tiles->col_start_sb[tile_col + 1];
1893 
1894       // Reset reference parameters for delta-coding at the start of each tile
1895       rsc_on_tile(rsc);
1896 
1897       for (int sb_row = sb_row_start; sb_row < sb_row_end; sb_row++) {
1898         int mi_row = sb_row << mib_size_log2;
1899         for (int sb_col = sb_col_start; sb_col < sb_col_end; sb_col++) {
1900           int mi_col = sb_col << mib_size_log2;
1901 
1902           int rcol0, rcol1, rrow0, rrow1;
1903           int has_lr_info = av1_loop_restoration_corners_in_sb(
1904               cm, plane, mi_row, mi_col, sb_size, &rcol0, &rcol1, &rrow0,
1905               &rrow1);
1906 
1907           if (!has_lr_info) continue;
1908 
1909           RestorationTileLimits limits;
1910           for (int rrow = rrow0; rrow < rrow1; rrow++) {
1911             int y0 = rrow * ru_size;
1912             int remaining_h = plane_h - y0;
1913             int h = (remaining_h < ext_size) ? remaining_h : ru_size;
1914 
1915             limits.v_start = y0;
1916             limits.v_end = y0 + h;
1917             assert(limits.v_end <= plane_h);
1918             // Offset upwards to align with the restoration processing stripe
1919             const int voffset = RESTORATION_UNIT_OFFSET >> ss_y;
1920             limits.v_start = AOMMAX(0, limits.v_start - voffset);
1921             if (limits.v_end < plane_h) limits.v_end -= voffset;
1922 
1923             for (int rcol = rcol0; rcol < rcol1; rcol++) {
1924               int x0 = rcol * ru_size;
1925               int remaining_w = plane_w - x0;
1926               int w = (remaining_w < ext_size) ? remaining_w : ru_size;
1927 
1928               limits.h_start = x0;
1929               limits.h_end = x0 + w;
1930               assert(limits.h_end <= plane_w);
1931 
1932               const int unit_idx = rrow * rsi->horz_units + rcol;
1933 
1934               rsc->skip_sgr_eval = 0;
1935               for (RestorationType r = RESTORE_NONE; r < num_rtypes; r++) {
1936                 if (disable_lr_filter[r]) continue;
1937 
1938                 funs[r](&limits, unit_idx, rsc, rsc->cm->rst_tmpbuf, NULL,
1939                         cm->error);
1940               }
1941             }
1942           }
1943         }
1944       }
1945     }
1946   }
1947 }
1948 
av1_derive_flags_for_lr_processing(const LOOP_FILTER_SPEED_FEATURES * lpf_sf,bool * disable_lr_filter)1949 static INLINE void av1_derive_flags_for_lr_processing(
1950     const LOOP_FILTER_SPEED_FEATURES *lpf_sf, bool *disable_lr_filter) {
1951   const bool is_wiener_disabled = lpf_sf->disable_wiener_filter;
1952   const bool is_sgr_disabled = lpf_sf->disable_sgr_filter;
1953 
1954   // Enable None Loop restoration filter if either of Wiener or Self-guided is
1955   // enabled.
1956   disable_lr_filter[RESTORE_NONE] = (is_wiener_disabled && is_sgr_disabled);
1957 
1958   disable_lr_filter[RESTORE_WIENER] = is_wiener_disabled;
1959   disable_lr_filter[RESTORE_SGRPROJ] = is_sgr_disabled;
1960 
1961   // Enable Swicthable Loop restoration filter if both of the Wiener and
1962   // Self-guided are enabled.
1963   disable_lr_filter[RESTORE_SWITCHABLE] =
1964       (is_wiener_disabled || is_sgr_disabled);
1965 }
1966 
1967 #define COUPLED_CHROMA_FROM_LUMA_RESTORATION 0
1968 // Allocate both decoder-side and encoder-side info structs for a single plane.
1969 // The unit size passed in should be the minimum size which we are going to
1970 // search; before each search, set_restoration_unit_size() must be called to
1971 // configure the actual size.
allocate_search_structs(AV1_COMMON * cm,RestorationInfo * rsi,int is_uv,int min_luma_unit_size)1972 static RestUnitSearchInfo *allocate_search_structs(AV1_COMMON *cm,
1973                                                    RestorationInfo *rsi,
1974                                                    int is_uv,
1975                                                    int min_luma_unit_size) {
1976 #if COUPLED_CHROMA_FROM_LUMA_RESTORATION
1977   int sx = cm->seq_params.subsampling_x;
1978   int sy = cm->seq_params.subsampling_y;
1979   int s = (p > 0) ? AOMMIN(sx, sy) : 0;
1980 #else
1981   int s = 0;
1982 #endif  // !COUPLED_CHROMA_FROM_LUMA_RESTORATION
1983   int min_unit_size = min_luma_unit_size >> s;
1984 
1985   int plane_w, plane_h;
1986   av1_get_upsampled_plane_size(cm, is_uv, &plane_w, &plane_h);
1987 
1988   const int max_horz_units = av1_lr_count_units(min_unit_size, plane_w);
1989   const int max_vert_units = av1_lr_count_units(min_unit_size, plane_h);
1990   const int max_num_units = max_horz_units * max_vert_units;
1991 
1992   aom_free(rsi->unit_info);
1993   CHECK_MEM_ERROR(cm, rsi->unit_info,
1994                   (RestorationUnitInfo *)aom_memalign(
1995                       16, sizeof(*rsi->unit_info) * max_num_units));
1996 
1997   RestUnitSearchInfo *rusi;
1998   CHECK_MEM_ERROR(
1999       cm, rusi,
2000       (RestUnitSearchInfo *)aom_memalign(16, sizeof(*rusi) * max_num_units));
2001 
2002   // If the restoration unit dimensions are not multiples of
2003   // rsi->restoration_unit_size then some elements of the rusi array may be
2004   // left uninitialised when we reach copy_unit_info(...). This is not a
2005   // problem, as these elements are ignored later, but in order to quiet
2006   // Valgrind's warnings we initialise the array below.
2007   memset(rusi, 0, sizeof(*rusi) * max_num_units);
2008 
2009   return rusi;
2010 }
2011 
set_restoration_unit_size(AV1_COMMON * cm,RestorationInfo * rsi,int is_uv,int luma_unit_size)2012 static void set_restoration_unit_size(AV1_COMMON *cm, RestorationInfo *rsi,
2013                                       int is_uv, int luma_unit_size) {
2014 #if COUPLED_CHROMA_FROM_LUMA_RESTORATION
2015   int sx = cm->seq_params.subsampling_x;
2016   int sy = cm->seq_params.subsampling_y;
2017   int s = (p > 0) ? AOMMIN(sx, sy) : 0;
2018 #else
2019   int s = 0;
2020 #endif  // !COUPLED_CHROMA_FROM_LUMA_RESTORATION
2021   int unit_size = luma_unit_size >> s;
2022 
2023   int plane_w, plane_h;
2024   av1_get_upsampled_plane_size(cm, is_uv, &plane_w, &plane_h);
2025 
2026   const int horz_units = av1_lr_count_units(unit_size, plane_w);
2027   const int vert_units = av1_lr_count_units(unit_size, plane_h);
2028 
2029   rsi->restoration_unit_size = unit_size;
2030   rsi->num_rest_units = horz_units * vert_units;
2031   rsi->horz_units = horz_units;
2032   rsi->vert_units = vert_units;
2033 }
2034 
av1_pick_filter_restoration(const YV12_BUFFER_CONFIG * src,AV1_COMP * cpi)2035 void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi) {
2036   AV1_COMMON *const cm = &cpi->common;
2037   MACROBLOCK *const x = &cpi->td.mb;
2038   const SequenceHeader *const seq_params = cm->seq_params;
2039   const LOOP_FILTER_SPEED_FEATURES *lpf_sf = &cpi->sf.lpf_sf;
2040   const int num_planes = av1_num_planes(cm);
2041   const int highbd = cm->seq_params->use_highbitdepth;
2042   assert(!cm->features.all_lossless);
2043 
2044   av1_fill_lr_rates(&x->mode_costs, x->e_mbd.tile_ctx);
2045 
2046   // Select unit size based on speed feature settings, and allocate
2047   // rui structs based on this size
2048   int min_lr_unit_size = cpi->sf.lpf_sf.min_lr_unit_size;
2049   int max_lr_unit_size = cpi->sf.lpf_sf.max_lr_unit_size;
2050 
2051   // The minimum allowed unit size at a syntax level is 1 superblock.
2052   // Apply this constraint here so that the speed features code which sets
2053   // cpi->sf.lpf_sf.min_lr_unit_size does not need to know the superblock size
2054   min_lr_unit_size =
2055       AOMMAX(min_lr_unit_size, block_size_wide[cm->seq_params->sb_size]);
2056 
2057   for (int plane = 0; plane < num_planes; ++plane) {
2058     cpi->pick_lr_ctxt.rusi[plane] = allocate_search_structs(
2059         cm, &cm->rst_info[plane], plane > 0, min_lr_unit_size);
2060   }
2061 
2062   x->rdmult = cpi->rd.RDMULT;
2063 
2064   // Allocate the frame buffer trial_frame_rst, which is used to temporarily
2065   // store the loop restored frame.
2066   if (aom_realloc_frame_buffer(
2067           &cpi->trial_frame_rst, cm->superres_upscaled_width,
2068           cm->superres_upscaled_height, seq_params->subsampling_x,
2069           seq_params->subsampling_y, highbd, AOM_RESTORATION_FRAME_BORDER,
2070           cm->features.byte_alignment, NULL, NULL, NULL, false, 0))
2071     aom_internal_error(cm->error, AOM_CODEC_MEM_ERROR,
2072                        "Failed to allocate trial restored frame buffer");
2073 
2074   RestSearchCtxt rsc;
2075 
2076   // The buffers 'src_avg' and 'dgd_avg' are used to compute H and M buffers.
2077   // These buffers are only required for the AVX2 and NEON implementations of
2078   // av1_compute_stats. The buffer size required is calculated based on maximum
2079   // width and height of the LRU (i.e., from foreach_rest_unit_in_plane() 1.5
2080   // times the RESTORATION_UNITSIZE_MAX) allowed for Wiener filtering. The width
2081   // and height aligned to multiple of 16 is considered for intrinsic purpose.
2082   rsc.dgd_avg = NULL;
2083   rsc.src_avg = NULL;
2084 #if HAVE_AVX2 || HAVE_NEON
2085   // The buffers allocated below are used during Wiener filter processing of low
2086   // bitdepth path. Hence, allocate the same when Wiener filter is enabled in
2087   // low bitdepth path.
2088   if (!cpi->sf.lpf_sf.disable_wiener_filter && !highbd) {
2089     const int buf_size = sizeof(*cpi->pick_lr_ctxt.dgd_avg) * 6 *
2090                          RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX;
2091     CHECK_MEM_ERROR(cm, cpi->pick_lr_ctxt.dgd_avg,
2092                     (int16_t *)aom_memalign(32, buf_size));
2093 
2094     rsc.dgd_avg = cpi->pick_lr_ctxt.dgd_avg;
2095     // When LRU width isn't multiple of 16, the 256 bits load instruction used
2096     // in AVX2 intrinsic can read data beyond valid LRU. Hence, in order to
2097     // silence Valgrind warning this buffer is initialized with zero. Overhead
2098     // due to this initialization is negligible since it is done at frame level.
2099     memset(rsc.dgd_avg, 0, buf_size);
2100     rsc.src_avg =
2101         rsc.dgd_avg + 3 * RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX;
2102     // Asserts the starting address of src_avg is always 32-bytes aligned.
2103     assert(!((intptr_t)rsc.src_avg % 32));
2104   }
2105 #endif
2106 
2107   // Initialize all planes, so that any planes we skip searching will still have
2108   // valid data
2109   for (int plane = 0; plane < num_planes; plane++) {
2110     cm->rst_info[plane].frame_restoration_type = RESTORE_NONE;
2111   }
2112 
2113   // Decide which planes to search
2114   int plane_start, plane_end;
2115 
2116   if (lpf_sf->disable_loop_restoration_luma) {
2117     plane_start = AOM_PLANE_U;
2118   } else {
2119     plane_start = AOM_PLANE_Y;
2120   }
2121 
2122   if (num_planes == 1 || lpf_sf->disable_loop_restoration_chroma) {
2123     plane_end = AOM_PLANE_Y;
2124   } else {
2125     plane_end = AOM_PLANE_V;
2126   }
2127 
2128   // Derive the flags to enable/disable Loop restoration filters based on the
2129   // speed features 'disable_wiener_filter' and 'disable_sgr_filter'.
2130   bool disable_lr_filter[RESTORE_TYPES] = { false };
2131   av1_derive_flags_for_lr_processing(lpf_sf, disable_lr_filter);
2132 
2133   for (int plane = plane_start; plane <= plane_end; plane++) {
2134     const YV12_BUFFER_CONFIG *dgd = &cm->cur_frame->buf;
2135     const int is_uv = plane != AOM_PLANE_Y;
2136     int plane_w, plane_h;
2137     av1_get_upsampled_plane_size(cm, is_uv, &plane_w, &plane_h);
2138     av1_extend_frame(dgd->buffers[plane], plane_w, plane_h, dgd->strides[is_uv],
2139                      RESTORATION_BORDER, RESTORATION_BORDER, highbd);
2140   }
2141 
2142   double best_cost = DBL_MAX;
2143   int best_luma_unit_size = max_lr_unit_size;
2144   for (int luma_unit_size = max_lr_unit_size;
2145        luma_unit_size >= min_lr_unit_size; luma_unit_size >>= 1) {
2146     int64_t bits_this_size = 0;
2147     int64_t sse_this_size = 0;
2148     RestorationType best_rtype[MAX_MB_PLANE] = { RESTORE_NONE, RESTORE_NONE,
2149                                                  RESTORE_NONE };
2150     for (int plane = plane_start; plane <= plane_end; ++plane) {
2151       set_restoration_unit_size(cm, &cm->rst_info[plane], plane > 0,
2152                                 luma_unit_size);
2153       init_rsc(src, &cpi->common, x, lpf_sf, plane,
2154                cpi->pick_lr_ctxt.rusi[plane], &cpi->trial_frame_rst, &rsc);
2155 
2156       restoration_search(cm, plane, &rsc, disable_lr_filter);
2157 
2158       const int plane_num_units = cm->rst_info[plane].num_rest_units;
2159       const RestorationType num_rtypes =
2160           (plane_num_units > 1) ? RESTORE_TYPES : RESTORE_SWITCHABLE_TYPES;
2161       double best_cost_this_plane = DBL_MAX;
2162       for (RestorationType r = 0; r < num_rtypes; ++r) {
2163         // Disable Loop restoration filter based on the flags set using speed
2164         // feature 'disable_wiener_filter' and 'disable_sgr_filter'.
2165         if (disable_lr_filter[r]) continue;
2166 
2167         double cost_this_plane = RDCOST_DBL_WITH_NATIVE_BD_DIST(
2168             x->rdmult, rsc.total_bits[r] >> 4, rsc.total_sse[r],
2169             cm->seq_params->bit_depth);
2170 
2171         if (cost_this_plane < best_cost_this_plane) {
2172           best_cost_this_plane = cost_this_plane;
2173           best_rtype[plane] = r;
2174         }
2175       }
2176 
2177       bits_this_size += rsc.total_bits[best_rtype[plane]];
2178       sse_this_size += rsc.total_sse[best_rtype[plane]];
2179     }
2180 
2181     double cost_this_size = RDCOST_DBL_WITH_NATIVE_BD_DIST(
2182         x->rdmult, bits_this_size >> 4, sse_this_size,
2183         cm->seq_params->bit_depth);
2184 
2185     if (cost_this_size < best_cost) {
2186       best_cost = cost_this_size;
2187       best_luma_unit_size = luma_unit_size;
2188       // Copy parameters out of rusi struct, before we overwrite it at
2189       // the start of the next iteration
2190       bool all_none = true;
2191       for (int plane = plane_start; plane <= plane_end; ++plane) {
2192         cm->rst_info[plane].frame_restoration_type = best_rtype[plane];
2193         if (best_rtype[plane] != RESTORE_NONE) {
2194           all_none = false;
2195           const int plane_num_units = cm->rst_info[plane].num_rest_units;
2196           for (int u = 0; u < plane_num_units; ++u) {
2197             copy_unit_info(best_rtype[plane], &cpi->pick_lr_ctxt.rusi[plane][u],
2198                            &cm->rst_info[plane].unit_info[u]);
2199           }
2200         }
2201       }
2202       // Heuristic: If all best_rtype entries are RESTORE_NONE, this means we
2203       // couldn't find any good filters at this size. So we likely won't find
2204       // any good filters at a smaller size either, so skip
2205       if (all_none) {
2206         break;
2207       }
2208     } else {
2209       // Heuristic: If this size is worse than the previous (larger) size, then
2210       // the next size down will likely be even worse, so skip
2211       break;
2212     }
2213   }
2214 
2215   // Final fixup to set the correct unit size
2216   // We set this for all planes, even ones we have skipped searching,
2217   // so that other code does not need to care which planes were and weren't
2218   // searched
2219   for (int plane = 0; plane < num_planes; ++plane) {
2220     set_restoration_unit_size(cm, &cm->rst_info[plane], plane > 0,
2221                               best_luma_unit_size);
2222   }
2223 
2224 #if HAVE_AVX || HAVE_NEON
2225   if (!cpi->sf.lpf_sf.disable_wiener_filter && !highbd) {
2226     aom_free(cpi->pick_lr_ctxt.dgd_avg);
2227     cpi->pick_lr_ctxt.dgd_avg = NULL;
2228   }
2229 #endif
2230   for (int plane = 0; plane < num_planes; plane++) {
2231     aom_free(cpi->pick_lr_ctxt.rusi[plane]);
2232     cpi->pick_lr_ctxt.rusi[plane] = NULL;
2233   }
2234 }
2235