• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/common/cfl.h"
13 #include "av1/common/reconintra.h"
14 #include "av1/encoder/block.h"
15 #include "av1/encoder/hybrid_fwd_txfm.h"
16 #include "av1/common/idct.h"
17 #include "av1/encoder/model_rd.h"
18 #include "av1/encoder/random.h"
19 #include "av1/encoder/rdopt_utils.h"
20 #include "av1/encoder/sorting_network.h"
21 #include "av1/encoder/tx_prune_model_weights.h"
22 #include "av1/encoder/tx_search.h"
23 #include "av1/encoder/txb_rdopt.h"
24 
25 #define PROB_THRESH_OFFSET_TX_TYPE 100
26 
27 struct rdcost_block_args {
28   const AV1_COMP *cpi;
29   MACROBLOCK *x;
30   ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
31   ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
32   RD_STATS rd_stats;
33   int64_t current_rd;
34   int64_t best_rd;
35   int exit_early;
36   int incomplete_exit;
37   FAST_TX_SEARCH_MODE ftxs_mode;
38   int skip_trellis;
39 };
40 
41 typedef struct {
42   int64_t rd;
43   int txb_entropy_ctx;
44   TX_TYPE tx_type;
45 } TxCandidateInfo;
46 
47 typedef struct {
48   int leaf;
49   int8_t children[4];
50 } RD_RECORD_IDX_NODE;
51 
52 typedef struct tx_size_rd_info_node {
53   TXB_RD_INFO *rd_info_array;  // Points to array of size TX_TYPES.
54   struct tx_size_rd_info_node *children[4];
55 } TXB_RD_INFO_NODE;
56 
57 // origin_threshold * 128 / 100
58 static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
59   {
60       64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
61       68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
62   },
63   {
64       88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
65       68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
66   },
67   {
68       90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
69       74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
70   },
71 };
72 
73 // lookup table for predict_skip_txfm
74 // int max_tx_size = max_txsize_rect_lookup[bsize];
75 // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
76 //   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
77 static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
78   TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
79   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
80   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
81   TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
82 };
83 
84 // look-up table for sqrt of number of pixels in a transform block
85 // rounded up to the nearest integer.
86 static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] = { 4,  8,  16, 32, 32, 6,  6,
87                                                      12, 12, 23, 23, 32, 32, 8,
88                                                      8,  16, 16, 23, 23 };
89 
find_tx_size_rd_info(TXB_RD_RECORD * cur_record,const uint32_t hash)90 static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record,
91                                 const uint32_t hash) {
92   // Linear search through the circular buffer to find matching hash.
93   for (int i = cur_record->index_start - 1; i >= 0; i--) {
94     if (cur_record->hash_vals[i] == hash) return i;
95   }
96   for (int i = cur_record->num - 1; i >= cur_record->index_start; i--) {
97     if (cur_record->hash_vals[i] == hash) return i;
98   }
99   int index;
100   // If not found - add new RD info into the buffer and return its index
101   if (cur_record->num < TX_SIZE_RD_RECORD_BUFFER_LEN) {
102     index = (cur_record->index_start + cur_record->num) %
103             TX_SIZE_RD_RECORD_BUFFER_LEN;
104     cur_record->num++;
105   } else {
106     index = cur_record->index_start;
107     cur_record->index_start =
108         (cur_record->index_start + 1) % TX_SIZE_RD_RECORD_BUFFER_LEN;
109   }
110 
111   cur_record->hash_vals[index] = hash;
112   av1_zero(cur_record->tx_rd_info[index]);
113   return index;
114 }
115 
116 static const RD_RECORD_IDX_NODE rd_record_tree_8x8[] = {
117   { 1, { 0 } },
118 };
119 
120 static const RD_RECORD_IDX_NODE rd_record_tree_8x16[] = {
121   { 0, { 1, 2, -1, -1 } },
122   { 1, { 0, 0, 0, 0 } },
123   { 1, { 0, 0, 0, 0 } },
124 };
125 
126 static const RD_RECORD_IDX_NODE rd_record_tree_16x8[] = {
127   { 0, { 1, 2, -1, -1 } },
128   { 1, { 0 } },
129   { 1, { 0 } },
130 };
131 
132 static const RD_RECORD_IDX_NODE rd_record_tree_16x16[] = {
133   { 0, { 1, 2, 3, 4 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } },
134 };
135 
136 static const RD_RECORD_IDX_NODE rd_record_tree_1_2[] = {
137   { 0, { 1, 2, -1, -1 } },
138   { 0, { 3, 4, 5, 6 } },
139   { 0, { 7, 8, 9, 10 } },
140 };
141 
142 static const RD_RECORD_IDX_NODE rd_record_tree_2_1[] = {
143   { 0, { 1, 2, -1, -1 } },
144   { 0, { 3, 4, 7, 8 } },
145   { 0, { 5, 6, 9, 10 } },
146 };
147 
148 static const RD_RECORD_IDX_NODE rd_record_tree_sqr[] = {
149   { 0, { 1, 2, 3, 4 } },     { 0, { 5, 6, 9, 10 } },    { 0, { 7, 8, 11, 12 } },
150   { 0, { 13, 14, 17, 18 } }, { 0, { 15, 16, 19, 20 } },
151 };
152 
153 static const RD_RECORD_IDX_NODE rd_record_tree_64x128[] = {
154   { 0, { 2, 3, 4, 5 } },     { 0, { 6, 7, 8, 9 } },
155   { 0, { 10, 11, 14, 15 } }, { 0, { 12, 13, 16, 17 } },
156   { 0, { 18, 19, 22, 23 } }, { 0, { 20, 21, 24, 25 } },
157   { 0, { 26, 27, 30, 31 } }, { 0, { 28, 29, 32, 33 } },
158   { 0, { 34, 35, 38, 39 } }, { 0, { 36, 37, 40, 41 } },
159 };
160 
161 static const RD_RECORD_IDX_NODE rd_record_tree_128x64[] = {
162   { 0, { 2, 3, 6, 7 } },     { 0, { 4, 5, 8, 9 } },
163   { 0, { 10, 11, 18, 19 } }, { 0, { 12, 13, 20, 21 } },
164   { 0, { 14, 15, 22, 23 } }, { 0, { 16, 17, 24, 25 } },
165   { 0, { 26, 27, 34, 35 } }, { 0, { 28, 29, 36, 37 } },
166   { 0, { 30, 31, 38, 39 } }, { 0, { 32, 33, 40, 41 } },
167 };
168 
169 static const RD_RECORD_IDX_NODE rd_record_tree_128x128[] = {
170   { 0, { 4, 5, 8, 9 } },     { 0, { 6, 7, 10, 11 } },
171   { 0, { 12, 13, 16, 17 } }, { 0, { 14, 15, 18, 19 } },
172   { 0, { 20, 21, 28, 29 } }, { 0, { 22, 23, 30, 31 } },
173   { 0, { 24, 25, 32, 33 } }, { 0, { 26, 27, 34, 35 } },
174   { 0, { 36, 37, 44, 45 } }, { 0, { 38, 39, 46, 47 } },
175   { 0, { 40, 41, 48, 49 } }, { 0, { 42, 43, 50, 51 } },
176   { 0, { 52, 53, 60, 61 } }, { 0, { 54, 55, 62, 63 } },
177   { 0, { 56, 57, 64, 65 } }, { 0, { 58, 59, 66, 67 } },
178   { 0, { 68, 69, 76, 77 } }, { 0, { 70, 71, 78, 79 } },
179   { 0, { 72, 73, 80, 81 } }, { 0, { 74, 75, 82, 83 } },
180 };
181 
182 static const RD_RECORD_IDX_NODE rd_record_tree_1_4[] = {
183   { 0, { 1, -1, 2, -1 } },
184   { 0, { 3, 4, -1, -1 } },
185   { 0, { 5, 6, -1, -1 } },
186 };
187 
188 static const RD_RECORD_IDX_NODE rd_record_tree_4_1[] = {
189   { 0, { 1, 2, -1, -1 } },
190   { 0, { 3, 4, -1, -1 } },
191   { 0, { 5, 6, -1, -1 } },
192 };
193 
194 static const RD_RECORD_IDX_NODE *rd_record_tree[BLOCK_SIZES_ALL] = {
195   NULL,                    // BLOCK_4X4
196   NULL,                    // BLOCK_4X8
197   NULL,                    // BLOCK_8X4
198   rd_record_tree_8x8,      // BLOCK_8X8
199   rd_record_tree_8x16,     // BLOCK_8X16
200   rd_record_tree_16x8,     // BLOCK_16X8
201   rd_record_tree_16x16,    // BLOCK_16X16
202   rd_record_tree_1_2,      // BLOCK_16X32
203   rd_record_tree_2_1,      // BLOCK_32X16
204   rd_record_tree_sqr,      // BLOCK_32X32
205   rd_record_tree_1_2,      // BLOCK_32X64
206   rd_record_tree_2_1,      // BLOCK_64X32
207   rd_record_tree_sqr,      // BLOCK_64X64
208   rd_record_tree_64x128,   // BLOCK_64X128
209   rd_record_tree_128x64,   // BLOCK_128X64
210   rd_record_tree_128x128,  // BLOCK_128X128
211   NULL,                    // BLOCK_4X16
212   NULL,                    // BLOCK_16X4
213   rd_record_tree_1_4,      // BLOCK_8X32
214   rd_record_tree_4_1,      // BLOCK_32X8
215   rd_record_tree_1_4,      // BLOCK_16X64
216   rd_record_tree_4_1,      // BLOCK_64X16
217 };
218 
219 static const int rd_record_tree_size[BLOCK_SIZES_ALL] = {
220   0,                                                            // BLOCK_4X4
221   0,                                                            // BLOCK_4X8
222   0,                                                            // BLOCK_8X4
223   sizeof(rd_record_tree_8x8) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X8
224   sizeof(rd_record_tree_8x16) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_8X16
225   sizeof(rd_record_tree_16x8) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_16X8
226   sizeof(rd_record_tree_16x16) / sizeof(RD_RECORD_IDX_NODE),    // BLOCK_16X16
227   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X32
228   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X16
229   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X32
230   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X64
231   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X32
232   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X64
233   sizeof(rd_record_tree_64x128) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_64X128
234   sizeof(rd_record_tree_128x64) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_128X64
235   sizeof(rd_record_tree_128x128) / sizeof(RD_RECORD_IDX_NODE),  // BLOCK_128X128
236   0,                                                            // BLOCK_4X16
237   0,                                                            // BLOCK_16X4
238   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X32
239   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X8
240   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X64
241   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X16
242 };
243 
init_rd_record_tree(TXB_RD_INFO_NODE * tree,BLOCK_SIZE bsize)244 static INLINE void init_rd_record_tree(TXB_RD_INFO_NODE *tree,
245                                        BLOCK_SIZE bsize) {
246   const RD_RECORD_IDX_NODE *rd_record = rd_record_tree[bsize];
247   const int size = rd_record_tree_size[bsize];
248   for (int i = 0; i < size; ++i) {
249     if (rd_record[i].leaf) {
250       av1_zero(tree[i].children);
251     } else {
252       for (int j = 0; j < 4; ++j) {
253         const int8_t idx = rd_record[i].children[j];
254         tree[i].children[j] = idx > 0 ? &tree[idx] : NULL;
255       }
256     }
257   }
258 }
259 
260 // Go through all TX blocks that could be used in TX size search, compute
261 // residual hash values for them and find matching RD info that stores previous
262 // RD search results for these TX blocks. The idea is to prevent repeated
263 // rate/distortion computations that happen because of the combination of
264 // partition and TX size search. The resulting RD info records are returned in
265 // the form of a quadtree for easier access in actual TX size search.
find_tx_size_rd_records(MACROBLOCK * x,BLOCK_SIZE bsize,TXB_RD_INFO_NODE * dst_rd_info)266 static int find_tx_size_rd_records(MACROBLOCK *x, BLOCK_SIZE bsize,
267                                    TXB_RD_INFO_NODE *dst_rd_info) {
268   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
269   TXB_RD_RECORD *rd_records_table[4] = {
270     txfm_info->txb_rd_records->txb_rd_record_8X8,
271     txfm_info->txb_rd_records->txb_rd_record_16X16,
272     txfm_info->txb_rd_records->txb_rd_record_32X32,
273     txfm_info->txb_rd_records->txb_rd_record_64X64
274   };
275   const TX_SIZE max_square_tx_size = max_txsize_lookup[bsize];
276   const int bw = block_size_wide[bsize];
277   const int bh = block_size_high[bsize];
278 
279   // Hashing is performed only for square TX sizes larger than TX_4X4
280   if (max_square_tx_size < TX_8X8) return 0;
281   const int diff_stride = bw;
282   const struct macroblock_plane *const p = &x->plane[0];
283   const int16_t *diff = &p->src_diff[0];
284   init_rd_record_tree(dst_rd_info, bsize);
285   // Coordinates of the top-left corner of current block within the superblock
286   // measured in pixels:
287   const int mi_row = x->e_mbd.mi_row;
288   const int mi_col = x->e_mbd.mi_col;
289   const int mi_row_in_sb = (mi_row % MAX_MIB_SIZE) << MI_SIZE_LOG2;
290   const int mi_col_in_sb = (mi_col % MAX_MIB_SIZE) << MI_SIZE_LOG2;
291   int cur_rd_info_idx = 0;
292   int cur_tx_depth = 0;
293   TX_SIZE cur_tx_size = max_txsize_rect_lookup[bsize];
294   while (cur_tx_depth <= MAX_VARTX_DEPTH) {
295     const int cur_tx_bw = tx_size_wide[cur_tx_size];
296     const int cur_tx_bh = tx_size_high[cur_tx_size];
297     if (cur_tx_bw < 8 || cur_tx_bh < 8) break;
298     const TX_SIZE next_tx_size = sub_tx_size_map[cur_tx_size];
299     const int tx_size_idx = cur_tx_size - TX_8X8;
300     for (int row = 0; row < bh; row += cur_tx_bh) {
301       for (int col = 0; col < bw; col += cur_tx_bw) {
302         if (cur_tx_bw != cur_tx_bh) {
303           // Use dummy nodes for all rectangular transforms within the
304           // TX size search tree.
305           dst_rd_info[cur_rd_info_idx].rd_info_array = NULL;
306         } else {
307           // Get spatial location of this TX block within the superblock
308           // (measured in cur_tx_bsize units).
309           const int row_in_sb = (mi_row_in_sb + row) / cur_tx_bh;
310           const int col_in_sb = (mi_col_in_sb + col) / cur_tx_bw;
311 
312           int16_t hash_data[MAX_SB_SQUARE];
313           int16_t *cur_hash_row = hash_data;
314           const int16_t *cur_diff_row = diff + row * diff_stride + col;
315           for (int i = 0; i < cur_tx_bh; i++) {
316             memcpy(cur_hash_row, cur_diff_row, sizeof(*hash_data) * cur_tx_bw);
317             cur_hash_row += cur_tx_bw;
318             cur_diff_row += diff_stride;
319           }
320           const int hash = av1_get_crc32c_value(
321               &txfm_info->txb_rd_records->mb_rd_record.crc_calculator,
322               (uint8_t *)hash_data, 2 * cur_tx_bw * cur_tx_bh);
323           // Find corresponding RD info based on the hash value.
324           const int record_idx =
325               row_in_sb * (MAX_MIB_SIZE >> (tx_size_idx + 1)) + col_in_sb;
326           TXB_RD_RECORD *records = &rd_records_table[tx_size_idx][record_idx];
327           int idx = find_tx_size_rd_info(records, hash);
328           dst_rd_info[cur_rd_info_idx].rd_info_array =
329               &records->tx_rd_info[idx];
330         }
331         ++cur_rd_info_idx;
332       }
333     }
334     cur_tx_size = next_tx_size;
335     ++cur_tx_depth;
336   }
337   return 1;
338 }
339 
get_block_residue_hash(MACROBLOCK * x,BLOCK_SIZE bsize)340 static INLINE uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
341   const int rows = block_size_high[bsize];
342   const int cols = block_size_wide[bsize];
343   const int16_t *diff = x->plane[0].src_diff;
344   const uint32_t hash = av1_get_crc32c_value(
345       &x->txfm_search_info.txb_rd_records->mb_rd_record.crc_calculator,
346       (uint8_t *)diff, 2 * rows * cols);
347   return (hash << 5) + bsize;
348 }
349 
find_mb_rd_info(const MB_RD_RECORD * const mb_rd_record,const int64_t ref_best_rd,const uint32_t hash)350 static INLINE int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
351                                       const int64_t ref_best_rd,
352                                       const uint32_t hash) {
353   int32_t match_index = -1;
354   if (ref_best_rd != INT64_MAX) {
355     for (int i = 0; i < mb_rd_record->num; ++i) {
356       const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
357       // If there is a match in the tx_rd_record, fetch the RD decision and
358       // terminate early.
359       if (mb_rd_record->tx_rd_info[index].hash_value == hash) {
360         match_index = index;
361         break;
362       }
363     }
364   }
365   return match_index;
366 }
367 
fetch_tx_rd_info(int n4,const MB_RD_INFO * const tx_rd_info,RD_STATS * const rd_stats,MACROBLOCK * const x)368 static AOM_INLINE void fetch_tx_rd_info(int n4,
369                                         const MB_RD_INFO *const tx_rd_info,
370                                         RD_STATS *const rd_stats,
371                                         MACROBLOCK *const x) {
372   MACROBLOCKD *const xd = &x->e_mbd;
373   MB_MODE_INFO *const mbmi = xd->mi[0];
374   mbmi->tx_size = tx_rd_info->tx_size;
375   memcpy(x->txfm_search_info.blk_skip, tx_rd_info->blk_skip,
376          sizeof(tx_rd_info->blk_skip[0]) * n4);
377   av1_copy(mbmi->inter_tx_size, tx_rd_info->inter_tx_size);
378   av1_copy_array(xd->tx_type_map, tx_rd_info->tx_type_map, n4);
379   *rd_stats = tx_rd_info->rd_stats;
380 }
381 
382 // Compute the pixel domain distortion from diff on all visible 4x4s in the
383 // transform block.
pixel_diff_dist(const MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8)384 static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane,
385                                       int blk_row, int blk_col,
386                                       const BLOCK_SIZE plane_bsize,
387                                       const BLOCK_SIZE tx_bsize,
388                                       unsigned int *block_mse_q8) {
389   int visible_rows, visible_cols;
390   const MACROBLOCKD *xd = &x->e_mbd;
391   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
392                      NULL, &visible_cols, &visible_rows);
393   const int diff_stride = block_size_wide[plane_bsize];
394   const int16_t *diff = x->plane[plane].src_diff;
395 
396   diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
397   uint64_t sse =
398       aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
399   if (block_mse_q8 != NULL) {
400     if (visible_cols > 0 && visible_rows > 0)
401       *block_mse_q8 =
402           (unsigned int)((256 * sse) / (visible_cols * visible_rows));
403     else
404       *block_mse_q8 = UINT_MAX;
405   }
406   return sse;
407 }
408 
409 // Computes the residual block's SSE and mean on all visible 4x4s in the
410 // transform block
pixel_diff_stats(MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8,int64_t * per_px_mean,uint64_t * block_var)411 static INLINE int64_t pixel_diff_stats(
412     MACROBLOCK *x, int plane, int blk_row, int blk_col,
413     const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize,
414     unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) {
415   int visible_rows, visible_cols;
416   const MACROBLOCKD *xd = &x->e_mbd;
417   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
418                      NULL, &visible_cols, &visible_rows);
419   const int diff_stride = block_size_wide[plane_bsize];
420   const int16_t *diff = x->plane[plane].src_diff;
421 
422   diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
423   uint64_t sse = 0;
424   int sum = 0;
425   sse = aom_sum_sse_2d_i16(diff, diff_stride, visible_cols, visible_rows, &sum);
426   if (visible_cols > 0 && visible_rows > 0) {
427     double norm_factor = 1.0 / (visible_cols * visible_rows);
428     int sign_sum = sum > 0 ? 1 : -1;
429     // Conversion to transform domain
430     *per_px_mean = (int64_t)(norm_factor * abs(sum)) << 7;
431     *per_px_mean = sign_sum * (*per_px_mean);
432     *block_mse_q8 = (unsigned int)(norm_factor * (256 * sse));
433     *block_var = (uint64_t)(sse - (uint64_t)(norm_factor * sum * sum));
434   } else {
435     *block_mse_q8 = UINT_MAX;
436   }
437   return sse;
438 }
439 
440 // Uses simple features on top of DCT coefficients to quickly predict
441 // whether optimal RD decision is to skip encoding the residual.
442 // The sse value is stored in dist.
predict_skip_txfm(MACROBLOCK * x,BLOCK_SIZE bsize,int64_t * dist,int reduced_tx_set)443 static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
444                              int reduced_tx_set) {
445   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
446   const int bw = block_size_wide[bsize];
447   const int bh = block_size_high[bsize];
448   const MACROBLOCKD *xd = &x->e_mbd;
449   const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
450 
451   *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
452 
453   const int64_t mse = *dist / bw / bh;
454   // Normalized quantizer takes the transform upscaling factor (8 for tx size
455   // smaller than 32) into account.
456   const int16_t normalized_dc_q = dc_q >> 3;
457   const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
458   // For faster early skip decision, use dist to compare against threshold so
459   // that quality risk is less for the skip=1 decision. Otherwise, use mse
460   // since the fwd_txfm coeff checks will take care of quality
461   // TODO(any): Use dist to return 0 when skip_txfm_level is 1
462   int64_t pred_err = (txfm_params->skip_txfm_level >= 2) ? *dist : mse;
463   // Predict not to skip when error is larger than threshold.
464   if (pred_err > mse_thresh) return 0;
465   // Return as skip otherwise for aggressive early skip
466   else if (txfm_params->skip_txfm_level >= 2)
467     return 1;
468 
469   const int max_tx_size = max_predict_sf_tx_size[bsize];
470   const int tx_h = tx_size_high[max_tx_size];
471   const int tx_w = tx_size_wide[max_tx_size];
472   DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
473   TxfmParam param;
474   param.tx_type = DCT_DCT;
475   param.tx_size = max_tx_size;
476   param.bd = xd->bd;
477   param.is_hbd = is_cur_buf_hbd(xd);
478   param.lossless = 0;
479   param.tx_set_type = av1_get_ext_tx_set_type(
480       param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
481   const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
482   const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
483   const int16_t *src_diff = x->plane[0].src_diff;
484   const int n_coeff = tx_w * tx_h;
485   const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
486   const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
487   const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
488   for (int row = 0; row < bh; row += tx_h) {
489     for (int col = 0; col < bw; col += tx_w) {
490       av1_fwd_txfm(src_diff + col, coefs, bw, &param);
491       // Operating on TX domain, not pixels; we want the QTX quantizers
492       const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
493       if (dc_coef >= dc_thresh) return 0;
494       for (int i = 1; i < n_coeff; ++i) {
495         const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
496         if (ac_coef >= ac_thresh) return 0;
497       }
498     }
499     src_diff += tx_h * bw;
500   }
501   return 1;
502 }
503 
504 // Used to set proper context for early termination with skip = 1.
set_skip_txfm(MACROBLOCK * x,RD_STATS * rd_stats,int bsize,int64_t dist)505 static AOM_INLINE void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats,
506                                      int bsize, int64_t dist) {
507   MACROBLOCKD *const xd = &x->e_mbd;
508   MB_MODE_INFO *const mbmi = xd->mi[0];
509   const int n4 = bsize_to_num_blk(bsize);
510   const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
511   memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
512   memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
513   mbmi->tx_size = tx_size;
514   for (int i = 0; i < n4; ++i)
515     set_blk_skip(x->txfm_search_info.blk_skip, 0, i, 1);
516   rd_stats->skip_txfm = 1;
517   if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
518   rd_stats->dist = rd_stats->sse = (dist << 4);
519   // Though decision is to make the block as skip based on luma stats,
520   // it is possible that block becomes non skip after chroma rd. In addition
521   // intermediate non skip costs calculated by caller function will be
522   // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
523   // accounted). Hence intermediate rate is populated to code the luma tx blks
524   // as skip, the caller function based on final rd decision (i.e., skip vs
525   // non-skip) sets the final rate accordingly. Here the rate populated
526   // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
527   // size possible) in the current block. Eg: For 128*128 block, rate would be
528   // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
529   // block as 'all zeros'
530   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
531   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
532   av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
533   ENTROPY_CONTEXT *ta = ctxa;
534   ENTROPY_CONTEXT *tl = ctxl;
535   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
536   TXB_CTX txb_ctx;
537   get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
538   const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
539                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
540   rd_stats->rate = zero_blk_rate *
541                    (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
542                    (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
543 }
544 
save_tx_rd_info(int n4,uint32_t hash,const MACROBLOCK * const x,const RD_STATS * const rd_stats,MB_RD_RECORD * tx_rd_record)545 static AOM_INLINE void save_tx_rd_info(int n4, uint32_t hash,
546                                        const MACROBLOCK *const x,
547                                        const RD_STATS *const rd_stats,
548                                        MB_RD_RECORD *tx_rd_record) {
549   int index;
550   if (tx_rd_record->num < RD_RECORD_BUFFER_LEN) {
551     index =
552         (tx_rd_record->index_start + tx_rd_record->num) % RD_RECORD_BUFFER_LEN;
553     ++tx_rd_record->num;
554   } else {
555     index = tx_rd_record->index_start;
556     tx_rd_record->index_start =
557         (tx_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
558   }
559   MB_RD_INFO *const tx_rd_info = &tx_rd_record->tx_rd_info[index];
560   const MACROBLOCKD *const xd = &x->e_mbd;
561   const MB_MODE_INFO *const mbmi = xd->mi[0];
562   tx_rd_info->hash_value = hash;
563   tx_rd_info->tx_size = mbmi->tx_size;
564   memcpy(tx_rd_info->blk_skip, x->txfm_search_info.blk_skip,
565          sizeof(tx_rd_info->blk_skip[0]) * n4);
566   av1_copy(tx_rd_info->inter_tx_size, mbmi->inter_tx_size);
567   av1_copy_array(tx_rd_info->tx_type_map, xd->tx_type_map, n4);
568   tx_rd_info->rd_stats = *rd_stats;
569 }
570 
get_search_init_depth(int mi_width,int mi_height,int is_inter,const SPEED_FEATURES * sf,int tx_size_search_method)571 static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
572                                  const SPEED_FEATURES *sf,
573                                  int tx_size_search_method) {
574   if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
575 
576   if (sf->tx_sf.tx_size_search_lgr_block) {
577     if (mi_width > mi_size_wide[BLOCK_64X64] ||
578         mi_height > mi_size_high[BLOCK_64X64])
579       return MAX_VARTX_DEPTH;
580   }
581 
582   if (is_inter) {
583     return (mi_height != mi_width)
584                ? sf->tx_sf.inter_tx_size_search_init_depth_rect
585                : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
586   } else {
587     return (mi_height != mi_width)
588                ? sf->tx_sf.intra_tx_size_search_init_depth_rect
589                : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
590   }
591 }
592 
593 static AOM_INLINE void select_tx_block(
594     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
595     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
596     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
597     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
598     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
599     TXB_RD_INFO_NODE *rd_info_node);
600 
601 // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
602 // 0: Do not collect any RD stats
603 // 1: Collect RD stats for transform units
604 // 2: Collect RD stats for partition units
605 #if CONFIG_COLLECT_RD_STATS
606 
get_energy_distribution_fine(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int need_4th,double * hordist,double * verdist)607 static AOM_INLINE void get_energy_distribution_fine(
608     const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
609     const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
610     double *verdist) {
611   const int bw = block_size_wide[bsize];
612   const int bh = block_size_high[bsize];
613   unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
614 
615   if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
616     // Special cases: calculate 'esq' values manually, as we don't have 'vf'
617     // functions for the 16 (very small) sub-blocks of this block.
618     const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
619     const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
620     assert(bw <= 32);
621     assert(bh <= 32);
622     assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
623     if (cpi->common.seq_params->use_highbitdepth) {
624       const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
625       const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
626       for (int i = 0; i < bh; ++i)
627         for (int j = 0; j < bw; ++j) {
628           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
629           esq[index] +=
630               (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
631               (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
632         }
633     } else {
634       for (int i = 0; i < bh; ++i)
635         for (int j = 0; j < bw; ++j) {
636           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
637           esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
638                         (src[j + i * src_stride] - dst[j + i * dst_stride]);
639         }
640     }
641   } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
642     const int f_index =
643         (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
644     assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
645     const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
646     assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
647     assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
648     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
649     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
650                                  dst_stride, &esq[1]);
651     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
652                                  dst_stride, &esq[2]);
653     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
654                                  dst_stride, &esq[3]);
655     src += bh / 4 * src_stride;
656     dst += bh / 4 * dst_stride;
657 
658     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
659     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
660                                  dst_stride, &esq[5]);
661     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
662                                  dst_stride, &esq[6]);
663     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
664                                  dst_stride, &esq[7]);
665     src += bh / 4 * src_stride;
666     dst += bh / 4 * dst_stride;
667 
668     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
669     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
670                                  dst_stride, &esq[9]);
671     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
672                                  dst_stride, &esq[10]);
673     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
674                                  dst_stride, &esq[11]);
675     src += bh / 4 * src_stride;
676     dst += bh / 4 * dst_stride;
677 
678     cpi->ppi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
679     cpi->ppi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4,
680                                  dst_stride, &esq[13]);
681     cpi->ppi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2,
682                                  dst_stride, &esq[14]);
683     cpi->ppi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
684                                  dst_stride, &esq[15]);
685   }
686 
687   double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
688                  esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
689                  esq[12] + esq[13] + esq[14] + esq[15];
690   if (total > 0) {
691     const double e_recip = 1.0 / total;
692     hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
693     hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
694     hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
695     if (need_4th) {
696       hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
697     }
698     verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
699     verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
700     verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
701     if (need_4th) {
702       verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
703     }
704   } else {
705     hordist[0] = verdist[0] = 0.25;
706     hordist[1] = verdist[1] = 0.25;
707     hordist[2] = verdist[2] = 0.25;
708     if (need_4th) {
709       hordist[3] = verdist[3] = 0.25;
710     }
711   }
712 }
713 
get_sse_norm(const int16_t * diff,int stride,int w,int h)714 static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
715   double sum = 0.0;
716   for (int j = 0; j < h; ++j) {
717     for (int i = 0; i < w; ++i) {
718       const int err = diff[j * stride + i];
719       sum += err * err;
720     }
721   }
722   assert(w > 0 && h > 0);
723   return sum / (w * h);
724 }
725 
get_sad_norm(const int16_t * diff,int stride,int w,int h)726 static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
727   double sum = 0.0;
728   for (int j = 0; j < h; ++j) {
729     for (int i = 0; i < w; ++i) {
730       sum += abs(diff[j * stride + i]);
731     }
732   }
733   assert(w > 0 && h > 0);
734   return sum / (w * h);
735 }
736 
get_2x2_normalized_sses_and_sads(const AV1_COMP * const cpi,BLOCK_SIZE tx_bsize,const uint8_t * const src,int src_stride,const uint8_t * const dst,int dst_stride,const int16_t * const src_diff,int diff_stride,double * const sse_norm_arr,double * const sad_norm_arr)737 static AOM_INLINE void get_2x2_normalized_sses_and_sads(
738     const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
739     int src_stride, const uint8_t *const dst, int dst_stride,
740     const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
741     double *const sad_norm_arr) {
742   const BLOCK_SIZE tx_bsize_half =
743       get_partition_subsize(tx_bsize, PARTITION_SPLIT);
744   if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
745     const int half_width = block_size_wide[tx_bsize] / 2;
746     const int half_height = block_size_high[tx_bsize] / 2;
747     for (int row = 0; row < 2; ++row) {
748       for (int col = 0; col < 2; ++col) {
749         const int16_t *const this_src_diff =
750             src_diff + row * half_height * diff_stride + col * half_width;
751         if (sse_norm_arr) {
752           sse_norm_arr[row * 2 + col] =
753               get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
754         }
755         if (sad_norm_arr) {
756           sad_norm_arr[row * 2 + col] =
757               get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
758         }
759       }
760     }
761   } else {  // use function pointers to calculate stats
762     const int half_width = block_size_wide[tx_bsize_half];
763     const int half_height = block_size_high[tx_bsize_half];
764     const int num_samples_half = half_width * half_height;
765     for (int row = 0; row < 2; ++row) {
766       for (int col = 0; col < 2; ++col) {
767         const uint8_t *const this_src =
768             src + row * half_height * src_stride + col * half_width;
769         const uint8_t *const this_dst =
770             dst + row * half_height * dst_stride + col * half_width;
771 
772         if (sse_norm_arr) {
773           unsigned int this_sse;
774           cpi->ppi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
775                                              dst_stride, &this_sse);
776           sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
777         }
778 
779         if (sad_norm_arr) {
780           const unsigned int this_sad = cpi->ppi->fn_ptr[tx_bsize_half].sdf(
781               this_src, src_stride, this_dst, dst_stride);
782           sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
783         }
784       }
785     }
786   }
787 }
788 
789 #if CONFIG_COLLECT_RD_STATS == 1
get_mean(const int16_t * diff,int stride,int w,int h)790 static double get_mean(const int16_t *diff, int stride, int w, int h) {
791   double sum = 0.0;
792   for (int j = 0; j < h; ++j) {
793     for (int i = 0; i < w; ++i) {
794       sum += diff[j * stride + i];
795     }
796   }
797   assert(w > 0 && h > 0);
798   return sum / (w * h);
799 }
PrintTransformUnitStats(const AV1_COMP * const cpi,MACROBLOCK * x,const RD_STATS * const rd_stats,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,TX_TYPE tx_type,int64_t rd)800 static AOM_INLINE void PrintTransformUnitStats(
801     const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
802     int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
803     TX_TYPE tx_type, int64_t rd) {
804   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
805 
806   // Generate small sample to restrict output size.
807   static unsigned int seed = 21743;
808   if (lcg_rand16(&seed) % 256 > 0) return;
809 
810   const char output_file[] = "tu_stats.txt";
811   FILE *fout = fopen(output_file, "a");
812   if (!fout) return;
813 
814   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
815   const MACROBLOCKD *const xd = &x->e_mbd;
816   const int plane = 0;
817   struct macroblock_plane *const p = &x->plane[plane];
818   const struct macroblockd_plane *const pd = &xd->plane[plane];
819   const int txw = tx_size_wide[tx_size];
820   const int txh = tx_size_high[tx_size];
821   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
822   const int q_step = p->dequant_QTX[1] >> dequant_shift;
823   const int num_samples = txw * txh;
824 
825   const double rate_norm = (double)rd_stats->rate / num_samples;
826   const double dist_norm = (double)rd_stats->dist / num_samples;
827 
828   fprintf(fout, "%g %g", rate_norm, dist_norm);
829 
830   const int src_stride = p->src.stride;
831   const uint8_t *const src =
832       &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
833   const int dst_stride = pd->dst.stride;
834   const uint8_t *const dst =
835       &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
836   unsigned int sse;
837   cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
838   const double sse_norm = (double)sse / num_samples;
839 
840   const unsigned int sad =
841       cpi->ppi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
842   const double sad_norm = (double)sad / num_samples;
843 
844   fprintf(fout, " %g %g", sse_norm, sad_norm);
845 
846   const int diff_stride = block_size_wide[plane_bsize];
847   const int16_t *const src_diff =
848       &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
849 
850   double sse_norm_arr[4], sad_norm_arr[4];
851   get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
852                                    dst_stride, src_diff, diff_stride,
853                                    sse_norm_arr, sad_norm_arr);
854   for (int i = 0; i < 4; ++i) {
855     fprintf(fout, " %g", sse_norm_arr[i]);
856   }
857   for (int i = 0; i < 4; ++i) {
858     fprintf(fout, " %g", sad_norm_arr[i]);
859   }
860 
861   const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
862   const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
863 
864   fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
865           tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
866 
867   int model_rate;
868   int64_t model_dist;
869   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
870                                    &model_rate, &model_dist);
871   const double model_rate_norm = (double)model_rate / num_samples;
872   const double model_dist_norm = (double)model_dist / num_samples;
873   fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
874 
875   const double mean = get_mean(src_diff, diff_stride, txw, txh);
876   float hor_corr, vert_corr;
877   av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
878                                   &vert_corr);
879   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
880 
881   double hdist[4] = { 0 }, vdist[4] = { 0 };
882   get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
883                                1, hdist, vdist);
884   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
885           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
886 
887   fprintf(fout, " %d %" PRId64, x->rdmult, rd);
888 
889   fprintf(fout, "\n");
890   fclose(fout);
891 }
892 #endif  // CONFIG_COLLECT_RD_STATS == 1
893 
894 #if CONFIG_COLLECT_RD_STATS >= 2
get_sse(const AV1_COMP * cpi,const MACROBLOCK * x)895 static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
896   const AV1_COMMON *cm = &cpi->common;
897   const int num_planes = av1_num_planes(cm);
898   const MACROBLOCKD *xd = &x->e_mbd;
899   const MB_MODE_INFO *mbmi = xd->mi[0];
900   int64_t total_sse = 0;
901   for (int plane = 0; plane < num_planes; ++plane) {
902     const struct macroblock_plane *const p = &x->plane[plane];
903     const struct macroblockd_plane *const pd = &xd->plane[plane];
904     const BLOCK_SIZE bs =
905         get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
906     unsigned int sse;
907 
908     if (x->skip_chroma_rd && plane) continue;
909 
910     cpi->ppi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf,
911                             pd->dst.stride, &sse);
912     total_sse += sse;
913   }
914   total_sse <<= 4;
915   return total_sse;
916 }
917 
get_est_rate_dist(const TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int * est_residue_cost,int64_t * est_dist)918 static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
919                              int64_t sse, int *est_residue_cost,
920                              int64_t *est_dist) {
921   const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
922   if (md->ready) {
923     if (sse < md->dist_mean) {
924       *est_residue_cost = 0;
925       *est_dist = sse;
926     } else {
927       *est_dist = (int64_t)round(md->dist_mean);
928       const double est_ld = md->a * sse + md->b;
929       // Clamp estimated rate cost by INT_MAX / 2.
930       // TODO(angiebird@google.com): find better solution than clamping.
931       if (fabs(est_ld) < 1e-2) {
932         *est_residue_cost = INT_MAX / 2;
933       } else {
934         double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
935         if (est_residue_cost_dbl < 0) {
936           *est_residue_cost = 0;
937         } else {
938           *est_residue_cost =
939               (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
940         }
941       }
942       if (*est_residue_cost <= 0) {
943         *est_residue_cost = 0;
944         *est_dist = sse;
945       }
946     }
947     return 1;
948   }
949   return 0;
950 }
951 
get_highbd_diff_mean(const uint8_t * src8,int src_stride,const uint8_t * dst8,int dst_stride,int w,int h)952 static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
953                                    const uint8_t *dst8, int dst_stride, int w,
954                                    int h) {
955   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
956   const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
957   double sum = 0.0;
958   for (int j = 0; j < h; ++j) {
959     for (int i = 0; i < w; ++i) {
960       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
961       sum += diff;
962     }
963   }
964   assert(w > 0 && h > 0);
965   return sum / (w * h);
966 }
967 
get_diff_mean(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int w,int h)968 static double get_diff_mean(const uint8_t *src, int src_stride,
969                             const uint8_t *dst, int dst_stride, int w, int h) {
970   double sum = 0.0;
971   for (int j = 0; j < h; ++j) {
972     for (int i = 0; i < w; ++i) {
973       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
974       sum += diff;
975     }
976   }
977   assert(w > 0 && h > 0);
978   return sum / (w * h);
979 }
980 
PrintPredictionUnitStats(const AV1_COMP * const cpi,const TileDataEnc * tile_data,MACROBLOCK * x,const RD_STATS * const rd_stats,BLOCK_SIZE plane_bsize)981 static AOM_INLINE void PrintPredictionUnitStats(const AV1_COMP *const cpi,
982                                                 const TileDataEnc *tile_data,
983                                                 MACROBLOCK *x,
984                                                 const RD_STATS *const rd_stats,
985                                                 BLOCK_SIZE plane_bsize) {
986   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
987 
988   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
989       (tile_data == NULL ||
990        !tile_data->inter_mode_rd_models[plane_bsize].ready))
991     return;
992   (void)tile_data;
993   // Generate small sample to restrict output size.
994   static unsigned int seed = 95014;
995 
996   if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
997       1)
998     return;
999 
1000   const char output_file[] = "pu_stats.txt";
1001   FILE *fout = fopen(output_file, "a");
1002   if (!fout) return;
1003 
1004   MACROBLOCKD *const xd = &x->e_mbd;
1005   const int plane = 0;
1006   struct macroblock_plane *const p = &x->plane[plane];
1007   struct macroblockd_plane *pd = &xd->plane[plane];
1008   const int diff_stride = block_size_wide[plane_bsize];
1009   int bw, bh;
1010   get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
1011                      &bh);
1012   const int num_samples = bw * bh;
1013   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
1014   const int q_step = p->dequant_QTX[1] >> dequant_shift;
1015   const int shift = (xd->bd - 8);
1016 
1017   const double rate_norm = (double)rd_stats->rate / num_samples;
1018   const double dist_norm = (double)rd_stats->dist / num_samples;
1019   const double rdcost_norm =
1020       (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
1021 
1022   fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
1023 
1024   const int src_stride = p->src.stride;
1025   const uint8_t *const src = p->src.buf;
1026   const int dst_stride = pd->dst.stride;
1027   const uint8_t *const dst = pd->dst.buf;
1028   const int16_t *const src_diff = p->src_diff;
1029 
1030   int64_t sse = calculate_sse(xd, p, pd, bw, bh);
1031   const double sse_norm = (double)sse / num_samples;
1032 
1033   const unsigned int sad =
1034       cpi->ppi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
1035   const double sad_norm =
1036       (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
1037 
1038   fprintf(fout, " %g %g", sse_norm, sad_norm);
1039 
1040   double sse_norm_arr[4], sad_norm_arr[4];
1041   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
1042                                    dst_stride, src_diff, diff_stride,
1043                                    sse_norm_arr, sad_norm_arr);
1044   if (shift) {
1045     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
1046     for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
1047   }
1048   for (int i = 0; i < 4; ++i) {
1049     fprintf(fout, " %g", sse_norm_arr[i]);
1050   }
1051   for (int i = 0; i < 4; ++i) {
1052     fprintf(fout, " %g", sad_norm_arr[i]);
1053   }
1054 
1055   fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
1056 
1057   int model_rate;
1058   int64_t model_dist;
1059   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
1060                                    &model_rate, &model_dist);
1061   const double model_rdcost_norm =
1062       (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
1063   const double model_rate_norm = (double)model_rate / num_samples;
1064   const double model_dist_norm = (double)model_dist / num_samples;
1065   fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
1066           model_rdcost_norm);
1067 
1068   double mean;
1069   if (is_cur_buf_hbd(xd)) {
1070     mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
1071                                 pd->dst.stride, bw, bh);
1072   } else {
1073     mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
1074                          bw, bh);
1075   }
1076   mean /= (1 << shift);
1077   float hor_corr, vert_corr;
1078   av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
1079                                   &vert_corr);
1080   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
1081 
1082   double hdist[4] = { 0 }, vdist[4] = { 0 };
1083   get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
1084                                dst_stride, 1, hdist, vdist);
1085   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
1086           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
1087 
1088   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
1089     assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
1090     const int64_t overall_sse = get_sse(cpi, x);
1091     int est_residue_cost = 0;
1092     int64_t est_dist = 0;
1093     get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
1094                       &est_dist);
1095     const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
1096     const double est_dist_norm = (double)est_dist / num_samples;
1097     const double est_rdcost_norm =
1098         (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
1099     fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
1100             est_rdcost_norm);
1101   }
1102 
1103   fprintf(fout, "\n");
1104   fclose(fout);
1105 }
1106 #endif  // CONFIG_COLLECT_RD_STATS >= 2
1107 #endif  // CONFIG_COLLECT_RD_STATS
1108 
inverse_transform_block_facade(MACROBLOCK * const x,int plane,int block,int blk_row,int blk_col,int eob,int reduced_tx_set)1109 static AOM_INLINE void inverse_transform_block_facade(MACROBLOCK *const x,
1110                                                       int plane, int block,
1111                                                       int blk_row, int blk_col,
1112                                                       int eob,
1113                                                       int reduced_tx_set) {
1114   if (!eob) return;
1115   struct macroblock_plane *const p = &x->plane[plane];
1116   MACROBLOCKD *const xd = &x->e_mbd;
1117   tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
1118   const PLANE_TYPE plane_type = get_plane_type(plane);
1119   const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
1120   const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
1121                                           tx_size, reduced_tx_set);
1122 
1123   struct macroblockd_plane *const pd = &xd->plane[plane];
1124   const int dst_stride = pd->dst.stride;
1125   uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
1126   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
1127                               dst_stride, eob, reduced_tx_set);
1128 }
1129 
recon_intra(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,int skip_trellis,TX_TYPE best_tx_type,int do_quant,int * rate_cost,uint16_t best_eob)1130 static INLINE void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1131                                int block, int blk_row, int blk_col,
1132                                BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1133                                const TXB_CTX *const txb_ctx, int skip_trellis,
1134                                TX_TYPE best_tx_type, int do_quant,
1135                                int *rate_cost, uint16_t best_eob) {
1136   const AV1_COMMON *cm = &cpi->common;
1137   MACROBLOCKD *xd = &x->e_mbd;
1138   MB_MODE_INFO *mbmi = xd->mi[0];
1139   const int is_inter = is_inter_block(mbmi);
1140   if (!is_inter && best_eob &&
1141       (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
1142        blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
1143     // if the quantized coefficients are stored in the dqcoeff buffer, we don't
1144     // need to do transform and quantization again.
1145     if (do_quant) {
1146       TxfmParam txfm_param_intra;
1147       QUANT_PARAM quant_param_intra;
1148       av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
1149       av1_setup_quant(tx_size, !skip_trellis,
1150                       skip_trellis
1151                           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
1152                                                     : AV1_XFORM_QUANT_FP)
1153                           : AV1_XFORM_QUANT_FP,
1154                       cpi->oxcf.q_cfg.quant_b_adapt, &quant_param_intra);
1155       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
1156                         &quant_param_intra);
1157       av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
1158                       &txfm_param_intra, &quant_param_intra);
1159       if (quant_param_intra.use_optimize_b) {
1160         av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
1161                        rate_cost);
1162       }
1163     }
1164 
1165     inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
1166                                    x->plane[plane].eobs[block],
1167                                    cm->features.reduced_tx_set_used);
1168 
1169     // This may happen because of hash collision. The eob stored in the hash
1170     // table is non-zero, but the real eob is zero. We need to make sure tx_type
1171     // is DCT_DCT in this case.
1172     if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
1173         best_tx_type != DCT_DCT) {
1174       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
1175     }
1176   }
1177 }
1178 
pixel_dist_visible_only(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,const BLOCK_SIZE tx_bsize,int txb_rows,int txb_cols,int visible_rows,int visible_cols)1179 static unsigned pixel_dist_visible_only(
1180     const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
1181     const int src_stride, const uint8_t *dst, const int dst_stride,
1182     const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
1183     int visible_cols) {
1184   unsigned sse;
1185 
1186   if (txb_rows == visible_rows && txb_cols == visible_cols) {
1187     cpi->ppi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
1188     return sse;
1189   }
1190 
1191 #if CONFIG_AV1_HIGHBITDEPTH
1192   const MACROBLOCKD *xd = &x->e_mbd;
1193   if (is_cur_buf_hbd(xd)) {
1194     uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
1195                                              visible_cols, visible_rows);
1196     return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
1197   }
1198 #else
1199   (void)x;
1200 #endif
1201   sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
1202                          visible_rows);
1203   return sse;
1204 }
1205 
1206 // Compute the pixel domain distortion from src and dst on all visible 4x4s in
1207 // the
1208 // transform block.
pixel_dist(const AV1_COMP * const cpi,const MACROBLOCK * x,int plane,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize)1209 static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
1210                            int plane, const uint8_t *src, const int src_stride,
1211                            const uint8_t *dst, const int dst_stride,
1212                            int blk_row, int blk_col,
1213                            const BLOCK_SIZE plane_bsize,
1214                            const BLOCK_SIZE tx_bsize) {
1215   int txb_rows, txb_cols, visible_rows, visible_cols;
1216   const MACROBLOCKD *xd = &x->e_mbd;
1217 
1218   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
1219                      &txb_cols, &txb_rows, &visible_cols, &visible_rows);
1220   assert(visible_rows > 0);
1221   assert(visible_cols > 0);
1222 
1223   unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
1224                                          dst_stride, tx_bsize, txb_rows,
1225                                          txb_cols, visible_rows, visible_cols);
1226 
1227   return sse;
1228 }
1229 
dist_block_px_domain(const AV1_COMP * cpi,MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,int block,int blk_row,int blk_col,TX_SIZE tx_size)1230 static INLINE int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
1231                                            int plane, BLOCK_SIZE plane_bsize,
1232                                            int block, int blk_row, int blk_col,
1233                                            TX_SIZE tx_size) {
1234   MACROBLOCKD *const xd = &x->e_mbd;
1235   const struct macroblock_plane *const p = &x->plane[plane];
1236   const uint16_t eob = p->eobs[block];
1237   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
1238   const int bsw = block_size_wide[tx_bsize];
1239   const int bsh = block_size_high[tx_bsize];
1240   const int src_stride = x->plane[plane].src.stride;
1241   const int dst_stride = xd->plane[plane].dst.stride;
1242   // Scale the transform block index to pixel unit.
1243   const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
1244   const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
1245   const uint8_t *src = &x->plane[plane].src.buf[src_idx];
1246   const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
1247   const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
1248 
1249   assert(cpi != NULL);
1250   assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
1251 
1252   uint8_t *recon;
1253   DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
1254 
1255 #if CONFIG_AV1_HIGHBITDEPTH
1256   if (is_cur_buf_hbd(xd)) {
1257     recon = CONVERT_TO_BYTEPTR(recon16);
1258     aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst), dst_stride,
1259                              CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw, bsh);
1260   } else {
1261     recon = (uint8_t *)recon16;
1262     aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1263   }
1264 #else
1265   recon = (uint8_t *)recon16;
1266   aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1267 #endif
1268 
1269   const PLANE_TYPE plane_type = get_plane_type(plane);
1270   TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
1271                                     cpi->common.features.reduced_tx_set_used);
1272   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
1273                               MAX_TX_SIZE, eob,
1274                               cpi->common.features.reduced_tx_set_used);
1275 
1276   return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
1277                          blk_row, blk_col, plane_bsize, tx_bsize);
1278 }
1279 
get_intra_txb_hash(MACROBLOCK * x,int plane,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size)1280 static uint32_t get_intra_txb_hash(MACROBLOCK *x, int plane, int blk_row,
1281                                    int blk_col, BLOCK_SIZE plane_bsize,
1282                                    TX_SIZE tx_size) {
1283   int16_t tmp_data[64 * 64];
1284   const int diff_stride = block_size_wide[plane_bsize];
1285   const int16_t *diff = x->plane[plane].src_diff;
1286   const int16_t *cur_diff_row = diff + 4 * blk_row * diff_stride + 4 * blk_col;
1287   const int txb_w = tx_size_wide[tx_size];
1288   const int txb_h = tx_size_high[tx_size];
1289   uint8_t *hash_data = (uint8_t *)cur_diff_row;
1290   if (txb_w != diff_stride) {
1291     int16_t *cur_hash_row = tmp_data;
1292     for (int i = 0; i < txb_h; i++) {
1293       memcpy(cur_hash_row, cur_diff_row, sizeof(*diff) * txb_w);
1294       cur_hash_row += txb_w;
1295       cur_diff_row += diff_stride;
1296     }
1297     hash_data = (uint8_t *)tmp_data;
1298   }
1299   CRC32C *crc =
1300       &x->txfm_search_info.txb_rd_records->mb_rd_record.crc_calculator;
1301   const uint32_t hash = av1_get_crc32c_value(crc, hash_data, 2 * txb_w * txb_h);
1302   return (hash << 5) + tx_size;
1303 }
1304 
1305 // pruning thresholds for prune_txk_type and prune_txk_type_separ
1306 static const int prune_factors[5] = { 200, 200, 120, 80, 40 };  // scale 1000
1307 static const int mul_factors[5] = { 80, 80, 70, 50, 30 };       // scale 100
1308 
is_intra_hash_match(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,TXB_RD_INFO ** intra_txb_rd_info,const int tx_type_map_idx,uint16_t * cur_joint_ctx)1309 static INLINE int is_intra_hash_match(const AV1_COMP *cpi, MACROBLOCK *x,
1310                                       int plane, int blk_row, int blk_col,
1311                                       BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1312                                       const TXB_CTX *const txb_ctx,
1313                                       TXB_RD_INFO **intra_txb_rd_info,
1314                                       const int tx_type_map_idx,
1315                                       uint16_t *cur_joint_ctx) {
1316   MACROBLOCKD *xd = &x->e_mbd;
1317   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
1318   assert(cpi->sf.tx_sf.use_intra_txb_hash &&
1319          frame_is_intra_only(&cpi->common) && !is_inter_block(xd->mi[0]) &&
1320          plane == 0 && tx_size_wide[tx_size] == tx_size_high[tx_size]);
1321   const uint32_t intra_hash =
1322       get_intra_txb_hash(x, plane, blk_row, blk_col, plane_bsize, tx_size);
1323   const int intra_hash_idx = find_tx_size_rd_info(
1324       &txfm_info->txb_rd_records->txb_rd_record_intra, intra_hash);
1325   *intra_txb_rd_info = &txfm_info->txb_rd_records->txb_rd_record_intra
1326                             .tx_rd_info[intra_hash_idx];
1327   *cur_joint_ctx = (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
1328   if ((*intra_txb_rd_info)->entropy_context == *cur_joint_ctx &&
1329       txfm_info->txb_rd_records->txb_rd_record_intra.tx_rd_info[intra_hash_idx]
1330           .valid) {
1331     xd->tx_type_map[tx_type_map_idx] = (*intra_txb_rd_info)->tx_type;
1332     const TX_TYPE ref_tx_type =
1333         av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1334                         cpi->common.features.reduced_tx_set_used);
1335     return (ref_tx_type == (*intra_txb_rd_info)->tx_type);
1336   }
1337   return 0;
1338 }
1339 
1340 // R-D costs are sorted in ascending order.
sort_rd(int64_t rds[],int txk[],int len)1341 static INLINE void sort_rd(int64_t rds[], int txk[], int len) {
1342   int i, j, k;
1343 
1344   for (i = 1; i <= len - 1; ++i) {
1345     for (j = 0; j < i; ++j) {
1346       if (rds[j] > rds[i]) {
1347         int64_t temprd;
1348         int tempi;
1349 
1350         temprd = rds[i];
1351         tempi = txk[i];
1352 
1353         for (k = i; k > j; k--) {
1354           rds[k] = rds[k - 1];
1355           txk[k] = txk[k - 1];
1356         }
1357 
1358         rds[j] = temprd;
1359         txk[j] = tempi;
1360         break;
1361       }
1362     }
1363   }
1364 }
1365 
dist_block_tx_domain(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int64_t * out_dist,int64_t * out_sse)1366 static INLINE void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
1367                                         TX_SIZE tx_size, int64_t *out_dist,
1368                                         int64_t *out_sse) {
1369   const struct macroblock_plane *const p = &x->plane[plane];
1370   // Transform domain distortion computation is more efficient as it does
1371   // not involve an inverse transform, but it is less accurate.
1372   const int buffer_length = av1_get_max_eob(tx_size);
1373   int64_t this_sse;
1374   // TX-domain results need to shift down to Q2/D10 to match pixel
1375   // domain distortion values which are in Q2^2
1376   int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
1377   const int block_offset = BLOCK_OFFSET(block);
1378   tran_low_t *const coeff = p->coeff + block_offset;
1379   tran_low_t *const dqcoeff = p->dqcoeff + block_offset;
1380 #if CONFIG_AV1_HIGHBITDEPTH
1381   MACROBLOCKD *const xd = &x->e_mbd;
1382   if (is_cur_buf_hbd(xd))
1383     *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse,
1384                                        xd->bd);
1385   else
1386 #endif
1387     *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
1388 
1389   *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
1390   *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
1391 }
1392 
prune_txk_type_separ(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,int16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used,int64_t ref_best_rd,int num_sel)1393 uint16_t prune_txk_type_separ(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1394                               int block, TX_SIZE tx_size, int blk_row,
1395                               int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
1396                               int16_t allowed_tx_mask, int prune_factor,
1397                               const TXB_CTX *const txb_ctx,
1398                               int reduced_tx_set_used, int64_t ref_best_rd,
1399                               int num_sel) {
1400   const AV1_COMMON *cm = &cpi->common;
1401 
1402   int idx;
1403 
1404   int64_t rds_v[4];
1405   int64_t rds_h[4];
1406   int idx_v[4] = { 0, 1, 2, 3 };
1407   int idx_h[4] = { 0, 1, 2, 3 };
1408   int skip_v[4] = { 0 };
1409   int skip_h[4] = { 0 };
1410   const int idx_map[16] = {
1411     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1412     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1413     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1414     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1415   };
1416 
1417   const int sel_pattern_v[16] = {
1418     0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
1419   };
1420   const int sel_pattern_h[16] = {
1421     0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
1422   };
1423 
1424   QUANT_PARAM quant_param;
1425   TxfmParam txfm_param;
1426   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1427   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1428                   &quant_param);
1429   int tx_type;
1430   // to ensure we can try ones even outside of ext_tx_set of current block
1431   // this function should only be called for size < 16
1432   assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
1433   txfm_param.tx_set_type = EXT_TX_SET_ALL16;
1434 
1435   int rate_cost = 0;
1436   int64_t dist = 0, sse = 0;
1437   // evaluate horizontal with vertical DCT
1438   for (idx = 0; idx < 4; ++idx) {
1439     tx_type = idx_map[idx];
1440     txfm_param.tx_type = tx_type;
1441 
1442     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1443                     &quant_param);
1444 
1445     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
1446 
1447     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1448                                               txb_ctx, reduced_tx_set_used, 0);
1449 
1450     rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
1451 
1452     if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
1453       skip_h[idx] = 1;
1454     }
1455   }
1456   sort_rd(rds_h, idx_h, 4);
1457   for (idx = 1; idx < 4; idx++) {
1458     if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
1459   }
1460 
1461   if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
1462 
1463   // evaluate vertical with the best horizontal chosen
1464   rds_v[0] = rds_h[0];
1465   int start_v = 1, end_v = 4;
1466   const int *idx_map_v = idx_map + idx_h[0];
1467 
1468   for (idx = start_v; idx < end_v; ++idx) {
1469     tx_type = idx_map_v[idx_v[idx] * 4];
1470     txfm_param.tx_type = tx_type;
1471 
1472     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1473                     &quant_param);
1474 
1475     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
1476 
1477     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1478                                               txb_ctx, reduced_tx_set_used, 0);
1479 
1480     rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
1481 
1482     if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
1483       skip_v[idx] = 1;
1484     }
1485   }
1486   sort_rd(rds_v, idx_v, 4);
1487   for (idx = 1; idx < 4; idx++) {
1488     if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
1489   }
1490 
1491   // combine rd_h and rd_v to prune tx candidates
1492   int i_v, i_h;
1493   int64_t rds[16];
1494   int num_cand = 0, last = TX_TYPES - 1;
1495 
1496   for (int i = 0; i < 16; i++) {
1497     i_v = sel_pattern_v[i];
1498     i_h = sel_pattern_h[i];
1499     tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
1500     if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
1501         skip_v[idx_v[i_v]]) {
1502       txk_map[last] = tx_type;
1503       last--;
1504     } else {
1505       txk_map[num_cand] = tx_type;
1506       rds[num_cand] = rds_v[i_v] + rds_h[i_h];
1507       if (rds[num_cand] == 0) rds[num_cand] = 1;
1508       num_cand++;
1509     }
1510   }
1511   sort_rd(rds, txk_map, num_cand);
1512 
1513   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1514   num_sel = AOMMIN(num_sel, num_cand);
1515 
1516   for (int i = 1; i < num_sel; i++) {
1517     int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
1518     if (factor < (int64_t)prune_factor)
1519       prune &= ~(1 << txk_map[i]);
1520     else
1521       break;
1522   }
1523   return prune;
1524 }
1525 
prune_txk_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,uint16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)1526 uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1527                         int block, TX_SIZE tx_size, int blk_row, int blk_col,
1528                         BLOCK_SIZE plane_bsize, int *txk_map,
1529                         uint16_t allowed_tx_mask, int prune_factor,
1530                         const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
1531   const AV1_COMMON *cm = &cpi->common;
1532   int tx_type;
1533 
1534   int64_t rds[TX_TYPES];
1535 
1536   int num_cand = 0;
1537   int last = TX_TYPES - 1;
1538 
1539   TxfmParam txfm_param;
1540   QUANT_PARAM quant_param;
1541   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1542   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1543                   &quant_param);
1544 
1545   for (int idx = 0; idx < TX_TYPES; idx++) {
1546     tx_type = idx;
1547     int rate_cost = 0;
1548     int64_t dist = 0, sse = 0;
1549     if (!(allowed_tx_mask & (1 << tx_type))) {
1550       txk_map[last] = tx_type;
1551       last--;
1552       continue;
1553     }
1554     txfm_param.tx_type = tx_type;
1555 
1556     // do txfm and quantization
1557     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1558                     &quant_param);
1559     // estimate rate cost
1560     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1561                                               txb_ctx, reduced_tx_set_used, 0);
1562     // tx domain dist
1563     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
1564 
1565     txk_map[num_cand] = tx_type;
1566     rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
1567     if (rds[num_cand] == 0) rds[num_cand] = 1;
1568     num_cand++;
1569   }
1570 
1571   if (num_cand == 0) return (uint16_t)0xFFFF;
1572 
1573   sort_rd(rds, txk_map, num_cand);
1574   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1575 
1576   // 0 < prune_factor <= 1000 controls aggressiveness
1577   int64_t factor = 0;
1578   for (int idx = 1; idx < num_cand; idx++) {
1579     factor = 1000 * (rds[idx] - rds[0]) / rds[0];
1580     if (factor < (int64_t)prune_factor)
1581       prune &= ~(1 << txk_map[idx]);
1582     else
1583       break;
1584   }
1585   return prune;
1586 }
1587 
1588 // These thresholds were calibrated to provide a certain number of TX types
1589 // pruned by the model on average, i.e. selecting a threshold with index i
1590 // will lead to pruning i+1 TX types on average
1591 static const float *prune_2D_adaptive_thresholds[] = {
1592   // TX_4X4
1593   (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1594              0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1595              0.09778f, 0.11780f },
1596   // TX_8X8
1597   (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1598              0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1599              0.10803f, 0.14124f },
1600   // TX_16X16
1601   (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1602              0.06897f, 0.07629f, 0.08875f, 0.11169f },
1603   // TX_32X32
1604   NULL,
1605   // TX_64X64
1606   NULL,
1607   // TX_4X8
1608   (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1609              0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1610              0.10168f, 0.12585f },
1611   // TX_8X4
1612   (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1613              0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1614              0.10583f, 0.13123f },
1615   // TX_8X16
1616   (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1617              0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1618              0.10730f, 0.14221f },
1619   // TX_16X8
1620   (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1621              0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1622              0.10339f, 0.13464f },
1623   // TX_16X32
1624   NULL,
1625   // TX_32X16
1626   NULL,
1627   // TX_32X64
1628   NULL,
1629   // TX_64X32
1630   NULL,
1631   // TX_4X16
1632   (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1633              0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1634              0.10242f, 0.12878f },
1635   // TX_16X4
1636   (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1637              0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1638              0.10217f, 0.12610f },
1639   // TX_8X32
1640   NULL,
1641   // TX_32X8
1642   NULL,
1643   // TX_16X64
1644   NULL,
1645   // TX_64X16
1646   NULL,
1647 };
1648 
get_adaptive_thresholds(TX_SIZE tx_size,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode)1649 static INLINE float get_adaptive_thresholds(
1650     TX_SIZE tx_size, TxSetType tx_set_type,
1651     TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) {
1652   const int prune_aggr_table[5][2] = {
1653     { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 }, { 12, 9 }
1654   };
1655   int pruning_aggressiveness = 0;
1656   if (tx_set_type == EXT_TX_SET_ALL16)
1657     pruning_aggressiveness =
1658         prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][0];
1659   else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
1660     pruning_aggressiveness =
1661         prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][1];
1662 
1663   return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
1664 }
1665 
get_energy_distribution_finer(const int16_t * diff,int stride,int bw,int bh,float * hordist,float * verdist)1666 static AOM_INLINE void get_energy_distribution_finer(const int16_t *diff,
1667                                                      int stride, int bw, int bh,
1668                                                      float *hordist,
1669                                                      float *verdist) {
1670   // First compute downscaled block energy values (esq); downscale factors
1671   // are defined by w_shift and h_shift.
1672   unsigned int esq[256];
1673   const int w_shift = bw <= 8 ? 0 : 1;
1674   const int h_shift = bh <= 8 ? 0 : 1;
1675   const int esq_w = bw >> w_shift;
1676   const int esq_h = bh >> h_shift;
1677   const int esq_sz = esq_w * esq_h;
1678   int i, j;
1679   memset(esq, 0, esq_sz * sizeof(esq[0]));
1680   if (w_shift) {
1681     for (i = 0; i < bh; i++) {
1682       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1683       const int16_t *cur_diff_row = diff + i * stride;
1684       for (j = 0; j < bw; j += 2) {
1685         cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1686                                 cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1687       }
1688     }
1689   } else {
1690     for (i = 0; i < bh; i++) {
1691       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1692       const int16_t *cur_diff_row = diff + i * stride;
1693       for (j = 0; j < bw; j++) {
1694         cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1695       }
1696     }
1697   }
1698 
1699   uint64_t total = 0;
1700   for (i = 0; i < esq_sz; i++) total += esq[i];
1701 
1702   // Output hordist and verdist arrays are normalized 1D projections of esq
1703   if (total == 0) {
1704     float hor_val = 1.0f / esq_w;
1705     for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1706     float ver_val = 1.0f / esq_h;
1707     for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1708     return;
1709   }
1710 
1711   const float e_recip = 1.0f / (float)total;
1712   memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1713   memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1714   const unsigned int *cur_esq_row;
1715   for (i = 0; i < esq_h - 1; i++) {
1716     cur_esq_row = esq + i * esq_w;
1717     for (j = 0; j < esq_w - 1; j++) {
1718       hordist[j] += (float)cur_esq_row[j];
1719       verdist[i] += (float)cur_esq_row[j];
1720     }
1721     verdist[i] += (float)cur_esq_row[j];
1722   }
1723   cur_esq_row = esq + i * esq_w;
1724   for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1725 
1726   for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1727   for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1728 }
1729 
check_bit_mask(uint16_t mask,int val)1730 static AOM_INLINE bool check_bit_mask(uint16_t mask, int val) {
1731   return mask & (1 << val);
1732 }
1733 
set_bit_mask(uint16_t * mask,int val)1734 static AOM_INLINE void set_bit_mask(uint16_t *mask, int val) {
1735   *mask |= (1 << val);
1736 }
1737 
unset_bit_mask(uint16_t * mask,int val)1738 static AOM_INLINE void unset_bit_mask(uint16_t *mask, int val) {
1739   *mask &= ~(1 << val);
1740 }
1741 
prune_tx_2D(MACROBLOCK * x,BLOCK_SIZE bsize,TX_SIZE tx_size,int blk_row,int blk_col,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode,int * txk_map,uint16_t * allowed_tx_mask)1742 static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1743                         int blk_row, int blk_col, TxSetType tx_set_type,
1744                         TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
1745                         uint16_t *allowed_tx_mask) {
1746   // This table is used because the search order is different from the enum
1747   // order.
1748   static const int tx_type_table_2D[16] = {
1749     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1750     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1751     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1752     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1753   };
1754   if (tx_set_type != EXT_TX_SET_ALL16 &&
1755       tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1756     return;
1757 #if CONFIG_NN_V2
1758   NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1759   NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1760 #else
1761   const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1762   const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1763 #endif
1764   if (!nn_config_hor || !nn_config_ver) return;  // Model not established yet.
1765 
1766   float hfeatures[16], vfeatures[16];
1767   float hscores[4], vscores[4];
1768   float scores_2D_raw[16];
1769   const int bw = tx_size_wide[tx_size];
1770   const int bh = tx_size_high[tx_size];
1771   const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1772   const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1773   assert(hfeatures_num <= 16);
1774   assert(vfeatures_num <= 16);
1775 
1776   const struct macroblock_plane *const p = &x->plane[0];
1777   const int diff_stride = block_size_wide[bsize];
1778   const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1779   get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1780                                 vfeatures);
1781 
1782   av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
1783                                   &hfeatures[hfeatures_num - 1],
1784                                   &vfeatures[vfeatures_num - 1]);
1785 
1786 #if CONFIG_NN_V2
1787   av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
1788   av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
1789 #else
1790   av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
1791   av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
1792 #endif
1793 
1794   for (int i = 0; i < 4; i++) {
1795     float *cur_scores_2D = scores_2D_raw + i * 4;
1796     cur_scores_2D[0] = vscores[i] * hscores[0];
1797     cur_scores_2D[1] = vscores[i] * hscores[1];
1798     cur_scores_2D[2] = vscores[i] * hscores[2];
1799     cur_scores_2D[3] = vscores[i] * hscores[3];
1800   }
1801 
1802   assert(TX_TYPES == 16);
1803   // This version of the function only works when there are at most 16 classes.
1804   // So we will need to change the optimization or use av1_nn_softmax instead if
1805   // this ever gets changed.
1806   av1_nn_fast_softmax_16(scores_2D_raw, scores_2D_raw);
1807 
1808   const float score_thresh =
1809       get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode);
1810 
1811   // Always keep the TX type with the highest score, prune all others with
1812   // score below score_thresh.
1813   int max_score_i = 0;
1814   float max_score = 0.0f;
1815   uint16_t allow_bitmask = 0;
1816   float sum_score = 0.0;
1817   // Calculate sum of allowed tx type score and Populate allow bit mask based
1818   // on score_thresh and allowed_tx_mask
1819   int allow_count = 0;
1820   int tx_type_allowed[16] = { TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1821                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1822                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1823                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1824                               TX_TYPE_INVALID, TX_TYPE_INVALID, TX_TYPE_INVALID,
1825                               TX_TYPE_INVALID };
1826   float scores_2D[16] = {
1827     -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
1828   };
1829   for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1830     const int allow_tx_type =
1831         check_bit_mask(*allowed_tx_mask, tx_type_table_2D[tx_idx]);
1832     if (!allow_tx_type) {
1833       continue;
1834     }
1835     if (scores_2D_raw[tx_idx] > max_score) {
1836       max_score = scores_2D_raw[tx_idx];
1837       max_score_i = tx_idx;
1838     }
1839     if (scores_2D_raw[tx_idx] >= score_thresh) {
1840       // Set allow mask based on score_thresh
1841       set_bit_mask(&allow_bitmask, tx_type_table_2D[tx_idx]);
1842 
1843       // Accumulate score of allowed tx type
1844       sum_score += scores_2D_raw[tx_idx];
1845 
1846       scores_2D[allow_count] = scores_2D_raw[tx_idx];
1847       tx_type_allowed[allow_count] = tx_type_table_2D[tx_idx];
1848       allow_count += 1;
1849     }
1850   }
1851   if (!check_bit_mask(allow_bitmask, tx_type_table_2D[max_score_i])) {
1852     // If even the tx_type with max score is pruned, this means that no other
1853     // tx_type is feasible. When this happens, we force enable max_score_i and
1854     // end the search.
1855     set_bit_mask(&allow_bitmask, tx_type_table_2D[max_score_i]);
1856     memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
1857     *allowed_tx_mask = allow_bitmask;
1858     return;
1859   }
1860 
1861   // Sort tx type probability of all types
1862   if (allow_count <= 8) {
1863     av1_sort_fi32_8(scores_2D, tx_type_allowed);
1864   } else {
1865     av1_sort_fi32_16(scores_2D, tx_type_allowed);
1866   }
1867 
1868   // Enable more pruning based on tx type probability and number of allowed tx
1869   // types
1870   if (prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) {
1871     float temp_score = 0.0;
1872     float score_ratio = 0.0;
1873     int tx_idx, tx_count = 0;
1874     const float inv_sum_score = 100 / sum_score;
1875     // Get allowed tx types based on sorted probability score and tx count
1876     for (tx_idx = 0; tx_idx < allow_count; tx_idx++) {
1877       // Skip the tx type which has more than 30% of cumulative
1878       // probability and allowed tx type count is more than 2
1879       if (score_ratio > 30.0 && tx_count >= 2) break;
1880 
1881       assert(check_bit_mask(allow_bitmask, tx_type_allowed[tx_idx]));
1882       // Calculate cumulative probability
1883       temp_score += scores_2D[tx_idx];
1884 
1885       // Calculate percentage of cumulative probability of allowed tx type
1886       score_ratio = temp_score * inv_sum_score;
1887       tx_count++;
1888     }
1889     // Set remaining tx types as pruned
1890     for (; tx_idx < allow_count; tx_idx++)
1891       unset_bit_mask(&allow_bitmask, tx_type_allowed[tx_idx]);
1892   }
1893 
1894   memcpy(txk_map, tx_type_allowed, sizeof(tx_type_table_2D));
1895   *allowed_tx_mask = allow_bitmask;
1896 }
1897 
get_dev(float mean,double x2_sum,int num)1898 static float get_dev(float mean, double x2_sum, int num) {
1899   const float e_x2 = (float)(x2_sum / num);
1900   const float diff = e_x2 - mean * mean;
1901   const float dev = (diff > 0) ? sqrtf(diff) : 0;
1902   return dev;
1903 }
1904 
1905 // Feature used by the model to predict tx split: the mean and standard
1906 // deviation values of the block and sub-blocks.
get_mean_dev_features(const int16_t * data,int stride,int bw,int bh,float * feature)1907 static AOM_INLINE void get_mean_dev_features(const int16_t *data, int stride,
1908                                              int bw, int bh, float *feature) {
1909   const int16_t *const data_ptr = &data[0];
1910   const int subh = (bh >= bw) ? (bh >> 1) : bh;
1911   const int subw = (bw >= bh) ? (bw >> 1) : bw;
1912   const int num = bw * bh;
1913   const int sub_num = subw * subh;
1914   int feature_idx = 2;
1915   int total_x_sum = 0;
1916   int64_t total_x2_sum = 0;
1917   int blk_idx = 0;
1918   double mean2_sum = 0.0f;
1919   float dev_sum = 0.0f;
1920 
1921   for (int row = 0; row < bh; row += subh) {
1922     for (int col = 0; col < bw; col += subw) {
1923       int x_sum;
1924       int64_t x2_sum;
1925       // TODO(any): Write a SIMD version. Clear registers.
1926       aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
1927                           &x_sum, &x2_sum);
1928       total_x_sum += x_sum;
1929       total_x2_sum += x2_sum;
1930 
1931       const float mean = (float)x_sum / sub_num;
1932       const float dev = get_dev(mean, (double)x2_sum, sub_num);
1933       feature[feature_idx++] = mean;
1934       feature[feature_idx++] = dev;
1935       mean2_sum += (double)(mean * mean);
1936       dev_sum += dev;
1937       blk_idx++;
1938     }
1939   }
1940 
1941   const float lvl0_mean = (float)total_x_sum / num;
1942   feature[0] = lvl0_mean;
1943   feature[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
1944 
1945   if (blk_idx > 1) {
1946     // Deviation of means.
1947     feature[feature_idx++] = get_dev(lvl0_mean, mean2_sum, blk_idx);
1948     // Mean of deviations.
1949     feature[feature_idx++] = dev_sum / blk_idx;
1950   }
1951 }
1952 
ml_predict_tx_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size)1953 static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
1954                                int blk_col, TX_SIZE tx_size) {
1955   const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
1956   if (!nn_config) return -1;
1957 
1958   const int diff_stride = block_size_wide[bsize];
1959   const int16_t *diff =
1960       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1961   const int bw = tx_size_wide[tx_size];
1962   const int bh = tx_size_high[tx_size];
1963 
1964   float features[64] = { 0.0f };
1965   get_mean_dev_features(diff, diff_stride, bw, bh, features);
1966 
1967   float score = 0.0f;
1968   av1_nn_predict(features, nn_config, 1, &score);
1969 
1970   int int_score = (int)(score * 10000);
1971   return clamp(int_score, -80000, 80000);
1972 }
1973 
1974 static INLINE uint16_t
get_tx_mask(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_best_rd,TX_TYPE * allowed_txk_types,int * txk_map)1975 get_tx_mask(const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block,
1976             int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1977             const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
1978             int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {
1979   const AV1_COMMON *cm = &cpi->common;
1980   MACROBLOCKD *xd = &x->e_mbd;
1981   MB_MODE_INFO *mbmi = xd->mi[0];
1982   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
1983   const int is_inter = is_inter_block(mbmi);
1984   const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
1985   // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
1986   // TX_TYPES, only that specific tx type is allowed.
1987   TX_TYPE txk_allowed = TX_TYPES;
1988 
1989   const FRAME_UPDATE_TYPE update_type =
1990       get_frame_update_type(&cpi->ppi->gf_group, cpi->gf_frame_index);
1991   const int *tx_type_probs =
1992       cpi->ppi->frame_probs.tx_type_probs[update_type][tx_size];
1993 
1994   if ((!is_inter && txfm_params->use_default_intra_tx_type) ||
1995       (is_inter && txfm_params->default_inter_tx_type_prob_thresh == 0)) {
1996     txk_allowed =
1997         get_default_tx_type(0, xd, tx_size, cpi->use_screen_content_tools);
1998   } else if (is_inter &&
1999              txfm_params->default_inter_tx_type_prob_thresh != INT_MAX) {
2000     if (tx_type_probs[DEFAULT_INTER_TX_TYPE] >
2001         txfm_params->default_inter_tx_type_prob_thresh) {
2002       txk_allowed = DEFAULT_INTER_TX_TYPE;
2003     } else {
2004       int force_tx_type = 0;
2005       int max_prob = 0;
2006       const int tx_type_prob_threshold =
2007           txfm_params->default_inter_tx_type_prob_thresh +
2008           PROB_THRESH_OFFSET_TX_TYPE;
2009       for (int i = 1; i < TX_TYPES; i++) {  // find maximum probability.
2010         if (tx_type_probs[i] > max_prob) {
2011           max_prob = tx_type_probs[i];
2012           force_tx_type = i;
2013         }
2014       }
2015       if (max_prob > tx_type_prob_threshold)  // force tx type with max prob.
2016         txk_allowed = force_tx_type;
2017       else if (x->rd_model == LOW_TXFM_RD) {
2018         if (plane == 0) txk_allowed = DCT_DCT;
2019       }
2020     }
2021   } else if (x->rd_model == LOW_TXFM_RD) {
2022     if (plane == 0) txk_allowed = DCT_DCT;
2023   }
2024 
2025   const TxSetType tx_set_type = av1_get_ext_tx_set_type(
2026       tx_size, is_inter, cm->features.reduced_tx_set_used);
2027 
2028   TX_TYPE uv_tx_type = DCT_DCT;
2029   if (plane) {
2030     // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
2031     uv_tx_type = txk_allowed =
2032         av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
2033                         cm->features.reduced_tx_set_used);
2034   }
2035   PREDICTION_MODE intra_dir =
2036       mbmi->filter_intra_mode_info.use_filter_intra
2037           ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
2038           : mbmi->mode;
2039   uint16_t ext_tx_used_flag =
2040       cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset != 0 &&
2041               tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
2042           ? av1_reduced_intra_tx_used_flag[intra_dir]
2043           : av1_ext_tx_used_flag[tx_set_type];
2044 
2045   if (cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset == 2)
2046     ext_tx_used_flag &= av1_derived_intra_tx_used_flag[intra_dir];
2047 
2048   if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
2049       ext_tx_used_flag == 0x0001 ||
2050       (is_inter && cpi->oxcf.txfm_cfg.use_inter_dct_only) ||
2051       (!is_inter && cpi->oxcf.txfm_cfg.use_intra_dct_only)) {
2052     txk_allowed = DCT_DCT;
2053   }
2054 
2055   if (cpi->oxcf.txfm_cfg.enable_flip_idtx == 0)
2056     ext_tx_used_flag &= DCT_ADST_TX_MASK;
2057 
2058   uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
2059   if (txk_allowed < TX_TYPES) {
2060     allowed_tx_mask = 1 << txk_allowed;
2061     allowed_tx_mask &= ext_tx_used_flag;
2062   } else if (fast_tx_search) {
2063     allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
2064     allowed_tx_mask &= ext_tx_used_flag;
2065   } else {
2066     assert(plane == 0);
2067     allowed_tx_mask = ext_tx_used_flag;
2068     int num_allowed = 0;
2069     int i;
2070 
2071     if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
2072       static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 },
2073                                             { 10, 17, 17, 10, 17, 17, 17 } };
2074       const int thresh =
2075           thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1]
2076                     [update_type];
2077       uint16_t prune = 0;
2078       int max_prob = -1;
2079       int max_idx = 0;
2080       for (i = 0; i < TX_TYPES; i++) {
2081         if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
2082           max_prob = tx_type_probs[i];
2083           max_idx = i;
2084         }
2085         if (tx_type_probs[i] < thresh) prune |= (1 << i);
2086       }
2087       if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx);
2088       allowed_tx_mask &= (~prune);
2089     }
2090     for (i = 0; i < TX_TYPES; i++) {
2091       if (allowed_tx_mask & (1 << i)) num_allowed++;
2092     }
2093     assert(num_allowed > 0);
2094 
2095     if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
2096       int pf = prune_factors[txfm_params->prune_2d_txfm_mode];
2097       int mf = mul_factors[txfm_params->prune_2d_txfm_mode];
2098       if (num_allowed <= 7) {
2099         const uint16_t prune =
2100             prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col,
2101                            plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx,
2102                            cm->features.reduced_tx_set_used);
2103         allowed_tx_mask &= (~prune);
2104       } else {
2105         const int num_sel = (num_allowed * mf + 50) / 100;
2106         const uint16_t prune = prune_txk_type_separ(
2107             cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
2108             txk_map, allowed_tx_mask, pf, txb_ctx,
2109             cm->features.reduced_tx_set_used, ref_best_rd, num_sel);
2110 
2111         allowed_tx_mask &= (~prune);
2112       }
2113     } else {
2114       assert(num_allowed > 0);
2115       int allowed_tx_count =
2116           (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) ? 1 : 5;
2117       // !fast_tx_search && txk_end != txk_start && plane == 0
2118       if (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_1 && is_inter &&
2119           num_allowed > allowed_tx_count) {
2120         prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
2121                     txfm_params->prune_2d_txfm_mode, txk_map, &allowed_tx_mask);
2122       }
2123     }
2124   }
2125 
2126   // Need to have at least one transform type allowed.
2127   if (allowed_tx_mask == 0) {
2128     txk_allowed = (plane ? uv_tx_type : DCT_DCT);
2129     allowed_tx_mask = (1 << txk_allowed);
2130   }
2131 
2132   assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
2133   *allowed_txk_types = txk_allowed;
2134   return allowed_tx_mask;
2135 }
2136 
2137 #if CONFIG_RD_DEBUG
update_txb_coeff_cost(RD_STATS * rd_stats,int plane,int txb_coeff_cost)2138 static INLINE void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
2139                                          int txb_coeff_cost) {
2140   rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
2141 }
2142 #endif
2143 
cost_coeffs(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,const TX_TYPE tx_type,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)2144 static INLINE int cost_coeffs(MACROBLOCK *x, int plane, int block,
2145                               TX_SIZE tx_size, const TX_TYPE tx_type,
2146                               const TXB_CTX *const txb_ctx,
2147                               int reduced_tx_set_used) {
2148 #if TXCOEFF_COST_TIMER
2149   struct aom_usec_timer timer;
2150   aom_usec_timer_start(&timer);
2151 #endif
2152   const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type,
2153                                        txb_ctx, reduced_tx_set_used);
2154 #if TXCOEFF_COST_TIMER
2155   AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
2156   aom_usec_timer_mark(&timer);
2157   const int64_t elapsed_time = aom_usec_timer_elapsed(&timer);
2158   tmp_cm->txcoeff_cost_timer += elapsed_time;
2159   ++tmp_cm->txcoeff_cost_count;
2160 #endif
2161   return cost;
2162 }
2163 
skip_trellis_opt_based_on_satd(MACROBLOCK * x,QUANT_PARAM * quant_param,int plane,int block,TX_SIZE tx_size,int quant_b_adapt,int qstep,unsigned int coeff_opt_satd_threshold,int skip_trellis,int dc_only_blk)2164 static int skip_trellis_opt_based_on_satd(MACROBLOCK *x,
2165                                           QUANT_PARAM *quant_param, int plane,
2166                                           int block, TX_SIZE tx_size,
2167                                           int quant_b_adapt, int qstep,
2168                                           unsigned int coeff_opt_satd_threshold,
2169                                           int skip_trellis, int dc_only_blk) {
2170   if (skip_trellis || (coeff_opt_satd_threshold == UINT_MAX))
2171     return skip_trellis;
2172 
2173   const struct macroblock_plane *const p = &x->plane[plane];
2174   const int block_offset = BLOCK_OFFSET(block);
2175   tran_low_t *const coeff_ptr = p->coeff + block_offset;
2176   const int n_coeffs = av1_get_max_eob(tx_size);
2177   const int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size));
2178   int satd = (dc_only_blk) ? abs(coeff_ptr[0]) : aom_satd(coeff_ptr, n_coeffs);
2179   satd = RIGHT_SIGNED_SHIFT(satd, shift);
2180   satd >>= (x->e_mbd.bd - 8);
2181 
2182   const int skip_block_trellis =
2183       ((uint64_t)satd >
2184        (uint64_t)coeff_opt_satd_threshold * qstep * sqrt_tx_pixels_2d[tx_size]);
2185 
2186   av1_setup_quant(
2187       tx_size, !skip_block_trellis,
2188       skip_block_trellis
2189           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP)
2190           : AV1_XFORM_QUANT_FP,
2191       quant_b_adapt, quant_param);
2192 
2193   return skip_block_trellis;
2194 }
2195 
2196 // Predict DC only blocks if the residual variance is below a qstep based
2197 // threshold.For such blocks, transform type search is bypassed.
predict_dc_only_block(MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,int block,int blk_row,int blk_col,RD_STATS * best_rd_stats,int64_t * block_sse,unsigned int * block_mse_q8,int64_t * per_px_mean,int * dc_only_blk)2198 static INLINE void predict_dc_only_block(
2199     MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2200     int block, int blk_row, int blk_col, RD_STATS *best_rd_stats,
2201     int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean,
2202     int *dc_only_blk) {
2203   MACROBLOCKD *xd = &x->e_mbd;
2204   MB_MODE_INFO *mbmi = xd->mi[0];
2205   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2206   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2207   uint64_t block_var = UINT64_MAX;
2208   const int dc_qstep = x->plane[plane].dequant_QTX[0] >> 3;
2209   *block_sse = pixel_diff_stats(x, plane, blk_row, blk_col, plane_bsize,
2210                                 txsize_to_bsize[tx_size], block_mse_q8,
2211                                 per_px_mean, &block_var);
2212   assert((*block_mse_q8) != UINT_MAX);
2213   uint64_t var_threshold = (uint64_t)(1.8 * qstep * qstep);
2214   if (is_cur_buf_hbd(xd))
2215     block_var = ROUND_POWER_OF_TWO(block_var, (xd->bd - 8) * 2);
2216   // Early prediction of skip block if residual mean and variance are less
2217   // than qstep based threshold
2218   if (((llabs(*per_px_mean) * dc_coeff_scale[tx_size]) < (dc_qstep << 12)) &&
2219       (block_var < var_threshold)) {
2220     // If the normalized mean of residual block is less than the dc qstep and
2221     // the  normalized block variance is less than ac qstep, then the block is
2222     // assumed to be a skip block and its rdcost is updated accordingly.
2223     best_rd_stats->skip_txfm = 1;
2224 
2225     x->plane[plane].eobs[block] = 0;
2226 
2227     if (is_cur_buf_hbd(xd))
2228       *block_sse = ROUND_POWER_OF_TWO((*block_sse), (xd->bd - 8) * 2);
2229 
2230     best_rd_stats->dist = (*block_sse) << 4;
2231     best_rd_stats->sse = best_rd_stats->dist;
2232 
2233     ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
2234     ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
2235     av1_get_entropy_contexts(plane_bsize, &xd->plane[plane], ctxa, ctxl);
2236     ENTROPY_CONTEXT *ta = ctxa;
2237     ENTROPY_CONTEXT *tl = ctxl;
2238     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2239     TXB_CTX txb_ctx_tmp;
2240     const PLANE_TYPE plane_type = get_plane_type(plane);
2241     get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx_tmp);
2242     const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][plane_type]
2243                                   .txb_skip_cost[txb_ctx_tmp.txb_skip_ctx][1];
2244     best_rd_stats->rate = zero_blk_rate;
2245 
2246     best_rd_stats->rdcost =
2247         RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->sse);
2248 
2249     x->plane[plane].txb_entropy_ctx[block] = 0;
2250   } else if (block_var < var_threshold) {
2251     // Predict DC only blocks based on residual variance.
2252     // For chroma plane, this early prediction is disabled for intra blocks.
2253     if ((plane == 0) || (plane > 0 && is_inter_block(mbmi))) *dc_only_blk = 1;
2254   }
2255 }
2256 
2257 // Search for the best transform type for a given transform block.
2258 // This function can be used for both inter and intra, both luma and chroma.
search_tx_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis,int64_t ref_best_rd,RD_STATS * best_rd_stats)2259 static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2260                            int block, int blk_row, int blk_col,
2261                            BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2262                            const TXB_CTX *const txb_ctx,
2263                            FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis,
2264                            int64_t ref_best_rd, RD_STATS *best_rd_stats) {
2265   const AV1_COMMON *cm = &cpi->common;
2266   MACROBLOCKD *xd = &x->e_mbd;
2267   MB_MODE_INFO *mbmi = xd->mi[0];
2268   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2269   int64_t best_rd = INT64_MAX;
2270   uint16_t best_eob = 0;
2271   TX_TYPE best_tx_type = DCT_DCT;
2272   int rate_cost = 0;
2273   // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
2274   // of the best tx_type
2275   DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff[MAX_SB_SQUARE]);
2276   struct macroblock_plane *const p = &x->plane[plane];
2277   tran_low_t *orig_dqcoeff = p->dqcoeff;
2278   tran_low_t *best_dqcoeff = this_dqcoeff;
2279   const int tx_type_map_idx =
2280       plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
2281   av1_invalid_rd_stats(best_rd_stats);
2282 
2283   skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id],
2284                                    DRY_RUN_NORMAL);
2285 
2286   // Hashing based speed feature for intra block. If the hash of the residue
2287   // is found in the hash table, use the previous RD search results stored in
2288   // the table and terminate early.
2289   TXB_RD_INFO *intra_txb_rd_info = NULL;
2290   uint16_t cur_joint_ctx = 0;
2291   const int is_inter = is_inter_block(mbmi);
2292   const int use_intra_txb_hash =
2293       cpi->sf.tx_sf.use_intra_txb_hash && frame_is_intra_only(cm) &&
2294       !is_inter && plane == 0 && tx_size_wide[tx_size] == tx_size_high[tx_size];
2295   if (use_intra_txb_hash) {
2296     const int mi_row = xd->mi_row;
2297     const int mi_col = xd->mi_col;
2298     const int within_border =
2299         mi_row >= xd->tile.mi_row_start &&
2300         (mi_row + mi_size_high[plane_bsize] < xd->tile.mi_row_end) &&
2301         mi_col >= xd->tile.mi_col_start &&
2302         (mi_col + mi_size_wide[plane_bsize] < xd->tile.mi_col_end);
2303     if (within_border &&
2304         is_intra_hash_match(cpi, x, plane, blk_row, blk_col, plane_bsize,
2305                             tx_size, txb_ctx, &intra_txb_rd_info,
2306                             tx_type_map_idx, &cur_joint_ctx)) {
2307       best_rd_stats->rate = intra_txb_rd_info->rate;
2308       best_rd_stats->dist = intra_txb_rd_info->dist;
2309       best_rd_stats->sse = intra_txb_rd_info->sse;
2310       best_rd_stats->skip_txfm = intra_txb_rd_info->eob == 0;
2311       x->plane[plane].eobs[block] = intra_txb_rd_info->eob;
2312       x->plane[plane].txb_entropy_ctx[block] =
2313           intra_txb_rd_info->txb_entropy_ctx;
2314       best_eob = intra_txb_rd_info->eob;
2315       best_tx_type = intra_txb_rd_info->tx_type;
2316       skip_trellis |= !intra_txb_rd_info->perform_block_coeff_opt;
2317       update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2318       recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2319                   txb_ctx, skip_trellis, best_tx_type, 1, &rate_cost, best_eob);
2320       p->dqcoeff = orig_dqcoeff;
2321       return;
2322     }
2323   }
2324 
2325   uint8_t best_txb_ctx = 0;
2326   // txk_allowed = TX_TYPES: >1 tx types are allowed
2327   // txk_allowed < TX_TYPES: only that specific tx type is allowed.
2328   TX_TYPE txk_allowed = TX_TYPES;
2329   int txk_map[TX_TYPES] = {
2330     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
2331   };
2332   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2333   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2334 
2335   const uint8_t txw = tx_size_wide[tx_size];
2336   const uint8_t txh = tx_size_high[tx_size];
2337   int64_t block_sse;
2338   unsigned int block_mse_q8;
2339   int dc_only_blk = 0;
2340   const bool predict_dc_block =
2341       txfm_params->predict_dc_level && txw != 64 && txh != 64;
2342   int64_t per_px_mean = INT64_MAX;
2343   if (predict_dc_block) {
2344     predict_dc_only_block(x, plane, plane_bsize, tx_size, block, blk_row,
2345                           blk_col, best_rd_stats, &block_sse, &block_mse_q8,
2346                           &per_px_mean, &dc_only_blk);
2347     if (best_rd_stats->skip_txfm == 1) {
2348       const TX_TYPE tx_type = DCT_DCT;
2349       if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2350       return;
2351     }
2352   } else {
2353     block_sse = pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
2354                                 txsize_to_bsize[tx_size], &block_mse_q8);
2355     assert(block_mse_q8 != UINT_MAX);
2356   }
2357 
2358   // Bit mask to indicate which transform types are allowed in the RD search.
2359   uint16_t tx_mask;
2360 
2361   // Use DCT_DCT transform for DC only block.
2362   if (dc_only_blk)
2363     tx_mask = 1 << DCT_DCT;
2364   else
2365     tx_mask = get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize,
2366                           tx_size, txb_ctx, ftxs_mode, ref_best_rd,
2367                           &txk_allowed, txk_map);
2368   const uint16_t allowed_tx_mask = tx_mask;
2369 
2370   if (is_cur_buf_hbd(xd)) {
2371     block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
2372     block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
2373   }
2374   block_sse *= 16;
2375   // Use mse / qstep^2 based threshold logic to take decision of R-D
2376   // optimization of coeffs. For smaller residuals, coeff optimization
2377   // would be helpful. For larger residuals, R-D optimization may not be
2378   // effective.
2379   // TODO(any): Experiment with variance and mean based thresholds
2380   const int perform_block_coeff_opt =
2381       ((uint64_t)block_mse_q8 <=
2382        (uint64_t)txfm_params->coeff_opt_thresholds[0] * qstep * qstep);
2383   skip_trellis |= !perform_block_coeff_opt;
2384 
2385   // Flag to indicate if distortion should be calculated in transform domain or
2386   // not during iterating through transform type candidates.
2387   // Transform domain distortion is accurate for higher residuals.
2388   // TODO(any): Experiment with variance and mean based thresholds
2389   int use_transform_domain_distortion =
2390       (txfm_params->use_transform_domain_distortion > 0) &&
2391       (block_mse_q8 >= txfm_params->tx_domain_dist_threshold) &&
2392       // Any 64-pt transforms only preserves half the coefficients.
2393       // Therefore transform domain distortion is not valid for these
2394       // transform sizes.
2395       (txsize_sqr_up_map[tx_size] != TX_64X64) &&
2396       // Use pixel domain distortion for DC only blocks
2397       !dc_only_blk;
2398   // Flag to indicate if an extra calculation of distortion in the pixel domain
2399   // should be performed at the end, after the best transform type has been
2400   // decided.
2401   int calc_pixel_domain_distortion_final =
2402       txfm_params->use_transform_domain_distortion == 1 &&
2403       use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
2404   if (calc_pixel_domain_distortion_final &&
2405       (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
2406     calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
2407 
2408   const uint16_t *eobs_ptr = x->plane[plane].eobs;
2409 
2410   TxfmParam txfm_param;
2411   QUANT_PARAM quant_param;
2412   int skip_trellis_based_on_satd[TX_TYPES] = { 0 };
2413   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
2414   av1_setup_quant(tx_size, !skip_trellis,
2415                   skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
2416                                                          : AV1_XFORM_QUANT_FP)
2417                                : AV1_XFORM_QUANT_FP,
2418                   cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
2419 
2420   // Iterate through all transform type candidates.
2421   for (int idx = 0; idx < TX_TYPES; ++idx) {
2422     const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
2423     if (tx_type == TX_TYPE_INVALID || !check_bit_mask(allowed_tx_mask, tx_type))
2424       continue;
2425     txfm_param.tx_type = tx_type;
2426     if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
2427       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
2428                         &quant_param);
2429     }
2430     if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2431     RD_STATS this_rd_stats;
2432     av1_invalid_rd_stats(&this_rd_stats);
2433 
2434     if (!dc_only_blk)
2435       av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
2436     else
2437       av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
2438 
2439     skip_trellis_based_on_satd[tx_type] = skip_trellis_opt_based_on_satd(
2440         x, &quant_param, plane, block, tx_size, cpi->oxcf.q_cfg.quant_b_adapt,
2441         qstep, txfm_params->coeff_opt_thresholds[1], skip_trellis, dc_only_blk);
2442 
2443     av1_quant(x, plane, block, &txfm_param, &quant_param);
2444 
2445     // Calculate rate cost of quantized coefficients.
2446     if (quant_param.use_optimize_b) {
2447       av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
2448                      &rate_cost);
2449     } else {
2450       rate_cost = cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
2451                               cm->features.reduced_tx_set_used);
2452     }
2453 
2454     // If rd cost based on coeff rate alone is already more than best_rd,
2455     // terminate early.
2456     if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
2457 
2458     // Calculate distortion.
2459     if (eobs_ptr[block] == 0) {
2460       // When eob is 0, pixel domain distortion is more efficient and accurate.
2461       this_rd_stats.dist = this_rd_stats.sse = block_sse;
2462     } else if (dc_only_blk) {
2463       this_rd_stats.sse = block_sse;
2464       this_rd_stats.dist = dist_block_px_domain(
2465           cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2466     } else if (use_transform_domain_distortion) {
2467       dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
2468                            &this_rd_stats.sse);
2469     } else {
2470       int64_t sse_diff = INT64_MAX;
2471       // high_energy threshold assumes that every pixel within a txfm block
2472       // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
2473       // for 8 bit.
2474       const int64_t high_energy_thresh =
2475           ((int64_t)128 * 128 * tx_size_2d[tx_size]);
2476       const int is_high_energy = (block_sse >= high_energy_thresh);
2477       if (tx_size == TX_64X64 || is_high_energy) {
2478         // Because 3 out 4 quadrants of transform coefficients are forced to
2479         // zero, the inverse transform has a tendency to overflow. sse_diff
2480         // is effectively the energy of those 3 quadrants, here we use it
2481         // to decide if we should do pixel domain distortion. If the energy
2482         // is mostly in first quadrant, then it is unlikely that we have
2483         // overflow issue in inverse transform.
2484         dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
2485                              &this_rd_stats.sse);
2486         sse_diff = block_sse - this_rd_stats.sse;
2487       }
2488       if (tx_size != TX_64X64 || !is_high_energy ||
2489           (sse_diff * 2) < this_rd_stats.sse) {
2490         const int64_t tx_domain_dist = this_rd_stats.dist;
2491         this_rd_stats.dist = dist_block_px_domain(
2492             cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2493         // For high energy blocks, occasionally, the pixel domain distortion
2494         // can be artificially low due to clamping at reconstruction stage
2495         // even when inverse transform output is hugely different from the
2496         // actual residue.
2497         if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
2498           this_rd_stats.dist = tx_domain_dist;
2499       } else {
2500         assert(sse_diff < INT64_MAX);
2501         this_rd_stats.dist += sse_diff;
2502       }
2503       this_rd_stats.sse = block_sse;
2504     }
2505 
2506     this_rd_stats.rate = rate_cost;
2507 
2508     const int64_t rd =
2509         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2510 
2511     if (rd < best_rd) {
2512       best_rd = rd;
2513       *best_rd_stats = this_rd_stats;
2514       best_tx_type = tx_type;
2515       best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
2516       best_eob = x->plane[plane].eobs[block];
2517       // Swap dqcoeff buffers
2518       tran_low_t *const tmp_dqcoeff = best_dqcoeff;
2519       best_dqcoeff = p->dqcoeff;
2520       p->dqcoeff = tmp_dqcoeff;
2521     }
2522 
2523 #if CONFIG_COLLECT_RD_STATS == 1
2524     if (plane == 0) {
2525       PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
2526                               plane_bsize, tx_size, tx_type, rd);
2527     }
2528 #endif  // CONFIG_COLLECT_RD_STATS == 1
2529 
2530 #if COLLECT_TX_SIZE_DATA
2531     // Generate small sample to restrict output size.
2532     static unsigned int seed = 21743;
2533     if (lcg_rand16(&seed) % 200 == 0) {
2534       FILE *fp = NULL;
2535 
2536       if (within_border) {
2537         fp = fopen(av1_tx_size_data_output_file, "a");
2538       }
2539 
2540       if (fp) {
2541         // Transform info and RD
2542         const int txb_w = tx_size_wide[tx_size];
2543         const int txb_h = tx_size_high[tx_size];
2544 
2545         // Residue signal.
2546         const int diff_stride = block_size_wide[plane_bsize];
2547         struct macroblock_plane *const p = &x->plane[plane];
2548         const int16_t *src_diff =
2549             &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
2550 
2551         for (int r = 0; r < txb_h; ++r) {
2552           for (int c = 0; c < txb_w; ++c) {
2553             fprintf(fp, "%d,", src_diff[c]);
2554           }
2555           src_diff += diff_stride;
2556         }
2557 
2558         fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
2559         fprintf(fp, "\n");
2560         fclose(fp);
2561       }
2562     }
2563 #endif  // COLLECT_TX_SIZE_DATA
2564 
2565     // If the current best RD cost is much worse than the reference RD cost,
2566     // terminate early.
2567     if (cpi->sf.tx_sf.adaptive_txb_search_level) {
2568       if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
2569           ref_best_rd) {
2570         break;
2571       }
2572     }
2573 
2574     // Terminate transform type search if the block has been quantized to
2575     // all zero.
2576     if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
2577   }
2578 
2579   assert(best_rd != INT64_MAX);
2580 
2581   best_rd_stats->skip_txfm = best_eob == 0;
2582   if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2583   x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
2584   x->plane[plane].eobs[block] = best_eob;
2585   skip_trellis = skip_trellis_based_on_satd[best_tx_type];
2586 
2587   // Point dqcoeff to the quantized coefficients corresponding to the best
2588   // transform type, then we can skip transform and quantization, e.g. in the
2589   // final pixel domain distortion calculation and recon_intra().
2590   p->dqcoeff = best_dqcoeff;
2591 
2592   if (calc_pixel_domain_distortion_final && best_eob) {
2593     best_rd_stats->dist = dist_block_px_domain(
2594         cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2595     best_rd_stats->sse = block_sse;
2596   }
2597 
2598   if (intra_txb_rd_info != NULL) {
2599     intra_txb_rd_info->valid = 1;
2600     intra_txb_rd_info->entropy_context = cur_joint_ctx;
2601     intra_txb_rd_info->rate = best_rd_stats->rate;
2602     intra_txb_rd_info->dist = best_rd_stats->dist;
2603     intra_txb_rd_info->sse = best_rd_stats->sse;
2604     intra_txb_rd_info->eob = best_eob;
2605     intra_txb_rd_info->txb_entropy_ctx = best_txb_ctx;
2606     intra_txb_rd_info->perform_block_coeff_opt = perform_block_coeff_opt;
2607     if (plane == 0) intra_txb_rd_info->tx_type = best_tx_type;
2608   }
2609 
2610   // Intra mode needs decoded pixels such that the next transform block
2611   // can use them for prediction.
2612   recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2613               txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
2614   p->dqcoeff = orig_dqcoeff;
2615 }
2616 
2617 // Pick transform type for a luma transform block of tx_size. Note this function
2618 // is used only for inter-predicted blocks.
tx_type_rd(const AV1_COMP * cpi,MACROBLOCK * x,TX_SIZE tx_size,int blk_row,int blk_col,int block,int plane_bsize,TXB_CTX * txb_ctx,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_rdcost,TXB_RD_INFO * rd_info_array)2619 static AOM_INLINE void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
2620                                   TX_SIZE tx_size, int blk_row, int blk_col,
2621                                   int block, int plane_bsize, TXB_CTX *txb_ctx,
2622                                   RD_STATS *rd_stats,
2623                                   FAST_TX_SEARCH_MODE ftxs_mode,
2624                                   int64_t ref_rdcost,
2625                                   TXB_RD_INFO *rd_info_array) {
2626   const struct macroblock_plane *const p = &x->plane[0];
2627   const uint16_t cur_joint_ctx =
2628       (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
2629   MACROBLOCKD *xd = &x->e_mbd;
2630   assert(is_inter_block(xd->mi[0]));
2631   const int tx_type_map_idx = blk_row * xd->tx_type_map_stride + blk_col;
2632   // Look up RD and terminate early in case when we've already processed exactly
2633   // the same residue with exactly the same entropy context.
2634   if (rd_info_array != NULL && rd_info_array->valid &&
2635       rd_info_array->entropy_context == cur_joint_ctx) {
2636     xd->tx_type_map[tx_type_map_idx] = rd_info_array->tx_type;
2637     const TX_TYPE ref_tx_type =
2638         av1_get_tx_type(&x->e_mbd, get_plane_type(0), blk_row, blk_col, tx_size,
2639                         cpi->common.features.reduced_tx_set_used);
2640     if (ref_tx_type == rd_info_array->tx_type) {
2641       rd_stats->rate += rd_info_array->rate;
2642       rd_stats->dist += rd_info_array->dist;
2643       rd_stats->sse += rd_info_array->sse;
2644       rd_stats->skip_txfm &= rd_info_array->eob == 0;
2645       p->eobs[block] = rd_info_array->eob;
2646       p->txb_entropy_ctx[block] = rd_info_array->txb_entropy_ctx;
2647       return;
2648     }
2649   }
2650 
2651   RD_STATS this_rd_stats;
2652   const int skip_trellis = 0;
2653   search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
2654                  txb_ctx, ftxs_mode, skip_trellis, ref_rdcost, &this_rd_stats);
2655 
2656   av1_merge_rd_stats(rd_stats, &this_rd_stats);
2657 
2658   // Save RD results for possible reuse in future.
2659   if (rd_info_array != NULL) {
2660     rd_info_array->valid = 1;
2661     rd_info_array->entropy_context = cur_joint_ctx;
2662     rd_info_array->rate = this_rd_stats.rate;
2663     rd_info_array->dist = this_rd_stats.dist;
2664     rd_info_array->sse = this_rd_stats.sse;
2665     rd_info_array->eob = p->eobs[block];
2666     rd_info_array->txb_entropy_ctx = p->txb_entropy_ctx[block];
2667     rd_info_array->tx_type = xd->tx_type_map[tx_type_map_idx];
2668   }
2669 }
2670 
try_tx_block_no_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,const ENTROPY_CONTEXT * ta,const ENTROPY_CONTEXT * tl,int txfm_partition_ctx,RD_STATS * rd_stats,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,TxCandidateInfo * no_split)2671 static AOM_INLINE void try_tx_block_no_split(
2672     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2673     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
2674     const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
2675     int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
2676     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
2677     TxCandidateInfo *no_split) {
2678   MACROBLOCKD *const xd = &x->e_mbd;
2679   MB_MODE_INFO *const mbmi = xd->mi[0];
2680   struct macroblock_plane *const p = &x->plane[0];
2681   const int bw = mi_size_wide[plane_bsize];
2682   const ENTROPY_CONTEXT *const pta = ta + blk_col;
2683   const ENTROPY_CONTEXT *const ptl = tl + blk_row;
2684   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2685   TXB_CTX txb_ctx;
2686   get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
2687   const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
2688                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
2689   rd_stats->zero_rate = zero_blk_rate;
2690   const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
2691   mbmi->inter_tx_size[index] = tx_size;
2692   tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
2693              rd_stats, ftxs_mode, ref_best_rd,
2694              rd_info_node != NULL ? rd_info_node->rd_info_array : NULL);
2695   assert(rd_stats->rate < INT_MAX);
2696 
2697   const int pick_skip_txfm =
2698       !xd->lossless[mbmi->segment_id] &&
2699       (rd_stats->skip_txfm == 1 ||
2700        RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
2701            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse));
2702   if (pick_skip_txfm) {
2703 #if CONFIG_RD_DEBUG
2704     update_txb_coeff_cost(rd_stats, 0, zero_blk_rate - rd_stats->rate);
2705 #endif  // CONFIG_RD_DEBUG
2706     rd_stats->rate = zero_blk_rate;
2707     rd_stats->dist = rd_stats->sse;
2708     p->eobs[block] = 0;
2709     update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
2710   }
2711   rd_stats->skip_txfm = pick_skip_txfm;
2712   set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2713                pick_skip_txfm);
2714 
2715   if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
2716     rd_stats->rate += x->mode_costs.txfm_partition_cost[txfm_partition_ctx][0];
2717 
2718   no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
2719   no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
2720   no_split->tx_type =
2721       xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
2722 }
2723 
try_tx_block_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int txfm_partition_ctx,int64_t no_split_rd,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,RD_STATS * split_rd_stats)2724 static AOM_INLINE void try_tx_block_split(
2725     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2726     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2727     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2728     int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
2729     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
2730     RD_STATS *split_rd_stats) {
2731   assert(tx_size < TX_SIZES_ALL);
2732   MACROBLOCKD *const xd = &x->e_mbd;
2733   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
2734   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
2735   const int txb_width = tx_size_wide_unit[tx_size];
2736   const int txb_height = tx_size_high_unit[tx_size];
2737   // Transform size after splitting current block.
2738   const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
2739   const int sub_txb_width = tx_size_wide_unit[sub_txs];
2740   const int sub_txb_height = tx_size_high_unit[sub_txs];
2741   const int sub_step = sub_txb_width * sub_txb_height;
2742   const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width);
2743   assert(nblks > 0);
2744   av1_init_rd_stats(split_rd_stats);
2745   split_rd_stats->rate =
2746       x->mode_costs.txfm_partition_cost[txfm_partition_ctx][1];
2747 
2748   for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) {
2749     const int offsetr = blk_row + r;
2750     if (offsetr >= max_blocks_high) break;
2751     for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) {
2752       assert(blk_idx < 4);
2753       const int offsetc = blk_col + c;
2754       if (offsetc >= max_blocks_wide) continue;
2755 
2756       RD_STATS this_rd_stats;
2757       int this_cost_valid = 1;
2758       select_tx_block(
2759           cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, ta,
2760           tl, tx_above, tx_left, &this_rd_stats, no_split_rd / nblks,
2761           ref_best_rd - split_rd_stats->rdcost, &this_cost_valid, ftxs_mode,
2762           (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
2763       if (!this_cost_valid) {
2764         split_rd_stats->rdcost = INT64_MAX;
2765         return;
2766       }
2767       av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
2768       split_rd_stats->rdcost =
2769           RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
2770       if (split_rd_stats->rdcost > ref_best_rd) {
2771         split_rd_stats->rdcost = INT64_MAX;
2772         return;
2773       }
2774       block += sub_step;
2775     }
2776   }
2777 }
2778 
get_var(float mean,double x2_sum,int num)2779 static float get_var(float mean, double x2_sum, int num) {
2780   const float e_x2 = (float)(x2_sum / num);
2781   const float diff = e_x2 - mean * mean;
2782   return diff;
2783 }
2784 
get_blk_var_dev(const int16_t * data,int stride,int bw,int bh,float * dev_of_mean,float * var_of_vars)2785 static AOM_INLINE void get_blk_var_dev(const int16_t *data, int stride, int bw,
2786                                        int bh, float *dev_of_mean,
2787                                        float *var_of_vars) {
2788   const int16_t *const data_ptr = &data[0];
2789   const int subh = (bh >= bw) ? (bh >> 1) : bh;
2790   const int subw = (bw >= bh) ? (bw >> 1) : bw;
2791   const int num = bw * bh;
2792   const int sub_num = subw * subh;
2793   int total_x_sum = 0;
2794   int64_t total_x2_sum = 0;
2795   int blk_idx = 0;
2796   float var_sum = 0.0f;
2797   float mean_sum = 0.0f;
2798   double var2_sum = 0.0f;
2799   double mean2_sum = 0.0f;
2800 
2801   for (int row = 0; row < bh; row += subh) {
2802     for (int col = 0; col < bw; col += subw) {
2803       int x_sum;
2804       int64_t x2_sum;
2805       aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
2806                           &x_sum, &x2_sum);
2807       total_x_sum += x_sum;
2808       total_x2_sum += x2_sum;
2809 
2810       const float mean = (float)x_sum / sub_num;
2811       const float var = get_var(mean, (double)x2_sum, sub_num);
2812       mean_sum += mean;
2813       mean2_sum += (double)(mean * mean);
2814       var_sum += var;
2815       var2_sum += var * var;
2816       blk_idx++;
2817     }
2818   }
2819 
2820   const float lvl0_mean = (float)total_x_sum / num;
2821   const float block_var = get_var(lvl0_mean, (double)total_x2_sum, num);
2822   mean_sum += lvl0_mean;
2823   mean2_sum += (double)(lvl0_mean * lvl0_mean);
2824   var_sum += block_var;
2825   var2_sum += block_var * block_var;
2826   const float av_mean = mean_sum / 5;
2827 
2828   if (blk_idx > 1) {
2829     // Deviation of means.
2830     *dev_of_mean = get_dev(av_mean, mean2_sum, (blk_idx + 1));
2831     // Variance of variances.
2832     const float mean_var = var_sum / (blk_idx + 1);
2833     *var_of_vars = get_var(mean_var, var2_sum, (blk_idx + 1));
2834   }
2835 }
2836 
prune_tx_split_no_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size,int * try_no_split,int * try_split,int pruning_level)2837 static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize,
2838                                     int blk_row, int blk_col, TX_SIZE tx_size,
2839                                     int *try_no_split, int *try_split,
2840                                     int pruning_level) {
2841   const int diff_stride = block_size_wide[bsize];
2842   const int16_t *diff =
2843       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
2844   const int bw = tx_size_wide[tx_size];
2845   const int bh = tx_size_high[tx_size];
2846   float dev_of_means = 0.0f;
2847   float var_of_vars = 0.0f;
2848 
2849   // This function calculates the deviation of means, and the variance of pixel
2850   // variances of the block as well as it's sub-blocks.
2851   get_blk_var_dev(diff, diff_stride, bw, bh, &dev_of_means, &var_of_vars);
2852   const int dc_q = x->plane[0].dequant_QTX[0] >> 3;
2853   const int ac_q = x->plane[0].dequant_QTX[1] >> 3;
2854   const int no_split_thresh_scales[4] = { 0, 24, 8, 8 };
2855   const int no_split_thresh_scale = no_split_thresh_scales[pruning_level];
2856   const int split_thresh_scales[4] = { 0, 24, 10, 8 };
2857   const int split_thresh_scale = split_thresh_scales[pruning_level];
2858 
2859   if ((dev_of_means <= dc_q) &&
2860       (split_thresh_scale * var_of_vars <= ac_q * ac_q)) {
2861     *try_split = 0;
2862   }
2863   if ((dev_of_means > no_split_thresh_scale * dc_q) &&
2864       (var_of_vars > no_split_thresh_scale * ac_q * ac_q)) {
2865     *try_no_split = 0;
2866   }
2867 }
2868 
2869 // Search for the best transform partition(recursive)/type for a given
2870 // inter-predicted luma block. The obtained transform selection will be saved
2871 // in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
select_tx_block(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,RD_STATS * rd_stats,int64_t prev_level_rd,int64_t ref_best_rd,int * is_cost_valid,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node)2872 static AOM_INLINE void select_tx_block(
2873     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2874     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2875     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2876     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
2877     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
2878     TXB_RD_INFO_NODE *rd_info_node) {
2879   assert(tx_size < TX_SIZES_ALL);
2880   av1_init_rd_stats(rd_stats);
2881   if (ref_best_rd < 0) {
2882     *is_cost_valid = 0;
2883     return;
2884   }
2885 
2886   MACROBLOCKD *const xd = &x->e_mbd;
2887   assert(blk_row < max_block_high(xd, plane_bsize, 0) &&
2888          blk_col < max_block_wide(xd, plane_bsize, 0));
2889   MB_MODE_INFO *const mbmi = xd->mi[0];
2890   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
2891                                          mbmi->bsize, tx_size);
2892   struct macroblock_plane *const p = &x->plane[0];
2893 
2894   int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 ||
2895                       txsize_sqr_up_map[tx_size] != TX_64X64) &&
2896                      (cpi->oxcf.txfm_cfg.enable_rect_tx ||
2897                       tx_size_wide[tx_size] == tx_size_high[tx_size]);
2898   int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
2899   TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
2900 
2901   // Prune tx_split and no-split based on sub-block properties.
2902   if (tx_size != TX_4X4 && try_split == 1 && try_no_split == 1 &&
2903       cpi->sf.tx_sf.prune_tx_size_level > 0) {
2904     prune_tx_split_no_split(x, plane_bsize, blk_row, blk_col, tx_size,
2905                             &try_no_split, &try_split,
2906                             cpi->sf.tx_sf.prune_tx_size_level);
2907   }
2908 
2909   if (cpi->sf.rt_sf.skip_tx_no_split_var_based_partition) {
2910     if (x->try_merge_partition && try_split && p->eobs[block]) try_no_split = 0;
2911   }
2912 
2913   // Try using current block as a single transform block without split.
2914   if (try_no_split) {
2915     try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2916                           plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
2917                           ftxs_mode, rd_info_node, &no_split);
2918 
2919     // Speed features for early termination.
2920     const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level;
2921     if (search_level) {
2922       if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) {
2923         *is_cost_valid = 0;
2924         return;
2925       }
2926       if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) {
2927         try_split = 0;
2928       }
2929     }
2930     if (cpi->sf.tx_sf.txb_split_cap) {
2931       if (p->eobs[block] == 0) try_split = 0;
2932     }
2933   }
2934 
2935   // ML based speed feature to skip searching for split transform blocks.
2936   if (x->e_mbd.bd == 8 && try_split &&
2937       !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
2938     const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
2939     if (threshold >= 0) {
2940       const int split_score =
2941           ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
2942       if (split_score < -threshold) try_split = 0;
2943     }
2944   }
2945 
2946   RD_STATS split_rd_stats;
2947   split_rd_stats.rdcost = INT64_MAX;
2948   // Try splitting current block into smaller transform blocks.
2949   if (try_split) {
2950     try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2951                        plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
2952                        AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
2953                        rd_info_node, &split_rd_stats);
2954   }
2955 
2956   if (no_split.rd < split_rd_stats.rdcost) {
2957     ENTROPY_CONTEXT *pta = ta + blk_col;
2958     ENTROPY_CONTEXT *ptl = tl + blk_row;
2959     p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
2960     av1_set_txb_context(x, 0, block, tx_size, pta, ptl);
2961     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
2962                           tx_size);
2963     for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
2964       for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
2965         const int index =
2966             av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
2967         mbmi->inter_tx_size[index] = tx_size;
2968       }
2969     }
2970     mbmi->tx_size = tx_size;
2971     update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
2972     const int bw = mi_size_wide[plane_bsize];
2973     set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2974                  rd_stats->skip_txfm);
2975   } else {
2976     *rd_stats = split_rd_stats;
2977     if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0;
2978   }
2979 }
2980 
choose_largest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2981 static AOM_INLINE void choose_largest_tx_size(const AV1_COMP *const cpi,
2982                                               MACROBLOCK *x, RD_STATS *rd_stats,
2983                                               int64_t ref_best_rd,
2984                                               BLOCK_SIZE bs) {
2985   MACROBLOCKD *const xd = &x->e_mbd;
2986   MB_MODE_INFO *const mbmi = xd->mi[0];
2987   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2988   mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2989 
2990   // If tx64 is not enabled, we need to go down to the next available size
2991   if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) {
2992     static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
2993       TX_4X4,    // 4x4 transform
2994       TX_8X8,    // 8x8 transform
2995       TX_16X16,  // 16x16 transform
2996       TX_32X32,  // 32x32 transform
2997       TX_32X32,  // 64x64 transform
2998       TX_4X8,    // 4x8 transform
2999       TX_8X4,    // 8x4 transform
3000       TX_8X16,   // 8x16 transform
3001       TX_16X8,   // 16x8 transform
3002       TX_16X32,  // 16x32 transform
3003       TX_32X16,  // 32x16 transform
3004       TX_32X32,  // 32x64 transform
3005       TX_32X32,  // 64x32 transform
3006       TX_4X16,   // 4x16 transform
3007       TX_16X4,   // 16x4 transform
3008       TX_8X32,   // 8x32 transform
3009       TX_32X8,   // 32x8 transform
3010       TX_16X32,  // 16x64 transform
3011       TX_32X16,  // 64x16 transform
3012     };
3013     mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
3014   } else if (cpi->oxcf.txfm_cfg.enable_tx64 &&
3015              !cpi->oxcf.txfm_cfg.enable_rect_tx) {
3016     static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = {
3017       TX_4X4,    // 4x4 transform
3018       TX_8X8,    // 8x8 transform
3019       TX_16X16,  // 16x16 transform
3020       TX_32X32,  // 32x32 transform
3021       TX_64X64,  // 64x64 transform
3022       TX_4X4,    // 4x8 transform
3023       TX_4X4,    // 8x4 transform
3024       TX_8X8,    // 8x16 transform
3025       TX_8X8,    // 16x8 transform
3026       TX_16X16,  // 16x32 transform
3027       TX_16X16,  // 32x16 transform
3028       TX_32X32,  // 32x64 transform
3029       TX_32X32,  // 64x32 transform
3030       TX_4X4,    // 4x16 transform
3031       TX_4X4,    // 16x4 transform
3032       TX_8X8,    // 8x32 transform
3033       TX_8X8,    // 32x8 transform
3034       TX_16X16,  // 16x64 transform
3035       TX_16X16,  // 64x16 transform
3036     };
3037     mbmi->tx_size = tx_size_max_square[mbmi->tx_size];
3038   } else if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
3039              !cpi->oxcf.txfm_cfg.enable_rect_tx) {
3040     static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = {
3041       TX_4X4,    // 4x4 transform
3042       TX_8X8,    // 8x8 transform
3043       TX_16X16,  // 16x16 transform
3044       TX_32X32,  // 32x32 transform
3045       TX_32X32,  // 64x64 transform
3046       TX_4X4,    // 4x8 transform
3047       TX_4X4,    // 8x4 transform
3048       TX_8X8,    // 8x16 transform
3049       TX_8X8,    // 16x8 transform
3050       TX_16X16,  // 16x32 transform
3051       TX_16X16,  // 32x16 transform
3052       TX_32X32,  // 32x64 transform
3053       TX_32X32,  // 64x32 transform
3054       TX_4X4,    // 4x16 transform
3055       TX_4X4,    // 16x4 transform
3056       TX_8X8,    // 8x32 transform
3057       TX_8X8,    // 32x8 transform
3058       TX_16X16,  // 16x64 transform
3059       TX_16X16,  // 64x16 transform
3060     };
3061 
3062     mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size];
3063   }
3064 
3065   const int skip_ctx = av1_get_skip_txfm_context(xd);
3066   const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3067   const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3068   // Skip RDcost is used only for Inter blocks
3069   const int64_t skip_txfm_rd =
3070       is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
3071   const int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_rate, 0);
3072   const int skip_trellis = 0;
3073   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
3074                        AOMMIN(no_skip_txfm_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
3075                        mbmi->tx_size, FTXS_NONE, skip_trellis);
3076 }
3077 
choose_smallest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3078 static AOM_INLINE void choose_smallest_tx_size(const AV1_COMP *const cpi,
3079                                                MACROBLOCK *x,
3080                                                RD_STATS *rd_stats,
3081                                                int64_t ref_best_rd,
3082                                                BLOCK_SIZE bs) {
3083   MACROBLOCKD *const xd = &x->e_mbd;
3084   MB_MODE_INFO *const mbmi = xd->mi[0];
3085 
3086   mbmi->tx_size = TX_4X4;
3087   // TODO(any) : Pass this_rd based on skip/non-skip cost
3088   const int skip_trellis = 0;
3089   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
3090                        FTXS_NONE, skip_trellis);
3091 }
3092 
3093 // Search for the best uniform transform size and type for current coding block.
choose_tx_size_type_from_rd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3094 static AOM_INLINE void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
3095                                                    MACROBLOCK *x,
3096                                                    RD_STATS *rd_stats,
3097                                                    int64_t ref_best_rd,
3098                                                    BLOCK_SIZE bs) {
3099   av1_invalid_rd_stats(rd_stats);
3100 
3101   MACROBLOCKD *const xd = &x->e_mbd;
3102   MB_MODE_INFO *const mbmi = xd->mi[0];
3103   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3104   const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
3105   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT;
3106   int start_tx;
3107   // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls
3108   // how many times of splitting is allowed during the RD search.
3109   int init_depth;
3110 
3111   if (tx_select) {
3112     start_tx = max_rect_tx_size;
3113     init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
3114                                        is_inter_block(mbmi), &cpi->sf,
3115                                        txfm_params->tx_size_search_method);
3116     if (init_depth == MAX_TX_DEPTH && !cpi->oxcf.txfm_cfg.enable_tx64 &&
3117         txsize_sqr_up_map[start_tx] == TX_64X64) {
3118       start_tx = sub_tx_size_map[start_tx];
3119     }
3120   } else {
3121     const TX_SIZE chosen_tx_size =
3122         tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
3123     start_tx = chosen_tx_size;
3124     init_depth = MAX_TX_DEPTH;
3125   }
3126 
3127   const int skip_trellis = 0;
3128   uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
3129   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
3130   TX_SIZE best_tx_size = max_rect_tx_size;
3131   int64_t best_rd = INT64_MAX;
3132   const int num_blks = bsize_to_num_blk(bs);
3133   x->rd_model = FULL_TXFM_RD;
3134   int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
3135   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3136   for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
3137        depth++, tx_size = sub_tx_size_map[tx_size]) {
3138     if ((!cpi->oxcf.txfm_cfg.enable_tx64 &&
3139          txsize_sqr_up_map[tx_size] == TX_64X64) ||
3140         (!cpi->oxcf.txfm_cfg.enable_rect_tx &&
3141          tx_size_wide[tx_size] != tx_size_high[tx_size])) {
3142       continue;
3143     }
3144 
3145     RD_STATS this_rd_stats;
3146     rd[depth] = av1_uniform_txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs,
3147                                      tx_size, FTXS_NONE, skip_trellis);
3148     if (rd[depth] < best_rd) {
3149       av1_copy_array(best_blk_skip, txfm_info->blk_skip, num_blks);
3150       av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks);
3151       best_tx_size = tx_size;
3152       best_rd = rd[depth];
3153       *rd_stats = this_rd_stats;
3154     }
3155     if (tx_size == TX_4X4) break;
3156     // If we are searching three depths, prune the smallest size depending
3157     // on rd results for the first two depths for low contrast blocks.
3158     if (depth > init_depth && depth != MAX_TX_DEPTH &&
3159         x->source_variance < 256) {
3160       if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
3161     }
3162   }
3163 
3164   if (rd_stats->rate != INT_MAX) {
3165     mbmi->tx_size = best_tx_size;
3166     av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks);
3167     av1_copy_array(txfm_info->blk_skip, best_blk_skip, num_blks);
3168   }
3169 }
3170 
3171 // Search for the best transform type for the given transform block in the
3172 // given plane/channel, and calculate the corresponding RD cost.
block_rd_txfm(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)3173 static AOM_INLINE void block_rd_txfm(int plane, int block, int blk_row,
3174                                      int blk_col, BLOCK_SIZE plane_bsize,
3175                                      TX_SIZE tx_size, void *arg) {
3176   struct rdcost_block_args *args = arg;
3177   if (args->exit_early) {
3178     args->incomplete_exit = 1;
3179     return;
3180   }
3181 
3182   MACROBLOCK *const x = args->x;
3183   MACROBLOCKD *const xd = &x->e_mbd;
3184   const int is_inter = is_inter_block(xd->mi[0]);
3185   const AV1_COMP *cpi = args->cpi;
3186   ENTROPY_CONTEXT *a = args->t_above + blk_col;
3187   ENTROPY_CONTEXT *l = args->t_left + blk_row;
3188   const AV1_COMMON *cm = &cpi->common;
3189   RD_STATS this_rd_stats;
3190   av1_init_rd_stats(&this_rd_stats);
3191 
3192   if (!is_inter) {
3193     av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
3194     av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
3195   }
3196 
3197   TXB_CTX txb_ctx;
3198   get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
3199   search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3200                  &txb_ctx, args->ftxs_mode, args->skip_trellis,
3201                  args->best_rd - args->current_rd, &this_rd_stats);
3202 
3203   if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
3204     assert(!is_inter || plane_bsize < BLOCK_8X8);
3205     cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
3206   }
3207 
3208 #if CONFIG_RD_DEBUG
3209   update_txb_coeff_cost(&this_rd_stats, plane, this_rd_stats.rate);
3210 #endif  // CONFIG_RD_DEBUG
3211   av1_set_txb_context(x, plane, block, tx_size, a, l);
3212 
3213   const int blk_idx =
3214       blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
3215 
3216   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3217   if (plane == 0)
3218     set_blk_skip(txfm_info->blk_skip, plane, blk_idx,
3219                  x->plane[plane].eobs[block] == 0);
3220   else
3221     set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 0);
3222 
3223   int64_t rd;
3224   if (is_inter) {
3225     const int64_t no_skip_txfm_rd =
3226         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3227     const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3228     rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd);
3229     this_rd_stats.skip_txfm &= !x->plane[plane].eobs[block];
3230   } else {
3231     // Signal non-skip_txfm for Intra blocks
3232     rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3233     this_rd_stats.skip_txfm = 0;
3234   }
3235 
3236   av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
3237 
3238   args->current_rd += rd;
3239   if (args->current_rd > args->best_rd) args->exit_early = 1;
3240 }
3241 
av1_estimate_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size)3242 int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3243                               RD_STATS *rd_stats, int64_t ref_best_rd,
3244                               BLOCK_SIZE bs, TX_SIZE tx_size) {
3245   MACROBLOCKD *const xd = &x->e_mbd;
3246   MB_MODE_INFO *const mbmi = xd->mi[0];
3247   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3248   const ModeCosts *mode_costs = &x->mode_costs;
3249   const int is_inter = is_inter_block(mbmi);
3250   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3251                         block_signals_txsize(mbmi->bsize);
3252   int tx_size_rate = 0;
3253   if (tx_select) {
3254     const int ctx = txfm_partition_context(
3255         xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3256     tx_size_rate = mode_costs->txfm_partition_cost[ctx][0];
3257   }
3258   const int skip_ctx = av1_get_skip_txfm_context(xd);
3259   const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3260   const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3261   const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, 0);
3262   const int64_t no_this_rd =
3263       RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3264   mbmi->tx_size = tx_size;
3265 
3266   const uint8_t txw_unit = tx_size_wide_unit[tx_size];
3267   const uint8_t txh_unit = tx_size_high_unit[tx_size];
3268   const int step = txw_unit * txh_unit;
3269   const int max_blocks_wide = max_block_wide(xd, bs, 0);
3270   const int max_blocks_high = max_block_high(xd, bs, 0);
3271 
3272   struct rdcost_block_args args;
3273   av1_zero(args);
3274   args.x = x;
3275   args.cpi = cpi;
3276   args.best_rd = ref_best_rd;
3277   args.current_rd = AOMMIN(no_this_rd, skip_txfm_rd);
3278   av1_init_rd_stats(&args.rd_stats);
3279   av1_get_entropy_contexts(bs, &xd->plane[0], args.t_above, args.t_left);
3280   int i = 0;
3281   for (int blk_row = 0; blk_row < max_blocks_high && !args.incomplete_exit;
3282        blk_row += txh_unit) {
3283     for (int blk_col = 0; blk_col < max_blocks_wide; blk_col += txw_unit) {
3284       RD_STATS this_rd_stats;
3285       av1_init_rd_stats(&this_rd_stats);
3286 
3287       if (args.exit_early) {
3288         args.incomplete_exit = 1;
3289         break;
3290       }
3291 
3292       ENTROPY_CONTEXT *a = args.t_above + blk_col;
3293       ENTROPY_CONTEXT *l = args.t_left + blk_row;
3294       TXB_CTX txb_ctx;
3295       get_txb_ctx(bs, tx_size, 0, a, l, &txb_ctx);
3296 
3297       TxfmParam txfm_param;
3298       QUANT_PARAM quant_param;
3299       av1_setup_xform(&cpi->common, x, tx_size, DCT_DCT, &txfm_param);
3300       av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, 0, &quant_param);
3301 
3302       av1_xform(x, 0, i, blk_row, blk_col, bs, &txfm_param);
3303       av1_quant(x, 0, i, &txfm_param, &quant_param);
3304 
3305       this_rd_stats.rate =
3306           cost_coeffs(x, 0, i, tx_size, txfm_param.tx_type, &txb_ctx, 0);
3307 
3308       dist_block_tx_domain(x, 0, i, tx_size, &this_rd_stats.dist,
3309                            &this_rd_stats.sse);
3310 
3311       const int64_t no_skip_txfm_rd =
3312           RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3313       const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3314 
3315       this_rd_stats.skip_txfm &= !x->plane[0].eobs[i];
3316 
3317       av1_merge_rd_stats(&args.rd_stats, &this_rd_stats);
3318       args.current_rd += AOMMIN(no_skip_txfm_rd, skip_rd);
3319 
3320       if (args.current_rd > ref_best_rd) {
3321         args.exit_early = 1;
3322         break;
3323       }
3324 
3325       av1_set_txb_context(x, 0, i, tx_size, a, l);
3326       i += step;
3327     }
3328   }
3329 
3330   if (args.incomplete_exit) av1_invalid_rd_stats(&args.rd_stats);
3331 
3332   *rd_stats = args.rd_stats;
3333   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3334 
3335   int64_t rd;
3336   // rdstats->rate should include all the rate except skip/non-skip cost as the
3337   // same is accounted in the caller functions after rd evaluation of all
3338   // planes. However the decisions should be done after considering the
3339   // skip/non-skip header cost
3340   if (rd_stats->skip_txfm && is_inter) {
3341     rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3342   } else {
3343     // Intra blocks are always signalled as non-skip
3344     rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3345                 rd_stats->dist);
3346     rd_stats->rate += tx_size_rate;
3347   }
3348   // Check if forcing the block to skip transform leads to smaller RD cost.
3349   if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3350     int64_t temp_skip_txfm_rd =
3351         RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3352     if (temp_skip_txfm_rd <= rd) {
3353       rd = temp_skip_txfm_rd;
3354       rd_stats->rate = 0;
3355       rd_stats->dist = rd_stats->sse;
3356       rd_stats->skip_txfm = 1;
3357     }
3358   }
3359 
3360   return rd;
3361 }
3362 
av1_uniform_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3363 int64_t av1_uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3364                              RD_STATS *rd_stats, int64_t ref_best_rd,
3365                              BLOCK_SIZE bs, TX_SIZE tx_size,
3366                              FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis) {
3367   assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
3368   MACROBLOCKD *const xd = &x->e_mbd;
3369   MB_MODE_INFO *const mbmi = xd->mi[0];
3370   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3371   const ModeCosts *mode_costs = &x->mode_costs;
3372   const int is_inter = is_inter_block(mbmi);
3373   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3374                         block_signals_txsize(mbmi->bsize);
3375   int tx_size_rate = 0;
3376   if (tx_select) {
3377     const int ctx = txfm_partition_context(
3378         xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3379     tx_size_rate = is_inter ? mode_costs->txfm_partition_cost[ctx][0]
3380                             : tx_size_cost(x, bs, tx_size);
3381   }
3382   const int skip_ctx = av1_get_skip_txfm_context(xd);
3383   const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3384   const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3385   const int64_t skip_txfm_rd =
3386       is_inter ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
3387   const int64_t no_this_rd =
3388       RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3389 
3390   mbmi->tx_size = tx_size;
3391   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
3392                        AOMMIN(no_this_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
3393                        tx_size, ftxs_mode, skip_trellis);
3394   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3395 
3396   int64_t rd;
3397   // rdstats->rate should include all the rate except skip/non-skip cost as the
3398   // same is accounted in the caller functions after rd evaluation of all
3399   // planes. However the decisions should be done after considering the
3400   // skip/non-skip header cost
3401   if (rd_stats->skip_txfm && is_inter) {
3402     rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3403   } else {
3404     // Intra blocks are always signalled as non-skip
3405     rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3406                 rd_stats->dist);
3407     rd_stats->rate += tx_size_rate;
3408   }
3409   // Check if forcing the block to skip transform leads to smaller RD cost.
3410   if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3411     int64_t temp_skip_txfm_rd =
3412         RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3413     if (temp_skip_txfm_rd <= rd) {
3414       rd = temp_skip_txfm_rd;
3415       rd_stats->rate = 0;
3416       rd_stats->dist = rd_stats->sse;
3417       rd_stats->skip_txfm = 1;
3418     }
3419   }
3420 
3421   return rd;
3422 }
3423 
3424 // Search for the best transform type for a luma inter-predicted block, given
3425 // the transform block partitions.
3426 // This function is used only when some speed features are enabled.
tx_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,BLOCK_SIZE plane_bsize,int depth,ENTROPY_CONTEXT * above_ctx,ENTROPY_CONTEXT * left_ctx,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int64_t ref_best_rd,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode)3427 static AOM_INLINE void tx_block_yrd(
3428     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
3429     TX_SIZE tx_size, BLOCK_SIZE plane_bsize, int depth,
3430     ENTROPY_CONTEXT *above_ctx, ENTROPY_CONTEXT *left_ctx,
3431     TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, int64_t ref_best_rd,
3432     RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode) {
3433   assert(tx_size < TX_SIZES_ALL);
3434   MACROBLOCKD *const xd = &x->e_mbd;
3435   MB_MODE_INFO *const mbmi = xd->mi[0];
3436   assert(is_inter_block(mbmi));
3437   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
3438   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
3439 
3440   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
3441 
3442   const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
3443       plane_bsize, blk_row, blk_col)];
3444   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
3445                                          mbmi->bsize, tx_size);
3446 
3447   av1_init_rd_stats(rd_stats);
3448   if (tx_size == plane_tx_size) {
3449     ENTROPY_CONTEXT *ta = above_ctx + blk_col;
3450     ENTROPY_CONTEXT *tl = left_ctx + blk_row;
3451     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
3452     TXB_CTX txb_ctx;
3453     get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
3454 
3455     const int zero_blk_rate =
3456         x->coeff_costs.coeff_costs[txs_ctx][get_plane_type(0)]
3457             .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
3458     rd_stats->zero_rate = zero_blk_rate;
3459     tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
3460                rd_stats, ftxs_mode, ref_best_rd, NULL);
3461     const int mi_width = mi_size_wide[plane_bsize];
3462     TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3463     if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
3464             RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
3465         rd_stats->skip_txfm == 1) {
3466       rd_stats->rate = zero_blk_rate;
3467       rd_stats->dist = rd_stats->sse;
3468       rd_stats->skip_txfm = 1;
3469       set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 1);
3470       x->plane[0].eobs[block] = 0;
3471       x->plane[0].txb_entropy_ctx[block] = 0;
3472       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
3473     } else {
3474       rd_stats->skip_txfm = 0;
3475       set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 0);
3476     }
3477     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3478       rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][0];
3479     av1_set_txb_context(x, 0, block, tx_size, ta, tl);
3480     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
3481                           tx_size);
3482   } else {
3483     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
3484     const int txb_width = tx_size_wide_unit[sub_txs];
3485     const int txb_height = tx_size_high_unit[sub_txs];
3486     const int step = txb_height * txb_width;
3487     const int row_end =
3488         AOMMIN(tx_size_high_unit[tx_size], max_blocks_high - blk_row);
3489     const int col_end =
3490         AOMMIN(tx_size_wide_unit[tx_size], max_blocks_wide - blk_col);
3491     RD_STATS pn_rd_stats;
3492     int64_t this_rd = 0;
3493     assert(txb_width > 0 && txb_height > 0);
3494 
3495     for (int row = 0; row < row_end; row += txb_height) {
3496       const int offsetr = blk_row + row;
3497       for (int col = 0; col < col_end; col += txb_width) {
3498         const int offsetc = blk_col + col;
3499 
3500         av1_init_rd_stats(&pn_rd_stats);
3501         tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
3502                      depth + 1, above_ctx, left_ctx, tx_above, tx_left,
3503                      ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
3504         if (pn_rd_stats.rate == INT_MAX) {
3505           av1_invalid_rd_stats(rd_stats);
3506           return;
3507         }
3508         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3509         this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
3510         block += step;
3511       }
3512     }
3513 
3514     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3515       rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][1];
3516   }
3517 }
3518 
3519 // search for tx type with tx sizes already decided for a inter-predicted luma
3520 // partition block. It's used only when some speed features are enabled.
3521 // Return value 0: early termination triggered, no valid rd cost available;
3522 //              1: rd cost values are valid.
inter_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode)3523 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3524                            RD_STATS *rd_stats, BLOCK_SIZE bsize,
3525                            int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
3526   if (ref_best_rd < 0) {
3527     av1_invalid_rd_stats(rd_stats);
3528     return 0;
3529   }
3530 
3531   av1_init_rd_stats(rd_stats);
3532 
3533   MACROBLOCKD *const xd = &x->e_mbd;
3534   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3535   const struct macroblockd_plane *const pd = &xd->plane[0];
3536   const int mi_width = mi_size_wide[bsize];
3537   const int mi_height = mi_size_high[bsize];
3538   const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0);
3539   const int bh = tx_size_high_unit[max_tx_size];
3540   const int bw = tx_size_wide_unit[max_tx_size];
3541   const int step = bw * bh;
3542   const int init_depth = get_search_init_depth(
3543       mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3544   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3545   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3546   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3547   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3548   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3549   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3550   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3551 
3552   int64_t this_rd = 0;
3553   for (int idy = 0, block = 0; idy < mi_height; idy += bh) {
3554     for (int idx = 0; idx < mi_width; idx += bw) {
3555       RD_STATS pn_rd_stats;
3556       av1_init_rd_stats(&pn_rd_stats);
3557       tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth,
3558                    ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd,
3559                    &pn_rd_stats, ftxs_mode);
3560       if (pn_rd_stats.rate == INT_MAX) {
3561         av1_invalid_rd_stats(rd_stats);
3562         return 0;
3563       }
3564       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3565       this_rd +=
3566           AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
3567                  RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
3568       block += step;
3569     }
3570   }
3571 
3572   const int skip_ctx = av1_get_skip_txfm_context(xd);
3573   const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3574   const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3575   const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3576   this_rd =
3577       RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate, rd_stats->dist);
3578   if (skip_txfm_rd < this_rd) {
3579     this_rd = skip_txfm_rd;
3580     rd_stats->rate = 0;
3581     rd_stats->dist = rd_stats->sse;
3582     rd_stats->skip_txfm = 1;
3583   }
3584 
3585   const int is_cost_valid = this_rd > ref_best_rd;
3586   if (!is_cost_valid) {
3587     // reset cost value
3588     av1_invalid_rd_stats(rd_stats);
3589   }
3590   return is_cost_valid;
3591 }
3592 
3593 // Search for the best transform size and type for current inter-predicted
3594 // luma block with recursive transform block partitioning. The obtained
3595 // transform selection will be saved in xd->mi[0], the corresponding RD stats
3596 // will be saved in rd_stats. The returned value is the corresponding RD cost.
select_tx_size_and_type(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,TXB_RD_INFO_NODE * rd_info_tree)3597 static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
3598                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
3599                                        int64_t ref_best_rd,
3600                                        TXB_RD_INFO_NODE *rd_info_tree) {
3601   MACROBLOCKD *const xd = &x->e_mbd;
3602   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3603   assert(is_inter_block(xd->mi[0]));
3604   assert(bsize < BLOCK_SIZES_ALL);
3605   const int fast_tx_search = txfm_params->tx_size_search_method > USE_FULL_RD;
3606   int64_t rd_thresh = ref_best_rd;
3607   if (rd_thresh == 0) {
3608     av1_invalid_rd_stats(rd_stats);
3609     return INT64_MAX;
3610   }
3611   if (fast_tx_search && rd_thresh < INT64_MAX) {
3612     if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
3613   }
3614   assert(rd_thresh > 0);
3615   const FAST_TX_SEARCH_MODE ftxs_mode =
3616       fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
3617   const struct macroblockd_plane *const pd = &xd->plane[0];
3618   assert(bsize < BLOCK_SIZES_ALL);
3619   const int mi_width = mi_size_wide[bsize];
3620   const int mi_height = mi_size_high[bsize];
3621   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3622   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3623   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3624   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3625   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3626   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3627   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3628   const int init_depth = get_search_init_depth(
3629       mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3630   const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
3631   const int bh = tx_size_high_unit[max_tx_size];
3632   const int bw = tx_size_wide_unit[max_tx_size];
3633   const int step = bw * bh;
3634   const int skip_ctx = av1_get_skip_txfm_context(xd);
3635   const int no_skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3636   const int skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3637   int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, 0);
3638   int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_cost, 0);
3639   int block = 0;
3640 
3641   av1_init_rd_stats(rd_stats);
3642   for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) {
3643     for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) {
3644       const int64_t best_rd_sofar =
3645           (rd_thresh == INT64_MAX)
3646               ? INT64_MAX
3647               : (rd_thresh - (AOMMIN(skip_txfm_rd, no_skip_txfm_rd)));
3648       int is_cost_valid = 1;
3649       RD_STATS pn_rd_stats;
3650       // Search for the best transform block size and type for the sub-block.
3651       select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize,
3652                       ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX,
3653                       best_rd_sofar, &is_cost_valid, ftxs_mode, rd_info_tree);
3654       if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
3655         av1_invalid_rd_stats(rd_stats);
3656         return INT64_MAX;
3657       }
3658       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3659       skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3660       no_skip_txfm_rd =
3661           RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3662       block += step;
3663       if (rd_info_tree != NULL) rd_info_tree += 1;
3664     }
3665   }
3666 
3667   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3668 
3669   rd_stats->skip_txfm = (skip_txfm_rd <= no_skip_txfm_rd);
3670 
3671   // If fast_tx_search is true, only DCT and 1D DCT were tested in
3672   // select_inter_block_yrd() above. Do a better search for tx type with
3673   // tx sizes already decided.
3674   if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) {
3675     if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
3676       return INT64_MAX;
3677   }
3678 
3679   int64_t final_rd;
3680   if (rd_stats->skip_txfm) {
3681     final_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3682   } else {
3683     final_rd =
3684         RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3685     if (!xd->lossless[xd->mi[0]->segment_id]) {
3686       final_rd =
3687           AOMMIN(final_rd, RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse));
3688     }
3689   }
3690 
3691   return final_rd;
3692 }
3693 
3694 // Return 1 to terminate transform search early. The decision is made based on
3695 // the comparison with the reference RD cost and the model-estimated RD cost.
model_based_tx_search_prune(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int64_t ref_best_rd)3696 static AOM_INLINE int model_based_tx_search_prune(const AV1_COMP *cpi,
3697                                                   MACROBLOCK *x,
3698                                                   BLOCK_SIZE bsize,
3699                                                   int64_t ref_best_rd) {
3700   const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
3701   assert(level >= 0 && level <= 2);
3702   int model_rate;
3703   int64_t model_dist;
3704   int model_skip;
3705   MACROBLOCKD *const xd = &x->e_mbd;
3706   model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
3707       cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
3708       NULL, NULL, NULL);
3709   if (model_skip) return 0;
3710   const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
3711   // TODO(debargha, urvang): Improve the model and make the check below
3712   // tighter.
3713   static const int prune_factor_by8[] = { 3, 5 };
3714   const int factor = prune_factor_by8[level - 1];
3715   return ((model_rd * factor) >> 3) > ref_best_rd;
3716 }
3717 
av1_pick_recursive_tx_size_type_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3718 void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3719                                          RD_STATS *rd_stats, BLOCK_SIZE bsize,
3720                                          int64_t ref_best_rd) {
3721   MACROBLOCKD *const xd = &x->e_mbd;
3722   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3723   assert(is_inter_block(xd->mi[0]));
3724 
3725   av1_invalid_rd_stats(rd_stats);
3726 
3727   // If modeled RD cost is a lot worse than the best so far, terminate early.
3728   if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
3729       ref_best_rd != INT64_MAX) {
3730     if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
3731   }
3732 
3733   // Hashing based speed feature. If the hash of the prediction residue block is
3734   // found in the hash table, use previous search results and terminate early.
3735   uint32_t hash = 0;
3736   MB_RD_RECORD *mb_rd_record = NULL;
3737   const int mi_row = x->e_mbd.mi_row;
3738   const int mi_col = x->e_mbd.mi_col;
3739   const int within_border =
3740       mi_row >= xd->tile.mi_row_start &&
3741       (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
3742       mi_col >= xd->tile.mi_col_start &&
3743       (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
3744   const int is_mb_rd_hash_enabled =
3745       (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
3746   const int n4 = bsize_to_num_blk(bsize);
3747   if (is_mb_rd_hash_enabled) {
3748     hash = get_block_residue_hash(x, bsize);
3749     mb_rd_record = &x->txfm_search_info.txb_rd_records->mb_rd_record;
3750     const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3751     if (match_index != -1) {
3752       MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
3753       fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
3754       return;
3755     }
3756   }
3757 
3758   // If we predict that skip is the optimal RD decision - set the respective
3759   // context and terminate early.
3760   int64_t dist;
3761   if (txfm_params->skip_txfm_level &&
3762       predict_skip_txfm(x, bsize, &dist,
3763                         cpi->common.features.reduced_tx_set_used)) {
3764     set_skip_txfm(x, rd_stats, bsize, dist);
3765     // Save the RD search results into tx_rd_record.
3766     if (is_mb_rd_hash_enabled)
3767       save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3768     return;
3769   }
3770 #if CONFIG_SPEED_STATS
3771   ++x->txfm_search_info.tx_search_count;
3772 #endif  // CONFIG_SPEED_STATS
3773 
3774   // Pre-compute residue hashes (transform block level) and find existing or
3775   // add new RD records to store and reuse rate and distortion values to speed
3776   // up TX size/type search.
3777   TXB_RD_INFO_NODE matched_rd_info[4 + 16 + 64];
3778   int found_rd_info = 0;
3779   if (ref_best_rd != INT64_MAX && within_border &&
3780       cpi->sf.tx_sf.use_inter_txb_hash) {
3781     found_rd_info = find_tx_size_rd_records(x, bsize, matched_rd_info);
3782   }
3783 
3784   const int64_t rd =
3785       select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd,
3786                               found_rd_info ? matched_rd_info : NULL);
3787 
3788   if (rd == INT64_MAX) {
3789     // We should always find at least one candidate unless ref_best_rd is less
3790     // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
3791     // might have failed to find something better)
3792     assert(ref_best_rd != INT64_MAX);
3793     av1_invalid_rd_stats(rd_stats);
3794     return;
3795   }
3796 
3797   // Save the RD search results into tx_rd_record.
3798   if (is_mb_rd_hash_enabled) {
3799     assert(mb_rd_record != NULL);
3800     save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3801   }
3802 }
3803 
av1_pick_uniform_tx_size_type_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bs,int64_t ref_best_rd)3804 void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3805                                        RD_STATS *rd_stats, BLOCK_SIZE bs,
3806                                        int64_t ref_best_rd) {
3807   MACROBLOCKD *const xd = &x->e_mbd;
3808   MB_MODE_INFO *const mbmi = xd->mi[0];
3809   const TxfmSearchParams *tx_params = &x->txfm_search_params;
3810   assert(bs == mbmi->bsize);
3811   const int is_inter = is_inter_block(mbmi);
3812   const int mi_row = xd->mi_row;
3813   const int mi_col = xd->mi_col;
3814 
3815   av1_init_rd_stats(rd_stats);
3816 
3817   // Hashing based speed feature for inter blocks. If the hash of the residue
3818   // block is found in the table, use previously saved search results and
3819   // terminate early.
3820   uint32_t hash = 0;
3821   MB_RD_RECORD *mb_rd_record = NULL;
3822   const int num_blks = bsize_to_num_blk(bs);
3823   if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) {
3824     const int within_border =
3825         mi_row >= xd->tile.mi_row_start &&
3826         (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
3827         mi_col >= xd->tile.mi_col_start &&
3828         (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
3829     if (within_border) {
3830       hash = get_block_residue_hash(x, bs);
3831       mb_rd_record = &x->txfm_search_info.txb_rd_records->mb_rd_record;
3832       const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3833       if (match_index != -1) {
3834         MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
3835         fetch_tx_rd_info(num_blks, tx_rd_info, rd_stats, x);
3836         return;
3837       }
3838     }
3839   }
3840 
3841   // If we predict that skip is the optimal RD decision - set the respective
3842   // context and terminate early.
3843   int64_t dist;
3844   if (tx_params->skip_txfm_level && is_inter &&
3845       !xd->lossless[mbmi->segment_id] &&
3846       predict_skip_txfm(x, bs, &dist,
3847                         cpi->common.features.reduced_tx_set_used)) {
3848     // Populate rdstats as per skip decision
3849     set_skip_txfm(x, rd_stats, bs, dist);
3850     // Save the RD search results into tx_rd_record.
3851     if (mb_rd_record) {
3852       save_tx_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3853     }
3854     return;
3855   }
3856 
3857   if (xd->lossless[mbmi->segment_id]) {
3858     // Lossless mode can only pick the smallest (4x4) transform size.
3859     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3860   } else if (tx_params->tx_size_search_method == USE_LARGESTALL) {
3861     choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3862   } else {
3863     choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
3864   }
3865 
3866   // Save the RD search results into tx_rd_record for possible reuse in future.
3867   if (mb_rd_record) {
3868     save_tx_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3869   }
3870 }
3871 
av1_txfm_uvrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3872 int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
3873                   BLOCK_SIZE bsize, int64_t ref_best_rd) {
3874   av1_init_rd_stats(rd_stats);
3875   if (ref_best_rd < 0) return 0;
3876   if (!x->e_mbd.is_chroma_ref) return 1;
3877 
3878   MACROBLOCKD *const xd = &x->e_mbd;
3879   MB_MODE_INFO *const mbmi = xd->mi[0];
3880   struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
3881   const int is_inter = is_inter_block(mbmi);
3882   int64_t this_rd = 0, skip_txfm_rd = 0;
3883   const BLOCK_SIZE plane_bsize =
3884       get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
3885 
3886   if (is_inter) {
3887     for (int plane = 1; plane < MAX_MB_PLANE; ++plane)
3888       av1_subtract_plane(x, plane_bsize, plane);
3889   }
3890 
3891   const int skip_trellis = 0;
3892   const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
3893   int is_cost_valid = 1;
3894   for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
3895     RD_STATS this_rd_stats;
3896     int64_t chroma_ref_best_rd = ref_best_rd;
3897     // For inter blocks, refined ref_best_rd is used for early exit
3898     // For intra blocks, even though current rd crosses ref_best_rd, early
3899     // exit is not recommended as current rd is used for gating subsequent
3900     // modes as well (say, for angular modes)
3901     // TODO(any): Extend the early exit mechanism for intra modes as well
3902     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
3903         chroma_ref_best_rd != INT64_MAX)
3904       chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
3905     av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
3906                          plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis);
3907     if (this_rd_stats.rate == INT_MAX) {
3908       is_cost_valid = 0;
3909       break;
3910     }
3911     av1_merge_rd_stats(rd_stats, &this_rd_stats);
3912     this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3913     skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
3914     if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
3915       is_cost_valid = 0;
3916       break;
3917     }
3918   }
3919 
3920   if (!is_cost_valid) {
3921     // reset cost value
3922     av1_invalid_rd_stats(rd_stats);
3923   }
3924 
3925   return is_cost_valid;
3926 }
3927 
av1_txfm_rd_in_plane(MACROBLOCK * x,const AV1_COMP * cpi,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t current_rd,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3928 void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3929                           RD_STATS *rd_stats, int64_t ref_best_rd,
3930                           int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
3931                           TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
3932                           int skip_trellis) {
3933   assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size));
3934 
3935   if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
3936       txsize_sqr_up_map[tx_size] == TX_64X64) {
3937     av1_invalid_rd_stats(rd_stats);
3938     return;
3939   }
3940 
3941   if (current_rd > ref_best_rd) {
3942     av1_invalid_rd_stats(rd_stats);
3943     return;
3944   }
3945 
3946   MACROBLOCKD *const xd = &x->e_mbd;
3947   const struct macroblockd_plane *const pd = &xd->plane[plane];
3948   struct rdcost_block_args args;
3949   av1_zero(args);
3950   args.x = x;
3951   args.cpi = cpi;
3952   args.best_rd = ref_best_rd;
3953   args.current_rd = current_rd;
3954   args.ftxs_mode = ftxs_mode;
3955   args.skip_trellis = skip_trellis;
3956   av1_init_rd_stats(&args.rd_stats);
3957 
3958   av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
3959   av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
3960                                          &args);
3961 
3962   MB_MODE_INFO *const mbmi = xd->mi[0];
3963   const int is_inter = is_inter_block(mbmi);
3964   const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3965 
3966   if (invalid_rd) {
3967     av1_invalid_rd_stats(rd_stats);
3968   } else {
3969     *rd_stats = args.rd_stats;
3970   }
3971 }
3972 
av1_txfm_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int mode_rate,int64_t ref_best_rd)3973 int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
3974                     RD_STATS *rd_stats, RD_STATS *rd_stats_y,
3975                     RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
3976   MACROBLOCKD *const xd = &x->e_mbd;
3977   TxfmSearchParams *txfm_params = &x->txfm_search_params;
3978   const int skip_ctx = av1_get_skip_txfm_context(xd);
3979   const int skip_txfm_cost[2] = { x->mode_costs.skip_txfm_cost[skip_ctx][0],
3980                                   x->mode_costs.skip_txfm_cost[skip_ctx][1] };
3981   const int64_t min_header_rate =
3982       mode_rate + AOMMIN(skip_txfm_cost[0], skip_txfm_cost[1]);
3983   // Account for minimum skip and non_skip rd.
3984   // Eventually either one of them will be added to mode_rate
3985   const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
3986   if (min_header_rd_possible > ref_best_rd) {
3987     av1_invalid_rd_stats(rd_stats_y);
3988     return 0;
3989   }
3990 
3991   const AV1_COMMON *cm = &cpi->common;
3992   MB_MODE_INFO *const mbmi = xd->mi[0];
3993   const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
3994   const int64_t rd_thresh =
3995       ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
3996   av1_init_rd_stats(rd_stats);
3997   av1_init_rd_stats(rd_stats_y);
3998   rd_stats->rate = mode_rate;
3999 
4000   // cost and distortion
4001   av1_subtract_plane(x, bsize, 0);
4002   if (txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
4003       !xd->lossless[mbmi->segment_id]) {
4004     av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
4005 #if CONFIG_COLLECT_RD_STATS == 2
4006     PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
4007 #endif  // CONFIG_COLLECT_RD_STATS == 2
4008   } else {
4009     av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
4010     memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
4011     for (int i = 0; i < xd->height * xd->width; ++i)
4012       set_blk_skip(x->txfm_search_info.blk_skip, 0, i, rd_stats_y->skip_txfm);
4013   }
4014 
4015   if (rd_stats_y->rate == INT_MAX) return 0;
4016 
4017   av1_merge_rd_stats(rd_stats, rd_stats_y);
4018 
4019   const int64_t non_skip_txfm_rdcosty =
4020       RDCOST(x->rdmult, rd_stats->rate + skip_txfm_cost[0], rd_stats->dist);
4021   const int64_t skip_txfm_rdcosty =
4022       RDCOST(x->rdmult, mode_rate + skip_txfm_cost[1], rd_stats->sse);
4023   const int64_t min_rdcosty = AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty);
4024   if (min_rdcosty > ref_best_rd) return 0;
4025 
4026   av1_init_rd_stats(rd_stats_uv);
4027   const int num_planes = av1_num_planes(cm);
4028   if (num_planes > 1) {
4029     int64_t ref_best_chroma_rd = ref_best_rd;
4030     // Calculate best rd cost possible for chroma
4031     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
4032         (ref_best_chroma_rd != INT64_MAX)) {
4033       ref_best_chroma_rd = (ref_best_chroma_rd -
4034                             AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty));
4035     }
4036     const int is_cost_valid_uv =
4037         av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
4038     if (!is_cost_valid_uv) return 0;
4039     av1_merge_rd_stats(rd_stats, rd_stats_uv);
4040   }
4041 
4042   int choose_skip_txfm = rd_stats->skip_txfm;
4043   if (!choose_skip_txfm && !xd->lossless[mbmi->segment_id]) {
4044     const int64_t rdcost_no_skip_txfm = RDCOST(
4045         x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_txfm_cost[0],
4046         rd_stats->dist);
4047     const int64_t rdcost_skip_txfm =
4048         RDCOST(x->rdmult, skip_txfm_cost[1], rd_stats->sse);
4049     if (rdcost_no_skip_txfm >= rdcost_skip_txfm) choose_skip_txfm = 1;
4050   }
4051   if (choose_skip_txfm) {
4052     rd_stats_y->rate = 0;
4053     rd_stats_uv->rate = 0;
4054     rd_stats->rate = mode_rate + skip_txfm_cost[1];
4055     rd_stats->dist = rd_stats->sse;
4056     rd_stats_y->dist = rd_stats_y->sse;
4057     rd_stats_uv->dist = rd_stats_uv->sse;
4058     mbmi->skip_txfm = 1;
4059     if (rd_stats->skip_txfm) {
4060       const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
4061       if (tmprd > ref_best_rd) return 0;
4062     }
4063   } else {
4064     rd_stats->rate += skip_txfm_cost[0];
4065     mbmi->skip_txfm = 0;
4066   }
4067 
4068   return 1;
4069 }
4070