1 /*
2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include "config/aom_config.h"
13
14 #include "av1/encoder/encodemv.h"
15 #if !CONFIG_REALTIME_ONLY
16 #include "av1/encoder/misc_model_weights.h"
17 #endif // !CONFIG_REALTIME_ONLY
18 #include "av1/encoder/mv_prec.h"
19
20 #if !CONFIG_REALTIME_ONLY
get_ref_mv_for_mv_stats(const MB_MODE_INFO * mbmi,const MB_MODE_INFO_EXT_FRAME * mbmi_ext_frame,int ref_idx)21 static AOM_INLINE int_mv get_ref_mv_for_mv_stats(
22 const MB_MODE_INFO *mbmi, const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame,
23 int ref_idx) {
24 int ref_mv_idx = mbmi->ref_mv_idx;
25 if (mbmi->mode == NEAR_NEWMV || mbmi->mode == NEW_NEARMV) {
26 assert(has_second_ref(mbmi));
27 ref_mv_idx += 1;
28 }
29
30 const MV_REFERENCE_FRAME *ref_frames = mbmi->ref_frame;
31 const int8_t ref_frame_type = av1_ref_frame_type(ref_frames);
32 const CANDIDATE_MV *curr_ref_mv_stack = mbmi_ext_frame->ref_mv_stack;
33
34 if (ref_frames[1] > INTRA_FRAME) {
35 assert(ref_idx == 0 || ref_idx == 1);
36 return ref_idx ? curr_ref_mv_stack[ref_mv_idx].comp_mv
37 : curr_ref_mv_stack[ref_mv_idx].this_mv;
38 }
39
40 assert(ref_idx == 0);
41 return ref_mv_idx < mbmi_ext_frame->ref_mv_count
42 ? curr_ref_mv_stack[ref_mv_idx].this_mv
43 : mbmi_ext_frame->global_mvs[ref_frame_type];
44 }
45
get_symbol_cost(const aom_cdf_prob * cdf,int symbol)46 static AOM_INLINE int get_symbol_cost(const aom_cdf_prob *cdf, int symbol) {
47 const aom_cdf_prob cur_cdf = AOM_ICDF(cdf[symbol]);
48 const aom_cdf_prob prev_cdf = symbol ? AOM_ICDF(cdf[symbol - 1]) : 0;
49 const aom_cdf_prob p15 = AOMMAX(cur_cdf - prev_cdf, EC_MIN_PROB);
50
51 return av1_cost_symbol(p15);
52 }
53
keep_one_comp_stat(MV_STATS * mv_stats,int comp_val,int comp_idx,const AV1_COMP * cpi,int * rates)54 static AOM_INLINE int keep_one_comp_stat(MV_STATS *mv_stats, int comp_val,
55 int comp_idx, const AV1_COMP *cpi,
56 int *rates) {
57 assert(comp_val != 0 && "mv component should not have zero value!");
58 const int sign = comp_val < 0;
59 const int mag = sign ? -comp_val : comp_val;
60 const int mag_minus_1 = mag - 1;
61 int offset;
62 const int mv_class = av1_get_mv_class(mag_minus_1, &offset);
63 const int int_part = offset >> 3; // int mv data
64 const int frac_part = (offset >> 1) & 3; // fractional mv data
65 const int high_part = offset & 1; // high precision mv data
66 const int use_hp = cpi->common.features.allow_high_precision_mv;
67 int r_idx = 0;
68
69 const MACROBLOCK *const x = &cpi->td.mb;
70 const MACROBLOCKD *const xd = &x->e_mbd;
71 FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
72 nmv_context *nmvc = &ec_ctx->nmvc;
73 nmv_component *mvcomp_ctx = nmvc->comps;
74 nmv_component *cur_mvcomp_ctx = &mvcomp_ctx[comp_idx];
75 aom_cdf_prob *sign_cdf = cur_mvcomp_ctx->sign_cdf;
76 aom_cdf_prob *class_cdf = cur_mvcomp_ctx->classes_cdf;
77 aom_cdf_prob *class0_cdf = cur_mvcomp_ctx->class0_cdf;
78 aom_cdf_prob(*bits_cdf)[3] = cur_mvcomp_ctx->bits_cdf;
79 aom_cdf_prob *frac_part_cdf = mv_class
80 ? (cur_mvcomp_ctx->fp_cdf)
81 : (cur_mvcomp_ctx->class0_fp_cdf[int_part]);
82 aom_cdf_prob *high_part_cdf =
83 mv_class ? (cur_mvcomp_ctx->hp_cdf) : (cur_mvcomp_ctx->class0_hp_cdf);
84
85 const int sign_rate = get_symbol_cost(sign_cdf, sign);
86 rates[r_idx++] = sign_rate;
87 update_cdf(sign_cdf, sign, 2);
88
89 const int class_rate = get_symbol_cost(class_cdf, mv_class);
90 rates[r_idx++] = class_rate;
91 update_cdf(class_cdf, mv_class, MV_CLASSES);
92
93 int int_bit_rate = 0;
94 if (mv_class == MV_CLASS_0) {
95 int_bit_rate = get_symbol_cost(class0_cdf, int_part);
96 update_cdf(class0_cdf, int_part, CLASS0_SIZE);
97 } else {
98 const int n = mv_class + CLASS0_BITS - 1; // number of bits
99 for (int i = 0; i < n; ++i) {
100 int_bit_rate += get_symbol_cost(bits_cdf[i], (int_part >> i) & 1);
101 update_cdf(bits_cdf[i], (int_part >> i) & 1, 2);
102 }
103 }
104 rates[r_idx++] = int_bit_rate;
105 const int frac_part_rate = get_symbol_cost(frac_part_cdf, frac_part);
106 rates[r_idx++] = frac_part_rate;
107 update_cdf(frac_part_cdf, frac_part, MV_FP_SIZE);
108 const int high_part_rate =
109 use_hp ? get_symbol_cost(high_part_cdf, high_part) : 0;
110 if (use_hp) {
111 update_cdf(high_part_cdf, high_part, 2);
112 }
113 rates[r_idx++] = high_part_rate;
114
115 mv_stats->last_bit_zero += !high_part;
116 mv_stats->last_bit_nonzero += high_part;
117 const int total_rate =
118 (sign_rate + class_rate + int_bit_rate + frac_part_rate + high_part_rate);
119 return total_rate;
120 }
121
keep_one_mv_stat(MV_STATS * mv_stats,const MV * ref_mv,const MV * cur_mv,const AV1_COMP * cpi)122 static AOM_INLINE void keep_one_mv_stat(MV_STATS *mv_stats, const MV *ref_mv,
123 const MV *cur_mv, const AV1_COMP *cpi) {
124 const MACROBLOCK *const x = &cpi->td.mb;
125 const MACROBLOCKD *const xd = &x->e_mbd;
126 FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
127 nmv_context *nmvc = &ec_ctx->nmvc;
128 aom_cdf_prob *joint_cdf = nmvc->joints_cdf;
129 const int use_hp = cpi->common.features.allow_high_precision_mv;
130
131 const MV diff = { cur_mv->row - ref_mv->row, cur_mv->col - ref_mv->col };
132 const int mv_joint = av1_get_mv_joint(&diff);
133 // TODO(chiyotsai@google.com): Estimate hp_diff when we are using lp
134 const MV hp_diff = diff;
135 const int hp_mv_joint = av1_get_mv_joint(&hp_diff);
136 const MV truncated_diff = { (diff.row / 2) * 2, (diff.col / 2) * 2 };
137 const MV lp_diff = use_hp ? truncated_diff : diff;
138 const int lp_mv_joint = av1_get_mv_joint(&lp_diff);
139
140 const int mv_joint_rate = get_symbol_cost(joint_cdf, mv_joint);
141 const int hp_mv_joint_rate = get_symbol_cost(joint_cdf, hp_mv_joint);
142 const int lp_mv_joint_rate = get_symbol_cost(joint_cdf, lp_mv_joint);
143
144 update_cdf(joint_cdf, mv_joint, MV_JOINTS);
145
146 mv_stats->total_mv_rate += mv_joint_rate;
147 mv_stats->hp_total_mv_rate += hp_mv_joint_rate;
148 mv_stats->lp_total_mv_rate += lp_mv_joint_rate;
149 mv_stats->mv_joint_count[mv_joint]++;
150
151 for (int comp_idx = 0; comp_idx < 2; comp_idx++) {
152 const int comp_val = comp_idx ? diff.col : diff.row;
153 const int hp_comp_val = comp_idx ? hp_diff.col : hp_diff.row;
154 const int lp_comp_val = comp_idx ? lp_diff.col : lp_diff.row;
155 int rates[5];
156 av1_zero_array(rates, 5);
157
158 const int comp_rate =
159 comp_val ? keep_one_comp_stat(mv_stats, comp_val, comp_idx, cpi, rates)
160 : 0;
161 // TODO(chiyotsai@google.com): Properly get hp rate when use_hp is false
162 const int hp_rate =
163 hp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] + rates[4] : 0;
164 const int lp_rate =
165 lp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] : 0;
166
167 mv_stats->total_mv_rate += comp_rate;
168 mv_stats->hp_total_mv_rate += hp_rate;
169 mv_stats->lp_total_mv_rate += lp_rate;
170 }
171 }
172
collect_mv_stats_b(MV_STATS * mv_stats,const AV1_COMP * cpi,int mi_row,int mi_col)173 static AOM_INLINE void collect_mv_stats_b(MV_STATS *mv_stats,
174 const AV1_COMP *cpi, int mi_row,
175 int mi_col) {
176 const AV1_COMMON *cm = &cpi->common;
177 const CommonModeInfoParams *const mi_params = &cm->mi_params;
178
179 if (mi_row >= mi_params->mi_rows || mi_col >= mi_params->mi_cols) {
180 return;
181 }
182
183 const MB_MODE_INFO *mbmi =
184 mi_params->mi_grid_base[mi_row * mi_params->mi_stride + mi_col];
185 const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame =
186 cpi->mbmi_ext_info.frame_base +
187 get_mi_ext_idx(mi_row, mi_col, cm->mi_params.mi_alloc_bsize,
188 cpi->mbmi_ext_info.stride);
189
190 if (!is_inter_block(mbmi)) {
191 mv_stats->intra_count++;
192 return;
193 }
194 mv_stats->inter_count++;
195
196 const PREDICTION_MODE mode = mbmi->mode;
197 const int is_compound = has_second_ref(mbmi);
198
199 if (mode == NEWMV || mode == NEW_NEWMV) {
200 // All mvs are new
201 for (int ref_idx = 0; ref_idx < 1 + is_compound; ++ref_idx) {
202 const MV ref_mv =
203 get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv;
204 const MV cur_mv = mbmi->mv[ref_idx].as_mv;
205 keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi);
206 }
207 } else if (mode == NEAREST_NEWMV || mode == NEAR_NEWMV ||
208 mode == NEW_NEARESTMV || mode == NEW_NEARMV) {
209 // has exactly one new_mv
210 mv_stats->default_mvs += 1;
211
212 const int ref_idx = (mode == NEAREST_NEWMV || mode == NEAR_NEWMV);
213 const MV ref_mv =
214 get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv;
215 const MV cur_mv = mbmi->mv[ref_idx].as_mv;
216
217 keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi);
218 } else {
219 // No new_mv
220 mv_stats->default_mvs += 1 + is_compound;
221 }
222
223 // Add texture information
224 const BLOCK_SIZE bsize = mbmi->bsize;
225 const int num_rows = block_size_high[bsize];
226 const int num_cols = block_size_wide[bsize];
227 const int y_stride = cpi->source->y_stride;
228 const int px_row = 4 * mi_row, px_col = 4 * mi_col;
229 const int buf_is_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH;
230 const int bd = cm->seq_params->bit_depth;
231 if (buf_is_hbd) {
232 uint16_t *source_buf =
233 CONVERT_TO_SHORTPTR(cpi->source->y_buffer) + px_row * y_stride + px_col;
234 for (int row = 0; row < num_rows - 1; row++) {
235 for (int col = 0; col < num_cols - 1; col++) {
236 const int offset = row * y_stride + col;
237 const int horz_diff =
238 abs(source_buf[offset + 1] - source_buf[offset]) >> (bd - 8);
239 const int vert_diff =
240 abs(source_buf[offset + y_stride] - source_buf[offset]) >> (bd - 8);
241 mv_stats->horz_text += horz_diff;
242 mv_stats->vert_text += vert_diff;
243 mv_stats->diag_text += horz_diff * vert_diff;
244 }
245 }
246 } else {
247 uint8_t *source_buf = cpi->source->y_buffer + px_row * y_stride + px_col;
248 for (int row = 0; row < num_rows - 1; row++) {
249 for (int col = 0; col < num_cols - 1; col++) {
250 const int offset = row * y_stride + col;
251 const int horz_diff = abs(source_buf[offset + 1] - source_buf[offset]);
252 const int vert_diff =
253 abs(source_buf[offset + y_stride] - source_buf[offset]);
254 mv_stats->horz_text += horz_diff;
255 mv_stats->vert_text += vert_diff;
256 mv_stats->diag_text += horz_diff * vert_diff;
257 }
258 }
259 }
260 }
261
262 // Split block
collect_mv_stats_sb(MV_STATS * mv_stats,const AV1_COMP * cpi,int mi_row,int mi_col,BLOCK_SIZE bsize)263 static AOM_INLINE void collect_mv_stats_sb(MV_STATS *mv_stats,
264 const AV1_COMP *cpi, int mi_row,
265 int mi_col, BLOCK_SIZE bsize) {
266 assert(bsize < BLOCK_SIZES_ALL);
267 const AV1_COMMON *cm = &cpi->common;
268
269 if (mi_row >= cm->mi_params.mi_rows || mi_col >= cm->mi_params.mi_cols)
270 return;
271
272 const PARTITION_TYPE partition = get_partition(cm, mi_row, mi_col, bsize);
273 const BLOCK_SIZE subsize = get_partition_subsize(bsize, partition);
274
275 const int hbs = mi_size_wide[bsize] / 2;
276 const int qbs = mi_size_wide[bsize] / 4;
277 switch (partition) {
278 case PARTITION_NONE:
279 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
280 break;
281 case PARTITION_HORZ:
282 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
283 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
284 break;
285 case PARTITION_VERT:
286 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
287 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
288 break;
289 case PARTITION_SPLIT:
290 collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, subsize);
291 collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col + hbs, subsize);
292 collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col, subsize);
293 collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col + hbs, subsize);
294 break;
295 case PARTITION_HORZ_A:
296 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
297 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
298 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
299 break;
300 case PARTITION_HORZ_B:
301 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
302 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
303 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs);
304 break;
305 case PARTITION_VERT_A:
306 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
307 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
308 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
309 break;
310 case PARTITION_VERT_B:
311 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
312 collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
313 collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs);
314 break;
315 case PARTITION_HORZ_4:
316 for (int i = 0; i < 4; ++i) {
317 const int this_mi_row = mi_row + i * qbs;
318 collect_mv_stats_b(mv_stats, cpi, this_mi_row, mi_col);
319 }
320 break;
321 case PARTITION_VERT_4:
322 for (int i = 0; i < 4; ++i) {
323 const int this_mi_col = mi_col + i * qbs;
324 collect_mv_stats_b(mv_stats, cpi, mi_row, this_mi_col);
325 }
326 break;
327 default: assert(0);
328 }
329 }
330
collect_mv_stats_tile(MV_STATS * mv_stats,const AV1_COMP * cpi,const TileInfo * tile_info)331 static AOM_INLINE void collect_mv_stats_tile(MV_STATS *mv_stats,
332 const AV1_COMP *cpi,
333 const TileInfo *tile_info) {
334 const AV1_COMMON *cm = &cpi->common;
335 const int mi_row_start = tile_info->mi_row_start;
336 const int mi_row_end = tile_info->mi_row_end;
337 const int mi_col_start = tile_info->mi_col_start;
338 const int mi_col_end = tile_info->mi_col_end;
339 const int sb_size_mi = cm->seq_params->mib_size;
340 BLOCK_SIZE sb_size = cm->seq_params->sb_size;
341 for (int mi_row = mi_row_start; mi_row < mi_row_end; mi_row += sb_size_mi) {
342 for (int mi_col = mi_col_start; mi_col < mi_col_end; mi_col += sb_size_mi) {
343 collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, sb_size);
344 }
345 }
346 }
347
av1_collect_mv_stats(AV1_COMP * cpi,int current_q)348 void av1_collect_mv_stats(AV1_COMP *cpi, int current_q) {
349 MV_STATS *mv_stats;
350 #if CONFIG_FRAME_PARALLEL_ENCODE
351 mv_stats = &cpi->mv_stats;
352 #else
353 mv_stats = &cpi->ppi->mv_stats;
354 #endif
355 const AV1_COMMON *cm = &cpi->common;
356 const int tile_cols = cm->tiles.cols;
357 const int tile_rows = cm->tiles.rows;
358
359 for (int tile_row = 0; tile_row < tile_rows; tile_row++) {
360 TileInfo tile_info;
361 av1_tile_set_row(&tile_info, cm, tile_row);
362 for (int tile_col = 0; tile_col < tile_cols; tile_col++) {
363 const int tile_idx = tile_row * tile_cols + tile_col;
364 av1_tile_set_col(&tile_info, cm, tile_col);
365 cpi->tile_data[tile_idx].tctx = *cm->fc;
366 cpi->td.mb.e_mbd.tile_ctx = &cpi->tile_data[tile_idx].tctx;
367 collect_mv_stats_tile(mv_stats, cpi, &tile_info);
368 }
369 }
370
371 mv_stats->q = current_q;
372 mv_stats->order = cpi->common.current_frame.order_hint;
373 mv_stats->valid = 1;
374 }
375
get_smart_mv_prec(AV1_COMP * cpi,const MV_STATS * mv_stats,int current_q)376 static AOM_INLINE int get_smart_mv_prec(AV1_COMP *cpi, const MV_STATS *mv_stats,
377 int current_q) {
378 const AV1_COMMON *cm = &cpi->common;
379 const int order_hint = cpi->common.current_frame.order_hint;
380 const int order_diff = order_hint - mv_stats->order;
381 const float area = (float)(cm->width * cm->height);
382 float features[MV_PREC_FEATURE_SIZE] = {
383 (float)current_q,
384 (float)mv_stats->q,
385 (float)order_diff,
386 mv_stats->inter_count / area,
387 mv_stats->intra_count / area,
388 mv_stats->default_mvs / area,
389 mv_stats->mv_joint_count[0] / area,
390 mv_stats->mv_joint_count[1] / area,
391 mv_stats->mv_joint_count[2] / area,
392 mv_stats->mv_joint_count[3] / area,
393 mv_stats->last_bit_zero / area,
394 mv_stats->last_bit_nonzero / area,
395 mv_stats->total_mv_rate / area,
396 mv_stats->hp_total_mv_rate / area,
397 mv_stats->lp_total_mv_rate / area,
398 mv_stats->horz_text / area,
399 mv_stats->vert_text / area,
400 mv_stats->diag_text / area,
401 };
402
403 for (int f_idx = 0; f_idx < MV_PREC_FEATURE_SIZE; f_idx++) {
404 features[f_idx] =
405 (features[f_idx] - av1_mv_prec_mean[f_idx]) / av1_mv_prec_std[f_idx];
406 }
407 float score = 0.0f;
408
409 av1_nn_predict(features, &av1_mv_prec_dnn_config, 1, &score);
410
411 const int use_high_hp = score >= 0.0f;
412 return use_high_hp;
413 }
414 #endif // !CONFIG_REALTIME_ONLY
415
av1_pick_and_set_high_precision_mv(AV1_COMP * cpi,int qindex)416 void av1_pick_and_set_high_precision_mv(AV1_COMP *cpi, int qindex) {
417 int use_hp = qindex < HIGH_PRECISION_MV_QTHRESH;
418
419 if (cpi->sf.hl_sf.high_precision_mv_usage == QTR_ONLY) {
420 use_hp = 0;
421 }
422 #if !CONFIG_REALTIME_ONLY
423 else if (cpi->sf.hl_sf.high_precision_mv_usage == LAST_MV_DATA &&
424 av1_frame_allows_smart_mv(cpi) && cpi->ppi->mv_stats.valid) {
425 #if CONFIG_FRAME_PARALLEL_ENCODE
426 use_hp = get_smart_mv_prec(cpi, &cpi->mv_stats, qindex);
427 #else
428 use_hp = get_smart_mv_prec(cpi, &cpi->ppi->mv_stats, qindex);
429 #endif
430 }
431 #endif // !CONFIG_REALTIME_ONLY
432
433 av1_set_high_precision_mv(cpi, use_hp,
434 cpi->common.features.cur_frame_force_integer_mv);
435 }
436