1 /*
2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include "av1/common/cfl.h"
13 #include "av1/common/reconintra.h"
14 #include "av1/encoder/block.h"
15 #include "av1/encoder/hybrid_fwd_txfm.h"
16 #include "av1/common/idct.h"
17 #include "av1/encoder/model_rd.h"
18 #include "av1/encoder/random.h"
19 #include "av1/encoder/rdopt_utils.h"
20 #include "av1/encoder/sorting_network.h"
21 #include "av1/encoder/tx_prune_model_weights.h"
22 #include "av1/encoder/tx_search.h"
23 #include "av1/encoder/txb_rdopt.h"
24
25 #define PROB_THRESH_OFFSET_TX_TYPE 100
26
27 struct rdcost_block_args {
28 const AV1_COMP *cpi;
29 MACROBLOCK *x;
30 ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
31 ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
32 RD_STATS rd_stats;
33 int64_t current_rd;
34 int64_t best_rd;
35 int exit_early;
36 int incomplete_exit;
37 FAST_TX_SEARCH_MODE ftxs_mode;
38 int skip_trellis;
39 };
40
41 typedef struct {
42 int64_t rd;
43 int txb_entropy_ctx;
44 TX_TYPE tx_type;
45 } TxCandidateInfo;
46
47 // origin_threshold * 128 / 100
48 static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
49 {
50 64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
51 68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
52 },
53 {
54 88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
55 68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
56 },
57 {
58 90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
59 74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
60 },
61 };
62
63 // lookup table for predict_skip_txfm
64 // int max_tx_size = max_txsize_rect_lookup[bsize];
65 // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
66 // max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
67 static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
68 TX_4X4, TX_4X8, TX_8X4, TX_8X8, TX_8X16, TX_16X8,
69 TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
70 TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16, TX_16X4,
71 TX_8X8, TX_8X8, TX_16X16, TX_16X16,
72 };
73
74 // look-up table for sqrt of number of pixels in a transform block
75 // rounded up to the nearest integer.
76 static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] = { 4, 8, 16, 32, 32, 6, 6,
77 12, 12, 23, 23, 32, 32, 8,
78 8, 16, 16, 23, 23 };
79
get_block_residue_hash(MACROBLOCK * x,BLOCK_SIZE bsize)80 static inline uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
81 const int rows = block_size_high[bsize];
82 const int cols = block_size_wide[bsize];
83 const int16_t *diff = x->plane[0].src_diff;
84 const uint32_t hash =
85 av1_get_crc32c_value(&x->txfm_search_info.mb_rd_record->crc_calculator,
86 (uint8_t *)diff, 2 * rows * cols);
87 return (hash << 5) + bsize;
88 }
89
find_mb_rd_info(const MB_RD_RECORD * const mb_rd_record,const int64_t ref_best_rd,const uint32_t hash)90 static inline int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
91 const int64_t ref_best_rd,
92 const uint32_t hash) {
93 int32_t match_index = -1;
94 if (ref_best_rd != INT64_MAX) {
95 for (int i = 0; i < mb_rd_record->num; ++i) {
96 const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
97 // If there is a match in the mb_rd_record, fetch the RD decision and
98 // terminate early.
99 if (mb_rd_record->mb_rd_info[index].hash_value == hash) {
100 match_index = index;
101 break;
102 }
103 }
104 }
105 return match_index;
106 }
107
fetch_mb_rd_info(int n4,const MB_RD_INFO * const mb_rd_info,RD_STATS * const rd_stats,MACROBLOCK * const x)108 static inline void fetch_mb_rd_info(int n4, const MB_RD_INFO *const mb_rd_info,
109 RD_STATS *const rd_stats,
110 MACROBLOCK *const x) {
111 MACROBLOCKD *const xd = &x->e_mbd;
112 MB_MODE_INFO *const mbmi = xd->mi[0];
113 mbmi->tx_size = mb_rd_info->tx_size;
114 memcpy(x->txfm_search_info.blk_skip, mb_rd_info->blk_skip,
115 sizeof(mb_rd_info->blk_skip[0]) * n4);
116 av1_copy(mbmi->inter_tx_size, mb_rd_info->inter_tx_size);
117 av1_copy_array(xd->tx_type_map, mb_rd_info->tx_type_map, n4);
118 *rd_stats = mb_rd_info->rd_stats;
119 }
120
av1_pixel_diff_dist(const MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8)121 int64_t av1_pixel_diff_dist(const MACROBLOCK *x, int plane, int blk_row,
122 int blk_col, const BLOCK_SIZE plane_bsize,
123 const BLOCK_SIZE tx_bsize,
124 unsigned int *block_mse_q8) {
125 int visible_rows, visible_cols;
126 const MACROBLOCKD *xd = &x->e_mbd;
127 get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
128 NULL, &visible_cols, &visible_rows);
129 const int diff_stride = block_size_wide[plane_bsize];
130 const int16_t *diff = x->plane[plane].src_diff;
131
132 diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
133 uint64_t sse =
134 aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
135 if (block_mse_q8 != NULL) {
136 if (visible_cols > 0 && visible_rows > 0)
137 *block_mse_q8 =
138 (unsigned int)((256 * sse) / (visible_cols * visible_rows));
139 else
140 *block_mse_q8 = UINT_MAX;
141 }
142 return sse;
143 }
144
145 // Computes the residual block's SSE and mean on all visible 4x4s in the
146 // transform block
pixel_diff_stats(MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8,int64_t * per_px_mean,uint64_t * block_var)147 static inline int64_t pixel_diff_stats(
148 MACROBLOCK *x, int plane, int blk_row, int blk_col,
149 const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize,
150 unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) {
151 int visible_rows, visible_cols;
152 const MACROBLOCKD *xd = &x->e_mbd;
153 get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
154 NULL, &visible_cols, &visible_rows);
155 const int diff_stride = block_size_wide[plane_bsize];
156 const int16_t *diff = x->plane[plane].src_diff;
157
158 diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
159 uint64_t sse = 0;
160 int sum = 0;
161 sse = aom_sum_sse_2d_i16(diff, diff_stride, visible_cols, visible_rows, &sum);
162 if (visible_cols > 0 && visible_rows > 0) {
163 double norm_factor = 1.0 / (visible_cols * visible_rows);
164 int sign_sum = sum > 0 ? 1 : -1;
165 // Conversion to transform domain
166 *per_px_mean = (int64_t)(norm_factor * abs(sum)) << 7;
167 *per_px_mean = sign_sum * (*per_px_mean);
168 *block_mse_q8 = (unsigned int)(norm_factor * (256 * sse));
169 *block_var = (uint64_t)(sse - (uint64_t)(norm_factor * sum * sum));
170 } else {
171 *block_mse_q8 = UINT_MAX;
172 }
173 return sse;
174 }
175
176 // Uses simple features on top of DCT coefficients to quickly predict
177 // whether optimal RD decision is to skip encoding the residual.
178 // The sse value is stored in dist.
predict_skip_txfm(MACROBLOCK * x,BLOCK_SIZE bsize,int64_t * dist,int reduced_tx_set)179 static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
180 int reduced_tx_set) {
181 const TxfmSearchParams *txfm_params = &x->txfm_search_params;
182 const int bw = block_size_wide[bsize];
183 const int bh = block_size_high[bsize];
184 const MACROBLOCKD *xd = &x->e_mbd;
185 const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
186
187 *dist = av1_pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
188
189 const int64_t mse = *dist / bw / bh;
190 // Normalized quantizer takes the transform upscaling factor (8 for tx size
191 // smaller than 32) into account.
192 const int16_t normalized_dc_q = dc_q >> 3;
193 const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
194 // For faster early skip decision, use dist to compare against threshold so
195 // that quality risk is less for the skip=1 decision. Otherwise, use mse
196 // since the fwd_txfm coeff checks will take care of quality
197 // TODO(any): Use dist to return 0 when skip_txfm_level is 1
198 int64_t pred_err = (txfm_params->skip_txfm_level >= 2) ? *dist : mse;
199 // Predict not to skip when error is larger than threshold.
200 if (pred_err > mse_thresh) return 0;
201 // Return as skip otherwise for aggressive early skip
202 else if (txfm_params->skip_txfm_level >= 2)
203 return 1;
204
205 const int max_tx_size = max_predict_sf_tx_size[bsize];
206 const int tx_h = tx_size_high[max_tx_size];
207 const int tx_w = tx_size_wide[max_tx_size];
208 DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
209 TxfmParam param;
210 param.tx_type = DCT_DCT;
211 param.tx_size = max_tx_size;
212 param.bd = xd->bd;
213 param.is_hbd = is_cur_buf_hbd(xd);
214 param.lossless = 0;
215 param.tx_set_type = av1_get_ext_tx_set_type(
216 param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
217 const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
218 const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
219 const int16_t *src_diff = x->plane[0].src_diff;
220 const int n_coeff = tx_w * tx_h;
221 const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
222 const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
223 const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
224 for (int row = 0; row < bh; row += tx_h) {
225 for (int col = 0; col < bw; col += tx_w) {
226 av1_fwd_txfm(src_diff + col, coefs, bw, ¶m);
227 // Operating on TX domain, not pixels; we want the QTX quantizers
228 const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
229 if (dc_coef >= dc_thresh) return 0;
230 for (int i = 1; i < n_coeff; ++i) {
231 const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
232 if (ac_coef >= ac_thresh) return 0;
233 }
234 }
235 src_diff += tx_h * bw;
236 }
237 return 1;
238 }
239
240 // Used to set proper context for early termination with skip = 1.
set_skip_txfm(MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t dist)241 static inline void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats,
242 BLOCK_SIZE bsize, int64_t dist) {
243 MACROBLOCKD *const xd = &x->e_mbd;
244 MB_MODE_INFO *const mbmi = xd->mi[0];
245 const int n4 = bsize_to_num_blk(bsize);
246 const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
247 memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
248 memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
249 mbmi->tx_size = tx_size;
250 for (int i = 0; i < n4; ++i)
251 set_blk_skip(x->txfm_search_info.blk_skip, 0, i, 1);
252 rd_stats->skip_txfm = 1;
253 if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
254 rd_stats->dist = rd_stats->sse = (dist << 4);
255 // Though decision is to make the block as skip based on luma stats,
256 // it is possible that block becomes non skip after chroma rd. In addition
257 // intermediate non skip costs calculated by caller function will be
258 // incorrect, if rate is set as zero (i.e., if zero_blk_rate is not
259 // accounted). Hence intermediate rate is populated to code the luma tx blks
260 // as skip, the caller function based on final rd decision (i.e., skip vs
261 // non-skip) sets the final rate accordingly. Here the rate populated
262 // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
263 // size possible) in the current block. Eg: For 128*128 block, rate would be
264 // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
265 // block as 'all zeros'
266 ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
267 ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
268 av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
269 ENTROPY_CONTEXT *ta = ctxa;
270 ENTROPY_CONTEXT *tl = ctxl;
271 const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
272 TXB_CTX txb_ctx;
273 get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
274 const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
275 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
276 rd_stats->rate = zero_blk_rate *
277 (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
278 (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
279 }
280
save_mb_rd_info(int n4,uint32_t hash,const MACROBLOCK * const x,const RD_STATS * const rd_stats,MB_RD_RECORD * mb_rd_record)281 static inline void save_mb_rd_info(int n4, uint32_t hash,
282 const MACROBLOCK *const x,
283 const RD_STATS *const rd_stats,
284 MB_RD_RECORD *mb_rd_record) {
285 int index;
286 if (mb_rd_record->num < RD_RECORD_BUFFER_LEN) {
287 index =
288 (mb_rd_record->index_start + mb_rd_record->num) % RD_RECORD_BUFFER_LEN;
289 ++mb_rd_record->num;
290 } else {
291 index = mb_rd_record->index_start;
292 mb_rd_record->index_start =
293 (mb_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
294 }
295 MB_RD_INFO *const mb_rd_info = &mb_rd_record->mb_rd_info[index];
296 const MACROBLOCKD *const xd = &x->e_mbd;
297 const MB_MODE_INFO *const mbmi = xd->mi[0];
298 mb_rd_info->hash_value = hash;
299 mb_rd_info->tx_size = mbmi->tx_size;
300 memcpy(mb_rd_info->blk_skip, x->txfm_search_info.blk_skip,
301 sizeof(mb_rd_info->blk_skip[0]) * n4);
302 av1_copy(mb_rd_info->inter_tx_size, mbmi->inter_tx_size);
303 av1_copy_array(mb_rd_info->tx_type_map, xd->tx_type_map, n4);
304 mb_rd_info->rd_stats = *rd_stats;
305 }
306
get_search_init_depth(int mi_width,int mi_height,int is_inter,const SPEED_FEATURES * sf,int tx_size_search_method)307 static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
308 const SPEED_FEATURES *sf,
309 int tx_size_search_method) {
310 if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
311
312 if (sf->tx_sf.tx_size_search_lgr_block) {
313 if (mi_width > mi_size_wide[BLOCK_64X64] ||
314 mi_height > mi_size_high[BLOCK_64X64])
315 return MAX_VARTX_DEPTH;
316 }
317
318 if (is_inter) {
319 return (mi_height != mi_width)
320 ? sf->tx_sf.inter_tx_size_search_init_depth_rect
321 : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
322 } else {
323 return (mi_height != mi_width)
324 ? sf->tx_sf.intra_tx_size_search_init_depth_rect
325 : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
326 }
327 }
328
329 static inline void select_tx_block(
330 const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
331 TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
332 ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
333 RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
334 int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode);
335
336 // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
337 // 0: Do not collect any RD stats
338 // 1: Collect RD stats for transform units
339 // 2: Collect RD stats for partition units
340 #if CONFIG_COLLECT_RD_STATS
341
get_energy_distribution_fine(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int need_4th,double * hordist,double * verdist)342 static inline void get_energy_distribution_fine(
343 const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
344 const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
345 double *verdist) {
346 const int bw = block_size_wide[bsize];
347 const int bh = block_size_high[bsize];
348 unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
349
350 if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
351 // Special cases: calculate 'esq' values manually, as we don't have 'vf'
352 // functions for the 16 (very small) sub-blocks of this block.
353 const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
354 const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
355 assert(bw <= 32);
356 assert(bh <= 32);
357 assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
358 if (cpi->common.seq_params->use_highbitdepth) {
359 const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
360 const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
361 for (int i = 0; i < bh; ++i)
362 for (int j = 0; j < bw; ++j) {
363 const int index = (j >> w_shift) + ((i >> h_shift) << 2);
364 esq[index] +=
365 (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
366 (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
367 }
368 } else {
369 for (int i = 0; i < bh; ++i)
370 for (int j = 0; j < bw; ++j) {
371 const int index = (j >> w_shift) + ((i >> h_shift) << 2);
372 esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
373 (src[j + i * src_stride] - dst[j + i * dst_stride]);
374 }
375 }
376 } else { // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
377 const int f_index =
378 (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
379 assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
380 const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
381 assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
382 assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
383 cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
384 cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
385 dst_stride, &esq[1]);
386 cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
387 dst_stride, &esq[2]);
388 cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
389 dst_stride, &esq[3]);
390 src += bh / 4 * src_stride;
391 dst += bh / 4 * dst_stride;
392
393 cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
394 cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
395 dst_stride, &esq[5]);
396 cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
397 dst_stride, &esq[6]);
398 cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
399 dst_stride, &esq[7]);
400 src += bh / 4 * src_stride;
401 dst += bh / 4 * dst_stride;
402
403 cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
404 cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
405 dst_stride, &esq[9]);
406 cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
407 dst_stride, &esq[10]);
408 cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
409 dst_stride, &esq[11]);
410 src += bh / 4 * src_stride;
411 dst += bh / 4 * dst_stride;
412
413 cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
414 cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
415 dst_stride, &esq[13]);
416 cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
417 dst_stride, &esq[14]);
418 cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
419 dst_stride, &esq[15]);
420 }
421
422 double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
423 esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
424 esq[12] + esq[13] + esq[14] + esq[15];
425 if (total > 0) {
426 const double e_recip = 1.0 / total;
427 hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
428 hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
429 hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
430 if (need_4th) {
431 hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
432 }
433 verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
434 verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
435 verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
436 if (need_4th) {
437 verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
438 }
439 } else {
440 hordist[0] = verdist[0] = 0.25;
441 hordist[1] = verdist[1] = 0.25;
442 hordist[2] = verdist[2] = 0.25;
443 if (need_4th) {
444 hordist[3] = verdist[3] = 0.25;
445 }
446 }
447 }
448
get_sse_norm(const int16_t * diff,int stride,int w,int h)449 static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
450 double sum = 0.0;
451 for (int j = 0; j < h; ++j) {
452 for (int i = 0; i < w; ++i) {
453 const int err = diff[j * stride + i];
454 sum += err * err;
455 }
456 }
457 assert(w > 0 && h > 0);
458 return sum / (w * h);
459 }
460
get_sad_norm(const int16_t * diff,int stride,int w,int h)461 static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
462 double sum = 0.0;
463 for (int j = 0; j < h; ++j) {
464 for (int i = 0; i < w; ++i) {
465 sum += abs(diff[j * stride + i]);
466 }
467 }
468 assert(w > 0 && h > 0);
469 return sum / (w * h);
470 }
471
get_2x2_normalized_sses_and_sads(const AV1_COMP * const cpi,BLOCK_SIZE tx_bsize,const uint8_t * const src,int src_stride,const uint8_t * const dst,int dst_stride,const int16_t * const src_diff,int diff_stride,double * const sse_norm_arr,double * const sad_norm_arr)472 static inline void get_2x2_normalized_sses_and_sads(
473 const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
474 int src_stride, const uint8_t *const dst, int dst_stride,
475 const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
476 double *const sad_norm_arr) {
477 const BLOCK_SIZE tx_bsize_half =
478 get_partition_subsize(tx_bsize, PARTITION_SPLIT);
479 if (tx_bsize_half == BLOCK_INVALID) { // manually calculate stats
480 const int half_width = block_size_wide[tx_bsize] / 2;
481 const int half_height = block_size_high[tx_bsize] / 2;
482 for (int row = 0; row < 2; ++row) {
483 for (int col = 0; col < 2; ++col) {
484 const int16_t *const this_src_diff =
485 src_diff + row * half_height * diff_stride + col * half_width;
486 if (sse_norm_arr) {
487 sse_norm_arr[row * 2 + col] =
488 get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
489 }
490 if (sad_norm_arr) {
491 sad_norm_arr[row * 2 + col] =
492 get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
493 }
494 }
495 }
496 } else { // use function pointers to calculate stats
497 const int half_width = block_size_wide[tx_bsize_half];
498 const int half_height = block_size_high[tx_bsize_half];
499 const int num_samples_half = half_width * half_height;
500 for (int row = 0; row < 2; ++row) {
501 for (int col = 0; col < 2; ++col) {
502 const uint8_t *const this_src =
503 src + row * half_height * src_stride + col * half_width;
504 const uint8_t *const this_dst =
505 dst + row * half_height * dst_stride + col * half_width;
506
507 if (sse_norm_arr) {
508 unsigned int this_sse;
509 cpi->ppi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
510 dst_stride, &this_sse);
511 sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
512 }
513
514 if (sad_norm_arr) {
515 const unsigned int this_sad = cpi->ppi->fn_ptr[tx_bsize_half].sdf(
516 this_src, src_stride, this_dst, dst_stride);
517 sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
518 }
519 }
520 }
521 }
522 }
523
524 #if CONFIG_COLLECT_RD_STATS == 1
get_mean(const int16_t * diff,int stride,int w,int h)525 static double get_mean(const int16_t *diff, int stride, int w, int h) {
526 double sum = 0.0;
527 for (int j = 0; j < h; ++j) {
528 for (int i = 0; i < w; ++i) {
529 sum += diff[j * stride + i];
530 }
531 }
532 assert(w > 0 && h > 0);
533 return sum / (w * h);
534 }
PrintTransformUnitStats(const AV1_COMP * const cpi,MACROBLOCK * x,const RD_STATS * const rd_stats,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,TX_TYPE tx_type,int64_t rd)535 static inline void PrintTransformUnitStats(
536 const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
537 int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
538 TX_TYPE tx_type, int64_t rd) {
539 if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
540
541 // Generate small sample to restrict output size.
542 static unsigned int seed = 21743;
543 if (lcg_rand16(&seed) % 256 > 0) return;
544
545 const char output_file[] = "tu_stats.txt";
546 FILE *fout = fopen(output_file, "a");
547 if (!fout) return;
548
549 const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
550 const MACROBLOCKD *const xd = &x->e_mbd;
551 const int plane = 0;
552 struct macroblock_plane *const p = &x->plane[plane];
553 const struct macroblockd_plane *const pd = &xd->plane[plane];
554 const int txw = tx_size_wide[tx_size];
555 const int txh = tx_size_high[tx_size];
556 const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
557 const int q_step = p->dequant_QTX[1] >> dequant_shift;
558 const int num_samples = txw * txh;
559
560 const double rate_norm = (double)rd_stats->rate / num_samples;
561 const double dist_norm = (double)rd_stats->dist / num_samples;
562
563 fprintf(fout, "%g %g", rate_norm, dist_norm);
564
565 const int src_stride = p->src.stride;
566 const uint8_t *const src =
567 &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
568 const int dst_stride = pd->dst.stride;
569 const uint8_t *const dst =
570 &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
571 unsigned int sse;
572 cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
573 const double sse_norm = (double)sse / num_samples;
574
575 const unsigned int sad =
576 cpi->ppi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
577 const double sad_norm = (double)sad / num_samples;
578
579 fprintf(fout, " %g %g", sse_norm, sad_norm);
580
581 const int diff_stride = block_size_wide[plane_bsize];
582 const int16_t *const src_diff =
583 &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
584
585 double sse_norm_arr[4], sad_norm_arr[4];
586 get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
587 dst_stride, src_diff, diff_stride,
588 sse_norm_arr, sad_norm_arr);
589 for (int i = 0; i < 4; ++i) {
590 fprintf(fout, " %g", sse_norm_arr[i]);
591 }
592 for (int i = 0; i < 4; ++i) {
593 fprintf(fout, " %g", sad_norm_arr[i]);
594 }
595
596 const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
597 const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
598
599 fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
600 tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
601
602 int model_rate;
603 int64_t model_dist;
604 model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
605 &model_rate, &model_dist);
606 const double model_rate_norm = (double)model_rate / num_samples;
607 const double model_dist_norm = (double)model_dist / num_samples;
608 fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
609
610 const double mean = get_mean(src_diff, diff_stride, txw, txh);
611 float hor_corr, vert_corr;
612 av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
613 &vert_corr);
614 fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
615
616 double hdist[4] = { 0 }, vdist[4] = { 0 };
617 get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
618 1, hdist, vdist);
619 fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
620 hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
621
622 fprintf(fout, " %d %" PRId64, x->rdmult, rd);
623
624 fprintf(fout, "\n");
625 fclose(fout);
626 }
627 #endif // CONFIG_COLLECT_RD_STATS == 1
628
629 #if CONFIG_COLLECT_RD_STATS >= 2
get_sse(const AV1_COMP * cpi,const MACROBLOCK * x)630 static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
631 const AV1_COMMON *cm = &cpi->common;
632 const int num_planes = av1_num_planes(cm);
633 const MACROBLOCKD *xd = &x->e_mbd;
634 const MB_MODE_INFO *mbmi = xd->mi[0];
635 int64_t total_sse = 0;
636 for (int plane = 0; plane < num_planes; ++plane) {
637 const struct macroblock_plane *const p = &x->plane[plane];
638 const struct macroblockd_plane *const pd = &xd->plane[plane];
639 const BLOCK_SIZE bs =
640 get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
641 unsigned int sse;
642
643 if (plane) continue;
644
645 cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf,
646 pd->dst.stride, &sse);
647 total_sse += sse;
648 }
649 total_sse <<= 4;
650 return total_sse;
651 }
652
get_est_rate_dist(const TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int * est_residue_cost,int64_t * est_dist)653 static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
654 int64_t sse, int *est_residue_cost,
655 int64_t *est_dist) {
656 const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
657 if (md->ready) {
658 if (sse < md->dist_mean) {
659 *est_residue_cost = 0;
660 *est_dist = sse;
661 } else {
662 *est_dist = (int64_t)round(md->dist_mean);
663 const double est_ld = md->a * sse + md->b;
664 // Clamp estimated rate cost by INT_MAX / 2.
665 // TODO(angiebird@google.com): find better solution than clamping.
666 if (fabs(est_ld) < 1e-2) {
667 *est_residue_cost = INT_MAX / 2;
668 } else {
669 double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
670 if (est_residue_cost_dbl < 0) {
671 *est_residue_cost = 0;
672 } else {
673 *est_residue_cost =
674 (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
675 }
676 }
677 if (*est_residue_cost <= 0) {
678 *est_residue_cost = 0;
679 *est_dist = sse;
680 }
681 }
682 return 1;
683 }
684 return 0;
685 }
686
get_highbd_diff_mean(const uint8_t * src8,int src_stride,const uint8_t * dst8,int dst_stride,int w,int h)687 static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
688 const uint8_t *dst8, int dst_stride, int w,
689 int h) {
690 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
691 const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
692 double sum = 0.0;
693 for (int j = 0; j < h; ++j) {
694 for (int i = 0; i < w; ++i) {
695 const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
696 sum += diff;
697 }
698 }
699 assert(w > 0 && h > 0);
700 return sum / (w * h);
701 }
702
get_diff_mean(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int w,int h)703 static double get_diff_mean(const uint8_t *src, int src_stride,
704 const uint8_t *dst, int dst_stride, int w, int h) {
705 double sum = 0.0;
706 for (int j = 0; j < h; ++j) {
707 for (int i = 0; i < w; ++i) {
708 const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
709 sum += diff;
710 }
711 }
712 assert(w > 0 && h > 0);
713 return sum / (w * h);
714 }
715
PrintPredictionUnitStats(const AV1_COMP * const cpi,const TileDataEnc * tile_data,MACROBLOCK * x,const RD_STATS * const rd_stats,BLOCK_SIZE plane_bsize)716 static inline void PrintPredictionUnitStats(const AV1_COMP *const cpi,
717 const TileDataEnc *tile_data,
718 MACROBLOCK *x,
719 const RD_STATS *const rd_stats,
720 BLOCK_SIZE plane_bsize) {
721 if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
722
723 if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
724 (tile_data == NULL ||
725 !tile_data->inter_mode_rd_models[plane_bsize].ready))
726 return;
727 (void)tile_data;
728 // Generate small sample to restrict output size.
729 static unsigned int seed = 95014;
730
731 if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
732 1)
733 return;
734
735 const char output_file[] = "pu_stats.txt";
736 FILE *fout = fopen(output_file, "a");
737 if (!fout) return;
738
739 MACROBLOCKD *const xd = &x->e_mbd;
740 const int plane = 0;
741 struct macroblock_plane *const p = &x->plane[plane];
742 struct macroblockd_plane *pd = &xd->plane[plane];
743 const int diff_stride = block_size_wide[plane_bsize];
744 int bw, bh;
745 get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
746 &bh);
747 const int num_samples = bw * bh;
748 const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
749 const int q_step = p->dequant_QTX[1] >> dequant_shift;
750 const int shift = (xd->bd - 8);
751
752 const double rate_norm = (double)rd_stats->rate / num_samples;
753 const double dist_norm = (double)rd_stats->dist / num_samples;
754 const double rdcost_norm =
755 (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
756
757 fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
758
759 const int src_stride = p->src.stride;
760 const uint8_t *const src = p->src.buf;
761 const int dst_stride = pd->dst.stride;
762 const uint8_t *const dst = pd->dst.buf;
763 const int16_t *const src_diff = p->src_diff;
764
765 int64_t sse = calculate_sse(xd, p, pd, bw, bh);
766 const double sse_norm = (double)sse / num_samples;
767
768 const unsigned int sad =
769 cpi->ppi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
770 const double sad_norm =
771 (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
772
773 fprintf(fout, " %g %g", sse_norm, sad_norm);
774
775 double sse_norm_arr[4], sad_norm_arr[4];
776 get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
777 dst_stride, src_diff, diff_stride,
778 sse_norm_arr, sad_norm_arr);
779 if (shift) {
780 for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
781 for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
782 }
783 for (int i = 0; i < 4; ++i) {
784 fprintf(fout, " %g", sse_norm_arr[i]);
785 }
786 for (int i = 0; i < 4; ++i) {
787 fprintf(fout, " %g", sad_norm_arr[i]);
788 }
789
790 fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
791
792 int model_rate;
793 int64_t model_dist;
794 model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
795 &model_rate, &model_dist);
796 const double model_rdcost_norm =
797 (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
798 const double model_rate_norm = (double)model_rate / num_samples;
799 const double model_dist_norm = (double)model_dist / num_samples;
800 fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
801 model_rdcost_norm);
802
803 double mean;
804 if (is_cur_buf_hbd(xd)) {
805 mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
806 pd->dst.stride, bw, bh);
807 } else {
808 mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
809 bw, bh);
810 }
811 mean /= (1 << shift);
812 float hor_corr, vert_corr;
813 av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
814 &vert_corr);
815 fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
816
817 double hdist[4] = { 0 }, vdist[4] = { 0 };
818 get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
819 dst_stride, 1, hdist, vdist);
820 fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
821 hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
822
823 if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
824 assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
825 const int64_t overall_sse = get_sse(cpi, x);
826 int est_residue_cost = 0;
827 int64_t est_dist = 0;
828 get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
829 &est_dist);
830 const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
831 const double est_dist_norm = (double)est_dist / num_samples;
832 const double est_rdcost_norm =
833 (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
834 fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
835 est_rdcost_norm);
836 }
837
838 fprintf(fout, "\n");
839 fclose(fout);
840 }
841 #endif // CONFIG_COLLECT_RD_STATS >= 2
842 #endif // CONFIG_COLLECT_RD_STATS
843
inverse_transform_block_facade(MACROBLOCK * const x,int plane,int block,int blk_row,int blk_col,int eob,int reduced_tx_set)844 static inline void inverse_transform_block_facade(MACROBLOCK *const x,
845 int plane, int block,
846 int blk_row, int blk_col,
847 int eob, int reduced_tx_set) {
848 if (!eob) return;
849 struct macroblock_plane *const p = &x->plane[plane];
850 MACROBLOCKD *const xd = &x->e_mbd;
851 tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
852 const PLANE_TYPE plane_type = get_plane_type(plane);
853 const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
854 const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
855 tx_size, reduced_tx_set);
856
857 struct macroblockd_plane *const pd = &xd->plane[plane];
858 const int dst_stride = pd->dst.stride;
859 uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
860 av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
861 dst_stride, eob, reduced_tx_set);
862 }
863
recon_intra(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,int skip_trellis,TX_TYPE best_tx_type,int do_quant,int * rate_cost,uint16_t best_eob)864 static inline void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
865 int block, int blk_row, int blk_col,
866 BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
867 const TXB_CTX *const txb_ctx, int skip_trellis,
868 TX_TYPE best_tx_type, int do_quant,
869 int *rate_cost, uint16_t best_eob) {
870 const AV1_COMMON *cm = &cpi->common;
871 MACROBLOCKD *xd = &x->e_mbd;
872 MB_MODE_INFO *mbmi = xd->mi[0];
873 const int is_inter = is_inter_block(mbmi);
874 if (!is_inter && best_eob &&
875 (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
876 blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
877 // if the quantized coefficients are stored in the dqcoeff buffer, we don't
878 // need to do transform and quantization again.
879 if (do_quant) {
880 TxfmParam txfm_param_intra;
881 QUANT_PARAM quant_param_intra;
882 av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
883 av1_setup_quant(tx_size, !skip_trellis,
884 skip_trellis
885 ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
886 : AV1_XFORM_QUANT_FP)
887 : AV1_XFORM_QUANT_FP,
888 cpi->oxcf.q_cfg.quant_b_adapt, &quant_param_intra);
889 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
890 &quant_param_intra);
891 av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
892 &txfm_param_intra, &quant_param_intra);
893 if (quant_param_intra.use_optimize_b) {
894 av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
895 rate_cost);
896 }
897 }
898
899 inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
900 x->plane[plane].eobs[block],
901 cm->features.reduced_tx_set_used);
902
903 // This may happen because of hash collision. The eob stored in the hash
904 // table is non-zero, but the real eob is zero. We need to make sure tx_type
905 // is DCT_DCT in this case.
906 if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
907 best_tx_type != DCT_DCT) {
908 update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
909 }
910 }
911 }
912
pixel_dist_visible_only(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,const BLOCK_SIZE tx_bsize,int txb_rows,int txb_cols,int visible_rows,int visible_cols)913 static unsigned pixel_dist_visible_only(
914 const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
915 const int src_stride, const uint8_t *dst, const int dst_stride,
916 const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
917 int visible_cols) {
918 unsigned sse;
919
920 if (txb_rows == visible_rows && txb_cols == visible_cols) {
921 cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
922 return sse;
923 }
924
925 #if CONFIG_AV1_HIGHBITDEPTH
926 const MACROBLOCKD *xd = &x->e_mbd;
927 if (is_cur_buf_hbd(xd)) {
928 uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
929 visible_cols, visible_rows);
930 return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
931 }
932 #else
933 (void)x;
934 #endif
935 sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
936 visible_rows);
937 return sse;
938 }
939
940 // Compute the pixel domain distortion from src and dst on all visible 4x4s in
941 // the
942 // transform block.
pixel_dist(const AV1_COMP * const cpi,const MACROBLOCK * x,int plane,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize)943 static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
944 int plane, const uint8_t *src, const int src_stride,
945 const uint8_t *dst, const int dst_stride,
946 int blk_row, int blk_col,
947 const BLOCK_SIZE plane_bsize,
948 const BLOCK_SIZE tx_bsize) {
949 int txb_rows, txb_cols, visible_rows, visible_cols;
950 const MACROBLOCKD *xd = &x->e_mbd;
951
952 get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
953 &txb_cols, &txb_rows, &visible_cols, &visible_rows);
954 assert(visible_rows > 0);
955 assert(visible_cols > 0);
956
957 unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
958 dst_stride, tx_bsize, txb_rows,
959 txb_cols, visible_rows, visible_cols);
960
961 return sse;
962 }
963
dist_block_px_domain(const AV1_COMP * cpi,MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,int block,int blk_row,int blk_col,TX_SIZE tx_size)964 static inline int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
965 int plane, BLOCK_SIZE plane_bsize,
966 int block, int blk_row, int blk_col,
967 TX_SIZE tx_size) {
968 MACROBLOCKD *const xd = &x->e_mbd;
969 const struct macroblock_plane *const p = &x->plane[plane];
970 const uint16_t eob = p->eobs[block];
971 const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
972 const int bsw = block_size_wide[tx_bsize];
973 const int bsh = block_size_high[tx_bsize];
974 const int src_stride = x->plane[plane].src.stride;
975 const int dst_stride = xd->plane[plane].dst.stride;
976 // Scale the transform block index to pixel unit.
977 const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
978 const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
979 const uint8_t *src = &x->plane[plane].src.buf[src_idx];
980 const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
981 const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
982
983 assert(cpi != NULL);
984 assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
985
986 uint8_t *recon;
987 DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
988
989 #if CONFIG_AV1_HIGHBITDEPTH
990 if (is_cur_buf_hbd(xd)) {
991 recon = CONVERT_TO_BYTEPTR(recon16);
992 aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst), dst_stride,
993 CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw, bsh);
994 } else {
995 recon = (uint8_t *)recon16;
996 aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
997 }
998 #else
999 recon = (uint8_t *)recon16;
1000 aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1001 #endif
1002
1003 const PLANE_TYPE plane_type = get_plane_type(plane);
1004 TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
1005 cpi->common.features.reduced_tx_set_used);
1006 av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
1007 MAX_TX_SIZE, eob,
1008 cpi->common.features.reduced_tx_set_used);
1009
1010 return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
1011 blk_row, blk_col, plane_bsize, tx_bsize);
1012 }
1013
1014 // pruning thresholds for prune_txk_type and prune_txk_type_separ
1015 static const int prune_factors[5] = { 200, 200, 120, 80, 40 }; // scale 1000
1016 static const int mul_factors[5] = { 80, 80, 70, 50, 30 }; // scale 100
1017
1018 // R-D costs are sorted in ascending order.
sort_rd(int64_t rds[],int txk[],int len)1019 static inline void sort_rd(int64_t rds[], int txk[], int len) {
1020 int i, j, k;
1021
1022 for (i = 1; i <= len - 1; ++i) {
1023 for (j = 0; j < i; ++j) {
1024 if (rds[j] > rds[i]) {
1025 int64_t temprd;
1026 int tempi;
1027
1028 temprd = rds[i];
1029 tempi = txk[i];
1030
1031 for (k = i; k > j; k--) {
1032 rds[k] = rds[k - 1];
1033 txk[k] = txk[k - 1];
1034 }
1035
1036 rds[j] = temprd;
1037 txk[j] = tempi;
1038 break;
1039 }
1040 }
1041 }
1042 }
1043
av1_block_error_qm(const tran_low_t * coeff,const tran_low_t * dqcoeff,intptr_t block_size,const qm_val_t * qmatrix,const int16_t * scan,int64_t * ssz,int bd)1044 static inline int64_t av1_block_error_qm(
1045 const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size,
1046 const qm_val_t *qmatrix, const int16_t *scan, int64_t *ssz, int bd) {
1047 int i;
1048 int64_t error = 0, sqcoeff = 0;
1049 int shift = 2 * (bd - 8);
1050 int rounding = (1 << shift) >> 1;
1051
1052 for (i = 0; i < block_size; i++) {
1053 int64_t weight = qmatrix[scan[i]];
1054 int64_t dd = coeff[i] - dqcoeff[i];
1055 dd *= weight;
1056 int64_t cc = coeff[i];
1057 cc *= weight;
1058 // The ranges of coeff and dqcoeff are
1059 // bd8 : 18 bits (including sign)
1060 // bd10: 20 bits (including sign)
1061 // bd12: 22 bits (including sign)
1062 // As AOM_QM_BITS is 5, the intermediate quantities in the calculation
1063 // below should fit in 54 bits, thus no overflow should happen.
1064 error += (dd * dd + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1065 sqcoeff += (cc * cc + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1066 }
1067
1068 error = (error + rounding) >> shift;
1069 sqcoeff = (sqcoeff + rounding) >> shift;
1070
1071 *ssz = sqcoeff;
1072 return error;
1073 }
1074
dist_block_tx_domain(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,const qm_val_t * qmatrix,const int16_t * scan,int64_t * out_dist,int64_t * out_sse)1075 static inline void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
1076 TX_SIZE tx_size,
1077 const qm_val_t *qmatrix,
1078 const int16_t *scan, int64_t *out_dist,
1079 int64_t *out_sse) {
1080 const struct macroblock_plane *const p = &x->plane[plane];
1081 // Transform domain distortion computation is more efficient as it does
1082 // not involve an inverse transform, but it is less accurate.
1083 const int buffer_length = av1_get_max_eob(tx_size);
1084 int64_t this_sse;
1085 // TX-domain results need to shift down to Q2/D10 to match pixel
1086 // domain distortion values which are in Q2^2
1087 int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
1088 const int block_offset = BLOCK_OFFSET(block);
1089 tran_low_t *const coeff = p->coeff + block_offset;
1090 tran_low_t *const dqcoeff = p->dqcoeff + block_offset;
1091 #if CONFIG_AV1_HIGHBITDEPTH
1092 MACROBLOCKD *const xd = &x->e_mbd;
1093 if (is_cur_buf_hbd(xd)) {
1094 if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
1095 *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length,
1096 &this_sse, xd->bd);
1097 } else {
1098 *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
1099 scan, &this_sse, xd->bd);
1100 }
1101 } else {
1102 #endif
1103 if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
1104 *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
1105 } else {
1106 *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
1107 scan, &this_sse, 8);
1108 }
1109 #if CONFIG_AV1_HIGHBITDEPTH
1110 }
1111 #endif
1112
1113 *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
1114 *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
1115 }
1116
prune_txk_type_separ(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,int16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used,int64_t ref_best_rd,int num_sel)1117 static uint16_t prune_txk_type_separ(
1118 const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, TX_SIZE tx_size,
1119 int blk_row, int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
1120 int16_t allowed_tx_mask, int prune_factor, const TXB_CTX *const txb_ctx,
1121 int reduced_tx_set_used, int64_t ref_best_rd, int num_sel) {
1122 const AV1_COMMON *cm = &cpi->common;
1123 MACROBLOCKD *xd = &x->e_mbd;
1124
1125 int idx;
1126
1127 int64_t rds_v[4];
1128 int64_t rds_h[4];
1129 int idx_v[4] = { 0, 1, 2, 3 };
1130 int idx_h[4] = { 0, 1, 2, 3 };
1131 int skip_v[4] = { 0 };
1132 int skip_h[4] = { 0 };
1133 const int idx_map[16] = {
1134 DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
1135 ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
1136 FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1137 H_DCT, H_ADST, H_FLIPADST, IDTX
1138 };
1139
1140 const int sel_pattern_v[16] = {
1141 0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
1142 };
1143 const int sel_pattern_h[16] = {
1144 0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
1145 };
1146
1147 QUANT_PARAM quant_param;
1148 TxfmParam txfm_param;
1149 av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1150 av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1151 &quant_param);
1152 int tx_type;
1153 // to ensure we can try ones even outside of ext_tx_set of current block
1154 // this function should only be called for size < 16
1155 assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
1156 txfm_param.tx_set_type = EXT_TX_SET_ALL16;
1157
1158 int rate_cost = 0;
1159 int64_t dist = 0, sse = 0;
1160 // evaluate horizontal with vertical DCT
1161 for (idx = 0; idx < 4; ++idx) {
1162 tx_type = idx_map[idx];
1163 txfm_param.tx_type = tx_type;
1164
1165 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1166 &quant_param);
1167
1168 av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1169 &quant_param);
1170
1171 const SCAN_ORDER *const scan_order =
1172 get_scan(txfm_param.tx_size, txfm_param.tx_type);
1173 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1174 scan_order->scan, &dist, &sse);
1175
1176 rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1177 txb_ctx, reduced_tx_set_used, 0);
1178
1179 rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
1180
1181 if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
1182 skip_h[idx] = 1;
1183 }
1184 }
1185 sort_rd(rds_h, idx_h, 4);
1186 for (idx = 1; idx < 4; idx++) {
1187 if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
1188 }
1189
1190 if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
1191
1192 // evaluate vertical with the best horizontal chosen
1193 rds_v[0] = rds_h[0];
1194 int start_v = 1, end_v = 4;
1195 const int *idx_map_v = idx_map + idx_h[0];
1196
1197 for (idx = start_v; idx < end_v; ++idx) {
1198 tx_type = idx_map_v[idx_v[idx] * 4];
1199 txfm_param.tx_type = tx_type;
1200
1201 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1202 &quant_param);
1203
1204 av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1205 &quant_param);
1206
1207 const SCAN_ORDER *const scan_order =
1208 get_scan(txfm_param.tx_size, txfm_param.tx_type);
1209 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1210 scan_order->scan, &dist, &sse);
1211
1212 rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1213 txb_ctx, reduced_tx_set_used, 0);
1214
1215 rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
1216
1217 if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
1218 skip_v[idx] = 1;
1219 }
1220 }
1221 sort_rd(rds_v, idx_v, 4);
1222 for (idx = 1; idx < 4; idx++) {
1223 if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
1224 }
1225
1226 // combine rd_h and rd_v to prune tx candidates
1227 int i_v, i_h;
1228 int64_t rds[16];
1229 int num_cand = 0, last = TX_TYPES - 1;
1230
1231 for (int i = 0; i < 16; i++) {
1232 i_v = sel_pattern_v[i];
1233 i_h = sel_pattern_h[i];
1234 tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
1235 if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
1236 skip_v[idx_v[i_v]]) {
1237 txk_map[last] = tx_type;
1238 last--;
1239 } else {
1240 txk_map[num_cand] = tx_type;
1241 rds[num_cand] = rds_v[i_v] + rds_h[i_h];
1242 if (rds[num_cand] == 0) rds[num_cand] = 1;
1243 num_cand++;
1244 }
1245 }
1246 sort_rd(rds, txk_map, num_cand);
1247
1248 uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1249 num_sel = AOMMIN(num_sel, num_cand);
1250
1251 for (int i = 1; i < num_sel; i++) {
1252 int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
1253 if (factor < (int64_t)prune_factor)
1254 prune &= ~(1 << txk_map[i]);
1255 else
1256 break;
1257 }
1258 return prune;
1259 }
1260
prune_txk_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,uint16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)1261 static uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1262 int block, TX_SIZE tx_size, int blk_row,
1263 int blk_col, BLOCK_SIZE plane_bsize,
1264 int *txk_map, uint16_t allowed_tx_mask,
1265 int prune_factor, const TXB_CTX *const txb_ctx,
1266 int reduced_tx_set_used) {
1267 const AV1_COMMON *cm = &cpi->common;
1268 MACROBLOCKD *xd = &x->e_mbd;
1269 int tx_type;
1270
1271 int64_t rds[TX_TYPES];
1272
1273 int num_cand = 0;
1274 int last = TX_TYPES - 1;
1275
1276 TxfmParam txfm_param;
1277 QUANT_PARAM quant_param;
1278 av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1279 av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1280 &quant_param);
1281
1282 for (int idx = 0; idx < TX_TYPES; idx++) {
1283 tx_type = idx;
1284 int rate_cost = 0;
1285 int64_t dist = 0, sse = 0;
1286 if (!(allowed_tx_mask & (1 << tx_type))) {
1287 txk_map[last] = tx_type;
1288 last--;
1289 continue;
1290 }
1291 txfm_param.tx_type = tx_type;
1292
1293 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1294 &quant_param);
1295
1296 // do txfm and quantization
1297 av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1298 &quant_param);
1299 // estimate rate cost
1300 rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1301 txb_ctx, reduced_tx_set_used, 0);
1302 // tx domain dist
1303 const SCAN_ORDER *const scan_order =
1304 get_scan(txfm_param.tx_size, txfm_param.tx_type);
1305 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1306 scan_order->scan, &dist, &sse);
1307
1308 txk_map[num_cand] = tx_type;
1309 rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
1310 if (rds[num_cand] == 0) rds[num_cand] = 1;
1311 num_cand++;
1312 }
1313
1314 if (num_cand == 0) return (uint16_t)0xFFFF;
1315
1316 sort_rd(rds, txk_map, num_cand);
1317 uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1318
1319 // 0 < prune_factor <= 1000 controls aggressiveness
1320 int64_t factor = 0;
1321 for (int idx = 1; idx < num_cand; idx++) {
1322 factor = 1000 * (rds[idx] - rds[0]) / rds[0];
1323 if (factor < (int64_t)prune_factor)
1324 prune &= ~(1 << txk_map[idx]);
1325 else
1326 break;
1327 }
1328 return prune;
1329 }
1330
1331 // These thresholds were calibrated to provide a certain number of TX types
1332 // pruned by the model on average, i.e. selecting a threshold with index i
1333 // will lead to pruning i+1 TX types on average
1334 static const float *prune_2D_adaptive_thresholds[] = {
1335 // TX_4X4
1336 (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1337 0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1338 0.09778f, 0.11780f },
1339 // TX_8X8
1340 (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1341 0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1342 0.10803f, 0.14124f },
1343 // TX_16X16
1344 (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1345 0.06897f, 0.07629f, 0.08875f, 0.11169f },
1346 // TX_32X32
1347 NULL,
1348 // TX_64X64
1349 NULL,
1350 // TX_4X8
1351 (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1352 0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1353 0.10168f, 0.12585f },
1354 // TX_8X4
1355 (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1356 0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1357 0.10583f, 0.13123f },
1358 // TX_8X16
1359 (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1360 0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1361 0.10730f, 0.14221f },
1362 // TX_16X8
1363 (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1364 0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1365 0.10339f, 0.13464f },
1366 // TX_16X32
1367 NULL,
1368 // TX_32X16
1369 NULL,
1370 // TX_32X64
1371 NULL,
1372 // TX_64X32
1373 NULL,
1374 // TX_4X16
1375 (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1376 0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1377 0.10242f, 0.12878f },
1378 // TX_16X4
1379 (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1380 0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1381 0.10217f, 0.12610f },
1382 // TX_8X32
1383 NULL,
1384 // TX_32X8
1385 NULL,
1386 // TX_16X64
1387 NULL,
1388 // TX_64X16
1389 NULL,
1390 };
1391
get_adaptive_thresholds(TX_SIZE tx_size,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode)1392 static inline float get_adaptive_thresholds(
1393 TX_SIZE tx_size, TxSetType tx_set_type,
1394 TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) {
1395 const int prune_aggr_table[5][2] = {
1396 { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 }, { 12, 9 }
1397 };
1398 int pruning_aggressiveness = 0;
1399 if (tx_set_type == EXT_TX_SET_ALL16)
1400 pruning_aggressiveness =
1401 prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][0];
1402 else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
1403 pruning_aggressiveness =
1404 prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][1];
1405
1406 return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
1407 }
1408
get_energy_distribution_finer(const int16_t * diff,int stride,int bw,int bh,float * hordist,float * verdist)1409 static inline void get_energy_distribution_finer(const int16_t *diff,
1410 int stride, int bw, int bh,
1411 float *hordist,
1412 float *verdist) {
1413 // First compute downscaled block energy values (esq); downscale factors
1414 // are defined by w_shift and h_shift.
1415 unsigned int esq[256];
1416 const int w_shift = bw <= 8 ? 0 : 1;
1417 const int h_shift = bh <= 8 ? 0 : 1;
1418 const int esq_w = bw >> w_shift;
1419 const int esq_h = bh >> h_shift;
1420 const int esq_sz = esq_w * esq_h;
1421 int i, j;
1422 memset(esq, 0, esq_sz * sizeof(esq[0]));
1423 if (w_shift) {
1424 for (i = 0; i < bh; i++) {
1425 unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1426 const int16_t *cur_diff_row = diff + i * stride;
1427 for (j = 0; j < bw; j += 2) {
1428 cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1429 cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1430 }
1431 }
1432 } else {
1433 for (i = 0; i < bh; i++) {
1434 unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1435 const int16_t *cur_diff_row = diff + i * stride;
1436 for (j = 0; j < bw; j++) {
1437 cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1438 }
1439 }
1440 }
1441
1442 uint64_t total = 0;
1443 for (i = 0; i < esq_sz; i++) total += esq[i];
1444
1445 // Output hordist and verdist arrays are normalized 1D projections of esq
1446 if (total == 0) {
1447 float hor_val = 1.0f / esq_w;
1448 for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1449 float ver_val = 1.0f / esq_h;
1450 for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1451 return;
1452 }
1453
1454 const float e_recip = 1.0f / (float)total;
1455 memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1456 memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1457 const unsigned int *cur_esq_row;
1458 for (i = 0; i < esq_h - 1; i++) {
1459 cur_esq_row = esq + i * esq_w;
1460 for (j = 0; j < esq_w - 1; j++) {
1461 hordist[j] += (float)cur_esq_row[j];
1462 verdist[i] += (float)cur_esq_row[j];
1463 }
1464 verdist[i] += (float)cur_esq_row[j];
1465 }
1466 cur_esq_row = esq + i * esq_w;
1467 for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1468
1469 for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1470 for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1471 }
1472
check_bit_mask(uint16_t mask,int val)1473 static inline bool check_bit_mask(uint16_t mask, int val) {
1474 return mask & (1 << val);
1475 }
1476
set_bit_mask(uint16_t * mask,int val)1477 static inline void set_bit_mask(uint16_t *mask, int val) {
1478 *mask |= (1 << val);
1479 }
1480
unset_bit_mask(uint16_t * mask,int val)1481 static inline void unset_bit_mask(uint16_t *mask, int val) {
1482 *mask &= ~(1 << val);
1483 }
1484
prune_tx_2D(MACROBLOCK * x,BLOCK_SIZE bsize,TX_SIZE tx_size,int blk_row,int blk_col,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode,int * txk_map,uint16_t * allowed_tx_mask)1485 static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1486 int blk_row, int blk_col, TxSetType tx_set_type,
1487 TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
1488 uint16_t *allowed_tx_mask) {
1489 // This table is used because the search order is different from the enum
1490 // order.
1491 static const int tx_type_table_2D[16] = {
1492 DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
1493 ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
1494 FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1495 H_DCT, H_ADST, H_FLIPADST, IDTX
1496 };
1497 if (tx_set_type != EXT_TX_SET_ALL16 &&
1498 tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1499 return;
1500 #if CONFIG_NN_V2
1501 NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1502 NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1503 #else
1504 const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1505 const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1506 #endif
1507 if (!nn_config_hor || !nn_config_ver) return; // Model not established yet.
1508
1509 float hfeatures[16], vfeatures[16];
1510 float hscores[4], vscores[4];
1511 float scores_2D_raw[16];
1512 const int bw = tx_size_wide[tx_size];
1513 const int bh = tx_size_high[tx_size];
1514 const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1515 const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1516 assert(hfeatures_num <= 16);
1517 assert(vfeatures_num <= 16);
1518
1519 const struct macroblock_plane *const p = &x->plane[0];
1520 const int diff_stride = block_size_wide[bsize];
1521 const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1522 get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1523 vfeatures);
1524
1525 av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
1526 &hfeatures[hfeatures_num - 1],
1527 &vfeatures[vfeatures_num - 1]);
1528
1529 #if CONFIG_NN_V2
1530 av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
1531 av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
1532 #else
1533 av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
1534 av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
1535 #endif
1536
1537 for (int i = 0; i < 4; i++) {
1538 float *cur_scores_2D = scores_2D_raw + i * 4;
1539 cur_scores_2D[0] = vscores[i] * hscores[0];
1540 cur_scores_2D[1] = vscores[i] * hscores[1];
1541 cur_scores_2D[2] = vscores[i] * hscores[2];
1542 cur_scores_2D[3] = vscores[i] * hscores[3];
1543 }
1544
1545 assert(TX_TYPES == 16);
1546 // This version of the function only works when there are at most 16 classes.
1547 // So we will need to change the optimization or use av1_nn_softmax instead if
1548 // this ever gets changed.
1549 av1_nn_fast_softmax_16(scores_2D_raw, scores_2D_raw);
1550
1551 const float score_thresh =
1552 get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode);
1553
1554 // Always keep the TX type with the highest score, prune all others with
1555 // score below score_thresh.
1556 int max_score_i = 0;
1557 float max_score = 0.0f;
1558 uint16_t allow_bitmask = 0;
1559 float sum_score = 0.0;
1560 // Calculate sum of allowed tx type score and Populate allow bit mask based
1561 // on score_thresh and allowed_tx_mask
1562 int allow_count = 0;
1563 int tx_type_allowed[16] = { TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1564 TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1565 TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1566 TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1567 TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1568 TX_TYPE_INVALID };
1569 float scores_2D[16] = {
1570 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
1571 };
1572 for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1573 const int allow_tx_type =
1574 check_bit_mask(*allowed_tx_mask, tx_type_table_2D[tx_idx]);
1575 if (!allow_tx_type) {
1576 continue;
1577 }
1578 if (scores_2D_raw[tx_idx] > max_score) {
1579 max_score = scores_2D_raw[tx_idx];
1580 max_score_i = tx_idx;
1581 }
1582 if (scores_2D_raw[tx_idx] >= score_thresh) {
1583 // Set allow mask based on score_thresh
1584 set_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]);
1585
1586 // Accumulate score of allowed tx type
1587 sum_score += scores_2D_raw[tx_idx];
1588
1589 scores_2D[allow_count] = scores_2D_raw[tx_idx];
1590 tx_type_allowed[allow_count] = tx_type_table_2D[tx_idx];
1591 allow_count += 1;
1592 }
1593 }
1594 if (!check_bit_mask(allow_bitmask, tx_type_table_2D[max_score_i])) {
1595 // If even the tx_type with max score is pruned, this means that no other
1596 // tx_type is feasible. When this happens, we force enable max_score_i and
1597 // end the search.
1598 set_bit_mask(&allow_bitmask, tx_type_table_2D[max_score_i]);
1599 memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
1600 *allowed_tx_mask = allow_bitmask;
1601 return;
1602 }
1603
1604 // Sort tx type probability of all types
1605 if (allow_count <= 8) {
1606 av1_sort_fi32_8(scores_2D, tx_type_allowed);
1607 } else {
1608 av1_sort_fi32_16(scores_2D, tx_type_allowed);
1609 }
1610
1611 // Enable more pruning based on tx type probability and number of allowed tx
1612 // types
1613 if (prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) {
1614 float temp_score = 0.0;
1615 float score_ratio = 0.0;
1616 int tx_idx, tx_count = 0;
1617 const float inv_sum_score = 100 / sum_score;
1618 // Get allowed tx types based on sorted probability score and tx count
1619 for (tx_idx = 0; tx_idx < allow_count; tx_idx++) {
1620 // Skip the tx type which has more than 30% of cumulative
1621 // probability and allowed tx type count is more than 2
1622 if (score_ratio > 30.0 && tx_count >= 2) break;
1623
1624 assert(check_bit_mask(allow_bitmask, tx_type_allowed[tx_idx]));
1625 // Calculate cumulative probability
1626 temp_score += scores_2D[tx_idx];
1627
1628 // Calculate percentage of cumulative probability of allowed tx type
1629 score_ratio = temp_score * inv_sum_score;
1630 tx_count++;
1631 }
1632 // Set remaining tx types as pruned
1633 for (; tx_idx < allow_count; tx_idx++)
1634 unset_bit_mask(&allow_bitmask, tx_type_allowed[tx_idx]);
1635 }
1636
1637 memcpy(txk_map, tx_type_allowed, sizeof(tx_type_table_2D));
1638 *allowed_tx_mask = allow_bitmask;
1639 }
1640
get_dev(float mean,double x2_sum,int num)1641 static float get_dev(float mean, double x2_sum, int num) {
1642 const float e_x2 = (float)(x2_sum / num);
1643 const float diff = e_x2 - mean * mean;
1644 const float dev = (diff > 0) ? sqrtf(diff) : 0;
1645 return dev;
1646 }
1647
1648 // Writes the features required by the ML model to predict tx split based on
1649 // mean and standard deviation values of the block and sub-blocks.
1650 // Returns the number of elements written to the output array which is at most
1651 // 12 currently. Hence 'features' buffer should be able to accommodate at least
1652 // 12 elements.
get_mean_dev_features(const int16_t * data,int stride,int bw,int bh,float * features)1653 static inline int get_mean_dev_features(const int16_t *data, int stride, int bw,
1654 int bh, float *features) {
1655 const int16_t *const data_ptr = &data[0];
1656 const int subh = (bh >= bw) ? (bh >> 1) : bh;
1657 const int subw = (bw >= bh) ? (bw >> 1) : bw;
1658 const int num = bw * bh;
1659 const int sub_num = subw * subh;
1660 int feature_idx = 2;
1661 int total_x_sum = 0;
1662 int64_t total_x2_sum = 0;
1663 int num_sub_blks = 0;
1664 double mean2_sum = 0.0f;
1665 float dev_sum = 0.0f;
1666
1667 for (int row = 0; row < bh; row += subh) {
1668 for (int col = 0; col < bw; col += subw) {
1669 int x_sum;
1670 int64_t x2_sum;
1671 // TODO(any): Write a SIMD version. Clear registers.
1672 aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
1673 &x_sum, &x2_sum);
1674 total_x_sum += x_sum;
1675 total_x2_sum += x2_sum;
1676
1677 const float mean = (float)x_sum / sub_num;
1678 const float dev = get_dev(mean, (double)x2_sum, sub_num);
1679 features[feature_idx++] = mean;
1680 features[feature_idx++] = dev;
1681 mean2_sum += (double)(mean * mean);
1682 dev_sum += dev;
1683 num_sub_blks++;
1684 }
1685 }
1686
1687 const float lvl0_mean = (float)total_x_sum / num;
1688 features[0] = lvl0_mean;
1689 features[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
1690
1691 // Deviation of means.
1692 features[feature_idx++] = get_dev(lvl0_mean, mean2_sum, num_sub_blks);
1693 // Mean of deviations.
1694 features[feature_idx++] = dev_sum / num_sub_blks;
1695
1696 return feature_idx;
1697 }
1698
ml_predict_tx_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size)1699 static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
1700 int blk_col, TX_SIZE tx_size) {
1701 const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
1702 if (!nn_config) return -1;
1703
1704 const int diff_stride = block_size_wide[bsize];
1705 const int16_t *diff =
1706 x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1707 const int bw = tx_size_wide[tx_size];
1708 const int bh = tx_size_high[tx_size];
1709
1710 float features[64] = { 0.0f };
1711 get_mean_dev_features(diff, diff_stride, bw, bh, features);
1712
1713 float score = 0.0f;
1714 av1_nn_predict(features, nn_config, 1, &score);
1715
1716 int int_score = (int)(score * 10000);
1717 return clamp(int_score, -80000, 80000);
1718 }
1719
get_tx_mask(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_best_rd,TX_TYPE * allowed_txk_types,int * txk_map)1720 static inline uint16_t get_tx_mask(
1721 const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block, int blk_row,
1722 int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1723 const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
1724 int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {
1725 const AV1_COMMON *cm = &cpi->common;
1726 MACROBLOCKD *xd = &x->e_mbd;
1727 MB_MODE_INFO *mbmi = xd->mi[0];
1728 const TxfmSearchParams *txfm_params = &x->txfm_search_params;
1729 const int is_inter = is_inter_block(mbmi);
1730 const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
1731 // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
1732 // TX_TYPES, only that specific tx type is allowed.
1733 TX_TYPE txk_allowed = TX_TYPES;
1734
1735 const FRAME_UPDATE_TYPE update_type =
1736 get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index);
1737 int use_actual_frame_probs = 1;
1738 const int *tx_type_probs;
1739 #if CONFIG_FPMT_TEST
1740 use_actual_frame_probs =
1741 (cpi->ppi->fpmt_unit_test_cfg == PARALLEL_SIMULATION_ENCODE) ? 0 : 1;
1742 if (!use_actual_frame_probs) {
1743 tx_type_probs =
1744 (int *)cpi->ppi->temp_frame_probs.tx_type_probs[update_type][tx_size];
1745 }
1746 #endif
1747 if (use_actual_frame_probs) {
1748 tx_type_probs = cpi->ppi->frame_probs.tx_type_probs[update_type][tx_size];
1749 }
1750
1751 if ((!is_inter && txfm_params->use_default_intra_tx_type) ||
1752 (is_inter && txfm_params->default_inter_tx_type_prob_thresh == 0)) {
1753 txk_allowed =
1754 get_default_tx_type(0, xd, tx_size, cpi->use_screen_content_tools);
1755 } else if (is_inter &&
1756 txfm_params->default_inter_tx_type_prob_thresh != INT_MAX) {
1757 if (tx_type_probs[DEFAULT_INTER_TX_TYPE] >
1758 txfm_params->default_inter_tx_type_prob_thresh) {
1759 txk_allowed = DEFAULT_INTER_TX_TYPE;
1760 } else {
1761 int force_tx_type = 0;
1762 int max_prob = 0;
1763 const int tx_type_prob_threshold =
1764 txfm_params->default_inter_tx_type_prob_thresh +
1765 PROB_THRESH_OFFSET_TX_TYPE;
1766 for (int i = 1; i < TX_TYPES; i++) { // find maximum probability.
1767 if (tx_type_probs[i] > max_prob) {
1768 max_prob = tx_type_probs[i];
1769 force_tx_type = i;
1770 }
1771 }
1772 if (max_prob > tx_type_prob_threshold) // force tx type with max prob.
1773 txk_allowed = force_tx_type;
1774 else if (x->rd_model == LOW_TXFM_RD) {
1775 if (plane == 0) txk_allowed = DCT_DCT;
1776 }
1777 }
1778 } else if (x->rd_model == LOW_TXFM_RD) {
1779 if (plane == 0) txk_allowed = DCT_DCT;
1780 }
1781
1782 const TxSetType tx_set_type = av1_get_ext_tx_set_type(
1783 tx_size, is_inter, cm->features.reduced_tx_set_used);
1784
1785 TX_TYPE uv_tx_type = DCT_DCT;
1786 if (plane) {
1787 // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
1788 uv_tx_type = txk_allowed =
1789 av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1790 cm->features.reduced_tx_set_used);
1791 }
1792 PREDICTION_MODE intra_dir =
1793 mbmi->filter_intra_mode_info.use_filter_intra
1794 ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
1795 : mbmi->mode;
1796 uint16_t ext_tx_used_flag =
1797 cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset != 0 &&
1798 tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
1799 ? av1_reduced_intra_tx_used_flag[intra_dir]
1800 : av1_ext_tx_used_flag[tx_set_type];
1801
1802 if (cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset == 2)
1803 ext_tx_used_flag &= av1_derived_intra_tx_used_flag[intra_dir];
1804
1805 if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
1806 ext_tx_used_flag == 0x0001 ||
1807 (is_inter && cpi->oxcf.txfm_cfg.use_inter_dct_only) ||
1808 (!is_inter && cpi->oxcf.txfm_cfg.use_intra_dct_only)) {
1809 txk_allowed = DCT_DCT;
1810 }
1811
1812 if (cpi->oxcf.txfm_cfg.enable_flip_idtx == 0)
1813 ext_tx_used_flag &= DCT_ADST_TX_MASK;
1814
1815 uint16_t allowed_tx_mask = 0; // 1: allow; 0: skip.
1816 if (txk_allowed < TX_TYPES) {
1817 allowed_tx_mask = 1 << txk_allowed;
1818 allowed_tx_mask &= ext_tx_used_flag;
1819 } else if (fast_tx_search) {
1820 allowed_tx_mask = 0x0c01; // V_DCT, H_DCT, DCT_DCT
1821 allowed_tx_mask &= ext_tx_used_flag;
1822 } else {
1823 assert(plane == 0);
1824 allowed_tx_mask = ext_tx_used_flag;
1825 int num_allowed = 0;
1826 int i;
1827
1828 if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
1829 static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 },
1830 { 10, 17, 17, 10, 17, 17, 17 } };
1831 const int thresh =
1832 thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1]
1833 [update_type];
1834 uint16_t prune = 0;
1835 int max_prob = -1;
1836 int max_idx = 0;
1837 for (i = 0; i < TX_TYPES; i++) {
1838 if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
1839 max_prob = tx_type_probs[i];
1840 max_idx = i;
1841 }
1842 if (tx_type_probs[i] < thresh) prune |= (1 << i);
1843 }
1844 if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx);
1845 allowed_tx_mask &= (~prune);
1846 }
1847 for (i = 0; i < TX_TYPES; i++) {
1848 if (allowed_tx_mask & (1 << i)) num_allowed++;
1849 }
1850 assert(num_allowed > 0);
1851
1852 if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
1853 int pf = prune_factors[txfm_params->prune_2d_txfm_mode];
1854 int mf = mul_factors[txfm_params->prune_2d_txfm_mode];
1855 if (num_allowed <= 7) {
1856 const uint16_t prune =
1857 prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col,
1858 plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx,
1859 cm->features.reduced_tx_set_used);
1860 allowed_tx_mask &= (~prune);
1861 } else {
1862 const int num_sel = (num_allowed * mf + 50) / 100;
1863 const uint16_t prune = prune_txk_type_separ(
1864 cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
1865 txk_map, allowed_tx_mask, pf, txb_ctx,
1866 cm->features.reduced_tx_set_used, ref_best_rd, num_sel);
1867
1868 allowed_tx_mask &= (~prune);
1869 }
1870 } else {
1871 assert(num_allowed > 0);
1872 int allowed_tx_count =
1873 (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) ? 1 : 5;
1874 // !fast_tx_search && txk_end != txk_start && plane == 0
1875 if (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_1 && is_inter &&
1876 num_allowed > allowed_tx_count) {
1877 prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
1878 txfm_params->prune_2d_txfm_mode, txk_map, &allowed_tx_mask);
1879 }
1880 }
1881 }
1882
1883 // Need to have at least one transform type allowed.
1884 if (allowed_tx_mask == 0) {
1885 txk_allowed = (plane ? uv_tx_type : DCT_DCT);
1886 allowed_tx_mask = (1 << txk_allowed);
1887 }
1888
1889 assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
1890 *allowed_txk_types = txk_allowed;
1891 return allowed_tx_mask;
1892 }
1893
1894 #if CONFIG_RD_DEBUG
update_txb_coeff_cost(RD_STATS * rd_stats,int plane,int txb_coeff_cost)1895 static inline void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
1896 int txb_coeff_cost) {
1897 rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
1898 }
1899 #endif
1900
cost_coeffs(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,const TX_TYPE tx_type,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)1901 static inline int cost_coeffs(MACROBLOCK *x, int plane, int block,
1902 TX_SIZE tx_size, const TX_TYPE tx_type,
1903 const TXB_CTX *const txb_ctx,
1904 int reduced_tx_set_used) {
1905 #if TXCOEFF_COST_TIMER
1906 struct aom_usec_timer timer;
1907 aom_usec_timer_start(&timer);
1908 #endif
1909 const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type,
1910 txb_ctx, reduced_tx_set_used);
1911 #if TXCOEFF_COST_TIMER
1912 AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
1913 aom_usec_timer_mark(&timer);
1914 const int64_t elapsed_time = aom_usec_timer_elapsed(&timer);
1915 tmp_cm->txcoeff_cost_timer += elapsed_time;
1916 ++tmp_cm->txcoeff_cost_count;
1917 #endif
1918 return cost;
1919 }
1920
skip_trellis_opt_based_on_satd(MACROBLOCK * x,QUANT_PARAM * quant_param,int plane,int block,TX_SIZE tx_size,int quant_b_adapt,int qstep,unsigned int coeff_opt_satd_threshold,int skip_trellis,int dc_only_blk)1921 static int skip_trellis_opt_based_on_satd(MACROBLOCK *x,
1922 QUANT_PARAM *quant_param, int plane,
1923 int block, TX_SIZE tx_size,
1924 int quant_b_adapt, int qstep,
1925 unsigned int coeff_opt_satd_threshold,
1926 int skip_trellis, int dc_only_blk) {
1927 if (skip_trellis || (coeff_opt_satd_threshold == UINT_MAX))
1928 return skip_trellis;
1929
1930 const struct macroblock_plane *const p = &x->plane[plane];
1931 const int block_offset = BLOCK_OFFSET(block);
1932 tran_low_t *const coeff_ptr = p->coeff + block_offset;
1933 const int n_coeffs = av1_get_max_eob(tx_size);
1934 const int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size));
1935 int satd = (dc_only_blk) ? abs(coeff_ptr[0]) : aom_satd(coeff_ptr, n_coeffs);
1936 satd = RIGHT_SIGNED_SHIFT(satd, shift);
1937 satd >>= (x->e_mbd.bd - 8);
1938
1939 const int skip_block_trellis =
1940 ((uint64_t)satd >
1941 (uint64_t)coeff_opt_satd_threshold * qstep * sqrt_tx_pixels_2d[tx_size]);
1942
1943 av1_setup_quant(
1944 tx_size, !skip_block_trellis,
1945 skip_block_trellis
1946 ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP)
1947 : AV1_XFORM_QUANT_FP,
1948 quant_b_adapt, quant_param);
1949
1950 return skip_block_trellis;
1951 }
1952
1953 // Predict DC only blocks if the residual variance is below a qstep based
1954 // threshold.For such blocks, transform type search is bypassed.
predict_dc_only_block(MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,int block,int blk_row,int blk_col,RD_STATS * best_rd_stats,int64_t * block_sse,unsigned int * block_mse_q8,int64_t * per_px_mean,int * dc_only_blk)1955 static inline void predict_dc_only_block(
1956 MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1957 int block, int blk_row, int blk_col, RD_STATS *best_rd_stats,
1958 int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean,
1959 int *dc_only_blk) {
1960 MACROBLOCKD *xd = &x->e_mbd;
1961 MB_MODE_INFO *mbmi = xd->mi[0];
1962 const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
1963 const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
1964 uint64_t block_var = UINT64_MAX;
1965 const int dc_qstep = x->plane[plane].dequant_QTX[0] >> 3;
1966 *block_sse = pixel_diff_stats(x, plane, blk_row, blk_col, plane_bsize,
1967 txsize_to_bsize[tx_size], block_mse_q8,
1968 per_px_mean, &block_var);
1969 assert((*block_mse_q8) != UINT_MAX);
1970 uint64_t var_threshold = (uint64_t)(1.8 * qstep * qstep);
1971 if (is_cur_buf_hbd(xd))
1972 block_var = ROUND_POWER_OF_TWO(block_var, (xd->bd - 8) * 2);
1973
1974 if (block_var >= var_threshold) return;
1975 const unsigned int predict_dc_level = x->txfm_search_params.predict_dc_level;
1976 assert(predict_dc_level != 0);
1977
1978 // Prediction of skip block if residual mean and variance are less
1979 // than qstep based threshold
1980 if ((llabs(*per_px_mean) * dc_coeff_scale[tx_size]) < (dc_qstep << 12)) {
1981 // If the normalized mean of residual block is less than the dc qstep and
1982 // the normalized block variance is less than ac qstep, then the block is
1983 // assumed to be a skip block and its rdcost is updated accordingly.
1984 best_rd_stats->skip_txfm = 1;
1985
1986 x->plane[plane].eobs[block] = 0;
1987
1988 if (is_cur_buf_hbd(xd))
1989 *block_sse = ROUND_POWER_OF_TWO((*block_sse), (xd->bd - 8) * 2);
1990
1991 best_rd_stats->dist = (*block_sse) << 4;
1992 best_rd_stats->sse = best_rd_stats->dist;
1993
1994 ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
1995 ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
1996 av1_get_entropy_contexts(plane_bsize, &xd->plane[plane], ctxa, ctxl);
1997 ENTROPY_CONTEXT *ta = ctxa;
1998 ENTROPY_CONTEXT *tl = ctxl;
1999 const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2000 TXB_CTX txb_ctx_tmp;
2001 const PLANE_TYPE plane_type = get_plane_type(plane);
2002 get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx_tmp);
2003 const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][plane_type]
2004 .txb_skip_cost[txb_ctx_tmp.txb_skip_ctx][1];
2005 best_rd_stats->rate = zero_blk_rate;
2006
2007 best_rd_stats->rdcost =
2008 RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->sse);
2009
2010 x->plane[plane].txb_entropy_ctx[block] = 0;
2011 } else if (predict_dc_level > 1) {
2012 // Predict DC only blocks based on residual variance.
2013 // For chroma plane, this prediction is disabled for intra blocks.
2014 if ((plane == 0) || (plane > 0 && is_inter_block(mbmi))) *dc_only_blk = 1;
2015 }
2016 }
2017
2018 // Search for the best transform type for a given transform block.
2019 // This function can be used for both inter and intra, both luma and chroma.
search_tx_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis,int64_t ref_best_rd,RD_STATS * best_rd_stats)2020 static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2021 int block, int blk_row, int blk_col,
2022 BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2023 const TXB_CTX *const txb_ctx,
2024 FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis,
2025 int64_t ref_best_rd, RD_STATS *best_rd_stats) {
2026 const AV1_COMMON *cm = &cpi->common;
2027 MACROBLOCKD *xd = &x->e_mbd;
2028 MB_MODE_INFO *mbmi = xd->mi[0];
2029 const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2030 int64_t best_rd = INT64_MAX;
2031 uint16_t best_eob = 0;
2032 TX_TYPE best_tx_type = DCT_DCT;
2033 int rate_cost = 0;
2034 struct macroblock_plane *const p = &x->plane[plane];
2035 tran_low_t *orig_dqcoeff = p->dqcoeff;
2036 tran_low_t *best_dqcoeff = x->dqcoeff_buf;
2037 const int tx_type_map_idx =
2038 plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
2039 av1_invalid_rd_stats(best_rd_stats);
2040
2041 skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id],
2042 DRY_RUN_NORMAL);
2043
2044 uint8_t best_txb_ctx = 0;
2045 // txk_allowed = TX_TYPES: >1 tx types are allowed
2046 // txk_allowed < TX_TYPES: only that specific tx type is allowed.
2047 TX_TYPE txk_allowed = TX_TYPES;
2048 int txk_map[TX_TYPES] = {
2049 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
2050 };
2051 const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2052 const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2053
2054 const uint8_t txw = tx_size_wide[tx_size];
2055 const uint8_t txh = tx_size_high[tx_size];
2056 int64_t block_sse;
2057 unsigned int block_mse_q8;
2058 int dc_only_blk = 0;
2059 const bool predict_dc_block =
2060 txfm_params->predict_dc_level >= 1 && txw != 64 && txh != 64;
2061 int64_t per_px_mean = INT64_MAX;
2062 if (predict_dc_block) {
2063 predict_dc_only_block(x, plane, plane_bsize, tx_size, block, blk_row,
2064 blk_col, best_rd_stats, &block_sse, &block_mse_q8,
2065 &per_px_mean, &dc_only_blk);
2066 if (best_rd_stats->skip_txfm == 1) {
2067 const TX_TYPE tx_type = DCT_DCT;
2068 if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2069 return;
2070 }
2071 } else {
2072 block_sse = av1_pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
2073 txsize_to_bsize[tx_size], &block_mse_q8);
2074 assert(block_mse_q8 != UINT_MAX);
2075 }
2076
2077 // Bit mask to indicate which transform types are allowed in the RD search.
2078 uint16_t tx_mask;
2079
2080 // Use DCT_DCT transform for DC only block.
2081 if (dc_only_blk || cpi->sf.rt_sf.dct_only_palette_nonrd == 1)
2082 tx_mask = 1 << DCT_DCT;
2083 else
2084 tx_mask = get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize,
2085 tx_size, txb_ctx, ftxs_mode, ref_best_rd,
2086 &txk_allowed, txk_map);
2087 const uint16_t allowed_tx_mask = tx_mask;
2088
2089 if (is_cur_buf_hbd(xd)) {
2090 block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
2091 block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
2092 }
2093 block_sse *= 16;
2094 // Use mse / qstep^2 based threshold logic to take decision of R-D
2095 // optimization of coeffs. For smaller residuals, coeff optimization
2096 // would be helpful. For larger residuals, R-D optimization may not be
2097 // effective.
2098 // TODO(any): Experiment with variance and mean based thresholds
2099 const int perform_block_coeff_opt =
2100 ((uint64_t)block_mse_q8 <=
2101 (uint64_t)txfm_params->coeff_opt_thresholds[0] * qstep * qstep);
2102 skip_trellis |= !perform_block_coeff_opt;
2103
2104 // Flag to indicate if distortion should be calculated in transform domain or
2105 // not during iterating through transform type candidates.
2106 // Transform domain distortion is accurate for higher residuals.
2107 // TODO(any): Experiment with variance and mean based thresholds
2108 int use_transform_domain_distortion =
2109 (txfm_params->use_transform_domain_distortion > 0) &&
2110 (block_mse_q8 >= txfm_params->tx_domain_dist_threshold) &&
2111 // Any 64-pt transforms only preserves half the coefficients.
2112 // Therefore transform domain distortion is not valid for these
2113 // transform sizes.
2114 (txsize_sqr_up_map[tx_size] != TX_64X64) &&
2115 // Use pixel domain distortion for DC only blocks
2116 !dc_only_blk;
2117 // Flag to indicate if an extra calculation of distortion in the pixel domain
2118 // should be performed at the end, after the best transform type has been
2119 // decided.
2120 int calc_pixel_domain_distortion_final =
2121 txfm_params->use_transform_domain_distortion == 1 &&
2122 use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
2123 if (calc_pixel_domain_distortion_final &&
2124 (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
2125 calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
2126
2127 const uint16_t *eobs_ptr = x->plane[plane].eobs;
2128
2129 TxfmParam txfm_param;
2130 QUANT_PARAM quant_param;
2131 int skip_trellis_based_on_satd[TX_TYPES] = { 0 };
2132 av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
2133 av1_setup_quant(tx_size, !skip_trellis,
2134 skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
2135 : AV1_XFORM_QUANT_FP)
2136 : AV1_XFORM_QUANT_FP,
2137 cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
2138
2139 // Iterate through all transform type candidates.
2140 for (int idx = 0; idx < TX_TYPES; ++idx) {
2141 const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
2142 if (tx_type == TX_TYPE_INVALID || !check_bit_mask(allowed_tx_mask, tx_type))
2143 continue;
2144 txfm_param.tx_type = tx_type;
2145 if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
2146 av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
2147 &quant_param);
2148 }
2149 if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2150 RD_STATS this_rd_stats;
2151 av1_invalid_rd_stats(&this_rd_stats);
2152
2153 if (!dc_only_blk)
2154 av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
2155 else
2156 av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
2157
2158 skip_trellis_based_on_satd[tx_type] = skip_trellis_opt_based_on_satd(
2159 x, &quant_param, plane, block, tx_size, cpi->oxcf.q_cfg.quant_b_adapt,
2160 qstep, txfm_params->coeff_opt_thresholds[1], skip_trellis, dc_only_blk);
2161
2162 av1_quant(x, plane, block, &txfm_param, &quant_param);
2163
2164 // Calculate rate cost of quantized coefficients.
2165 if (quant_param.use_optimize_b) {
2166 // TODO(aomedia:3209): update Trellis quantization to take into account
2167 // quantization matrices.
2168 av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
2169 &rate_cost);
2170 } else {
2171 rate_cost = cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
2172 cm->features.reduced_tx_set_used);
2173 }
2174
2175 // If rd cost based on coeff rate alone is already more than best_rd,
2176 // terminate early.
2177 if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
2178
2179 // Calculate distortion.
2180 if (eobs_ptr[block] == 0) {
2181 // When eob is 0, pixel domain distortion is more efficient and accurate.
2182 this_rd_stats.dist = this_rd_stats.sse = block_sse;
2183 } else if (dc_only_blk) {
2184 this_rd_stats.sse = block_sse;
2185 this_rd_stats.dist = dist_block_px_domain(
2186 cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2187 } else if (use_transform_domain_distortion) {
2188 const SCAN_ORDER *const scan_order =
2189 get_scan(txfm_param.tx_size, txfm_param.tx_type);
2190 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2191 scan_order->scan, &this_rd_stats.dist,
2192 &this_rd_stats.sse);
2193 } else {
2194 int64_t sse_diff = INT64_MAX;
2195 // high_energy threshold assumes that every pixel within a txfm block
2196 // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
2197 // for 8 bit.
2198 const int64_t high_energy_thresh =
2199 ((int64_t)128 * 128 * tx_size_2d[tx_size]);
2200 const int is_high_energy = (block_sse >= high_energy_thresh);
2201 if (tx_size == TX_64X64 || is_high_energy) {
2202 // Because 3 out 4 quadrants of transform coefficients are forced to
2203 // zero, the inverse transform has a tendency to overflow. sse_diff
2204 // is effectively the energy of those 3 quadrants, here we use it
2205 // to decide if we should do pixel domain distortion. If the energy
2206 // is mostly in first quadrant, then it is unlikely that we have
2207 // overflow issue in inverse transform.
2208 const SCAN_ORDER *const scan_order =
2209 get_scan(txfm_param.tx_size, txfm_param.tx_type);
2210 dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2211 scan_order->scan, &this_rd_stats.dist,
2212 &this_rd_stats.sse);
2213 sse_diff = block_sse - this_rd_stats.sse;
2214 }
2215 if (tx_size != TX_64X64 || !is_high_energy ||
2216 (sse_diff * 2) < this_rd_stats.sse) {
2217 const int64_t tx_domain_dist = this_rd_stats.dist;
2218 this_rd_stats.dist = dist_block_px_domain(
2219 cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2220 // For high energy blocks, occasionally, the pixel domain distortion
2221 // can be artificially low due to clamping at reconstruction stage
2222 // even when inverse transform output is hugely different from the
2223 // actual residue.
2224 if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
2225 this_rd_stats.dist = tx_domain_dist;
2226 } else {
2227 assert(sse_diff < INT64_MAX);
2228 this_rd_stats.dist += sse_diff;
2229 }
2230 this_rd_stats.sse = block_sse;
2231 }
2232
2233 this_rd_stats.rate = rate_cost;
2234
2235 const int64_t rd =
2236 RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2237
2238 if (rd < best_rd) {
2239 best_rd = rd;
2240 *best_rd_stats = this_rd_stats;
2241 best_tx_type = tx_type;
2242 best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
2243 best_eob = x->plane[plane].eobs[block];
2244 // Swap dqcoeff buffers
2245 tran_low_t *const tmp_dqcoeff = best_dqcoeff;
2246 best_dqcoeff = p->dqcoeff;
2247 p->dqcoeff = tmp_dqcoeff;
2248 }
2249
2250 #if CONFIG_COLLECT_RD_STATS == 1
2251 if (plane == 0) {
2252 PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
2253 plane_bsize, tx_size, tx_type, rd);
2254 }
2255 #endif // CONFIG_COLLECT_RD_STATS == 1
2256
2257 #if COLLECT_TX_SIZE_DATA
2258 // Generate small sample to restrict output size.
2259 static unsigned int seed = 21743;
2260 if (lcg_rand16(&seed) % 200 == 0) {
2261 FILE *fp = NULL;
2262
2263 if (within_border) {
2264 fp = fopen(av1_tx_size_data_output_file, "a");
2265 }
2266
2267 if (fp) {
2268 // Transform info and RD
2269 const int txb_w = tx_size_wide[tx_size];
2270 const int txb_h = tx_size_high[tx_size];
2271
2272 // Residue signal.
2273 const int diff_stride = block_size_wide[plane_bsize];
2274 struct macroblock_plane *const p = &x->plane[plane];
2275 const int16_t *src_diff =
2276 &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
2277
2278 for (int r = 0; r < txb_h; ++r) {
2279 for (int c = 0; c < txb_w; ++c) {
2280 fprintf(fp, "%d,", src_diff[c]);
2281 }
2282 src_diff += diff_stride;
2283 }
2284
2285 fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
2286 fprintf(fp, "\n");
2287 fclose(fp);
2288 }
2289 }
2290 #endif // COLLECT_TX_SIZE_DATA
2291
2292 // If the current best RD cost is much worse than the reference RD cost,
2293 // terminate early.
2294 if (cpi->sf.tx_sf.adaptive_txb_search_level) {
2295 if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
2296 ref_best_rd) {
2297 break;
2298 }
2299 }
2300
2301 // Terminate transform type search if the block has been quantized to
2302 // all zero.
2303 if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
2304 }
2305
2306 assert(best_rd != INT64_MAX);
2307
2308 best_rd_stats->skip_txfm = best_eob == 0;
2309 if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2310 x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
2311 x->plane[plane].eobs[block] = best_eob;
2312 skip_trellis = skip_trellis_based_on_satd[best_tx_type];
2313
2314 // Point dqcoeff to the quantized coefficients corresponding to the best
2315 // transform type, then we can skip transform and quantization, e.g. in the
2316 // final pixel domain distortion calculation and recon_intra().
2317 p->dqcoeff = best_dqcoeff;
2318
2319 if (calc_pixel_domain_distortion_final && best_eob) {
2320 best_rd_stats->dist = dist_block_px_domain(
2321 cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2322 best_rd_stats->sse = block_sse;
2323 }
2324
2325 // Intra mode needs decoded pixels such that the next transform block
2326 // can use them for prediction.
2327 recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2328 txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
2329 p->dqcoeff = orig_dqcoeff;
2330 }
2331
2332 // Pick transform type for a luma transform block of tx_size. Note this function
2333 // is used only for inter-predicted blocks.
tx_type_rd(const AV1_COMP * cpi,MACROBLOCK * x,TX_SIZE tx_size,int blk_row,int blk_col,int block,int plane_bsize,TXB_CTX * txb_ctx,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_rdcost)2334 static inline void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
2335 TX_SIZE tx_size, int blk_row, int blk_col,
2336 int block, int plane_bsize, TXB_CTX *txb_ctx,
2337 RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode,
2338 int64_t ref_rdcost) {
2339 assert(is_inter_block(x->e_mbd.mi[0]));
2340 RD_STATS this_rd_stats;
2341 const int skip_trellis = 0;
2342 search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
2343 txb_ctx, ftxs_mode, skip_trellis, ref_rdcost, &this_rd_stats);
2344
2345 av1_merge_rd_stats(rd_stats, &this_rd_stats);
2346 }
2347
try_tx_block_no_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,const ENTROPY_CONTEXT * ta,const ENTROPY_CONTEXT * tl,int txfm_partition_ctx,RD_STATS * rd_stats,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TxCandidateInfo * no_split)2348 static inline void try_tx_block_no_split(
2349 const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2350 TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
2351 const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
2352 int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
2353 FAST_TX_SEARCH_MODE ftxs_mode, TxCandidateInfo *no_split) {
2354 MACROBLOCKD *const xd = &x->e_mbd;
2355 MB_MODE_INFO *const mbmi = xd->mi[0];
2356 struct macroblock_plane *const p = &x->plane[0];
2357 const int bw = mi_size_wide[plane_bsize];
2358 const ENTROPY_CONTEXT *const pta = ta + blk_col;
2359 const ENTROPY_CONTEXT *const ptl = tl + blk_row;
2360 const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2361 TXB_CTX txb_ctx;
2362 get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
2363 const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
2364 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
2365 rd_stats->zero_rate = zero_blk_rate;
2366 const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
2367 mbmi->inter_tx_size[index] = tx_size;
2368 tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
2369 rd_stats, ftxs_mode, ref_best_rd);
2370 assert(rd_stats->rate < INT_MAX);
2371
2372 const int pick_skip_txfm =
2373 !xd->lossless[mbmi->segment_id] &&
2374 (rd_stats->skip_txfm == 1 ||
2375 RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
2376 RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse));
2377 if (pick_skip_txfm) {
2378 #if CONFIG_RD_DEBUG
2379 update_txb_coeff_cost(rd_stats, 0, zero_blk_rate - rd_stats->rate);
2380 #endif // CONFIG_RD_DEBUG
2381 rd_stats->rate = zero_blk_rate;
2382 rd_stats->dist = rd_stats->sse;
2383 p->eobs[block] = 0;
2384 update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
2385 }
2386 rd_stats->skip_txfm = pick_skip_txfm;
2387 set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2388 pick_skip_txfm);
2389
2390 if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
2391 rd_stats->rate += x->mode_costs.txfm_partition_cost[txfm_partition_ctx][0];
2392
2393 no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
2394 no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
2395 no_split->tx_type =
2396 xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
2397 }
2398
try_tx_block_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int txfm_partition_ctx,int64_t no_split_rd,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,RD_STATS * split_rd_stats)2399 static inline void try_tx_block_split(
2400 const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2401 TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2402 ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2403 int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
2404 FAST_TX_SEARCH_MODE ftxs_mode, RD_STATS *split_rd_stats) {
2405 assert(tx_size < TX_SIZES_ALL);
2406 MACROBLOCKD *const xd = &x->e_mbd;
2407 const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
2408 const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
2409 const int txb_width = tx_size_wide_unit[tx_size];
2410 const int txb_height = tx_size_high_unit[tx_size];
2411 // Transform size after splitting current block.
2412 const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
2413 const int sub_txb_width = tx_size_wide_unit[sub_txs];
2414 const int sub_txb_height = tx_size_high_unit[sub_txs];
2415 const int sub_step = sub_txb_width * sub_txb_height;
2416 const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width);
2417 assert(nblks > 0);
2418 av1_init_rd_stats(split_rd_stats);
2419 split_rd_stats->rate =
2420 x->mode_costs.txfm_partition_cost[txfm_partition_ctx][1];
2421
2422 for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) {
2423 const int offsetr = blk_row + r;
2424 if (offsetr >= max_blocks_high) break;
2425 for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) {
2426 assert(blk_idx < 4);
2427 const int offsetc = blk_col + c;
2428 if (offsetc >= max_blocks_wide) continue;
2429
2430 RD_STATS this_rd_stats;
2431 int this_cost_valid = 1;
2432 select_tx_block(cpi, x, offsetr, offsetc, block, sub_txs, depth + 1,
2433 plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats,
2434 no_split_rd / nblks, ref_best_rd - split_rd_stats->rdcost,
2435 &this_cost_valid, ftxs_mode);
2436 if (!this_cost_valid) {
2437 split_rd_stats->rdcost = INT64_MAX;
2438 return;
2439 }
2440 av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
2441 split_rd_stats->rdcost =
2442 RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
2443 if (split_rd_stats->rdcost > ref_best_rd) {
2444 split_rd_stats->rdcost = INT64_MAX;
2445 return;
2446 }
2447 block += sub_step;
2448 }
2449 }
2450 }
2451
get_var(float mean,double x2_sum,int num)2452 static float get_var(float mean, double x2_sum, int num) {
2453 const float e_x2 = (float)(x2_sum / num);
2454 const float diff = e_x2 - mean * mean;
2455 return diff;
2456 }
2457
get_blk_var_dev(const int16_t * data,int stride,int bw,int bh,float * dev_of_mean,float * var_of_vars)2458 static inline void get_blk_var_dev(const int16_t *data, int stride, int bw,
2459 int bh, float *dev_of_mean,
2460 float *var_of_vars) {
2461 const int16_t *const data_ptr = &data[0];
2462 const int subh = (bh >= bw) ? (bh >> 1) : bh;
2463 const int subw = (bw >= bh) ? (bw >> 1) : bw;
2464 const int num = bw * bh;
2465 const int sub_num = subw * subh;
2466 int total_x_sum = 0;
2467 int64_t total_x2_sum = 0;
2468 int blk_idx = 0;
2469 float var_sum = 0.0f;
2470 float mean_sum = 0.0f;
2471 double var2_sum = 0.0f;
2472 double mean2_sum = 0.0f;
2473
2474 for (int row = 0; row < bh; row += subh) {
2475 for (int col = 0; col < bw; col += subw) {
2476 int x_sum;
2477 int64_t x2_sum;
2478 aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
2479 &x_sum, &x2_sum);
2480 total_x_sum += x_sum;
2481 total_x2_sum += x2_sum;
2482
2483 const float mean = (float)x_sum / sub_num;
2484 const float var = get_var(mean, (double)x2_sum, sub_num);
2485 mean_sum += mean;
2486 mean2_sum += (double)(mean * mean);
2487 var_sum += var;
2488 var2_sum += var * var;
2489 blk_idx++;
2490 }
2491 }
2492
2493 const float lvl0_mean = (float)total_x_sum / num;
2494 const float block_var = get_var(lvl0_mean, (double)total_x2_sum, num);
2495 mean_sum += lvl0_mean;
2496 mean2_sum += (double)(lvl0_mean * lvl0_mean);
2497 var_sum += block_var;
2498 var2_sum += block_var * block_var;
2499 const float av_mean = mean_sum / 5;
2500
2501 if (blk_idx > 1) {
2502 // Deviation of means.
2503 *dev_of_mean = get_dev(av_mean, mean2_sum, (blk_idx + 1));
2504 // Variance of variances.
2505 const float mean_var = var_sum / (blk_idx + 1);
2506 *var_of_vars = get_var(mean_var, var2_sum, (blk_idx + 1));
2507 }
2508 }
2509
prune_tx_split_no_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size,int * try_no_split,int * try_split,int pruning_level)2510 static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize,
2511 int blk_row, int blk_col, TX_SIZE tx_size,
2512 int *try_no_split, int *try_split,
2513 int pruning_level) {
2514 const int diff_stride = block_size_wide[bsize];
2515 const int16_t *diff =
2516 x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
2517 const int bw = tx_size_wide[tx_size];
2518 const int bh = tx_size_high[tx_size];
2519 float dev_of_means = 0.0f;
2520 float var_of_vars = 0.0f;
2521
2522 // This function calculates the deviation of means, and the variance of pixel
2523 // variances of the block as well as it's sub-blocks.
2524 get_blk_var_dev(diff, diff_stride, bw, bh, &dev_of_means, &var_of_vars);
2525 const int dc_q = x->plane[0].dequant_QTX[0] >> 3;
2526 const int ac_q = x->plane[0].dequant_QTX[1] >> 3;
2527 const int no_split_thresh_scales[4] = { 0, 24, 8, 8 };
2528 const int no_split_thresh_scale = no_split_thresh_scales[pruning_level];
2529 const int split_thresh_scales[4] = { 0, 24, 10, 8 };
2530 const int split_thresh_scale = split_thresh_scales[pruning_level];
2531
2532 if ((dev_of_means <= dc_q) &&
2533 (split_thresh_scale * var_of_vars <= ac_q * ac_q)) {
2534 *try_split = 0;
2535 }
2536 if ((dev_of_means > no_split_thresh_scale * dc_q) &&
2537 (var_of_vars > no_split_thresh_scale * ac_q * ac_q)) {
2538 *try_no_split = 0;
2539 }
2540 }
2541
2542 // Search for the best transform partition(recursive)/type for a given
2543 // inter-predicted luma block. The obtained transform selection will be saved
2544 // in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
select_tx_block(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,RD_STATS * rd_stats,int64_t prev_level_rd,int64_t ref_best_rd,int * is_cost_valid,FAST_TX_SEARCH_MODE ftxs_mode)2545 static inline void select_tx_block(
2546 const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2547 TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2548 ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2549 RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
2550 int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode) {
2551 assert(tx_size < TX_SIZES_ALL);
2552 av1_init_rd_stats(rd_stats);
2553 if (ref_best_rd < 0) {
2554 *is_cost_valid = 0;
2555 return;
2556 }
2557
2558 MACROBLOCKD *const xd = &x->e_mbd;
2559 assert(blk_row < max_block_high(xd, plane_bsize, 0) &&
2560 blk_col < max_block_wide(xd, plane_bsize, 0));
2561 MB_MODE_INFO *const mbmi = xd->mi[0];
2562 const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
2563 mbmi->bsize, tx_size);
2564 struct macroblock_plane *const p = &x->plane[0];
2565
2566 int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 ||
2567 txsize_sqr_up_map[tx_size] != TX_64X64) &&
2568 (cpi->oxcf.txfm_cfg.enable_rect_tx ||
2569 tx_size_wide[tx_size] == tx_size_high[tx_size]);
2570 int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
2571 TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
2572
2573 // Prune tx_split and no-split based on sub-block properties.
2574 if (tx_size != TX_4X4 && try_split == 1 && try_no_split == 1 &&
2575 cpi->sf.tx_sf.prune_tx_size_level > 0) {
2576 prune_tx_split_no_split(x, plane_bsize, blk_row, blk_col, tx_size,
2577 &try_no_split, &try_split,
2578 cpi->sf.tx_sf.prune_tx_size_level);
2579 }
2580
2581 if (cpi->sf.rt_sf.skip_tx_no_split_var_based_partition) {
2582 if (x->try_merge_partition && try_split && p->eobs[block]) try_no_split = 0;
2583 }
2584
2585 // Try using current block as a single transform block without split.
2586 if (try_no_split) {
2587 try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2588 plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
2589 ftxs_mode, &no_split);
2590
2591 // Speed features for early termination.
2592 const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level;
2593 if (search_level) {
2594 if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) {
2595 *is_cost_valid = 0;
2596 return;
2597 }
2598 if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) {
2599 try_split = 0;
2600 }
2601 }
2602 if (cpi->sf.tx_sf.txb_split_cap) {
2603 if (p->eobs[block] == 0) try_split = 0;
2604 }
2605 }
2606
2607 // ML based speed feature to skip searching for split transform blocks.
2608 if (x->e_mbd.bd == 8 && try_split &&
2609 !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
2610 const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
2611 if (threshold >= 0) {
2612 const int split_score =
2613 ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
2614 if (split_score < -threshold) try_split = 0;
2615 }
2616 }
2617
2618 RD_STATS split_rd_stats;
2619 split_rd_stats.rdcost = INT64_MAX;
2620 // Try splitting current block into smaller transform blocks.
2621 if (try_split) {
2622 try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2623 plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
2624 AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
2625 &split_rd_stats);
2626 }
2627
2628 if (no_split.rd < split_rd_stats.rdcost) {
2629 ENTROPY_CONTEXT *pta = ta + blk_col;
2630 ENTROPY_CONTEXT *ptl = tl + blk_row;
2631 p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
2632 av1_set_txb_context(x, 0, block, tx_size, pta, ptl);
2633 txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
2634 tx_size);
2635 for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
2636 for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
2637 const int index =
2638 av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
2639 mbmi->inter_tx_size[index] = tx_size;
2640 }
2641 }
2642 mbmi->tx_size = tx_size;
2643 update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
2644 const int bw = mi_size_wide[plane_bsize];
2645 set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2646 rd_stats->skip_txfm);
2647 } else {
2648 *rd_stats = split_rd_stats;
2649 if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0;
2650 }
2651 }
2652
choose_largest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2653 static inline void choose_largest_tx_size(const AV1_COMP *const cpi,
2654 MACROBLOCK *x, RD_STATS *rd_stats,
2655 int64_t ref_best_rd, BLOCK_SIZE bs) {
2656 MACROBLOCKD *const xd = &x->e_mbd;
2657 MB_MODE_INFO *const mbmi = xd->mi[0];
2658 const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2659 mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2660
2661 // If tx64 is not enabled, we need to go down to the next available size
2662 if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) {
2663 static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
2664 TX_4X4, // 4x4 transform
2665 TX_8X8, // 8x8 transform
2666 TX_16X16, // 16x16 transform
2667 TX_32X32, // 32x32 transform
2668 TX_32X32, // 64x64 transform
2669 TX_4X8, // 4x8 transform
2670 TX_8X4, // 8x4 transform
2671 TX_8X16, // 8x16 transform
2672 TX_16X8, // 16x8 transform
2673 TX_16X32, // 16x32 transform
2674 TX_32X16, // 32x16 transform
2675 TX_32X32, // 32x64 transform
2676 TX_32X32, // 64x32 transform
2677 TX_4X16, // 4x16 transform
2678 TX_16X4, // 16x4 transform
2679 TX_8X32, // 8x32 transform
2680 TX_32X8, // 32x8 transform
2681 TX_16X32, // 16x64 transform
2682 TX_32X16, // 64x16 transform
2683 };
2684 mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
2685 } else if (cpi->oxcf.txfm_cfg.enable_tx64 &&
2686 !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2687 static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = {
2688 TX_4X4, // 4x4 transform
2689 TX_8X8, // 8x8 transform
2690 TX_16X16, // 16x16 transform
2691 TX_32X32, // 32x32 transform
2692 TX_64X64, // 64x64 transform
2693 TX_4X4, // 4x8 transform
2694 TX_4X4, // 8x4 transform
2695 TX_8X8, // 8x16 transform
2696 TX_8X8, // 16x8 transform
2697 TX_16X16, // 16x32 transform
2698 TX_16X16, // 32x16 transform
2699 TX_32X32, // 32x64 transform
2700 TX_32X32, // 64x32 transform
2701 TX_4X4, // 4x16 transform
2702 TX_4X4, // 16x4 transform
2703 TX_8X8, // 8x32 transform
2704 TX_8X8, // 32x8 transform
2705 TX_16X16, // 16x64 transform
2706 TX_16X16, // 64x16 transform
2707 };
2708 mbmi->tx_size = tx_size_max_square[mbmi->tx_size];
2709 } else if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
2710 !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2711 static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = {
2712 TX_4X4, // 4x4 transform
2713 TX_8X8, // 8x8 transform
2714 TX_16X16, // 16x16 transform
2715 TX_32X32, // 32x32 transform
2716 TX_32X32, // 64x64 transform
2717 TX_4X4, // 4x8 transform
2718 TX_4X4, // 8x4 transform
2719 TX_8X8, // 8x16 transform
2720 TX_8X8, // 16x8 transform
2721 TX_16X16, // 16x32 transform
2722 TX_16X16, // 32x16 transform
2723 TX_32X32, // 32x64 transform
2724 TX_32X32, // 64x32 transform
2725 TX_4X4, // 4x16 transform
2726 TX_4X4, // 16x4 transform
2727 TX_8X8, // 8x32 transform
2728 TX_8X8, // 32x8 transform
2729 TX_16X16, // 16x64 transform
2730 TX_16X16, // 64x16 transform
2731 };
2732
2733 mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size];
2734 }
2735
2736 const int skip_ctx = av1_get_skip_txfm_context(xd);
2737 const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
2738 const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
2739 // Skip RDcost is used only for Inter blocks
2740 const int64_t skip_txfm_rd =
2741 is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
2742 const int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_rate, 0);
2743 const int skip_trellis = 0;
2744 av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2745 AOMMIN(no_skip_txfm_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
2746 mbmi->tx_size, FTXS_NONE, skip_trellis);
2747 }
2748
choose_smallest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2749 static inline void choose_smallest_tx_size(const AV1_COMP *const cpi,
2750 MACROBLOCK *x, RD_STATS *rd_stats,
2751 int64_t ref_best_rd, BLOCK_SIZE bs) {
2752 MACROBLOCKD *const xd = &x->e_mbd;
2753 MB_MODE_INFO *const mbmi = xd->mi[0];
2754
2755 mbmi->tx_size = TX_4X4;
2756 // TODO(any) : Pass this_rd based on skip/non-skip cost
2757 const int skip_trellis = 0;
2758 av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
2759 FTXS_NONE, skip_trellis);
2760 }
2761
2762 #if !CONFIG_REALTIME_ONLY
ml_predict_intra_tx_depth_prune(MACROBLOCK * x,int blk_row,int blk_col,BLOCK_SIZE bsize,TX_SIZE tx_size)2763 static void ml_predict_intra_tx_depth_prune(MACROBLOCK *x, int blk_row,
2764 int blk_col, BLOCK_SIZE bsize,
2765 TX_SIZE tx_size) {
2766 const MACROBLOCKD *const xd = &x->e_mbd;
2767 const MB_MODE_INFO *const mbmi = xd->mi[0];
2768
2769 // Disable the pruning logic using NN model for the following cases:
2770 // 1) Lossless coding as only 4x4 transform is evaluated in this case
2771 // 2) When transform and current block sizes do not match as the features are
2772 // obtained over the current block
2773 // 3) When operating bit-depth is not 8-bit as the input features are not
2774 // scaled according to bit-depth.
2775 if (xd->lossless[mbmi->segment_id] || txsize_to_bsize[tx_size] != bsize ||
2776 xd->bd != 8)
2777 return;
2778
2779 // Currently NN model based pruning is supported only when largest transform
2780 // size is 8x8
2781 if (tx_size != TX_8X8) return;
2782
2783 // Neural network model is a sequential neural net and was trained using SGD
2784 // optimizer. The model can be further improved in terms of speed/quality by
2785 // considering the following experiments:
2786 // 1) Generate ML model by training with balanced data for different learning
2787 // rates and optimizers.
2788 // 2) Experiment with ML model by adding features related to the statistics of
2789 // top and left pixels to capture the accuracy of reconstructed neighbouring
2790 // pixels for 4x4 blocks numbered 1, 2, 3 in 8x8 block, source variance of 4x4
2791 // sub-blocks, etc.
2792 // 3) Generate ML models for transform blocks other than 8x8.
2793 const NN_CONFIG *const nn_config = &av1_intra_tx_split_nnconfig_8x8;
2794 const float *const intra_tx_prune_thresh = av1_intra_tx_prune_nn_thresh_8x8;
2795
2796 float features[NUM_INTRA_TX_SPLIT_FEATURES] = { 0.0f };
2797 const int diff_stride = block_size_wide[bsize];
2798
2799 const int16_t *diff = x->plane[0].src_diff + MI_SIZE * blk_row * diff_stride +
2800 MI_SIZE * blk_col;
2801 const int bw = tx_size_wide[tx_size];
2802 const int bh = tx_size_high[tx_size];
2803
2804 int feature_idx = get_mean_dev_features(diff, diff_stride, bw, bh, features);
2805
2806 features[feature_idx++] = log1pf((float)x->source_variance);
2807
2808 const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
2809 const float log_dc_q_square = log1pf((float)(dc_q * dc_q) / 256.0f);
2810 features[feature_idx++] = log_dc_q_square;
2811 assert(feature_idx == NUM_INTRA_TX_SPLIT_FEATURES);
2812 for (int i = 0; i < NUM_INTRA_TX_SPLIT_FEATURES; i++) {
2813 features[i] = (features[i] - av1_intra_tx_split_8x8_mean[i]) /
2814 av1_intra_tx_split_8x8_std[i];
2815 }
2816
2817 float score;
2818 av1_nn_predict(features, nn_config, 1, &score);
2819
2820 TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2821 if (score <= intra_tx_prune_thresh[0])
2822 txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_SPLIT;
2823 else if (score > intra_tx_prune_thresh[1])
2824 txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_LARGEST;
2825 }
2826 #endif // !CONFIG_REALTIME_ONLY
2827
2828 /*!\brief Transform type search for luma macroblock with fixed transform size.
2829 *
2830 * \ingroup transform_search
2831 * Search for the best transform type and return the transform coefficients RD
2832 * cost of current luma macroblock with the given uniform transform size.
2833 *
2834 * \param[in] x Pointer to structure holding the data for the
2835 current encoding macroblock
2836 * \param[in] cpi Top-level encoder structure
2837 * \param[in] rd_stats Pointer to struct to keep track of the RD stats
2838 * \param[in] ref_best_rd Best RD cost seen for this block so far
2839 * \param[in] bs Size of the current macroblock
2840 * \param[in] tx_size The given transform size
2841 * \param[in] ftxs_mode Transform search mode specifying desired speed
2842 and quality tradeoff
2843 * \param[in] skip_trellis Binary flag indicating if trellis optimization
2844 should be skipped
2845 * \return An int64_t value that is the best RD cost found.
2846 */
uniform_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)2847 static int64_t uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
2848 RD_STATS *rd_stats, int64_t ref_best_rd,
2849 BLOCK_SIZE bs, TX_SIZE tx_size,
2850 FAST_TX_SEARCH_MODE ftxs_mode,
2851 int skip_trellis) {
2852 assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
2853 MACROBLOCKD *const xd = &x->e_mbd;
2854 MB_MODE_INFO *const mbmi = xd->mi[0];
2855 const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2856 const ModeCosts *mode_costs = &x->mode_costs;
2857 const int is_inter = is_inter_block(mbmi);
2858 const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
2859 block_signals_txsize(mbmi->bsize);
2860 int tx_size_rate = 0;
2861 if (tx_select) {
2862 const int ctx = txfm_partition_context(
2863 xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
2864 tx_size_rate = is_inter ? mode_costs->txfm_partition_cost[ctx][0]
2865 : tx_size_cost(x, bs, tx_size);
2866 }
2867 const int skip_ctx = av1_get_skip_txfm_context(xd);
2868 const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
2869 const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
2870 const int64_t skip_txfm_rd =
2871 is_inter ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
2872 const int64_t no_this_rd =
2873 RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
2874
2875 mbmi->tx_size = tx_size;
2876 av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2877 AOMMIN(no_this_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
2878 tx_size, ftxs_mode, skip_trellis);
2879 if (rd_stats->rate == INT_MAX) return INT64_MAX;
2880
2881 int64_t rd;
2882 // rdstats->rate should include all the rate except skip/non-skip cost as the
2883 // same is accounted in the caller functions after rd evaluation of all
2884 // planes. However the decisions should be done after considering the
2885 // skip/non-skip header cost
2886 if (rd_stats->skip_txfm && is_inter) {
2887 rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
2888 } else {
2889 // Intra blocks are always signalled as non-skip
2890 rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
2891 rd_stats->dist);
2892 rd_stats->rate += tx_size_rate;
2893 }
2894 // Check if forcing the block to skip transform leads to smaller RD cost.
2895 if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
2896 int64_t temp_skip_txfm_rd =
2897 RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
2898 if (temp_skip_txfm_rd <= rd) {
2899 rd = temp_skip_txfm_rd;
2900 rd_stats->rate = 0;
2901 rd_stats->dist = rd_stats->sse;
2902 rd_stats->skip_txfm = 1;
2903 }
2904 }
2905
2906 return rd;
2907 }
2908
2909 // Search for the best uniform transform size and type for current coding block.
choose_tx_size_type_from_rd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2910 static inline void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
2911 MACROBLOCK *x,
2912 RD_STATS *rd_stats,
2913 int64_t ref_best_rd,
2914 BLOCK_SIZE bs) {
2915 av1_invalid_rd_stats(rd_stats);
2916
2917 MACROBLOCKD *const xd = &x->e_mbd;
2918 MB_MODE_INFO *const mbmi = xd->mi[0];
2919 TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2920 const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
2921 const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT;
2922 int start_tx;
2923 // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls
2924 // how many times of splitting is allowed during the RD search.
2925 int init_depth;
2926
2927 if (tx_select) {
2928 start_tx = max_rect_tx_size;
2929 init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
2930 is_inter_block(mbmi), &cpi->sf,
2931 txfm_params->tx_size_search_method);
2932 if (init_depth == MAX_TX_DEPTH && !cpi->oxcf.txfm_cfg.enable_tx64 &&
2933 txsize_sqr_up_map[start_tx] == TX_64X64) {
2934 start_tx = sub_tx_size_map[start_tx];
2935 }
2936 } else {
2937 const TX_SIZE chosen_tx_size =
2938 tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2939 start_tx = chosen_tx_size;
2940 init_depth = MAX_TX_DEPTH;
2941 }
2942
2943 const int skip_trellis = 0;
2944 uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
2945 uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
2946 TX_SIZE best_tx_size = max_rect_tx_size;
2947 int64_t best_rd = INT64_MAX;
2948 const int num_blks = bsize_to_num_blk(bs);
2949 x->rd_model = FULL_TXFM_RD;
2950 int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
2951 TxfmSearchInfo *txfm_info = &x->txfm_search_info;
2952 for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
2953 depth++, tx_size = sub_tx_size_map[tx_size]) {
2954 if ((!cpi->oxcf.txfm_cfg.enable_tx64 &&
2955 txsize_sqr_up_map[tx_size] == TX_64X64) ||
2956 (!cpi->oxcf.txfm_cfg.enable_rect_tx &&
2957 tx_size_wide[tx_size] != tx_size_high[tx_size])) {
2958 continue;
2959 }
2960
2961 #if !CONFIG_REALTIME_ONLY
2962 if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_SPLIT) break;
2963
2964 // Set the flag to enable the evaluation of NN classifier to prune transform
2965 // depths. As the features are based on intra residual information of
2966 // largest transform, the evaluation of NN model is enabled only for this
2967 // case.
2968 txfm_params->enable_nn_prune_intra_tx_depths =
2969 (cpi->sf.tx_sf.prune_intra_tx_depths_using_nn && tx_size == start_tx);
2970 #endif
2971
2972 RD_STATS this_rd_stats;
2973 // When the speed feature use_rd_based_breakout_for_intra_tx_search is
2974 // enabled, use the known minimum best_rd for early termination.
2975 const int64_t rd_thresh =
2976 cpi->sf.tx_sf.use_rd_based_breakout_for_intra_tx_search
2977 ? AOMMIN(ref_best_rd, best_rd)
2978 : ref_best_rd;
2979 rd[depth] = uniform_txfm_yrd(cpi, x, &this_rd_stats, rd_thresh, bs, tx_size,
2980 FTXS_NONE, skip_trellis);
2981 if (rd[depth] < best_rd) {
2982 av1_copy_array(best_blk_skip, txfm_info->blk_skip, num_blks);
2983 av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks);
2984 best_tx_size = tx_size;
2985 best_rd = rd[depth];
2986 *rd_stats = this_rd_stats;
2987 }
2988 if (tx_size == TX_4X4) break;
2989 // If we are searching three depths, prune the smallest size depending
2990 // on rd results for the first two depths for low contrast blocks.
2991 if (depth > init_depth && depth != MAX_TX_DEPTH &&
2992 x->source_variance < 256) {
2993 if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
2994 }
2995 }
2996
2997 if (rd_stats->rate != INT_MAX) {
2998 mbmi->tx_size = best_tx_size;
2999 av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks);
3000 av1_copy_array(txfm_info->blk_skip, best_blk_skip, num_blks);
3001 }
3002
3003 #if !CONFIG_REALTIME_ONLY
3004 // Reset the flags to avoid any unintentional evaluation of NN model and
3005 // consumption of prune depths.
3006 txfm_params->enable_nn_prune_intra_tx_depths = false;
3007 txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_NONE;
3008 #endif
3009 }
3010
3011 // Search for the best transform type for the given transform block in the
3012 // given plane/channel, and calculate the corresponding RD cost.
block_rd_txfm(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)3013 static inline void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
3014 BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
3015 void *arg) {
3016 struct rdcost_block_args *args = arg;
3017 if (args->exit_early) {
3018 args->incomplete_exit = 1;
3019 return;
3020 }
3021
3022 MACROBLOCK *const x = args->x;
3023 MACROBLOCKD *const xd = &x->e_mbd;
3024 const int is_inter = is_inter_block(xd->mi[0]);
3025 const AV1_COMP *cpi = args->cpi;
3026 ENTROPY_CONTEXT *a = args->t_above + blk_col;
3027 ENTROPY_CONTEXT *l = args->t_left + blk_row;
3028 const AV1_COMMON *cm = &cpi->common;
3029 RD_STATS this_rd_stats;
3030 av1_init_rd_stats(&this_rd_stats);
3031
3032 if (!is_inter) {
3033 av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
3034 av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
3035 #if !CONFIG_REALTIME_ONLY
3036 const TxfmSearchParams *const txfm_params = &x->txfm_search_params;
3037 if (txfm_params->enable_nn_prune_intra_tx_depths) {
3038 ml_predict_intra_tx_depth_prune(x, blk_row, blk_col, plane_bsize,
3039 tx_size);
3040 if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_LARGEST) {
3041 av1_invalid_rd_stats(&args->rd_stats);
3042 args->exit_early = 1;
3043 return;
3044 }
3045 }
3046 #endif
3047 }
3048
3049 TXB_CTX txb_ctx;
3050 get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
3051 search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3052 &txb_ctx, args->ftxs_mode, args->skip_trellis,
3053 args->best_rd - args->current_rd, &this_rd_stats);
3054
3055 #if !CONFIG_REALTIME_ONLY
3056 if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
3057 assert(!is_inter || plane_bsize < BLOCK_8X8);
3058 cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
3059 }
3060 #endif
3061
3062 #if CONFIG_RD_DEBUG
3063 update_txb_coeff_cost(&this_rd_stats, plane, this_rd_stats.rate);
3064 #endif // CONFIG_RD_DEBUG
3065 av1_set_txb_context(x, plane, block, tx_size, a, l);
3066
3067 const int blk_idx =
3068 blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
3069
3070 TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3071 if (plane == 0)
3072 set_blk_skip(txfm_info->blk_skip, plane, blk_idx,
3073 x->plane[plane].eobs[block] == 0);
3074 else
3075 set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 0);
3076
3077 int64_t rd;
3078 if (is_inter) {
3079 const int64_t no_skip_txfm_rd =
3080 RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3081 const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3082 rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd);
3083 this_rd_stats.skip_txfm &= !x->plane[plane].eobs[block];
3084 } else {
3085 // Signal non-skip_txfm for Intra blocks
3086 rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3087 this_rd_stats.skip_txfm = 0;
3088 }
3089
3090 av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
3091
3092 args->current_rd += rd;
3093 if (args->current_rd > args->best_rd) args->exit_early = 1;
3094 }
3095
av1_estimate_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size)3096 int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3097 RD_STATS *rd_stats, int64_t ref_best_rd,
3098 BLOCK_SIZE bs, TX_SIZE tx_size) {
3099 MACROBLOCKD *const xd = &x->e_mbd;
3100 MB_MODE_INFO *const mbmi = xd->mi[0];
3101 const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3102 const ModeCosts *mode_costs = &x->mode_costs;
3103 const int is_inter = is_inter_block(mbmi);
3104 const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3105 block_signals_txsize(mbmi->bsize);
3106 int tx_size_rate = 0;
3107 if (tx_select) {
3108 const int ctx = txfm_partition_context(
3109 xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3110 tx_size_rate = mode_costs->txfm_partition_cost[ctx][0];
3111 }
3112 const int skip_ctx = av1_get_skip_txfm_context(xd);
3113 const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3114 const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3115 const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, 0);
3116 const int64_t no_this_rd =
3117 RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3118 mbmi->tx_size = tx_size;
3119
3120 const uint8_t txw_unit = tx_size_wide_unit[tx_size];
3121 const uint8_t txh_unit = tx_size_high_unit[tx_size];
3122 const int step = txw_unit * txh_unit;
3123 const int max_blocks_wide = max_block_wide(xd, bs, 0);
3124 const int max_blocks_high = max_block_high(xd, bs, 0);
3125
3126 struct rdcost_block_args args;
3127 av1_zero(args);
3128 args.x = x;
3129 args.cpi = cpi;
3130 args.best_rd = ref_best_rd;
3131 args.current_rd = AOMMIN(no_this_rd, skip_txfm_rd);
3132 av1_init_rd_stats(&args.rd_stats);
3133 av1_get_entropy_contexts(bs, &xd->plane[0], args.t_above, args.t_left);
3134 int i = 0;
3135 for (int blk_row = 0; blk_row < max_blocks_high && !args.incomplete_exit;
3136 blk_row += txh_unit) {
3137 for (int blk_col = 0; blk_col < max_blocks_wide; blk_col += txw_unit) {
3138 RD_STATS this_rd_stats;
3139 av1_init_rd_stats(&this_rd_stats);
3140
3141 if (args.exit_early) {
3142 args.incomplete_exit = 1;
3143 break;
3144 }
3145
3146 ENTROPY_CONTEXT *a = args.t_above + blk_col;
3147 ENTROPY_CONTEXT *l = args.t_left + blk_row;
3148 TXB_CTX txb_ctx;
3149 get_txb_ctx(bs, tx_size, 0, a, l, &txb_ctx);
3150
3151 TxfmParam txfm_param;
3152 QUANT_PARAM quant_param;
3153 av1_setup_xform(&cpi->common, x, tx_size, DCT_DCT, &txfm_param);
3154 av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, 0, &quant_param);
3155
3156 av1_xform(x, 0, i, blk_row, blk_col, bs, &txfm_param);
3157 av1_quant(x, 0, i, &txfm_param, &quant_param);
3158
3159 this_rd_stats.rate =
3160 cost_coeffs(x, 0, i, tx_size, txfm_param.tx_type, &txb_ctx, 0);
3161
3162 const SCAN_ORDER *const scan_order =
3163 get_scan(txfm_param.tx_size, txfm_param.tx_type);
3164 dist_block_tx_domain(x, 0, i, tx_size, quant_param.qmatrix,
3165 scan_order->scan, &this_rd_stats.dist,
3166 &this_rd_stats.sse);
3167
3168 const int64_t no_skip_txfm_rd =
3169 RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3170 const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3171
3172 this_rd_stats.skip_txfm &= !x->plane[0].eobs[i];
3173
3174 av1_merge_rd_stats(&args.rd_stats, &this_rd_stats);
3175 args.current_rd += AOMMIN(no_skip_txfm_rd, skip_rd);
3176
3177 if (args.current_rd > ref_best_rd) {
3178 args.exit_early = 1;
3179 break;
3180 }
3181
3182 av1_set_txb_context(x, 0, i, tx_size, a, l);
3183 i += step;
3184 }
3185 }
3186
3187 if (args.incomplete_exit) av1_invalid_rd_stats(&args.rd_stats);
3188
3189 *rd_stats = args.rd_stats;
3190 if (rd_stats->rate == INT_MAX) return INT64_MAX;
3191
3192 int64_t rd;
3193 // rdstats->rate should include all the rate except skip/non-skip cost as the
3194 // same is accounted in the caller functions after rd evaluation of all
3195 // planes. However the decisions should be done after considering the
3196 // skip/non-skip header cost
3197 if (rd_stats->skip_txfm && is_inter) {
3198 rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3199 } else {
3200 // Intra blocks are always signalled as non-skip
3201 rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3202 rd_stats->dist);
3203 rd_stats->rate += tx_size_rate;
3204 }
3205 // Check if forcing the block to skip transform leads to smaller RD cost.
3206 if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3207 int64_t temp_skip_txfm_rd =
3208 RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3209 if (temp_skip_txfm_rd <= rd) {
3210 rd = temp_skip_txfm_rd;
3211 rd_stats->rate = 0;
3212 rd_stats->dist = rd_stats->sse;
3213 rd_stats->skip_txfm = 1;
3214 }
3215 }
3216
3217 return rd;
3218 }
3219
3220 // Search for the best transform type for a luma inter-predicted block, given
3221 // the transform block partitions.
3222 // This function is used only when some speed features are enabled.
tx_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,BLOCK_SIZE plane_bsize,int depth,ENTROPY_CONTEXT * above_ctx,ENTROPY_CONTEXT * left_ctx,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int64_t ref_best_rd,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode)3223 static inline void tx_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
3224 int blk_col, int block, TX_SIZE tx_size,
3225 BLOCK_SIZE plane_bsize, int depth,
3226 ENTROPY_CONTEXT *above_ctx,
3227 ENTROPY_CONTEXT *left_ctx,
3228 TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
3229 int64_t ref_best_rd, RD_STATS *rd_stats,
3230 FAST_TX_SEARCH_MODE ftxs_mode) {
3231 assert(tx_size < TX_SIZES_ALL);
3232 MACROBLOCKD *const xd = &x->e_mbd;
3233 MB_MODE_INFO *const mbmi = xd->mi[0];
3234 assert(is_inter_block(mbmi));
3235 const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
3236 const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
3237
3238 if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
3239
3240 const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
3241 plane_bsize, blk_row, blk_col)];
3242 const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
3243 mbmi->bsize, tx_size);
3244
3245 av1_init_rd_stats(rd_stats);
3246 if (tx_size == plane_tx_size) {
3247 ENTROPY_CONTEXT *ta = above_ctx + blk_col;
3248 ENTROPY_CONTEXT *tl = left_ctx + blk_row;
3249 const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
3250 TXB_CTX txb_ctx;
3251 get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
3252
3253 const int zero_blk_rate =
3254 x->coeff_costs.coeff_costs[txs_ctx][get_plane_type(0)]
3255 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
3256 rd_stats->zero_rate = zero_blk_rate;
3257 tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
3258 rd_stats, ftxs_mode, ref_best_rd);
3259 const int mi_width = mi_size_wide[plane_bsize];
3260 TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3261 if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
3262 RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
3263 rd_stats->skip_txfm == 1) {
3264 rd_stats->rate = zero_blk_rate;
3265 rd_stats->dist = rd_stats->sse;
3266 rd_stats->skip_txfm = 1;
3267 set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 1);
3268 x->plane[0].eobs[block] = 0;
3269 x->plane[0].txb_entropy_ctx[block] = 0;
3270 update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
3271 } else {
3272 rd_stats->skip_txfm = 0;
3273 set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 0);
3274 }
3275 if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3276 rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][0];
3277 av1_set_txb_context(x, 0, block, tx_size, ta, tl);
3278 txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
3279 tx_size);
3280 } else {
3281 const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
3282 const int txb_width = tx_size_wide_unit[sub_txs];
3283 const int txb_height = tx_size_high_unit[sub_txs];
3284 const int step = txb_height * txb_width;
3285 const int row_end =
3286 AOMMIN(tx_size_high_unit[tx_size], max_blocks_high - blk_row);
3287 const int col_end =
3288 AOMMIN(tx_size_wide_unit[tx_size], max_blocks_wide - blk_col);
3289 RD_STATS pn_rd_stats;
3290 int64_t this_rd = 0;
3291 assert(txb_width > 0 && txb_height > 0);
3292
3293 for (int row = 0; row < row_end; row += txb_height) {
3294 const int offsetr = blk_row + row;
3295 for (int col = 0; col < col_end; col += txb_width) {
3296 const int offsetc = blk_col + col;
3297
3298 av1_init_rd_stats(&pn_rd_stats);
3299 tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
3300 depth + 1, above_ctx, left_ctx, tx_above, tx_left,
3301 ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
3302 if (pn_rd_stats.rate == INT_MAX) {
3303 av1_invalid_rd_stats(rd_stats);
3304 return;
3305 }
3306 av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3307 this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
3308 block += step;
3309 }
3310 }
3311
3312 if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3313 rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][1];
3314 }
3315 }
3316
3317 // search for tx type with tx sizes already decided for a inter-predicted luma
3318 // partition block. It's used only when some speed features are enabled.
3319 // Return value 0: early termination triggered, no valid rd cost available;
3320 // 1: rd cost values are valid.
inter_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode)3321 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3322 RD_STATS *rd_stats, BLOCK_SIZE bsize,
3323 int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
3324 if (ref_best_rd < 0) {
3325 av1_invalid_rd_stats(rd_stats);
3326 return 0;
3327 }
3328
3329 av1_init_rd_stats(rd_stats);
3330
3331 MACROBLOCKD *const xd = &x->e_mbd;
3332 const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3333 const struct macroblockd_plane *const pd = &xd->plane[0];
3334 const int mi_width = mi_size_wide[bsize];
3335 const int mi_height = mi_size_high[bsize];
3336 const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0);
3337 const int bh = tx_size_high_unit[max_tx_size];
3338 const int bw = tx_size_wide_unit[max_tx_size];
3339 const int step = bw * bh;
3340 const int init_depth = get_search_init_depth(
3341 mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3342 ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3343 ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3344 TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3345 TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3346 av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3347 memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3348 memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3349
3350 int64_t this_rd = 0;
3351 for (int idy = 0, block = 0; idy < mi_height; idy += bh) {
3352 for (int idx = 0; idx < mi_width; idx += bw) {
3353 RD_STATS pn_rd_stats;
3354 av1_init_rd_stats(&pn_rd_stats);
3355 tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth,
3356 ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd,
3357 &pn_rd_stats, ftxs_mode);
3358 if (pn_rd_stats.rate == INT_MAX) {
3359 av1_invalid_rd_stats(rd_stats);
3360 return 0;
3361 }
3362 av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3363 this_rd +=
3364 AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
3365 RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
3366 block += step;
3367 }
3368 }
3369
3370 const int skip_ctx = av1_get_skip_txfm_context(xd);
3371 const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3372 const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3373 const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3374 this_rd =
3375 RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate, rd_stats->dist);
3376 if (skip_txfm_rd < this_rd) {
3377 this_rd = skip_txfm_rd;
3378 rd_stats->rate = 0;
3379 rd_stats->dist = rd_stats->sse;
3380 rd_stats->skip_txfm = 1;
3381 }
3382
3383 const int is_cost_valid = this_rd > ref_best_rd;
3384 if (!is_cost_valid) {
3385 // reset cost value
3386 av1_invalid_rd_stats(rd_stats);
3387 }
3388 return is_cost_valid;
3389 }
3390
3391 // Search for the best transform size and type for current inter-predicted
3392 // luma block with recursive transform block partitioning. The obtained
3393 // transform selection will be saved in xd->mi[0], the corresponding RD stats
3394 // will be saved in rd_stats. The returned value is the corresponding RD cost.
select_tx_size_and_type(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3395 static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
3396 RD_STATS *rd_stats, BLOCK_SIZE bsize,
3397 int64_t ref_best_rd) {
3398 MACROBLOCKD *const xd = &x->e_mbd;
3399 const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3400 assert(is_inter_block(xd->mi[0]));
3401 assert(bsize < BLOCK_SIZES_ALL);
3402 const int fast_tx_search = txfm_params->tx_size_search_method > USE_FULL_RD;
3403 int64_t rd_thresh = ref_best_rd;
3404 if (rd_thresh == 0) {
3405 av1_invalid_rd_stats(rd_stats);
3406 return INT64_MAX;
3407 }
3408 if (fast_tx_search && rd_thresh < INT64_MAX) {
3409 if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
3410 }
3411 assert(rd_thresh > 0);
3412 const FAST_TX_SEARCH_MODE ftxs_mode =
3413 fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
3414 const struct macroblockd_plane *const pd = &xd->plane[0];
3415 assert(bsize < BLOCK_SIZES_ALL);
3416 const int mi_width = mi_size_wide[bsize];
3417 const int mi_height = mi_size_high[bsize];
3418 ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3419 ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3420 TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3421 TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3422 av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3423 memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3424 memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3425 const int init_depth = get_search_init_depth(
3426 mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3427 const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
3428 const int bh = tx_size_high_unit[max_tx_size];
3429 const int bw = tx_size_wide_unit[max_tx_size];
3430 const int step = bw * bh;
3431 const int skip_ctx = av1_get_skip_txfm_context(xd);
3432 const int no_skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3433 const int skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3434 int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, 0);
3435 int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_cost, 0);
3436 int block = 0;
3437
3438 av1_init_rd_stats(rd_stats);
3439 for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) {
3440 for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) {
3441 const int64_t best_rd_sofar =
3442 (rd_thresh == INT64_MAX)
3443 ? INT64_MAX
3444 : (rd_thresh - (AOMMIN(skip_txfm_rd, no_skip_txfm_rd)));
3445 int is_cost_valid = 1;
3446 RD_STATS pn_rd_stats;
3447 // Search for the best transform block size and type for the sub-block.
3448 select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize,
3449 ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX,
3450 best_rd_sofar, &is_cost_valid, ftxs_mode);
3451 if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
3452 av1_invalid_rd_stats(rd_stats);
3453 return INT64_MAX;
3454 }
3455 av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3456 skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3457 no_skip_txfm_rd =
3458 RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3459 block += step;
3460 }
3461 }
3462
3463 if (rd_stats->rate == INT_MAX) return INT64_MAX;
3464
3465 rd_stats->skip_txfm = (skip_txfm_rd <= no_skip_txfm_rd);
3466
3467 // If fast_tx_search is true, only DCT and 1D DCT were tested in
3468 // select_inter_block_yrd() above. Do a better search for tx type with
3469 // tx sizes already decided.
3470 if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) {
3471 if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
3472 return INT64_MAX;
3473 }
3474
3475 int64_t final_rd;
3476 if (rd_stats->skip_txfm) {
3477 final_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3478 } else {
3479 final_rd =
3480 RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3481 if (!xd->lossless[xd->mi[0]->segment_id]) {
3482 final_rd =
3483 AOMMIN(final_rd, RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse));
3484 }
3485 }
3486
3487 return final_rd;
3488 }
3489
3490 // Return 1 to terminate transform search early. The decision is made based on
3491 // the comparison with the reference RD cost and the model-estimated RD cost.
model_based_tx_search_prune(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int64_t ref_best_rd)3492 static inline int model_based_tx_search_prune(const AV1_COMP *cpi,
3493 MACROBLOCK *x, BLOCK_SIZE bsize,
3494 int64_t ref_best_rd) {
3495 const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
3496 assert(level >= 0 && level <= 2);
3497 int model_rate;
3498 int64_t model_dist;
3499 uint8_t model_skip;
3500 MACROBLOCKD *const xd = &x->e_mbd;
3501 model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
3502 cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
3503 NULL, NULL, NULL);
3504 if (model_skip) return 0;
3505 const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
3506 // TODO(debargha, urvang): Improve the model and make the check below
3507 // tighter.
3508 static const int prune_factor_by8[] = { 3, 5 };
3509 const int factor = prune_factor_by8[level - 1];
3510 return ((model_rd * factor) >> 3) > ref_best_rd;
3511 }
3512
av1_pick_recursive_tx_size_type_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3513 void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3514 RD_STATS *rd_stats, BLOCK_SIZE bsize,
3515 int64_t ref_best_rd) {
3516 MACROBLOCKD *const xd = &x->e_mbd;
3517 const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3518 assert(is_inter_block(xd->mi[0]));
3519
3520 av1_invalid_rd_stats(rd_stats);
3521
3522 // If modeled RD cost is a lot worse than the best so far, terminate early.
3523 if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
3524 ref_best_rd != INT64_MAX) {
3525 if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
3526 }
3527
3528 // Hashing based speed feature. If the hash of the prediction residue block is
3529 // found in the hash table, use previous search results and terminate early.
3530 uint32_t hash = 0;
3531 MB_RD_RECORD *mb_rd_record = NULL;
3532 const int mi_row = x->e_mbd.mi_row;
3533 const int mi_col = x->e_mbd.mi_col;
3534 const int within_border =
3535 mi_row >= xd->tile.mi_row_start &&
3536 (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
3537 mi_col >= xd->tile.mi_col_start &&
3538 (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
3539 const int is_mb_rd_hash_enabled =
3540 (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
3541 const int n4 = bsize_to_num_blk(bsize);
3542 if (is_mb_rd_hash_enabled) {
3543 hash = get_block_residue_hash(x, bsize);
3544 mb_rd_record = x->txfm_search_info.mb_rd_record;
3545 const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3546 if (match_index != -1) {
3547 MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3548 fetch_mb_rd_info(n4, mb_rd_info, rd_stats, x);
3549 return;
3550 }
3551 }
3552
3553 // If we predict that skip is the optimal RD decision - set the respective
3554 // context and terminate early.
3555 int64_t dist;
3556 if (txfm_params->skip_txfm_level &&
3557 predict_skip_txfm(x, bsize, &dist,
3558 cpi->common.features.reduced_tx_set_used)) {
3559 set_skip_txfm(x, rd_stats, bsize, dist);
3560 // Save the RD search results into mb_rd_record.
3561 if (is_mb_rd_hash_enabled)
3562 save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3563 return;
3564 }
3565 #if CONFIG_SPEED_STATS
3566 ++x->txfm_search_info.tx_search_count;
3567 #endif // CONFIG_SPEED_STATS
3568
3569 const int64_t rd =
3570 select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd);
3571
3572 if (rd == INT64_MAX) {
3573 // We should always find at least one candidate unless ref_best_rd is less
3574 // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
3575 // might have failed to find something better)
3576 assert(ref_best_rd != INT64_MAX);
3577 av1_invalid_rd_stats(rd_stats);
3578 return;
3579 }
3580
3581 // Save the RD search results into mb_rd_record.
3582 if (is_mb_rd_hash_enabled) {
3583 assert(mb_rd_record != NULL);
3584 save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3585 }
3586 }
3587
av1_pick_uniform_tx_size_type_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bs,int64_t ref_best_rd)3588 void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3589 RD_STATS *rd_stats, BLOCK_SIZE bs,
3590 int64_t ref_best_rd) {
3591 MACROBLOCKD *const xd = &x->e_mbd;
3592 MB_MODE_INFO *const mbmi = xd->mi[0];
3593 const TxfmSearchParams *tx_params = &x->txfm_search_params;
3594 assert(bs == mbmi->bsize);
3595 const int is_inter = is_inter_block(mbmi);
3596 const int mi_row = xd->mi_row;
3597 const int mi_col = xd->mi_col;
3598
3599 av1_init_rd_stats(rd_stats);
3600
3601 // Hashing based speed feature for inter blocks. If the hash of the residue
3602 // block is found in the table, use previously saved search results and
3603 // terminate early.
3604 uint32_t hash = 0;
3605 MB_RD_RECORD *mb_rd_record = NULL;
3606 const int num_blks = bsize_to_num_blk(bs);
3607 if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) {
3608 const int within_border =
3609 mi_row >= xd->tile.mi_row_start &&
3610 (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
3611 mi_col >= xd->tile.mi_col_start &&
3612 (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
3613 if (within_border) {
3614 hash = get_block_residue_hash(x, bs);
3615 mb_rd_record = x->txfm_search_info.mb_rd_record;
3616 const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3617 if (match_index != -1) {
3618 MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3619 fetch_mb_rd_info(num_blks, mb_rd_info, rd_stats, x);
3620 return;
3621 }
3622 }
3623 }
3624
3625 // If we predict that skip is the optimal RD decision - set the respective
3626 // context and terminate early.
3627 int64_t dist;
3628 if (tx_params->skip_txfm_level && is_inter &&
3629 !xd->lossless[mbmi->segment_id] &&
3630 predict_skip_txfm(x, bs, &dist,
3631 cpi->common.features.reduced_tx_set_used)) {
3632 // Populate rdstats as per skip decision
3633 set_skip_txfm(x, rd_stats, bs, dist);
3634 // Save the RD search results into mb_rd_record.
3635 if (mb_rd_record) {
3636 save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3637 }
3638 return;
3639 }
3640
3641 if (xd->lossless[mbmi->segment_id]) {
3642 // Lossless mode can only pick the smallest (4x4) transform size.
3643 choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3644 } else if (tx_params->tx_size_search_method == USE_LARGESTALL) {
3645 choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3646 } else {
3647 choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
3648 }
3649
3650 // Save the RD search results into mb_rd_record for possible reuse in future.
3651 if (mb_rd_record) {
3652 save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3653 }
3654 }
3655
av1_txfm_uvrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3656 int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
3657 BLOCK_SIZE bsize, int64_t ref_best_rd) {
3658 av1_init_rd_stats(rd_stats);
3659 if (ref_best_rd < 0) return 0;
3660 if (!x->e_mbd.is_chroma_ref) return 1;
3661
3662 MACROBLOCKD *const xd = &x->e_mbd;
3663 MB_MODE_INFO *const mbmi = xd->mi[0];
3664 struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
3665 const int is_inter = is_inter_block(mbmi);
3666 int64_t this_rd = 0, skip_txfm_rd = 0;
3667 const BLOCK_SIZE plane_bsize =
3668 get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
3669
3670 if (is_inter) {
3671 for (int plane = 1; plane < MAX_MB_PLANE; ++plane)
3672 av1_subtract_plane(x, plane_bsize, plane);
3673 }
3674
3675 const int skip_trellis = 0;
3676 const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
3677 int is_cost_valid = 1;
3678 for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
3679 RD_STATS this_rd_stats;
3680 int64_t chroma_ref_best_rd = ref_best_rd;
3681 // For inter blocks, refined ref_best_rd is used for early exit
3682 // For intra blocks, even though current rd crosses ref_best_rd, early
3683 // exit is not recommended as current rd is used for gating subsequent
3684 // modes as well (say, for angular modes)
3685 // TODO(any): Extend the early exit mechanism for intra modes as well
3686 if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
3687 chroma_ref_best_rd != INT64_MAX)
3688 chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
3689 av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
3690 plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis);
3691 if (this_rd_stats.rate == INT_MAX) {
3692 is_cost_valid = 0;
3693 break;
3694 }
3695 av1_merge_rd_stats(rd_stats, &this_rd_stats);
3696 this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3697 skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
3698 if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
3699 is_cost_valid = 0;
3700 break;
3701 }
3702 }
3703
3704 if (!is_cost_valid) {
3705 // reset cost value
3706 av1_invalid_rd_stats(rd_stats);
3707 }
3708
3709 return is_cost_valid;
3710 }
3711
av1_txfm_rd_in_plane(MACROBLOCK * x,const AV1_COMP * cpi,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t current_rd,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3712 void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3713 RD_STATS *rd_stats, int64_t ref_best_rd,
3714 int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
3715 TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
3716 int skip_trellis) {
3717 assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size));
3718
3719 if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
3720 txsize_sqr_up_map[tx_size] == TX_64X64) {
3721 av1_invalid_rd_stats(rd_stats);
3722 return;
3723 }
3724
3725 if (current_rd > ref_best_rd) {
3726 av1_invalid_rd_stats(rd_stats);
3727 return;
3728 }
3729
3730 MACROBLOCKD *const xd = &x->e_mbd;
3731 const struct macroblockd_plane *const pd = &xd->plane[plane];
3732 struct rdcost_block_args args;
3733 av1_zero(args);
3734 args.x = x;
3735 args.cpi = cpi;
3736 args.best_rd = ref_best_rd;
3737 args.current_rd = current_rd;
3738 args.ftxs_mode = ftxs_mode;
3739 args.skip_trellis = skip_trellis;
3740 av1_init_rd_stats(&args.rd_stats);
3741
3742 av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
3743 av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
3744 &args);
3745
3746 MB_MODE_INFO *const mbmi = xd->mi[0];
3747 const int is_inter = is_inter_block(mbmi);
3748 const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3749
3750 if (invalid_rd) {
3751 av1_invalid_rd_stats(rd_stats);
3752 } else {
3753 *rd_stats = args.rd_stats;
3754 }
3755 }
3756
av1_txfm_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int mode_rate,int64_t ref_best_rd)3757 int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
3758 RD_STATS *rd_stats, RD_STATS *rd_stats_y,
3759 RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
3760 MACROBLOCKD *const xd = &x->e_mbd;
3761 TxfmSearchParams *txfm_params = &x->txfm_search_params;
3762 const int skip_ctx = av1_get_skip_txfm_context(xd);
3763 const int skip_txfm_cost[2] = { x->mode_costs.skip_txfm_cost[skip_ctx][0],
3764 x->mode_costs.skip_txfm_cost[skip_ctx][1] };
3765 const int64_t min_header_rate =
3766 mode_rate + AOMMIN(skip_txfm_cost[0], skip_txfm_cost[1]);
3767 // Account for minimum skip and non_skip rd.
3768 // Eventually either one of them will be added to mode_rate
3769 const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
3770 if (min_header_rd_possible > ref_best_rd) {
3771 av1_invalid_rd_stats(rd_stats_y);
3772 return 0;
3773 }
3774
3775 const AV1_COMMON *cm = &cpi->common;
3776 MB_MODE_INFO *const mbmi = xd->mi[0];
3777 const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
3778 const int64_t rd_thresh =
3779 ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
3780 av1_init_rd_stats(rd_stats);
3781 av1_init_rd_stats(rd_stats_y);
3782 rd_stats->rate = mode_rate;
3783
3784 // cost and distortion
3785 av1_subtract_plane(x, bsize, 0);
3786 if (txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3787 !xd->lossless[mbmi->segment_id]) {
3788 av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3789 #if CONFIG_COLLECT_RD_STATS == 2
3790 PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
3791 #endif // CONFIG_COLLECT_RD_STATS == 2
3792 } else {
3793 av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3794 memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
3795 for (int i = 0; i < xd->height * xd->width; ++i)
3796 set_blk_skip(x->txfm_search_info.blk_skip, 0, i, rd_stats_y->skip_txfm);
3797 }
3798
3799 if (rd_stats_y->rate == INT_MAX) return 0;
3800
3801 av1_merge_rd_stats(rd_stats, rd_stats_y);
3802
3803 const int64_t non_skip_txfm_rdcosty =
3804 RDCOST(x->rdmult, rd_stats->rate + skip_txfm_cost[0], rd_stats->dist);
3805 const int64_t skip_txfm_rdcosty =
3806 RDCOST(x->rdmult, mode_rate + skip_txfm_cost[1], rd_stats->sse);
3807 const int64_t min_rdcosty = AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty);
3808 if (min_rdcosty > ref_best_rd) return 0;
3809
3810 av1_init_rd_stats(rd_stats_uv);
3811 const int num_planes = av1_num_planes(cm);
3812 if (num_planes > 1) {
3813 int64_t ref_best_chroma_rd = ref_best_rd;
3814 // Calculate best rd cost possible for chroma
3815 if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
3816 (ref_best_chroma_rd != INT64_MAX)) {
3817 ref_best_chroma_rd = (ref_best_chroma_rd -
3818 AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty));
3819 }
3820 const int is_cost_valid_uv =
3821 av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
3822 if (!is_cost_valid_uv) return 0;
3823 av1_merge_rd_stats(rd_stats, rd_stats_uv);
3824 }
3825
3826 int choose_skip_txfm = rd_stats->skip_txfm;
3827 if (!choose_skip_txfm && !xd->lossless[mbmi->segment_id]) {
3828 const int64_t rdcost_no_skip_txfm = RDCOST(
3829 x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_txfm_cost[0],
3830 rd_stats->dist);
3831 const int64_t rdcost_skip_txfm =
3832 RDCOST(x->rdmult, skip_txfm_cost[1], rd_stats->sse);
3833 if (rdcost_no_skip_txfm >= rdcost_skip_txfm) choose_skip_txfm = 1;
3834 }
3835 if (choose_skip_txfm) {
3836 rd_stats_y->rate = 0;
3837 rd_stats_uv->rate = 0;
3838 rd_stats->rate = mode_rate + skip_txfm_cost[1];
3839 rd_stats->dist = rd_stats->sse;
3840 rd_stats_y->dist = rd_stats_y->sse;
3841 rd_stats_uv->dist = rd_stats_uv->sse;
3842 mbmi->skip_txfm = 1;
3843 if (rd_stats->skip_txfm) {
3844 const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3845 if (tmprd > ref_best_rd) return 0;
3846 }
3847 } else {
3848 rd_stats->rate += skip_txfm_cost[0];
3849 mbmi->skip_txfm = 0;
3850 }
3851
3852 return 1;
3853 }
3854