• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/common/pred_common.h"
13 #include "av1/encoder/compound_type.h"
14 #include "av1/encoder/model_rd.h"
15 #include "av1/encoder/motion_search_facade.h"
16 #include "av1/encoder/rdopt_utils.h"
17 #include "av1/encoder/reconinter_enc.h"
18 #include "av1/encoder/tx_search.h"
19 
20 typedef int64_t (*pick_interinter_mask_type)(
21     const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
22     const uint8_t *const p0, const uint8_t *const p1,
23     const int16_t *const residual1, const int16_t *const diff10,
24     uint64_t *best_sse);
25 
26 // Checks if characteristics of search match
is_comp_rd_match(const AV1_COMP * const cpi,const MACROBLOCK * const x,const COMP_RD_STATS * st,const MB_MODE_INFO * const mi,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int * comp_rs2)27 static INLINE int is_comp_rd_match(const AV1_COMP *const cpi,
28                                    const MACROBLOCK *const x,
29                                    const COMP_RD_STATS *st,
30                                    const MB_MODE_INFO *const mi,
31                                    int32_t *comp_rate, int64_t *comp_dist,
32                                    int32_t *comp_model_rate,
33                                    int64_t *comp_model_dist, int *comp_rs2) {
34   // TODO(ranjit): Ensure that compound type search use regular filter always
35   // and check if following check can be removed
36   // Check if interp filter matches with previous case
37   if (st->filter.as_int != mi->interp_filters.as_int) return 0;
38 
39   const MACROBLOCKD *const xd = &x->e_mbd;
40   // Match MV and reference indices
41   for (int i = 0; i < 2; ++i) {
42     if ((st->ref_frames[i] != mi->ref_frame[i]) ||
43         (st->mv[i].as_int != mi->mv[i].as_int)) {
44       return 0;
45     }
46     const WarpedMotionParams *const wm = &xd->global_motion[mi->ref_frame[i]];
47     if (is_global_mv_block(mi, wm->wmtype) != st->is_global[i]) return 0;
48   }
49 
50   // Store the stats for COMPOUND_AVERAGE and COMPOUND_DISTWTD
51   for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
52        comp_type++) {
53     comp_rate[comp_type] = st->rate[comp_type];
54     comp_dist[comp_type] = st->dist[comp_type];
55     comp_model_rate[comp_type] = st->model_rate[comp_type];
56     comp_model_dist[comp_type] = st->model_dist[comp_type];
57     comp_rs2[comp_type] = st->comp_rs2[comp_type];
58   }
59 
60   // For compound wedge/segment, reuse data only if NEWMV is not present in
61   // either of the directions
62   if ((!have_newmv_in_inter_mode(mi->mode) &&
63        !have_newmv_in_inter_mode(st->mode)) ||
64       (cpi->sf.inter_sf.disable_interinter_wedge_newmv_search)) {
65     memcpy(&comp_rate[COMPOUND_WEDGE], &st->rate[COMPOUND_WEDGE],
66            sizeof(comp_rate[COMPOUND_WEDGE]) * 2);
67     memcpy(&comp_dist[COMPOUND_WEDGE], &st->dist[COMPOUND_WEDGE],
68            sizeof(comp_dist[COMPOUND_WEDGE]) * 2);
69     memcpy(&comp_model_rate[COMPOUND_WEDGE], &st->model_rate[COMPOUND_WEDGE],
70            sizeof(comp_model_rate[COMPOUND_WEDGE]) * 2);
71     memcpy(&comp_model_dist[COMPOUND_WEDGE], &st->model_dist[COMPOUND_WEDGE],
72            sizeof(comp_model_dist[COMPOUND_WEDGE]) * 2);
73     memcpy(&comp_rs2[COMPOUND_WEDGE], &st->comp_rs2[COMPOUND_WEDGE],
74            sizeof(comp_rs2[COMPOUND_WEDGE]) * 2);
75   }
76   return 1;
77 }
78 
79 // Checks if similar compound type search case is accounted earlier
80 // If found, returns relevant rd data
find_comp_rd_in_stats(const AV1_COMP * const cpi,const MACROBLOCK * x,const MB_MODE_INFO * const mbmi,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int * comp_rs2,int * match_index)81 static INLINE int find_comp_rd_in_stats(const AV1_COMP *const cpi,
82                                         const MACROBLOCK *x,
83                                         const MB_MODE_INFO *const mbmi,
84                                         int32_t *comp_rate, int64_t *comp_dist,
85                                         int32_t *comp_model_rate,
86                                         int64_t *comp_model_dist, int *comp_rs2,
87                                         int *match_index) {
88   for (int j = 0; j < x->comp_rd_stats_idx; ++j) {
89     if (is_comp_rd_match(cpi, x, &x->comp_rd_stats[j], mbmi, comp_rate,
90                          comp_dist, comp_model_rate, comp_model_dist,
91                          comp_rs2)) {
92       *match_index = j;
93       return 1;
94     }
95   }
96   return 0;  // no match result found
97 }
98 
enable_wedge_search(MACROBLOCK * const x,const AV1_COMP * const cpi)99 static INLINE bool enable_wedge_search(MACROBLOCK *const x,
100                                        const AV1_COMP *const cpi) {
101   // Enable wedge search if source variance and edge strength are above
102   // the thresholds.
103   return x->source_variance >
104              cpi->sf.inter_sf.disable_wedge_search_var_thresh &&
105          x->edge_strength > cpi->sf.inter_sf.disable_wedge_search_edge_thresh;
106 }
107 
enable_wedge_interinter_search(MACROBLOCK * const x,const AV1_COMP * const cpi)108 static INLINE bool enable_wedge_interinter_search(MACROBLOCK *const x,
109                                                   const AV1_COMP *const cpi) {
110   return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interinter_wedge &&
111          !cpi->sf.inter_sf.disable_interinter_wedge;
112 }
113 
enable_wedge_interintra_search(MACROBLOCK * const x,const AV1_COMP * const cpi)114 static INLINE bool enable_wedge_interintra_search(MACROBLOCK *const x,
115                                                   const AV1_COMP *const cpi) {
116   return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interintra_wedge &&
117          !cpi->sf.inter_sf.disable_wedge_interintra_search;
118 }
119 
estimate_wedge_sign(const AV1_COMP * cpi,const MACROBLOCK * x,const BLOCK_SIZE bsize,const uint8_t * pred0,int stride0,const uint8_t * pred1,int stride1)120 static int8_t estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x,
121                                   const BLOCK_SIZE bsize, const uint8_t *pred0,
122                                   int stride0, const uint8_t *pred1,
123                                   int stride1) {
124   static const BLOCK_SIZE split_qtr[BLOCK_SIZES_ALL] = {
125     //                            4X4
126     BLOCK_INVALID,
127     // 4X8,        8X4,           8X8
128     BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X4,
129     // 8X16,       16X8,          16X16
130     BLOCK_4X8, BLOCK_8X4, BLOCK_8X8,
131     // 16X32,      32X16,         32X32
132     BLOCK_8X16, BLOCK_16X8, BLOCK_16X16,
133     // 32X64,      64X32,         64X64
134     BLOCK_16X32, BLOCK_32X16, BLOCK_32X32,
135     // 64x128,     128x64,        128x128
136     BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
137     // 4X16,       16X4,          8X32
138     BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X16,
139     // 32X8,       16X64,         64X16
140     BLOCK_16X4, BLOCK_8X32, BLOCK_32X8
141   };
142   const struct macroblock_plane *const p = &x->plane[0];
143   const uint8_t *src = p->src.buf;
144   int src_stride = p->src.stride;
145   const int bw = block_size_wide[bsize];
146   const int bh = block_size_high[bsize];
147   const int bw_by2 = bw >> 1;
148   const int bh_by2 = bh >> 1;
149   uint32_t esq[2][2];
150   int64_t tl, br;
151 
152   const BLOCK_SIZE f_index = split_qtr[bsize];
153   assert(f_index != BLOCK_INVALID);
154 
155   if (is_cur_buf_hbd(&x->e_mbd)) {
156     pred0 = CONVERT_TO_BYTEPTR(pred0);
157     pred1 = CONVERT_TO_BYTEPTR(pred1);
158   }
159 
160   // Residual variance computation over relevant quandrants in order to
161   // find TL + BR, TL = sum(1st,2nd,3rd) quadrants of (pred0 - pred1),
162   // BR = sum(2nd,3rd,4th) quadrants of (pred1 - pred0)
163   // The 2nd and 3rd quadrants cancel out in TL + BR
164   // Hence TL + BR = 1st quadrant of (pred0-pred1) + 4th of (pred1-pred0)
165   // TODO(nithya): Sign estimation assumes 45 degrees (1st and 4th quadrants)
166   // for all codebooks; experiment with other quadrant combinations for
167   // 0, 90 and 135 degrees also.
168   cpi->fn_ptr[f_index].vf(src, src_stride, pred0, stride0, &esq[0][0]);
169   cpi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
170                           pred0 + bh_by2 * stride0 + bw_by2, stride0,
171                           &esq[0][1]);
172   cpi->fn_ptr[f_index].vf(src, src_stride, pred1, stride1, &esq[1][0]);
173   cpi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
174                           pred1 + bh_by2 * stride1 + bw_by2, stride0,
175                           &esq[1][1]);
176 
177   tl = ((int64_t)esq[0][0]) - ((int64_t)esq[1][0]);
178   br = ((int64_t)esq[1][1]) - ((int64_t)esq[0][1]);
179   return (tl + br > 0);
180 }
181 
182 // Choose the best wedge index and sign
pick_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const int16_t * const residual1,const int16_t * const diff10,int8_t * const best_wedge_sign,int8_t * const best_wedge_index,uint64_t * best_sse)183 static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
184                           const BLOCK_SIZE bsize, const uint8_t *const p0,
185                           const int16_t *const residual1,
186                           const int16_t *const diff10,
187                           int8_t *const best_wedge_sign,
188                           int8_t *const best_wedge_index, uint64_t *best_sse) {
189   const MACROBLOCKD *const xd = &x->e_mbd;
190   const struct buf_2d *const src = &x->plane[0].src;
191   const int bw = block_size_wide[bsize];
192   const int bh = block_size_high[bsize];
193   const int N = bw * bh;
194   assert(N >= 64);
195   int rate;
196   int64_t dist;
197   int64_t rd, best_rd = INT64_MAX;
198   int8_t wedge_index;
199   int8_t wedge_sign;
200   const int8_t wedge_types = get_wedge_types_lookup(bsize);
201   const uint8_t *mask;
202   uint64_t sse;
203   const int hbd = is_cur_buf_hbd(xd);
204   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
205 
206   DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]);  // src - pred0
207 #if CONFIG_AV1_HIGHBITDEPTH
208   if (hbd) {
209     aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride,
210                               CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
211   } else {
212     aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
213   }
214 #else
215   (void)hbd;
216   aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
217 #endif
218 
219   int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) -
220                         (int64_t)aom_sum_squares_i16(residual1, N)) *
221                        (1 << WEDGE_WEIGHT_BITS) / 2;
222   int16_t *ds = residual0;
223 
224   av1_wedge_compute_delta_squares(ds, residual0, residual1, N);
225 
226   for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
227     mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize);
228 
229     wedge_sign = av1_wedge_sign_from_residuals(ds, mask, N, sign_limit);
230 
231     mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
232     sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
233     sse = ROUND_POWER_OF_TWO(sse, bd_round);
234 
235     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
236                                                   &rate, &dist);
237     // int rate2;
238     // int64_t dist2;
239     // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2);
240     // printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n",
241     // sse, rate, dist, rate2, dist2); dist = dist2;
242     // rate = rate2;
243 
244     rate += x->wedge_idx_cost[bsize][wedge_index];
245     rd = RDCOST(x->rdmult, rate, dist);
246 
247     if (rd < best_rd) {
248       *best_wedge_index = wedge_index;
249       *best_wedge_sign = wedge_sign;
250       best_rd = rd;
251       *best_sse = sse;
252     }
253   }
254 
255   return best_rd -
256          RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
257 }
258 
259 // Choose the best wedge index the specified sign
pick_wedge_fixed_sign(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const int16_t * const residual1,const int16_t * const diff10,const int8_t wedge_sign,int8_t * const best_wedge_index,uint64_t * best_sse)260 static int64_t pick_wedge_fixed_sign(
261     const AV1_COMP *const cpi, const MACROBLOCK *const x,
262     const BLOCK_SIZE bsize, const int16_t *const residual1,
263     const int16_t *const diff10, const int8_t wedge_sign,
264     int8_t *const best_wedge_index, uint64_t *best_sse) {
265   const MACROBLOCKD *const xd = &x->e_mbd;
266 
267   const int bw = block_size_wide[bsize];
268   const int bh = block_size_high[bsize];
269   const int N = bw * bh;
270   assert(N >= 64);
271   int rate;
272   int64_t dist;
273   int64_t rd, best_rd = INT64_MAX;
274   int8_t wedge_index;
275   const int8_t wedge_types = get_wedge_types_lookup(bsize);
276   const uint8_t *mask;
277   uint64_t sse;
278   const int hbd = is_cur_buf_hbd(xd);
279   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
280   for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
281     mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
282     sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
283     sse = ROUND_POWER_OF_TWO(sse, bd_round);
284 
285     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
286                                                   &rate, &dist);
287     rate += x->wedge_idx_cost[bsize][wedge_index];
288     rd = RDCOST(x->rdmult, rate, dist);
289 
290     if (rd < best_rd) {
291       *best_wedge_index = wedge_index;
292       best_rd = rd;
293       *best_sse = sse;
294     }
295   }
296   return best_rd -
297          RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
298 }
299 
pick_interinter_wedge(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10,uint64_t * best_sse)300 static int64_t pick_interinter_wedge(
301     const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
302     const uint8_t *const p0, const uint8_t *const p1,
303     const int16_t *const residual1, const int16_t *const diff10,
304     uint64_t *best_sse) {
305   MACROBLOCKD *const xd = &x->e_mbd;
306   MB_MODE_INFO *const mbmi = xd->mi[0];
307   const int bw = block_size_wide[bsize];
308 
309   int64_t rd;
310   int8_t wedge_index = -1;
311   int8_t wedge_sign = 0;
312 
313   assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
314   assert(cpi->common.seq_params.enable_masked_compound);
315 
316   if (cpi->sf.inter_sf.fast_wedge_sign_estimate) {
317     wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
318     rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
319                                &wedge_index, best_sse);
320   } else {
321     rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
322                     &wedge_index, best_sse);
323   }
324 
325   mbmi->interinter_comp.wedge_sign = wedge_sign;
326   mbmi->interinter_comp.wedge_index = wedge_index;
327   return rd;
328 }
329 
pick_interinter_seg(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10,uint64_t * best_sse)330 static int64_t pick_interinter_seg(const AV1_COMP *const cpi,
331                                    MACROBLOCK *const x, const BLOCK_SIZE bsize,
332                                    const uint8_t *const p0,
333                                    const uint8_t *const p1,
334                                    const int16_t *const residual1,
335                                    const int16_t *const diff10,
336                                    uint64_t *best_sse) {
337   MACROBLOCKD *const xd = &x->e_mbd;
338   MB_MODE_INFO *const mbmi = xd->mi[0];
339   const int bw = block_size_wide[bsize];
340   const int bh = block_size_high[bsize];
341   const int N = 1 << num_pels_log2_lookup[bsize];
342   int rate;
343   int64_t dist;
344   DIFFWTD_MASK_TYPE cur_mask_type;
345   int64_t best_rd = INT64_MAX;
346   DIFFWTD_MASK_TYPE best_mask_type = 0;
347   const int hbd = is_cur_buf_hbd(xd);
348   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
349   DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
350   uint8_t *tmp_mask[2] = { xd->seg_mask, seg_mask };
351   // try each mask type and its inverse
352   for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) {
353     // build mask and inverse
354     if (hbd)
355       av1_build_compound_diffwtd_mask_highbd(
356           tmp_mask[cur_mask_type], cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw,
357           CONVERT_TO_BYTEPTR(p1), bw, bh, bw, xd->bd);
358     else
359       av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type,
360                                       p0, bw, p1, bw, bh, bw);
361 
362     // compute rd for mask
363     uint64_t sse = av1_wedge_sse_from_residuals(residual1, diff10,
364                                                 tmp_mask[cur_mask_type], N);
365     sse = ROUND_POWER_OF_TWO(sse, bd_round);
366 
367     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
368                                                   &rate, &dist);
369     const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
370 
371     if (rd0 < best_rd) {
372       best_mask_type = cur_mask_type;
373       best_rd = rd0;
374       *best_sse = sse;
375     }
376   }
377   mbmi->interinter_comp.mask_type = best_mask_type;
378   if (best_mask_type == DIFFWTD_38_INV) {
379     memcpy(xd->seg_mask, seg_mask, N * 2);
380   }
381   return best_rd;
382 }
383 
pick_interintra_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1)384 static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
385                                      const MACROBLOCK *const x,
386                                      const BLOCK_SIZE bsize,
387                                      const uint8_t *const p0,
388                                      const uint8_t *const p1) {
389   const MACROBLOCKD *const xd = &x->e_mbd;
390   MB_MODE_INFO *const mbmi = xd->mi[0];
391   assert(av1_is_wedge_used(bsize));
392   assert(cpi->common.seq_params.enable_interintra_compound);
393 
394   const struct buf_2d *const src = &x->plane[0].src;
395   const int bw = block_size_wide[bsize];
396   const int bh = block_size_high[bsize];
397   DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]);  // src - pred1
398   DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]);     // pred1 - pred0
399 #if CONFIG_AV1_HIGHBITDEPTH
400   if (is_cur_buf_hbd(xd)) {
401     aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
402                               CONVERT_TO_BYTEPTR(p1), bw, xd->bd);
403     aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw,
404                               CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
405   } else {
406     aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
407     aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
408   }
409 #else
410   aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
411   aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
412 #endif
413   int8_t wedge_index = -1;
414   uint64_t sse;
415   int64_t rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0,
416                                      &wedge_index, &sse);
417 
418   mbmi->interintra_wedge_index = wedge_index;
419   return rd;
420 }
421 
get_inter_predictors_masked_compound(MACROBLOCK * x,const BLOCK_SIZE bsize,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides)422 static AOM_INLINE void get_inter_predictors_masked_compound(
423     MACROBLOCK *x, const BLOCK_SIZE bsize, uint8_t **preds0, uint8_t **preds1,
424     int16_t *residual1, int16_t *diff10, int *strides) {
425   MACROBLOCKD *xd = &x->e_mbd;
426   const int bw = block_size_wide[bsize];
427   const int bh = block_size_high[bsize];
428   // get inter predictors to use for masked compound modes
429   av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0, preds0,
430                                                    strides);
431   av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1, preds1,
432                                                    strides);
433   const struct buf_2d *const src = &x->plane[0].src;
434 #if CONFIG_AV1_HIGHBITDEPTH
435   if (is_cur_buf_hbd(xd)) {
436     aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
437                               CONVERT_TO_BYTEPTR(*preds1), bw, xd->bd);
438     aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(*preds1),
439                               bw, CONVERT_TO_BYTEPTR(*preds0), bw, xd->bd);
440   } else {
441     aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1,
442                        bw);
443     aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
444   }
445 #else
446   aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1, bw);
447   aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
448 #endif
449 }
450 
451 // Computes the rd cost for the given interintra mode and updates the best
compute_best_interintra_mode(const AV1_COMP * const cpi,MB_MODE_INFO * mbmi,MACROBLOCKD * xd,MACROBLOCK * const x,const int * const interintra_mode_cost,const BUFFER_SET * orig_dst,uint8_t * intrapred,const uint8_t * tmp_buf,INTERINTRA_MODE * best_interintra_mode,int64_t * best_interintra_rd,INTERINTRA_MODE interintra_mode,BLOCK_SIZE bsize)452 static INLINE void compute_best_interintra_mode(
453     const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
454     MACROBLOCK *const x, const int *const interintra_mode_cost,
455     const BUFFER_SET *orig_dst, uint8_t *intrapred, const uint8_t *tmp_buf,
456     INTERINTRA_MODE *best_interintra_mode, int64_t *best_interintra_rd,
457     INTERINTRA_MODE interintra_mode, BLOCK_SIZE bsize) {
458   const AV1_COMMON *const cm = &cpi->common;
459   int rate, skip_txfm_sb;
460   int64_t dist, skip_sse_sb;
461   const int bw = block_size_wide[bsize];
462   mbmi->interintra_mode = interintra_mode;
463   int rmode = interintra_mode_cost[interintra_mode];
464   av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
465                                             intrapred, bw);
466   av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
467   model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](cpi, bsize, x, xd, 0, 0, &rate, &dist,
468                                           &skip_txfm_sb, &skip_sse_sb, NULL,
469                                           NULL, NULL);
470   int64_t rd = RDCOST(x->rdmult, rate + rmode, dist);
471   if (rd < *best_interintra_rd) {
472     *best_interintra_rd = rd;
473     *best_interintra_mode = mbmi->interintra_mode;
474   }
475 }
476 
estimate_yrd_for_sb(const AV1_COMP * const cpi,BLOCK_SIZE bs,MACROBLOCK * x,int64_t ref_best_rd,RD_STATS * rd_stats)477 static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
478                                    MACROBLOCK *x, int64_t ref_best_rd,
479                                    RD_STATS *rd_stats) {
480   MACROBLOCKD *const xd = &x->e_mbd;
481   if (ref_best_rd < 0) return INT64_MAX;
482   av1_subtract_plane(x, bs, 0);
483   x->rd_model = LOW_TXFM_RD;
484   const int skip_trellis = (cpi->optimize_seg_arr[xd->mi[0]->segment_id] ==
485                             NO_ESTIMATE_YRD_TRELLIS_OPT);
486   const int64_t rd =
487       av1_uniform_txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs,
488                            max_txsize_rect_lookup[bs], FTXS_NONE, skip_trellis);
489   x->rd_model = FULL_TXFM_RD;
490   if (rd != INT64_MAX) {
491     const int skip_ctx = av1_get_skip_context(xd);
492     if (rd_stats->skip) {
493       const int s1 = x->skip_cost[skip_ctx][1];
494       rd_stats->rate = s1;
495     } else {
496       const int s0 = x->skip_cost[skip_ctx][0];
497       rd_stats->rate += s0;
498     }
499   }
500   return rd;
501 }
502 
503 // Computes the rd_threshold for smooth interintra rd search.
compute_rd_thresh(MACROBLOCK * const x,int total_mode_rate,int64_t ref_best_rd)504 static AOM_INLINE int64_t compute_rd_thresh(MACROBLOCK *const x,
505                                             int total_mode_rate,
506                                             int64_t ref_best_rd) {
507   const int64_t rd_thresh = get_rd_thresh_from_best_rd(
508       ref_best_rd, (1 << INTER_INTRA_RD_THRESH_SHIFT),
509       INTER_INTRA_RD_THRESH_SCALE);
510   const int64_t mode_rd = RDCOST(x->rdmult, total_mode_rate, 0);
511   return (rd_thresh - mode_rd);
512 }
513 
514 // Computes the best wedge interintra mode
compute_best_wedge_interintra(const AV1_COMP * const cpi,MB_MODE_INFO * mbmi,MACROBLOCKD * xd,MACROBLOCK * const x,const int * const interintra_mode_cost,const BUFFER_SET * orig_dst,uint8_t * intrapred_,uint8_t * tmp_buf_,int * best_mode,int * best_wedge_index,BLOCK_SIZE bsize)515 static AOM_INLINE int64_t compute_best_wedge_interintra(
516     const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
517     MACROBLOCK *const x, const int *const interintra_mode_cost,
518     const BUFFER_SET *orig_dst, uint8_t *intrapred_, uint8_t *tmp_buf_,
519     int *best_mode, int *best_wedge_index, BLOCK_SIZE bsize) {
520   const AV1_COMMON *const cm = &cpi->common;
521   const int bw = block_size_wide[bsize];
522   int64_t best_interintra_rd_wedge = INT64_MAX;
523   int64_t best_total_rd = INT64_MAX;
524   uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
525   for (INTERINTRA_MODE mode = 0; mode < INTERINTRA_MODES; ++mode) {
526     mbmi->interintra_mode = mode;
527     av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
528                                               intrapred, bw);
529     int64_t rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
530     const int rate_overhead =
531         interintra_mode_cost[mode] +
532         x->wedge_idx_cost[bsize][mbmi->interintra_wedge_index];
533     const int64_t total_rd = rd + RDCOST(x->rdmult, rate_overhead, 0);
534     if (total_rd < best_total_rd) {
535       best_total_rd = total_rd;
536       best_interintra_rd_wedge = rd;
537       *best_mode = mbmi->interintra_mode;
538       *best_wedge_index = mbmi->interintra_wedge_index;
539     }
540   }
541   return best_interintra_rd_wedge;
542 }
543 
av1_handle_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,MB_MODE_INFO * mbmi,HandleInterModeArgs * args,int64_t ref_best_rd,int * rate_mv,int * tmp_rate2,const BUFFER_SET * orig_dst)544 int av1_handle_inter_intra_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
545                                 BLOCK_SIZE bsize, MB_MODE_INFO *mbmi,
546                                 HandleInterModeArgs *args, int64_t ref_best_rd,
547                                 int *rate_mv, int *tmp_rate2,
548                                 const BUFFER_SET *orig_dst) {
549   const int try_smooth_interintra = cpi->oxcf.enable_smooth_interintra &&
550                                     !cpi->sf.inter_sf.disable_smooth_interintra;
551   const int is_wedge_used = av1_is_wedge_used(bsize);
552   const int try_wedge_interintra =
553       is_wedge_used && enable_wedge_interintra_search(x, cpi);
554   if (!try_smooth_interintra && !try_wedge_interintra) return -1;
555 
556   const AV1_COMMON *const cm = &cpi->common;
557   MACROBLOCKD *xd = &x->e_mbd;
558   int64_t rd = INT64_MAX;
559   const int bw = block_size_wide[bsize];
560   DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
561   DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
562   uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
563   uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
564   const int *const interintra_mode_cost =
565       x->interintra_mode_cost[size_group_lookup[bsize]];
566   const int mi_row = xd->mi_row;
567   const int mi_col = xd->mi_col;
568 
569   // Single reference inter prediction
570   mbmi->ref_frame[1] = NONE_FRAME;
571   xd->plane[0].dst.buf = tmp_buf;
572   xd->plane[0].dst.stride = bw;
573   av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
574                                 AOM_PLANE_Y, AOM_PLANE_Y);
575   const int num_planes = av1_num_planes(cm);
576 
577   // Restore the buffers for intra prediction
578   restore_dst_buf(xd, *orig_dst, num_planes);
579   mbmi->ref_frame[1] = INTRA_FRAME;
580   INTERINTRA_MODE best_interintra_mode =
581       args->inter_intra_mode[mbmi->ref_frame[0]];
582 
583   // Compute smooth_interintra
584   int64_t best_interintra_rd_nowedge = INT64_MAX;
585   int best_mode_rate = INT_MAX;
586   if (try_smooth_interintra) {
587     mbmi->use_wedge_interintra = 0;
588     int interintra_mode_reuse = 1;
589     if (cpi->sf.inter_sf.reuse_inter_intra_mode == 0 ||
590         best_interintra_mode == INTERINTRA_MODES) {
591       interintra_mode_reuse = 0;
592       int64_t best_interintra_rd = INT64_MAX;
593       for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
594            ++cur_mode) {
595         if ((!cpi->oxcf.enable_smooth_intra ||
596              cpi->sf.intra_sf.disable_smooth_intra) &&
597             cur_mode == II_SMOOTH_PRED)
598           continue;
599         compute_best_interintra_mode(cpi, mbmi, xd, x, interintra_mode_cost,
600                                      orig_dst, intrapred, tmp_buf,
601                                      &best_interintra_mode, &best_interintra_rd,
602                                      cur_mode, bsize);
603       }
604       args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
605     }
606     assert(IMPLIES(!cpi->oxcf.enable_smooth_interintra ||
607                        cpi->sf.inter_sf.disable_smooth_interintra,
608                    best_interintra_mode != II_SMOOTH_PRED));
609     // Recompute prediction if required
610     if (interintra_mode_reuse || best_interintra_mode != INTERINTRA_MODES - 1) {
611       mbmi->interintra_mode = best_interintra_mode;
612       av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
613                                                 intrapred, bw);
614       av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
615     }
616 
617     // Compute rd cost for best smooth_interintra
618     RD_STATS rd_stats;
619     const int rmode = interintra_mode_cost[best_interintra_mode] +
620                       (is_wedge_used ? x->wedge_interintra_cost[bsize][0] : 0);
621     const int total_mode_rate = rmode + *rate_mv;
622     const int64_t rd_thresh =
623         compute_rd_thresh(x, total_mode_rate, ref_best_rd);
624     rd = estimate_yrd_for_sb(cpi, bsize, x, rd_thresh, &rd_stats);
625     if (rd != INT64_MAX) {
626       rd = RDCOST(x->rdmult, total_mode_rate + rd_stats.rate, rd_stats.dist);
627     } else {
628       return -1;
629     }
630     best_interintra_rd_nowedge = rd;
631     best_mode_rate = rmode;
632     // Return early if best_interintra_rd_nowedge not good enough
633     if (ref_best_rd < INT64_MAX &&
634         (best_interintra_rd_nowedge >> INTER_INTRA_RD_THRESH_SHIFT) *
635                 INTER_INTRA_RD_THRESH_SCALE >
636             ref_best_rd) {
637       return -1;
638     }
639   }
640 
641   // Compute wedge interintra
642   int64_t best_interintra_rd_wedge = INT64_MAX;
643   if (try_wedge_interintra) {
644     mbmi->use_wedge_interintra = 1;
645     if (!cpi->sf.inter_sf.fast_interintra_wedge_search) {
646       // Exhaustive search of all wedge and mode combinations.
647       int best_mode = 0;
648       int best_wedge_index = 0;
649       best_interintra_rd_wedge = compute_best_wedge_interintra(
650           cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred_,
651           tmp_buf_, &best_mode, &best_wedge_index, bsize);
652       mbmi->interintra_mode = best_mode;
653       mbmi->interintra_wedge_index = best_wedge_index;
654       if (best_mode != INTERINTRA_MODES - 1) {
655         av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
656                                                   intrapred, bw);
657       }
658     } else if (!try_smooth_interintra) {
659       if (best_interintra_mode == INTERINTRA_MODES) {
660         mbmi->interintra_mode = INTERINTRA_MODES - 1;
661         best_interintra_mode = INTERINTRA_MODES - 1;
662         av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
663                                                   intrapred, bw);
664         // Pick wedge mask based on INTERINTRA_MODES - 1
665         best_interintra_rd_wedge =
666             pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
667         // Find the best interintra mode for the chosen wedge mask
668         for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
669              ++cur_mode) {
670           compute_best_interintra_mode(
671               cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred,
672               tmp_buf, &best_interintra_mode, &best_interintra_rd_wedge,
673               cur_mode, bsize);
674         }
675         args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
676         mbmi->interintra_mode = best_interintra_mode;
677 
678         // Recompute prediction if required
679         if (best_interintra_mode != INTERINTRA_MODES - 1) {
680           av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
681                                                     intrapred, bw);
682         }
683       } else {
684         // Pick wedge mask for the best interintra mode (reused)
685         mbmi->interintra_mode = best_interintra_mode;
686         av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
687                                                   intrapred, bw);
688         best_interintra_rd_wedge =
689             pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
690       }
691     } else {
692       // Pick wedge mask for the best interintra mode from smooth_interintra
693       best_interintra_rd_wedge =
694           pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
695     }
696 
697     const int rate_overhead =
698         interintra_mode_cost[mbmi->interintra_mode] +
699         x->wedge_idx_cost[bsize][mbmi->interintra_wedge_index] +
700         x->wedge_interintra_cost[bsize][1];
701     best_interintra_rd_wedge += RDCOST(x->rdmult, rate_overhead + *rate_mv, 0);
702 
703     const int_mv mv0 = mbmi->mv[0];
704     int_mv tmp_mv = mv0;
705     rd = INT64_MAX;
706     int tmp_rate_mv = 0;
707     // Refine motion vector for NEWMV case.
708     if (have_newmv_in_inter_mode(mbmi->mode)) {
709       int rate_sum, skip_txfm_sb;
710       int64_t dist_sum, skip_sse_sb;
711       // get negative of mask
712       const uint8_t *mask =
713           av1_get_contiguous_soft_mask(mbmi->interintra_wedge_index, 1, bsize);
714       av1_compound_single_motion_search(cpi, x, bsize, &tmp_mv.as_mv, intrapred,
715                                         mask, bw, &tmp_rate_mv, 0);
716       if (mbmi->mv[0].as_int != tmp_mv.as_int) {
717         mbmi->mv[0].as_int = tmp_mv.as_int;
718         // Set ref_frame[1] to NONE_FRAME temporarily so that the intra
719         // predictor is not calculated again in av1_enc_build_inter_predictor().
720         mbmi->ref_frame[1] = NONE_FRAME;
721         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
722                                       AOM_PLANE_Y, AOM_PLANE_Y);
723         mbmi->ref_frame[1] = INTRA_FRAME;
724         av1_combine_interintra(xd, bsize, 0, xd->plane[AOM_PLANE_Y].dst.buf,
725                                xd->plane[AOM_PLANE_Y].dst.stride, intrapred,
726                                bw);
727         model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
728             cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &skip_txfm_sb,
729             &skip_sse_sb, NULL, NULL, NULL);
730         rd =
731             RDCOST(x->rdmult, tmp_rate_mv + rate_overhead + rate_sum, dist_sum);
732       }
733     }
734     if (rd >= best_interintra_rd_wedge) {
735       tmp_mv.as_int = mv0.as_int;
736       tmp_rate_mv = *rate_mv;
737       av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
738     }
739     // Evaluate closer to true rd
740     RD_STATS rd_stats;
741     const int64_t mode_rd = RDCOST(x->rdmult, rate_overhead + tmp_rate_mv, 0);
742     const int64_t tmp_rd_thresh = best_interintra_rd_nowedge - mode_rd;
743     rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
744     if (rd != INT64_MAX) {
745       rd = RDCOST(x->rdmult, rate_overhead + tmp_rate_mv + rd_stats.rate,
746                   rd_stats.dist);
747     } else {
748       if (best_interintra_rd_nowedge == INT64_MAX) return -1;
749     }
750     best_interintra_rd_wedge = rd;
751     if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
752       mbmi->mv[0].as_int = tmp_mv.as_int;
753       *tmp_rate2 += tmp_rate_mv - *rate_mv;
754       *rate_mv = tmp_rate_mv;
755       best_mode_rate = rate_overhead;
756     } else {
757       mbmi->use_wedge_interintra = 0;
758       mbmi->interintra_mode = best_interintra_mode;
759       mbmi->mv[0].as_int = mv0.as_int;
760       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
761                                     AOM_PLANE_Y, AOM_PLANE_Y);
762     }
763   }
764 
765   if (best_interintra_rd_nowedge == INT64_MAX &&
766       best_interintra_rd_wedge == INT64_MAX) {
767     return -1;
768   }
769 
770   *tmp_rate2 += best_mode_rate;
771 
772   if (num_planes > 1) {
773     av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
774                                   AOM_PLANE_U, num_planes - 1);
775   }
776   return 0;
777 }
778 
alloc_compound_type_rd_buffers_no_check(CompoundTypeRdBuffers * const bufs)779 static void alloc_compound_type_rd_buffers_no_check(
780     CompoundTypeRdBuffers *const bufs) {
781   bufs->pred0 =
782       (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred0));
783   bufs->pred1 =
784       (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred1));
785   bufs->residual1 =
786       (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->residual1));
787   bufs->diff10 =
788       (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->diff10));
789   bufs->tmp_best_mask_buf = (uint8_t *)aom_malloc(
790       2 * MAX_SB_SQUARE * sizeof(*bufs->tmp_best_mask_buf));
791 }
792 
793 // Computes the valid compound_types to be evaluated
compute_valid_comp_types(MACROBLOCK * x,const AV1_COMP * const cpi,int * try_average_and_distwtd_comp,BLOCK_SIZE bsize,int masked_compound_used,int mode_search_mask,COMPOUND_TYPE * valid_comp_types)794 static INLINE int compute_valid_comp_types(
795     MACROBLOCK *x, const AV1_COMP *const cpi, int *try_average_and_distwtd_comp,
796     BLOCK_SIZE bsize, int masked_compound_used, int mode_search_mask,
797     COMPOUND_TYPE *valid_comp_types) {
798   const AV1_COMMON *cm = &cpi->common;
799   int valid_type_count = 0;
800   int comp_type, valid_check;
801   int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
802 
803   const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
804   const int try_distwtd_comp =
805       ((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
806        cm->seq_params.order_hint_info.enable_dist_wtd_comp == 1 &&
807        cpi->sf.inter_sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
808   *try_average_and_distwtd_comp = try_average_comp && try_distwtd_comp;
809 
810   // Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
811   for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
812        comp_type++) {
813     valid_check =
814         (comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
815     if (!*try_average_and_distwtd_comp && valid_check &&
816         is_interinter_compound_used(comp_type, bsize))
817       valid_comp_types[valid_type_count++] = comp_type;
818   }
819   // Check if COMPOUND_WEDGE and COMPOUND_DIFFWTD are valid cases
820   if (masked_compound_used) {
821     // enable_masked_type[0] corresponds to COMPOUND_WEDGE
822     // enable_masked_type[1] corresponds to COMPOUND_DIFFWTD
823     enable_masked_type[0] = enable_wedge_interinter_search(x, cpi);
824     enable_masked_type[1] = cpi->oxcf.enable_diff_wtd_comp;
825     for (comp_type = COMPOUND_WEDGE; comp_type <= COMPOUND_DIFFWTD;
826          comp_type++) {
827       if ((mode_search_mask & (1 << comp_type)) &&
828           is_interinter_compound_used(comp_type, bsize) &&
829           enable_masked_type[comp_type - COMPOUND_WEDGE])
830         valid_comp_types[valid_type_count++] = comp_type;
831     }
832   }
833   return valid_type_count;
834 }
835 
836 // Calculates the cost for compound type mask
calc_masked_type_cost(MACROBLOCK * x,BLOCK_SIZE bsize,int comp_group_idx_ctx,int comp_index_ctx,int masked_compound_used,int * masked_type_cost)837 static INLINE void calc_masked_type_cost(MACROBLOCK *x, BLOCK_SIZE bsize,
838                                          int comp_group_idx_ctx,
839                                          int comp_index_ctx,
840                                          int masked_compound_used,
841                                          int *masked_type_cost) {
842   av1_zero_array(masked_type_cost, COMPOUND_TYPES);
843   // Account for group index cost when wedge and/or diffwtd prediction are
844   // enabled
845   if (masked_compound_used) {
846     // Compound group index of average and distwtd is 0
847     // Compound group index of wedge and diffwtd is 1
848     masked_type_cost[COMPOUND_AVERAGE] +=
849         x->comp_group_idx_cost[comp_group_idx_ctx][0];
850     masked_type_cost[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_AVERAGE];
851     masked_type_cost[COMPOUND_WEDGE] +=
852         x->comp_group_idx_cost[comp_group_idx_ctx][1];
853     masked_type_cost[COMPOUND_DIFFWTD] += masked_type_cost[COMPOUND_WEDGE];
854   }
855 
856   // Compute the cost to signal compound index/type
857   masked_type_cost[COMPOUND_AVERAGE] += x->comp_idx_cost[comp_index_ctx][1];
858   masked_type_cost[COMPOUND_DISTWTD] += x->comp_idx_cost[comp_index_ctx][0];
859   masked_type_cost[COMPOUND_WEDGE] += x->compound_type_cost[bsize][0];
860   masked_type_cost[COMPOUND_DIFFWTD] += x->compound_type_cost[bsize][1];
861 }
862 
863 // Updates mbmi structure with the relevant compound type info
update_mbmi_for_compound_type(MB_MODE_INFO * mbmi,COMPOUND_TYPE cur_type)864 static INLINE void update_mbmi_for_compound_type(MB_MODE_INFO *mbmi,
865                                                  COMPOUND_TYPE cur_type) {
866   mbmi->interinter_comp.type = cur_type;
867   mbmi->comp_group_idx = (cur_type >= COMPOUND_WEDGE);
868   mbmi->compound_idx = (cur_type != COMPOUND_DISTWTD);
869 }
870 
871 // When match is found, populate the compound type data
872 // and calculate the rd cost using the stored stats and
873 // update the mbmi appropriately.
populate_reuse_comp_type_data(const MACROBLOCK * x,MB_MODE_INFO * mbmi,BEST_COMP_TYPE_STATS * best_type_stats,int_mv * cur_mv,int32_t * comp_rate,int64_t * comp_dist,int * comp_rs2,int * rate_mv,int64_t * rd,int match_index)874 static INLINE int populate_reuse_comp_type_data(
875     const MACROBLOCK *x, MB_MODE_INFO *mbmi,
876     BEST_COMP_TYPE_STATS *best_type_stats, int_mv *cur_mv, int32_t *comp_rate,
877     int64_t *comp_dist, int *comp_rs2, int *rate_mv, int64_t *rd,
878     int match_index) {
879   const int winner_comp_type =
880       x->comp_rd_stats[match_index].interinter_comp.type;
881   if (comp_rate[winner_comp_type] == INT_MAX)
882     return best_type_stats->best_compmode_interinter_cost;
883   update_mbmi_for_compound_type(mbmi, winner_comp_type);
884   mbmi->interinter_comp = x->comp_rd_stats[match_index].interinter_comp;
885   *rd = RDCOST(
886       x->rdmult,
887       comp_rs2[winner_comp_type] + *rate_mv + comp_rate[winner_comp_type],
888       comp_dist[winner_comp_type]);
889   mbmi->mv[0].as_int = cur_mv[0].as_int;
890   mbmi->mv[1].as_int = cur_mv[1].as_int;
891   return comp_rs2[winner_comp_type];
892 }
893 
894 // Updates rd cost and relevant compound type data for the best compound type
update_best_info(const MB_MODE_INFO * const mbmi,int64_t * rd,BEST_COMP_TYPE_STATS * best_type_stats,int64_t best_rd_cur,int64_t comp_model_rd_cur,int rs2)895 static INLINE void update_best_info(const MB_MODE_INFO *const mbmi, int64_t *rd,
896                                     BEST_COMP_TYPE_STATS *best_type_stats,
897                                     int64_t best_rd_cur,
898                                     int64_t comp_model_rd_cur, int rs2) {
899   *rd = best_rd_cur;
900   best_type_stats->comp_best_model_rd = comp_model_rd_cur;
901   best_type_stats->best_compound_data = mbmi->interinter_comp;
902   best_type_stats->best_compmode_interinter_cost = rs2;
903 }
904 
905 // Updates best_mv for masked compound types
update_mask_best_mv(const MB_MODE_INFO * const mbmi,int_mv * best_mv,int_mv * cur_mv,const COMPOUND_TYPE cur_type,int * best_tmp_rate_mv,int tmp_rate_mv,const SPEED_FEATURES * const sf)906 static INLINE void update_mask_best_mv(const MB_MODE_INFO *const mbmi,
907                                        int_mv *best_mv, int_mv *cur_mv,
908                                        const COMPOUND_TYPE cur_type,
909                                        int *best_tmp_rate_mv, int tmp_rate_mv,
910                                        const SPEED_FEATURES *const sf) {
911   if (cur_type == COMPOUND_WEDGE ||
912       (sf->inter_sf.enable_interinter_diffwtd_newmv_search &&
913        cur_type == COMPOUND_DIFFWTD)) {
914     *best_tmp_rate_mv = tmp_rate_mv;
915     best_mv[0].as_int = mbmi->mv[0].as_int;
916     best_mv[1].as_int = mbmi->mv[1].as_int;
917   } else {
918     best_mv[0].as_int = cur_mv[0].as_int;
919     best_mv[1].as_int = cur_mv[1].as_int;
920   }
921 }
922 
923 // Choose the better of the two COMPOUND_AVERAGE,
924 // COMPOUND_DISTWTD based on modeled cost
find_best_avg_distwtd_comp_type(MACROBLOCK * x,int * comp_model_rate,int64_t * comp_model_dist,int rate_mv,int64_t * best_rd)925 static int find_best_avg_distwtd_comp_type(MACROBLOCK *x, int *comp_model_rate,
926                                            int64_t *comp_model_dist,
927                                            int rate_mv, int64_t *best_rd) {
928   int64_t est_rd[2];
929   est_rd[COMPOUND_AVERAGE] =
930       RDCOST(x->rdmult, comp_model_rate[COMPOUND_AVERAGE] + rate_mv,
931              comp_model_dist[COMPOUND_AVERAGE]);
932   est_rd[COMPOUND_DISTWTD] =
933       RDCOST(x->rdmult, comp_model_rate[COMPOUND_DISTWTD] + rate_mv,
934              comp_model_dist[COMPOUND_DISTWTD]);
935   int best_type = (est_rd[COMPOUND_AVERAGE] <= est_rd[COMPOUND_DISTWTD])
936                       ? COMPOUND_AVERAGE
937                       : COMPOUND_DISTWTD;
938   *best_rd = est_rd[best_type];
939   return best_type;
940 }
941 
save_comp_rd_search_stat(MACROBLOCK * x,const MB_MODE_INFO * const mbmi,const int32_t * comp_rate,const int64_t * comp_dist,const int32_t * comp_model_rate,const int64_t * comp_model_dist,const int_mv * cur_mv,const int * comp_rs2)942 static INLINE void save_comp_rd_search_stat(
943     MACROBLOCK *x, const MB_MODE_INFO *const mbmi, const int32_t *comp_rate,
944     const int64_t *comp_dist, const int32_t *comp_model_rate,
945     const int64_t *comp_model_dist, const int_mv *cur_mv, const int *comp_rs2) {
946   const int offset = x->comp_rd_stats_idx;
947   if (offset < MAX_COMP_RD_STATS) {
948     COMP_RD_STATS *const rd_stats = x->comp_rd_stats + offset;
949     memcpy(rd_stats->rate, comp_rate, sizeof(rd_stats->rate));
950     memcpy(rd_stats->dist, comp_dist, sizeof(rd_stats->dist));
951     memcpy(rd_stats->model_rate, comp_model_rate, sizeof(rd_stats->model_rate));
952     memcpy(rd_stats->model_dist, comp_model_dist, sizeof(rd_stats->model_dist));
953     memcpy(rd_stats->comp_rs2, comp_rs2, sizeof(rd_stats->comp_rs2));
954     memcpy(rd_stats->mv, cur_mv, sizeof(rd_stats->mv));
955     memcpy(rd_stats->ref_frames, mbmi->ref_frame, sizeof(rd_stats->ref_frames));
956     rd_stats->mode = mbmi->mode;
957     rd_stats->filter = mbmi->interp_filters;
958     rd_stats->ref_mv_idx = mbmi->ref_mv_idx;
959     const MACROBLOCKD *const xd = &x->e_mbd;
960     for (int i = 0; i < 2; ++i) {
961       const WarpedMotionParams *const wm =
962           &xd->global_motion[mbmi->ref_frame[i]];
963       rd_stats->is_global[i] = is_global_mv_block(mbmi, wm->wmtype);
964     }
965     memcpy(&rd_stats->interinter_comp, &mbmi->interinter_comp,
966            sizeof(rd_stats->interinter_comp));
967     ++x->comp_rd_stats_idx;
968   }
969 }
970 
get_interinter_compound_mask_rate(const MACROBLOCK * const x,const MB_MODE_INFO * const mbmi)971 static INLINE int get_interinter_compound_mask_rate(
972     const MACROBLOCK *const x, const MB_MODE_INFO *const mbmi) {
973   const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
974   // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
975   if (compound_type == COMPOUND_WEDGE) {
976     return av1_is_wedge_used(mbmi->sb_type)
977                ? av1_cost_literal(1) +
978                      x->wedge_idx_cost[mbmi->sb_type]
979                                       [mbmi->interinter_comp.wedge_index]
980                : 0;
981   } else {
982     assert(compound_type == COMPOUND_DIFFWTD);
983     return av1_cost_literal(1);
984   }
985 }
986 
987 // Takes a backup of rate, distortion and model_rd for future reuse
backup_stats(COMPOUND_TYPE cur_type,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int rate_sum,int64_t dist_sum,RD_STATS * rd_stats,int * comp_rs2,int rs2)988 static INLINE void backup_stats(COMPOUND_TYPE cur_type, int32_t *comp_rate,
989                                 int64_t *comp_dist, int32_t *comp_model_rate,
990                                 int64_t *comp_model_dist, int rate_sum,
991                                 int64_t dist_sum, RD_STATS *rd_stats,
992                                 int *comp_rs2, int rs2) {
993   comp_rate[cur_type] = rd_stats->rate;
994   comp_dist[cur_type] = rd_stats->dist;
995   comp_model_rate[cur_type] = rate_sum;
996   comp_model_dist[cur_type] = dist_sum;
997   comp_rs2[cur_type] = rs2;
998 }
999 
masked_compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const BLOCK_SIZE bsize,const PREDICTION_MODE this_mode,int * rs2,int rate_mv,const BUFFER_SET * ctx,int * out_rate_mv,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides,int mode_rate,int64_t rd_thresh,int * calc_pred_masked_compound,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,const int64_t comp_best_model_rd,int64_t * const comp_model_rd_cur,int * comp_rs2,int64_t ref_skip_rd)1000 static int64_t masked_compound_type_rd(
1001     const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
1002     const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
1003     int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
1004     uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
1005     int mode_rate, int64_t rd_thresh, int *calc_pred_masked_compound,
1006     int32_t *comp_rate, int64_t *comp_dist, int32_t *comp_model_rate,
1007     int64_t *comp_model_dist, const int64_t comp_best_model_rd,
1008     int64_t *const comp_model_rd_cur, int *comp_rs2, int64_t ref_skip_rd) {
1009   const AV1_COMMON *const cm = &cpi->common;
1010   MACROBLOCKD *xd = &x->e_mbd;
1011   MB_MODE_INFO *const mbmi = xd->mi[0];
1012   int64_t best_rd_cur = INT64_MAX;
1013   int64_t rd = INT64_MAX;
1014   const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
1015   // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
1016   assert(compound_type == COMPOUND_WEDGE || compound_type == COMPOUND_DIFFWTD);
1017   int rate_sum, tmp_skip_txfm_sb;
1018   int64_t dist_sum, tmp_skip_sse_sb;
1019   pick_interinter_mask_type pick_interinter_mask[2] = { pick_interinter_wedge,
1020                                                         pick_interinter_seg };
1021 
1022   // TODO(any): Save pred and mask calculation as well into records. However
1023   // this may increase memory requirements as compound segment mask needs to be
1024   // stored in each record.
1025   if (*calc_pred_masked_compound) {
1026     get_inter_predictors_masked_compound(x, bsize, preds0, preds1, residual1,
1027                                          diff10, strides);
1028     *calc_pred_masked_compound = 0;
1029   }
1030   if (cpi->sf.inter_sf.prune_wedge_pred_diff_based &&
1031       compound_type == COMPOUND_WEDGE) {
1032     unsigned int sse;
1033     if (is_cur_buf_hbd(xd))
1034       (void)cpi->fn_ptr[bsize].vf(CONVERT_TO_BYTEPTR(*preds0), *strides,
1035                                   CONVERT_TO_BYTEPTR(*preds1), *strides, &sse);
1036     else
1037       (void)cpi->fn_ptr[bsize].vf(*preds0, *strides, *preds1, *strides, &sse);
1038     const unsigned int mse =
1039         ROUND_POWER_OF_TWO(sse, num_pels_log2_lookup[bsize]);
1040     // If two predictors are very similar, skip wedge compound mode search
1041     if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64)) {
1042       *comp_model_rd_cur = INT64_MAX;
1043       return INT64_MAX;
1044     }
1045   }
1046   // Function pointer to pick the appropriate mask
1047   // compound_type == COMPOUND_WEDGE, calls pick_interinter_wedge()
1048   // compound_type == COMPOUND_DIFFWTD, calls pick_interinter_seg()
1049   uint64_t cur_sse = UINT64_MAX;
1050   best_rd_cur = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
1051       cpi, x, bsize, *preds0, *preds1, residual1, diff10, &cur_sse);
1052   *rs2 += get_interinter_compound_mask_rate(x, mbmi);
1053   best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
1054   assert(cur_sse != UINT64_MAX);
1055   int64_t skip_rd_cur = RDCOST(x->rdmult, *rs2 + rate_mv, (cur_sse << 4));
1056 
1057   // Although the true rate_mv might be different after motion search, but it
1058   // is unlikely to be the best mode considering the transform rd cost and other
1059   // mode overhead cost
1060   int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
1061   if (mode_rd > rd_thresh) {
1062     *comp_model_rd_cur = INT64_MAX;
1063     return INT64_MAX;
1064   }
1065 
1066   // Check if the mode is good enough based on skip rd
1067   // TODO(nithya): Handle wedge_newmv_search if extending for lower speed
1068   // setting
1069   if (cpi->sf.inter_sf.txfm_rd_gate_level) {
1070     int eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd_cur,
1071                                     cpi->sf.inter_sf.txfm_rd_gate_level, 1);
1072     if (!eval_txfm) {
1073       *comp_model_rd_cur = INT64_MAX;
1074       return INT64_MAX;
1075     }
1076   }
1077 
1078   // Compute cost if matching record not found, else, reuse data
1079   if (comp_rate[compound_type] == INT_MAX) {
1080     // Check whether new MV search for wedge is to be done
1081     int wedge_newmv_search =
1082         have_newmv_in_inter_mode(this_mode) &&
1083         (compound_type == COMPOUND_WEDGE) &&
1084         (!cpi->sf.inter_sf.disable_interinter_wedge_newmv_search);
1085     int diffwtd_newmv_search =
1086         cpi->sf.inter_sf.enable_interinter_diffwtd_newmv_search &&
1087         compound_type == COMPOUND_DIFFWTD &&
1088         have_newmv_in_inter_mode(this_mode);
1089 
1090     // Search for new MV if needed and build predictor
1091     if (wedge_newmv_search) {
1092       *out_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1093                                                            bsize, this_mode);
1094       const int mi_row = xd->mi_row;
1095       const int mi_col = xd->mi_col;
1096       av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, ctx, bsize,
1097                                     AOM_PLANE_Y, AOM_PLANE_Y);
1098     } else if (diffwtd_newmv_search) {
1099       *out_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1100                                                            bsize, this_mode);
1101       // we need to update the mask according to the new motion vector
1102       CompoundTypeRdBuffers tmp_buf;
1103       int64_t tmp_rd = INT64_MAX;
1104       alloc_compound_type_rd_buffers_no_check(&tmp_buf);
1105 
1106       uint8_t *tmp_preds0[1] = { tmp_buf.pred0 };
1107       uint8_t *tmp_preds1[1] = { tmp_buf.pred1 };
1108 
1109       get_inter_predictors_masked_compound(x, bsize, tmp_preds0, tmp_preds1,
1110                                            tmp_buf.residual1, tmp_buf.diff10,
1111                                            strides);
1112 
1113       tmp_rd = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
1114           cpi, x, bsize, *tmp_preds0, *tmp_preds1, tmp_buf.residual1,
1115           tmp_buf.diff10, &cur_sse);
1116       // we can reuse rs2 here
1117       tmp_rd += RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
1118 
1119       if (tmp_rd >= best_rd_cur) {
1120         // restore the motion vector
1121         mbmi->mv[0].as_int = cur_mv[0].as_int;
1122         mbmi->mv[1].as_int = cur_mv[1].as_int;
1123         *out_rate_mv = rate_mv;
1124         av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
1125                                                  strides, preds1, strides);
1126       } else {
1127         // build the final prediciton using the updated mv
1128         av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, tmp_preds0,
1129                                                  strides, tmp_preds1, strides);
1130       }
1131       av1_release_compound_type_rd_buffers(&tmp_buf);
1132     } else {
1133       *out_rate_mv = rate_mv;
1134       av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
1135                                                preds1, strides);
1136     }
1137     // Get the RD cost from model RD
1138     model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
1139         cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
1140         &tmp_skip_sse_sb, NULL, NULL, NULL);
1141     rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
1142     *comp_model_rd_cur = rd;
1143     // Override with best if current is worse than best for new MV
1144     if (wedge_newmv_search) {
1145       if (rd >= best_rd_cur) {
1146         mbmi->mv[0].as_int = cur_mv[0].as_int;
1147         mbmi->mv[1].as_int = cur_mv[1].as_int;
1148         *out_rate_mv = rate_mv;
1149         av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
1150                                                  strides, preds1, strides);
1151         *comp_model_rd_cur = best_rd_cur;
1152       }
1153     }
1154     if (cpi->sf.inter_sf.prune_comp_type_by_model_rd &&
1155         (*comp_model_rd_cur > comp_best_model_rd) &&
1156         comp_best_model_rd != INT64_MAX) {
1157       *comp_model_rd_cur = INT64_MAX;
1158       return INT64_MAX;
1159     }
1160     // Compute RD cost for the current type
1161     RD_STATS rd_stats;
1162     const int64_t tmp_mode_rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
1163     const int64_t tmp_rd_thresh = rd_thresh - tmp_mode_rd;
1164     rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
1165     if (rd != INT64_MAX) {
1166       rd =
1167           RDCOST(x->rdmult, *rs2 + *out_rate_mv + rd_stats.rate, rd_stats.dist);
1168       // Backup rate and distortion for future reuse
1169       backup_stats(compound_type, comp_rate, comp_dist, comp_model_rate,
1170                    comp_model_dist, rate_sum, dist_sum, &rd_stats, comp_rs2,
1171                    *rs2);
1172     }
1173   } else {
1174     // Reuse data as matching record is found
1175     assert(comp_dist[compound_type] != INT64_MAX);
1176     // When disable_interinter_wedge_newmv_search is set, motion refinement is
1177     // disabled. Hence rate and distortion can be reused in this case as well
1178     assert(IMPLIES(have_newmv_in_inter_mode(this_mode),
1179                    cpi->sf.inter_sf.disable_interinter_wedge_newmv_search));
1180     assert(mbmi->mv[0].as_int == cur_mv[0].as_int);
1181     assert(mbmi->mv[1].as_int == cur_mv[1].as_int);
1182     *out_rate_mv = rate_mv;
1183     // Calculate RD cost based on stored stats
1184     rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_rate[compound_type],
1185                 comp_dist[compound_type]);
1186     // Recalculate model rdcost with the updated rate
1187     *comp_model_rd_cur =
1188         RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_model_rate[compound_type],
1189                comp_model_dist[compound_type]);
1190   }
1191   return rd;
1192 }
1193 
1194 // scaling values to be used for gating wedge/compound segment based on best
1195 // approximate rd
1196 static int comp_type_rd_threshold_mul[3] = { 1, 11, 12 };
1197 static int comp_type_rd_threshold_div[3] = { 3, 16, 16 };
1198 
av1_compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int_mv * cur_mv,int mode_search_mask,int masked_compound_used,const BUFFER_SET * orig_dst,const BUFFER_SET * tmp_dst,const CompoundTypeRdBuffers * buffers,int * rate_mv,int64_t * rd,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t ref_skip_rd,int * is_luma_interp_done,int64_t rd_thresh)1199 int av1_compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x,
1200                          BLOCK_SIZE bsize, int_mv *cur_mv, int mode_search_mask,
1201                          int masked_compound_used, const BUFFER_SET *orig_dst,
1202                          const BUFFER_SET *tmp_dst,
1203                          const CompoundTypeRdBuffers *buffers, int *rate_mv,
1204                          int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
1205                          int64_t ref_skip_rd, int *is_luma_interp_done,
1206                          int64_t rd_thresh) {
1207   const AV1_COMMON *cm = &cpi->common;
1208   MACROBLOCKD *xd = &x->e_mbd;
1209   MB_MODE_INFO *mbmi = xd->mi[0];
1210   const PREDICTION_MODE this_mode = mbmi->mode;
1211   const int bw = block_size_wide[bsize];
1212   int rs2;
1213   int_mv best_mv[2];
1214   int best_tmp_rate_mv = *rate_mv;
1215   BEST_COMP_TYPE_STATS best_type_stats;
1216   // Initializing BEST_COMP_TYPE_STATS
1217   best_type_stats.best_compound_data.type = COMPOUND_AVERAGE;
1218   best_type_stats.best_compmode_interinter_cost = 0;
1219   best_type_stats.comp_best_model_rd = INT64_MAX;
1220 
1221   uint8_t *preds0[1] = { buffers->pred0 };
1222   uint8_t *preds1[1] = { buffers->pred1 };
1223   int strides[1] = { bw };
1224   int tmp_rate_mv;
1225   const int num_pix = 1 << num_pels_log2_lookup[bsize];
1226   const int mask_len = 2 * num_pix * sizeof(uint8_t);
1227   COMPOUND_TYPE cur_type;
1228   // Local array to store the mask cost for different compound types
1229   int masked_type_cost[COMPOUND_TYPES];
1230 
1231   int calc_pred_masked_compound = 1;
1232   int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
1233                                         INT64_MAX };
1234   int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
1235   int comp_rs2[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
1236   int32_t comp_model_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX,
1237                                               INT_MAX };
1238   int64_t comp_model_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
1239                                               INT64_MAX };
1240   int match_index = 0;
1241   const int match_found =
1242       find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist, comp_model_rate,
1243                             comp_model_dist, comp_rs2, &match_index);
1244   best_mv[0].as_int = cur_mv[0].as_int;
1245   best_mv[1].as_int = cur_mv[1].as_int;
1246   *rd = INT64_MAX;
1247   int rate_sum, tmp_skip_txfm_sb;
1248   int64_t dist_sum, tmp_skip_sse_sb;
1249 
1250   // Local array to store the valid compound types to be evaluated in the core
1251   // loop
1252   COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
1253     COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
1254   };
1255   int valid_type_count = 0;
1256   int try_average_and_distwtd_comp = 0;
1257   // compute_valid_comp_types() returns the number of valid compound types to be
1258   // evaluated and populates the same in the local array valid_comp_types[].
1259   // It also sets the flag 'try_average_and_distwtd_comp'
1260   valid_type_count = compute_valid_comp_types(
1261       x, cpi, &try_average_and_distwtd_comp, bsize, masked_compound_used,
1262       mode_search_mask, valid_comp_types);
1263 
1264   // The following context indices are independent of compound type
1265   const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
1266   const int comp_index_ctx = get_comp_index_context(cm, xd);
1267 
1268   // Populates masked_type_cost local array for the 4 compound types
1269   calc_masked_type_cost(x, bsize, comp_group_idx_ctx, comp_index_ctx,
1270                         masked_compound_used, masked_type_cost);
1271 
1272   int64_t comp_model_rd_cur = INT64_MAX;
1273   int64_t best_rd_cur = INT64_MAX;
1274   const int mi_row = xd->mi_row;
1275   const int mi_col = xd->mi_col;
1276 
1277   // If the match is found, calculate the rd cost using the
1278   // stored stats and update the mbmi appropriately.
1279   if (match_found && cpi->sf.inter_sf.reuse_compound_type_decision) {
1280     return populate_reuse_comp_type_data(x, mbmi, &best_type_stats, cur_mv,
1281                                          comp_rate, comp_dist, comp_rs2,
1282                                          rate_mv, rd, match_index);
1283   }
1284   // Special handling if both compound_average and compound_distwtd
1285   // are to be searched. In this case, first estimate between the two
1286   // modes and then call estimate_yrd_for_sb() only for the better of
1287   // the two.
1288   if (try_average_and_distwtd_comp) {
1289     int est_rate[2];
1290     int64_t est_dist[2], est_rd;
1291     COMPOUND_TYPE best_type;
1292     // Since modelled rate and dist are separately stored,
1293     // compute better of COMPOUND_AVERAGE and COMPOUND_DISTWTD
1294     // using the stored stats.
1295     if ((comp_model_rate[COMPOUND_AVERAGE] != INT_MAX) &&
1296         comp_model_rate[COMPOUND_DISTWTD] != INT_MAX) {
1297       // Choose the better of the COMPOUND_AVERAGE,
1298       // COMPOUND_DISTWTD on modeled cost.
1299       best_type = find_best_avg_distwtd_comp_type(
1300           x, comp_model_rate, comp_model_dist, *rate_mv, &est_rd);
1301       update_mbmi_for_compound_type(mbmi, best_type);
1302       if (comp_rate[best_type] != INT_MAX)
1303         best_rd_cur = RDCOST(
1304             x->rdmult,
1305             masked_type_cost[best_type] + *rate_mv + comp_rate[best_type],
1306             comp_dist[best_type]);
1307       comp_model_rd_cur = est_rd;
1308       // Update stats for best compound type
1309       if (best_rd_cur < *rd) {
1310         update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
1311                          comp_model_rd_cur, masked_type_cost[best_type]);
1312       }
1313       restore_dst_buf(xd, *tmp_dst, 1);
1314     } else {
1315       int64_t sse_y[COMPOUND_DISTWTD + 1];
1316       // Calculate model_rd for COMPOUND_AVERAGE and COMPOUND_DISTWTD
1317       for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
1318            comp_type++) {
1319         update_mbmi_for_compound_type(mbmi, comp_type);
1320         av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1321                                       AOM_PLANE_Y, AOM_PLANE_Y);
1322         model_rd_sb_fn[MODELRD_CURVFIT](
1323             cpi, bsize, x, xd, 0, 0, &est_rate[comp_type], &est_dist[comp_type],
1324             NULL, NULL, NULL, NULL, NULL);
1325         est_rate[comp_type] += masked_type_cost[comp_type];
1326         comp_model_rate[comp_type] = est_rate[comp_type];
1327         comp_model_dist[comp_type] = est_dist[comp_type];
1328         sse_y[comp_type] = x->pred_sse[xd->mi[0]->ref_frame[0]];
1329         if (comp_type == COMPOUND_AVERAGE) {
1330           *is_luma_interp_done = 1;
1331           restore_dst_buf(xd, *tmp_dst, 1);
1332         }
1333       }
1334       // Choose the better of the two based on modeled cost and call
1335       // estimate_yrd_for_sb() for that one.
1336       best_type = find_best_avg_distwtd_comp_type(
1337           x, comp_model_rate, comp_model_dist, *rate_mv, &est_rd);
1338       update_mbmi_for_compound_type(mbmi, best_type);
1339       if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *orig_dst, 1);
1340       rs2 = masked_type_cost[best_type];
1341       RD_STATS est_rd_stats;
1342       const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
1343       const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
1344       int64_t est_rd_ = INT64_MAX;
1345       int eval_txfm = 1;
1346       // Check if the mode is good enough based on skip rd
1347       if (cpi->sf.inter_sf.txfm_rd_gate_level) {
1348         int64_t skip_rd =
1349             RDCOST(x->rdmult, rs2 + *rate_mv, (sse_y[best_type] << 4));
1350         eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd,
1351                                     cpi->sf.inter_sf.txfm_rd_gate_level, 1);
1352       }
1353       // Evaluate further if skip rd is low enough
1354       if (eval_txfm) {
1355         est_rd_ =
1356             estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
1357       }
1358 
1359       if (est_rd_ != INT64_MAX) {
1360         best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
1361                              est_rd_stats.dist);
1362         // Backup rate and distortion for future reuse
1363         backup_stats(best_type, comp_rate, comp_dist, comp_model_rate,
1364                      comp_model_dist, est_rate[best_type], est_dist[best_type],
1365                      &est_rd_stats, comp_rs2, rs2);
1366         comp_model_rd_cur = est_rd;
1367       }
1368       if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1369       // Update stats for best compound type
1370       if (best_rd_cur < *rd) {
1371         update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
1372                          comp_model_rd_cur, rs2);
1373       }
1374     }
1375   }
1376 
1377   // If COMPOUND_AVERAGE is not valid, use the spare buffer
1378   if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1379 
1380   // Loop over valid compound types
1381   for (int i = 0; i < valid_type_count; i++) {
1382     cur_type = valid_comp_types[i];
1383     comp_model_rd_cur = INT64_MAX;
1384     tmp_rate_mv = *rate_mv;
1385     best_rd_cur = INT64_MAX;
1386 
1387     // Case COMPOUND_AVERAGE and COMPOUND_DISTWTD
1388     if (cur_type < COMPOUND_WEDGE) {
1389       update_mbmi_for_compound_type(mbmi, cur_type);
1390       rs2 = masked_type_cost[cur_type];
1391       const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
1392       if (mode_rd < ref_best_rd) {
1393         // Reuse data if matching record is found
1394         if (comp_rate[cur_type] == INT_MAX) {
1395           av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1396                                         AOM_PLANE_Y, AOM_PLANE_Y);
1397           if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
1398 
1399           // Compute RD cost for the current type
1400           RD_STATS est_rd_stats;
1401           const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
1402           int64_t est_rd = INT64_MAX;
1403           int eval_txfm = 1;
1404           // Check if the mode is good enough based on skip rd
1405           if (cpi->sf.inter_sf.txfm_rd_gate_level) {
1406             int64_t sse_y = compute_sse_plane(x, xd, PLANE_TYPE_Y, bsize);
1407             int64_t skip_rd = RDCOST(x->rdmult, rs2 + *rate_mv, (sse_y << 4));
1408             eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd,
1409                                         cpi->sf.inter_sf.txfm_rd_gate_level, 1);
1410           }
1411           // Evaluate further if skip rd is low enough
1412           if (eval_txfm) {
1413             est_rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh,
1414                                          &est_rd_stats);
1415           }
1416 
1417           if (est_rd != INT64_MAX) {
1418             best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
1419                                  est_rd_stats.dist);
1420             model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
1421                 cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
1422                 &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
1423             comp_model_rd_cur =
1424                 RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
1425 
1426             // Backup rate and distortion for future reuse
1427             backup_stats(cur_type, comp_rate, comp_dist, comp_model_rate,
1428                          comp_model_dist, rate_sum, dist_sum, &est_rd_stats,
1429                          comp_rs2, rs2);
1430           }
1431         } else {
1432           // Calculate RD cost based on stored stats
1433           assert(comp_dist[cur_type] != INT64_MAX);
1434           best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
1435                                comp_dist[cur_type]);
1436           // Recalculate model rdcost with the updated rate
1437           comp_model_rd_cur =
1438               RDCOST(x->rdmult, rs2 + *rate_mv + comp_model_rate[cur_type],
1439                      comp_model_dist[cur_type]);
1440         }
1441       }
1442       // use spare buffer for following compound type try
1443       if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1444     } else {
1445       // Handle masked compound types
1446       update_mbmi_for_compound_type(mbmi, cur_type);
1447       rs2 = masked_type_cost[cur_type];
1448       // Factors to control gating of compound type selection based on best
1449       // approximate rd so far
1450       const int max_comp_type_rd_threshold_mul =
1451           comp_type_rd_threshold_mul[cpi->sf.inter_sf
1452                                          .prune_comp_type_by_comp_avg];
1453       const int max_comp_type_rd_threshold_div =
1454           comp_type_rd_threshold_div[cpi->sf.inter_sf
1455                                          .prune_comp_type_by_comp_avg];
1456       // Evaluate COMPOUND_WEDGE / COMPOUND_DIFFWTD if approximated cost is
1457       // within threshold
1458       int64_t approx_rd = ((*rd / max_comp_type_rd_threshold_div) *
1459                            max_comp_type_rd_threshold_mul);
1460 
1461       if (approx_rd < ref_best_rd) {
1462         const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
1463         best_rd_cur = masked_compound_type_rd(
1464             cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
1465             &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
1466             strides, rd_stats->rate, tmp_rd_thresh, &calc_pred_masked_compound,
1467             comp_rate, comp_dist, comp_model_rate, comp_model_dist,
1468             best_type_stats.comp_best_model_rd, &comp_model_rd_cur, comp_rs2,
1469             ref_skip_rd);
1470       }
1471     }
1472     // Update stats for best compound type
1473     if (best_rd_cur < *rd) {
1474       update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
1475                        comp_model_rd_cur, rs2);
1476       if (masked_compound_used && cur_type >= COMPOUND_WEDGE) {
1477         memcpy(buffers->tmp_best_mask_buf, xd->seg_mask, mask_len);
1478         if (have_newmv_in_inter_mode(this_mode))
1479           update_mask_best_mv(mbmi, best_mv, cur_mv, cur_type,
1480                               &best_tmp_rate_mv, tmp_rate_mv, &cpi->sf);
1481       }
1482     }
1483     // reset to original mvs for next iteration
1484     mbmi->mv[0].as_int = cur_mv[0].as_int;
1485     mbmi->mv[1].as_int = cur_mv[1].as_int;
1486   }
1487   if (mbmi->interinter_comp.type != best_type_stats.best_compound_data.type) {
1488     mbmi->comp_group_idx =
1489         (best_type_stats.best_compound_data.type < COMPOUND_WEDGE) ? 0 : 1;
1490     mbmi->compound_idx =
1491         !(best_type_stats.best_compound_data.type == COMPOUND_DISTWTD);
1492     mbmi->interinter_comp = best_type_stats.best_compound_data;
1493     memcpy(xd->seg_mask, buffers->tmp_best_mask_buf, mask_len);
1494   }
1495   if (have_newmv_in_inter_mode(this_mode)) {
1496     mbmi->mv[0].as_int = best_mv[0].as_int;
1497     mbmi->mv[1].as_int = best_mv[1].as_int;
1498     if (mbmi->interinter_comp.type == COMPOUND_WEDGE) {
1499       rd_stats->rate += best_tmp_rate_mv - *rate_mv;
1500       *rate_mv = best_tmp_rate_mv;
1501     }
1502   }
1503   restore_dst_buf(xd, *orig_dst, 1);
1504   if (!match_found)
1505     save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, comp_model_rate,
1506                              comp_model_dist, cur_mv, comp_rs2);
1507   return best_type_stats.best_compmode_interinter_cost;
1508 }
1509