• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/common/cfl.h"
13 #include "av1/common/reconintra.h"
14 #include "av1/encoder/block.h"
15 #include "av1/encoder/hybrid_fwd_txfm.h"
16 #include "av1/common/idct.h"
17 #include "av1/encoder/model_rd.h"
18 #include "av1/encoder/random.h"
19 #include "av1/encoder/rdopt_utils.h"
20 #include "av1/encoder/sorting_network.h"
21 #include "av1/encoder/tx_prune_model_weights.h"
22 #include "av1/encoder/tx_search.h"
23 #include "av1/encoder/txb_rdopt.h"
24 
25 #define PROB_THRESH_OFFSET_TX_TYPE 100
26 
27 struct rdcost_block_args {
28   const AV1_COMP *cpi;
29   MACROBLOCK *x;
30   ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
31   ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
32   RD_STATS rd_stats;
33   int64_t current_rd;
34   int64_t best_rd;
35   int exit_early;
36   int incomplete_exit;
37   FAST_TX_SEARCH_MODE ftxs_mode;
38   int skip_trellis;
39 };
40 
41 typedef struct {
42   int64_t rd;
43   int txb_entropy_ctx;
44   TX_TYPE tx_type;
45 } TxCandidateInfo;
46 
47 // origin_threshold * 128 / 100
48 static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
49   {
50       64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
51       68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
52   },
53   {
54       88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
55       68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
56   },
57   {
58       90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
59       74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
60   },
61 };
62 
63 // lookup table for predict_skip_txfm
64 // int max_tx_size = max_txsize_rect_lookup[bsize];
65 // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
66 //   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
67 static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
68   TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
69   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
70   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
71   TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
72 };
73 
74 // look-up table for sqrt of number of pixels in a transform block
75 // rounded up to the nearest integer.
76 static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] = { 4,  8,  16, 32, 32, 6,  6,
77                                                      12, 12, 23, 23, 32, 32, 8,
78                                                      8,  16, 16, 23, 23 };
79 
get_block_residue_hash(MACROBLOCK * x,BLOCK_SIZE bsize)80 static INLINE uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
81   const int rows = block_size_high[bsize];
82   const int cols = block_size_wide[bsize];
83   const int16_t *diff = x->plane[0].src_diff;
84   const uint32_t hash =
85       av1_get_crc32c_value(&x->txfm_search_info.mb_rd_record->crc_calculator,
86                            (uint8_t *)diff, 2 * rows * cols);
87   return (hash << 5) + bsize;
88 }
89 
find_mb_rd_info(const MB_RD_RECORD * const mb_rd_record,const int64_t ref_best_rd,const uint32_t hash)90 static INLINE int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
91                                       const int64_t ref_best_rd,
92                                       const uint32_t hash) {
93   int32_t match_index = -1;
94   if (ref_best_rd != INT64_MAX) {
95     for (int i = 0; i < mb_rd_record->num; ++i) {
96       const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
97       // If there is a match in the mb_rd_record, fetch the RD decision and
98       // terminate early.
99       if (mb_rd_record->mb_rd_info[index].hash_value == hash) {
100         match_index = index;
101         break;
102       }
103     }
104   }
105   return match_index;
106 }
107 
fetch_mb_rd_info(int n4,const MB_RD_INFO * const mb_rd_info,RD_STATS * const rd_stats,MACROBLOCK * const x)108 static AOM_INLINE void fetch_mb_rd_info(int n4,
109                                         const MB_RD_INFO *const mb_rd_info,
110                                         RD_STATS *const rd_stats,
111                                         MACROBLOCK *const x) {
112   MACROBLOCKD *const xd = &x->e_mbd;
113   MB_MODE_INFO *const mbmi = xd->mi[0];
114   mbmi->tx_size = mb_rd_info->tx_size;
115   memcpy(x->txfm_search_info.blk_skip, mb_rd_info->blk_skip,
116          sizeof(mb_rd_info->blk_skip[0]) * n4);
117   av1_copy(mbmi->inter_tx_size, mb_rd_info->inter_tx_size);
118   av1_copy_array(xd->tx_type_map, mb_rd_info->tx_type_map, n4);
119   *rd_stats = mb_rd_info->rd_stats;
120 }
121 
122 // Compute the pixel domain distortion from diff on all visible 4x4s in the
123 // transform block.
pixel_diff_dist(const MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8)124 static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane,
125                                       int blk_row, int blk_col,
126                                       const BLOCK_SIZE plane_bsize,
127                                       const BLOCK_SIZE tx_bsize,
128                                       unsigned int *block_mse_q8) {
129   int visible_rows, visible_cols;
130   const MACROBLOCKD *xd = &x->e_mbd;
131   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
132                      NULL, &visible_cols, &visible_rows);
133   const int diff_stride = block_size_wide[plane_bsize];
134   const int16_t *diff = x->plane[plane].src_diff;
135 
136   diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
137   uint64_t sse =
138       aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
139   if (block_mse_q8 != NULL) {
140     if (visible_cols > 0 && visible_rows > 0)
141       *block_mse_q8 =
142           (unsigned int)((256 * sse) / (visible_cols * visible_rows));
143     else
144       *block_mse_q8 = UINT_MAX;
145   }
146   return sse;
147 }
148 
149 // Computes the residual block's SSE and mean on all visible 4x4s in the
150 // transform block
pixel_diff_stats(MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8,int64_t * per_px_mean,uint64_t * block_var)151 static INLINE int64_t pixel_diff_stats(
152     MACROBLOCK *x, int plane, int blk_row, int blk_col,
153     const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize,
154     unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) {
155   int visible_rows, visible_cols;
156   const MACROBLOCKD *xd = &x->e_mbd;
157   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
158                      NULL, &visible_cols, &visible_rows);
159   const int diff_stride = block_size_wide[plane_bsize];
160   const int16_t *diff = x->plane[plane].src_diff;
161 
162   diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
163   uint64_t sse = 0;
164   int sum = 0;
165   sse = aom_sum_sse_2d_i16(diff, diff_stride, visible_cols, visible_rows, &sum);
166   if (visible_cols > 0 && visible_rows > 0) {
167     double norm_factor = 1.0 / (visible_cols * visible_rows);
168     int sign_sum = sum > 0 ? 1 : -1;
169     // Conversion to transform domain
170     *per_px_mean = (int64_t)(norm_factor * abs(sum)) << 7;
171     *per_px_mean = sign_sum * (*per_px_mean);
172     *block_mse_q8 = (unsigned int)(norm_factor * (256 * sse));
173     *block_var = (uint64_t)(sse - (uint64_t)(norm_factor * sum * sum));
174   } else {
175     *block_mse_q8 = UINT_MAX;
176   }
177   return sse;
178 }
179 
180 // Uses simple features on top of DCT coefficients to quickly predict
181 // whether optimal RD decision is to skip encoding the residual.
182 // The sse value is stored in dist.
predict_skip_txfm(MACROBLOCK * x,BLOCK_SIZE bsize,int64_t * dist,int reduced_tx_set)183 static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
184                              int reduced_tx_set) {
185   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
186   const int bw = block_size_wide[bsize];
187   const int bh = block_size_high[bsize];
188   const MACROBLOCKD *xd = &x->e_mbd;
189   const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
190 
191   *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
192 
193   const int64_t mse = *dist / bw / bh;
194   // Normalized quantizer takes the transform upscaling factor (8 for tx size
195   // smaller than 32) into account.
196   const int16_t normalized_dc_q = dc_q >> 3;
197   const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
198   // For faster early skip decision, use dist to compare against threshold so
199   // that quality risk is less for the skip=1 decision. Otherwise, use mse
200   // since the fwd_txfm coeff checks will take care of quality
201   // TODO(any): Use dist to return 0 when skip_txfm_level is 1
202   int64_t pred_err = (txfm_params->skip_txfm_level >= 2) ? *dist : mse;
203   // Predict not to skip when error is larger than threshold.
204   if (pred_err > mse_thresh) return 0;
205   // Return as skip otherwise for aggressive early skip
206   else if (txfm_params->skip_txfm_level >= 2)
207     return 1;
208 
209   const int max_tx_size = max_predict_sf_tx_size[bsize];
210   const int tx_h = tx_size_high[max_tx_size];
211   const int tx_w = tx_size_wide[max_tx_size];
212   DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
213   TxfmParam param;
214   param.tx_type = DCT_DCT;
215   param.tx_size = max_tx_size;
216   param.bd = xd->bd;
217   param.is_hbd = is_cur_buf_hbd(xd);
218   param.lossless = 0;
219   param.tx_set_type = av1_get_ext_tx_set_type(
220       param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
221   const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
222   const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
223   const int16_t *src_diff = x->plane[0].src_diff;
224   const int n_coeff = tx_w * tx_h;
225   const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
226   const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
227   const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
228   for (int row = 0; row < bh; row += tx_h) {
229     for (int col = 0; col < bw; col += tx_w) {
230       av1_fwd_txfm(src_diff + col, coefs, bw, &param);
231       // Operating on TX domain, not pixels; we want the QTX quantizers
232       const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
233       if (dc_coef >= dc_thresh) return 0;
234       for (int i = 1; i < n_coeff; ++i) {
235         const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
236         if (ac_coef >= ac_thresh) return 0;
237       }
238     }
239     src_diff += tx_h * bw;
240   }
241   return 1;
242 }
243 
244 // Used to set proper context for early termination with skip = 1.
set_skip_txfm(MACROBLOCK * x,RD_STATS * rd_stats,int bsize,int64_t dist)245 static AOM_INLINE void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats,
246                                      int bsize, int64_t dist) {
247   MACROBLOCKD *const xd = &x->e_mbd;
248   MB_MODE_INFO *const mbmi = xd->mi[0];
249   const int n4 = bsize_to_num_blk(bsize);
250   const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
251   memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
252   memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
253   mbmi->tx_size = tx_size;
254   for (int i = 0; i < n4; ++i)
255     set_blk_skip(x->txfm_search_info.blk_skip, 0, i, 1);
256   rd_stats->skip_txfm = 1;
257   if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
258   rd_stats->dist = rd_stats->sse = (dist << 4);
259   // Though decision is to make the block as skip based on luma stats,
260   // it is possible that block becomes non skip after chroma rd. In addition
261   // intermediate non skip costs calculated by caller function will be
262   // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
263   // accounted). Hence intermediate rate is populated to code the luma tx blks
264   // as skip, the caller function based on final rd decision (i.e., skip vs
265   // non-skip) sets the final rate accordingly. Here the rate populated
266   // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
267   // size possible) in the current block. Eg: For 128*128 block, rate would be
268   // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
269   // block as 'all zeros'
270   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
271   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
272   av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
273   ENTROPY_CONTEXT *ta = ctxa;
274   ENTROPY_CONTEXT *tl = ctxl;
275   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
276   TXB_CTX txb_ctx;
277   get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
278   const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
279                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
280   rd_stats->rate = zero_blk_rate *
281                    (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
282                    (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
283 }
284 
save_mb_rd_info(int n4,uint32_t hash,const MACROBLOCK * const x,const RD_STATS * const rd_stats,MB_RD_RECORD * mb_rd_record)285 static AOM_INLINE void save_mb_rd_info(int n4, uint32_t hash,
286                                        const MACROBLOCK *const x,
287                                        const RD_STATS *const rd_stats,
288                                        MB_RD_RECORD *mb_rd_record) {
289   int index;
290   if (mb_rd_record->num < RD_RECORD_BUFFER_LEN) {
291     index =
292         (mb_rd_record->index_start + mb_rd_record->num) % RD_RECORD_BUFFER_LEN;
293     ++mb_rd_record->num;
294   } else {
295     index = mb_rd_record->index_start;
296     mb_rd_record->index_start =
297         (mb_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
298   }
299   MB_RD_INFO *const mb_rd_info = &mb_rd_record->mb_rd_info[index];
300   const MACROBLOCKD *const xd = &x->e_mbd;
301   const MB_MODE_INFO *const mbmi = xd->mi[0];
302   mb_rd_info->hash_value = hash;
303   mb_rd_info->tx_size = mbmi->tx_size;
304   memcpy(mb_rd_info->blk_skip, x->txfm_search_info.blk_skip,
305          sizeof(mb_rd_info->blk_skip[0]) * n4);
306   av1_copy(mb_rd_info->inter_tx_size, mbmi->inter_tx_size);
307   av1_copy_array(mb_rd_info->tx_type_map, xd->tx_type_map, n4);
308   mb_rd_info->rd_stats = *rd_stats;
309 }
310 
get_search_init_depth(int mi_width,int mi_height,int is_inter,const SPEED_FEATURES * sf,int tx_size_search_method)311 static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
312                                  const SPEED_FEATURES *sf,
313                                  int tx_size_search_method) {
314   if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
315 
316   if (sf->tx_sf.tx_size_search_lgr_block) {
317     if (mi_width > mi_size_wide[BLOCK_64X64] ||
318         mi_height > mi_size_high[BLOCK_64X64])
319       return MAX_VARTX_DEPTH;
320   }
321 
322   if (is_inter) {
323     return (mi_height != mi_width)
324                ? sf->tx_sf.inter_tx_size_search_init_depth_rect
325                : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
326   } else {
327     return (mi_height != mi_width)
328                ? sf->tx_sf.intra_tx_size_search_init_depth_rect
329                : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
330   }
331 }
332 
333 static AOM_INLINE void select_tx_block(
334     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
335     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
336     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
337     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
338     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode);
339 
340 // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
341 // 0: Do not collect any RD stats
342 // 1: Collect RD stats for transform units
343 // 2: Collect RD stats for partition units
344 #if CONFIG_COLLECT_RD_STATS
345 
get_energy_distribution_fine(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int need_4th,double * hordist,double * verdist)346 static AOM_INLINE void get_energy_distribution_fine(
347     const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
348     const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
349     double *verdist) {
350   const int bw = block_size_wide[bsize];
351   const int bh = block_size_high[bsize];
352   unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
353 
354   if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
355     // Special cases: calculate 'esq' values manually, as we don't have 'vf'
356     // functions for the 16 (very small) sub-blocks of this block.
357     const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
358     const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
359     assert(bw <= 32);
360     assert(bh <= 32);
361     assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
362     if (cpi->common.seq_params->use_highbitdepth) {
363       const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
364       const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
365       for (int i = 0; i < bh; ++i)
366         for (int j = 0; j < bw; ++j) {
367           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
368           esq[index] +=
369               (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
370               (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
371         }
372     } else {
373       for (int i = 0; i < bh; ++i)
374         for (int j = 0; j < bw; ++j) {
375           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
376           esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
377                         (src[j + i * src_stride] - dst[j + i * dst_stride]);
378         }
379     }
380   } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
381     const int f_index =
382         (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
383     assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
384     const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
385     assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
386     assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
387     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
388     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
389                                  dst_stride, &esq[1]);
390     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
391                                  dst_stride, &esq[2]);
392     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
393                                  dst_stride, &esq[3]);
394     src += bh / 4 * src_stride;
395     dst += bh / 4 * dst_stride;
396 
397     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
398     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
399                                  dst_stride, &esq[5]);
400     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
401                                  dst_stride, &esq[6]);
402     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
403                                  dst_stride, &esq[7]);
404     src += bh / 4 * src_stride;
405     dst += bh / 4 * dst_stride;
406 
407     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
408     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
409                                  dst_stride, &esq[9]);
410     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
411                                  dst_stride, &esq[10]);
412     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
413                                  dst_stride, &esq[11]);
414     src += bh / 4 * src_stride;
415     dst += bh / 4 * dst_stride;
416 
417     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
418     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
419                                  dst_stride, &esq[13]);
420     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
421                                  dst_stride, &esq[14]);
422     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
423                                  dst_stride, &esq[15]);
424   }
425 
426   double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
427                  esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
428                  esq[12] + esq[13] + esq[14] + esq[15];
429   if (total > 0) {
430     const double e_recip = 1.0 / total;
431     hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
432     hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
433     hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
434     if (need_4th) {
435       hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
436     }
437     verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
438     verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
439     verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
440     if (need_4th) {
441       verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
442     }
443   } else {
444     hordist[0] = verdist[0] = 0.25;
445     hordist[1] = verdist[1] = 0.25;
446     hordist[2] = verdist[2] = 0.25;
447     if (need_4th) {
448       hordist[3] = verdist[3] = 0.25;
449     }
450   }
451 }
452 
get_sse_norm(const int16_t * diff,int stride,int w,int h)453 static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
454   double sum = 0.0;
455   for (int j = 0; j < h; ++j) {
456     for (int i = 0; i < w; ++i) {
457       const int err = diff[j * stride + i];
458       sum += err * err;
459     }
460   }
461   assert(w > 0 && h > 0);
462   return sum / (w * h);
463 }
464 
get_sad_norm(const int16_t * diff,int stride,int w,int h)465 static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
466   double sum = 0.0;
467   for (int j = 0; j < h; ++j) {
468     for (int i = 0; i < w; ++i) {
469       sum += abs(diff[j * stride + i]);
470     }
471   }
472   assert(w > 0 && h > 0);
473   return sum / (w * h);
474 }
475 
get_2x2_normalized_sses_and_sads(const AV1_COMP * const cpi,BLOCK_SIZE tx_bsize,const uint8_t * const src,int src_stride,const uint8_t * const dst,int dst_stride,const int16_t * const src_diff,int diff_stride,double * const sse_norm_arr,double * const sad_norm_arr)476 static AOM_INLINE void get_2x2_normalized_sses_and_sads(
477     const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
478     int src_stride, const uint8_t *const dst, int dst_stride,
479     const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
480     double *const sad_norm_arr) {
481   const BLOCK_SIZE tx_bsize_half =
482       get_partition_subsize(tx_bsize, PARTITION_SPLIT);
483   if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
484     const int half_width = block_size_wide[tx_bsize] / 2;
485     const int half_height = block_size_high[tx_bsize] / 2;
486     for (int row = 0; row < 2; ++row) {
487       for (int col = 0; col < 2; ++col) {
488         const int16_t *const this_src_diff =
489             src_diff + row * half_height * diff_stride + col * half_width;
490         if (sse_norm_arr) {
491           sse_norm_arr[row * 2 + col] =
492               get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
493         }
494         if (sad_norm_arr) {
495           sad_norm_arr[row * 2 + col] =
496               get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
497         }
498       }
499     }
500   } else {  // use function pointers to calculate stats
501     const int half_width = block_size_wide[tx_bsize_half];
502     const int half_height = block_size_high[tx_bsize_half];
503     const int num_samples_half = half_width * half_height;
504     for (int row = 0; row < 2; ++row) {
505       for (int col = 0; col < 2; ++col) {
506         const uint8_t *const this_src =
507             src + row * half_height * src_stride + col * half_width;
508         const uint8_t *const this_dst =
509             dst + row * half_height * dst_stride + col * half_width;
510 
511         if (sse_norm_arr) {
512           unsigned int this_sse;
513           cpi->ppi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
514                                              dst_stride, &this_sse);
515           sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
516         }
517 
518         if (sad_norm_arr) {
519           const unsigned int this_sad = cpi->ppi->fn_ptr[tx_bsize_half].sdf(
520               this_src, src_stride, this_dst, dst_stride);
521           sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
522         }
523       }
524     }
525   }
526 }
527 
528 #if CONFIG_COLLECT_RD_STATS == 1
get_mean(const int16_t * diff,int stride,int w,int h)529 static double get_mean(const int16_t *diff, int stride, int w, int h) {
530   double sum = 0.0;
531   for (int j = 0; j < h; ++j) {
532     for (int i = 0; i < w; ++i) {
533       sum += diff[j * stride + i];
534     }
535   }
536   assert(w > 0 && h > 0);
537   return sum / (w * h);
538 }
PrintTransformUnitStats(const AV1_COMP * const cpi,MACROBLOCK * x,const RD_STATS * const rd_stats,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,TX_TYPE tx_type,int64_t rd)539 static AOM_INLINE void PrintTransformUnitStats(
540     const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
541     int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
542     TX_TYPE tx_type, int64_t rd) {
543   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
544 
545   // Generate small sample to restrict output size.
546   static unsigned int seed = 21743;
547   if (lcg_rand16(&seed) % 256 > 0) return;
548 
549   const char output_file[] = "tu_stats.txt";
550   FILE *fout = fopen(output_file, "a");
551   if (!fout) return;
552 
553   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
554   const MACROBLOCKD *const xd = &x->e_mbd;
555   const int plane = 0;
556   struct macroblock_plane *const p = &x->plane[plane];
557   const struct macroblockd_plane *const pd = &xd->plane[plane];
558   const int txw = tx_size_wide[tx_size];
559   const int txh = tx_size_high[tx_size];
560   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
561   const int q_step = p->dequant_QTX[1] >> dequant_shift;
562   const int num_samples = txw * txh;
563 
564   const double rate_norm = (double)rd_stats->rate / num_samples;
565   const double dist_norm = (double)rd_stats->dist / num_samples;
566 
567   fprintf(fout, "%g %g", rate_norm, dist_norm);
568 
569   const int src_stride = p->src.stride;
570   const uint8_t *const src =
571       &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
572   const int dst_stride = pd->dst.stride;
573   const uint8_t *const dst =
574       &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
575   unsigned int sse;
576   cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
577   const double sse_norm = (double)sse / num_samples;
578 
579   const unsigned int sad =
580       cpi->ppi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
581   const double sad_norm = (double)sad / num_samples;
582 
583   fprintf(fout, " %g %g", sse_norm, sad_norm);
584 
585   const int diff_stride = block_size_wide[plane_bsize];
586   const int16_t *const src_diff =
587       &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
588 
589   double sse_norm_arr[4], sad_norm_arr[4];
590   get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
591                                    dst_stride, src_diff, diff_stride,
592                                    sse_norm_arr, sad_norm_arr);
593   for (int i = 0; i < 4; ++i) {
594     fprintf(fout, " %g", sse_norm_arr[i]);
595   }
596   for (int i = 0; i < 4; ++i) {
597     fprintf(fout, " %g", sad_norm_arr[i]);
598   }
599 
600   const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
601   const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
602 
603   fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
604           tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
605 
606   int model_rate;
607   int64_t model_dist;
608   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
609                                    &model_rate, &model_dist);
610   const double model_rate_norm = (double)model_rate / num_samples;
611   const double model_dist_norm = (double)model_dist / num_samples;
612   fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
613 
614   const double mean = get_mean(src_diff, diff_stride, txw, txh);
615   float hor_corr, vert_corr;
616   av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
617                                   &vert_corr);
618   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
619 
620   double hdist[4] = { 0 }, vdist[4] = { 0 };
621   get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
622                                1, hdist, vdist);
623   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
624           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
625 
626   fprintf(fout, " %d %" PRId64, x->rdmult, rd);
627 
628   fprintf(fout, "\n");
629   fclose(fout);
630 }
631 #endif  // CONFIG_COLLECT_RD_STATS == 1
632 
633 #if CONFIG_COLLECT_RD_STATS >= 2
get_sse(const AV1_COMP * cpi,const MACROBLOCK * x)634 static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
635   const AV1_COMMON *cm = &cpi->common;
636   const int num_planes = av1_num_planes(cm);
637   const MACROBLOCKD *xd = &x->e_mbd;
638   const MB_MODE_INFO *mbmi = xd->mi[0];
639   int64_t total_sse = 0;
640   for (int plane = 0; plane < num_planes; ++plane) {
641     const struct macroblock_plane *const p = &x->plane[plane];
642     const struct macroblockd_plane *const pd = &xd->plane[plane];
643     const BLOCK_SIZE bs =
644         get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
645     unsigned int sse;
646 
647     if (x->skip_chroma_rd && plane) continue;
648 
649     cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf,
650                             pd->dst.stride, &sse);
651     total_sse += sse;
652   }
653   total_sse <<= 4;
654   return total_sse;
655 }
656 
get_est_rate_dist(const TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int * est_residue_cost,int64_t * est_dist)657 static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
658                              int64_t sse, int *est_residue_cost,
659                              int64_t *est_dist) {
660   const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
661   if (md->ready) {
662     if (sse < md->dist_mean) {
663       *est_residue_cost = 0;
664       *est_dist = sse;
665     } else {
666       *est_dist = (int64_t)round(md->dist_mean);
667       const double est_ld = md->a * sse + md->b;
668       // Clamp estimated rate cost by INT_MAX / 2.
669       // TODO(angiebird@google.com): find better solution than clamping.
670       if (fabs(est_ld) < 1e-2) {
671         *est_residue_cost = INT_MAX / 2;
672       } else {
673         double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
674         if (est_residue_cost_dbl < 0) {
675           *est_residue_cost = 0;
676         } else {
677           *est_residue_cost =
678               (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
679         }
680       }
681       if (*est_residue_cost <= 0) {
682         *est_residue_cost = 0;
683         *est_dist = sse;
684       }
685     }
686     return 1;
687   }
688   return 0;
689 }
690 
get_highbd_diff_mean(const uint8_t * src8,int src_stride,const uint8_t * dst8,int dst_stride,int w,int h)691 static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
692                                    const uint8_t *dst8, int dst_stride, int w,
693                                    int h) {
694   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
695   const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
696   double sum = 0.0;
697   for (int j = 0; j < h; ++j) {
698     for (int i = 0; i < w; ++i) {
699       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
700       sum += diff;
701     }
702   }
703   assert(w > 0 && h > 0);
704   return sum / (w * h);
705 }
706 
get_diff_mean(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int w,int h)707 static double get_diff_mean(const uint8_t *src, int src_stride,
708                             const uint8_t *dst, int dst_stride, int w, int h) {
709   double sum = 0.0;
710   for (int j = 0; j < h; ++j) {
711     for (int i = 0; i < w; ++i) {
712       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
713       sum += diff;
714     }
715   }
716   assert(w > 0 && h > 0);
717   return sum / (w * h);
718 }
719 
PrintPredictionUnitStats(const AV1_COMP * const cpi,const TileDataEnc * tile_data,MACROBLOCK * x,const RD_STATS * const rd_stats,BLOCK_SIZE plane_bsize)720 static AOM_INLINE void PrintPredictionUnitStats(const AV1_COMP *const cpi,
721                                                 const TileDataEnc *tile_data,
722                                                 MACROBLOCK *x,
723                                                 const RD_STATS *const rd_stats,
724                                                 BLOCK_SIZE plane_bsize) {
725   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
726 
727   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
728       (tile_data == NULL ||
729        !tile_data->inter_mode_rd_models[plane_bsize].ready))
730     return;
731   (void)tile_data;
732   // Generate small sample to restrict output size.
733   static unsigned int seed = 95014;
734 
735   if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
736       1)
737     return;
738 
739   const char output_file[] = "pu_stats.txt";
740   FILE *fout = fopen(output_file, "a");
741   if (!fout) return;
742 
743   MACROBLOCKD *const xd = &x->e_mbd;
744   const int plane = 0;
745   struct macroblock_plane *const p = &x->plane[plane];
746   struct macroblockd_plane *pd = &xd->plane[plane];
747   const int diff_stride = block_size_wide[plane_bsize];
748   int bw, bh;
749   get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
750                      &bh);
751   const int num_samples = bw * bh;
752   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
753   const int q_step = p->dequant_QTX[1] >> dequant_shift;
754   const int shift = (xd->bd - 8);
755 
756   const double rate_norm = (double)rd_stats->rate / num_samples;
757   const double dist_norm = (double)rd_stats->dist / num_samples;
758   const double rdcost_norm =
759       (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
760 
761   fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
762 
763   const int src_stride = p->src.stride;
764   const uint8_t *const src = p->src.buf;
765   const int dst_stride = pd->dst.stride;
766   const uint8_t *const dst = pd->dst.buf;
767   const int16_t *const src_diff = p->src_diff;
768 
769   int64_t sse = calculate_sse(xd, p, pd, bw, bh);
770   const double sse_norm = (double)sse / num_samples;
771 
772   const unsigned int sad =
773       cpi->ppi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
774   const double sad_norm =
775       (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
776 
777   fprintf(fout, " %g %g", sse_norm, sad_norm);
778 
779   double sse_norm_arr[4], sad_norm_arr[4];
780   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
781                                    dst_stride, src_diff, diff_stride,
782                                    sse_norm_arr, sad_norm_arr);
783   if (shift) {
784     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
785     for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
786   }
787   for (int i = 0; i < 4; ++i) {
788     fprintf(fout, " %g", sse_norm_arr[i]);
789   }
790   for (int i = 0; i < 4; ++i) {
791     fprintf(fout, " %g", sad_norm_arr[i]);
792   }
793 
794   fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
795 
796   int model_rate;
797   int64_t model_dist;
798   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
799                                    &model_rate, &model_dist);
800   const double model_rdcost_norm =
801       (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
802   const double model_rate_norm = (double)model_rate / num_samples;
803   const double model_dist_norm = (double)model_dist / num_samples;
804   fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
805           model_rdcost_norm);
806 
807   double mean;
808   if (is_cur_buf_hbd(xd)) {
809     mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
810                                 pd->dst.stride, bw, bh);
811   } else {
812     mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
813                          bw, bh);
814   }
815   mean /= (1 << shift);
816   float hor_corr, vert_corr;
817   av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
818                                   &vert_corr);
819   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
820 
821   double hdist[4] = { 0 }, vdist[4] = { 0 };
822   get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
823                                dst_stride, 1, hdist, vdist);
824   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
825           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
826 
827   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
828     assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
829     const int64_t overall_sse = get_sse(cpi, x);
830     int est_residue_cost = 0;
831     int64_t est_dist = 0;
832     get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
833                       &est_dist);
834     const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
835     const double est_dist_norm = (double)est_dist / num_samples;
836     const double est_rdcost_norm =
837         (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
838     fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
839             est_rdcost_norm);
840   }
841 
842   fprintf(fout, "\n");
843   fclose(fout);
844 }
845 #endif  // CONFIG_COLLECT_RD_STATS >= 2
846 #endif  // CONFIG_COLLECT_RD_STATS
847 
inverse_transform_block_facade(MACROBLOCK * const x,int plane,int block,int blk_row,int blk_col,int eob,int reduced_tx_set)848 static AOM_INLINE void inverse_transform_block_facade(MACROBLOCK *const x,
849                                                       int plane, int block,
850                                                       int blk_row, int blk_col,
851                                                       int eob,
852                                                       int reduced_tx_set) {
853   if (!eob) return;
854   struct macroblock_plane *const p = &x->plane[plane];
855   MACROBLOCKD *const xd = &x->e_mbd;
856   tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
857   const PLANE_TYPE plane_type = get_plane_type(plane);
858   const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
859   const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
860                                           tx_size, reduced_tx_set);
861 
862   struct macroblockd_plane *const pd = &xd->plane[plane];
863   const int dst_stride = pd->dst.stride;
864   uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
865   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
866                               dst_stride, eob, reduced_tx_set);
867 }
868 
recon_intra(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,int skip_trellis,TX_TYPE best_tx_type,int do_quant,int * rate_cost,uint16_t best_eob)869 static INLINE void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
870                                int block, int blk_row, int blk_col,
871                                BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
872                                const TXB_CTX *const txb_ctx, int skip_trellis,
873                                TX_TYPE best_tx_type, int do_quant,
874                                int *rate_cost, uint16_t best_eob) {
875   const AV1_COMMON *cm = &cpi->common;
876   MACROBLOCKD *xd = &x->e_mbd;
877   MB_MODE_INFO *mbmi = xd->mi[0];
878   const int is_inter = is_inter_block(mbmi);
879   if (!is_inter && best_eob &&
880       (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
881        blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
882     // if the quantized coefficients are stored in the dqcoeff buffer, we don't
883     // need to do transform and quantization again.
884     if (do_quant) {
885       TxfmParam txfm_param_intra;
886       QUANT_PARAM quant_param_intra;
887       av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
888       av1_setup_quant(tx_size, !skip_trellis,
889                       skip_trellis
890                           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
891                                                     : AV1_XFORM_QUANT_FP)
892                           : AV1_XFORM_QUANT_FP,
893                       cpi->oxcf.q_cfg.quant_b_adapt, &quant_param_intra);
894       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
895                         &quant_param_intra);
896       av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
897                       &txfm_param_intra, &quant_param_intra);
898       if (quant_param_intra.use_optimize_b) {
899         av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
900                        rate_cost);
901       }
902     }
903 
904     inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
905                                    x->plane[plane].eobs[block],
906                                    cm->features.reduced_tx_set_used);
907 
908     // This may happen because of hash collision. The eob stored in the hash
909     // table is non-zero, but the real eob is zero. We need to make sure tx_type
910     // is DCT_DCT in this case.
911     if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
912         best_tx_type != DCT_DCT) {
913       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
914     }
915   }
916 }
917 
pixel_dist_visible_only(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,const BLOCK_SIZE tx_bsize,int txb_rows,int txb_cols,int visible_rows,int visible_cols)918 static unsigned pixel_dist_visible_only(
919     const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
920     const int src_stride, const uint8_t *dst, const int dst_stride,
921     const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
922     int visible_cols) {
923   unsigned sse;
924 
925   if (txb_rows == visible_rows && txb_cols == visible_cols) {
926     cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
927     return sse;
928   }
929 
930 #if CONFIG_AV1_HIGHBITDEPTH
931   const MACROBLOCKD *xd = &x->e_mbd;
932   if (is_cur_buf_hbd(xd)) {
933     uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
934                                              visible_cols, visible_rows);
935     return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
936   }
937 #else
938   (void)x;
939 #endif
940   sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
941                          visible_rows);
942   return sse;
943 }
944 
945 // Compute the pixel domain distortion from src and dst on all visible 4x4s in
946 // the
947 // transform block.
pixel_dist(const AV1_COMP * const cpi,const MACROBLOCK * x,int plane,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize)948 static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
949                            int plane, const uint8_t *src, const int src_stride,
950                            const uint8_t *dst, const int dst_stride,
951                            int blk_row, int blk_col,
952                            const BLOCK_SIZE plane_bsize,
953                            const BLOCK_SIZE tx_bsize) {
954   int txb_rows, txb_cols, visible_rows, visible_cols;
955   const MACROBLOCKD *xd = &x->e_mbd;
956 
957   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
958                      &txb_cols, &txb_rows, &visible_cols, &visible_rows);
959   assert(visible_rows > 0);
960   assert(visible_cols > 0);
961 
962   unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
963                                          dst_stride, tx_bsize, txb_rows,
964                                          txb_cols, visible_rows, visible_cols);
965 
966   return sse;
967 }
968 
dist_block_px_domain(const AV1_COMP * cpi,MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,int block,int blk_row,int blk_col,TX_SIZE tx_size)969 static INLINE int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
970                                            int plane, BLOCK_SIZE plane_bsize,
971                                            int block, int blk_row, int blk_col,
972                                            TX_SIZE tx_size) {
973   MACROBLOCKD *const xd = &x->e_mbd;
974   const struct macroblock_plane *const p = &x->plane[plane];
975   const uint16_t eob = p->eobs[block];
976   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
977   const int bsw = block_size_wide[tx_bsize];
978   const int bsh = block_size_high[tx_bsize];
979   const int src_stride = x->plane[plane].src.stride;
980   const int dst_stride = xd->plane[plane].dst.stride;
981   // Scale the transform block index to pixel unit.
982   const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
983   const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
984   const uint8_t *src = &x->plane[plane].src.buf[src_idx];
985   const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
986   const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
987 
988   assert(cpi != NULL);
989   assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
990 
991   uint8_t *recon;
992   DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
993 
994 #if CONFIG_AV1_HIGHBITDEPTH
995   if (is_cur_buf_hbd(xd)) {
996     recon = CONVERT_TO_BYTEPTR(recon16);
997     aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst), dst_stride,
998                              CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw, bsh);
999   } else {
1000     recon = (uint8_t *)recon16;
1001     aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1002   }
1003 #else
1004   recon = (uint8_t *)recon16;
1005   aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1006 #endif
1007 
1008   const PLANE_TYPE plane_type = get_plane_type(plane);
1009   TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
1010                                     cpi->common.features.reduced_tx_set_used);
1011   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
1012                               MAX_TX_SIZE, eob,
1013                               cpi->common.features.reduced_tx_set_used);
1014 
1015   return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
1016                          blk_row, blk_col, plane_bsize, tx_bsize);
1017 }
1018 
1019 // pruning thresholds for prune_txk_type and prune_txk_type_separ
1020 static const int prune_factors[5] = { 200, 200, 120, 80, 40 };  // scale 1000
1021 static const int mul_factors[5] = { 80, 80, 70, 50, 30 };       // scale 100
1022 
1023 // R-D costs are sorted in ascending order.
sort_rd(int64_t rds[],int txk[],int len)1024 static INLINE void sort_rd(int64_t rds[], int txk[], int len) {
1025   int i, j, k;
1026 
1027   for (i = 1; i <= len - 1; ++i) {
1028     for (j = 0; j < i; ++j) {
1029       if (rds[j] > rds[i]) {
1030         int64_t temprd;
1031         int tempi;
1032 
1033         temprd = rds[i];
1034         tempi = txk[i];
1035 
1036         for (k = i; k > j; k--) {
1037           rds[k] = rds[k - 1];
1038           txk[k] = txk[k - 1];
1039         }
1040 
1041         rds[j] = temprd;
1042         txk[j] = tempi;
1043         break;
1044       }
1045     }
1046   }
1047 }
1048 
av1_block_error_qm(const tran_low_t * coeff,const tran_low_t * dqcoeff,intptr_t block_size,const qm_val_t * qmatrix,const int16_t * scan,int64_t * ssz)1049 static INLINE int64_t av1_block_error_qm(const tran_low_t *coeff,
1050                                          const tran_low_t *dqcoeff,
1051                                          intptr_t block_size,
1052                                          const qm_val_t *qmatrix,
1053                                          const int16_t *scan, int64_t *ssz) {
1054   int i;
1055   int64_t error = 0, sqcoeff = 0;
1056 
1057   for (i = 0; i < block_size; i++) {
1058     int64_t weight = qmatrix[scan[i]];
1059     int64_t dd = coeff[i] - dqcoeff[i];
1060     dd *= weight;
1061     int64_t cc = coeff[i];
1062     cc *= weight;
1063     // The ranges of coeff and dqcoeff are
1064     //  bd8 : 18 bits (including sign)
1065     //  bd10: 20 bits (including sign)
1066     //  bd12: 22 bits (including sign)
1067     // As AOM_QM_BITS is 5, the intermediate quantities in the calculation
1068     // below should fit in 54 bits, thus no overflow should happen.
1069     error += (dd * dd + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1070     sqcoeff += (cc * cc + (1 << (2 * AOM_QM_BITS - 1))) >> (2 * AOM_QM_BITS);
1071   }
1072 
1073   *ssz = sqcoeff;
1074   return error;
1075 }
1076 
dist_block_tx_domain(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,const qm_val_t * qmatrix,const int16_t * scan,int64_t * out_dist,int64_t * out_sse)1077 static INLINE void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
1078                                         TX_SIZE tx_size,
1079                                         const qm_val_t *qmatrix,
1080                                         const int16_t *scan, int64_t *out_dist,
1081                                         int64_t *out_sse) {
1082   const struct macroblock_plane *const p = &x->plane[plane];
1083   // Transform domain distortion computation is more efficient as it does
1084   // not involve an inverse transform, but it is less accurate.
1085   const int buffer_length = av1_get_max_eob(tx_size);
1086   int64_t this_sse;
1087   // TX-domain results need to shift down to Q2/D10 to match pixel
1088   // domain distortion values which are in Q2^2
1089   int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
1090   const int block_offset = BLOCK_OFFSET(block);
1091   tran_low_t *const coeff = p->coeff + block_offset;
1092   tran_low_t *const dqcoeff = p->dqcoeff + block_offset;
1093 #if CONFIG_AV1_HIGHBITDEPTH
1094   MACROBLOCKD *const xd = &x->e_mbd;
1095   if (is_cur_buf_hbd(xd)) {
1096     // TODO(veluca): handle use_qm_dist_metric for HBD too.
1097     *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse,
1098                                        xd->bd);
1099   } else {
1100 #endif
1101     if (qmatrix == NULL || !x->txfm_search_params.use_qm_dist_metric) {
1102       *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
1103     } else {
1104       *out_dist = av1_block_error_qm(coeff, dqcoeff, buffer_length, qmatrix,
1105                                      scan, &this_sse);
1106     }
1107 #if CONFIG_AV1_HIGHBITDEPTH
1108   }
1109 #endif
1110 
1111   *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
1112   *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
1113 }
1114 
prune_txk_type_separ(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,int16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used,int64_t ref_best_rd,int num_sel)1115 uint16_t prune_txk_type_separ(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1116                               int block, TX_SIZE tx_size, int blk_row,
1117                               int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
1118                               int16_t allowed_tx_mask, int prune_factor,
1119                               const TXB_CTX *const txb_ctx,
1120                               int reduced_tx_set_used, int64_t ref_best_rd,
1121                               int num_sel) {
1122   const AV1_COMMON *cm = &cpi->common;
1123   MACROBLOCKD *xd = &x->e_mbd;
1124 
1125   int idx;
1126 
1127   int64_t rds_v[4];
1128   int64_t rds_h[4];
1129   int idx_v[4] = { 0, 1, 2, 3 };
1130   int idx_h[4] = { 0, 1, 2, 3 };
1131   int skip_v[4] = { 0 };
1132   int skip_h[4] = { 0 };
1133   const int idx_map[16] = {
1134     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1135     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1136     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1137     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1138   };
1139 
1140   const int sel_pattern_v[16] = {
1141     0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
1142   };
1143   const int sel_pattern_h[16] = {
1144     0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
1145   };
1146 
1147   QUANT_PARAM quant_param;
1148   TxfmParam txfm_param;
1149   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1150   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1151                   &quant_param);
1152   int tx_type;
1153   // to ensure we can try ones even outside of ext_tx_set of current block
1154   // this function should only be called for size < 16
1155   assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
1156   txfm_param.tx_set_type = EXT_TX_SET_ALL16;
1157 
1158   int rate_cost = 0;
1159   int64_t dist = 0, sse = 0;
1160   // evaluate horizontal with vertical DCT
1161   for (idx = 0; idx < 4; ++idx) {
1162     tx_type = idx_map[idx];
1163     txfm_param.tx_type = tx_type;
1164 
1165     av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1166                       &quant_param);
1167 
1168     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1169                     &quant_param);
1170 
1171     const SCAN_ORDER *const scan_order =
1172         get_scan(txfm_param.tx_size, txfm_param.tx_type);
1173     dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1174                          scan_order->scan, &dist, &sse);
1175 
1176     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1177                                               txb_ctx, reduced_tx_set_used, 0);
1178 
1179     rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
1180 
1181     if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
1182       skip_h[idx] = 1;
1183     }
1184   }
1185   sort_rd(rds_h, idx_h, 4);
1186   for (idx = 1; idx < 4; idx++) {
1187     if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
1188   }
1189 
1190   if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
1191 
1192   // evaluate vertical with the best horizontal chosen
1193   rds_v[0] = rds_h[0];
1194   int start_v = 1, end_v = 4;
1195   const int *idx_map_v = idx_map + idx_h[0];
1196 
1197   for (idx = start_v; idx < end_v; ++idx) {
1198     tx_type = idx_map_v[idx_v[idx] * 4];
1199     txfm_param.tx_type = tx_type;
1200 
1201     av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1202                       &quant_param);
1203 
1204     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1205                     &quant_param);
1206 
1207     const SCAN_ORDER *const scan_order =
1208         get_scan(txfm_param.tx_size, txfm_param.tx_type);
1209     dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1210                          scan_order->scan, &dist, &sse);
1211 
1212     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1213                                               txb_ctx, reduced_tx_set_used, 0);
1214 
1215     rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
1216 
1217     if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
1218       skip_v[idx] = 1;
1219     }
1220   }
1221   sort_rd(rds_v, idx_v, 4);
1222   for (idx = 1; idx < 4; idx++) {
1223     if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
1224   }
1225 
1226   // combine rd_h and rd_v to prune tx candidates
1227   int i_v, i_h;
1228   int64_t rds[16];
1229   int num_cand = 0, last = TX_TYPES - 1;
1230 
1231   for (int i = 0; i < 16; i++) {
1232     i_v = sel_pattern_v[i];
1233     i_h = sel_pattern_h[i];
1234     tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
1235     if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
1236         skip_v[idx_v[i_v]]) {
1237       txk_map[last] = tx_type;
1238       last--;
1239     } else {
1240       txk_map[num_cand] = tx_type;
1241       rds[num_cand] = rds_v[i_v] + rds_h[i_h];
1242       if (rds[num_cand] == 0) rds[num_cand] = 1;
1243       num_cand++;
1244     }
1245   }
1246   sort_rd(rds, txk_map, num_cand);
1247 
1248   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1249   num_sel = AOMMIN(num_sel, num_cand);
1250 
1251   for (int i = 1; i < num_sel; i++) {
1252     int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
1253     if (factor < (int64_t)prune_factor)
1254       prune &= ~(1 << txk_map[i]);
1255     else
1256       break;
1257   }
1258   return prune;
1259 }
1260 
prune_txk_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,uint16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)1261 uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1262                         int block, TX_SIZE tx_size, int blk_row, int blk_col,
1263                         BLOCK_SIZE plane_bsize, int *txk_map,
1264                         uint16_t allowed_tx_mask, int prune_factor,
1265                         const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
1266   const AV1_COMMON *cm = &cpi->common;
1267   MACROBLOCKD *xd = &x->e_mbd;
1268   int tx_type;
1269 
1270   int64_t rds[TX_TYPES];
1271 
1272   int num_cand = 0;
1273   int last = TX_TYPES - 1;
1274 
1275   TxfmParam txfm_param;
1276   QUANT_PARAM quant_param;
1277   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1278   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1279                   &quant_param);
1280 
1281   for (int idx = 0; idx < TX_TYPES; idx++) {
1282     tx_type = idx;
1283     int rate_cost = 0;
1284     int64_t dist = 0, sse = 0;
1285     if (!(allowed_tx_mask & (1 << tx_type))) {
1286       txk_map[last] = tx_type;
1287       last--;
1288       continue;
1289     }
1290     txfm_param.tx_type = tx_type;
1291 
1292     av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
1293                       &quant_param);
1294 
1295     // do txfm and quantization
1296     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1297                     &quant_param);
1298     // estimate rate cost
1299     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1300                                               txb_ctx, reduced_tx_set_used, 0);
1301     // tx domain dist
1302     const SCAN_ORDER *const scan_order =
1303         get_scan(txfm_param.tx_size, txfm_param.tx_type);
1304     dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
1305                          scan_order->scan, &dist, &sse);
1306 
1307     txk_map[num_cand] = tx_type;
1308     rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
1309     if (rds[num_cand] == 0) rds[num_cand] = 1;
1310     num_cand++;
1311   }
1312 
1313   if (num_cand == 0) return (uint16_t)0xFFFF;
1314 
1315   sort_rd(rds, txk_map, num_cand);
1316   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1317 
1318   // 0 < prune_factor <= 1000 controls aggressiveness
1319   int64_t factor = 0;
1320   for (int idx = 1; idx < num_cand; idx++) {
1321     factor = 1000 * (rds[idx] - rds[0]) / rds[0];
1322     if (factor < (int64_t)prune_factor)
1323       prune &= ~(1 << txk_map[idx]);
1324     else
1325       break;
1326   }
1327   return prune;
1328 }
1329 
1330 // These thresholds were calibrated to provide a certain number of TX types
1331 // pruned by the model on average, i.e. selecting a threshold with index i
1332 // will lead to pruning i+1 TX types on average
1333 static const float *prune_2D_adaptive_thresholds[] = {
1334   // TX_4X4
1335   (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1336              0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1337              0.09778f, 0.11780f },
1338   // TX_8X8
1339   (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1340              0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1341              0.10803f, 0.14124f },
1342   // TX_16X16
1343   (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1344              0.06897f, 0.07629f, 0.08875f, 0.11169f },
1345   // TX_32X32
1346   NULL,
1347   // TX_64X64
1348   NULL,
1349   // TX_4X8
1350   (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1351              0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1352              0.10168f, 0.12585f },
1353   // TX_8X4
1354   (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1355              0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1356              0.10583f, 0.13123f },
1357   // TX_8X16
1358   (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1359              0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1360              0.10730f, 0.14221f },
1361   // TX_16X8
1362   (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1363              0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1364              0.10339f, 0.13464f },
1365   // TX_16X32
1366   NULL,
1367   // TX_32X16
1368   NULL,
1369   // TX_32X64
1370   NULL,
1371   // TX_64X32
1372   NULL,
1373   // TX_4X16
1374   (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1375              0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1376              0.10242f, 0.12878f },
1377   // TX_16X4
1378   (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1379              0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1380              0.10217f, 0.12610f },
1381   // TX_8X32
1382   NULL,
1383   // TX_32X8
1384   NULL,
1385   // TX_16X64
1386   NULL,
1387   // TX_64X16
1388   NULL,
1389 };
1390 
get_adaptive_thresholds(TX_SIZE tx_size,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode)1391 static INLINE float get_adaptive_thresholds(
1392     TX_SIZE tx_size, TxSetType tx_set_type,
1393     TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) {
1394   const int prune_aggr_table[5][2] = {
1395     { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 }, { 12, 9 }
1396   };
1397   int pruning_aggressiveness = 0;
1398   if (tx_set_type == EXT_TX_SET_ALL16)
1399     pruning_aggressiveness =
1400         prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][0];
1401   else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
1402     pruning_aggressiveness =
1403         prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][1];
1404 
1405   return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
1406 }
1407 
get_energy_distribution_finer(const int16_t * diff,int stride,int bw,int bh,float * hordist,float * verdist)1408 static AOM_INLINE void get_energy_distribution_finer(const int16_t *diff,
1409                                                      int stride, int bw, int bh,
1410                                                      float *hordist,
1411                                                      float *verdist) {
1412   // First compute downscaled block energy values (esq); downscale factors
1413   // are defined by w_shift and h_shift.
1414   unsigned int esq[256];
1415   const int w_shift = bw <= 8 ? 0 : 1;
1416   const int h_shift = bh <= 8 ? 0 : 1;
1417   const int esq_w = bw >> w_shift;
1418   const int esq_h = bh >> h_shift;
1419   const int esq_sz = esq_w * esq_h;
1420   int i, j;
1421   memset(esq, 0, esq_sz * sizeof(esq[0]));
1422   if (w_shift) {
1423     for (i = 0; i < bh; i++) {
1424       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1425       const int16_t *cur_diff_row = diff + i * stride;
1426       for (j = 0; j < bw; j += 2) {
1427         cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1428                                 cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1429       }
1430     }
1431   } else {
1432     for (i = 0; i < bh; i++) {
1433       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1434       const int16_t *cur_diff_row = diff + i * stride;
1435       for (j = 0; j < bw; j++) {
1436         cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1437       }
1438     }
1439   }
1440 
1441   uint64_t total = 0;
1442   for (i = 0; i < esq_sz; i++) total += esq[i];
1443 
1444   // Output hordist and verdist arrays are normalized 1D projections of esq
1445   if (total == 0) {
1446     float hor_val = 1.0f / esq_w;
1447     for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1448     float ver_val = 1.0f / esq_h;
1449     for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1450     return;
1451   }
1452 
1453   const float e_recip = 1.0f / (float)total;
1454   memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1455   memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1456   const unsigned int *cur_esq_row;
1457   for (i = 0; i < esq_h - 1; i++) {
1458     cur_esq_row = esq + i * esq_w;
1459     for (j = 0; j < esq_w - 1; j++) {
1460       hordist[j] += (float)cur_esq_row[j];
1461       verdist[i] += (float)cur_esq_row[j];
1462     }
1463     verdist[i] += (float)cur_esq_row[j];
1464   }
1465   cur_esq_row = esq + i * esq_w;
1466   for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1467 
1468   for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1469   for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1470 }
1471 
check_bit_mask(uint16_t mask,int val)1472 static AOM_INLINE bool check_bit_mask(uint16_t mask, int val) {
1473   return mask & (1 << val);
1474 }
1475 
set_bit_mask(uint16_t * mask,int val)1476 static AOM_INLINE void set_bit_mask(uint16_t *mask, int val) {
1477   *mask |= (1 << val);
1478 }
1479 
unset_bit_mask(uint16_t * mask,int val)1480 static AOM_INLINE void unset_bit_mask(uint16_t *mask, int val) {
1481   *mask &= ~(1 << val);
1482 }
1483 
prune_tx_2D(MACROBLOCK * x,BLOCK_SIZE bsize,TX_SIZE tx_size,int blk_row,int blk_col,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode,int * txk_map,uint16_t * allowed_tx_mask)1484 static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1485                         int blk_row, int blk_col, TxSetType tx_set_type,
1486                         TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
1487                         uint16_t *allowed_tx_mask) {
1488   // This table is used because the search order is different from the enum
1489   // order.
1490   static const int tx_type_table_2D[16] = {
1491     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1492     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1493     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1494     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1495   };
1496   if (tx_set_type != EXT_TX_SET_ALL16 &&
1497       tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1498     return;
1499 #if CONFIG_NN_V2
1500   NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1501   NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1502 #else
1503   const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1504   const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1505 #endif
1506   if (!nn_config_hor || !nn_config_ver) return;  // Model not established yet.
1507 
1508   float hfeatures[16], vfeatures[16];
1509   float hscores[4], vscores[4];
1510   float scores_2D_raw[16];
1511   const int bw = tx_size_wide[tx_size];
1512   const int bh = tx_size_high[tx_size];
1513   const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1514   const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1515   assert(hfeatures_num <= 16);
1516   assert(vfeatures_num <= 16);
1517 
1518   const struct macroblock_plane *const p = &x->plane[0];
1519   const int diff_stride = block_size_wide[bsize];
1520   const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1521   get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1522                                 vfeatures);
1523 
1524   av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
1525                                   &hfeatures[hfeatures_num - 1],
1526                                   &vfeatures[vfeatures_num - 1]);
1527 
1528 #if CONFIG_NN_V2
1529   av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
1530   av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
1531 #else
1532   av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
1533   av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
1534 #endif
1535 
1536   for (int i = 0; i < 4; i++) {
1537     float *cur_scores_2D = scores_2D_raw + i * 4;
1538     cur_scores_2D[0] = vscores[i] * hscores[0];
1539     cur_scores_2D[1] = vscores[i] * hscores[1];
1540     cur_scores_2D[2] = vscores[i] * hscores[2];
1541     cur_scores_2D[3] = vscores[i] * hscores[3];
1542   }
1543 
1544   assert(TX_TYPES == 16);
1545   // This version of the function only works when there are at most 16 classes.
1546   // So we will need to change the optimization or use av1_nn_softmax instead if
1547   // this ever gets changed.
1548   av1_nn_fast_softmax_16(scores_2D_raw, scores_2D_raw);
1549 
1550   const float score_thresh =
1551       get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode);
1552 
1553   // Always keep the TX type with the highest score, prune all others with
1554   // score below score_thresh.
1555   int max_score_i = 0;
1556   float max_score = 0.0f;
1557   uint16_t allow_bitmask = 0;
1558   float sum_score = 0.0;
1559   // Calculate sum of allowed tx type score and Populate allow bit mask based
1560   // on score_thresh and allowed_tx_mask
1561   int allow_count = 0;
1562   int tx_type_allowed[16] = { TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1563                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1564                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1565                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1566                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1567                               TX_TYPE_INVALID };
1568   float scores_2D[16] = {
1569     -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
1570   };
1571   for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1572     const int allow_tx_type =
1573         check_bit_mask(*allowed_tx_mask, tx_type_table_2D[tx_idx]);
1574     if (!allow_tx_type) {
1575       continue;
1576     }
1577     if (scores_2D_raw[tx_idx] > max_score) {
1578       max_score = scores_2D_raw[tx_idx];
1579       max_score_i = tx_idx;
1580     }
1581     if (scores_2D_raw[tx_idx] >= score_thresh) {
1582       // Set allow mask based on score_thresh
1583       set_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]);
1584 
1585       // Accumulate score of allowed tx type
1586       sum_score += scores_2D_raw[tx_idx];
1587 
1588       scores_2D[allow_count] = scores_2D_raw[tx_idx];
1589       tx_type_allowed[allow_count] = tx_type_table_2D[tx_idx];
1590       allow_count += 1;
1591     }
1592   }
1593   if (!check_bit_mask(allow_bitmask, tx_type_table_2D[max_score_i])) {
1594     // If even the tx_type with max score is pruned, this means that no other
1595     // tx_type is feasible. When this happens, we force enable max_score_i and
1596     // end the search.
1597     set_bit_mask(&allow_bitmask, tx_type_table_2D[max_score_i]);
1598     memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
1599     *allowed_tx_mask = allow_bitmask;
1600     return;
1601   }
1602 
1603   // Sort tx type probability of all types
1604   if (allow_count <= 8) {
1605     av1_sort_fi32_8(scores_2D, tx_type_allowed);
1606   } else {
1607     av1_sort_fi32_16(scores_2D, tx_type_allowed);
1608   }
1609 
1610   // Enable more pruning based on tx type probability and number of allowed tx
1611   // types
1612   if (prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) {
1613     float temp_score = 0.0;
1614     float score_ratio = 0.0;
1615     int tx_idx, tx_count = 0;
1616     const float inv_sum_score = 100 / sum_score;
1617     // Get allowed tx types based on sorted probability score and tx count
1618     for (tx_idx = 0; tx_idx < allow_count; tx_idx++) {
1619       // Skip the tx type which has more than 30% of cumulative
1620       // probability and allowed tx type count is more than 2
1621       if (score_ratio > 30.0 && tx_count >= 2) break;
1622 
1623       assert(check_bit_mask(allow_bitmask, tx_type_allowed[tx_idx]));
1624       // Calculate cumulative probability
1625       temp_score += scores_2D[tx_idx];
1626 
1627       // Calculate percentage of cumulative probability of allowed tx type
1628       score_ratio = temp_score * inv_sum_score;
1629       tx_count++;
1630     }
1631     // Set remaining tx types as pruned
1632     for (; tx_idx < allow_count; tx_idx++)
1633       unset_bit_mask(&allow_bitmask, tx_type_allowed[tx_idx]);
1634   }
1635 
1636   memcpy(txk_map, tx_type_allowed, sizeof(tx_type_table_2D));
1637   *allowed_tx_mask = allow_bitmask;
1638 }
1639 
get_dev(float mean,double x2_sum,int num)1640 static float get_dev(float mean, double x2_sum, int num) {
1641   const float e_x2 = (float)(x2_sum / num);
1642   const float diff = e_x2 - mean * mean;
1643   const float dev = (diff > 0) ? sqrtf(diff) : 0;
1644   return dev;
1645 }
1646 
1647 // Writes the features required by the ML model to predict tx split based on
1648 // mean and standard deviation values of the block and sub-blocks.
1649 // Returns the number of elements written to the output array which is at most
1650 // 12 currently. Hence 'features' buffer should be able to accommodate at least
1651 // 12 elements.
get_mean_dev_features(const int16_t * data,int stride,int bw,int bh,float * features)1652 static AOM_INLINE int get_mean_dev_features(const int16_t *data, int stride,
1653                                             int bw, int bh, float *features) {
1654   const int16_t *const data_ptr = &data[0];
1655   const int subh = (bh >= bw) ? (bh >> 1) : bh;
1656   const int subw = (bw >= bh) ? (bw >> 1) : bw;
1657   const int num = bw * bh;
1658   const int sub_num = subw * subh;
1659   int feature_idx = 2;
1660   int total_x_sum = 0;
1661   int64_t total_x2_sum = 0;
1662   int num_sub_blks = 0;
1663   double mean2_sum = 0.0f;
1664   float dev_sum = 0.0f;
1665 
1666   for (int row = 0; row < bh; row += subh) {
1667     for (int col = 0; col < bw; col += subw) {
1668       int x_sum;
1669       int64_t x2_sum;
1670       // TODO(any): Write a SIMD version. Clear registers.
1671       aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
1672                           &x_sum, &x2_sum);
1673       total_x_sum += x_sum;
1674       total_x2_sum += x2_sum;
1675 
1676       const float mean = (float)x_sum / sub_num;
1677       const float dev = get_dev(mean, (double)x2_sum, sub_num);
1678       features[feature_idx++] = mean;
1679       features[feature_idx++] = dev;
1680       mean2_sum += (double)(mean * mean);
1681       dev_sum += dev;
1682       num_sub_blks++;
1683     }
1684   }
1685 
1686   const float lvl0_mean = (float)total_x_sum / num;
1687   features[0] = lvl0_mean;
1688   features[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
1689 
1690   // Deviation of means.
1691   features[feature_idx++] = get_dev(lvl0_mean, mean2_sum, num_sub_blks);
1692   // Mean of deviations.
1693   features[feature_idx++] = dev_sum / num_sub_blks;
1694 
1695   return feature_idx;
1696 }
1697 
ml_predict_tx_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size)1698 static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
1699                                int blk_col, TX_SIZE tx_size) {
1700   const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
1701   if (!nn_config) return -1;
1702 
1703   const int diff_stride = block_size_wide[bsize];
1704   const int16_t *diff =
1705       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1706   const int bw = tx_size_wide[tx_size];
1707   const int bh = tx_size_high[tx_size];
1708 
1709   float features[64] = { 0.0f };
1710   get_mean_dev_features(diff, diff_stride, bw, bh, features);
1711 
1712   float score = 0.0f;
1713   av1_nn_predict(features, nn_config, 1, &score);
1714 
1715   int int_score = (int)(score * 10000);
1716   return clamp(int_score, -80000, 80000);
1717 }
1718 
1719 static INLINE uint16_t
get_tx_mask(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_best_rd,TX_TYPE * allowed_txk_types,int * txk_map)1720 get_tx_mask(const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block,
1721             int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1722             const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
1723             int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {
1724   const AV1_COMMON *cm = &cpi->common;
1725   MACROBLOCKD *xd = &x->e_mbd;
1726   MB_MODE_INFO *mbmi = xd->mi[0];
1727   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
1728   const int is_inter = is_inter_block(mbmi);
1729   const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
1730   // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
1731   // TX_TYPES, only that specific tx type is allowed.
1732   TX_TYPE txk_allowed = TX_TYPES;
1733 
1734   const FRAME_UPDATE_TYPE update_type =
1735       get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index);
1736   int use_actual_frame_probs = 1;
1737   const int *tx_type_probs;
1738 #if CONFIG_FPMT_TEST
1739   use_actual_frame_probs =
1740       (cpi->ppi->fpmt_unit_test_cfg == PARALLEL_SIMULATION_ENCODE) ? 0 : 1;
1741   if (!use_actual_frame_probs) {
1742     tx_type_probs =
1743         (int *)cpi->ppi->temp_frame_probs.tx_type_probs[update_type][tx_size];
1744   }
1745 #endif
1746   if (use_actual_frame_probs) {
1747     tx_type_probs = cpi->ppi->frame_probs.tx_type_probs[update_type][tx_size];
1748   }
1749 
1750   if ((!is_inter && txfm_params->use_default_intra_tx_type) ||
1751       (is_inter && txfm_params->default_inter_tx_type_prob_thresh == 0)) {
1752     txk_allowed =
1753         get_default_tx_type(0, xd, tx_size, cpi->use_screen_content_tools);
1754   } else if (is_inter &&
1755              txfm_params->default_inter_tx_type_prob_thresh != INT_MAX) {
1756     if (tx_type_probs[DEFAULT_INTER_TX_TYPE] >
1757         txfm_params->default_inter_tx_type_prob_thresh) {
1758       txk_allowed = DEFAULT_INTER_TX_TYPE;
1759     } else {
1760       int force_tx_type = 0;
1761       int max_prob = 0;
1762       const int tx_type_prob_threshold =
1763           txfm_params->default_inter_tx_type_prob_thresh +
1764           PROB_THRESH_OFFSET_TX_TYPE;
1765       for (int i = 1; i < TX_TYPES; i++) {  // find maximum probability.
1766         if (tx_type_probs[i] > max_prob) {
1767           max_prob = tx_type_probs[i];
1768           force_tx_type = i;
1769         }
1770       }
1771       if (max_prob > tx_type_prob_threshold)  // force tx type with max prob.
1772         txk_allowed = force_tx_type;
1773       else if (x->rd_model == LOW_TXFM_RD) {
1774         if (plane == 0) txk_allowed = DCT_DCT;
1775       }
1776     }
1777   } else if (x->rd_model == LOW_TXFM_RD) {
1778     if (plane == 0) txk_allowed = DCT_DCT;
1779   }
1780 
1781   const TxSetType tx_set_type = av1_get_ext_tx_set_type(
1782       tx_size, is_inter, cm->features.reduced_tx_set_used);
1783 
1784   TX_TYPE uv_tx_type = DCT_DCT;
1785   if (plane) {
1786     // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
1787     uv_tx_type = txk_allowed =
1788         av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1789                         cm->features.reduced_tx_set_used);
1790   }
1791   PREDICTION_MODE intra_dir =
1792       mbmi->filter_intra_mode_info.use_filter_intra
1793           ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
1794           : mbmi->mode;
1795   uint16_t ext_tx_used_flag =
1796       cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset != 0 &&
1797               tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
1798           ? av1_reduced_intra_tx_used_flag[intra_dir]
1799           : av1_ext_tx_used_flag[tx_set_type];
1800 
1801   if (cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset == 2)
1802     ext_tx_used_flag &= av1_derived_intra_tx_used_flag[intra_dir];
1803 
1804   if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
1805       ext_tx_used_flag == 0x0001 ||
1806       (is_inter && cpi->oxcf.txfm_cfg.use_inter_dct_only) ||
1807       (!is_inter && cpi->oxcf.txfm_cfg.use_intra_dct_only)) {
1808     txk_allowed = DCT_DCT;
1809   }
1810 
1811   if (cpi->oxcf.txfm_cfg.enable_flip_idtx == 0)
1812     ext_tx_used_flag &= DCT_ADST_TX_MASK;
1813 
1814   uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
1815   if (txk_allowed < TX_TYPES) {
1816     allowed_tx_mask = 1 << txk_allowed;
1817     allowed_tx_mask &= ext_tx_used_flag;
1818   } else if (fast_tx_search) {
1819     allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
1820     allowed_tx_mask &= ext_tx_used_flag;
1821   } else {
1822     assert(plane == 0);
1823     allowed_tx_mask = ext_tx_used_flag;
1824     int num_allowed = 0;
1825     int i;
1826 
1827     if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
1828       static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 },
1829                                             { 10, 17, 17, 10, 17, 17, 17 } };
1830       const int thresh =
1831           thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1]
1832                     [update_type];
1833       uint16_t prune = 0;
1834       int max_prob = -1;
1835       int max_idx = 0;
1836       for (i = 0; i < TX_TYPES; i++) {
1837         if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
1838           max_prob = tx_type_probs[i];
1839           max_idx = i;
1840         }
1841         if (tx_type_probs[i] < thresh) prune |= (1 << i);
1842       }
1843       if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx);
1844       allowed_tx_mask &= (~prune);
1845     }
1846     for (i = 0; i < TX_TYPES; i++) {
1847       if (allowed_tx_mask & (1 << i)) num_allowed++;
1848     }
1849     assert(num_allowed > 0);
1850 
1851     if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
1852       int pf = prune_factors[txfm_params->prune_2d_txfm_mode];
1853       int mf = mul_factors[txfm_params->prune_2d_txfm_mode];
1854       if (num_allowed <= 7) {
1855         const uint16_t prune =
1856             prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col,
1857                            plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx,
1858                            cm->features.reduced_tx_set_used);
1859         allowed_tx_mask &= (~prune);
1860       } else {
1861         const int num_sel = (num_allowed * mf + 50) / 100;
1862         const uint16_t prune = prune_txk_type_separ(
1863             cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
1864             txk_map, allowed_tx_mask, pf, txb_ctx,
1865             cm->features.reduced_tx_set_used, ref_best_rd, num_sel);
1866 
1867         allowed_tx_mask &= (~prune);
1868       }
1869     } else {
1870       assert(num_allowed > 0);
1871       int allowed_tx_count =
1872           (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) ? 1 : 5;
1873       // !fast_tx_search && txk_end != txk_start && plane == 0
1874       if (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_1 && is_inter &&
1875           num_allowed > allowed_tx_count) {
1876         prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
1877                     txfm_params->prune_2d_txfm_mode, txk_map, &allowed_tx_mask);
1878       }
1879     }
1880   }
1881 
1882   // Need to have at least one transform type allowed.
1883   if (allowed_tx_mask == 0) {
1884     txk_allowed = (plane ? uv_tx_type : DCT_DCT);
1885     allowed_tx_mask = (1 << txk_allowed);
1886   }
1887 
1888   assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
1889   *allowed_txk_types = txk_allowed;
1890   return allowed_tx_mask;
1891 }
1892 
1893 #if CONFIG_RD_DEBUG
update_txb_coeff_cost(RD_STATS * rd_stats,int plane,int txb_coeff_cost)1894 static INLINE void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
1895                                          int txb_coeff_cost) {
1896   rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
1897 }
1898 #endif
1899 
cost_coeffs(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,const TX_TYPE tx_type,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)1900 static INLINE int cost_coeffs(MACROBLOCK *x, int plane, int block,
1901                               TX_SIZE tx_size, const TX_TYPE tx_type,
1902                               const TXB_CTX *const txb_ctx,
1903                               int reduced_tx_set_used) {
1904 #if TXCOEFF_COST_TIMER
1905   struct aom_usec_timer timer;
1906   aom_usec_timer_start(&timer);
1907 #endif
1908   const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type,
1909                                        txb_ctx, reduced_tx_set_used);
1910 #if TXCOEFF_COST_TIMER
1911   AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
1912   aom_usec_timer_mark(&timer);
1913   const int64_t elapsed_time = aom_usec_timer_elapsed(&timer);
1914   tmp_cm->txcoeff_cost_timer += elapsed_time;
1915   ++tmp_cm->txcoeff_cost_count;
1916 #endif
1917   return cost;
1918 }
1919 
skip_trellis_opt_based_on_satd(MACROBLOCK * x,QUANT_PARAM * quant_param,int plane,int block,TX_SIZE tx_size,int quant_b_adapt,int qstep,unsigned int coeff_opt_satd_threshold,int skip_trellis,int dc_only_blk)1920 static int skip_trellis_opt_based_on_satd(MACROBLOCK *x,
1921                                           QUANT_PARAM *quant_param, int plane,
1922                                           int block, TX_SIZE tx_size,
1923                                           int quant_b_adapt, int qstep,
1924                                           unsigned int coeff_opt_satd_threshold,
1925                                           int skip_trellis, int dc_only_blk) {
1926   if (skip_trellis || (coeff_opt_satd_threshold == UINT_MAX))
1927     return skip_trellis;
1928 
1929   const struct macroblock_plane *const p = &x->plane[plane];
1930   const int block_offset = BLOCK_OFFSET(block);
1931   tran_low_t *const coeff_ptr = p->coeff + block_offset;
1932   const int n_coeffs = av1_get_max_eob(tx_size);
1933   const int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size));
1934   int satd = (dc_only_blk) ? abs(coeff_ptr[0]) : aom_satd(coeff_ptr, n_coeffs);
1935   satd = RIGHT_SIGNED_SHIFT(satd, shift);
1936   satd >>= (x->e_mbd.bd - 8);
1937 
1938   const int skip_block_trellis =
1939       ((uint64_t)satd >
1940        (uint64_t)coeff_opt_satd_threshold * qstep * sqrt_tx_pixels_2d[tx_size]);
1941 
1942   av1_setup_quant(
1943       tx_size, !skip_block_trellis,
1944       skip_block_trellis
1945           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP)
1946           : AV1_XFORM_QUANT_FP,
1947       quant_b_adapt, quant_param);
1948 
1949   return skip_block_trellis;
1950 }
1951 
1952 // Predict DC only blocks if the residual variance is below a qstep based
1953 // threshold.For such blocks, transform type search is bypassed.
predict_dc_only_block(MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,int block,int blk_row,int blk_col,RD_STATS * best_rd_stats,int64_t * block_sse,unsigned int * block_mse_q8,int64_t * per_px_mean,int * dc_only_blk)1954 static INLINE void predict_dc_only_block(
1955     MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1956     int block, int blk_row, int blk_col, RD_STATS *best_rd_stats,
1957     int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean,
1958     int *dc_only_blk) {
1959   MACROBLOCKD *xd = &x->e_mbd;
1960   MB_MODE_INFO *mbmi = xd->mi[0];
1961   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
1962   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
1963   uint64_t block_var = UINT64_MAX;
1964   const int dc_qstep = x->plane[plane].dequant_QTX[0] >> 3;
1965   *block_sse = pixel_diff_stats(x, plane, blk_row, blk_col, plane_bsize,
1966                                 txsize_to_bsize[tx_size], block_mse_q8,
1967                                 per_px_mean, &block_var);
1968   assert((*block_mse_q8) != UINT_MAX);
1969   uint64_t var_threshold = (uint64_t)(1.8 * qstep * qstep);
1970   if (is_cur_buf_hbd(xd))
1971     block_var = ROUND_POWER_OF_TWO(block_var, (xd->bd - 8) * 2);
1972 
1973   if (block_var >= var_threshold) return;
1974   const unsigned int predict_dc_level = x->txfm_search_params.predict_dc_level;
1975   assert(predict_dc_level != 0);
1976 
1977   // Prediction of skip block if residual mean and variance are less
1978   // than qstep based threshold
1979   if ((llabs(*per_px_mean) * dc_coeff_scale[tx_size]) < (dc_qstep << 12)) {
1980     // If the normalized mean of residual block is less than the dc qstep and
1981     // the  normalized block variance is less than ac qstep, then the block is
1982     // assumed to be a skip block and its rdcost is updated accordingly.
1983     best_rd_stats->skip_txfm = 1;
1984 
1985     x->plane[plane].eobs[block] = 0;
1986 
1987     if (is_cur_buf_hbd(xd))
1988       *block_sse = ROUND_POWER_OF_TWO((*block_sse), (xd->bd - 8) * 2);
1989 
1990     best_rd_stats->dist = (*block_sse) << 4;
1991     best_rd_stats->sse = best_rd_stats->dist;
1992 
1993     ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
1994     ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
1995     av1_get_entropy_contexts(plane_bsize, &xd->plane[plane], ctxa, ctxl);
1996     ENTROPY_CONTEXT *ta = ctxa;
1997     ENTROPY_CONTEXT *tl = ctxl;
1998     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
1999     TXB_CTX txb_ctx_tmp;
2000     const PLANE_TYPE plane_type = get_plane_type(plane);
2001     get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx_tmp);
2002     const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][plane_type]
2003                                   .txb_skip_cost[txb_ctx_tmp.txb_skip_ctx][1];
2004     best_rd_stats->rate = zero_blk_rate;
2005 
2006     best_rd_stats->rdcost =
2007         RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->sse);
2008 
2009     x->plane[plane].txb_entropy_ctx[block] = 0;
2010   } else if (predict_dc_level > 1) {
2011     // Predict DC only blocks based on residual variance.
2012     // For chroma plane, this prediction is disabled for intra blocks.
2013     if ((plane == 0) || (plane > 0 && is_inter_block(mbmi))) *dc_only_blk = 1;
2014   }
2015 }
2016 
2017 // Search for the best transform type for a given transform block.
2018 // This function can be used for both inter and intra, both luma and chroma.
search_tx_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis,int64_t ref_best_rd,RD_STATS * best_rd_stats)2019 static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2020                            int block, int blk_row, int blk_col,
2021                            BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2022                            const TXB_CTX *const txb_ctx,
2023                            FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis,
2024                            int64_t ref_best_rd, RD_STATS *best_rd_stats) {
2025   const AV1_COMMON *cm = &cpi->common;
2026   MACROBLOCKD *xd = &x->e_mbd;
2027   MB_MODE_INFO *mbmi = xd->mi[0];
2028   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2029   int64_t best_rd = INT64_MAX;
2030   uint16_t best_eob = 0;
2031   TX_TYPE best_tx_type = DCT_DCT;
2032   int rate_cost = 0;
2033   // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
2034   // of the best tx_type
2035   DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff[MAX_SB_SQUARE]);
2036   struct macroblock_plane *const p = &x->plane[plane];
2037   tran_low_t *orig_dqcoeff = p->dqcoeff;
2038   tran_low_t *best_dqcoeff = this_dqcoeff;
2039   const int tx_type_map_idx =
2040       plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
2041   av1_invalid_rd_stats(best_rd_stats);
2042 
2043   skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id],
2044                                    DRY_RUN_NORMAL);
2045 
2046   uint8_t best_txb_ctx = 0;
2047   // txk_allowed = TX_TYPES: >1 tx types are allowed
2048   // txk_allowed < TX_TYPES: only that specific tx type is allowed.
2049   TX_TYPE txk_allowed = TX_TYPES;
2050   int txk_map[TX_TYPES] = {
2051     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
2052   };
2053   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2054   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2055 
2056   const uint8_t txw = tx_size_wide[tx_size];
2057   const uint8_t txh = tx_size_high[tx_size];
2058   int64_t block_sse;
2059   unsigned int block_mse_q8;
2060   int dc_only_blk = 0;
2061   const bool predict_dc_block =
2062       txfm_params->predict_dc_level >= 1 && txw != 64 && txh != 64;
2063   int64_t per_px_mean = INT64_MAX;
2064   if (predict_dc_block) {
2065     predict_dc_only_block(x, plane, plane_bsize, tx_size, block, blk_row,
2066                           blk_col, best_rd_stats, &block_sse, &block_mse_q8,
2067                           &per_px_mean, &dc_only_blk);
2068     if (best_rd_stats->skip_txfm == 1) {
2069       const TX_TYPE tx_type = DCT_DCT;
2070       if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2071       return;
2072     }
2073   } else {
2074     block_sse = pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
2075                                 txsize_to_bsize[tx_size], &block_mse_q8);
2076     assert(block_mse_q8 != UINT_MAX);
2077   }
2078 
2079   // Bit mask to indicate which transform types are allowed in the RD search.
2080   uint16_t tx_mask;
2081 
2082   // Use DCT_DCT transform for DC only block.
2083   if (dc_only_blk)
2084     tx_mask = 1 << DCT_DCT;
2085   else
2086     tx_mask = get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize,
2087                           tx_size, txb_ctx, ftxs_mode, ref_best_rd,
2088                           &txk_allowed, txk_map);
2089   const uint16_t allowed_tx_mask = tx_mask;
2090 
2091   if (is_cur_buf_hbd(xd)) {
2092     block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
2093     block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
2094   }
2095   block_sse *= 16;
2096   // Use mse / qstep^2 based threshold logic to take decision of R-D
2097   // optimization of coeffs. For smaller residuals, coeff optimization
2098   // would be helpful. For larger residuals, R-D optimization may not be
2099   // effective.
2100   // TODO(any): Experiment with variance and mean based thresholds
2101   const int perform_block_coeff_opt =
2102       ((uint64_t)block_mse_q8 <=
2103        (uint64_t)txfm_params->coeff_opt_thresholds[0] * qstep * qstep);
2104   skip_trellis |= !perform_block_coeff_opt;
2105 
2106   // Flag to indicate if distortion should be calculated in transform domain or
2107   // not during iterating through transform type candidates.
2108   // Transform domain distortion is accurate for higher residuals.
2109   // TODO(any): Experiment with variance and mean based thresholds
2110   int use_transform_domain_distortion =
2111       (txfm_params->use_transform_domain_distortion > 0) &&
2112       (block_mse_q8 >= txfm_params->tx_domain_dist_threshold) &&
2113       // Any 64-pt transforms only preserves half the coefficients.
2114       // Therefore transform domain distortion is not valid for these
2115       // transform sizes.
2116       (txsize_sqr_up_map[tx_size] != TX_64X64) &&
2117       // Use pixel domain distortion for DC only blocks
2118       !dc_only_blk;
2119   // Flag to indicate if an extra calculation of distortion in the pixel domain
2120   // should be performed at the end, after the best transform type has been
2121   // decided.
2122   int calc_pixel_domain_distortion_final =
2123       txfm_params->use_transform_domain_distortion == 1 &&
2124       use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
2125   if (calc_pixel_domain_distortion_final &&
2126       (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
2127     calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
2128 
2129   const uint16_t *eobs_ptr = x->plane[plane].eobs;
2130 
2131   TxfmParam txfm_param;
2132   QUANT_PARAM quant_param;
2133   int skip_trellis_based_on_satd[TX_TYPES] = { 0 };
2134   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
2135   av1_setup_quant(tx_size, !skip_trellis,
2136                   skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
2137                                                          : AV1_XFORM_QUANT_FP)
2138                                : AV1_XFORM_QUANT_FP,
2139                   cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
2140 
2141   // Iterate through all transform type candidates.
2142   for (int idx = 0; idx < TX_TYPES; ++idx) {
2143     const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
2144     if (tx_type == TX_TYPE_INVALID || !check_bit_mask(allowed_tx_mask, tx_type))
2145       continue;
2146     txfm_param.tx_type = tx_type;
2147     if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
2148       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
2149                         &quant_param);
2150     }
2151     if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2152     RD_STATS this_rd_stats;
2153     av1_invalid_rd_stats(&this_rd_stats);
2154 
2155     if (!dc_only_blk)
2156       av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
2157     else
2158       av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
2159 
2160     skip_trellis_based_on_satd[tx_type] = skip_trellis_opt_based_on_satd(
2161         x, &quant_param, plane, block, tx_size, cpi->oxcf.q_cfg.quant_b_adapt,
2162         qstep, txfm_params->coeff_opt_thresholds[1], skip_trellis, dc_only_blk);
2163 
2164     av1_quant(x, plane, block, &txfm_param, &quant_param);
2165 
2166     // Calculate rate cost of quantized coefficients.
2167     if (quant_param.use_optimize_b) {
2168       // TODO(aomedia:3209): update Trellis quantization to take into account
2169       // quantization matrices.
2170       av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
2171                      &rate_cost);
2172     } else {
2173       rate_cost = cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
2174                               cm->features.reduced_tx_set_used);
2175     }
2176 
2177     // If rd cost based on coeff rate alone is already more than best_rd,
2178     // terminate early.
2179     if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
2180 
2181     // Calculate distortion.
2182     if (eobs_ptr[block] == 0) {
2183       // When eob is 0, pixel domain distortion is more efficient and accurate.
2184       this_rd_stats.dist = this_rd_stats.sse = block_sse;
2185     } else if (dc_only_blk) {
2186       this_rd_stats.sse = block_sse;
2187       this_rd_stats.dist = dist_block_px_domain(
2188           cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2189     } else if (use_transform_domain_distortion) {
2190       const SCAN_ORDER *const scan_order =
2191           get_scan(txfm_param.tx_size, txfm_param.tx_type);
2192       dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2193                            scan_order->scan, &this_rd_stats.dist,
2194                            &this_rd_stats.sse);
2195     } else {
2196       int64_t sse_diff = INT64_MAX;
2197       // high_energy threshold assumes that every pixel within a txfm block
2198       // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
2199       // for 8 bit.
2200       const int64_t high_energy_thresh =
2201           ((int64_t)128 * 128 * tx_size_2d[tx_size]);
2202       const int is_high_energy = (block_sse >= high_energy_thresh);
2203       if (tx_size == TX_64X64 || is_high_energy) {
2204         // Because 3 out 4 quadrants of transform coefficients are forced to
2205         // zero, the inverse transform has a tendency to overflow. sse_diff
2206         // is effectively the energy of those 3 quadrants, here we use it
2207         // to decide if we should do pixel domain distortion. If the energy
2208         // is mostly in first quadrant, then it is unlikely that we have
2209         // overflow issue in inverse transform.
2210         const SCAN_ORDER *const scan_order =
2211             get_scan(txfm_param.tx_size, txfm_param.tx_type);
2212         dist_block_tx_domain(x, plane, block, tx_size, quant_param.qmatrix,
2213                              scan_order->scan, &this_rd_stats.dist,
2214                              &this_rd_stats.sse);
2215         sse_diff = block_sse - this_rd_stats.sse;
2216       }
2217       if (tx_size != TX_64X64 || !is_high_energy ||
2218           (sse_diff * 2) < this_rd_stats.sse) {
2219         const int64_t tx_domain_dist = this_rd_stats.dist;
2220         this_rd_stats.dist = dist_block_px_domain(
2221             cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2222         // For high energy blocks, occasionally, the pixel domain distortion
2223         // can be artificially low due to clamping at reconstruction stage
2224         // even when inverse transform output is hugely different from the
2225         // actual residue.
2226         if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
2227           this_rd_stats.dist = tx_domain_dist;
2228       } else {
2229         assert(sse_diff < INT64_MAX);
2230         this_rd_stats.dist += sse_diff;
2231       }
2232       this_rd_stats.sse = block_sse;
2233     }
2234 
2235     this_rd_stats.rate = rate_cost;
2236 
2237     const int64_t rd =
2238         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2239 
2240     if (rd < best_rd) {
2241       best_rd = rd;
2242       *best_rd_stats = this_rd_stats;
2243       best_tx_type = tx_type;
2244       best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
2245       best_eob = x->plane[plane].eobs[block];
2246       // Swap dqcoeff buffers
2247       tran_low_t *const tmp_dqcoeff = best_dqcoeff;
2248       best_dqcoeff = p->dqcoeff;
2249       p->dqcoeff = tmp_dqcoeff;
2250     }
2251 
2252 #if CONFIG_COLLECT_RD_STATS == 1
2253     if (plane == 0) {
2254       PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
2255                               plane_bsize, tx_size, tx_type, rd);
2256     }
2257 #endif  // CONFIG_COLLECT_RD_STATS == 1
2258 
2259 #if COLLECT_TX_SIZE_DATA
2260     // Generate small sample to restrict output size.
2261     static unsigned int seed = 21743;
2262     if (lcg_rand16(&seed) % 200 == 0) {
2263       FILE *fp = NULL;
2264 
2265       if (within_border) {
2266         fp = fopen(av1_tx_size_data_output_file, "a");
2267       }
2268 
2269       if (fp) {
2270         // Transform info and RD
2271         const int txb_w = tx_size_wide[tx_size];
2272         const int txb_h = tx_size_high[tx_size];
2273 
2274         // Residue signal.
2275         const int diff_stride = block_size_wide[plane_bsize];
2276         struct macroblock_plane *const p = &x->plane[plane];
2277         const int16_t *src_diff =
2278             &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
2279 
2280         for (int r = 0; r < txb_h; ++r) {
2281           for (int c = 0; c < txb_w; ++c) {
2282             fprintf(fp, "%d,", src_diff[c]);
2283           }
2284           src_diff += diff_stride;
2285         }
2286 
2287         fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
2288         fprintf(fp, "\n");
2289         fclose(fp);
2290       }
2291     }
2292 #endif  // COLLECT_TX_SIZE_DATA
2293 
2294     // If the current best RD cost is much worse than the reference RD cost,
2295     // terminate early.
2296     if (cpi->sf.tx_sf.adaptive_txb_search_level) {
2297       if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
2298           ref_best_rd) {
2299         break;
2300       }
2301     }
2302 
2303     // Terminate transform type search if the block has been quantized to
2304     // all zero.
2305     if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
2306   }
2307 
2308   assert(best_rd != INT64_MAX);
2309 
2310   best_rd_stats->skip_txfm = best_eob == 0;
2311   if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2312   x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
2313   x->plane[plane].eobs[block] = best_eob;
2314   skip_trellis = skip_trellis_based_on_satd[best_tx_type];
2315 
2316   // Point dqcoeff to the quantized coefficients corresponding to the best
2317   // transform type, then we can skip transform and quantization, e.g. in the
2318   // final pixel domain distortion calculation and recon_intra().
2319   p->dqcoeff = best_dqcoeff;
2320 
2321   if (calc_pixel_domain_distortion_final && best_eob) {
2322     best_rd_stats->dist = dist_block_px_domain(
2323         cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2324     best_rd_stats->sse = block_sse;
2325   }
2326 
2327   // Intra mode needs decoded pixels such that the next transform block
2328   // can use them for prediction.
2329   recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2330               txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
2331   p->dqcoeff = orig_dqcoeff;
2332 }
2333 
2334 // Pick transform type for a luma transform block of tx_size. Note this function
2335 // is used only for inter-predicted blocks.
tx_type_rd(const AV1_COMP * cpi,MACROBLOCK * x,TX_SIZE tx_size,int blk_row,int blk_col,int block,int plane_bsize,TXB_CTX * txb_ctx,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_rdcost)2336 static AOM_INLINE void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
2337                                   TX_SIZE tx_size, int blk_row, int blk_col,
2338                                   int block, int plane_bsize, TXB_CTX *txb_ctx,
2339                                   RD_STATS *rd_stats,
2340                                   FAST_TX_SEARCH_MODE ftxs_mode,
2341                                   int64_t ref_rdcost) {
2342   assert(is_inter_block(x->e_mbd.mi[0]));
2343   RD_STATS this_rd_stats;
2344   const int skip_trellis = 0;
2345   search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
2346                  txb_ctx, ftxs_mode, skip_trellis, ref_rdcost, &this_rd_stats);
2347 
2348   av1_merge_rd_stats(rd_stats, &this_rd_stats);
2349 }
2350 
try_tx_block_no_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,const ENTROPY_CONTEXT * ta,const ENTROPY_CONTEXT * tl,int txfm_partition_ctx,RD_STATS * rd_stats,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TxCandidateInfo * no_split)2351 static AOM_INLINE void try_tx_block_no_split(
2352     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2353     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
2354     const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
2355     int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
2356     FAST_TX_SEARCH_MODE ftxs_mode, TxCandidateInfo *no_split) {
2357   MACROBLOCKD *const xd = &x->e_mbd;
2358   MB_MODE_INFO *const mbmi = xd->mi[0];
2359   struct macroblock_plane *const p = &x->plane[0];
2360   const int bw = mi_size_wide[plane_bsize];
2361   const ENTROPY_CONTEXT *const pta = ta + blk_col;
2362   const ENTROPY_CONTEXT *const ptl = tl + blk_row;
2363   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2364   TXB_CTX txb_ctx;
2365   get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
2366   const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
2367                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
2368   rd_stats->zero_rate = zero_blk_rate;
2369   const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
2370   mbmi->inter_tx_size[index] = tx_size;
2371   tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
2372              rd_stats, ftxs_mode, ref_best_rd);
2373   assert(rd_stats->rate < INT_MAX);
2374 
2375   const int pick_skip_txfm =
2376       !xd->lossless[mbmi->segment_id] &&
2377       (rd_stats->skip_txfm == 1 ||
2378        RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
2379            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse));
2380   if (pick_skip_txfm) {
2381 #if CONFIG_RD_DEBUG
2382     update_txb_coeff_cost(rd_stats, 0, zero_blk_rate - rd_stats->rate);
2383 #endif  // CONFIG_RD_DEBUG
2384     rd_stats->rate = zero_blk_rate;
2385     rd_stats->dist = rd_stats->sse;
2386     p->eobs[block] = 0;
2387     update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
2388   }
2389   rd_stats->skip_txfm = pick_skip_txfm;
2390   set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2391                pick_skip_txfm);
2392 
2393   if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
2394     rd_stats->rate += x->mode_costs.txfm_partition_cost[txfm_partition_ctx][0];
2395 
2396   no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
2397   no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
2398   no_split->tx_type =
2399       xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
2400 }
2401 
try_tx_block_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int txfm_partition_ctx,int64_t no_split_rd,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,RD_STATS * split_rd_stats)2402 static AOM_INLINE void try_tx_block_split(
2403     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2404     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2405     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2406     int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
2407     FAST_TX_SEARCH_MODE ftxs_mode, RD_STATS *split_rd_stats) {
2408   assert(tx_size < TX_SIZES_ALL);
2409   MACROBLOCKD *const xd = &x->e_mbd;
2410   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
2411   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
2412   const int txb_width = tx_size_wide_unit[tx_size];
2413   const int txb_height = tx_size_high_unit[tx_size];
2414   // Transform size after splitting current block.
2415   const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
2416   const int sub_txb_width = tx_size_wide_unit[sub_txs];
2417   const int sub_txb_height = tx_size_high_unit[sub_txs];
2418   const int sub_step = sub_txb_width * sub_txb_height;
2419   const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width);
2420   assert(nblks > 0);
2421   av1_init_rd_stats(split_rd_stats);
2422   split_rd_stats->rate =
2423       x->mode_costs.txfm_partition_cost[txfm_partition_ctx][1];
2424 
2425   for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) {
2426     const int offsetr = blk_row + r;
2427     if (offsetr >= max_blocks_high) break;
2428     for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) {
2429       assert(blk_idx < 4);
2430       const int offsetc = blk_col + c;
2431       if (offsetc >= max_blocks_wide) continue;
2432 
2433       RD_STATS this_rd_stats;
2434       int this_cost_valid = 1;
2435       select_tx_block(cpi, x, offsetr, offsetc, block, sub_txs, depth + 1,
2436                       plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats,
2437                       no_split_rd / nblks, ref_best_rd - split_rd_stats->rdcost,
2438                       &this_cost_valid, ftxs_mode);
2439       if (!this_cost_valid) {
2440         split_rd_stats->rdcost = INT64_MAX;
2441         return;
2442       }
2443       av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
2444       split_rd_stats->rdcost =
2445           RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
2446       if (split_rd_stats->rdcost > ref_best_rd) {
2447         split_rd_stats->rdcost = INT64_MAX;
2448         return;
2449       }
2450       block += sub_step;
2451     }
2452   }
2453 }
2454 
get_var(float mean,double x2_sum,int num)2455 static float get_var(float mean, double x2_sum, int num) {
2456   const float e_x2 = (float)(x2_sum / num);
2457   const float diff = e_x2 - mean * mean;
2458   return diff;
2459 }
2460 
get_blk_var_dev(const int16_t * data,int stride,int bw,int bh,float * dev_of_mean,float * var_of_vars)2461 static AOM_INLINE void get_blk_var_dev(const int16_t *data, int stride, int bw,
2462                                        int bh, float *dev_of_mean,
2463                                        float *var_of_vars) {
2464   const int16_t *const data_ptr = &data[0];
2465   const int subh = (bh >= bw) ? (bh >> 1) : bh;
2466   const int subw = (bw >= bh) ? (bw >> 1) : bw;
2467   const int num = bw * bh;
2468   const int sub_num = subw * subh;
2469   int total_x_sum = 0;
2470   int64_t total_x2_sum = 0;
2471   int blk_idx = 0;
2472   float var_sum = 0.0f;
2473   float mean_sum = 0.0f;
2474   double var2_sum = 0.0f;
2475   double mean2_sum = 0.0f;
2476 
2477   for (int row = 0; row < bh; row += subh) {
2478     for (int col = 0; col < bw; col += subw) {
2479       int x_sum;
2480       int64_t x2_sum;
2481       aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
2482                           &x_sum, &x2_sum);
2483       total_x_sum += x_sum;
2484       total_x2_sum += x2_sum;
2485 
2486       const float mean = (float)x_sum / sub_num;
2487       const float var = get_var(mean, (double)x2_sum, sub_num);
2488       mean_sum += mean;
2489       mean2_sum += (double)(mean * mean);
2490       var_sum += var;
2491       var2_sum += var * var;
2492       blk_idx++;
2493     }
2494   }
2495 
2496   const float lvl0_mean = (float)total_x_sum / num;
2497   const float block_var = get_var(lvl0_mean, (double)total_x2_sum, num);
2498   mean_sum += lvl0_mean;
2499   mean2_sum += (double)(lvl0_mean * lvl0_mean);
2500   var_sum += block_var;
2501   var2_sum += block_var * block_var;
2502   const float av_mean = mean_sum / 5;
2503 
2504   if (blk_idx > 1) {
2505     // Deviation of means.
2506     *dev_of_mean = get_dev(av_mean, mean2_sum, (blk_idx + 1));
2507     // Variance of variances.
2508     const float mean_var = var_sum / (blk_idx + 1);
2509     *var_of_vars = get_var(mean_var, var2_sum, (blk_idx + 1));
2510   }
2511 }
2512 
prune_tx_split_no_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size,int * try_no_split,int * try_split,int pruning_level)2513 static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize,
2514                                     int blk_row, int blk_col, TX_SIZE tx_size,
2515                                     int *try_no_split, int *try_split,
2516                                     int pruning_level) {
2517   const int diff_stride = block_size_wide[bsize];
2518   const int16_t *diff =
2519       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
2520   const int bw = tx_size_wide[tx_size];
2521   const int bh = tx_size_high[tx_size];
2522   float dev_of_means = 0.0f;
2523   float var_of_vars = 0.0f;
2524 
2525   // This function calculates the deviation of means, and the variance of pixel
2526   // variances of the block as well as it's sub-blocks.
2527   get_blk_var_dev(diff, diff_stride, bw, bh, &dev_of_means, &var_of_vars);
2528   const int dc_q = x->plane[0].dequant_QTX[0] >> 3;
2529   const int ac_q = x->plane[0].dequant_QTX[1] >> 3;
2530   const int no_split_thresh_scales[4] = { 0, 24, 8, 8 };
2531   const int no_split_thresh_scale = no_split_thresh_scales[pruning_level];
2532   const int split_thresh_scales[4] = { 0, 24, 10, 8 };
2533   const int split_thresh_scale = split_thresh_scales[pruning_level];
2534 
2535   if ((dev_of_means <= dc_q) &&
2536       (split_thresh_scale * var_of_vars <= ac_q * ac_q)) {
2537     *try_split = 0;
2538   }
2539   if ((dev_of_means > no_split_thresh_scale * dc_q) &&
2540       (var_of_vars > no_split_thresh_scale * ac_q * ac_q)) {
2541     *try_no_split = 0;
2542   }
2543 }
2544 
2545 // Search for the best transform partition(recursive)/type for a given
2546 // inter-predicted luma block. The obtained transform selection will be saved
2547 // in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
select_tx_block(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,RD_STATS * rd_stats,int64_t prev_level_rd,int64_t ref_best_rd,int * is_cost_valid,FAST_TX_SEARCH_MODE ftxs_mode)2548 static AOM_INLINE void select_tx_block(
2549     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2550     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2551     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2552     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
2553     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode) {
2554   assert(tx_size < TX_SIZES_ALL);
2555   av1_init_rd_stats(rd_stats);
2556   if (ref_best_rd < 0) {
2557     *is_cost_valid = 0;
2558     return;
2559   }
2560 
2561   MACROBLOCKD *const xd = &x->e_mbd;
2562   assert(blk_row < max_block_high(xd, plane_bsize, 0) &&
2563          blk_col < max_block_wide(xd, plane_bsize, 0));
2564   MB_MODE_INFO *const mbmi = xd->mi[0];
2565   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
2566                                          mbmi->bsize, tx_size);
2567   struct macroblock_plane *const p = &x->plane[0];
2568 
2569   int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 ||
2570                       txsize_sqr_up_map[tx_size] != TX_64X64) &&
2571                      (cpi->oxcf.txfm_cfg.enable_rect_tx ||
2572                       tx_size_wide[tx_size] == tx_size_high[tx_size]);
2573   int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
2574   TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
2575 
2576   // Prune tx_split and no-split based on sub-block properties.
2577   if (tx_size != TX_4X4 && try_split == 1 && try_no_split == 1 &&
2578       cpi->sf.tx_sf.prune_tx_size_level > 0) {
2579     prune_tx_split_no_split(x, plane_bsize, blk_row, blk_col, tx_size,
2580                             &try_no_split, &try_split,
2581                             cpi->sf.tx_sf.prune_tx_size_level);
2582   }
2583 
2584   if (cpi->sf.rt_sf.skip_tx_no_split_var_based_partition) {
2585     if (x->try_merge_partition && try_split && p->eobs[block]) try_no_split = 0;
2586   }
2587 
2588   // Try using current block as a single transform block without split.
2589   if (try_no_split) {
2590     try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2591                           plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
2592                           ftxs_mode, &no_split);
2593 
2594     // Speed features for early termination.
2595     const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level;
2596     if (search_level) {
2597       if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) {
2598         *is_cost_valid = 0;
2599         return;
2600       }
2601       if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) {
2602         try_split = 0;
2603       }
2604     }
2605     if (cpi->sf.tx_sf.txb_split_cap) {
2606       if (p->eobs[block] == 0) try_split = 0;
2607     }
2608   }
2609 
2610   // ML based speed feature to skip searching for split transform blocks.
2611   if (x->e_mbd.bd == 8 && try_split &&
2612       !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
2613     const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
2614     if (threshold >= 0) {
2615       const int split_score =
2616           ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
2617       if (split_score < -threshold) try_split = 0;
2618     }
2619   }
2620 
2621   RD_STATS split_rd_stats;
2622   split_rd_stats.rdcost = INT64_MAX;
2623   // Try splitting current block into smaller transform blocks.
2624   if (try_split) {
2625     try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2626                        plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
2627                        AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
2628                        &split_rd_stats);
2629   }
2630 
2631   if (no_split.rd < split_rd_stats.rdcost) {
2632     ENTROPY_CONTEXT *pta = ta + blk_col;
2633     ENTROPY_CONTEXT *ptl = tl + blk_row;
2634     p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
2635     av1_set_txb_context(x, 0, block, tx_size, pta, ptl);
2636     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
2637                           tx_size);
2638     for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
2639       for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
2640         const int index =
2641             av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
2642         mbmi->inter_tx_size[index] = tx_size;
2643       }
2644     }
2645     mbmi->tx_size = tx_size;
2646     update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
2647     const int bw = mi_size_wide[plane_bsize];
2648     set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2649                  rd_stats->skip_txfm);
2650   } else {
2651     *rd_stats = split_rd_stats;
2652     if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0;
2653   }
2654 }
2655 
choose_largest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2656 static AOM_INLINE void choose_largest_tx_size(const AV1_COMP *const cpi,
2657                                               MACROBLOCK *x, RD_STATS *rd_stats,
2658                                               int64_t ref_best_rd,
2659                                               BLOCK_SIZE bs) {
2660   MACROBLOCKD *const xd = &x->e_mbd;
2661   MB_MODE_INFO *const mbmi = xd->mi[0];
2662   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2663   mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2664 
2665   // If tx64 is not enabled, we need to go down to the next available size
2666   if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) {
2667     static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
2668       TX_4X4,    // 4x4 transform
2669       TX_8X8,    // 8x8 transform
2670       TX_16X16,  // 16x16 transform
2671       TX_32X32,  // 32x32 transform
2672       TX_32X32,  // 64x64 transform
2673       TX_4X8,    // 4x8 transform
2674       TX_8X4,    // 8x4 transform
2675       TX_8X16,   // 8x16 transform
2676       TX_16X8,   // 16x8 transform
2677       TX_16X32,  // 16x32 transform
2678       TX_32X16,  // 32x16 transform
2679       TX_32X32,  // 32x64 transform
2680       TX_32X32,  // 64x32 transform
2681       TX_4X16,   // 4x16 transform
2682       TX_16X4,   // 16x4 transform
2683       TX_8X32,   // 8x32 transform
2684       TX_32X8,   // 32x8 transform
2685       TX_16X32,  // 16x64 transform
2686       TX_32X16,  // 64x16 transform
2687     };
2688     mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
2689   } else if (cpi->oxcf.txfm_cfg.enable_tx64 &&
2690              !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2691     static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = {
2692       TX_4X4,    // 4x4 transform
2693       TX_8X8,    // 8x8 transform
2694       TX_16X16,  // 16x16 transform
2695       TX_32X32,  // 32x32 transform
2696       TX_64X64,  // 64x64 transform
2697       TX_4X4,    // 4x8 transform
2698       TX_4X4,    // 8x4 transform
2699       TX_8X8,    // 8x16 transform
2700       TX_8X8,    // 16x8 transform
2701       TX_16X16,  // 16x32 transform
2702       TX_16X16,  // 32x16 transform
2703       TX_32X32,  // 32x64 transform
2704       TX_32X32,  // 64x32 transform
2705       TX_4X4,    // 4x16 transform
2706       TX_4X4,    // 16x4 transform
2707       TX_8X8,    // 8x32 transform
2708       TX_8X8,    // 32x8 transform
2709       TX_16X16,  // 16x64 transform
2710       TX_16X16,  // 64x16 transform
2711     };
2712     mbmi->tx_size = tx_size_max_square[mbmi->tx_size];
2713   } else if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
2714              !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2715     static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = {
2716       TX_4X4,    // 4x4 transform
2717       TX_8X8,    // 8x8 transform
2718       TX_16X16,  // 16x16 transform
2719       TX_32X32,  // 32x32 transform
2720       TX_32X32,  // 64x64 transform
2721       TX_4X4,    // 4x8 transform
2722       TX_4X4,    // 8x4 transform
2723       TX_8X8,    // 8x16 transform
2724       TX_8X8,    // 16x8 transform
2725       TX_16X16,  // 16x32 transform
2726       TX_16X16,  // 32x16 transform
2727       TX_32X32,  // 32x64 transform
2728       TX_32X32,  // 64x32 transform
2729       TX_4X4,    // 4x16 transform
2730       TX_4X4,    // 16x4 transform
2731       TX_8X8,    // 8x32 transform
2732       TX_8X8,    // 32x8 transform
2733       TX_16X16,  // 16x64 transform
2734       TX_16X16,  // 64x16 transform
2735     };
2736 
2737     mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size];
2738   }
2739 
2740   const int skip_ctx = av1_get_skip_txfm_context(xd);
2741   const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
2742   const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
2743   // Skip RDcost is used only for Inter blocks
2744   const int64_t skip_txfm_rd =
2745       is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
2746   const int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_rate, 0);
2747   const int skip_trellis = 0;
2748   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
2749                        AOMMIN(no_skip_txfm_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
2750                        mbmi->tx_size, FTXS_NONE, skip_trellis);
2751 }
2752 
choose_smallest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2753 static AOM_INLINE void choose_smallest_tx_size(const AV1_COMP *const cpi,
2754                                                MACROBLOCK *x,
2755                                                RD_STATS *rd_stats,
2756                                                int64_t ref_best_rd,
2757                                                BLOCK_SIZE bs) {
2758   MACROBLOCKD *const xd = &x->e_mbd;
2759   MB_MODE_INFO *const mbmi = xd->mi[0];
2760 
2761   mbmi->tx_size = TX_4X4;
2762   // TODO(any) : Pass this_rd based on skip/non-skip cost
2763   const int skip_trellis = 0;
2764   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
2765                        FTXS_NONE, skip_trellis);
2766 }
2767 
2768 #if !CONFIG_REALTIME_ONLY
ml_predict_intra_tx_depth_prune(MACROBLOCK * x,int blk_row,int blk_col,BLOCK_SIZE bsize,TX_SIZE tx_size)2769 static void ml_predict_intra_tx_depth_prune(MACROBLOCK *x, int blk_row,
2770                                             int blk_col, BLOCK_SIZE bsize,
2771                                             TX_SIZE tx_size) {
2772   const MACROBLOCKD *const xd = &x->e_mbd;
2773   const MB_MODE_INFO *const mbmi = xd->mi[0];
2774 
2775   // Disable the pruning logic using NN model for the following cases:
2776   // 1) Lossless coding as only 4x4 transform is evaluated in this case
2777   // 2) When transform and current block sizes do not match as the features are
2778   // obtained over the current block
2779   // 3) When operating bit-depth is not 8-bit as the input features are not
2780   // scaled according to bit-depth.
2781   if (xd->lossless[mbmi->segment_id] || txsize_to_bsize[tx_size] != bsize ||
2782       xd->bd != 8)
2783     return;
2784 
2785   // Currently NN model based pruning is supported only when largest transform
2786   // size is 8x8
2787   if (tx_size != TX_8X8) return;
2788 
2789   // Neural network model is a sequential neural net and was trained using SGD
2790   // optimizer. The model can be further improved in terms of speed/quality by
2791   // considering the following experiments:
2792   // 1) Generate ML model by training with balanced data for different learning
2793   // rates and optimizers.
2794   // 2) Experiment with ML model by adding features related to the statistics of
2795   // top and left pixels to capture the accuracy of reconstructed neighbouring
2796   // pixels for 4x4 blocks numbered 1, 2, 3 in 8x8 block, source variance of 4x4
2797   // sub-blocks, etc.
2798   // 3) Generate ML models for transform blocks other than 8x8.
2799   const NN_CONFIG *const nn_config = &av1_intra_tx_split_nnconfig_8x8;
2800   const float *const intra_tx_prune_thresh = av1_intra_tx_prune_nn_thresh_8x8;
2801 
2802   float features[NUM_INTRA_TX_SPLIT_FEATURES] = { 0.0f };
2803   const int diff_stride = block_size_wide[bsize];
2804 
2805   const int16_t *diff = x->plane[0].src_diff + MI_SIZE * blk_row * diff_stride +
2806                         MI_SIZE * blk_col;
2807   const int bw = tx_size_wide[tx_size];
2808   const int bh = tx_size_high[tx_size];
2809 
2810   int feature_idx = get_mean_dev_features(diff, diff_stride, bw, bh, features);
2811 
2812   features[feature_idx++] = logf(1.0f + (float)x->source_variance);
2813 
2814   const int dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd) >> (xd->bd - 8);
2815   const float log_dc_q_square = logf(1.0f + (float)(dc_q * dc_q) / 256.0f);
2816   features[feature_idx++] = log_dc_q_square;
2817   assert(feature_idx == NUM_INTRA_TX_SPLIT_FEATURES);
2818   for (int i = 0; i < NUM_INTRA_TX_SPLIT_FEATURES; i++) {
2819     features[i] = (features[i] - av1_intra_tx_split_8x8_mean[i]) /
2820                   av1_intra_tx_split_8x8_std[i];
2821   }
2822 
2823   float score;
2824   av1_nn_predict(features, nn_config, 1, &score);
2825 
2826   TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2827   if (score <= intra_tx_prune_thresh[0])
2828     txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_SPLIT;
2829   else if (score > intra_tx_prune_thresh[1])
2830     txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_LARGEST;
2831 }
2832 #endif  // !CONFIG_REALTIME_ONLY
2833 
2834 // Search for the best uniform transform size and type for current coding block.
choose_tx_size_type_from_rd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2835 static AOM_INLINE void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
2836                                                    MACROBLOCK *x,
2837                                                    RD_STATS *rd_stats,
2838                                                    int64_t ref_best_rd,
2839                                                    BLOCK_SIZE bs) {
2840   av1_invalid_rd_stats(rd_stats);
2841 
2842   MACROBLOCKD *const xd = &x->e_mbd;
2843   MB_MODE_INFO *const mbmi = xd->mi[0];
2844   TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2845   const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
2846   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT;
2847   int start_tx;
2848   // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls
2849   // how many times of splitting is allowed during the RD search.
2850   int init_depth;
2851 
2852   if (tx_select) {
2853     start_tx = max_rect_tx_size;
2854     init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
2855                                        is_inter_block(mbmi), &cpi->sf,
2856                                        txfm_params->tx_size_search_method);
2857     if (init_depth == MAX_TX_DEPTH && !cpi->oxcf.txfm_cfg.enable_tx64 &&
2858         txsize_sqr_up_map[start_tx] == TX_64X64) {
2859       start_tx = sub_tx_size_map[start_tx];
2860     }
2861   } else {
2862     const TX_SIZE chosen_tx_size =
2863         tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2864     start_tx = chosen_tx_size;
2865     init_depth = MAX_TX_DEPTH;
2866   }
2867 
2868   const int skip_trellis = 0;
2869   uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
2870   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
2871   TX_SIZE best_tx_size = max_rect_tx_size;
2872   int64_t best_rd = INT64_MAX;
2873   const int num_blks = bsize_to_num_blk(bs);
2874   x->rd_model = FULL_TXFM_RD;
2875   int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
2876   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
2877   for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
2878        depth++, tx_size = sub_tx_size_map[tx_size]) {
2879     if ((!cpi->oxcf.txfm_cfg.enable_tx64 &&
2880          txsize_sqr_up_map[tx_size] == TX_64X64) ||
2881         (!cpi->oxcf.txfm_cfg.enable_rect_tx &&
2882          tx_size_wide[tx_size] != tx_size_high[tx_size])) {
2883       continue;
2884     }
2885 
2886 #if !CONFIG_REALTIME_ONLY
2887     if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_SPLIT) break;
2888 
2889     // Set the flag to enable the evaluation of NN classifier to prune transform
2890     // depths. As the features are based on intra residual information of
2891     // largest transform, the evaluation of NN model is enabled only for this
2892     // case.
2893     txfm_params->enable_nn_prune_intra_tx_depths =
2894         (cpi->sf.tx_sf.prune_intra_tx_depths_using_nn && tx_size == start_tx);
2895 #endif
2896 
2897     RD_STATS this_rd_stats;
2898     rd[depth] = av1_uniform_txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs,
2899                                      tx_size, FTXS_NONE, skip_trellis);
2900     if (rd[depth] < best_rd) {
2901       av1_copy_array(best_blk_skip, txfm_info->blk_skip, num_blks);
2902       av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks);
2903       best_tx_size = tx_size;
2904       best_rd = rd[depth];
2905       *rd_stats = this_rd_stats;
2906     }
2907     if (tx_size == TX_4X4) break;
2908     // If we are searching three depths, prune the smallest size depending
2909     // on rd results for the first two depths for low contrast blocks.
2910     if (depth > init_depth && depth != MAX_TX_DEPTH &&
2911         x->source_variance < 256) {
2912       if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
2913     }
2914   }
2915 
2916   if (rd_stats->rate != INT_MAX) {
2917     mbmi->tx_size = best_tx_size;
2918     av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks);
2919     av1_copy_array(txfm_info->blk_skip, best_blk_skip, num_blks);
2920   }
2921 
2922 #if !CONFIG_REALTIME_ONLY
2923   // Reset the flags to avoid any unintentional evaluation of NN model and
2924   // consumption of prune depths.
2925   txfm_params->enable_nn_prune_intra_tx_depths = false;
2926   txfm_params->nn_prune_depths_for_intra_tx = TX_PRUNE_NONE;
2927 #endif
2928 }
2929 
2930 // Search for the best transform type for the given transform block in the
2931 // given plane/channel, and calculate the corresponding RD cost.
block_rd_txfm(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)2932 static AOM_INLINE void block_rd_txfm(int plane, int block, int blk_row,
2933                                      int blk_col, BLOCK_SIZE plane_bsize,
2934                                      TX_SIZE tx_size, void *arg) {
2935   struct rdcost_block_args *args = arg;
2936   if (args->exit_early) {
2937     args->incomplete_exit = 1;
2938     return;
2939   }
2940 
2941   MACROBLOCK *const x = args->x;
2942   MACROBLOCKD *const xd = &x->e_mbd;
2943   const int is_inter = is_inter_block(xd->mi[0]);
2944   const AV1_COMP *cpi = args->cpi;
2945   ENTROPY_CONTEXT *a = args->t_above + blk_col;
2946   ENTROPY_CONTEXT *l = args->t_left + blk_row;
2947   const AV1_COMMON *cm = &cpi->common;
2948   RD_STATS this_rd_stats;
2949   av1_init_rd_stats(&this_rd_stats);
2950 
2951   if (!is_inter) {
2952     av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
2953     av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
2954 #if !CONFIG_REALTIME_ONLY
2955     const TxfmSearchParams *const txfm_params = &x->txfm_search_params;
2956     if (txfm_params->enable_nn_prune_intra_tx_depths) {
2957       ml_predict_intra_tx_depth_prune(x, blk_row, blk_col, plane_bsize,
2958                                       tx_size);
2959       if (txfm_params->nn_prune_depths_for_intra_tx == TX_PRUNE_LARGEST) {
2960         av1_invalid_rd_stats(&args->rd_stats);
2961         args->exit_early = 1;
2962         return;
2963       }
2964     }
2965 #endif
2966   }
2967 
2968   TXB_CTX txb_ctx;
2969   get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
2970   search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2971                  &txb_ctx, args->ftxs_mode, args->skip_trellis,
2972                  args->best_rd - args->current_rd, &this_rd_stats);
2973 
2974   if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
2975     assert(!is_inter || plane_bsize < BLOCK_8X8);
2976     cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
2977   }
2978 
2979 #if CONFIG_RD_DEBUG
2980   update_txb_coeff_cost(&this_rd_stats, plane, this_rd_stats.rate);
2981 #endif  // CONFIG_RD_DEBUG
2982   av1_set_txb_context(x, plane, block, tx_size, a, l);
2983 
2984   const int blk_idx =
2985       blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
2986 
2987   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
2988   if (plane == 0)
2989     set_blk_skip(txfm_info->blk_skip, plane, blk_idx,
2990                  x->plane[plane].eobs[block] == 0);
2991   else
2992     set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 0);
2993 
2994   int64_t rd;
2995   if (is_inter) {
2996     const int64_t no_skip_txfm_rd =
2997         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2998     const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
2999     rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd);
3000     this_rd_stats.skip_txfm &= !x->plane[plane].eobs[block];
3001   } else {
3002     // Signal non-skip_txfm for Intra blocks
3003     rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3004     this_rd_stats.skip_txfm = 0;
3005   }
3006 
3007   av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
3008 
3009   args->current_rd += rd;
3010   if (args->current_rd > args->best_rd) args->exit_early = 1;
3011 }
3012 
av1_estimate_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size)3013 int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3014                               RD_STATS *rd_stats, int64_t ref_best_rd,
3015                               BLOCK_SIZE bs, TX_SIZE tx_size) {
3016   MACROBLOCKD *const xd = &x->e_mbd;
3017   MB_MODE_INFO *const mbmi = xd->mi[0];
3018   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3019   const ModeCosts *mode_costs = &x->mode_costs;
3020   const int is_inter = is_inter_block(mbmi);
3021   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3022                         block_signals_txsize(mbmi->bsize);
3023   int tx_size_rate = 0;
3024   if (tx_select) {
3025     const int ctx = txfm_partition_context(
3026         xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3027     tx_size_rate = mode_costs->txfm_partition_cost[ctx][0];
3028   }
3029   const int skip_ctx = av1_get_skip_txfm_context(xd);
3030   const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3031   const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3032   const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, 0);
3033   const int64_t no_this_rd =
3034       RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3035   mbmi->tx_size = tx_size;
3036 
3037   const uint8_t txw_unit = tx_size_wide_unit[tx_size];
3038   const uint8_t txh_unit = tx_size_high_unit[tx_size];
3039   const int step = txw_unit * txh_unit;
3040   const int max_blocks_wide = max_block_wide(xd, bs, 0);
3041   const int max_blocks_high = max_block_high(xd, bs, 0);
3042 
3043   struct rdcost_block_args args;
3044   av1_zero(args);
3045   args.x = x;
3046   args.cpi = cpi;
3047   args.best_rd = ref_best_rd;
3048   args.current_rd = AOMMIN(no_this_rd, skip_txfm_rd);
3049   av1_init_rd_stats(&args.rd_stats);
3050   av1_get_entropy_contexts(bs, &xd->plane[0], args.t_above, args.t_left);
3051   int i = 0;
3052   for (int blk_row = 0; blk_row < max_blocks_high && !args.incomplete_exit;
3053        blk_row += txh_unit) {
3054     for (int blk_col = 0; blk_col < max_blocks_wide; blk_col += txw_unit) {
3055       RD_STATS this_rd_stats;
3056       av1_init_rd_stats(&this_rd_stats);
3057 
3058       if (args.exit_early) {
3059         args.incomplete_exit = 1;
3060         break;
3061       }
3062 
3063       ENTROPY_CONTEXT *a = args.t_above + blk_col;
3064       ENTROPY_CONTEXT *l = args.t_left + blk_row;
3065       TXB_CTX txb_ctx;
3066       get_txb_ctx(bs, tx_size, 0, a, l, &txb_ctx);
3067 
3068       TxfmParam txfm_param;
3069       QUANT_PARAM quant_param;
3070       av1_setup_xform(&cpi->common, x, tx_size, DCT_DCT, &txfm_param);
3071       av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, 0, &quant_param);
3072 
3073       av1_xform(x, 0, i, blk_row, blk_col, bs, &txfm_param);
3074       av1_quant(x, 0, i, &txfm_param, &quant_param);
3075 
3076       this_rd_stats.rate =
3077           cost_coeffs(x, 0, i, tx_size, txfm_param.tx_type, &txb_ctx, 0);
3078 
3079       const SCAN_ORDER *const scan_order =
3080           get_scan(txfm_param.tx_size, txfm_param.tx_type);
3081       dist_block_tx_domain(x, 0, i, tx_size, quant_param.qmatrix,
3082                            scan_order->scan, &this_rd_stats.dist,
3083                            &this_rd_stats.sse);
3084 
3085       const int64_t no_skip_txfm_rd =
3086           RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3087       const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3088 
3089       this_rd_stats.skip_txfm &= !x->plane[0].eobs[i];
3090 
3091       av1_merge_rd_stats(&args.rd_stats, &this_rd_stats);
3092       args.current_rd += AOMMIN(no_skip_txfm_rd, skip_rd);
3093 
3094       if (args.current_rd > ref_best_rd) {
3095         args.exit_early = 1;
3096         break;
3097       }
3098 
3099       av1_set_txb_context(x, 0, i, tx_size, a, l);
3100       i += step;
3101     }
3102   }
3103 
3104   if (args.incomplete_exit) av1_invalid_rd_stats(&args.rd_stats);
3105 
3106   *rd_stats = args.rd_stats;
3107   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3108 
3109   int64_t rd;
3110   // rdstats->rate should include all the rate except skip/non-skip cost as the
3111   // same is accounted in the caller functions after rd evaluation of all
3112   // planes. However the decisions should be done after considering the
3113   // skip/non-skip header cost
3114   if (rd_stats->skip_txfm && is_inter) {
3115     rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3116   } else {
3117     // Intra blocks are always signalled as non-skip
3118     rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3119                 rd_stats->dist);
3120     rd_stats->rate += tx_size_rate;
3121   }
3122   // Check if forcing the block to skip transform leads to smaller RD cost.
3123   if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3124     int64_t temp_skip_txfm_rd =
3125         RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3126     if (temp_skip_txfm_rd <= rd) {
3127       rd = temp_skip_txfm_rd;
3128       rd_stats->rate = 0;
3129       rd_stats->dist = rd_stats->sse;
3130       rd_stats->skip_txfm = 1;
3131     }
3132   }
3133 
3134   return rd;
3135 }
3136 
av1_uniform_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3137 int64_t av1_uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3138                              RD_STATS *rd_stats, int64_t ref_best_rd,
3139                              BLOCK_SIZE bs, TX_SIZE tx_size,
3140                              FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis) {
3141   assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
3142   MACROBLOCKD *const xd = &x->e_mbd;
3143   MB_MODE_INFO *const mbmi = xd->mi[0];
3144   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3145   const ModeCosts *mode_costs = &x->mode_costs;
3146   const int is_inter = is_inter_block(mbmi);
3147   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3148                         block_signals_txsize(mbmi->bsize);
3149   int tx_size_rate = 0;
3150   if (tx_select) {
3151     const int ctx = txfm_partition_context(
3152         xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3153     tx_size_rate = is_inter ? mode_costs->txfm_partition_cost[ctx][0]
3154                             : tx_size_cost(x, bs, tx_size);
3155   }
3156   const int skip_ctx = av1_get_skip_txfm_context(xd);
3157   const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3158   const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3159   const int64_t skip_txfm_rd =
3160       is_inter ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
3161   const int64_t no_this_rd =
3162       RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3163 
3164   mbmi->tx_size = tx_size;
3165   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
3166                        AOMMIN(no_this_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
3167                        tx_size, ftxs_mode, skip_trellis);
3168   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3169 
3170   int64_t rd;
3171   // rdstats->rate should include all the rate except skip/non-skip cost as the
3172   // same is accounted in the caller functions after rd evaluation of all
3173   // planes. However the decisions should be done after considering the
3174   // skip/non-skip header cost
3175   if (rd_stats->skip_txfm && is_inter) {
3176     rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3177   } else {
3178     // Intra blocks are always signalled as non-skip
3179     rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3180                 rd_stats->dist);
3181     rd_stats->rate += tx_size_rate;
3182   }
3183   // Check if forcing the block to skip transform leads to smaller RD cost.
3184   if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3185     int64_t temp_skip_txfm_rd =
3186         RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3187     if (temp_skip_txfm_rd <= rd) {
3188       rd = temp_skip_txfm_rd;
3189       rd_stats->rate = 0;
3190       rd_stats->dist = rd_stats->sse;
3191       rd_stats->skip_txfm = 1;
3192     }
3193   }
3194 
3195   return rd;
3196 }
3197 
3198 // Search for the best transform type for a luma inter-predicted block, given
3199 // the transform block partitions.
3200 // This function is used only when some speed features are enabled.
tx_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,BLOCK_SIZE plane_bsize,int depth,ENTROPY_CONTEXT * above_ctx,ENTROPY_CONTEXT * left_ctx,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int64_t ref_best_rd,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode)3201 static AOM_INLINE void tx_block_yrd(
3202     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
3203     TX_SIZE tx_size, BLOCK_SIZE plane_bsize, int depth,
3204     ENTROPY_CONTEXT *above_ctx, ENTROPY_CONTEXT *left_ctx,
3205     TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, int64_t ref_best_rd,
3206     RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode) {
3207   assert(tx_size < TX_SIZES_ALL);
3208   MACROBLOCKD *const xd = &x->e_mbd;
3209   MB_MODE_INFO *const mbmi = xd->mi[0];
3210   assert(is_inter_block(mbmi));
3211   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
3212   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
3213 
3214   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
3215 
3216   const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
3217       plane_bsize, blk_row, blk_col)];
3218   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
3219                                          mbmi->bsize, tx_size);
3220 
3221   av1_init_rd_stats(rd_stats);
3222   if (tx_size == plane_tx_size) {
3223     ENTROPY_CONTEXT *ta = above_ctx + blk_col;
3224     ENTROPY_CONTEXT *tl = left_ctx + blk_row;
3225     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
3226     TXB_CTX txb_ctx;
3227     get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
3228 
3229     const int zero_blk_rate =
3230         x->coeff_costs.coeff_costs[txs_ctx][get_plane_type(0)]
3231             .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
3232     rd_stats->zero_rate = zero_blk_rate;
3233     tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
3234                rd_stats, ftxs_mode, ref_best_rd);
3235     const int mi_width = mi_size_wide[plane_bsize];
3236     TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3237     if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
3238             RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
3239         rd_stats->skip_txfm == 1) {
3240       rd_stats->rate = zero_blk_rate;
3241       rd_stats->dist = rd_stats->sse;
3242       rd_stats->skip_txfm = 1;
3243       set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 1);
3244       x->plane[0].eobs[block] = 0;
3245       x->plane[0].txb_entropy_ctx[block] = 0;
3246       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
3247     } else {
3248       rd_stats->skip_txfm = 0;
3249       set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 0);
3250     }
3251     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3252       rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][0];
3253     av1_set_txb_context(x, 0, block, tx_size, ta, tl);
3254     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
3255                           tx_size);
3256   } else {
3257     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
3258     const int txb_width = tx_size_wide_unit[sub_txs];
3259     const int txb_height = tx_size_high_unit[sub_txs];
3260     const int step = txb_height * txb_width;
3261     const int row_end =
3262         AOMMIN(tx_size_high_unit[tx_size], max_blocks_high - blk_row);
3263     const int col_end =
3264         AOMMIN(tx_size_wide_unit[tx_size], max_blocks_wide - blk_col);
3265     RD_STATS pn_rd_stats;
3266     int64_t this_rd = 0;
3267     assert(txb_width > 0 && txb_height > 0);
3268 
3269     for (int row = 0; row < row_end; row += txb_height) {
3270       const int offsetr = blk_row + row;
3271       for (int col = 0; col < col_end; col += txb_width) {
3272         const int offsetc = blk_col + col;
3273 
3274         av1_init_rd_stats(&pn_rd_stats);
3275         tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
3276                      depth + 1, above_ctx, left_ctx, tx_above, tx_left,
3277                      ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
3278         if (pn_rd_stats.rate == INT_MAX) {
3279           av1_invalid_rd_stats(rd_stats);
3280           return;
3281         }
3282         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3283         this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
3284         block += step;
3285       }
3286     }
3287 
3288     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3289       rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][1];
3290   }
3291 }
3292 
3293 // search for tx type with tx sizes already decided for a inter-predicted luma
3294 // partition block. It's used only when some speed features are enabled.
3295 // Return value 0: early termination triggered, no valid rd cost available;
3296 //              1: rd cost values are valid.
inter_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode)3297 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3298                            RD_STATS *rd_stats, BLOCK_SIZE bsize,
3299                            int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
3300   if (ref_best_rd < 0) {
3301     av1_invalid_rd_stats(rd_stats);
3302     return 0;
3303   }
3304 
3305   av1_init_rd_stats(rd_stats);
3306 
3307   MACROBLOCKD *const xd = &x->e_mbd;
3308   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3309   const struct macroblockd_plane *const pd = &xd->plane[0];
3310   const int mi_width = mi_size_wide[bsize];
3311   const int mi_height = mi_size_high[bsize];
3312   const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0);
3313   const int bh = tx_size_high_unit[max_tx_size];
3314   const int bw = tx_size_wide_unit[max_tx_size];
3315   const int step = bw * bh;
3316   const int init_depth = get_search_init_depth(
3317       mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3318   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3319   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3320   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3321   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3322   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3323   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3324   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3325 
3326   int64_t this_rd = 0;
3327   for (int idy = 0, block = 0; idy < mi_height; idy += bh) {
3328     for (int idx = 0; idx < mi_width; idx += bw) {
3329       RD_STATS pn_rd_stats;
3330       av1_init_rd_stats(&pn_rd_stats);
3331       tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth,
3332                    ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd,
3333                    &pn_rd_stats, ftxs_mode);
3334       if (pn_rd_stats.rate == INT_MAX) {
3335         av1_invalid_rd_stats(rd_stats);
3336         return 0;
3337       }
3338       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3339       this_rd +=
3340           AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
3341                  RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
3342       block += step;
3343     }
3344   }
3345 
3346   const int skip_ctx = av1_get_skip_txfm_context(xd);
3347   const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3348   const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3349   const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3350   this_rd =
3351       RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate, rd_stats->dist);
3352   if (skip_txfm_rd < this_rd) {
3353     this_rd = skip_txfm_rd;
3354     rd_stats->rate = 0;
3355     rd_stats->dist = rd_stats->sse;
3356     rd_stats->skip_txfm = 1;
3357   }
3358 
3359   const int is_cost_valid = this_rd > ref_best_rd;
3360   if (!is_cost_valid) {
3361     // reset cost value
3362     av1_invalid_rd_stats(rd_stats);
3363   }
3364   return is_cost_valid;
3365 }
3366 
3367 // Search for the best transform size and type for current inter-predicted
3368 // luma block with recursive transform block partitioning. The obtained
3369 // transform selection will be saved in xd->mi[0], the corresponding RD stats
3370 // will be saved in rd_stats. The returned value is the corresponding RD cost.
select_tx_size_and_type(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3371 static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
3372                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
3373                                        int64_t ref_best_rd) {
3374   MACROBLOCKD *const xd = &x->e_mbd;
3375   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3376   assert(is_inter_block(xd->mi[0]));
3377   assert(bsize < BLOCK_SIZES_ALL);
3378   const int fast_tx_search = txfm_params->tx_size_search_method > USE_FULL_RD;
3379   int64_t rd_thresh = ref_best_rd;
3380   if (rd_thresh == 0) {
3381     av1_invalid_rd_stats(rd_stats);
3382     return INT64_MAX;
3383   }
3384   if (fast_tx_search && rd_thresh < INT64_MAX) {
3385     if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
3386   }
3387   assert(rd_thresh > 0);
3388   const FAST_TX_SEARCH_MODE ftxs_mode =
3389       fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
3390   const struct macroblockd_plane *const pd = &xd->plane[0];
3391   assert(bsize < BLOCK_SIZES_ALL);
3392   const int mi_width = mi_size_wide[bsize];
3393   const int mi_height = mi_size_high[bsize];
3394   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3395   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3396   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3397   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3398   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3399   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3400   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3401   const int init_depth = get_search_init_depth(
3402       mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3403   const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
3404   const int bh = tx_size_high_unit[max_tx_size];
3405   const int bw = tx_size_wide_unit[max_tx_size];
3406   const int step = bw * bh;
3407   const int skip_ctx = av1_get_skip_txfm_context(xd);
3408   const int no_skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3409   const int skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3410   int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, 0);
3411   int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_cost, 0);
3412   int block = 0;
3413 
3414   av1_init_rd_stats(rd_stats);
3415   for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) {
3416     for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) {
3417       const int64_t best_rd_sofar =
3418           (rd_thresh == INT64_MAX)
3419               ? INT64_MAX
3420               : (rd_thresh - (AOMMIN(skip_txfm_rd, no_skip_txfm_rd)));
3421       int is_cost_valid = 1;
3422       RD_STATS pn_rd_stats;
3423       // Search for the best transform block size and type for the sub-block.
3424       select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize,
3425                       ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX,
3426                       best_rd_sofar, &is_cost_valid, ftxs_mode);
3427       if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
3428         av1_invalid_rd_stats(rd_stats);
3429         return INT64_MAX;
3430       }
3431       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3432       skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3433       no_skip_txfm_rd =
3434           RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3435       block += step;
3436     }
3437   }
3438 
3439   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3440 
3441   rd_stats->skip_txfm = (skip_txfm_rd <= no_skip_txfm_rd);
3442 
3443   // If fast_tx_search is true, only DCT and 1D DCT were tested in
3444   // select_inter_block_yrd() above. Do a better search for tx type with
3445   // tx sizes already decided.
3446   if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) {
3447     if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
3448       return INT64_MAX;
3449   }
3450 
3451   int64_t final_rd;
3452   if (rd_stats->skip_txfm) {
3453     final_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3454   } else {
3455     final_rd =
3456         RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3457     if (!xd->lossless[xd->mi[0]->segment_id]) {
3458       final_rd =
3459           AOMMIN(final_rd, RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse));
3460     }
3461   }
3462 
3463   return final_rd;
3464 }
3465 
3466 // Return 1 to terminate transform search early. The decision is made based on
3467 // the comparison with the reference RD cost and the model-estimated RD cost.
model_based_tx_search_prune(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int64_t ref_best_rd)3468 static AOM_INLINE int model_based_tx_search_prune(const AV1_COMP *cpi,
3469                                                   MACROBLOCK *x,
3470                                                   BLOCK_SIZE bsize,
3471                                                   int64_t ref_best_rd) {
3472   const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
3473   assert(level >= 0 && level <= 2);
3474   int model_rate;
3475   int64_t model_dist;
3476   uint8_t model_skip;
3477   MACROBLOCKD *const xd = &x->e_mbd;
3478   model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
3479       cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
3480       NULL, NULL, NULL);
3481   if (model_skip) return 0;
3482   const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
3483   // TODO(debargha, urvang): Improve the model and make the check below
3484   // tighter.
3485   static const int prune_factor_by8[] = { 3, 5 };
3486   const int factor = prune_factor_by8[level - 1];
3487   return ((model_rd * factor) >> 3) > ref_best_rd;
3488 }
3489 
av1_pick_recursive_tx_size_type_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3490 void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3491                                          RD_STATS *rd_stats, BLOCK_SIZE bsize,
3492                                          int64_t ref_best_rd) {
3493   MACROBLOCKD *const xd = &x->e_mbd;
3494   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3495   assert(is_inter_block(xd->mi[0]));
3496 
3497   av1_invalid_rd_stats(rd_stats);
3498 
3499   // If modeled RD cost is a lot worse than the best so far, terminate early.
3500   if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
3501       ref_best_rd != INT64_MAX) {
3502     if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
3503   }
3504 
3505   // Hashing based speed feature. If the hash of the prediction residue block is
3506   // found in the hash table, use previous search results and terminate early.
3507   uint32_t hash = 0;
3508   MB_RD_RECORD *mb_rd_record = NULL;
3509   const int mi_row = x->e_mbd.mi_row;
3510   const int mi_col = x->e_mbd.mi_col;
3511   const int within_border =
3512       mi_row >= xd->tile.mi_row_start &&
3513       (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
3514       mi_col >= xd->tile.mi_col_start &&
3515       (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
3516   const int is_mb_rd_hash_enabled =
3517       (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
3518   const int n4 = bsize_to_num_blk(bsize);
3519   if (is_mb_rd_hash_enabled) {
3520     hash = get_block_residue_hash(x, bsize);
3521     mb_rd_record = x->txfm_search_info.mb_rd_record;
3522     const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3523     if (match_index != -1) {
3524       MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3525       fetch_mb_rd_info(n4, mb_rd_info, rd_stats, x);
3526       return;
3527     }
3528   }
3529 
3530   // If we predict that skip is the optimal RD decision - set the respective
3531   // context and terminate early.
3532   int64_t dist;
3533   if (txfm_params->skip_txfm_level &&
3534       predict_skip_txfm(x, bsize, &dist,
3535                         cpi->common.features.reduced_tx_set_used)) {
3536     set_skip_txfm(x, rd_stats, bsize, dist);
3537     // Save the RD search results into mb_rd_record.
3538     if (is_mb_rd_hash_enabled)
3539       save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3540     return;
3541   }
3542 #if CONFIG_SPEED_STATS
3543   ++x->txfm_search_info.tx_search_count;
3544 #endif  // CONFIG_SPEED_STATS
3545 
3546   const int64_t rd =
3547       select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd);
3548 
3549   if (rd == INT64_MAX) {
3550     // We should always find at least one candidate unless ref_best_rd is less
3551     // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
3552     // might have failed to find something better)
3553     assert(ref_best_rd != INT64_MAX);
3554     av1_invalid_rd_stats(rd_stats);
3555     return;
3556   }
3557 
3558   // Save the RD search results into mb_rd_record.
3559   if (is_mb_rd_hash_enabled) {
3560     assert(mb_rd_record != NULL);
3561     save_mb_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3562   }
3563 }
3564 
av1_pick_uniform_tx_size_type_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bs,int64_t ref_best_rd)3565 void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3566                                        RD_STATS *rd_stats, BLOCK_SIZE bs,
3567                                        int64_t ref_best_rd) {
3568   MACROBLOCKD *const xd = &x->e_mbd;
3569   MB_MODE_INFO *const mbmi = xd->mi[0];
3570   const TxfmSearchParams *tx_params = &x->txfm_search_params;
3571   assert(bs == mbmi->bsize);
3572   const int is_inter = is_inter_block(mbmi);
3573   const int mi_row = xd->mi_row;
3574   const int mi_col = xd->mi_col;
3575 
3576   av1_init_rd_stats(rd_stats);
3577 
3578   // Hashing based speed feature for inter blocks. If the hash of the residue
3579   // block is found in the table, use previously saved search results and
3580   // terminate early.
3581   uint32_t hash = 0;
3582   MB_RD_RECORD *mb_rd_record = NULL;
3583   const int num_blks = bsize_to_num_blk(bs);
3584   if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) {
3585     const int within_border =
3586         mi_row >= xd->tile.mi_row_start &&
3587         (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
3588         mi_col >= xd->tile.mi_col_start &&
3589         (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
3590     if (within_border) {
3591       hash = get_block_residue_hash(x, bs);
3592       mb_rd_record = x->txfm_search_info.mb_rd_record;
3593       const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3594       if (match_index != -1) {
3595         MB_RD_INFO *mb_rd_info = &mb_rd_record->mb_rd_info[match_index];
3596         fetch_mb_rd_info(num_blks, mb_rd_info, rd_stats, x);
3597         return;
3598       }
3599     }
3600   }
3601 
3602   // If we predict that skip is the optimal RD decision - set the respective
3603   // context and terminate early.
3604   int64_t dist;
3605   if (tx_params->skip_txfm_level && is_inter &&
3606       !xd->lossless[mbmi->segment_id] &&
3607       predict_skip_txfm(x, bs, &dist,
3608                         cpi->common.features.reduced_tx_set_used)) {
3609     // Populate rdstats as per skip decision
3610     set_skip_txfm(x, rd_stats, bs, dist);
3611     // Save the RD search results into mb_rd_record.
3612     if (mb_rd_record) {
3613       save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3614     }
3615     return;
3616   }
3617 
3618   if (xd->lossless[mbmi->segment_id]) {
3619     // Lossless mode can only pick the smallest (4x4) transform size.
3620     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3621   } else if (tx_params->tx_size_search_method == USE_LARGESTALL) {
3622     choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3623   } else {
3624     choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
3625   }
3626 
3627   // Save the RD search results into mb_rd_record for possible reuse in future.
3628   if (mb_rd_record) {
3629     save_mb_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3630   }
3631 }
3632 
av1_txfm_uvrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3633 int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
3634                   BLOCK_SIZE bsize, int64_t ref_best_rd) {
3635   av1_init_rd_stats(rd_stats);
3636   if (ref_best_rd < 0) return 0;
3637   if (!x->e_mbd.is_chroma_ref) return 1;
3638 
3639   MACROBLOCKD *const xd = &x->e_mbd;
3640   MB_MODE_INFO *const mbmi = xd->mi[0];
3641   struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
3642   const int is_inter = is_inter_block(mbmi);
3643   int64_t this_rd = 0, skip_txfm_rd = 0;
3644   const BLOCK_SIZE plane_bsize =
3645       get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
3646 
3647   if (is_inter) {
3648     for (int plane = 1; plane < MAX_MB_PLANE; ++plane)
3649       av1_subtract_plane(x, plane_bsize, plane);
3650   }
3651 
3652   const int skip_trellis = 0;
3653   const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
3654   int is_cost_valid = 1;
3655   for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
3656     RD_STATS this_rd_stats;
3657     int64_t chroma_ref_best_rd = ref_best_rd;
3658     // For inter blocks, refined ref_best_rd is used for early exit
3659     // For intra blocks, even though current rd crosses ref_best_rd, early
3660     // exit is not recommended as current rd is used for gating subsequent
3661     // modes as well (say, for angular modes)
3662     // TODO(any): Extend the early exit mechanism for intra modes as well
3663     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
3664         chroma_ref_best_rd != INT64_MAX)
3665       chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
3666     av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
3667                          plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis);
3668     if (this_rd_stats.rate == INT_MAX) {
3669       is_cost_valid = 0;
3670       break;
3671     }
3672     av1_merge_rd_stats(rd_stats, &this_rd_stats);
3673     this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3674     skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
3675     if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
3676       is_cost_valid = 0;
3677       break;
3678     }
3679   }
3680 
3681   if (!is_cost_valid) {
3682     // reset cost value
3683     av1_invalid_rd_stats(rd_stats);
3684   }
3685 
3686   return is_cost_valid;
3687 }
3688 
av1_txfm_rd_in_plane(MACROBLOCK * x,const AV1_COMP * cpi,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t current_rd,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3689 void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3690                           RD_STATS *rd_stats, int64_t ref_best_rd,
3691                           int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
3692                           TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
3693                           int skip_trellis) {
3694   assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size));
3695 
3696   if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
3697       txsize_sqr_up_map[tx_size] == TX_64X64) {
3698     av1_invalid_rd_stats(rd_stats);
3699     return;
3700   }
3701 
3702   if (current_rd > ref_best_rd) {
3703     av1_invalid_rd_stats(rd_stats);
3704     return;
3705   }
3706 
3707   MACROBLOCKD *const xd = &x->e_mbd;
3708   const struct macroblockd_plane *const pd = &xd->plane[plane];
3709   struct rdcost_block_args args;
3710   av1_zero(args);
3711   args.x = x;
3712   args.cpi = cpi;
3713   args.best_rd = ref_best_rd;
3714   args.current_rd = current_rd;
3715   args.ftxs_mode = ftxs_mode;
3716   args.skip_trellis = skip_trellis;
3717   av1_init_rd_stats(&args.rd_stats);
3718 
3719   av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
3720   av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
3721                                          &args);
3722 
3723   MB_MODE_INFO *const mbmi = xd->mi[0];
3724   const int is_inter = is_inter_block(mbmi);
3725   const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3726 
3727   if (invalid_rd) {
3728     av1_invalid_rd_stats(rd_stats);
3729   } else {
3730     *rd_stats = args.rd_stats;
3731   }
3732 }
3733 
av1_txfm_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int mode_rate,int64_t ref_best_rd)3734 int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
3735                     RD_STATS *rd_stats, RD_STATS *rd_stats_y,
3736                     RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
3737   MACROBLOCKD *const xd = &x->e_mbd;
3738   TxfmSearchParams *txfm_params = &x->txfm_search_params;
3739   const int skip_ctx = av1_get_skip_txfm_context(xd);
3740   const int skip_txfm_cost[2] = { x->mode_costs.skip_txfm_cost[skip_ctx][0],
3741                                   x->mode_costs.skip_txfm_cost[skip_ctx][1] };
3742   const int64_t min_header_rate =
3743       mode_rate + AOMMIN(skip_txfm_cost[0], skip_txfm_cost[1]);
3744   // Account for minimum skip and non_skip rd.
3745   // Eventually either one of them will be added to mode_rate
3746   const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
3747   if (min_header_rd_possible > ref_best_rd) {
3748     av1_invalid_rd_stats(rd_stats_y);
3749     return 0;
3750   }
3751 
3752   const AV1_COMMON *cm = &cpi->common;
3753   MB_MODE_INFO *const mbmi = xd->mi[0];
3754   const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
3755   const int64_t rd_thresh =
3756       ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
3757   av1_init_rd_stats(rd_stats);
3758   av1_init_rd_stats(rd_stats_y);
3759   rd_stats->rate = mode_rate;
3760 
3761   // cost and distortion
3762   av1_subtract_plane(x, bsize, 0);
3763   if (txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3764       !xd->lossless[mbmi->segment_id]) {
3765     av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3766 #if CONFIG_COLLECT_RD_STATS == 2
3767     PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
3768 #endif  // CONFIG_COLLECT_RD_STATS == 2
3769   } else {
3770     av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3771     memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
3772     for (int i = 0; i < xd->height * xd->width; ++i)
3773       set_blk_skip(x->txfm_search_info.blk_skip, 0, i, rd_stats_y->skip_txfm);
3774   }
3775 
3776   if (rd_stats_y->rate == INT_MAX) return 0;
3777 
3778   av1_merge_rd_stats(rd_stats, rd_stats_y);
3779 
3780   const int64_t non_skip_txfm_rdcosty =
3781       RDCOST(x->rdmult, rd_stats->rate + skip_txfm_cost[0], rd_stats->dist);
3782   const int64_t skip_txfm_rdcosty =
3783       RDCOST(x->rdmult, mode_rate + skip_txfm_cost[1], rd_stats->sse);
3784   const int64_t min_rdcosty = AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty);
3785   if (min_rdcosty > ref_best_rd) return 0;
3786 
3787   av1_init_rd_stats(rd_stats_uv);
3788   const int num_planes = av1_num_planes(cm);
3789   if (num_planes > 1) {
3790     int64_t ref_best_chroma_rd = ref_best_rd;
3791     // Calculate best rd cost possible for chroma
3792     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
3793         (ref_best_chroma_rd != INT64_MAX)) {
3794       ref_best_chroma_rd = (ref_best_chroma_rd -
3795                             AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty));
3796     }
3797     const int is_cost_valid_uv =
3798         av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
3799     if (!is_cost_valid_uv) return 0;
3800     av1_merge_rd_stats(rd_stats, rd_stats_uv);
3801   }
3802 
3803   int choose_skip_txfm = rd_stats->skip_txfm;
3804   if (!choose_skip_txfm && !xd->lossless[mbmi->segment_id]) {
3805     const int64_t rdcost_no_skip_txfm = RDCOST(
3806         x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_txfm_cost[0],
3807         rd_stats->dist);
3808     const int64_t rdcost_skip_txfm =
3809         RDCOST(x->rdmult, skip_txfm_cost[1], rd_stats->sse);
3810     if (rdcost_no_skip_txfm >= rdcost_skip_txfm) choose_skip_txfm = 1;
3811   }
3812   if (choose_skip_txfm) {
3813     rd_stats_y->rate = 0;
3814     rd_stats_uv->rate = 0;
3815     rd_stats->rate = mode_rate + skip_txfm_cost[1];
3816     rd_stats->dist = rd_stats->sse;
3817     rd_stats_y->dist = rd_stats_y->sse;
3818     rd_stats_uv->dist = rd_stats_uv->sse;
3819     mbmi->skip_txfm = 1;
3820     if (rd_stats->skip_txfm) {
3821       const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3822       if (tmprd > ref_best_rd) return 0;
3823     }
3824   } else {
3825     rd_stats->rate += skip_txfm_cost[0];
3826     mbmi->skip_txfm = 0;
3827   }
3828 
3829   return 1;
3830 }
3831