1 /*
2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include "av1/common/pred_common.h"
13 #include "av1/encoder/compound_type.h"
14 #include "av1/encoder/model_rd.h"
15 #include "av1/encoder/motion_search_facade.h"
16 #include "av1/encoder/rdopt_utils.h"
17 #include "av1/encoder/reconinter_enc.h"
18 #include "av1/encoder/tx_search.h"
19
20 typedef int64_t (*pick_interinter_mask_type)(
21 const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
22 const uint8_t *const p0, const uint8_t *const p1,
23 const int16_t *const residual1, const int16_t *const diff10,
24 uint64_t *best_sse);
25
26 // Checks if characteristics of search match
is_comp_rd_match(const AV1_COMP * const cpi,const MACROBLOCK * const x,const COMP_RD_STATS * st,const MB_MODE_INFO * const mi,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int * comp_rs2)27 static INLINE int is_comp_rd_match(const AV1_COMP *const cpi,
28 const MACROBLOCK *const x,
29 const COMP_RD_STATS *st,
30 const MB_MODE_INFO *const mi,
31 int32_t *comp_rate, int64_t *comp_dist,
32 int32_t *comp_model_rate,
33 int64_t *comp_model_dist, int *comp_rs2) {
34 // TODO(ranjit): Ensure that compound type search use regular filter always
35 // and check if following check can be removed
36 // Check if interp filter matches with previous case
37 if (st->filter.as_int != mi->interp_filters.as_int) return 0;
38
39 const MACROBLOCKD *const xd = &x->e_mbd;
40 // Match MV and reference indices
41 for (int i = 0; i < 2; ++i) {
42 if ((st->ref_frames[i] != mi->ref_frame[i]) ||
43 (st->mv[i].as_int != mi->mv[i].as_int)) {
44 return 0;
45 }
46 const WarpedMotionParams *const wm = &xd->global_motion[mi->ref_frame[i]];
47 if (is_global_mv_block(mi, wm->wmtype) != st->is_global[i]) return 0;
48 }
49
50 // Store the stats for COMPOUND_AVERAGE and COMPOUND_DISTWTD
51 for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
52 comp_type++) {
53 comp_rate[comp_type] = st->rate[comp_type];
54 comp_dist[comp_type] = st->dist[comp_type];
55 comp_model_rate[comp_type] = st->model_rate[comp_type];
56 comp_model_dist[comp_type] = st->model_dist[comp_type];
57 comp_rs2[comp_type] = st->comp_rs2[comp_type];
58 }
59
60 // For compound wedge/segment, reuse data only if NEWMV is not present in
61 // either of the directions
62 if ((!have_newmv_in_inter_mode(mi->mode) &&
63 !have_newmv_in_inter_mode(st->mode)) ||
64 (cpi->sf.inter_sf.disable_interinter_wedge_newmv_search)) {
65 memcpy(&comp_rate[COMPOUND_WEDGE], &st->rate[COMPOUND_WEDGE],
66 sizeof(comp_rate[COMPOUND_WEDGE]) * 2);
67 memcpy(&comp_dist[COMPOUND_WEDGE], &st->dist[COMPOUND_WEDGE],
68 sizeof(comp_dist[COMPOUND_WEDGE]) * 2);
69 memcpy(&comp_model_rate[COMPOUND_WEDGE], &st->model_rate[COMPOUND_WEDGE],
70 sizeof(comp_model_rate[COMPOUND_WEDGE]) * 2);
71 memcpy(&comp_model_dist[COMPOUND_WEDGE], &st->model_dist[COMPOUND_WEDGE],
72 sizeof(comp_model_dist[COMPOUND_WEDGE]) * 2);
73 memcpy(&comp_rs2[COMPOUND_WEDGE], &st->comp_rs2[COMPOUND_WEDGE],
74 sizeof(comp_rs2[COMPOUND_WEDGE]) * 2);
75 }
76 return 1;
77 }
78
79 // Checks if similar compound type search case is accounted earlier
80 // If found, returns relevant rd data
find_comp_rd_in_stats(const AV1_COMP * const cpi,const MACROBLOCK * x,const MB_MODE_INFO * const mbmi,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int * comp_rs2,int * match_index)81 static INLINE int find_comp_rd_in_stats(const AV1_COMP *const cpi,
82 const MACROBLOCK *x,
83 const MB_MODE_INFO *const mbmi,
84 int32_t *comp_rate, int64_t *comp_dist,
85 int32_t *comp_model_rate,
86 int64_t *comp_model_dist, int *comp_rs2,
87 int *match_index) {
88 for (int j = 0; j < x->comp_rd_stats_idx; ++j) {
89 if (is_comp_rd_match(cpi, x, &x->comp_rd_stats[j], mbmi, comp_rate,
90 comp_dist, comp_model_rate, comp_model_dist,
91 comp_rs2)) {
92 *match_index = j;
93 return 1;
94 }
95 }
96 return 0; // no match result found
97 }
98
enable_wedge_search(MACROBLOCK * const x,const AV1_COMP * const cpi)99 static INLINE bool enable_wedge_search(MACROBLOCK *const x,
100 const AV1_COMP *const cpi) {
101 // Enable wedge search if source variance and edge strength are above
102 // the thresholds.
103 return x->source_variance >
104 cpi->sf.inter_sf.disable_wedge_search_var_thresh &&
105 x->edge_strength > cpi->sf.inter_sf.disable_wedge_search_edge_thresh;
106 }
107
enable_wedge_interinter_search(MACROBLOCK * const x,const AV1_COMP * const cpi)108 static INLINE bool enable_wedge_interinter_search(MACROBLOCK *const x,
109 const AV1_COMP *const cpi) {
110 return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interinter_wedge &&
111 !cpi->sf.inter_sf.disable_interinter_wedge;
112 }
113
enable_wedge_interintra_search(MACROBLOCK * const x,const AV1_COMP * const cpi)114 static INLINE bool enable_wedge_interintra_search(MACROBLOCK *const x,
115 const AV1_COMP *const cpi) {
116 return enable_wedge_search(x, cpi) && cpi->oxcf.enable_interintra_wedge &&
117 !cpi->sf.inter_sf.disable_wedge_interintra_search;
118 }
119
estimate_wedge_sign(const AV1_COMP * cpi,const MACROBLOCK * x,const BLOCK_SIZE bsize,const uint8_t * pred0,int stride0,const uint8_t * pred1,int stride1)120 static int8_t estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x,
121 const BLOCK_SIZE bsize, const uint8_t *pred0,
122 int stride0, const uint8_t *pred1,
123 int stride1) {
124 static const BLOCK_SIZE split_qtr[BLOCK_SIZES_ALL] = {
125 // 4X4
126 BLOCK_INVALID,
127 // 4X8, 8X4, 8X8
128 BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X4,
129 // 8X16, 16X8, 16X16
130 BLOCK_4X8, BLOCK_8X4, BLOCK_8X8,
131 // 16X32, 32X16, 32X32
132 BLOCK_8X16, BLOCK_16X8, BLOCK_16X16,
133 // 32X64, 64X32, 64X64
134 BLOCK_16X32, BLOCK_32X16, BLOCK_32X32,
135 // 64x128, 128x64, 128x128
136 BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
137 // 4X16, 16X4, 8X32
138 BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X16,
139 // 32X8, 16X64, 64X16
140 BLOCK_16X4, BLOCK_8X32, BLOCK_32X8
141 };
142 const struct macroblock_plane *const p = &x->plane[0];
143 const uint8_t *src = p->src.buf;
144 int src_stride = p->src.stride;
145 const int bw = block_size_wide[bsize];
146 const int bh = block_size_high[bsize];
147 const int bw_by2 = bw >> 1;
148 const int bh_by2 = bh >> 1;
149 uint32_t esq[2][2];
150 int64_t tl, br;
151
152 const BLOCK_SIZE f_index = split_qtr[bsize];
153 assert(f_index != BLOCK_INVALID);
154
155 if (is_cur_buf_hbd(&x->e_mbd)) {
156 pred0 = CONVERT_TO_BYTEPTR(pred0);
157 pred1 = CONVERT_TO_BYTEPTR(pred1);
158 }
159
160 // Residual variance computation over relevant quandrants in order to
161 // find TL + BR, TL = sum(1st,2nd,3rd) quadrants of (pred0 - pred1),
162 // BR = sum(2nd,3rd,4th) quadrants of (pred1 - pred0)
163 // The 2nd and 3rd quadrants cancel out in TL + BR
164 // Hence TL + BR = 1st quadrant of (pred0-pred1) + 4th of (pred1-pred0)
165 // TODO(nithya): Sign estimation assumes 45 degrees (1st and 4th quadrants)
166 // for all codebooks; experiment with other quadrant combinations for
167 // 0, 90 and 135 degrees also.
168 cpi->fn_ptr[f_index].vf(src, src_stride, pred0, stride0, &esq[0][0]);
169 cpi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
170 pred0 + bh_by2 * stride0 + bw_by2, stride0,
171 &esq[0][1]);
172 cpi->fn_ptr[f_index].vf(src, src_stride, pred1, stride1, &esq[1][0]);
173 cpi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
174 pred1 + bh_by2 * stride1 + bw_by2, stride0,
175 &esq[1][1]);
176
177 tl = ((int64_t)esq[0][0]) - ((int64_t)esq[1][0]);
178 br = ((int64_t)esq[1][1]) - ((int64_t)esq[0][1]);
179 return (tl + br > 0);
180 }
181
182 // Choose the best wedge index and sign
pick_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const int16_t * const residual1,const int16_t * const diff10,int8_t * const best_wedge_sign,int8_t * const best_wedge_index,uint64_t * best_sse)183 static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
184 const BLOCK_SIZE bsize, const uint8_t *const p0,
185 const int16_t *const residual1,
186 const int16_t *const diff10,
187 int8_t *const best_wedge_sign,
188 int8_t *const best_wedge_index, uint64_t *best_sse) {
189 const MACROBLOCKD *const xd = &x->e_mbd;
190 const struct buf_2d *const src = &x->plane[0].src;
191 const int bw = block_size_wide[bsize];
192 const int bh = block_size_high[bsize];
193 const int N = bw * bh;
194 assert(N >= 64);
195 int rate;
196 int64_t dist;
197 int64_t rd, best_rd = INT64_MAX;
198 int8_t wedge_index;
199 int8_t wedge_sign;
200 const int8_t wedge_types = get_wedge_types_lookup(bsize);
201 const uint8_t *mask;
202 uint64_t sse;
203 const int hbd = is_cur_buf_hbd(xd);
204 const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
205
206 DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]); // src - pred0
207 #if CONFIG_AV1_HIGHBITDEPTH
208 if (hbd) {
209 aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride,
210 CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
211 } else {
212 aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
213 }
214 #else
215 (void)hbd;
216 aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
217 #endif
218
219 int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) -
220 (int64_t)aom_sum_squares_i16(residual1, N)) *
221 (1 << WEDGE_WEIGHT_BITS) / 2;
222 int16_t *ds = residual0;
223
224 av1_wedge_compute_delta_squares(ds, residual0, residual1, N);
225
226 for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
227 mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize);
228
229 wedge_sign = av1_wedge_sign_from_residuals(ds, mask, N, sign_limit);
230
231 mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
232 sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
233 sse = ROUND_POWER_OF_TWO(sse, bd_round);
234
235 model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
236 &rate, &dist);
237 // int rate2;
238 // int64_t dist2;
239 // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2);
240 // printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n",
241 // sse, rate, dist, rate2, dist2); dist = dist2;
242 // rate = rate2;
243
244 rate += x->wedge_idx_cost[bsize][wedge_index];
245 rd = RDCOST(x->rdmult, rate, dist);
246
247 if (rd < best_rd) {
248 *best_wedge_index = wedge_index;
249 *best_wedge_sign = wedge_sign;
250 best_rd = rd;
251 *best_sse = sse;
252 }
253 }
254
255 return best_rd -
256 RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
257 }
258
259 // Choose the best wedge index the specified sign
pick_wedge_fixed_sign(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const int16_t * const residual1,const int16_t * const diff10,const int8_t wedge_sign,int8_t * const best_wedge_index,uint64_t * best_sse)260 static int64_t pick_wedge_fixed_sign(
261 const AV1_COMP *const cpi, const MACROBLOCK *const x,
262 const BLOCK_SIZE bsize, const int16_t *const residual1,
263 const int16_t *const diff10, const int8_t wedge_sign,
264 int8_t *const best_wedge_index, uint64_t *best_sse) {
265 const MACROBLOCKD *const xd = &x->e_mbd;
266
267 const int bw = block_size_wide[bsize];
268 const int bh = block_size_high[bsize];
269 const int N = bw * bh;
270 assert(N >= 64);
271 int rate;
272 int64_t dist;
273 int64_t rd, best_rd = INT64_MAX;
274 int8_t wedge_index;
275 const int8_t wedge_types = get_wedge_types_lookup(bsize);
276 const uint8_t *mask;
277 uint64_t sse;
278 const int hbd = is_cur_buf_hbd(xd);
279 const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
280 for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
281 mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
282 sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
283 sse = ROUND_POWER_OF_TWO(sse, bd_round);
284
285 model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
286 &rate, &dist);
287 rate += x->wedge_idx_cost[bsize][wedge_index];
288 rd = RDCOST(x->rdmult, rate, dist);
289
290 if (rd < best_rd) {
291 *best_wedge_index = wedge_index;
292 best_rd = rd;
293 *best_sse = sse;
294 }
295 }
296 return best_rd -
297 RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
298 }
299
pick_interinter_wedge(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10,uint64_t * best_sse)300 static int64_t pick_interinter_wedge(
301 const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
302 const uint8_t *const p0, const uint8_t *const p1,
303 const int16_t *const residual1, const int16_t *const diff10,
304 uint64_t *best_sse) {
305 MACROBLOCKD *const xd = &x->e_mbd;
306 MB_MODE_INFO *const mbmi = xd->mi[0];
307 const int bw = block_size_wide[bsize];
308
309 int64_t rd;
310 int8_t wedge_index = -1;
311 int8_t wedge_sign = 0;
312
313 assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
314 assert(cpi->common.seq_params.enable_masked_compound);
315
316 if (cpi->sf.inter_sf.fast_wedge_sign_estimate) {
317 wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
318 rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
319 &wedge_index, best_sse);
320 } else {
321 rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
322 &wedge_index, best_sse);
323 }
324
325 mbmi->interinter_comp.wedge_sign = wedge_sign;
326 mbmi->interinter_comp.wedge_index = wedge_index;
327 return rd;
328 }
329
pick_interinter_seg(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10,uint64_t * best_sse)330 static int64_t pick_interinter_seg(const AV1_COMP *const cpi,
331 MACROBLOCK *const x, const BLOCK_SIZE bsize,
332 const uint8_t *const p0,
333 const uint8_t *const p1,
334 const int16_t *const residual1,
335 const int16_t *const diff10,
336 uint64_t *best_sse) {
337 MACROBLOCKD *const xd = &x->e_mbd;
338 MB_MODE_INFO *const mbmi = xd->mi[0];
339 const int bw = block_size_wide[bsize];
340 const int bh = block_size_high[bsize];
341 const int N = 1 << num_pels_log2_lookup[bsize];
342 int rate;
343 int64_t dist;
344 DIFFWTD_MASK_TYPE cur_mask_type;
345 int64_t best_rd = INT64_MAX;
346 DIFFWTD_MASK_TYPE best_mask_type = 0;
347 const int hbd = is_cur_buf_hbd(xd);
348 const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
349 DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
350 uint8_t *tmp_mask[2] = { xd->seg_mask, seg_mask };
351 // try each mask type and its inverse
352 for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) {
353 // build mask and inverse
354 if (hbd)
355 av1_build_compound_diffwtd_mask_highbd(
356 tmp_mask[cur_mask_type], cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw,
357 CONVERT_TO_BYTEPTR(p1), bw, bh, bw, xd->bd);
358 else
359 av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type,
360 p0, bw, p1, bw, bh, bw);
361
362 // compute rd for mask
363 uint64_t sse = av1_wedge_sse_from_residuals(residual1, diff10,
364 tmp_mask[cur_mask_type], N);
365 sse = ROUND_POWER_OF_TWO(sse, bd_round);
366
367 model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
368 &rate, &dist);
369 const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
370
371 if (rd0 < best_rd) {
372 best_mask_type = cur_mask_type;
373 best_rd = rd0;
374 *best_sse = sse;
375 }
376 }
377 mbmi->interinter_comp.mask_type = best_mask_type;
378 if (best_mask_type == DIFFWTD_38_INV) {
379 memcpy(xd->seg_mask, seg_mask, N * 2);
380 }
381 return best_rd;
382 }
383
pick_interintra_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1)384 static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
385 const MACROBLOCK *const x,
386 const BLOCK_SIZE bsize,
387 const uint8_t *const p0,
388 const uint8_t *const p1) {
389 const MACROBLOCKD *const xd = &x->e_mbd;
390 MB_MODE_INFO *const mbmi = xd->mi[0];
391 assert(av1_is_wedge_used(bsize));
392 assert(cpi->common.seq_params.enable_interintra_compound);
393
394 const struct buf_2d *const src = &x->plane[0].src;
395 const int bw = block_size_wide[bsize];
396 const int bh = block_size_high[bsize];
397 DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]); // src - pred1
398 DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]); // pred1 - pred0
399 #if CONFIG_AV1_HIGHBITDEPTH
400 if (is_cur_buf_hbd(xd)) {
401 aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
402 CONVERT_TO_BYTEPTR(p1), bw, xd->bd);
403 aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw,
404 CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
405 } else {
406 aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
407 aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
408 }
409 #else
410 aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
411 aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
412 #endif
413 int8_t wedge_index = -1;
414 uint64_t sse;
415 int64_t rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0,
416 &wedge_index, &sse);
417
418 mbmi->interintra_wedge_index = wedge_index;
419 return rd;
420 }
421
get_inter_predictors_masked_compound(MACROBLOCK * x,const BLOCK_SIZE bsize,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides)422 static AOM_INLINE void get_inter_predictors_masked_compound(
423 MACROBLOCK *x, const BLOCK_SIZE bsize, uint8_t **preds0, uint8_t **preds1,
424 int16_t *residual1, int16_t *diff10, int *strides) {
425 MACROBLOCKD *xd = &x->e_mbd;
426 const int bw = block_size_wide[bsize];
427 const int bh = block_size_high[bsize];
428 // get inter predictors to use for masked compound modes
429 av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0, preds0,
430 strides);
431 av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1, preds1,
432 strides);
433 const struct buf_2d *const src = &x->plane[0].src;
434 #if CONFIG_AV1_HIGHBITDEPTH
435 if (is_cur_buf_hbd(xd)) {
436 aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
437 CONVERT_TO_BYTEPTR(*preds1), bw, xd->bd);
438 aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(*preds1),
439 bw, CONVERT_TO_BYTEPTR(*preds0), bw, xd->bd);
440 } else {
441 aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1,
442 bw);
443 aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
444 }
445 #else
446 aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1, bw);
447 aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
448 #endif
449 }
450
451 // Computes the rd cost for the given interintra mode and updates the best
compute_best_interintra_mode(const AV1_COMP * const cpi,MB_MODE_INFO * mbmi,MACROBLOCKD * xd,MACROBLOCK * const x,const int * const interintra_mode_cost,const BUFFER_SET * orig_dst,uint8_t * intrapred,const uint8_t * tmp_buf,INTERINTRA_MODE * best_interintra_mode,int64_t * best_interintra_rd,INTERINTRA_MODE interintra_mode,BLOCK_SIZE bsize)452 static INLINE void compute_best_interintra_mode(
453 const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
454 MACROBLOCK *const x, const int *const interintra_mode_cost,
455 const BUFFER_SET *orig_dst, uint8_t *intrapred, const uint8_t *tmp_buf,
456 INTERINTRA_MODE *best_interintra_mode, int64_t *best_interintra_rd,
457 INTERINTRA_MODE interintra_mode, BLOCK_SIZE bsize) {
458 const AV1_COMMON *const cm = &cpi->common;
459 int rate, skip_txfm_sb;
460 int64_t dist, skip_sse_sb;
461 const int bw = block_size_wide[bsize];
462 mbmi->interintra_mode = interintra_mode;
463 int rmode = interintra_mode_cost[interintra_mode];
464 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
465 intrapred, bw);
466 av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
467 model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](cpi, bsize, x, xd, 0, 0, &rate, &dist,
468 &skip_txfm_sb, &skip_sse_sb, NULL,
469 NULL, NULL);
470 int64_t rd = RDCOST(x->rdmult, rate + rmode, dist);
471 if (rd < *best_interintra_rd) {
472 *best_interintra_rd = rd;
473 *best_interintra_mode = mbmi->interintra_mode;
474 }
475 }
476
estimate_yrd_for_sb(const AV1_COMP * const cpi,BLOCK_SIZE bs,MACROBLOCK * x,int64_t ref_best_rd,RD_STATS * rd_stats)477 static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
478 MACROBLOCK *x, int64_t ref_best_rd,
479 RD_STATS *rd_stats) {
480 MACROBLOCKD *const xd = &x->e_mbd;
481 if (ref_best_rd < 0) return INT64_MAX;
482 av1_subtract_plane(x, bs, 0);
483 x->rd_model = LOW_TXFM_RD;
484 const int skip_trellis = (cpi->optimize_seg_arr[xd->mi[0]->segment_id] ==
485 NO_ESTIMATE_YRD_TRELLIS_OPT);
486 const int64_t rd =
487 av1_uniform_txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs,
488 max_txsize_rect_lookup[bs], FTXS_NONE, skip_trellis);
489 x->rd_model = FULL_TXFM_RD;
490 if (rd != INT64_MAX) {
491 const int skip_ctx = av1_get_skip_context(xd);
492 if (rd_stats->skip) {
493 const int s1 = x->skip_cost[skip_ctx][1];
494 rd_stats->rate = s1;
495 } else {
496 const int s0 = x->skip_cost[skip_ctx][0];
497 rd_stats->rate += s0;
498 }
499 }
500 return rd;
501 }
502
503 // Computes the rd_threshold for smooth interintra rd search.
compute_rd_thresh(MACROBLOCK * const x,int total_mode_rate,int64_t ref_best_rd)504 static AOM_INLINE int64_t compute_rd_thresh(MACROBLOCK *const x,
505 int total_mode_rate,
506 int64_t ref_best_rd) {
507 const int64_t rd_thresh = get_rd_thresh_from_best_rd(
508 ref_best_rd, (1 << INTER_INTRA_RD_THRESH_SHIFT),
509 INTER_INTRA_RD_THRESH_SCALE);
510 const int64_t mode_rd = RDCOST(x->rdmult, total_mode_rate, 0);
511 return (rd_thresh - mode_rd);
512 }
513
514 // Computes the best wedge interintra mode
compute_best_wedge_interintra(const AV1_COMP * const cpi,MB_MODE_INFO * mbmi,MACROBLOCKD * xd,MACROBLOCK * const x,const int * const interintra_mode_cost,const BUFFER_SET * orig_dst,uint8_t * intrapred_,uint8_t * tmp_buf_,int * best_mode,int * best_wedge_index,BLOCK_SIZE bsize)515 static AOM_INLINE int64_t compute_best_wedge_interintra(
516 const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
517 MACROBLOCK *const x, const int *const interintra_mode_cost,
518 const BUFFER_SET *orig_dst, uint8_t *intrapred_, uint8_t *tmp_buf_,
519 int *best_mode, int *best_wedge_index, BLOCK_SIZE bsize) {
520 const AV1_COMMON *const cm = &cpi->common;
521 const int bw = block_size_wide[bsize];
522 int64_t best_interintra_rd_wedge = INT64_MAX;
523 int64_t best_total_rd = INT64_MAX;
524 uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
525 for (INTERINTRA_MODE mode = 0; mode < INTERINTRA_MODES; ++mode) {
526 mbmi->interintra_mode = mode;
527 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
528 intrapred, bw);
529 int64_t rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
530 const int rate_overhead =
531 interintra_mode_cost[mode] +
532 x->wedge_idx_cost[bsize][mbmi->interintra_wedge_index];
533 const int64_t total_rd = rd + RDCOST(x->rdmult, rate_overhead, 0);
534 if (total_rd < best_total_rd) {
535 best_total_rd = total_rd;
536 best_interintra_rd_wedge = rd;
537 *best_mode = mbmi->interintra_mode;
538 *best_wedge_index = mbmi->interintra_wedge_index;
539 }
540 }
541 return best_interintra_rd_wedge;
542 }
543
av1_handle_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,MB_MODE_INFO * mbmi,HandleInterModeArgs * args,int64_t ref_best_rd,int * rate_mv,int * tmp_rate2,const BUFFER_SET * orig_dst)544 int av1_handle_inter_intra_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
545 BLOCK_SIZE bsize, MB_MODE_INFO *mbmi,
546 HandleInterModeArgs *args, int64_t ref_best_rd,
547 int *rate_mv, int *tmp_rate2,
548 const BUFFER_SET *orig_dst) {
549 const int try_smooth_interintra = cpi->oxcf.enable_smooth_interintra &&
550 !cpi->sf.inter_sf.disable_smooth_interintra;
551 const int is_wedge_used = av1_is_wedge_used(bsize);
552 const int try_wedge_interintra =
553 is_wedge_used && enable_wedge_interintra_search(x, cpi);
554 if (!try_smooth_interintra && !try_wedge_interintra) return -1;
555
556 const AV1_COMMON *const cm = &cpi->common;
557 MACROBLOCKD *xd = &x->e_mbd;
558 int64_t rd = INT64_MAX;
559 const int bw = block_size_wide[bsize];
560 DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
561 DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
562 uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
563 uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
564 const int *const interintra_mode_cost =
565 x->interintra_mode_cost[size_group_lookup[bsize]];
566 const int mi_row = xd->mi_row;
567 const int mi_col = xd->mi_col;
568
569 // Single reference inter prediction
570 mbmi->ref_frame[1] = NONE_FRAME;
571 xd->plane[0].dst.buf = tmp_buf;
572 xd->plane[0].dst.stride = bw;
573 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
574 AOM_PLANE_Y, AOM_PLANE_Y);
575 const int num_planes = av1_num_planes(cm);
576
577 // Restore the buffers for intra prediction
578 restore_dst_buf(xd, *orig_dst, num_planes);
579 mbmi->ref_frame[1] = INTRA_FRAME;
580 INTERINTRA_MODE best_interintra_mode =
581 args->inter_intra_mode[mbmi->ref_frame[0]];
582
583 // Compute smooth_interintra
584 int64_t best_interintra_rd_nowedge = INT64_MAX;
585 int best_mode_rate = INT_MAX;
586 if (try_smooth_interintra) {
587 mbmi->use_wedge_interintra = 0;
588 int interintra_mode_reuse = 1;
589 if (cpi->sf.inter_sf.reuse_inter_intra_mode == 0 ||
590 best_interintra_mode == INTERINTRA_MODES) {
591 interintra_mode_reuse = 0;
592 int64_t best_interintra_rd = INT64_MAX;
593 for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
594 ++cur_mode) {
595 if ((!cpi->oxcf.enable_smooth_intra ||
596 cpi->sf.intra_sf.disable_smooth_intra) &&
597 cur_mode == II_SMOOTH_PRED)
598 continue;
599 compute_best_interintra_mode(cpi, mbmi, xd, x, interintra_mode_cost,
600 orig_dst, intrapred, tmp_buf,
601 &best_interintra_mode, &best_interintra_rd,
602 cur_mode, bsize);
603 }
604 args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
605 }
606 assert(IMPLIES(!cpi->oxcf.enable_smooth_interintra ||
607 cpi->sf.inter_sf.disable_smooth_interintra,
608 best_interintra_mode != II_SMOOTH_PRED));
609 // Recompute prediction if required
610 if (interintra_mode_reuse || best_interintra_mode != INTERINTRA_MODES - 1) {
611 mbmi->interintra_mode = best_interintra_mode;
612 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
613 intrapred, bw);
614 av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
615 }
616
617 // Compute rd cost for best smooth_interintra
618 RD_STATS rd_stats;
619 const int rmode = interintra_mode_cost[best_interintra_mode] +
620 (is_wedge_used ? x->wedge_interintra_cost[bsize][0] : 0);
621 const int total_mode_rate = rmode + *rate_mv;
622 const int64_t rd_thresh =
623 compute_rd_thresh(x, total_mode_rate, ref_best_rd);
624 rd = estimate_yrd_for_sb(cpi, bsize, x, rd_thresh, &rd_stats);
625 if (rd != INT64_MAX) {
626 rd = RDCOST(x->rdmult, total_mode_rate + rd_stats.rate, rd_stats.dist);
627 } else {
628 return -1;
629 }
630 best_interintra_rd_nowedge = rd;
631 best_mode_rate = rmode;
632 // Return early if best_interintra_rd_nowedge not good enough
633 if (ref_best_rd < INT64_MAX &&
634 (best_interintra_rd_nowedge >> INTER_INTRA_RD_THRESH_SHIFT) *
635 INTER_INTRA_RD_THRESH_SCALE >
636 ref_best_rd) {
637 return -1;
638 }
639 }
640
641 // Compute wedge interintra
642 int64_t best_interintra_rd_wedge = INT64_MAX;
643 if (try_wedge_interintra) {
644 mbmi->use_wedge_interintra = 1;
645 if (!cpi->sf.inter_sf.fast_interintra_wedge_search) {
646 // Exhaustive search of all wedge and mode combinations.
647 int best_mode = 0;
648 int best_wedge_index = 0;
649 best_interintra_rd_wedge = compute_best_wedge_interintra(
650 cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred_,
651 tmp_buf_, &best_mode, &best_wedge_index, bsize);
652 mbmi->interintra_mode = best_mode;
653 mbmi->interintra_wedge_index = best_wedge_index;
654 if (best_mode != INTERINTRA_MODES - 1) {
655 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
656 intrapred, bw);
657 }
658 } else if (!try_smooth_interintra) {
659 if (best_interintra_mode == INTERINTRA_MODES) {
660 mbmi->interintra_mode = INTERINTRA_MODES - 1;
661 best_interintra_mode = INTERINTRA_MODES - 1;
662 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
663 intrapred, bw);
664 // Pick wedge mask based on INTERINTRA_MODES - 1
665 best_interintra_rd_wedge =
666 pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
667 // Find the best interintra mode for the chosen wedge mask
668 for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
669 ++cur_mode) {
670 compute_best_interintra_mode(
671 cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred,
672 tmp_buf, &best_interintra_mode, &best_interintra_rd_wedge,
673 cur_mode, bsize);
674 }
675 args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
676 mbmi->interintra_mode = best_interintra_mode;
677
678 // Recompute prediction if required
679 if (best_interintra_mode != INTERINTRA_MODES - 1) {
680 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
681 intrapred, bw);
682 }
683 } else {
684 // Pick wedge mask for the best interintra mode (reused)
685 mbmi->interintra_mode = best_interintra_mode;
686 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
687 intrapred, bw);
688 best_interintra_rd_wedge =
689 pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
690 }
691 } else {
692 // Pick wedge mask for the best interintra mode from smooth_interintra
693 best_interintra_rd_wedge =
694 pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
695 }
696
697 const int rate_overhead =
698 interintra_mode_cost[mbmi->interintra_mode] +
699 x->wedge_idx_cost[bsize][mbmi->interintra_wedge_index] +
700 x->wedge_interintra_cost[bsize][1];
701 best_interintra_rd_wedge += RDCOST(x->rdmult, rate_overhead + *rate_mv, 0);
702
703 const int_mv mv0 = mbmi->mv[0];
704 int_mv tmp_mv = mv0;
705 rd = INT64_MAX;
706 int tmp_rate_mv = 0;
707 // Refine motion vector for NEWMV case.
708 if (have_newmv_in_inter_mode(mbmi->mode)) {
709 int rate_sum, skip_txfm_sb;
710 int64_t dist_sum, skip_sse_sb;
711 // get negative of mask
712 const uint8_t *mask =
713 av1_get_contiguous_soft_mask(mbmi->interintra_wedge_index, 1, bsize);
714 av1_compound_single_motion_search(cpi, x, bsize, &tmp_mv.as_mv, intrapred,
715 mask, bw, &tmp_rate_mv, 0);
716 if (mbmi->mv[0].as_int != tmp_mv.as_int) {
717 mbmi->mv[0].as_int = tmp_mv.as_int;
718 // Set ref_frame[1] to NONE_FRAME temporarily so that the intra
719 // predictor is not calculated again in av1_enc_build_inter_predictor().
720 mbmi->ref_frame[1] = NONE_FRAME;
721 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
722 AOM_PLANE_Y, AOM_PLANE_Y);
723 mbmi->ref_frame[1] = INTRA_FRAME;
724 av1_combine_interintra(xd, bsize, 0, xd->plane[AOM_PLANE_Y].dst.buf,
725 xd->plane[AOM_PLANE_Y].dst.stride, intrapred,
726 bw);
727 model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
728 cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &skip_txfm_sb,
729 &skip_sse_sb, NULL, NULL, NULL);
730 rd =
731 RDCOST(x->rdmult, tmp_rate_mv + rate_overhead + rate_sum, dist_sum);
732 }
733 }
734 if (rd >= best_interintra_rd_wedge) {
735 tmp_mv.as_int = mv0.as_int;
736 tmp_rate_mv = *rate_mv;
737 av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
738 }
739 // Evaluate closer to true rd
740 RD_STATS rd_stats;
741 const int64_t mode_rd = RDCOST(x->rdmult, rate_overhead + tmp_rate_mv, 0);
742 const int64_t tmp_rd_thresh = best_interintra_rd_nowedge - mode_rd;
743 rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
744 if (rd != INT64_MAX) {
745 rd = RDCOST(x->rdmult, rate_overhead + tmp_rate_mv + rd_stats.rate,
746 rd_stats.dist);
747 } else {
748 if (best_interintra_rd_nowedge == INT64_MAX) return -1;
749 }
750 best_interintra_rd_wedge = rd;
751 if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
752 mbmi->mv[0].as_int = tmp_mv.as_int;
753 *tmp_rate2 += tmp_rate_mv - *rate_mv;
754 *rate_mv = tmp_rate_mv;
755 best_mode_rate = rate_overhead;
756 } else {
757 mbmi->use_wedge_interintra = 0;
758 mbmi->interintra_mode = best_interintra_mode;
759 mbmi->mv[0].as_int = mv0.as_int;
760 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
761 AOM_PLANE_Y, AOM_PLANE_Y);
762 }
763 }
764
765 if (best_interintra_rd_nowedge == INT64_MAX &&
766 best_interintra_rd_wedge == INT64_MAX) {
767 return -1;
768 }
769
770 *tmp_rate2 += best_mode_rate;
771
772 if (num_planes > 1) {
773 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
774 AOM_PLANE_U, num_planes - 1);
775 }
776 return 0;
777 }
778
alloc_compound_type_rd_buffers_no_check(CompoundTypeRdBuffers * const bufs)779 static void alloc_compound_type_rd_buffers_no_check(
780 CompoundTypeRdBuffers *const bufs) {
781 bufs->pred0 =
782 (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred0));
783 bufs->pred1 =
784 (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred1));
785 bufs->residual1 =
786 (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->residual1));
787 bufs->diff10 =
788 (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->diff10));
789 bufs->tmp_best_mask_buf = (uint8_t *)aom_malloc(
790 2 * MAX_SB_SQUARE * sizeof(*bufs->tmp_best_mask_buf));
791 }
792
793 // Computes the valid compound_types to be evaluated
compute_valid_comp_types(MACROBLOCK * x,const AV1_COMP * const cpi,int * try_average_and_distwtd_comp,BLOCK_SIZE bsize,int masked_compound_used,int mode_search_mask,COMPOUND_TYPE * valid_comp_types)794 static INLINE int compute_valid_comp_types(
795 MACROBLOCK *x, const AV1_COMP *const cpi, int *try_average_and_distwtd_comp,
796 BLOCK_SIZE bsize, int masked_compound_used, int mode_search_mask,
797 COMPOUND_TYPE *valid_comp_types) {
798 const AV1_COMMON *cm = &cpi->common;
799 int valid_type_count = 0;
800 int comp_type, valid_check;
801 int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
802
803 const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
804 const int try_distwtd_comp =
805 ((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
806 cm->seq_params.order_hint_info.enable_dist_wtd_comp == 1 &&
807 cpi->sf.inter_sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
808 *try_average_and_distwtd_comp = try_average_comp && try_distwtd_comp;
809
810 // Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
811 for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
812 comp_type++) {
813 valid_check =
814 (comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
815 if (!*try_average_and_distwtd_comp && valid_check &&
816 is_interinter_compound_used(comp_type, bsize))
817 valid_comp_types[valid_type_count++] = comp_type;
818 }
819 // Check if COMPOUND_WEDGE and COMPOUND_DIFFWTD are valid cases
820 if (masked_compound_used) {
821 // enable_masked_type[0] corresponds to COMPOUND_WEDGE
822 // enable_masked_type[1] corresponds to COMPOUND_DIFFWTD
823 enable_masked_type[0] = enable_wedge_interinter_search(x, cpi);
824 enable_masked_type[1] = cpi->oxcf.enable_diff_wtd_comp;
825 for (comp_type = COMPOUND_WEDGE; comp_type <= COMPOUND_DIFFWTD;
826 comp_type++) {
827 if ((mode_search_mask & (1 << comp_type)) &&
828 is_interinter_compound_used(comp_type, bsize) &&
829 enable_masked_type[comp_type - COMPOUND_WEDGE])
830 valid_comp_types[valid_type_count++] = comp_type;
831 }
832 }
833 return valid_type_count;
834 }
835
836 // Calculates the cost for compound type mask
calc_masked_type_cost(MACROBLOCK * x,BLOCK_SIZE bsize,int comp_group_idx_ctx,int comp_index_ctx,int masked_compound_used,int * masked_type_cost)837 static INLINE void calc_masked_type_cost(MACROBLOCK *x, BLOCK_SIZE bsize,
838 int comp_group_idx_ctx,
839 int comp_index_ctx,
840 int masked_compound_used,
841 int *masked_type_cost) {
842 av1_zero_array(masked_type_cost, COMPOUND_TYPES);
843 // Account for group index cost when wedge and/or diffwtd prediction are
844 // enabled
845 if (masked_compound_used) {
846 // Compound group index of average and distwtd is 0
847 // Compound group index of wedge and diffwtd is 1
848 masked_type_cost[COMPOUND_AVERAGE] +=
849 x->comp_group_idx_cost[comp_group_idx_ctx][0];
850 masked_type_cost[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_AVERAGE];
851 masked_type_cost[COMPOUND_WEDGE] +=
852 x->comp_group_idx_cost[comp_group_idx_ctx][1];
853 masked_type_cost[COMPOUND_DIFFWTD] += masked_type_cost[COMPOUND_WEDGE];
854 }
855
856 // Compute the cost to signal compound index/type
857 masked_type_cost[COMPOUND_AVERAGE] += x->comp_idx_cost[comp_index_ctx][1];
858 masked_type_cost[COMPOUND_DISTWTD] += x->comp_idx_cost[comp_index_ctx][0];
859 masked_type_cost[COMPOUND_WEDGE] += x->compound_type_cost[bsize][0];
860 masked_type_cost[COMPOUND_DIFFWTD] += x->compound_type_cost[bsize][1];
861 }
862
863 // Updates mbmi structure with the relevant compound type info
update_mbmi_for_compound_type(MB_MODE_INFO * mbmi,COMPOUND_TYPE cur_type)864 static INLINE void update_mbmi_for_compound_type(MB_MODE_INFO *mbmi,
865 COMPOUND_TYPE cur_type) {
866 mbmi->interinter_comp.type = cur_type;
867 mbmi->comp_group_idx = (cur_type >= COMPOUND_WEDGE);
868 mbmi->compound_idx = (cur_type != COMPOUND_DISTWTD);
869 }
870
871 // When match is found, populate the compound type data
872 // and calculate the rd cost using the stored stats and
873 // update the mbmi appropriately.
populate_reuse_comp_type_data(const MACROBLOCK * x,MB_MODE_INFO * mbmi,BEST_COMP_TYPE_STATS * best_type_stats,int_mv * cur_mv,int32_t * comp_rate,int64_t * comp_dist,int * comp_rs2,int * rate_mv,int64_t * rd,int match_index)874 static INLINE int populate_reuse_comp_type_data(
875 const MACROBLOCK *x, MB_MODE_INFO *mbmi,
876 BEST_COMP_TYPE_STATS *best_type_stats, int_mv *cur_mv, int32_t *comp_rate,
877 int64_t *comp_dist, int *comp_rs2, int *rate_mv, int64_t *rd,
878 int match_index) {
879 const int winner_comp_type =
880 x->comp_rd_stats[match_index].interinter_comp.type;
881 if (comp_rate[winner_comp_type] == INT_MAX)
882 return best_type_stats->best_compmode_interinter_cost;
883 update_mbmi_for_compound_type(mbmi, winner_comp_type);
884 mbmi->interinter_comp = x->comp_rd_stats[match_index].interinter_comp;
885 *rd = RDCOST(
886 x->rdmult,
887 comp_rs2[winner_comp_type] + *rate_mv + comp_rate[winner_comp_type],
888 comp_dist[winner_comp_type]);
889 mbmi->mv[0].as_int = cur_mv[0].as_int;
890 mbmi->mv[1].as_int = cur_mv[1].as_int;
891 return comp_rs2[winner_comp_type];
892 }
893
894 // Updates rd cost and relevant compound type data for the best compound type
update_best_info(const MB_MODE_INFO * const mbmi,int64_t * rd,BEST_COMP_TYPE_STATS * best_type_stats,int64_t best_rd_cur,int64_t comp_model_rd_cur,int rs2)895 static INLINE void update_best_info(const MB_MODE_INFO *const mbmi, int64_t *rd,
896 BEST_COMP_TYPE_STATS *best_type_stats,
897 int64_t best_rd_cur,
898 int64_t comp_model_rd_cur, int rs2) {
899 *rd = best_rd_cur;
900 best_type_stats->comp_best_model_rd = comp_model_rd_cur;
901 best_type_stats->best_compound_data = mbmi->interinter_comp;
902 best_type_stats->best_compmode_interinter_cost = rs2;
903 }
904
905 // Updates best_mv for masked compound types
update_mask_best_mv(const MB_MODE_INFO * const mbmi,int_mv * best_mv,int_mv * cur_mv,const COMPOUND_TYPE cur_type,int * best_tmp_rate_mv,int tmp_rate_mv,const SPEED_FEATURES * const sf)906 static INLINE void update_mask_best_mv(const MB_MODE_INFO *const mbmi,
907 int_mv *best_mv, int_mv *cur_mv,
908 const COMPOUND_TYPE cur_type,
909 int *best_tmp_rate_mv, int tmp_rate_mv,
910 const SPEED_FEATURES *const sf) {
911 if (cur_type == COMPOUND_WEDGE ||
912 (sf->inter_sf.enable_interinter_diffwtd_newmv_search &&
913 cur_type == COMPOUND_DIFFWTD)) {
914 *best_tmp_rate_mv = tmp_rate_mv;
915 best_mv[0].as_int = mbmi->mv[0].as_int;
916 best_mv[1].as_int = mbmi->mv[1].as_int;
917 } else {
918 best_mv[0].as_int = cur_mv[0].as_int;
919 best_mv[1].as_int = cur_mv[1].as_int;
920 }
921 }
922
923 // Choose the better of the two COMPOUND_AVERAGE,
924 // COMPOUND_DISTWTD based on modeled cost
find_best_avg_distwtd_comp_type(MACROBLOCK * x,int * comp_model_rate,int64_t * comp_model_dist,int rate_mv,int64_t * best_rd)925 static int find_best_avg_distwtd_comp_type(MACROBLOCK *x, int *comp_model_rate,
926 int64_t *comp_model_dist,
927 int rate_mv, int64_t *best_rd) {
928 int64_t est_rd[2];
929 est_rd[COMPOUND_AVERAGE] =
930 RDCOST(x->rdmult, comp_model_rate[COMPOUND_AVERAGE] + rate_mv,
931 comp_model_dist[COMPOUND_AVERAGE]);
932 est_rd[COMPOUND_DISTWTD] =
933 RDCOST(x->rdmult, comp_model_rate[COMPOUND_DISTWTD] + rate_mv,
934 comp_model_dist[COMPOUND_DISTWTD]);
935 int best_type = (est_rd[COMPOUND_AVERAGE] <= est_rd[COMPOUND_DISTWTD])
936 ? COMPOUND_AVERAGE
937 : COMPOUND_DISTWTD;
938 *best_rd = est_rd[best_type];
939 return best_type;
940 }
941
save_comp_rd_search_stat(MACROBLOCK * x,const MB_MODE_INFO * const mbmi,const int32_t * comp_rate,const int64_t * comp_dist,const int32_t * comp_model_rate,const int64_t * comp_model_dist,const int_mv * cur_mv,const int * comp_rs2)942 static INLINE void save_comp_rd_search_stat(
943 MACROBLOCK *x, const MB_MODE_INFO *const mbmi, const int32_t *comp_rate,
944 const int64_t *comp_dist, const int32_t *comp_model_rate,
945 const int64_t *comp_model_dist, const int_mv *cur_mv, const int *comp_rs2) {
946 const int offset = x->comp_rd_stats_idx;
947 if (offset < MAX_COMP_RD_STATS) {
948 COMP_RD_STATS *const rd_stats = x->comp_rd_stats + offset;
949 memcpy(rd_stats->rate, comp_rate, sizeof(rd_stats->rate));
950 memcpy(rd_stats->dist, comp_dist, sizeof(rd_stats->dist));
951 memcpy(rd_stats->model_rate, comp_model_rate, sizeof(rd_stats->model_rate));
952 memcpy(rd_stats->model_dist, comp_model_dist, sizeof(rd_stats->model_dist));
953 memcpy(rd_stats->comp_rs2, comp_rs2, sizeof(rd_stats->comp_rs2));
954 memcpy(rd_stats->mv, cur_mv, sizeof(rd_stats->mv));
955 memcpy(rd_stats->ref_frames, mbmi->ref_frame, sizeof(rd_stats->ref_frames));
956 rd_stats->mode = mbmi->mode;
957 rd_stats->filter = mbmi->interp_filters;
958 rd_stats->ref_mv_idx = mbmi->ref_mv_idx;
959 const MACROBLOCKD *const xd = &x->e_mbd;
960 for (int i = 0; i < 2; ++i) {
961 const WarpedMotionParams *const wm =
962 &xd->global_motion[mbmi->ref_frame[i]];
963 rd_stats->is_global[i] = is_global_mv_block(mbmi, wm->wmtype);
964 }
965 memcpy(&rd_stats->interinter_comp, &mbmi->interinter_comp,
966 sizeof(rd_stats->interinter_comp));
967 ++x->comp_rd_stats_idx;
968 }
969 }
970
get_interinter_compound_mask_rate(const MACROBLOCK * const x,const MB_MODE_INFO * const mbmi)971 static INLINE int get_interinter_compound_mask_rate(
972 const MACROBLOCK *const x, const MB_MODE_INFO *const mbmi) {
973 const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
974 // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
975 if (compound_type == COMPOUND_WEDGE) {
976 return av1_is_wedge_used(mbmi->sb_type)
977 ? av1_cost_literal(1) +
978 x->wedge_idx_cost[mbmi->sb_type]
979 [mbmi->interinter_comp.wedge_index]
980 : 0;
981 } else {
982 assert(compound_type == COMPOUND_DIFFWTD);
983 return av1_cost_literal(1);
984 }
985 }
986
987 // Takes a backup of rate, distortion and model_rd for future reuse
backup_stats(COMPOUND_TYPE cur_type,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int rate_sum,int64_t dist_sum,RD_STATS * rd_stats,int * comp_rs2,int rs2)988 static INLINE void backup_stats(COMPOUND_TYPE cur_type, int32_t *comp_rate,
989 int64_t *comp_dist, int32_t *comp_model_rate,
990 int64_t *comp_model_dist, int rate_sum,
991 int64_t dist_sum, RD_STATS *rd_stats,
992 int *comp_rs2, int rs2) {
993 comp_rate[cur_type] = rd_stats->rate;
994 comp_dist[cur_type] = rd_stats->dist;
995 comp_model_rate[cur_type] = rate_sum;
996 comp_model_dist[cur_type] = dist_sum;
997 comp_rs2[cur_type] = rs2;
998 }
999
masked_compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const BLOCK_SIZE bsize,const PREDICTION_MODE this_mode,int * rs2,int rate_mv,const BUFFER_SET * ctx,int * out_rate_mv,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides,int mode_rate,int64_t rd_thresh,int * calc_pred_masked_compound,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,const int64_t comp_best_model_rd,int64_t * const comp_model_rd_cur,int * comp_rs2,int64_t ref_skip_rd)1000 static int64_t masked_compound_type_rd(
1001 const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
1002 const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
1003 int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
1004 uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
1005 int mode_rate, int64_t rd_thresh, int *calc_pred_masked_compound,
1006 int32_t *comp_rate, int64_t *comp_dist, int32_t *comp_model_rate,
1007 int64_t *comp_model_dist, const int64_t comp_best_model_rd,
1008 int64_t *const comp_model_rd_cur, int *comp_rs2, int64_t ref_skip_rd) {
1009 const AV1_COMMON *const cm = &cpi->common;
1010 MACROBLOCKD *xd = &x->e_mbd;
1011 MB_MODE_INFO *const mbmi = xd->mi[0];
1012 int64_t best_rd_cur = INT64_MAX;
1013 int64_t rd = INT64_MAX;
1014 const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
1015 // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
1016 assert(compound_type == COMPOUND_WEDGE || compound_type == COMPOUND_DIFFWTD);
1017 int rate_sum, tmp_skip_txfm_sb;
1018 int64_t dist_sum, tmp_skip_sse_sb;
1019 pick_interinter_mask_type pick_interinter_mask[2] = { pick_interinter_wedge,
1020 pick_interinter_seg };
1021
1022 // TODO(any): Save pred and mask calculation as well into records. However
1023 // this may increase memory requirements as compound segment mask needs to be
1024 // stored in each record.
1025 if (*calc_pred_masked_compound) {
1026 get_inter_predictors_masked_compound(x, bsize, preds0, preds1, residual1,
1027 diff10, strides);
1028 *calc_pred_masked_compound = 0;
1029 }
1030 if (cpi->sf.inter_sf.prune_wedge_pred_diff_based &&
1031 compound_type == COMPOUND_WEDGE) {
1032 unsigned int sse;
1033 if (is_cur_buf_hbd(xd))
1034 (void)cpi->fn_ptr[bsize].vf(CONVERT_TO_BYTEPTR(*preds0), *strides,
1035 CONVERT_TO_BYTEPTR(*preds1), *strides, &sse);
1036 else
1037 (void)cpi->fn_ptr[bsize].vf(*preds0, *strides, *preds1, *strides, &sse);
1038 const unsigned int mse =
1039 ROUND_POWER_OF_TWO(sse, num_pels_log2_lookup[bsize]);
1040 // If two predictors are very similar, skip wedge compound mode search
1041 if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64)) {
1042 *comp_model_rd_cur = INT64_MAX;
1043 return INT64_MAX;
1044 }
1045 }
1046 // Function pointer to pick the appropriate mask
1047 // compound_type == COMPOUND_WEDGE, calls pick_interinter_wedge()
1048 // compound_type == COMPOUND_DIFFWTD, calls pick_interinter_seg()
1049 uint64_t cur_sse = UINT64_MAX;
1050 best_rd_cur = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
1051 cpi, x, bsize, *preds0, *preds1, residual1, diff10, &cur_sse);
1052 *rs2 += get_interinter_compound_mask_rate(x, mbmi);
1053 best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
1054 assert(cur_sse != UINT64_MAX);
1055 int64_t skip_rd_cur = RDCOST(x->rdmult, *rs2 + rate_mv, (cur_sse << 4));
1056
1057 // Although the true rate_mv might be different after motion search, but it
1058 // is unlikely to be the best mode considering the transform rd cost and other
1059 // mode overhead cost
1060 int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
1061 if (mode_rd > rd_thresh) {
1062 *comp_model_rd_cur = INT64_MAX;
1063 return INT64_MAX;
1064 }
1065
1066 // Check if the mode is good enough based on skip rd
1067 // TODO(nithya): Handle wedge_newmv_search if extending for lower speed
1068 // setting
1069 if (cpi->sf.inter_sf.txfm_rd_gate_level) {
1070 int eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd_cur,
1071 cpi->sf.inter_sf.txfm_rd_gate_level, 1);
1072 if (!eval_txfm) {
1073 *comp_model_rd_cur = INT64_MAX;
1074 return INT64_MAX;
1075 }
1076 }
1077
1078 // Compute cost if matching record not found, else, reuse data
1079 if (comp_rate[compound_type] == INT_MAX) {
1080 // Check whether new MV search for wedge is to be done
1081 int wedge_newmv_search =
1082 have_newmv_in_inter_mode(this_mode) &&
1083 (compound_type == COMPOUND_WEDGE) &&
1084 (!cpi->sf.inter_sf.disable_interinter_wedge_newmv_search);
1085 int diffwtd_newmv_search =
1086 cpi->sf.inter_sf.enable_interinter_diffwtd_newmv_search &&
1087 compound_type == COMPOUND_DIFFWTD &&
1088 have_newmv_in_inter_mode(this_mode);
1089
1090 // Search for new MV if needed and build predictor
1091 if (wedge_newmv_search) {
1092 *out_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1093 bsize, this_mode);
1094 const int mi_row = xd->mi_row;
1095 const int mi_col = xd->mi_col;
1096 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, ctx, bsize,
1097 AOM_PLANE_Y, AOM_PLANE_Y);
1098 } else if (diffwtd_newmv_search) {
1099 *out_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1100 bsize, this_mode);
1101 // we need to update the mask according to the new motion vector
1102 CompoundTypeRdBuffers tmp_buf;
1103 int64_t tmp_rd = INT64_MAX;
1104 alloc_compound_type_rd_buffers_no_check(&tmp_buf);
1105
1106 uint8_t *tmp_preds0[1] = { tmp_buf.pred0 };
1107 uint8_t *tmp_preds1[1] = { tmp_buf.pred1 };
1108
1109 get_inter_predictors_masked_compound(x, bsize, tmp_preds0, tmp_preds1,
1110 tmp_buf.residual1, tmp_buf.diff10,
1111 strides);
1112
1113 tmp_rd = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
1114 cpi, x, bsize, *tmp_preds0, *tmp_preds1, tmp_buf.residual1,
1115 tmp_buf.diff10, &cur_sse);
1116 // we can reuse rs2 here
1117 tmp_rd += RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
1118
1119 if (tmp_rd >= best_rd_cur) {
1120 // restore the motion vector
1121 mbmi->mv[0].as_int = cur_mv[0].as_int;
1122 mbmi->mv[1].as_int = cur_mv[1].as_int;
1123 *out_rate_mv = rate_mv;
1124 av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
1125 strides, preds1, strides);
1126 } else {
1127 // build the final prediciton using the updated mv
1128 av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, tmp_preds0,
1129 strides, tmp_preds1, strides);
1130 }
1131 av1_release_compound_type_rd_buffers(&tmp_buf);
1132 } else {
1133 *out_rate_mv = rate_mv;
1134 av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
1135 preds1, strides);
1136 }
1137 // Get the RD cost from model RD
1138 model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
1139 cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
1140 &tmp_skip_sse_sb, NULL, NULL, NULL);
1141 rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
1142 *comp_model_rd_cur = rd;
1143 // Override with best if current is worse than best for new MV
1144 if (wedge_newmv_search) {
1145 if (rd >= best_rd_cur) {
1146 mbmi->mv[0].as_int = cur_mv[0].as_int;
1147 mbmi->mv[1].as_int = cur_mv[1].as_int;
1148 *out_rate_mv = rate_mv;
1149 av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
1150 strides, preds1, strides);
1151 *comp_model_rd_cur = best_rd_cur;
1152 }
1153 }
1154 if (cpi->sf.inter_sf.prune_comp_type_by_model_rd &&
1155 (*comp_model_rd_cur > comp_best_model_rd) &&
1156 comp_best_model_rd != INT64_MAX) {
1157 *comp_model_rd_cur = INT64_MAX;
1158 return INT64_MAX;
1159 }
1160 // Compute RD cost for the current type
1161 RD_STATS rd_stats;
1162 const int64_t tmp_mode_rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
1163 const int64_t tmp_rd_thresh = rd_thresh - tmp_mode_rd;
1164 rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
1165 if (rd != INT64_MAX) {
1166 rd =
1167 RDCOST(x->rdmult, *rs2 + *out_rate_mv + rd_stats.rate, rd_stats.dist);
1168 // Backup rate and distortion for future reuse
1169 backup_stats(compound_type, comp_rate, comp_dist, comp_model_rate,
1170 comp_model_dist, rate_sum, dist_sum, &rd_stats, comp_rs2,
1171 *rs2);
1172 }
1173 } else {
1174 // Reuse data as matching record is found
1175 assert(comp_dist[compound_type] != INT64_MAX);
1176 // When disable_interinter_wedge_newmv_search is set, motion refinement is
1177 // disabled. Hence rate and distortion can be reused in this case as well
1178 assert(IMPLIES(have_newmv_in_inter_mode(this_mode),
1179 cpi->sf.inter_sf.disable_interinter_wedge_newmv_search));
1180 assert(mbmi->mv[0].as_int == cur_mv[0].as_int);
1181 assert(mbmi->mv[1].as_int == cur_mv[1].as_int);
1182 *out_rate_mv = rate_mv;
1183 // Calculate RD cost based on stored stats
1184 rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_rate[compound_type],
1185 comp_dist[compound_type]);
1186 // Recalculate model rdcost with the updated rate
1187 *comp_model_rd_cur =
1188 RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_model_rate[compound_type],
1189 comp_model_dist[compound_type]);
1190 }
1191 return rd;
1192 }
1193
1194 // scaling values to be used for gating wedge/compound segment based on best
1195 // approximate rd
1196 static int comp_type_rd_threshold_mul[3] = { 1, 11, 12 };
1197 static int comp_type_rd_threshold_div[3] = { 3, 16, 16 };
1198
av1_compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int_mv * cur_mv,int mode_search_mask,int masked_compound_used,const BUFFER_SET * orig_dst,const BUFFER_SET * tmp_dst,const CompoundTypeRdBuffers * buffers,int * rate_mv,int64_t * rd,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t ref_skip_rd,int * is_luma_interp_done,int64_t rd_thresh)1199 int av1_compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x,
1200 BLOCK_SIZE bsize, int_mv *cur_mv, int mode_search_mask,
1201 int masked_compound_used, const BUFFER_SET *orig_dst,
1202 const BUFFER_SET *tmp_dst,
1203 const CompoundTypeRdBuffers *buffers, int *rate_mv,
1204 int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
1205 int64_t ref_skip_rd, int *is_luma_interp_done,
1206 int64_t rd_thresh) {
1207 const AV1_COMMON *cm = &cpi->common;
1208 MACROBLOCKD *xd = &x->e_mbd;
1209 MB_MODE_INFO *mbmi = xd->mi[0];
1210 const PREDICTION_MODE this_mode = mbmi->mode;
1211 const int bw = block_size_wide[bsize];
1212 int rs2;
1213 int_mv best_mv[2];
1214 int best_tmp_rate_mv = *rate_mv;
1215 BEST_COMP_TYPE_STATS best_type_stats;
1216 // Initializing BEST_COMP_TYPE_STATS
1217 best_type_stats.best_compound_data.type = COMPOUND_AVERAGE;
1218 best_type_stats.best_compmode_interinter_cost = 0;
1219 best_type_stats.comp_best_model_rd = INT64_MAX;
1220
1221 uint8_t *preds0[1] = { buffers->pred0 };
1222 uint8_t *preds1[1] = { buffers->pred1 };
1223 int strides[1] = { bw };
1224 int tmp_rate_mv;
1225 const int num_pix = 1 << num_pels_log2_lookup[bsize];
1226 const int mask_len = 2 * num_pix * sizeof(uint8_t);
1227 COMPOUND_TYPE cur_type;
1228 // Local array to store the mask cost for different compound types
1229 int masked_type_cost[COMPOUND_TYPES];
1230
1231 int calc_pred_masked_compound = 1;
1232 int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
1233 INT64_MAX };
1234 int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
1235 int comp_rs2[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
1236 int32_t comp_model_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX,
1237 INT_MAX };
1238 int64_t comp_model_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
1239 INT64_MAX };
1240 int match_index = 0;
1241 const int match_found =
1242 find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist, comp_model_rate,
1243 comp_model_dist, comp_rs2, &match_index);
1244 best_mv[0].as_int = cur_mv[0].as_int;
1245 best_mv[1].as_int = cur_mv[1].as_int;
1246 *rd = INT64_MAX;
1247 int rate_sum, tmp_skip_txfm_sb;
1248 int64_t dist_sum, tmp_skip_sse_sb;
1249
1250 // Local array to store the valid compound types to be evaluated in the core
1251 // loop
1252 COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
1253 COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
1254 };
1255 int valid_type_count = 0;
1256 int try_average_and_distwtd_comp = 0;
1257 // compute_valid_comp_types() returns the number of valid compound types to be
1258 // evaluated and populates the same in the local array valid_comp_types[].
1259 // It also sets the flag 'try_average_and_distwtd_comp'
1260 valid_type_count = compute_valid_comp_types(
1261 x, cpi, &try_average_and_distwtd_comp, bsize, masked_compound_used,
1262 mode_search_mask, valid_comp_types);
1263
1264 // The following context indices are independent of compound type
1265 const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
1266 const int comp_index_ctx = get_comp_index_context(cm, xd);
1267
1268 // Populates masked_type_cost local array for the 4 compound types
1269 calc_masked_type_cost(x, bsize, comp_group_idx_ctx, comp_index_ctx,
1270 masked_compound_used, masked_type_cost);
1271
1272 int64_t comp_model_rd_cur = INT64_MAX;
1273 int64_t best_rd_cur = INT64_MAX;
1274 const int mi_row = xd->mi_row;
1275 const int mi_col = xd->mi_col;
1276
1277 // If the match is found, calculate the rd cost using the
1278 // stored stats and update the mbmi appropriately.
1279 if (match_found && cpi->sf.inter_sf.reuse_compound_type_decision) {
1280 return populate_reuse_comp_type_data(x, mbmi, &best_type_stats, cur_mv,
1281 comp_rate, comp_dist, comp_rs2,
1282 rate_mv, rd, match_index);
1283 }
1284 // Special handling if both compound_average and compound_distwtd
1285 // are to be searched. In this case, first estimate between the two
1286 // modes and then call estimate_yrd_for_sb() only for the better of
1287 // the two.
1288 if (try_average_and_distwtd_comp) {
1289 int est_rate[2];
1290 int64_t est_dist[2], est_rd;
1291 COMPOUND_TYPE best_type;
1292 // Since modelled rate and dist are separately stored,
1293 // compute better of COMPOUND_AVERAGE and COMPOUND_DISTWTD
1294 // using the stored stats.
1295 if ((comp_model_rate[COMPOUND_AVERAGE] != INT_MAX) &&
1296 comp_model_rate[COMPOUND_DISTWTD] != INT_MAX) {
1297 // Choose the better of the COMPOUND_AVERAGE,
1298 // COMPOUND_DISTWTD on modeled cost.
1299 best_type = find_best_avg_distwtd_comp_type(
1300 x, comp_model_rate, comp_model_dist, *rate_mv, &est_rd);
1301 update_mbmi_for_compound_type(mbmi, best_type);
1302 if (comp_rate[best_type] != INT_MAX)
1303 best_rd_cur = RDCOST(
1304 x->rdmult,
1305 masked_type_cost[best_type] + *rate_mv + comp_rate[best_type],
1306 comp_dist[best_type]);
1307 comp_model_rd_cur = est_rd;
1308 // Update stats for best compound type
1309 if (best_rd_cur < *rd) {
1310 update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
1311 comp_model_rd_cur, masked_type_cost[best_type]);
1312 }
1313 restore_dst_buf(xd, *tmp_dst, 1);
1314 } else {
1315 int64_t sse_y[COMPOUND_DISTWTD + 1];
1316 // Calculate model_rd for COMPOUND_AVERAGE and COMPOUND_DISTWTD
1317 for (int comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
1318 comp_type++) {
1319 update_mbmi_for_compound_type(mbmi, comp_type);
1320 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1321 AOM_PLANE_Y, AOM_PLANE_Y);
1322 model_rd_sb_fn[MODELRD_CURVFIT](
1323 cpi, bsize, x, xd, 0, 0, &est_rate[comp_type], &est_dist[comp_type],
1324 NULL, NULL, NULL, NULL, NULL);
1325 est_rate[comp_type] += masked_type_cost[comp_type];
1326 comp_model_rate[comp_type] = est_rate[comp_type];
1327 comp_model_dist[comp_type] = est_dist[comp_type];
1328 sse_y[comp_type] = x->pred_sse[xd->mi[0]->ref_frame[0]];
1329 if (comp_type == COMPOUND_AVERAGE) {
1330 *is_luma_interp_done = 1;
1331 restore_dst_buf(xd, *tmp_dst, 1);
1332 }
1333 }
1334 // Choose the better of the two based on modeled cost and call
1335 // estimate_yrd_for_sb() for that one.
1336 best_type = find_best_avg_distwtd_comp_type(
1337 x, comp_model_rate, comp_model_dist, *rate_mv, &est_rd);
1338 update_mbmi_for_compound_type(mbmi, best_type);
1339 if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *orig_dst, 1);
1340 rs2 = masked_type_cost[best_type];
1341 RD_STATS est_rd_stats;
1342 const int64_t mode_rd = RDCOST(x->rdmult, rs2 + *rate_mv, 0);
1343 const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
1344 int64_t est_rd_ = INT64_MAX;
1345 int eval_txfm = 1;
1346 // Check if the mode is good enough based on skip rd
1347 if (cpi->sf.inter_sf.txfm_rd_gate_level) {
1348 int64_t skip_rd =
1349 RDCOST(x->rdmult, rs2 + *rate_mv, (sse_y[best_type] << 4));
1350 eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd,
1351 cpi->sf.inter_sf.txfm_rd_gate_level, 1);
1352 }
1353 // Evaluate further if skip rd is low enough
1354 if (eval_txfm) {
1355 est_rd_ =
1356 estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &est_rd_stats);
1357 }
1358
1359 if (est_rd_ != INT64_MAX) {
1360 best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
1361 est_rd_stats.dist);
1362 // Backup rate and distortion for future reuse
1363 backup_stats(best_type, comp_rate, comp_dist, comp_model_rate,
1364 comp_model_dist, est_rate[best_type], est_dist[best_type],
1365 &est_rd_stats, comp_rs2, rs2);
1366 comp_model_rd_cur = est_rd;
1367 }
1368 if (best_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1369 // Update stats for best compound type
1370 if (best_rd_cur < *rd) {
1371 update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
1372 comp_model_rd_cur, rs2);
1373 }
1374 }
1375 }
1376
1377 // If COMPOUND_AVERAGE is not valid, use the spare buffer
1378 if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1379
1380 // Loop over valid compound types
1381 for (int i = 0; i < valid_type_count; i++) {
1382 cur_type = valid_comp_types[i];
1383 comp_model_rd_cur = INT64_MAX;
1384 tmp_rate_mv = *rate_mv;
1385 best_rd_cur = INT64_MAX;
1386
1387 // Case COMPOUND_AVERAGE and COMPOUND_DISTWTD
1388 if (cur_type < COMPOUND_WEDGE) {
1389 update_mbmi_for_compound_type(mbmi, cur_type);
1390 rs2 = masked_type_cost[cur_type];
1391 const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
1392 if (mode_rd < ref_best_rd) {
1393 // Reuse data if matching record is found
1394 if (comp_rate[cur_type] == INT_MAX) {
1395 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1396 AOM_PLANE_Y, AOM_PLANE_Y);
1397 if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
1398
1399 // Compute RD cost for the current type
1400 RD_STATS est_rd_stats;
1401 const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
1402 int64_t est_rd = INT64_MAX;
1403 int eval_txfm = 1;
1404 // Check if the mode is good enough based on skip rd
1405 if (cpi->sf.inter_sf.txfm_rd_gate_level) {
1406 int64_t sse_y = compute_sse_plane(x, xd, PLANE_TYPE_Y, bsize);
1407 int64_t skip_rd = RDCOST(x->rdmult, rs2 + *rate_mv, (sse_y << 4));
1408 eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd,
1409 cpi->sf.inter_sf.txfm_rd_gate_level, 1);
1410 }
1411 // Evaluate further if skip rd is low enough
1412 if (eval_txfm) {
1413 est_rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh,
1414 &est_rd_stats);
1415 }
1416
1417 if (est_rd != INT64_MAX) {
1418 best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
1419 est_rd_stats.dist);
1420 model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
1421 cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
1422 &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
1423 comp_model_rd_cur =
1424 RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
1425
1426 // Backup rate and distortion for future reuse
1427 backup_stats(cur_type, comp_rate, comp_dist, comp_model_rate,
1428 comp_model_dist, rate_sum, dist_sum, &est_rd_stats,
1429 comp_rs2, rs2);
1430 }
1431 } else {
1432 // Calculate RD cost based on stored stats
1433 assert(comp_dist[cur_type] != INT64_MAX);
1434 best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
1435 comp_dist[cur_type]);
1436 // Recalculate model rdcost with the updated rate
1437 comp_model_rd_cur =
1438 RDCOST(x->rdmult, rs2 + *rate_mv + comp_model_rate[cur_type],
1439 comp_model_dist[cur_type]);
1440 }
1441 }
1442 // use spare buffer for following compound type try
1443 if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1444 } else {
1445 // Handle masked compound types
1446 update_mbmi_for_compound_type(mbmi, cur_type);
1447 rs2 = masked_type_cost[cur_type];
1448 // Factors to control gating of compound type selection based on best
1449 // approximate rd so far
1450 const int max_comp_type_rd_threshold_mul =
1451 comp_type_rd_threshold_mul[cpi->sf.inter_sf
1452 .prune_comp_type_by_comp_avg];
1453 const int max_comp_type_rd_threshold_div =
1454 comp_type_rd_threshold_div[cpi->sf.inter_sf
1455 .prune_comp_type_by_comp_avg];
1456 // Evaluate COMPOUND_WEDGE / COMPOUND_DIFFWTD if approximated cost is
1457 // within threshold
1458 int64_t approx_rd = ((*rd / max_comp_type_rd_threshold_div) *
1459 max_comp_type_rd_threshold_mul);
1460
1461 if (approx_rd < ref_best_rd) {
1462 const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
1463 best_rd_cur = masked_compound_type_rd(
1464 cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
1465 &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
1466 strides, rd_stats->rate, tmp_rd_thresh, &calc_pred_masked_compound,
1467 comp_rate, comp_dist, comp_model_rate, comp_model_dist,
1468 best_type_stats.comp_best_model_rd, &comp_model_rd_cur, comp_rs2,
1469 ref_skip_rd);
1470 }
1471 }
1472 // Update stats for best compound type
1473 if (best_rd_cur < *rd) {
1474 update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
1475 comp_model_rd_cur, rs2);
1476 if (masked_compound_used && cur_type >= COMPOUND_WEDGE) {
1477 memcpy(buffers->tmp_best_mask_buf, xd->seg_mask, mask_len);
1478 if (have_newmv_in_inter_mode(this_mode))
1479 update_mask_best_mv(mbmi, best_mv, cur_mv, cur_type,
1480 &best_tmp_rate_mv, tmp_rate_mv, &cpi->sf);
1481 }
1482 }
1483 // reset to original mvs for next iteration
1484 mbmi->mv[0].as_int = cur_mv[0].as_int;
1485 mbmi->mv[1].as_int = cur_mv[1].as_int;
1486 }
1487 if (mbmi->interinter_comp.type != best_type_stats.best_compound_data.type) {
1488 mbmi->comp_group_idx =
1489 (best_type_stats.best_compound_data.type < COMPOUND_WEDGE) ? 0 : 1;
1490 mbmi->compound_idx =
1491 !(best_type_stats.best_compound_data.type == COMPOUND_DISTWTD);
1492 mbmi->interinter_comp = best_type_stats.best_compound_data;
1493 memcpy(xd->seg_mask, buffers->tmp_best_mask_buf, mask_len);
1494 }
1495 if (have_newmv_in_inter_mode(this_mode)) {
1496 mbmi->mv[0].as_int = best_mv[0].as_int;
1497 mbmi->mv[1].as_int = best_mv[1].as_int;
1498 if (mbmi->interinter_comp.type == COMPOUND_WEDGE) {
1499 rd_stats->rate += best_tmp_rate_mv - *rate_mv;
1500 *rate_mv = best_tmp_rate_mv;
1501 }
1502 }
1503 restore_dst_buf(xd, *orig_dst, 1);
1504 if (!match_found)
1505 save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, comp_model_rate,
1506 comp_model_dist, cur_mv, comp_rs2);
1507 return best_type_stats.best_compmode_interinter_cost;
1508 }
1509