1 /*
2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include "av1/common/pred_common.h"
13 #include "av1/encoder/compound_type.h"
14 #include "av1/encoder/encoder_alloc.h"
15 #include "av1/encoder/model_rd.h"
16 #include "av1/encoder/motion_search_facade.h"
17 #include "av1/encoder/rdopt_utils.h"
18 #include "av1/encoder/reconinter_enc.h"
19 #include "av1/encoder/tx_search.h"
20
21 typedef int64_t (*pick_interinter_mask_type)(
22 const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
23 const uint8_t *const p0, const uint8_t *const p1,
24 const int16_t *const residual1, const int16_t *const diff10,
25 uint64_t *best_sse);
26
27 // Checks if characteristics of search match
is_comp_rd_match(const AV1_COMP * const cpi,const MACROBLOCK * const x,const COMP_RD_STATS * st,const MB_MODE_INFO * const mi,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int * comp_rs2)28 static INLINE int is_comp_rd_match(const AV1_COMP *const cpi,
29 const MACROBLOCK *const x,
30 const COMP_RD_STATS *st,
31 const MB_MODE_INFO *const mi,
32 int32_t *comp_rate, int64_t *comp_dist,
33 int32_t *comp_model_rate,
34 int64_t *comp_model_dist, int *comp_rs2) {
35 // TODO(ranjit): Ensure that compound type search use regular filter always
36 // and check if following check can be removed
37 // Check if interp filter matches with previous case
38 if (st->filter.as_int != mi->interp_filters.as_int) return 0;
39
40 const MACROBLOCKD *const xd = &x->e_mbd;
41 // Match MV and reference indices
42 for (int i = 0; i < 2; ++i) {
43 if ((st->ref_frames[i] != mi->ref_frame[i]) ||
44 (st->mv[i].as_int != mi->mv[i].as_int)) {
45 return 0;
46 }
47 const WarpedMotionParams *const wm = &xd->global_motion[mi->ref_frame[i]];
48 if (is_global_mv_block(mi, wm->wmtype) != st->is_global[i]) return 0;
49 }
50
51 int reuse_data[COMPOUND_TYPES] = { 1, 1, 0, 0 };
52 // For compound wedge, reuse data if newmv search is disabled when NEWMV is
53 // present or if NEWMV is not present in either of the directions
54 if ((!have_newmv_in_inter_mode(mi->mode) &&
55 !have_newmv_in_inter_mode(st->mode)) ||
56 (cpi->sf.inter_sf.disable_interinter_wedge_newmv_search))
57 reuse_data[COMPOUND_WEDGE] = 1;
58 // For compound diffwtd, reuse data if fast search is enabled (no newmv search
59 // when NEWMV is present) or if NEWMV is not present in either of the
60 // directions
61 if (cpi->sf.inter_sf.enable_fast_compound_mode_search ||
62 (!have_newmv_in_inter_mode(mi->mode) &&
63 !have_newmv_in_inter_mode(st->mode)))
64 reuse_data[COMPOUND_DIFFWTD] = 1;
65
66 // Store the stats for the different compound types
67 for (int comp_type = COMPOUND_AVERAGE; comp_type < COMPOUND_TYPES;
68 comp_type++) {
69 if (reuse_data[comp_type]) {
70 comp_rate[comp_type] = st->rate[comp_type];
71 comp_dist[comp_type] = st->dist[comp_type];
72 comp_model_rate[comp_type] = st->model_rate[comp_type];
73 comp_model_dist[comp_type] = st->model_dist[comp_type];
74 comp_rs2[comp_type] = st->comp_rs2[comp_type];
75 }
76 }
77 return 1;
78 }
79
80 // Checks if similar compound type search case is accounted earlier
81 // If found, returns relevant rd data
find_comp_rd_in_stats(const AV1_COMP * const cpi,const MACROBLOCK * x,const MB_MODE_INFO * const mbmi,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int * comp_rs2,int * match_index)82 static INLINE int find_comp_rd_in_stats(const AV1_COMP *const cpi,
83 const MACROBLOCK *x,
84 const MB_MODE_INFO *const mbmi,
85 int32_t *comp_rate, int64_t *comp_dist,
86 int32_t *comp_model_rate,
87 int64_t *comp_model_dist, int *comp_rs2,
88 int *match_index) {
89 for (int j = 0; j < x->comp_rd_stats_idx; ++j) {
90 if (is_comp_rd_match(cpi, x, &x->comp_rd_stats[j], mbmi, comp_rate,
91 comp_dist, comp_model_rate, comp_model_dist,
92 comp_rs2)) {
93 *match_index = j;
94 return 1;
95 }
96 }
97 return 0; // no match result found
98 }
99
enable_wedge_search(MACROBLOCK * const x,const unsigned int disable_wedge_var_thresh)100 static INLINE bool enable_wedge_search(
101 MACROBLOCK *const x, const unsigned int disable_wedge_var_thresh) {
102 // Enable wedge search if source variance and edge strength are above
103 // the thresholds.
104 return x->source_variance > disable_wedge_var_thresh;
105 }
106
enable_wedge_interinter_search(MACROBLOCK * const x,const AV1_COMP * const cpi)107 static INLINE bool enable_wedge_interinter_search(MACROBLOCK *const x,
108 const AV1_COMP *const cpi) {
109 return enable_wedge_search(
110 x, cpi->sf.inter_sf.disable_interinter_wedge_var_thresh) &&
111 cpi->oxcf.comp_type_cfg.enable_interinter_wedge;
112 }
113
enable_wedge_interintra_search(MACROBLOCK * const x,const AV1_COMP * const cpi)114 static INLINE bool enable_wedge_interintra_search(MACROBLOCK *const x,
115 const AV1_COMP *const cpi) {
116 return enable_wedge_search(
117 x, cpi->sf.inter_sf.disable_interintra_wedge_var_thresh) &&
118 cpi->oxcf.comp_type_cfg.enable_interintra_wedge;
119 }
120
estimate_wedge_sign(const AV1_COMP * cpi,const MACROBLOCK * x,const BLOCK_SIZE bsize,const uint8_t * pred0,int stride0,const uint8_t * pred1,int stride1)121 static int8_t estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x,
122 const BLOCK_SIZE bsize, const uint8_t *pred0,
123 int stride0, const uint8_t *pred1,
124 int stride1) {
125 static const BLOCK_SIZE split_qtr[BLOCK_SIZES_ALL] = {
126 // 4X4
127 BLOCK_INVALID,
128 // 4X8, 8X4, 8X8
129 BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X4,
130 // 8X16, 16X8, 16X16
131 BLOCK_4X8, BLOCK_8X4, BLOCK_8X8,
132 // 16X32, 32X16, 32X32
133 BLOCK_8X16, BLOCK_16X8, BLOCK_16X16,
134 // 32X64, 64X32, 64X64
135 BLOCK_16X32, BLOCK_32X16, BLOCK_32X32,
136 // 64x128, 128x64, 128x128
137 BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
138 // 4X16, 16X4, 8X32
139 BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X16,
140 // 32X8, 16X64, 64X16
141 BLOCK_16X4, BLOCK_8X32, BLOCK_32X8
142 };
143 const struct macroblock_plane *const p = &x->plane[0];
144 const uint8_t *src = p->src.buf;
145 int src_stride = p->src.stride;
146 const int bw = block_size_wide[bsize];
147 const int bh = block_size_high[bsize];
148 const int bw_by2 = bw >> 1;
149 const int bh_by2 = bh >> 1;
150 uint32_t esq[2][2];
151 int64_t tl, br;
152
153 const BLOCK_SIZE f_index = split_qtr[bsize];
154 assert(f_index != BLOCK_INVALID);
155
156 if (is_cur_buf_hbd(&x->e_mbd)) {
157 pred0 = CONVERT_TO_BYTEPTR(pred0);
158 pred1 = CONVERT_TO_BYTEPTR(pred1);
159 }
160
161 // Residual variance computation over relevant quandrants in order to
162 // find TL + BR, TL = sum(1st,2nd,3rd) quadrants of (pred0 - pred1),
163 // BR = sum(2nd,3rd,4th) quadrants of (pred1 - pred0)
164 // The 2nd and 3rd quadrants cancel out in TL + BR
165 // Hence TL + BR = 1st quadrant of (pred0-pred1) + 4th of (pred1-pred0)
166 // TODO(nithya): Sign estimation assumes 45 degrees (1st and 4th quadrants)
167 // for all codebooks; experiment with other quadrant combinations for
168 // 0, 90 and 135 degrees also.
169 cpi->ppi->fn_ptr[f_index].vf(src, src_stride, pred0, stride0, &esq[0][0]);
170 cpi->ppi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
171 pred0 + bh_by2 * stride0 + bw_by2, stride0,
172 &esq[0][1]);
173 cpi->ppi->fn_ptr[f_index].vf(src, src_stride, pred1, stride1, &esq[1][0]);
174 cpi->ppi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
175 pred1 + bh_by2 * stride1 + bw_by2, stride0,
176 &esq[1][1]);
177
178 tl = ((int64_t)esq[0][0]) - ((int64_t)esq[1][0]);
179 br = ((int64_t)esq[1][1]) - ((int64_t)esq[0][1]);
180 return (tl + br > 0);
181 }
182
183 // Choose the best wedge index and sign
pick_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const int16_t * const residual1,const int16_t * const diff10,int8_t * const best_wedge_sign,int8_t * const best_wedge_index,uint64_t * best_sse)184 static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
185 const BLOCK_SIZE bsize, const uint8_t *const p0,
186 const int16_t *const residual1,
187 const int16_t *const diff10,
188 int8_t *const best_wedge_sign,
189 int8_t *const best_wedge_index, uint64_t *best_sse) {
190 const MACROBLOCKD *const xd = &x->e_mbd;
191 const struct buf_2d *const src = &x->plane[0].src;
192 const int bw = block_size_wide[bsize];
193 const int bh = block_size_high[bsize];
194 const int N = bw * bh;
195 assert(N >= 64);
196 int rate;
197 int64_t dist;
198 int64_t rd, best_rd = INT64_MAX;
199 int8_t wedge_index;
200 int8_t wedge_sign;
201 const int8_t wedge_types = get_wedge_types_lookup(bsize);
202 const uint8_t *mask;
203 uint64_t sse;
204 const int hbd = is_cur_buf_hbd(xd);
205 const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
206
207 DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]); // src - pred0
208 #if CONFIG_AV1_HIGHBITDEPTH
209 if (hbd) {
210 aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride,
211 CONVERT_TO_BYTEPTR(p0), bw);
212 } else {
213 aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
214 }
215 #else
216 (void)hbd;
217 aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
218 #endif
219
220 int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) -
221 (int64_t)aom_sum_squares_i16(residual1, N)) *
222 (1 << WEDGE_WEIGHT_BITS) / 2;
223 int16_t *ds = residual0;
224
225 av1_wedge_compute_delta_squares(ds, residual0, residual1, N);
226
227 for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
228 mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize);
229
230 wedge_sign = av1_wedge_sign_from_residuals(ds, mask, N, sign_limit);
231
232 mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
233 sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
234 sse = ROUND_POWER_OF_TWO(sse, bd_round);
235
236 model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
237 &rate, &dist);
238 // int rate2;
239 // int64_t dist2;
240 // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2);
241 // printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n",
242 // sse, rate, dist, rate2, dist2); dist = dist2;
243 // rate = rate2;
244
245 rate += x->mode_costs.wedge_idx_cost[bsize][wedge_index];
246 rd = RDCOST(x->rdmult, rate, dist);
247
248 if (rd < best_rd) {
249 *best_wedge_index = wedge_index;
250 *best_wedge_sign = wedge_sign;
251 best_rd = rd;
252 *best_sse = sse;
253 }
254 }
255
256 return best_rd -
257 RDCOST(x->rdmult,
258 x->mode_costs.wedge_idx_cost[bsize][*best_wedge_index], 0);
259 }
260
261 // Choose the best wedge index the specified sign
pick_wedge_fixed_sign(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const int16_t * const residual1,const int16_t * const diff10,const int8_t wedge_sign,int8_t * const best_wedge_index,uint64_t * best_sse)262 static int64_t pick_wedge_fixed_sign(
263 const AV1_COMP *const cpi, const MACROBLOCK *const x,
264 const BLOCK_SIZE bsize, const int16_t *const residual1,
265 const int16_t *const diff10, const int8_t wedge_sign,
266 int8_t *const best_wedge_index, uint64_t *best_sse) {
267 const MACROBLOCKD *const xd = &x->e_mbd;
268
269 const int bw = block_size_wide[bsize];
270 const int bh = block_size_high[bsize];
271 const int N = bw * bh;
272 assert(N >= 64);
273 int rate;
274 int64_t dist;
275 int64_t rd, best_rd = INT64_MAX;
276 int8_t wedge_index;
277 const int8_t wedge_types = get_wedge_types_lookup(bsize);
278 const uint8_t *mask;
279 uint64_t sse;
280 const int hbd = is_cur_buf_hbd(xd);
281 const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
282 for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
283 mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
284 sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
285 sse = ROUND_POWER_OF_TWO(sse, bd_round);
286
287 model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
288 &rate, &dist);
289 rate += x->mode_costs.wedge_idx_cost[bsize][wedge_index];
290 rd = RDCOST(x->rdmult, rate, dist);
291
292 if (rd < best_rd) {
293 *best_wedge_index = wedge_index;
294 best_rd = rd;
295 *best_sse = sse;
296 }
297 }
298 return best_rd -
299 RDCOST(x->rdmult,
300 x->mode_costs.wedge_idx_cost[bsize][*best_wedge_index], 0);
301 }
302
pick_interinter_wedge(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10,uint64_t * best_sse)303 static int64_t pick_interinter_wedge(
304 const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
305 const uint8_t *const p0, const uint8_t *const p1,
306 const int16_t *const residual1, const int16_t *const diff10,
307 uint64_t *best_sse) {
308 MACROBLOCKD *const xd = &x->e_mbd;
309 MB_MODE_INFO *const mbmi = xd->mi[0];
310 const int bw = block_size_wide[bsize];
311
312 int64_t rd;
313 int8_t wedge_index = -1;
314 int8_t wedge_sign = 0;
315
316 assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
317 assert(cpi->common.seq_params->enable_masked_compound);
318
319 if (cpi->sf.inter_sf.fast_wedge_sign_estimate) {
320 wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
321 rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
322 &wedge_index, best_sse);
323 } else {
324 rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
325 &wedge_index, best_sse);
326 }
327
328 mbmi->interinter_comp.wedge_sign = wedge_sign;
329 mbmi->interinter_comp.wedge_index = wedge_index;
330 return rd;
331 }
332
pick_interinter_seg(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10,uint64_t * best_sse)333 static int64_t pick_interinter_seg(const AV1_COMP *const cpi,
334 MACROBLOCK *const x, const BLOCK_SIZE bsize,
335 const uint8_t *const p0,
336 const uint8_t *const p1,
337 const int16_t *const residual1,
338 const int16_t *const diff10,
339 uint64_t *best_sse) {
340 MACROBLOCKD *const xd = &x->e_mbd;
341 MB_MODE_INFO *const mbmi = xd->mi[0];
342 const int bw = block_size_wide[bsize];
343 const int bh = block_size_high[bsize];
344 const int N = 1 << num_pels_log2_lookup[bsize];
345 int rate;
346 int64_t dist;
347 DIFFWTD_MASK_TYPE cur_mask_type;
348 int64_t best_rd = INT64_MAX;
349 DIFFWTD_MASK_TYPE best_mask_type = 0;
350 const int hbd = is_cur_buf_hbd(xd);
351 const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
352 DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
353 uint8_t *tmp_mask[2] = { xd->seg_mask, seg_mask };
354 // try each mask type and its inverse
355 for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) {
356 // build mask and inverse
357 #if CONFIG_AV1_HIGHBITDEPTH
358 if (hbd)
359 av1_build_compound_diffwtd_mask_highbd(
360 tmp_mask[cur_mask_type], cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw,
361 CONVERT_TO_BYTEPTR(p1), bw, bh, bw, xd->bd);
362 else
363 av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type,
364 p0, bw, p1, bw, bh, bw);
365 #else
366 (void)hbd;
367 av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type, p0,
368 bw, p1, bw, bh, bw);
369 #endif // CONFIG_AV1_HIGHBITDEPTH
370
371 // compute rd for mask
372 uint64_t sse = av1_wedge_sse_from_residuals(residual1, diff10,
373 tmp_mask[cur_mask_type], N);
374 sse = ROUND_POWER_OF_TWO(sse, bd_round);
375
376 model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
377 &rate, &dist);
378 const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
379
380 if (rd0 < best_rd) {
381 best_mask_type = cur_mask_type;
382 best_rd = rd0;
383 *best_sse = sse;
384 }
385 }
386 mbmi->interinter_comp.mask_type = best_mask_type;
387 if (best_mask_type == DIFFWTD_38_INV) {
388 memcpy(xd->seg_mask, seg_mask, N * 2);
389 }
390 return best_rd;
391 }
392
pick_interintra_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1)393 static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
394 const MACROBLOCK *const x,
395 const BLOCK_SIZE bsize,
396 const uint8_t *const p0,
397 const uint8_t *const p1) {
398 const MACROBLOCKD *const xd = &x->e_mbd;
399 MB_MODE_INFO *const mbmi = xd->mi[0];
400 assert(av1_is_wedge_used(bsize));
401 assert(cpi->common.seq_params->enable_interintra_compound);
402
403 const struct buf_2d *const src = &x->plane[0].src;
404 const int bw = block_size_wide[bsize];
405 const int bh = block_size_high[bsize];
406 DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]); // src - pred1
407 DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]); // pred1 - pred0
408 #if CONFIG_AV1_HIGHBITDEPTH
409 if (is_cur_buf_hbd(xd)) {
410 aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
411 CONVERT_TO_BYTEPTR(p1), bw);
412 aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw,
413 CONVERT_TO_BYTEPTR(p0), bw);
414 } else {
415 aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
416 aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
417 }
418 #else
419 aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
420 aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
421 #endif
422 int8_t wedge_index = -1;
423 uint64_t sse;
424 int64_t rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0,
425 &wedge_index, &sse);
426
427 mbmi->interintra_wedge_index = wedge_index;
428 return rd;
429 }
430
get_inter_predictors_masked_compound(MACROBLOCK * x,const BLOCK_SIZE bsize,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides)431 static AOM_INLINE void get_inter_predictors_masked_compound(
432 MACROBLOCK *x, const BLOCK_SIZE bsize, uint8_t **preds0, uint8_t **preds1,
433 int16_t *residual1, int16_t *diff10, int *strides) {
434 MACROBLOCKD *xd = &x->e_mbd;
435 const int bw = block_size_wide[bsize];
436 const int bh = block_size_high[bsize];
437 // get inter predictors to use for masked compound modes
438 av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0, preds0,
439 strides);
440 av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1, preds1,
441 strides);
442 const struct buf_2d *const src = &x->plane[0].src;
443 #if CONFIG_AV1_HIGHBITDEPTH
444 if (is_cur_buf_hbd(xd)) {
445 aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
446 CONVERT_TO_BYTEPTR(*preds1), bw);
447 aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(*preds1),
448 bw, CONVERT_TO_BYTEPTR(*preds0), bw);
449 } else {
450 aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1,
451 bw);
452 aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
453 }
454 #else
455 aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1, bw);
456 aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
457 #endif
458 }
459
460 // Computes the rd cost for the given interintra mode and updates the best
compute_best_interintra_mode(const AV1_COMP * const cpi,MB_MODE_INFO * mbmi,MACROBLOCKD * xd,MACROBLOCK * const x,const int * const interintra_mode_cost,const BUFFER_SET * orig_dst,uint8_t * intrapred,const uint8_t * tmp_buf,INTERINTRA_MODE * best_interintra_mode,int64_t * best_interintra_rd,INTERINTRA_MODE interintra_mode,BLOCK_SIZE bsize)461 static INLINE void compute_best_interintra_mode(
462 const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
463 MACROBLOCK *const x, const int *const interintra_mode_cost,
464 const BUFFER_SET *orig_dst, uint8_t *intrapred, const uint8_t *tmp_buf,
465 INTERINTRA_MODE *best_interintra_mode, int64_t *best_interintra_rd,
466 INTERINTRA_MODE interintra_mode, BLOCK_SIZE bsize) {
467 const AV1_COMMON *const cm = &cpi->common;
468 int rate;
469 uint8_t skip_txfm_sb;
470 int64_t dist, skip_sse_sb;
471 const int bw = block_size_wide[bsize];
472 mbmi->interintra_mode = interintra_mode;
473 int rmode = interintra_mode_cost[interintra_mode];
474 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
475 intrapred, bw);
476 av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
477 model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](cpi, bsize, x, xd, 0, 0, &rate, &dist,
478 &skip_txfm_sb, &skip_sse_sb, NULL,
479 NULL, NULL);
480 int64_t rd = RDCOST(x->rdmult, rate + rmode, dist);
481 if (rd < *best_interintra_rd) {
482 *best_interintra_rd = rd;
483 *best_interintra_mode = mbmi->interintra_mode;
484 }
485 }
486
estimate_yrd_for_sb(const AV1_COMP * const cpi,BLOCK_SIZE bs,MACROBLOCK * x,int64_t ref_best_rd,RD_STATS * rd_stats)487 static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
488 MACROBLOCK *x, int64_t ref_best_rd,
489 RD_STATS *rd_stats) {
490 MACROBLOCKD *const xd = &x->e_mbd;
491 if (ref_best_rd < 0) return INT64_MAX;
492 av1_subtract_plane(x, bs, 0);
493 const int64_t rd = av1_estimate_txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs,
494 max_txsize_rect_lookup[bs]);
495 if (rd != INT64_MAX) {
496 const int skip_ctx = av1_get_skip_txfm_context(xd);
497 if (rd_stats->skip_txfm) {
498 const int s1 = x->mode_costs.skip_txfm_cost[skip_ctx][1];
499 rd_stats->rate = s1;
500 } else {
501 const int s0 = x->mode_costs.skip_txfm_cost[skip_ctx][0];
502 rd_stats->rate += s0;
503 }
504 }
505 return rd;
506 }
507
508 // Computes the rd_threshold for smooth interintra rd search.
compute_rd_thresh(MACROBLOCK * const x,int total_mode_rate,int64_t ref_best_rd)509 static AOM_INLINE int64_t compute_rd_thresh(MACROBLOCK *const x,
510 int total_mode_rate,
511 int64_t ref_best_rd) {
512 const int64_t rd_thresh = get_rd_thresh_from_best_rd(
513 ref_best_rd, (1 << INTER_INTRA_RD_THRESH_SHIFT),
514 INTER_INTRA_RD_THRESH_SCALE);
515 const int64_t mode_rd = RDCOST(x->rdmult, total_mode_rate, 0);
516 return (rd_thresh - mode_rd);
517 }
518
519 // Computes the best wedge interintra mode
compute_best_wedge_interintra(const AV1_COMP * const cpi,MB_MODE_INFO * mbmi,MACROBLOCKD * xd,MACROBLOCK * const x,const int * const interintra_mode_cost,const BUFFER_SET * orig_dst,uint8_t * intrapred_,uint8_t * tmp_buf_,int * best_mode,int * best_wedge_index,BLOCK_SIZE bsize)520 static AOM_INLINE int64_t compute_best_wedge_interintra(
521 const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
522 MACROBLOCK *const x, const int *const interintra_mode_cost,
523 const BUFFER_SET *orig_dst, uint8_t *intrapred_, uint8_t *tmp_buf_,
524 int *best_mode, int *best_wedge_index, BLOCK_SIZE bsize) {
525 const AV1_COMMON *const cm = &cpi->common;
526 const int bw = block_size_wide[bsize];
527 int64_t best_interintra_rd_wedge = INT64_MAX;
528 int64_t best_total_rd = INT64_MAX;
529 uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
530 for (INTERINTRA_MODE mode = 0; mode < INTERINTRA_MODES; ++mode) {
531 mbmi->interintra_mode = mode;
532 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
533 intrapred, bw);
534 int64_t rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
535 const int rate_overhead =
536 interintra_mode_cost[mode] +
537 x->mode_costs.wedge_idx_cost[bsize][mbmi->interintra_wedge_index];
538 const int64_t total_rd = rd + RDCOST(x->rdmult, rate_overhead, 0);
539 if (total_rd < best_total_rd) {
540 best_total_rd = total_rd;
541 best_interintra_rd_wedge = rd;
542 *best_mode = mbmi->interintra_mode;
543 *best_wedge_index = mbmi->interintra_wedge_index;
544 }
545 }
546 return best_interintra_rd_wedge;
547 }
548
handle_smooth_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,MB_MODE_INFO * mbmi,int64_t ref_best_rd,int * rate_mv,INTERINTRA_MODE * best_interintra_mode,int64_t * best_rd,int * best_mode_rate,const BUFFER_SET * orig_dst,uint8_t * tmp_buf,uint8_t * intrapred,HandleInterModeArgs * args)549 static int handle_smooth_inter_intra_mode(
550 const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
551 MB_MODE_INFO *mbmi, int64_t ref_best_rd, int *rate_mv,
552 INTERINTRA_MODE *best_interintra_mode, int64_t *best_rd,
553 int *best_mode_rate, const BUFFER_SET *orig_dst, uint8_t *tmp_buf,
554 uint8_t *intrapred, HandleInterModeArgs *args) {
555 MACROBLOCKD *xd = &x->e_mbd;
556 const ModeCosts *mode_costs = &x->mode_costs;
557 const int *const interintra_mode_cost =
558 mode_costs->interintra_mode_cost[size_group_lookup[bsize]];
559 const AV1_COMMON *const cm = &cpi->common;
560 const int bw = block_size_wide[bsize];
561
562 mbmi->use_wedge_interintra = 0;
563
564 if (cpi->sf.inter_sf.reuse_inter_intra_mode == 0 ||
565 *best_interintra_mode == INTERINTRA_MODES) {
566 int64_t best_interintra_rd = INT64_MAX;
567 for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
568 ++cur_mode) {
569 if ((!cpi->oxcf.intra_mode_cfg.enable_smooth_intra ||
570 cpi->sf.intra_sf.disable_smooth_intra) &&
571 cur_mode == II_SMOOTH_PRED)
572 continue;
573 compute_best_interintra_mode(
574 cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred, tmp_buf,
575 best_interintra_mode, &best_interintra_rd, cur_mode, bsize);
576 }
577 args->inter_intra_mode[mbmi->ref_frame[0]] = *best_interintra_mode;
578 }
579 assert(IMPLIES(!cpi->oxcf.comp_type_cfg.enable_smooth_interintra,
580 *best_interintra_mode != II_SMOOTH_PRED));
581 // Recompute prediction if required
582 bool interintra_mode_reuse = cpi->sf.inter_sf.reuse_inter_intra_mode ||
583 *best_interintra_mode != INTERINTRA_MODES;
584 if (interintra_mode_reuse || *best_interintra_mode != INTERINTRA_MODES - 1) {
585 mbmi->interintra_mode = *best_interintra_mode;
586 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
587 intrapred, bw);
588 av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
589 }
590
591 // Compute rd cost for best smooth_interintra
592 RD_STATS rd_stats;
593 const int is_wedge_used = av1_is_wedge_used(bsize);
594 const int rmode =
595 interintra_mode_cost[*best_interintra_mode] +
596 (is_wedge_used ? mode_costs->wedge_interintra_cost[bsize][0] : 0);
597 const int total_mode_rate = rmode + *rate_mv;
598 const int64_t rd_thresh = compute_rd_thresh(x, total_mode_rate, ref_best_rd);
599 int64_t rd = estimate_yrd_for_sb(cpi, bsize, x, rd_thresh, &rd_stats);
600 if (rd != INT64_MAX) {
601 rd = RDCOST(x->rdmult, total_mode_rate + rd_stats.rate, rd_stats.dist);
602 } else {
603 return IGNORE_MODE;
604 }
605 *best_rd = rd;
606 *best_mode_rate = rmode;
607 // Return early if best rd not good enough
608 if (ref_best_rd < INT64_MAX &&
609 (*best_rd >> INTER_INTRA_RD_THRESH_SHIFT) * INTER_INTRA_RD_THRESH_SCALE >
610 ref_best_rd) {
611 return IGNORE_MODE;
612 }
613 return 0;
614 }
615
handle_wedge_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,MB_MODE_INFO * mbmi,int * rate_mv,INTERINTRA_MODE * best_interintra_mode,int64_t * best_rd,const BUFFER_SET * orig_dst,uint8_t * tmp_buf_,uint8_t * tmp_buf,uint8_t * intrapred_,uint8_t * intrapred,HandleInterModeArgs * args,int * tmp_rate_mv,int * rate_overhead,int_mv * tmp_mv,int64_t best_rd_no_wedge)616 static int handle_wedge_inter_intra_mode(
617 const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
618 MB_MODE_INFO *mbmi, int *rate_mv, INTERINTRA_MODE *best_interintra_mode,
619 int64_t *best_rd, const BUFFER_SET *orig_dst, uint8_t *tmp_buf_,
620 uint8_t *tmp_buf, uint8_t *intrapred_, uint8_t *intrapred,
621 HandleInterModeArgs *args, int *tmp_rate_mv, int *rate_overhead,
622 int_mv *tmp_mv, int64_t best_rd_no_wedge) {
623 MACROBLOCKD *xd = &x->e_mbd;
624 const ModeCosts *mode_costs = &x->mode_costs;
625 const int *const interintra_mode_cost =
626 mode_costs->interintra_mode_cost[size_group_lookup[bsize]];
627 const AV1_COMMON *const cm = &cpi->common;
628 const int bw = block_size_wide[bsize];
629 const int try_smooth_interintra =
630 cpi->oxcf.comp_type_cfg.enable_smooth_interintra;
631
632 mbmi->use_wedge_interintra = 1;
633
634 if (!cpi->sf.inter_sf.fast_interintra_wedge_search) {
635 // Exhaustive search of all wedge and mode combinations.
636 int best_mode = 0;
637 int best_wedge_index = 0;
638 *best_rd = compute_best_wedge_interintra(
639 cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred_, tmp_buf_,
640 &best_mode, &best_wedge_index, bsize);
641 mbmi->interintra_mode = best_mode;
642 mbmi->interintra_wedge_index = best_wedge_index;
643 if (best_mode != INTERINTRA_MODES - 1) {
644 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
645 intrapred, bw);
646 }
647 } else if (!try_smooth_interintra) {
648 if (*best_interintra_mode == INTERINTRA_MODES) {
649 mbmi->interintra_mode = INTERINTRA_MODES - 1;
650 *best_interintra_mode = INTERINTRA_MODES - 1;
651 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
652 intrapred, bw);
653 // Pick wedge mask based on INTERINTRA_MODES - 1
654 *best_rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
655 // Find the best interintra mode for the chosen wedge mask
656 for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
657 ++cur_mode) {
658 compute_best_interintra_mode(
659 cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred,
660 tmp_buf, best_interintra_mode, best_rd, cur_mode, bsize);
661 }
662 args->inter_intra_mode[mbmi->ref_frame[0]] = *best_interintra_mode;
663 mbmi->interintra_mode = *best_interintra_mode;
664
665 // Recompute prediction if required
666 if (*best_interintra_mode != INTERINTRA_MODES - 1) {
667 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
668 intrapred, bw);
669 }
670 } else {
671 // Pick wedge mask for the best interintra mode (reused)
672 mbmi->interintra_mode = *best_interintra_mode;
673 av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
674 intrapred, bw);
675 *best_rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
676 }
677 } else {
678 // Pick wedge mask for the best interintra mode from smooth_interintra
679 *best_rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
680 }
681
682 *rate_overhead =
683 interintra_mode_cost[mbmi->interintra_mode] +
684 mode_costs->wedge_idx_cost[bsize][mbmi->interintra_wedge_index] +
685 mode_costs->wedge_interintra_cost[bsize][1];
686 *best_rd += RDCOST(x->rdmult, *rate_overhead + *rate_mv, 0);
687
688 int64_t rd = INT64_MAX;
689 const int_mv mv0 = mbmi->mv[0];
690 // Refine motion vector for NEWMV case.
691 if (have_newmv_in_inter_mode(mbmi->mode)) {
692 int rate_sum;
693 uint8_t skip_txfm_sb;
694 int64_t dist_sum, skip_sse_sb;
695 // get negative of mask
696 const uint8_t *mask =
697 av1_get_contiguous_soft_mask(mbmi->interintra_wedge_index, 1, bsize);
698 av1_compound_single_motion_search(cpi, x, bsize, &tmp_mv->as_mv, intrapred,
699 mask, bw, tmp_rate_mv, 0);
700 if (mbmi->mv[0].as_int != tmp_mv->as_int) {
701 mbmi->mv[0].as_int = tmp_mv->as_int;
702 // Set ref_frame[1] to NONE_FRAME temporarily so that the intra
703 // predictor is not calculated again in av1_enc_build_inter_predictor().
704 mbmi->ref_frame[1] = NONE_FRAME;
705 const int mi_row = xd->mi_row;
706 const int mi_col = xd->mi_col;
707 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
708 AOM_PLANE_Y, AOM_PLANE_Y);
709 mbmi->ref_frame[1] = INTRA_FRAME;
710 av1_combine_interintra(xd, bsize, 0, xd->plane[AOM_PLANE_Y].dst.buf,
711 xd->plane[AOM_PLANE_Y].dst.stride, intrapred, bw);
712 model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
713 cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &skip_txfm_sb,
714 &skip_sse_sb, NULL, NULL, NULL);
715 rd =
716 RDCOST(x->rdmult, *tmp_rate_mv + *rate_overhead + rate_sum, dist_sum);
717 }
718 }
719 if (rd >= *best_rd) {
720 tmp_mv->as_int = mv0.as_int;
721 *tmp_rate_mv = *rate_mv;
722 av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
723 }
724 // Evaluate closer to true rd
725 RD_STATS rd_stats;
726 const int64_t mode_rd = RDCOST(x->rdmult, *rate_overhead + *tmp_rate_mv, 0);
727 const int64_t tmp_rd_thresh = best_rd_no_wedge - mode_rd;
728 rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
729 if (rd != INT64_MAX) {
730 rd = RDCOST(x->rdmult, *rate_overhead + *tmp_rate_mv + rd_stats.rate,
731 rd_stats.dist);
732 } else {
733 if (*best_rd == INT64_MAX) return IGNORE_MODE;
734 }
735 *best_rd = rd;
736 return 0;
737 }
738
av1_handle_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,MB_MODE_INFO * mbmi,HandleInterModeArgs * args,int64_t ref_best_rd,int * rate_mv,int * tmp_rate2,const BUFFER_SET * orig_dst)739 int av1_handle_inter_intra_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
740 BLOCK_SIZE bsize, MB_MODE_INFO *mbmi,
741 HandleInterModeArgs *args, int64_t ref_best_rd,
742 int *rate_mv, int *tmp_rate2,
743 const BUFFER_SET *orig_dst) {
744 const int try_smooth_interintra =
745 cpi->oxcf.comp_type_cfg.enable_smooth_interintra;
746
747 const int is_wedge_used = av1_is_wedge_used(bsize);
748 const int try_wedge_interintra =
749 is_wedge_used && enable_wedge_interintra_search(x, cpi);
750
751 const AV1_COMMON *const cm = &cpi->common;
752 MACROBLOCKD *xd = &x->e_mbd;
753 const int bw = block_size_wide[bsize];
754 DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
755 DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
756 uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
757 uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
758 const int mi_row = xd->mi_row;
759 const int mi_col = xd->mi_col;
760
761 // Single reference inter prediction
762 mbmi->ref_frame[1] = NONE_FRAME;
763 xd->plane[0].dst.buf = tmp_buf;
764 xd->plane[0].dst.stride = bw;
765 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
766 AOM_PLANE_Y, AOM_PLANE_Y);
767 const int num_planes = av1_num_planes(cm);
768
769 // Restore the buffers for intra prediction
770 restore_dst_buf(xd, *orig_dst, num_planes);
771 mbmi->ref_frame[1] = INTRA_FRAME;
772 INTERINTRA_MODE best_interintra_mode =
773 args->inter_intra_mode[mbmi->ref_frame[0]];
774
775 // Compute smooth_interintra
776 int64_t best_interintra_rd_nowedge = INT64_MAX;
777 int best_mode_rate = INT_MAX;
778 if (try_smooth_interintra) {
779 int ret = handle_smooth_inter_intra_mode(
780 cpi, x, bsize, mbmi, ref_best_rd, rate_mv, &best_interintra_mode,
781 &best_interintra_rd_nowedge, &best_mode_rate, orig_dst, tmp_buf,
782 intrapred, args);
783 if (ret == IGNORE_MODE) {
784 return IGNORE_MODE;
785 }
786 }
787
788 // Compute wedge interintra
789 int64_t best_interintra_rd_wedge = INT64_MAX;
790 const int_mv mv0 = mbmi->mv[0];
791 int_mv tmp_mv = mv0;
792 int tmp_rate_mv = 0;
793 int rate_overhead = 0;
794 if (try_wedge_interintra) {
795 int ret = handle_wedge_inter_intra_mode(
796 cpi, x, bsize, mbmi, rate_mv, &best_interintra_mode,
797 &best_interintra_rd_wedge, orig_dst, tmp_buf_, tmp_buf, intrapred_,
798 intrapred, args, &tmp_rate_mv, &rate_overhead, &tmp_mv,
799 best_interintra_rd_nowedge);
800 if (ret == IGNORE_MODE) {
801 return IGNORE_MODE;
802 }
803 }
804
805 if (best_interintra_rd_nowedge == INT64_MAX &&
806 best_interintra_rd_wedge == INT64_MAX) {
807 return IGNORE_MODE;
808 }
809 if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
810 mbmi->mv[0].as_int = tmp_mv.as_int;
811 *tmp_rate2 += tmp_rate_mv - *rate_mv;
812 *rate_mv = tmp_rate_mv;
813 best_mode_rate = rate_overhead;
814 } else if (try_smooth_interintra && try_wedge_interintra) {
815 // If smooth was best, but we over-wrote the values when evaluating the
816 // wedge mode, we need to recompute the smooth values.
817 mbmi->use_wedge_interintra = 0;
818 mbmi->interintra_mode = best_interintra_mode;
819 mbmi->mv[0].as_int = mv0.as_int;
820 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
821 AOM_PLANE_Y, AOM_PLANE_Y);
822 }
823 *tmp_rate2 += best_mode_rate;
824
825 if (num_planes > 1) {
826 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
827 AOM_PLANE_U, num_planes - 1);
828 }
829 return 0;
830 }
831
832 // Computes the valid compound_types to be evaluated
compute_valid_comp_types(MACROBLOCK * x,const AV1_COMP * const cpi,BLOCK_SIZE bsize,int masked_compound_used,int mode_search_mask,COMPOUND_TYPE * valid_comp_types)833 static INLINE int compute_valid_comp_types(MACROBLOCK *x,
834 const AV1_COMP *const cpi,
835 BLOCK_SIZE bsize,
836 int masked_compound_used,
837 int mode_search_mask,
838 COMPOUND_TYPE *valid_comp_types) {
839 const AV1_COMMON *cm = &cpi->common;
840 int valid_type_count = 0;
841 int comp_type, valid_check;
842 int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
843
844 const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
845 const int try_distwtd_comp =
846 ((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
847 cm->seq_params->order_hint_info.enable_dist_wtd_comp == 1 &&
848 cpi->sf.inter_sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
849
850 // Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
851 for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
852 comp_type++) {
853 valid_check =
854 (comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
855 if (valid_check && is_interinter_compound_used(comp_type, bsize))
856 valid_comp_types[valid_type_count++] = comp_type;
857 }
858 // Check if COMPOUND_WEDGE and COMPOUND_DIFFWTD are valid cases
859 if (masked_compound_used) {
860 // enable_masked_type[0] corresponds to COMPOUND_WEDGE
861 // enable_masked_type[1] corresponds to COMPOUND_DIFFWTD
862 enable_masked_type[0] = enable_wedge_interinter_search(x, cpi);
863 enable_masked_type[1] = cpi->oxcf.comp_type_cfg.enable_diff_wtd_comp;
864 for (comp_type = COMPOUND_WEDGE; comp_type <= COMPOUND_DIFFWTD;
865 comp_type++) {
866 if ((mode_search_mask & (1 << comp_type)) &&
867 is_interinter_compound_used(comp_type, bsize) &&
868 enable_masked_type[comp_type - COMPOUND_WEDGE])
869 valid_comp_types[valid_type_count++] = comp_type;
870 }
871 }
872 return valid_type_count;
873 }
874
875 // Calculates the cost for compound type mask
calc_masked_type_cost(const ModeCosts * mode_costs,BLOCK_SIZE bsize,int comp_group_idx_ctx,int comp_index_ctx,int masked_compound_used,int * masked_type_cost)876 static INLINE void calc_masked_type_cost(
877 const ModeCosts *mode_costs, BLOCK_SIZE bsize, int comp_group_idx_ctx,
878 int comp_index_ctx, int masked_compound_used, int *masked_type_cost) {
879 av1_zero_array(masked_type_cost, COMPOUND_TYPES);
880 // Account for group index cost when wedge and/or diffwtd prediction are
881 // enabled
882 if (masked_compound_used) {
883 // Compound group index of average and distwtd is 0
884 // Compound group index of wedge and diffwtd is 1
885 masked_type_cost[COMPOUND_AVERAGE] +=
886 mode_costs->comp_group_idx_cost[comp_group_idx_ctx][0];
887 masked_type_cost[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_AVERAGE];
888 masked_type_cost[COMPOUND_WEDGE] +=
889 mode_costs->comp_group_idx_cost[comp_group_idx_ctx][1];
890 masked_type_cost[COMPOUND_DIFFWTD] += masked_type_cost[COMPOUND_WEDGE];
891 }
892
893 // Compute the cost to signal compound index/type
894 masked_type_cost[COMPOUND_AVERAGE] +=
895 mode_costs->comp_idx_cost[comp_index_ctx][1];
896 masked_type_cost[COMPOUND_DISTWTD] +=
897 mode_costs->comp_idx_cost[comp_index_ctx][0];
898 masked_type_cost[COMPOUND_WEDGE] += mode_costs->compound_type_cost[bsize][0];
899 masked_type_cost[COMPOUND_DIFFWTD] +=
900 mode_costs->compound_type_cost[bsize][1];
901 }
902
903 // Updates mbmi structure with the relevant compound type info
update_mbmi_for_compound_type(MB_MODE_INFO * mbmi,COMPOUND_TYPE cur_type)904 static INLINE void update_mbmi_for_compound_type(MB_MODE_INFO *mbmi,
905 COMPOUND_TYPE cur_type) {
906 mbmi->interinter_comp.type = cur_type;
907 mbmi->comp_group_idx = (cur_type >= COMPOUND_WEDGE);
908 mbmi->compound_idx = (cur_type != COMPOUND_DISTWTD);
909 }
910
911 // When match is found, populate the compound type data
912 // and calculate the rd cost using the stored stats and
913 // update the mbmi appropriately.
populate_reuse_comp_type_data(const MACROBLOCK * x,MB_MODE_INFO * mbmi,BEST_COMP_TYPE_STATS * best_type_stats,int_mv * cur_mv,int32_t * comp_rate,int64_t * comp_dist,int * comp_rs2,int * rate_mv,int64_t * rd,int match_index)914 static INLINE int populate_reuse_comp_type_data(
915 const MACROBLOCK *x, MB_MODE_INFO *mbmi,
916 BEST_COMP_TYPE_STATS *best_type_stats, int_mv *cur_mv, int32_t *comp_rate,
917 int64_t *comp_dist, int *comp_rs2, int *rate_mv, int64_t *rd,
918 int match_index) {
919 const int winner_comp_type =
920 x->comp_rd_stats[match_index].interinter_comp.type;
921 if (comp_rate[winner_comp_type] == INT_MAX)
922 return best_type_stats->best_compmode_interinter_cost;
923 update_mbmi_for_compound_type(mbmi, winner_comp_type);
924 mbmi->interinter_comp = x->comp_rd_stats[match_index].interinter_comp;
925 *rd = RDCOST(
926 x->rdmult,
927 comp_rs2[winner_comp_type] + *rate_mv + comp_rate[winner_comp_type],
928 comp_dist[winner_comp_type]);
929 mbmi->mv[0].as_int = cur_mv[0].as_int;
930 mbmi->mv[1].as_int = cur_mv[1].as_int;
931 return comp_rs2[winner_comp_type];
932 }
933
934 // Updates rd cost and relevant compound type data for the best compound type
update_best_info(const MB_MODE_INFO * const mbmi,int64_t * rd,BEST_COMP_TYPE_STATS * best_type_stats,int64_t best_rd_cur,int64_t comp_model_rd_cur,int rs2)935 static INLINE void update_best_info(const MB_MODE_INFO *const mbmi, int64_t *rd,
936 BEST_COMP_TYPE_STATS *best_type_stats,
937 int64_t best_rd_cur,
938 int64_t comp_model_rd_cur, int rs2) {
939 *rd = best_rd_cur;
940 best_type_stats->comp_best_model_rd = comp_model_rd_cur;
941 best_type_stats->best_compound_data = mbmi->interinter_comp;
942 best_type_stats->best_compmode_interinter_cost = rs2;
943 }
944
945 // Updates best_mv for masked compound types
update_mask_best_mv(const MB_MODE_INFO * const mbmi,int_mv * best_mv,int * best_tmp_rate_mv,int tmp_rate_mv)946 static INLINE void update_mask_best_mv(const MB_MODE_INFO *const mbmi,
947 int_mv *best_mv, int *best_tmp_rate_mv,
948 int tmp_rate_mv) {
949 *best_tmp_rate_mv = tmp_rate_mv;
950 best_mv[0].as_int = mbmi->mv[0].as_int;
951 best_mv[1].as_int = mbmi->mv[1].as_int;
952 }
953
save_comp_rd_search_stat(MACROBLOCK * x,const MB_MODE_INFO * const mbmi,const int32_t * comp_rate,const int64_t * comp_dist,const int32_t * comp_model_rate,const int64_t * comp_model_dist,const int_mv * cur_mv,const int * comp_rs2)954 static INLINE void save_comp_rd_search_stat(
955 MACROBLOCK *x, const MB_MODE_INFO *const mbmi, const int32_t *comp_rate,
956 const int64_t *comp_dist, const int32_t *comp_model_rate,
957 const int64_t *comp_model_dist, const int_mv *cur_mv, const int *comp_rs2) {
958 const int offset = x->comp_rd_stats_idx;
959 if (offset < MAX_COMP_RD_STATS) {
960 COMP_RD_STATS *const rd_stats = x->comp_rd_stats + offset;
961 memcpy(rd_stats->rate, comp_rate, sizeof(rd_stats->rate));
962 memcpy(rd_stats->dist, comp_dist, sizeof(rd_stats->dist));
963 memcpy(rd_stats->model_rate, comp_model_rate, sizeof(rd_stats->model_rate));
964 memcpy(rd_stats->model_dist, comp_model_dist, sizeof(rd_stats->model_dist));
965 memcpy(rd_stats->comp_rs2, comp_rs2, sizeof(rd_stats->comp_rs2));
966 memcpy(rd_stats->mv, cur_mv, sizeof(rd_stats->mv));
967 memcpy(rd_stats->ref_frames, mbmi->ref_frame, sizeof(rd_stats->ref_frames));
968 rd_stats->mode = mbmi->mode;
969 rd_stats->filter = mbmi->interp_filters;
970 rd_stats->ref_mv_idx = mbmi->ref_mv_idx;
971 const MACROBLOCKD *const xd = &x->e_mbd;
972 for (int i = 0; i < 2; ++i) {
973 const WarpedMotionParams *const wm =
974 &xd->global_motion[mbmi->ref_frame[i]];
975 rd_stats->is_global[i] = is_global_mv_block(mbmi, wm->wmtype);
976 }
977 memcpy(&rd_stats->interinter_comp, &mbmi->interinter_comp,
978 sizeof(rd_stats->interinter_comp));
979 ++x->comp_rd_stats_idx;
980 }
981 }
982
get_interinter_compound_mask_rate(const ModeCosts * const mode_costs,const MB_MODE_INFO * const mbmi)983 static INLINE int get_interinter_compound_mask_rate(
984 const ModeCosts *const mode_costs, const MB_MODE_INFO *const mbmi) {
985 const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
986 // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
987 if (compound_type == COMPOUND_WEDGE) {
988 return av1_is_wedge_used(mbmi->bsize)
989 ? av1_cost_literal(1) +
990 mode_costs
991 ->wedge_idx_cost[mbmi->bsize]
992 [mbmi->interinter_comp.wedge_index]
993 : 0;
994 } else {
995 assert(compound_type == COMPOUND_DIFFWTD);
996 return av1_cost_literal(1);
997 }
998 }
999
1000 // Takes a backup of rate, distortion and model_rd for future reuse
backup_stats(COMPOUND_TYPE cur_type,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,int rate_sum,int64_t dist_sum,RD_STATS * rd_stats,int * comp_rs2,int rs2)1001 static INLINE void backup_stats(COMPOUND_TYPE cur_type, int32_t *comp_rate,
1002 int64_t *comp_dist, int32_t *comp_model_rate,
1003 int64_t *comp_model_dist, int rate_sum,
1004 int64_t dist_sum, RD_STATS *rd_stats,
1005 int *comp_rs2, int rs2) {
1006 comp_rate[cur_type] = rd_stats->rate;
1007 comp_dist[cur_type] = rd_stats->dist;
1008 comp_model_rate[cur_type] = rate_sum;
1009 comp_model_dist[cur_type] = dist_sum;
1010 comp_rs2[cur_type] = rs2;
1011 }
1012
save_mask_search_results(const PREDICTION_MODE this_mode,const int reuse_level)1013 static INLINE int save_mask_search_results(const PREDICTION_MODE this_mode,
1014 const int reuse_level) {
1015 if (reuse_level || (this_mode == NEW_NEWMV))
1016 return 1;
1017 else
1018 return 0;
1019 }
1020
prune_mode_by_skip_rd(const AV1_COMP * const cpi,MACROBLOCK * x,MACROBLOCKD * xd,const BLOCK_SIZE bsize,int64_t ref_skip_rd,int mode_rate)1021 static INLINE int prune_mode_by_skip_rd(const AV1_COMP *const cpi,
1022 MACROBLOCK *x, MACROBLOCKD *xd,
1023 const BLOCK_SIZE bsize,
1024 int64_t ref_skip_rd, int mode_rate) {
1025 int eval_txfm = 1;
1026 // Check if the mode is good enough based on skip rd
1027 if (cpi->sf.inter_sf.txfm_rd_gate_level) {
1028 int64_t sse_y = compute_sse_plane(x, xd, PLANE_TYPE_Y, bsize);
1029 int64_t skip_rd = RDCOST(x->rdmult, mode_rate, (sse_y << 4));
1030 eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd,
1031 cpi->sf.inter_sf.txfm_rd_gate_level, 1);
1032 }
1033 return eval_txfm;
1034 }
1035
masked_compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const BLOCK_SIZE bsize,const PREDICTION_MODE this_mode,int * rs2,int rate_mv,const BUFFER_SET * ctx,int * out_rate_mv,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides,int mode_rate,int64_t rd_thresh,int * calc_pred_masked_compound,int32_t * comp_rate,int64_t * comp_dist,int32_t * comp_model_rate,int64_t * comp_model_dist,const int64_t comp_best_model_rd,int64_t * const comp_model_rd_cur,int * comp_rs2,int64_t ref_skip_rd)1036 static int64_t masked_compound_type_rd(
1037 const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
1038 const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
1039 int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
1040 uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
1041 int mode_rate, int64_t rd_thresh, int *calc_pred_masked_compound,
1042 int32_t *comp_rate, int64_t *comp_dist, int32_t *comp_model_rate,
1043 int64_t *comp_model_dist, const int64_t comp_best_model_rd,
1044 int64_t *const comp_model_rd_cur, int *comp_rs2, int64_t ref_skip_rd) {
1045 const AV1_COMMON *const cm = &cpi->common;
1046 MACROBLOCKD *xd = &x->e_mbd;
1047 MB_MODE_INFO *const mbmi = xd->mi[0];
1048 int64_t best_rd_cur = INT64_MAX;
1049 int64_t rd = INT64_MAX;
1050 const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
1051 // This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
1052 assert(compound_type == COMPOUND_WEDGE || compound_type == COMPOUND_DIFFWTD);
1053 int rate_sum;
1054 uint8_t tmp_skip_txfm_sb;
1055 int64_t dist_sum, tmp_skip_sse_sb;
1056 pick_interinter_mask_type pick_interinter_mask[2] = { pick_interinter_wedge,
1057 pick_interinter_seg };
1058
1059 // TODO(any): Save pred and mask calculation as well into records. However
1060 // this may increase memory requirements as compound segment mask needs to be
1061 // stored in each record.
1062 if (*calc_pred_masked_compound) {
1063 get_inter_predictors_masked_compound(x, bsize, preds0, preds1, residual1,
1064 diff10, strides);
1065 *calc_pred_masked_compound = 0;
1066 }
1067 if (compound_type == COMPOUND_WEDGE) {
1068 unsigned int sse;
1069 if (is_cur_buf_hbd(xd))
1070 (void)cpi->ppi->fn_ptr[bsize].vf(CONVERT_TO_BYTEPTR(*preds0), *strides,
1071 CONVERT_TO_BYTEPTR(*preds1), *strides,
1072 &sse);
1073 else
1074 (void)cpi->ppi->fn_ptr[bsize].vf(*preds0, *strides, *preds1, *strides,
1075 &sse);
1076 const unsigned int mse =
1077 ROUND_POWER_OF_TWO(sse, num_pels_log2_lookup[bsize]);
1078 // If two predictors are very similar, skip wedge compound mode search
1079 if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64)) {
1080 *comp_model_rd_cur = INT64_MAX;
1081 return INT64_MAX;
1082 }
1083 }
1084 // Function pointer to pick the appropriate mask
1085 // compound_type == COMPOUND_WEDGE, calls pick_interinter_wedge()
1086 // compound_type == COMPOUND_DIFFWTD, calls pick_interinter_seg()
1087 uint64_t cur_sse = UINT64_MAX;
1088 best_rd_cur = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
1089 cpi, x, bsize, *preds0, *preds1, residual1, diff10, &cur_sse);
1090 *rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1091 best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
1092 assert(cur_sse != UINT64_MAX);
1093 int64_t skip_rd_cur = RDCOST(x->rdmult, *rs2 + rate_mv, (cur_sse << 4));
1094
1095 // Although the true rate_mv might be different after motion search, but it
1096 // is unlikely to be the best mode considering the transform rd cost and other
1097 // mode overhead cost
1098 int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
1099 if (mode_rd > rd_thresh) {
1100 *comp_model_rd_cur = INT64_MAX;
1101 return INT64_MAX;
1102 }
1103
1104 // Check if the mode is good enough based on skip rd
1105 // TODO(nithya): Handle wedge_newmv_search if extending for lower speed
1106 // setting
1107 if (cpi->sf.inter_sf.txfm_rd_gate_level) {
1108 int eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd_cur,
1109 cpi->sf.inter_sf.txfm_rd_gate_level, 1);
1110 if (!eval_txfm) {
1111 *comp_model_rd_cur = INT64_MAX;
1112 return INT64_MAX;
1113 }
1114 }
1115
1116 // Compute cost if matching record not found, else, reuse data
1117 if (comp_rate[compound_type] == INT_MAX) {
1118 // Check whether new MV search for wedge is to be done
1119 int wedge_newmv_search =
1120 have_newmv_in_inter_mode(this_mode) &&
1121 (compound_type == COMPOUND_WEDGE) &&
1122 (!cpi->sf.inter_sf.disable_interinter_wedge_newmv_search);
1123
1124 // Search for new MV if needed and build predictor
1125 if (wedge_newmv_search) {
1126 *out_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1127 bsize, this_mode);
1128 const int mi_row = xd->mi_row;
1129 const int mi_col = xd->mi_col;
1130 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, ctx, bsize,
1131 AOM_PLANE_Y, AOM_PLANE_Y);
1132 } else {
1133 *out_rate_mv = rate_mv;
1134 av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
1135 preds1, strides);
1136 }
1137 // Get the RD cost from model RD
1138 model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
1139 cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
1140 &tmp_skip_sse_sb, NULL, NULL, NULL);
1141 rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
1142 *comp_model_rd_cur = rd;
1143 // Override with best if current is worse than best for new MV
1144 if (wedge_newmv_search) {
1145 if (rd >= best_rd_cur) {
1146 mbmi->mv[0].as_int = cur_mv[0].as_int;
1147 mbmi->mv[1].as_int = cur_mv[1].as_int;
1148 *out_rate_mv = rate_mv;
1149 av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
1150 strides, preds1, strides);
1151 *comp_model_rd_cur = best_rd_cur;
1152 }
1153 }
1154 if (cpi->sf.inter_sf.prune_comp_type_by_model_rd &&
1155 (*comp_model_rd_cur > comp_best_model_rd) &&
1156 comp_best_model_rd != INT64_MAX) {
1157 *comp_model_rd_cur = INT64_MAX;
1158 return INT64_MAX;
1159 }
1160 // Compute RD cost for the current type
1161 RD_STATS rd_stats;
1162 const int64_t tmp_mode_rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
1163 const int64_t tmp_rd_thresh = rd_thresh - tmp_mode_rd;
1164 rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
1165 if (rd != INT64_MAX) {
1166 rd =
1167 RDCOST(x->rdmult, *rs2 + *out_rate_mv + rd_stats.rate, rd_stats.dist);
1168 // Backup rate and distortion for future reuse
1169 backup_stats(compound_type, comp_rate, comp_dist, comp_model_rate,
1170 comp_model_dist, rate_sum, dist_sum, &rd_stats, comp_rs2,
1171 *rs2);
1172 }
1173 } else {
1174 // Reuse data as matching record is found
1175 assert(comp_dist[compound_type] != INT64_MAX);
1176 // When disable_interinter_wedge_newmv_search is set, motion refinement is
1177 // disabled. Hence rate and distortion can be reused in this case as well
1178 assert(IMPLIES((have_newmv_in_inter_mode(this_mode) &&
1179 (compound_type == COMPOUND_WEDGE)),
1180 cpi->sf.inter_sf.disable_interinter_wedge_newmv_search));
1181 assert(mbmi->mv[0].as_int == cur_mv[0].as_int);
1182 assert(mbmi->mv[1].as_int == cur_mv[1].as_int);
1183 *out_rate_mv = rate_mv;
1184 // Calculate RD cost based on stored stats
1185 rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_rate[compound_type],
1186 comp_dist[compound_type]);
1187 // Recalculate model rdcost with the updated rate
1188 *comp_model_rd_cur =
1189 RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_model_rate[compound_type],
1190 comp_model_dist[compound_type]);
1191 }
1192 return rd;
1193 }
1194
1195 // scaling values to be used for gating wedge/compound segment based on best
1196 // approximate rd
1197 static int comp_type_rd_threshold_mul[3] = { 1, 11, 12 };
1198 static int comp_type_rd_threshold_div[3] = { 3, 16, 16 };
1199
av1_compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,HandleInterModeArgs * args,BLOCK_SIZE bsize,int_mv * cur_mv,int mode_search_mask,int masked_compound_used,const BUFFER_SET * orig_dst,const BUFFER_SET * tmp_dst,const CompoundTypeRdBuffers * buffers,int * rate_mv,int64_t * rd,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t ref_skip_rd,int * is_luma_interp_done,int64_t rd_thresh)1200 int av1_compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x,
1201 HandleInterModeArgs *args, BLOCK_SIZE bsize,
1202 int_mv *cur_mv, int mode_search_mask,
1203 int masked_compound_used, const BUFFER_SET *orig_dst,
1204 const BUFFER_SET *tmp_dst,
1205 const CompoundTypeRdBuffers *buffers, int *rate_mv,
1206 int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
1207 int64_t ref_skip_rd, int *is_luma_interp_done,
1208 int64_t rd_thresh) {
1209 const AV1_COMMON *cm = &cpi->common;
1210 MACROBLOCKD *xd = &x->e_mbd;
1211 MB_MODE_INFO *mbmi = xd->mi[0];
1212 const PREDICTION_MODE this_mode = mbmi->mode;
1213 int ref_frame = av1_ref_frame_type(mbmi->ref_frame);
1214 const int bw = block_size_wide[bsize];
1215 int rs2;
1216 int_mv best_mv[2];
1217 int best_tmp_rate_mv = *rate_mv;
1218 BEST_COMP_TYPE_STATS best_type_stats;
1219 // Initializing BEST_COMP_TYPE_STATS
1220 best_type_stats.best_compound_data.type = COMPOUND_AVERAGE;
1221 best_type_stats.best_compmode_interinter_cost = 0;
1222 best_type_stats.comp_best_model_rd = INT64_MAX;
1223
1224 uint8_t *preds0[1] = { buffers->pred0 };
1225 uint8_t *preds1[1] = { buffers->pred1 };
1226 int strides[1] = { bw };
1227 int tmp_rate_mv;
1228 COMPOUND_TYPE cur_type;
1229 // Local array to store the mask cost for different compound types
1230 int masked_type_cost[COMPOUND_TYPES];
1231
1232 int calc_pred_masked_compound = 1;
1233 int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
1234 INT64_MAX };
1235 int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
1236 int comp_rs2[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
1237 int32_t comp_model_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX,
1238 INT_MAX };
1239 int64_t comp_model_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
1240 INT64_MAX };
1241 int match_index = 0;
1242 const int match_found =
1243 find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist, comp_model_rate,
1244 comp_model_dist, comp_rs2, &match_index);
1245 best_mv[0].as_int = cur_mv[0].as_int;
1246 best_mv[1].as_int = cur_mv[1].as_int;
1247 *rd = INT64_MAX;
1248
1249 // Local array to store the valid compound types to be evaluated in the core
1250 // loop
1251 COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
1252 COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
1253 };
1254 int valid_type_count = 0;
1255 // compute_valid_comp_types() returns the number of valid compound types to be
1256 // evaluated and populates the same in the local array valid_comp_types[].
1257 // It also sets the flag 'try_average_and_distwtd_comp'
1258 valid_type_count = compute_valid_comp_types(
1259 x, cpi, bsize, masked_compound_used, mode_search_mask, valid_comp_types);
1260
1261 // The following context indices are independent of compound type
1262 const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
1263 const int comp_index_ctx = get_comp_index_context(cm, xd);
1264
1265 // Populates masked_type_cost local array for the 4 compound types
1266 calc_masked_type_cost(&x->mode_costs, bsize, comp_group_idx_ctx,
1267 comp_index_ctx, masked_compound_used, masked_type_cost);
1268
1269 int64_t comp_model_rd_cur = INT64_MAX;
1270 int64_t best_rd_cur = ref_best_rd;
1271 const int mi_row = xd->mi_row;
1272 const int mi_col = xd->mi_col;
1273
1274 // If the match is found, calculate the rd cost using the
1275 // stored stats and update the mbmi appropriately.
1276 if (match_found && cpi->sf.inter_sf.reuse_compound_type_decision) {
1277 return populate_reuse_comp_type_data(x, mbmi, &best_type_stats, cur_mv,
1278 comp_rate, comp_dist, comp_rs2,
1279 rate_mv, rd, match_index);
1280 }
1281
1282 // If COMPOUND_AVERAGE is not valid, use the spare buffer
1283 if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1284
1285 // Loop over valid compound types
1286 for (int i = 0; i < valid_type_count; i++) {
1287 cur_type = valid_comp_types[i];
1288
1289 if (args->cmp_mode[ref_frame] == COMPOUND_AVERAGE) {
1290 if (cur_type == COMPOUND_WEDGE) continue;
1291 }
1292
1293 comp_model_rd_cur = INT64_MAX;
1294 tmp_rate_mv = *rate_mv;
1295 best_rd_cur = INT64_MAX;
1296 ref_best_rd = AOMMIN(ref_best_rd, *rd);
1297 update_mbmi_for_compound_type(mbmi, cur_type);
1298 rs2 = masked_type_cost[cur_type];
1299
1300 int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
1301 if (mode_rd >= ref_best_rd) continue;
1302
1303 // Case COMPOUND_AVERAGE and COMPOUND_DISTWTD
1304 if (cur_type < COMPOUND_WEDGE) {
1305 if (cpi->sf.inter_sf.enable_fast_compound_mode_search == 2) {
1306 int rate_sum;
1307 uint8_t tmp_skip_txfm_sb;
1308 int64_t dist_sum, tmp_skip_sse_sb;
1309
1310 // Reuse data if matching record is found
1311 if (comp_rate[cur_type] == INT_MAX) {
1312 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1313 AOM_PLANE_Y, AOM_PLANE_Y);
1314 if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
1315 // Compute RD cost for the current type
1316 RD_STATS est_rd_stats;
1317 const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
1318 int64_t est_rd = INT64_MAX;
1319 int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1320 rs2 + *rate_mv);
1321 // Evaluate further if skip rd is low enough
1322 if (eval_txfm) {
1323 est_rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh,
1324 &est_rd_stats);
1325 }
1326 if (est_rd != INT64_MAX) {
1327 best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
1328 est_rd_stats.dist);
1329 model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
1330 cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
1331 &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
1332 comp_model_rd_cur =
1333 RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
1334 // Backup rate and distortion for future reuse
1335 backup_stats(cur_type, comp_rate, comp_dist, comp_model_rate,
1336 comp_model_dist, rate_sum, dist_sum, &est_rd_stats,
1337 comp_rs2, rs2);
1338 }
1339 } else {
1340 // Calculate RD cost based on stored stats
1341 assert(comp_dist[cur_type] != INT64_MAX);
1342 best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
1343 comp_dist[cur_type]);
1344 // Recalculate model rdcost with the updated rate
1345 comp_model_rd_cur =
1346 RDCOST(x->rdmult, rs2 + *rate_mv + comp_model_rate[cur_type],
1347 comp_model_dist[cur_type]);
1348 }
1349 } else {
1350 tmp_rate_mv = *rate_mv;
1351 if (have_newmv_in_inter_mode(this_mode)) {
1352 InterPredParams inter_pred_params;
1353 av1_dist_wtd_comp_weight_assign(
1354 &cpi->common, mbmi, &inter_pred_params.conv_params.fwd_offset,
1355 &inter_pred_params.conv_params.bck_offset,
1356 &inter_pred_params.conv_params.use_dist_wtd_comp_avg, 1);
1357 int mask_value = inter_pred_params.conv_params.fwd_offset * 4;
1358 memset(xd->seg_mask, mask_value,
1359 sizeof(xd->seg_mask[0]) * 2 * MAX_SB_SQUARE);
1360 tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1361 bsize, this_mode);
1362 }
1363 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1364 AOM_PLANE_Y, AOM_PLANE_Y);
1365 if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
1366
1367 int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1368 rs2 + *rate_mv);
1369 if (eval_txfm) {
1370 RD_STATS est_rd_stats;
1371 estimate_yrd_for_sb(cpi, bsize, x, INT64_MAX, &est_rd_stats);
1372
1373 best_rd_cur = RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
1374 est_rd_stats.dist);
1375 }
1376 }
1377
1378 // use spare buffer for following compound type try
1379 if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
1380 } else if (cur_type == COMPOUND_WEDGE) {
1381 int best_mask_index = 0;
1382 int best_wedge_sign = 0;
1383 int_mv tmp_mv[2] = { mbmi->mv[0], mbmi->mv[1] };
1384 int best_rs2 = 0;
1385 int best_rate_mv = *rate_mv;
1386 int wedge_mask_size = get_wedge_types_lookup(bsize);
1387 int need_mask_search = args->wedge_index == -1;
1388 int wedge_newmv_search =
1389 have_newmv_in_inter_mode(this_mode) &&
1390 !cpi->sf.inter_sf.disable_interinter_wedge_newmv_search;
1391
1392 if (need_mask_search && !wedge_newmv_search) {
1393 // short cut repeated single reference block build
1394 av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0,
1395 preds0, strides);
1396 av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1,
1397 preds1, strides);
1398 }
1399
1400 for (int wedge_mask = 0; wedge_mask < wedge_mask_size && need_mask_search;
1401 ++wedge_mask) {
1402 for (int wedge_sign = 0; wedge_sign < 2; ++wedge_sign) {
1403 tmp_rate_mv = *rate_mv;
1404 mbmi->interinter_comp.wedge_index = wedge_mask;
1405 mbmi->interinter_comp.wedge_sign = wedge_sign;
1406 rs2 = masked_type_cost[cur_type];
1407 rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1408
1409 mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
1410 if (mode_rd >= ref_best_rd / 2) continue;
1411
1412 if (wedge_newmv_search) {
1413 tmp_rate_mv = av1_interinter_compound_motion_search(
1414 cpi, x, cur_mv, bsize, this_mode);
1415 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst,
1416 bsize, AOM_PLANE_Y, AOM_PLANE_Y);
1417 } else {
1418 av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
1419 strides, preds1, strides);
1420 }
1421
1422 RD_STATS est_rd_stats;
1423 int64_t this_rd_cur = INT64_MAX;
1424 int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1425 rs2 + *rate_mv);
1426 if (eval_txfm) {
1427 this_rd_cur = estimate_yrd_for_sb(
1428 cpi, bsize, x, AOMMIN(best_rd_cur, ref_best_rd), &est_rd_stats);
1429 }
1430 if (this_rd_cur < INT64_MAX) {
1431 this_rd_cur =
1432 RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
1433 est_rd_stats.dist);
1434 }
1435 if (this_rd_cur < best_rd_cur) {
1436 best_mask_index = wedge_mask;
1437 best_wedge_sign = wedge_sign;
1438 best_rd_cur = this_rd_cur;
1439 tmp_mv[0] = mbmi->mv[0];
1440 tmp_mv[1] = mbmi->mv[1];
1441 best_rate_mv = tmp_rate_mv;
1442 best_rs2 = rs2;
1443 }
1444 }
1445 // Consider the asymmetric partitions for oblique angle only if the
1446 // corresponding symmetric partition is the best so far.
1447 // Note: For horizontal and vertical types, both symmetric and
1448 // asymmetric partitions are always considered.
1449 if (cpi->sf.inter_sf.enable_fast_wedge_mask_search) {
1450 // The first 4 entries in wedge_codebook_16_heqw/hltw/hgtw[16]
1451 // correspond to symmetric partitions of the 4 oblique angles, the
1452 // next 4 entries correspond to the vertical/horizontal
1453 // symmetric/asymmetric partitions and the last 8 entries correspond
1454 // to the asymmetric partitions of oblique types.
1455 const int idx_before_asym_oblique = 7;
1456 const int last_oblique_sym_idx = 3;
1457 if (wedge_mask == idx_before_asym_oblique) {
1458 if (best_mask_index > last_oblique_sym_idx) {
1459 break;
1460 } else {
1461 // Asymmetric (Index-1) map for the corresponding oblique masks.
1462 // WEDGE_OBLIQUE27: sym - 0, asym - 8, 9
1463 // WEDGE_OBLIQUE63: sym - 1, asym - 12, 13
1464 // WEDGE_OBLIQUE117: sym - 2, asym - 14, 15
1465 // WEDGE_OBLIQUE153: sym - 3, asym - 10, 11
1466 const int asym_mask_idx[4] = { 7, 11, 13, 9 };
1467 wedge_mask = asym_mask_idx[best_mask_index];
1468 wedge_mask_size = wedge_mask + 3;
1469 }
1470 }
1471 }
1472 }
1473
1474 if (need_mask_search) {
1475 if (save_mask_search_results(
1476 this_mode, cpi->sf.inter_sf.reuse_mask_search_results)) {
1477 args->wedge_index = best_mask_index;
1478 args->wedge_sign = best_wedge_sign;
1479 }
1480 } else {
1481 mbmi->interinter_comp.wedge_index = args->wedge_index;
1482 mbmi->interinter_comp.wedge_sign = args->wedge_sign;
1483 rs2 = masked_type_cost[cur_type];
1484 rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1485
1486 if (wedge_newmv_search) {
1487 tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1488 bsize, this_mode);
1489 }
1490
1491 best_mask_index = args->wedge_index;
1492 best_wedge_sign = args->wedge_sign;
1493 tmp_mv[0] = mbmi->mv[0];
1494 tmp_mv[1] = mbmi->mv[1];
1495 best_rate_mv = tmp_rate_mv;
1496 best_rs2 = masked_type_cost[cur_type];
1497 best_rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1498 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1499 AOM_PLANE_Y, AOM_PLANE_Y);
1500 int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1501 best_rs2 + *rate_mv);
1502 if (eval_txfm) {
1503 RD_STATS est_rd_stats;
1504 estimate_yrd_for_sb(cpi, bsize, x, INT64_MAX, &est_rd_stats);
1505 best_rd_cur =
1506 RDCOST(x->rdmult, best_rs2 + tmp_rate_mv + est_rd_stats.rate,
1507 est_rd_stats.dist);
1508 }
1509 }
1510
1511 mbmi->interinter_comp.wedge_index = best_mask_index;
1512 mbmi->interinter_comp.wedge_sign = best_wedge_sign;
1513 mbmi->mv[0] = tmp_mv[0];
1514 mbmi->mv[1] = tmp_mv[1];
1515 tmp_rate_mv = best_rate_mv;
1516 rs2 = best_rs2;
1517 } else if (!cpi->sf.inter_sf.enable_fast_compound_mode_search &&
1518 cur_type == COMPOUND_DIFFWTD) {
1519 int_mv tmp_mv[2];
1520 int best_mask_index = 0;
1521 rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1522
1523 int need_mask_search = args->diffwtd_index == -1;
1524
1525 for (int mask_index = 0; mask_index < 2 && need_mask_search;
1526 ++mask_index) {
1527 tmp_rate_mv = *rate_mv;
1528 mbmi->interinter_comp.mask_type = mask_index;
1529 if (have_newmv_in_inter_mode(this_mode)) {
1530 // hard coded number for diff wtd
1531 int mask_value = mask_index == 0 ? 38 : 26;
1532 memset(xd->seg_mask, mask_value,
1533 sizeof(xd->seg_mask[0]) * 2 * MAX_SB_SQUARE);
1534 tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1535 bsize, this_mode);
1536 }
1537 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1538 AOM_PLANE_Y, AOM_PLANE_Y);
1539 RD_STATS est_rd_stats;
1540 int64_t this_rd_cur = INT64_MAX;
1541 int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1542 rs2 + *rate_mv);
1543 if (eval_txfm) {
1544 this_rd_cur =
1545 estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
1546 }
1547 if (this_rd_cur < INT64_MAX) {
1548 this_rd_cur = RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
1549 est_rd_stats.dist);
1550 }
1551
1552 if (this_rd_cur < best_rd_cur) {
1553 best_rd_cur = this_rd_cur;
1554 best_mask_index = mbmi->interinter_comp.mask_type;
1555 tmp_mv[0] = mbmi->mv[0];
1556 tmp_mv[1] = mbmi->mv[1];
1557 }
1558 }
1559
1560 if (need_mask_search) {
1561 if (save_mask_search_results(this_mode, 0))
1562 args->diffwtd_index = best_mask_index;
1563 } else {
1564 mbmi->interinter_comp.mask_type = args->diffwtd_index;
1565 rs2 = masked_type_cost[cur_type];
1566 rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
1567
1568 int mask_value = mbmi->interinter_comp.mask_type == 0 ? 38 : 26;
1569 memset(xd->seg_mask, mask_value,
1570 sizeof(xd->seg_mask[0]) * 2 * MAX_SB_SQUARE);
1571
1572 if (have_newmv_in_inter_mode(this_mode)) {
1573 tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
1574 bsize, this_mode);
1575 }
1576 best_mask_index = mbmi->interinter_comp.mask_type;
1577 tmp_mv[0] = mbmi->mv[0];
1578 tmp_mv[1] = mbmi->mv[1];
1579 av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
1580 AOM_PLANE_Y, AOM_PLANE_Y);
1581 RD_STATS est_rd_stats;
1582 int64_t this_rd_cur = INT64_MAX;
1583 int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
1584 rs2 + *rate_mv);
1585 if (eval_txfm) {
1586 this_rd_cur =
1587 estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
1588 }
1589 if (this_rd_cur < INT64_MAX) {
1590 best_rd_cur = RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
1591 est_rd_stats.dist);
1592 }
1593 }
1594
1595 mbmi->interinter_comp.mask_type = best_mask_index;
1596 mbmi->mv[0] = tmp_mv[0];
1597 mbmi->mv[1] = tmp_mv[1];
1598 } else {
1599 // Handle masked compound types
1600 // Factors to control gating of compound type selection based on best
1601 // approximate rd so far
1602 const int max_comp_type_rd_threshold_mul =
1603 comp_type_rd_threshold_mul[cpi->sf.inter_sf
1604 .prune_comp_type_by_comp_avg];
1605 const int max_comp_type_rd_threshold_div =
1606 comp_type_rd_threshold_div[cpi->sf.inter_sf
1607 .prune_comp_type_by_comp_avg];
1608 // Evaluate COMPOUND_WEDGE / COMPOUND_DIFFWTD if approximated cost is
1609 // within threshold
1610 int64_t approx_rd = ((*rd / max_comp_type_rd_threshold_div) *
1611 max_comp_type_rd_threshold_mul);
1612
1613 if (approx_rd < ref_best_rd) {
1614 const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
1615 best_rd_cur = masked_compound_type_rd(
1616 cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
1617 &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
1618 strides, rd_stats->rate, tmp_rd_thresh, &calc_pred_masked_compound,
1619 comp_rate, comp_dist, comp_model_rate, comp_model_dist,
1620 best_type_stats.comp_best_model_rd, &comp_model_rd_cur, comp_rs2,
1621 ref_skip_rd);
1622 }
1623 }
1624
1625 // Update stats for best compound type
1626 if (best_rd_cur < *rd) {
1627 update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
1628 comp_model_rd_cur, rs2);
1629 if (have_newmv_in_inter_mode(this_mode))
1630 update_mask_best_mv(mbmi, best_mv, &best_tmp_rate_mv, tmp_rate_mv);
1631 }
1632 // reset to original mvs for next iteration
1633 mbmi->mv[0].as_int = cur_mv[0].as_int;
1634 mbmi->mv[1].as_int = cur_mv[1].as_int;
1635 }
1636
1637 mbmi->comp_group_idx =
1638 (best_type_stats.best_compound_data.type < COMPOUND_WEDGE) ? 0 : 1;
1639 mbmi->compound_idx =
1640 !(best_type_stats.best_compound_data.type == COMPOUND_DISTWTD);
1641 mbmi->interinter_comp = best_type_stats.best_compound_data;
1642
1643 if (have_newmv_in_inter_mode(this_mode)) {
1644 mbmi->mv[0].as_int = best_mv[0].as_int;
1645 mbmi->mv[1].as_int = best_mv[1].as_int;
1646 rd_stats->rate += best_tmp_rate_mv - *rate_mv;
1647 *rate_mv = best_tmp_rate_mv;
1648 }
1649
1650 if (this_mode == NEW_NEWMV)
1651 args->cmp_mode[ref_frame] = mbmi->interinter_comp.type;
1652
1653 restore_dst_buf(xd, *orig_dst, 1);
1654 if (!match_found)
1655 save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, comp_model_rate,
1656 comp_model_dist, cur_mv, comp_rs2);
1657 return best_type_stats.best_compmode_interinter_cost;
1658 }
1659