1 /*
2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <float.h>
13
14 #include "aom_ports/system_state.h"
15
16 #include "av1/common/enums.h"
17 #include "av1/common/reconinter.h"
18
19 #include "av1/encoder/encoder.h"
20 #include "av1/encoder/partition_model_weights.h"
21 #include "av1/encoder/partition_strategy.h"
22 #include "av1/encoder/rdopt.h"
23
24 // Performs a simple_motion_search with a single reference frame and extract
25 // the variance of residues. Here features is assumed to be a length 6 array.
26 // After this function is called, we will store the following in to features:
27 // features[0] = log(1 + dc_q**2/256)
28 // features[1] = log(1 + variance_of_residue)
29 // for i in [2, 3, 4, 5]:
30 // features[i] = log(1 + variance_of_residue_in_block[i]/variance_of_residue)
get_res_var_features(AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,BLOCK_SIZE bsize,float * features)31 static void get_res_var_features(AV1_COMP *const cpi, MACROBLOCK *x, int mi_row,
32 int mi_col, BLOCK_SIZE bsize,
33 float *features) {
34 // TODO(chiyotsai@google.com): The data this model trained on did not also use
35 // SIMPLE_TRANSLATION to build the inter_predictor. Retraining and tuning the
36 // model with the correct data should give better performance.
37 assert(mi_size_wide[bsize] == mi_size_high[bsize]);
38
39 MACROBLOCKD *xd = &x->e_mbd;
40
41 // Perform a single motion search in Y_PLANE to make a prediction
42 const int use_subpixel = 0;
43
44 // Start getting the features
45 int f_idx = 0;
46
47 // Q_INDEX
48 const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
49 aom_clear_system_state();
50 features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
51
52 // VARIANCE
53 unsigned int sse = 0;
54 unsigned int var = 0;
55 const MV ref_mv_full = { .row = 0, .col = 0 };
56 av1_simple_motion_sse_var(cpi, x, mi_row, mi_col, bsize, ref_mv_full,
57 use_subpixel, &sse, &var);
58 aom_clear_system_state();
59 features[f_idx++] = logf(1.0f + (float)var);
60
61 // Regional
62 const uint8_t *src = x->plane[0].src.buf;
63 const int src_stride = x->plane[0].src.stride;
64 const uint8_t *dst = xd->plane[0].dst.buf;
65 const int dst_stride = xd->plane[0].dst.stride;
66 const int bw = block_size_wide[bsize];
67 const int bh = block_size_high[bsize];
68 const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
69 int r_idx = 0;
70 for (r_idx = 0; r_idx < 4; r_idx++) {
71 const int x_idx = (r_idx & 1) * bw / 2;
72 const int y_idx = (r_idx >> 1) * bh / 2;
73 const int src_offset = y_idx * src_stride + x_idx;
74 const int dst_offset = y_idx * dst_stride + x_idx;
75 const unsigned int sub_var = cpi->fn_ptr[subsize].vf(
76 src + src_offset, src_stride, dst + dst_offset, dst_stride, &sse);
77 aom_clear_system_state();
78 const float var_ratio = (1.0f + (float)sub_var) / (4.0f + (float)var);
79 features[f_idx++] = var_ratio;
80 }
81 }
82
av1_simple_motion_search_based_split(AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,BLOCK_SIZE bsize,int * partition_none_allowed,int * partition_horz_allowed,int * partition_vert_allowed,int * do_rectangular_split,int * do_square_split)83 void av1_simple_motion_search_based_split(
84 AV1_COMP *const cpi, MACROBLOCK *x, int mi_row, int mi_col,
85 BLOCK_SIZE bsize, int *partition_none_allowed, int *partition_horz_allowed,
86 int *partition_vert_allowed, int *do_rectangular_split,
87 int *do_square_split) {
88 const NN_CONFIG *nn_config = NULL;
89 float split_only_thresh = 0.0f;
90 if (bsize == BLOCK_128X128) {
91 nn_config = &av1_simple_motion_search_based_split_nn_config_128;
92 split_only_thresh = av1_simple_motion_search_based_split_thresh_128;
93 } else if (bsize == BLOCK_64X64) {
94 nn_config = &av1_simple_motion_search_based_split_nn_config_64;
95 split_only_thresh = av1_simple_motion_search_based_split_thresh_64;
96 } else if (bsize == BLOCK_32X32) {
97 nn_config = &av1_simple_motion_search_based_split_nn_config_32;
98 split_only_thresh = av1_simple_motion_search_based_split_thresh_32;
99 } else if (bsize == BLOCK_16X16) {
100 nn_config = &av1_simple_motion_search_based_split_nn_config_16;
101 split_only_thresh = av1_simple_motion_search_based_split_thresh_16;
102 } else if (bsize == BLOCK_8X8) {
103 // Disable BLOCK_8X8 for now
104 #if !CONFIG_DISABLE_FULL_PIXEL_SPLIT_8X8
105 nn_config = &av1_simple_motion_search_based_split_nn_config_8;
106 split_only_thresh = av1_simple_motion_search_based_split_thresh_8;
107 #endif
108 } else {
109 assert(0 && "Unexpected block size in simple_motion_based_split");
110 }
111 if (nn_config) {
112 float features[6] = { 0 };
113 float score = 0;
114 get_res_var_features(cpi, x, mi_row, mi_col, bsize, features);
115 av1_nn_predict(features, nn_config, &score);
116
117 if (score > split_only_thresh) {
118 *partition_none_allowed = 0;
119 *partition_horz_allowed = 0;
120 *partition_vert_allowed = 0;
121 *do_rectangular_split = 0;
122 }
123 if (cpi->sf.simple_motion_search_split_only >= 2) {
124 if (score < -split_only_thresh) *do_square_split = 0;
125 // For larger scores (>split_only_thresh), none and rectangular partitions
126 // are skipped. As score reduces, possibility of split decreases. Hence
127 // for near larger scores (.875 * split_only_thresh to split_only_thresh)
128 // none partition is disabled, but rectangular partitions are evaluated
129 // additionally.
130 if (score > (split_only_thresh * 0.875)) *partition_none_allowed = 0;
131 }
132 }
133 }
134
135 // Given a list of ref frames in refs, performs simple_motion_search on each of
136 // the refs and returns the ref with the smallest sse. Returns -1 if none of the
137 // ref in the list is available. Also stores the best sse and var in best_sse,
138 // best_var, respectively. If save_mv_code is -1, don't update mv_ref_fulls in
139 // pc_tree. If save_mv_code is between 0 and 3, update mv_ref_fulls under
140 // pc_tree->split[i]. If save_mv_code is 4, update mv_ref_fulls under pc_tree.
simple_motion_search_get_best_ref(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,const int * const refs,int num_refs,int use_subpixel,int save_mv_code,unsigned int * best_sse,unsigned int * best_var)141 static int simple_motion_search_get_best_ref(
142 AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
143 int mi_col, BLOCK_SIZE bsize, const int *const refs, int num_refs,
144 int use_subpixel, int save_mv_code, unsigned int *best_sse,
145 unsigned int *best_var) {
146 // TODO(chiyotsai@google.com): The calculation of variance currently uses
147 // bsize, so we might take area outside of the image into account. We need to
148 // modify the SIMD functions to fix this later.
149 const AV1_COMMON *const cm = &cpi->common;
150 int best_ref = -1;
151
152 if (mi_col >= cm->mi_cols || mi_row >= cm->mi_rows) {
153 // If the whole block is outside of the image, set the var and sse to 0.
154 *best_var = 0;
155 *best_sse = 0;
156
157 return best_ref;
158 }
159
160 // Otherwise do loop through the reference frames and find the one with the
161 // minimum SSE
162 const MACROBLOCKD *xd = &x->e_mbd;
163 const MV *mv_ref_fulls = pc_tree->mv_ref_fulls;
164
165 const int num_planes = 1;
166
167 *best_sse = INT_MAX;
168
169 for (int ref_idx = 0; ref_idx < num_refs; ref_idx++) {
170 const int ref = refs[ref_idx];
171
172 if (cpi->ref_frame_flags & av1_ref_frame_flag_list[ref]) {
173 unsigned int curr_sse = 0, curr_var = 0;
174 av1_simple_motion_search(cpi, x, mi_row, mi_col, bsize, ref,
175 mv_ref_fulls[ref], num_planes, use_subpixel);
176 curr_var = cpi->fn_ptr[bsize].vf(
177 x->plane[0].src.buf, x->plane[0].src.stride, xd->plane[0].dst.buf,
178 xd->plane[0].dst.stride, &curr_sse);
179 if (curr_sse < *best_sse) {
180 *best_sse = curr_sse;
181 *best_var = curr_var;
182 best_ref = ref;
183 }
184
185 const int new_mv_row = x->best_mv.as_mv.row / 8;
186 const int new_mv_col = x->best_mv.as_mv.col / 8;
187 if (save_mv_code == 4) {
188 pc_tree->mv_ref_fulls[ref].row = new_mv_row;
189 pc_tree->mv_ref_fulls[ref].col = new_mv_col;
190 } else if (save_mv_code >= 0 && save_mv_code < 4) {
191 // Propagate the new motion vectors to a lower level
192 pc_tree->split[save_mv_code]->mv_ref_fulls[ref].row = new_mv_row;
193 pc_tree->split[save_mv_code]->mv_ref_fulls[ref].col = new_mv_col;
194 } else {
195 assert(save_mv_code == -1 &&
196 "Unknown code in simple_motion_search_get_best_ref.");
197 }
198 }
199 }
200
201 return best_ref;
202 }
203
204 // Performs fullpixel simple_motion_search with LAST_FRAME and ALTREF_FRAME on
205 // each subblock and extract the variance and sse of residues. Then store the
206 // var and sse from each partition subblock to features. The DC qindex is also
207 // stored in features.
208 // Here features is assumed to be a length 19 array.
209 // After this function is called, we will store the following to features:
210 // features[0:17] = var and sse from subblocks
211 // features[18] = DC q_index
simple_motion_search_prune_part_features(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,float * features)212 static void simple_motion_search_prune_part_features(
213 AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
214 int mi_col, BLOCK_SIZE bsize, float *features) {
215 // TODO(chiyotsai@google.com): Cache the result of the motion search from the
216 // larger bsize.
217 const int w_mi = mi_size_wide[bsize];
218 const int h_mi = mi_size_high[bsize];
219 int f_idx = 0;
220 assert(mi_size_wide[bsize] == mi_size_high[bsize]);
221 assert(cpi->ref_frame_flags & av1_ref_frame_flag_list[LAST_FRAME] ||
222 cpi->ref_frame_flags & av1_ref_frame_flag_list[ALTREF_FRAME]);
223
224 // Setting up motion search
225 const int ref_list[] = { LAST_FRAME, ALTREF_FRAME };
226 const int num_refs = 2;
227 const int use_subpixel = 1;
228
229 unsigned int int_features[FEATURE_SIZE_SMS_PRUNE_PART - 1];
230
231 // Doing whole block first to update the mv
232 simple_motion_search_get_best_ref(
233 cpi, x, pc_tree, mi_row, mi_col, bsize, ref_list, num_refs, use_subpixel,
234 4, &int_features[f_idx], &int_features[f_idx + 1]);
235 f_idx += 2;
236
237 // Split subblocks
238 BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
239 int r_idx = 0;
240 for (r_idx = 0; r_idx < 4; r_idx++) {
241 const int sub_mi_col = mi_col + (r_idx & 1) * w_mi / 2;
242 const int sub_mi_row = mi_row + (r_idx >> 1) * h_mi / 2;
243
244 simple_motion_search_get_best_ref(
245 cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
246 use_subpixel, r_idx, &int_features[f_idx], &int_features[f_idx + 1]);
247 f_idx += 2;
248 }
249
250 // Horz subblocks
251 subsize = get_partition_subsize(bsize, PARTITION_HORZ);
252 for (r_idx = 0; r_idx < 2; r_idx++) {
253 const int sub_mi_col = mi_col + 0;
254 const int sub_mi_row = mi_row + r_idx * h_mi / 2;
255
256 simple_motion_search_get_best_ref(
257 cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
258 use_subpixel, -1, &int_features[f_idx], &int_features[f_idx + 1]);
259
260 f_idx += 2;
261 }
262
263 // Vert subblock
264 subsize = get_partition_subsize(bsize, PARTITION_VERT);
265 for (r_idx = 0; r_idx < 2; r_idx++) {
266 const int sub_mi_col = mi_col + r_idx * w_mi / 2;
267 const int sub_mi_row = mi_row + 0;
268
269 simple_motion_search_get_best_ref(
270 cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
271 use_subpixel, -1, &int_features[f_idx], &int_features[f_idx + 1]);
272
273 f_idx += 2;
274 }
275
276 aom_clear_system_state();
277 for (int idx = 0; idx < f_idx; idx++) {
278 features[idx] = logf(1.0f + (float)int_features[idx]);
279 }
280
281 const MACROBLOCKD *xd = &x->e_mbd;
282 set_offsets_for_motion_search(cpi, x, mi_row, mi_col, bsize);
283
284 // Q_INDEX
285 const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
286 features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
287
288 // Neighbor stuff
289 const int has_above = !!xd->above_mbmi;
290 const int has_left = !!xd->left_mbmi;
291 const BLOCK_SIZE above_bsize = has_above ? xd->above_mbmi->sb_type : bsize;
292 const BLOCK_SIZE left_bsize = has_left ? xd->left_mbmi->sb_type : bsize;
293 features[f_idx++] = (float)has_above;
294 features[f_idx++] = (float)mi_size_wide_log2[above_bsize];
295 features[f_idx++] = (float)mi_size_high_log2[above_bsize];
296 features[f_idx++] = (float)has_left;
297 features[f_idx++] = (float)mi_size_wide_log2[left_bsize];
298 features[f_idx++] = (float)mi_size_high_log2[left_bsize];
299
300 assert(f_idx == FEATURE_SIZE_SMS_PRUNE_PART);
301 }
302
av1_simple_motion_search_prune_part(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,int * partition_none_allowed,int * partition_horz_allowed,int * partition_vert_allowed,int * do_square_split,int * do_rectangular_split,int * prune_horz,int * prune_vert,float * features,int * valid)303 void av1_simple_motion_search_prune_part(
304 AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
305 int mi_col, BLOCK_SIZE bsize, int *partition_none_allowed,
306 int *partition_horz_allowed, int *partition_vert_allowed,
307 int *do_square_split, int *do_rectangular_split, int *prune_horz,
308 int *prune_vert, float *features, int *valid) {
309 const AV1_COMMON *const cm = &cpi->common;
310 // Get model parameters
311 const NN_CONFIG *nn_config = NULL;
312 const float *prune_thresh = NULL, *only_thresh = NULL;
313 const float *ml_mean = NULL, *ml_std = NULL;
314 float normalized_features[FEATURE_SIZE_SMS_PRUNE_PART] = { 0.0f };
315
316 if (bsize == BLOCK_128X128) {
317 nn_config = &av1_simple_motion_search_prune_part_nn_config_128;
318 ml_mean = av1_simple_motion_search_prune_part_mean_128;
319 ml_std = av1_simple_motion_search_prune_part_std_128;
320 prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_128;
321 only_thresh = av1_simple_motion_search_prune_part_only_thresh_128;
322 } else if (bsize == BLOCK_64X64) {
323 nn_config = &av1_simple_motion_search_prune_part_nn_config_64;
324 ml_mean = av1_simple_motion_search_prune_part_mean_64;
325 ml_std = av1_simple_motion_search_prune_part_std_64;
326 prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_64;
327 only_thresh = av1_simple_motion_search_prune_part_only_thresh_64;
328 } else if (bsize == BLOCK_32X32) {
329 nn_config = &av1_simple_motion_search_prune_part_nn_config_32;
330 ml_mean = av1_simple_motion_search_prune_part_mean_32;
331 ml_std = av1_simple_motion_search_prune_part_std_32;
332 prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_32;
333 only_thresh = av1_simple_motion_search_prune_part_only_thresh_32;
334 } else if (bsize == BLOCK_16X16) {
335 nn_config = &av1_simple_motion_search_prune_part_nn_config_16;
336 ml_mean = av1_simple_motion_search_prune_part_mean_16;
337 ml_std = av1_simple_motion_search_prune_part_std_16;
338 prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_16;
339 only_thresh = av1_simple_motion_search_prune_part_only_thresh_16;
340 } else if (bsize == BLOCK_8X8) {
341 nn_config = &av1_simple_motion_search_prune_part_nn_config_8;
342 ml_mean = av1_simple_motion_search_prune_part_mean_8;
343 ml_std = av1_simple_motion_search_prune_part_std_8;
344 prune_thresh = av1_simple_motion_search_prune_part_prune_thresh_8;
345 only_thresh = av1_simple_motion_search_prune_part_only_thresh_8;
346 } else {
347 assert(0 && "Unexpected block size in simple_motion_prune_part");
348 }
349
350 // If there is no valid threshold, return immediately.
351 if (!nn_config || (prune_thresh[PARTITION_HORZ] == 0.0f &&
352 prune_thresh[PARTITION_VERT] == 0.0f)) {
353 return;
354 }
355 if (bsize < BLOCK_8X8) {
356 return;
357 }
358
359 // Get features
360 simple_motion_search_prune_part_features(cpi, x, pc_tree, mi_row, mi_col,
361 bsize, features);
362 *valid = 1;
363 for (int f_idx = 0; f_idx < FEATURE_SIZE_SMS_PRUNE_PART; f_idx++) {
364 normalized_features[f_idx] =
365 (features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
366 }
367
368 // Get probabilities
369 float scores[EXT_PARTITION_TYPES] = { 0.0f },
370 probs[EXT_PARTITION_TYPES] = { 0.0f };
371 const int num_classes = (bsize == BLOCK_128X128 || bsize == BLOCK_8X8)
372 ? PARTITION_TYPES
373 : EXT_PARTITION_TYPES;
374
375 av1_nn_predict(normalized_features, nn_config, scores);
376 aom_clear_system_state();
377
378 av1_nn_softmax(scores, probs, num_classes);
379
380 // Determine if we should prune rectangular partitions.
381 if (cpi->sf.simple_motion_search_prune_rect && !frame_is_intra_only(cm) &&
382 (*partition_horz_allowed || *partition_vert_allowed) &&
383 bsize >= BLOCK_8X8 && !av1_superres_scaled(cm)) {
384 *prune_horz = probs[PARTITION_HORZ] <= prune_thresh[PARTITION_HORZ];
385 *prune_vert = probs[PARTITION_VERT] <= prune_thresh[PARTITION_VERT];
386 }
387
388 // Silence compiler warnings
389 (void)only_thresh;
390 (void)partition_none_allowed;
391 (void)do_square_split;
392 (void)do_rectangular_split;
393 }
394
395 // Early terminates PARTITION_NONE using simple_motion_search features and the
396 // rate, distortion, and rdcost of PARTITION_NONE. This is only called when:
397 // - The frame is a show frame
398 // - The frame is not intra only
399 // - The current bsize is > BLOCK_8X8
400 // - blk_row + blk_height/2 < total_rows and blk_col + blk_width/2 < total_cols
av1_simple_motion_search_early_term_none(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,const RD_STATS * none_rdc,int * early_terminate,float * simple_motion_features,int * simple_motion_features_are_valid)401 void av1_simple_motion_search_early_term_none(
402 AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
403 int mi_col, BLOCK_SIZE bsize, const RD_STATS *none_rdc,
404 int *early_terminate, float *simple_motion_features,
405 int *simple_motion_features_are_valid) {
406 // TODO(chiyotsai@google.com): There are other features we can extract from
407 // PARTITION_NONE. Play with this later.
408 int f_idx = 0;
409 if (!*simple_motion_features_are_valid) {
410 simple_motion_search_prune_part_features(cpi, x, pc_tree, mi_row, mi_col,
411 bsize, simple_motion_features);
412 *simple_motion_features_are_valid = 1;
413 }
414 f_idx = 25;
415
416 simple_motion_features[f_idx++] = logf(1.0f + (float)none_rdc->rate);
417 simple_motion_features[f_idx++] = logf(1.0f + (float)none_rdc->dist);
418 simple_motion_features[f_idx++] = logf(1.0f + (float)none_rdc->rdcost);
419
420 assert(f_idx == FEATURE_SIZE_SMS_TERM_NONE);
421
422 const float *ml_mean = NULL;
423 const float *ml_std = NULL;
424 const float *ml_model = NULL;
425
426 if (bsize == BLOCK_128X128) {
427 ml_mean = av1_simple_motion_search_term_none_mean_128;
428 ml_std = av1_simple_motion_search_term_none_std_128;
429 ml_model = av1_simple_motion_search_term_none_model_128;
430 } else if (bsize == BLOCK_64X64) {
431 ml_mean = av1_simple_motion_search_term_none_mean_64;
432 ml_std = av1_simple_motion_search_term_none_std_64;
433 ml_model = av1_simple_motion_search_term_none_model_64;
434 } else if (bsize == BLOCK_32X32) {
435 ml_mean = av1_simple_motion_search_term_none_mean_32;
436 ml_std = av1_simple_motion_search_term_none_std_32;
437 ml_model = av1_simple_motion_search_term_none_model_32;
438 } else if (bsize == BLOCK_16X16) {
439 ml_mean = av1_simple_motion_search_term_none_mean_16;
440 ml_std = av1_simple_motion_search_term_none_std_16;
441 ml_model = av1_simple_motion_search_term_none_model_16;
442 } else {
443 assert(0 && "Unexpected block size in simple_motion_term_none");
444 }
445
446 if (ml_model) {
447 float score = 0.0f;
448 for (f_idx = 0; f_idx < FEATURE_SIZE_SMS_TERM_NONE; f_idx++) {
449 score += ml_model[f_idx] *
450 (simple_motion_features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
451 }
452 score += ml_model[FEATURE_SIZE_SMS_TERM_NONE];
453
454 if (score >= 0.0f) {
455 *early_terminate = 1;
456 }
457 }
458 }
459
firstpass_simple_motion_search_features(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,float * features)460 static void firstpass_simple_motion_search_features(
461 AV1_COMP *const cpi, MACROBLOCK *x, PC_TREE *pc_tree, int mi_row,
462 int mi_col, BLOCK_SIZE bsize, float *features) {
463 assert(mi_size_wide[bsize] == mi_size_high[bsize]);
464 assert(cpi->ref_frame_flags & av1_ref_frame_flag_list[LAST_FRAME] ||
465 cpi->ref_frame_flags & av1_ref_frame_flag_list[ALTREF_FRAME]);
466
467 // Setting up motion search
468 const int ref_list[] = { LAST_FRAME, ALTREF_FRAME };
469 const int num_refs = 2;
470 const int use_subpixel = 0;
471
472 unsigned int int_features[10] = { 0 };
473
474 int f_idx = 0;
475 // Doing whole block first to update the mv
476 simple_motion_search_get_best_ref(
477 cpi, x, pc_tree, mi_row, mi_col, bsize, ref_list, num_refs, use_subpixel,
478 4, &int_features[f_idx], &int_features[f_idx + 1]);
479 f_idx += 2;
480
481 // Split subblocks
482 const BLOCK_SIZE subsize = get_partition_subsize(bsize, PARTITION_SPLIT);
483 const int w_mi = mi_size_wide[bsize];
484 const int h_mi = mi_size_high[bsize];
485 for (int r_idx = 0; r_idx < 4; r_idx++) {
486 const int sub_mi_col = mi_col + (r_idx & 1) * w_mi / 2;
487 const int sub_mi_row = mi_row + (r_idx >> 1) * h_mi / 2;
488
489 simple_motion_search_get_best_ref(
490 cpi, x, pc_tree, sub_mi_row, sub_mi_col, subsize, ref_list, num_refs,
491 use_subpixel, r_idx, &int_features[f_idx], &int_features[f_idx + 1]);
492 f_idx += 2;
493 }
494
495 aom_clear_system_state();
496 for (int idx = 0; idx < f_idx; idx++) {
497 features[idx] = logf(1.0f + (float)int_features[idx]);
498 }
499
500 const MACROBLOCKD *xd = &x->e_mbd;
501 set_offsets_for_motion_search(cpi, x, mi_row, mi_col, bsize);
502
503 // Q_INDEX
504 const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
505 features[f_idx++] = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
506
507 // Neighbor stuff
508 const int has_above = !!xd->above_mbmi;
509 const int has_left = !!xd->left_mbmi;
510 const BLOCK_SIZE above_bsize = has_above ? xd->above_mbmi->sb_type : bsize;
511 const BLOCK_SIZE left_bsize = has_left ? xd->left_mbmi->sb_type : bsize;
512 features[f_idx++] = (float)has_above;
513 features[f_idx++] = (float)mi_size_wide_log2[above_bsize];
514 features[f_idx++] = (float)mi_size_high_log2[above_bsize];
515 features[f_idx++] = (float)has_left;
516 features[f_idx++] = (float)mi_size_wide_log2[left_bsize];
517 features[f_idx++] = (float)mi_size_high_log2[left_bsize];
518 }
519
av1_firstpass_simple_motion_search_early_term(AV1_COMP * const cpi,MACROBLOCK * x,PC_TREE * pc_tree,int mi_row,int mi_col,BLOCK_SIZE bsize,const RD_STATS * none_rdc,int * do_square_split)520 void av1_firstpass_simple_motion_search_early_term(AV1_COMP *const cpi,
521 MACROBLOCK *x,
522 PC_TREE *pc_tree, int mi_row,
523 int mi_col, BLOCK_SIZE bsize,
524 const RD_STATS *none_rdc,
525 int *do_square_split) {
526 const NN_CONFIG *nn_config = NULL;
527 float thresh = 0.0f;
528 const float *ml_mean = NULL, *ml_std = NULL;
529 if (bsize == BLOCK_32X32) {
530 nn_config = &av1_fp_simple_motion_search_term_none_nn_config_32;
531 ml_mean = av1_fp_simple_motion_search_term_none_mean_32;
532 ml_std = av1_fp_simple_motion_search_term_none_std_32;
533 thresh = av1_fp_simple_motion_search_term_none_thresh_32;
534 } else if (bsize == BLOCK_16X16) {
535 nn_config = &av1_fp_simple_motion_search_term_none_nn_config_16;
536 ml_mean = av1_fp_simple_motion_search_term_none_mean_16;
537 ml_std = av1_fp_simple_motion_search_term_none_std_16;
538 thresh = av1_fp_simple_motion_search_term_none_thresh_16;
539 } else if (bsize == BLOCK_8X8) {
540 nn_config = &av1_fp_simple_motion_search_term_none_nn_config_8;
541 ml_mean = av1_fp_simple_motion_search_term_none_mean_8;
542 ml_std = av1_fp_simple_motion_search_term_none_std_8;
543 thresh = av1_fp_simple_motion_search_term_none_thresh_8;
544 } else {
545 assert(0 &&
546 "Unexpected bsize in firstpass_simple_motion_search_early_term");
547 return;
548 }
549
550 float ml_features[FEATURE_SIZE_FP_SMS_TERM_NONE] = { 0.0f };
551
552 firstpass_simple_motion_search_features(cpi, x, pc_tree, mi_row, mi_col,
553 bsize, ml_features);
554 int f_idx = 17;
555
556 ml_features[f_idx++] = logf(1.0f + (float)none_rdc->rate);
557 ml_features[f_idx++] = logf(1.0f + (float)none_rdc->dist);
558 ml_features[f_idx++] = logf(1.0f + (float)none_rdc->rdcost);
559
560 for (f_idx = 0; f_idx < 20; f_idx++) {
561 ml_features[f_idx] = (ml_features[f_idx] - ml_mean[f_idx]) / ml_std[f_idx];
562 }
563
564 // Get probabilities
565 float score = 0.0f;
566
567 av1_nn_predict(ml_features, nn_config, &score);
568 aom_clear_system_state();
569
570 // Determine if we should prune square partitions.
571 if (score < thresh) {
572 *do_square_split = 0;
573 }
574 }
575
av1_get_max_min_partition_features(AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,float * features)576 void av1_get_max_min_partition_features(AV1_COMP *const cpi, MACROBLOCK *x,
577 int mi_row, int mi_col,
578 float *features) {
579 AV1_COMMON *const cm = &cpi->common;
580 MACROBLOCKD *xd = &x->e_mbd;
581 const BLOCK_SIZE sb_size = cm->seq_params.sb_size;
582
583 assert(sb_size == BLOCK_128X128);
584
585 int f_idx = 0;
586
587 const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
588 aom_clear_system_state();
589 const float log_q_sq = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
590
591 // Perform full-pixel single motion search in Y plane of 16x16 mbs in the sb
592 float sum_mv_row_sq = 0;
593 float sum_mv_row = 0;
594 float min_abs_mv_row = FLT_MAX;
595 float max_abs_mv_row = 0;
596
597 float sum_mv_col_sq = 0;
598 float sum_mv_col = 0;
599 float min_abs_mv_col = FLT_MAX;
600 float max_abs_mv_col = 0;
601
602 float sum_log_sse_sq = 0;
603 float sum_log_sse = 0;
604 float min_log_sse = FLT_MAX;
605 float max_log_sse = 0;
606
607 const BLOCK_SIZE mb_size = BLOCK_16X16;
608 const int mb_rows = block_size_high[sb_size] / block_size_high[mb_size];
609 const int mb_cols = block_size_wide[sb_size] / block_size_wide[mb_size];
610 const int mb_in_mi_size_high_log2 = mi_size_high_log2[mb_size];
611 const int mb_in_mi_size_wide_log2 = mi_size_wide_log2[mb_size];
612
613 for (int mb_row = 0; mb_row < mb_rows; mb_row++)
614 for (int mb_col = 0; mb_col < mb_cols; mb_col++) {
615 const int this_mi_row = mi_row + (mb_row << mb_in_mi_size_high_log2);
616 const int this_mi_col = mi_col + (mb_col << mb_in_mi_size_wide_log2);
617 unsigned int sse = 0;
618 unsigned int var = 0;
619 const MV ref_mv_full = { .row = 0, .col = 0 };
620
621 av1_simple_motion_sse_var(cpi, x, this_mi_row, this_mi_col, mb_size,
622 ref_mv_full, 0, &sse, &var);
623
624 aom_clear_system_state();
625 const float mv_row = (float)(x->best_mv.as_mv.row / 8);
626 const float mv_col = (float)(x->best_mv.as_mv.col / 8);
627 const float log_sse = logf(1.0f + (float)sse);
628 const float abs_mv_row = fabsf(mv_row);
629 const float abs_mv_col = fabsf(mv_col);
630
631 sum_mv_row_sq += mv_row * mv_row;
632 sum_mv_row += mv_row;
633 sum_mv_col_sq += mv_col * mv_col;
634 sum_mv_col += mv_col;
635
636 if (abs_mv_row < min_abs_mv_row) min_abs_mv_row = abs_mv_row;
637 if (abs_mv_row > max_abs_mv_row) max_abs_mv_row = abs_mv_row;
638 if (abs_mv_col < min_abs_mv_col) min_abs_mv_col = abs_mv_col;
639 if (abs_mv_col > max_abs_mv_col) max_abs_mv_col = abs_mv_col;
640
641 sum_log_sse_sq += log_sse * log_sse;
642 sum_log_sse += log_sse;
643 if (log_sse < min_log_sse) min_log_sse = log_sse;
644 if (log_sse > max_log_sse) max_log_sse = log_sse;
645 }
646 aom_clear_system_state();
647 const float avg_mv_row = sum_mv_row / 64.0f;
648 const float var_mv_row = sum_mv_row_sq / 64.0f - avg_mv_row * avg_mv_row;
649
650 const float avg_mv_col = sum_mv_col / 64.0f;
651 const float var_mv_col = sum_mv_col_sq / 64.0f - avg_mv_col * avg_mv_col;
652
653 const float avg_log_sse = sum_log_sse / 64.0f;
654 const float var_log_sse = sum_log_sse_sq / 64.0f - avg_log_sse * avg_log_sse;
655
656 features[f_idx++] = avg_log_sse;
657 features[f_idx++] = avg_mv_col;
658 features[f_idx++] = avg_mv_row;
659 features[f_idx++] = log_q_sq;
660 features[f_idx++] = max_abs_mv_col;
661 features[f_idx++] = max_abs_mv_row;
662 features[f_idx++] = max_log_sse;
663 features[f_idx++] = min_abs_mv_col;
664 features[f_idx++] = min_abs_mv_row;
665 features[f_idx++] = min_log_sse;
666 features[f_idx++] = var_log_sse;
667 features[f_idx++] = var_mv_col;
668 features[f_idx++] = var_mv_row;
669
670 assert(f_idx == FEATURE_SIZE_MAX_MIN_PART_PRED);
671 }
672
av1_predict_max_partition(AV1_COMP * const cpi,MACROBLOCK * const x,const float * features)673 BLOCK_SIZE av1_predict_max_partition(AV1_COMP *const cpi, MACROBLOCK *const x,
674 const float *features) {
675 float scores[MAX_NUM_CLASSES_MAX_MIN_PART_PRED] = { 0.0f },
676 probs[MAX_NUM_CLASSES_MAX_MIN_PART_PRED] = { 0.0f };
677 const NN_CONFIG *nn_config = &av1_max_part_pred_nn_config;
678
679 assert(cpi->sf.auto_max_partition_based_on_simple_motion != NOT_IN_USE);
680
681 aom_clear_system_state();
682 av1_nn_predict(features, nn_config, scores);
683 av1_nn_softmax(scores, probs, MAX_NUM_CLASSES_MAX_MIN_PART_PRED);
684
685 int result = MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1;
686 if (cpi->sf.auto_max_partition_based_on_simple_motion == DIRECT_PRED) {
687 result = 0;
688 float max_prob = probs[0];
689 for (int i = 1; i < MAX_NUM_CLASSES_MAX_MIN_PART_PRED; ++i) {
690 if (probs[i] > max_prob) {
691 max_prob = probs[i];
692 result = i;
693 }
694 }
695 } else if (cpi->sf.auto_max_partition_based_on_simple_motion ==
696 RELAXED_PRED) {
697 for (result = MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1; result >= 0;
698 --result) {
699 if (result < MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1) {
700 probs[result] += probs[result + 1];
701 }
702 if (probs[result] > 0.2) break;
703 }
704 } else if (cpi->sf.auto_max_partition_based_on_simple_motion == ADAPT_PRED) {
705 const BLOCK_SIZE sb_size = cpi->common.seq_params.sb_size;
706 MACROBLOCKD *const xd = &x->e_mbd;
707 // TODO(debargha): x->source_variance is unavailable at this point,
708 // so compute. The redundant recomputation later can be removed.
709 const unsigned int source_variance =
710 is_cur_buf_hbd(xd)
711 ? av1_high_get_sby_perpixel_variance(cpi, &x->plane[0].src, sb_size,
712 xd->bd)
713 : av1_get_sby_perpixel_variance(cpi, &x->plane[0].src, sb_size);
714 if (source_variance > 16) {
715 const double thresh = source_variance < 128 ? 0.05 : 0.1;
716 for (result = MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1; result >= 0;
717 --result) {
718 if (result < MAX_NUM_CLASSES_MAX_MIN_PART_PRED - 1) {
719 probs[result] += probs[result + 1];
720 }
721 if (probs[result] > thresh) break;
722 }
723 }
724 }
725
726 return (BLOCK_SIZE)((result + 2) * 3);
727 }
728