• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/encoder/intra_mode_search.h"
13 #include "av1/encoder/model_rd.h"
14 #include "av1/encoder/palette.h"
15 #include "av1/common/pred_common.h"
16 #include "av1/common/reconintra.h"
17 #include "av1/encoder/tx_search.h"
18 
19 static const PREDICTION_MODE intra_rd_search_mode_order[INTRA_MODES] = {
20   DC_PRED,       H_PRED,        V_PRED,    SMOOTH_PRED, PAETH_PRED,
21   SMOOTH_V_PRED, SMOOTH_H_PRED, D135_PRED, D203_PRED,   D157_PRED,
22   D67_PRED,      D113_PRED,     D45_PRED,
23 };
24 
25 static const UV_PREDICTION_MODE uv_rd_search_mode_order[UV_INTRA_MODES] = {
26   UV_DC_PRED,     UV_CFL_PRED,   UV_H_PRED,        UV_V_PRED,
27   UV_SMOOTH_PRED, UV_PAETH_PRED, UV_SMOOTH_V_PRED, UV_SMOOTH_H_PRED,
28   UV_D135_PRED,   UV_D203_PRED,  UV_D157_PRED,     UV_D67_PRED,
29   UV_D113_PRED,   UV_D45_PRED,
30 };
31 
32 #define BINS 32
33 static const float intra_hog_model_bias[DIRECTIONAL_MODES] = {
34   0.450578f,  0.695518f,  -0.717944f, -0.639894f,
35   -0.602019f, -0.453454f, 0.055857f,  -0.465480f,
36 };
37 
38 static const float intra_hog_model_weights[BINS * DIRECTIONAL_MODES] = {
39   -3.076402f, -3.757063f, -3.275266f, -3.180665f, -3.452105f, -3.216593f,
40   -2.871212f, -3.134296f, -1.822324f, -2.401411f, -1.541016f, -1.195322f,
41   -0.434156f, 0.322868f,  2.260546f,  3.368715f,  3.989290f,  3.308487f,
42   2.277893f,  0.923793f,  0.026412f,  -0.385174f, -0.718622f, -1.408867f,
43   -1.050558f, -2.323941f, -2.225827f, -2.585453f, -3.054283f, -2.875087f,
44   -2.985709f, -3.447155f, 3.758139f,  3.204353f,  2.170998f,  0.826587f,
45   -0.269665f, -0.702068f, -1.085776f, -2.175249f, -1.623180f, -2.975142f,
46   -2.779629f, -3.190799f, -3.521900f, -3.375480f, -3.319355f, -3.897389f,
47   -3.172334f, -3.594528f, -2.879132f, -2.547777f, -2.921023f, -2.281844f,
48   -1.818988f, -2.041771f, -0.618268f, -1.396458f, -0.567153f, -0.285868f,
49   -0.088058f, 0.753494f,  2.092413f,  3.215266f,  -3.300277f, -2.748658f,
50   -2.315784f, -2.423671f, -2.257283f, -2.269583f, -2.196660f, -2.301076f,
51   -2.646516f, -2.271319f, -2.254366f, -2.300102f, -2.217960f, -2.473300f,
52   -2.116866f, -2.528246f, -3.314712f, -1.701010f, -0.589040f, -0.088077f,
53   0.813112f,  1.702213f,  2.653045f,  3.351749f,  3.243554f,  3.199409f,
54   2.437856f,  1.468854f,  0.533039f,  -0.099065f, -0.622643f, -2.200732f,
55   -4.228861f, -2.875263f, -1.273956f, -0.433280f, 0.803771f,  1.975043f,
56   3.179528f,  3.939064f,  3.454379f,  3.689386f,  3.116411f,  1.970991f,
57   0.798406f,  -0.628514f, -1.252546f, -2.825176f, -4.090178f, -3.777448f,
58   -3.227314f, -3.479403f, -3.320569f, -3.159372f, -2.729202f, -2.722341f,
59   -3.054913f, -2.742923f, -2.612703f, -2.662632f, -2.907314f, -3.117794f,
60   -3.102660f, -3.970972f, -4.891357f, -3.935582f, -3.347758f, -2.721924f,
61   -2.219011f, -1.702391f, -0.866529f, -0.153743f, 0.107733f,  1.416882f,
62   2.572884f,  3.607755f,  3.974820f,  3.997783f,  2.970459f,  0.791687f,
63   -1.478921f, -1.228154f, -1.216955f, -1.765932f, -1.951003f, -1.985301f,
64   -1.975881f, -1.985593f, -2.422371f, -2.419978f, -2.531288f, -2.951853f,
65   -3.071380f, -3.277027f, -3.373539f, -4.462010f, -0.967888f, 0.805524f,
66   2.794130f,  3.685984f,  3.745195f,  3.252444f,  2.316108f,  1.399146f,
67   -0.136519f, -0.162811f, -1.004357f, -1.667911f, -1.964662f, -2.937579f,
68   -3.019533f, -3.942766f, -5.102767f, -3.882073f, -3.532027f, -3.451956f,
69   -2.944015f, -2.643064f, -2.529872f, -2.077290f, -2.809965f, -1.803734f,
70   -1.783593f, -1.662585f, -1.415484f, -1.392673f, -0.788794f, -1.204819f,
71   -1.998864f, -1.182102f, -0.892110f, -1.317415f, -1.359112f, -1.522867f,
72   -1.468552f, -1.779072f, -2.332959f, -2.160346f, -2.329387f, -2.631259f,
73   -2.744936f, -3.052494f, -2.787363f, -3.442548f, -4.245075f, -3.032172f,
74   -2.061609f, -1.768116f, -1.286072f, -0.706587f, -0.192413f, 0.386938f,
75   0.716997f,  1.481393f,  2.216702f,  2.737986f,  3.109809f,  3.226084f,
76   2.490098f,  -0.095827f, -3.864816f, -3.507248f, -3.128925f, -2.908251f,
77   -2.883836f, -2.881411f, -2.524377f, -2.624478f, -2.399573f, -2.367718f,
78   -1.918255f, -1.926277f, -1.694584f, -1.723790f, -0.966491f, -1.183115f,
79   -1.430687f, 0.872896f,  2.766550f,  3.610080f,  3.578041f,  3.334928f,
80   2.586680f,  1.895721f,  1.122195f,  0.488519f,  -0.140689f, -0.799076f,
81   -1.222860f, -1.502437f, -1.900969f, -3.206816f,
82 };
83 
generate_hog(const uint8_t * src,int stride,int rows,int cols,float * hist)84 static void generate_hog(const uint8_t *src, int stride, int rows, int cols,
85                          float *hist) {
86   const float step = (float)PI / BINS;
87   float total = 0.1f;
88   src += stride;
89   for (int r = 1; r < rows - 1; ++r) {
90     for (int c = 1; c < cols - 1; ++c) {
91       const uint8_t *above = &src[c - stride];
92       const uint8_t *below = &src[c + stride];
93       const uint8_t *left = &src[c - 1];
94       const uint8_t *right = &src[c + 1];
95       // Calculate gradient using Sobel fitlers.
96       const int dx = (right[-stride] + 2 * right[0] + right[stride]) -
97                      (left[-stride] + 2 * left[0] + left[stride]);
98       const int dy = (below[-1] + 2 * below[0] + below[1]) -
99                      (above[-1] + 2 * above[0] + above[1]);
100       if (dx == 0 && dy == 0) continue;
101       const int temp = abs(dx) + abs(dy);
102       if (!temp) continue;
103       total += temp;
104       if (dx == 0) {
105         hist[0] += temp / 2;
106         hist[BINS - 1] += temp / 2;
107       } else {
108         const float angle = atanf(dy * 1.0f / dx);
109         int idx = (int)roundf(angle / step) + BINS / 2;
110         idx = AOMMIN(idx, BINS - 1);
111         idx = AOMMAX(idx, 0);
112         hist[idx] += temp;
113       }
114     }
115     src += stride;
116   }
117 
118   for (int i = 0; i < BINS; ++i) hist[i] /= total;
119 }
120 
generate_hog_hbd(const uint8_t * src8,int stride,int rows,int cols,float * hist)121 static void generate_hog_hbd(const uint8_t *src8, int stride, int rows,
122                              int cols, float *hist) {
123   const float step = (float)PI / BINS;
124   float total = 0.1f;
125   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
126   src += stride;
127   for (int r = 1; r < rows - 1; ++r) {
128     for (int c = 1; c < cols - 1; ++c) {
129       const uint16_t *above = &src[c - stride];
130       const uint16_t *below = &src[c + stride];
131       const uint16_t *left = &src[c - 1];
132       const uint16_t *right = &src[c + 1];
133       // Calculate gradient using Sobel fitlers.
134       const int dx = (right[-stride] + 2 * right[0] + right[stride]) -
135                      (left[-stride] + 2 * left[0] + left[stride]);
136       const int dy = (below[-1] + 2 * below[0] + below[1]) -
137                      (above[-1] + 2 * above[0] + above[1]);
138       if (dx == 0 && dy == 0) continue;
139       const int temp = abs(dx) + abs(dy);
140       if (!temp) continue;
141       total += temp;
142       if (dx == 0) {
143         hist[0] += temp / 2;
144         hist[BINS - 1] += temp / 2;
145       } else {
146         const float angle = atanf(dy * 1.0f / dx);
147         int idx = (int)roundf(angle / step) + BINS / 2;
148         idx = AOMMIN(idx, BINS - 1);
149         idx = AOMMAX(idx, 0);
150         hist[idx] += temp;
151       }
152     }
153     src += stride;
154   }
155 
156   for (int i = 0; i < BINS; ++i) hist[i] /= total;
157 }
158 
prune_intra_mode_with_hog(const MACROBLOCK * x,BLOCK_SIZE bsize,float th,uint8_t * directional_mode_skip_mask)159 static void prune_intra_mode_with_hog(const MACROBLOCK *x, BLOCK_SIZE bsize,
160                                       float th,
161                                       uint8_t *directional_mode_skip_mask) {
162   aom_clear_system_state();
163 
164   const int bh = block_size_high[bsize];
165   const int bw = block_size_wide[bsize];
166   const MACROBLOCKD *xd = &x->e_mbd;
167   const int rows =
168       (xd->mb_to_bottom_edge >= 0) ? bh : (xd->mb_to_bottom_edge >> 3) + bh;
169   const int cols =
170       (xd->mb_to_right_edge >= 0) ? bw : (xd->mb_to_right_edge >> 3) + bw;
171   const int src_stride = x->plane[0].src.stride;
172   const uint8_t *src = x->plane[0].src.buf;
173   float hist[BINS] = { 0.0f };
174   if (is_cur_buf_hbd(xd)) {
175     generate_hog_hbd(src, src_stride, rows, cols, hist);
176   } else {
177     generate_hog(src, src_stride, rows, cols, hist);
178   }
179 
180   for (int i = 0; i < DIRECTIONAL_MODES; ++i) {
181     float this_score = intra_hog_model_bias[i];
182     const float *weights = &intra_hog_model_weights[i * BINS];
183     for (int j = 0; j < BINS; ++j) {
184       this_score += weights[j] * hist[j];
185     }
186     if (this_score < th) directional_mode_skip_mask[i + 1] = 1;
187   }
188 
189   aom_clear_system_state();
190 }
191 
192 #undef BINS
193 
194 // Model based RD estimation for luma intra blocks.
intra_model_yrd(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mode_cost)195 static int64_t intra_model_yrd(const AV1_COMP *const cpi, MACROBLOCK *const x,
196                                BLOCK_SIZE bsize, int mode_cost) {
197   const AV1_COMMON *cm = &cpi->common;
198   MACROBLOCKD *const xd = &x->e_mbd;
199   MB_MODE_INFO *const mbmi = xd->mi[0];
200   assert(!is_inter_block(mbmi));
201   RD_STATS this_rd_stats;
202   int row, col;
203   int64_t temp_sse, this_rd;
204   TX_SIZE tx_size = tx_size_from_tx_mode(bsize, x->tx_mode_search_type);
205   const int stepr = tx_size_high_unit[tx_size];
206   const int stepc = tx_size_wide_unit[tx_size];
207   const int max_blocks_wide = max_block_wide(xd, bsize, 0);
208   const int max_blocks_high = max_block_high(xd, bsize, 0);
209   mbmi->tx_size = tx_size;
210   // Prediction.
211   for (row = 0; row < max_blocks_high; row += stepr) {
212     for (col = 0; col < max_blocks_wide; col += stepc) {
213       av1_predict_intra_block_facade(cm, xd, 0, col, row, tx_size);
214     }
215   }
216   // RD estimation.
217   model_rd_sb_fn[cpi->sf.rt_sf.use_simple_rd_model ? MODELRD_LEGACY
218                                                    : MODELRD_TYPE_INTRA](
219       cpi, bsize, x, xd, 0, 0, &this_rd_stats.rate, &this_rd_stats.dist,
220       &this_rd_stats.skip, &temp_sse, NULL, NULL, NULL);
221   if (av1_is_directional_mode(mbmi->mode) && av1_use_angle_delta(bsize)) {
222     mode_cost +=
223         x->angle_delta_cost[mbmi->mode - V_PRED]
224                            [MAX_ANGLE_DELTA + mbmi->angle_delta[PLANE_TYPE_Y]];
225   }
226   if (mbmi->mode == DC_PRED &&
227       av1_filter_intra_allowed_bsize(cm, mbmi->sb_type)) {
228     if (mbmi->filter_intra_mode_info.use_filter_intra) {
229       const int mode = mbmi->filter_intra_mode_info.filter_intra_mode;
230       mode_cost += x->filter_intra_cost[mbmi->sb_type][1] +
231                    x->filter_intra_mode_cost[mode];
232     } else {
233       mode_cost += x->filter_intra_cost[mbmi->sb_type][0];
234     }
235   }
236   this_rd =
237       RDCOST(x->rdmult, this_rd_stats.rate + mode_cost, this_rd_stats.dist);
238   return this_rd;
239 }
240 
241 // Update the intra model yrd and prune the current mode if the new estimate
242 // y_rd > 1.5 * best_model_rd.
model_intra_yrd_and_prune(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mode_info_cost,int64_t * best_model_rd)243 static AOM_INLINE int model_intra_yrd_and_prune(const AV1_COMP *const cpi,
244                                                 MACROBLOCK *x, BLOCK_SIZE bsize,
245                                                 int mode_info_cost,
246                                                 int64_t *best_model_rd) {
247   const int64_t this_model_rd = intra_model_yrd(cpi, x, bsize, mode_info_cost);
248   if (*best_model_rd != INT64_MAX &&
249       this_model_rd > *best_model_rd + (*best_model_rd >> 1)) {
250     return 1;
251   } else if (this_model_rd < *best_model_rd) {
252     *best_model_rd = this_model_rd;
253   }
254   return 0;
255 }
256 
257 // Run RD calculation with given luma intra prediction angle., and return
258 // the RD cost. Update the best mode info. if the RD cost is the best so far.
calc_rd_given_intra_angle(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mode_cost,int64_t best_rd_in,int8_t angle_delta,int max_angle_delta,int * rate,RD_STATS * rd_stats,int * best_angle_delta,TX_SIZE * best_tx_size,int64_t * best_rd,int64_t * best_model_rd,uint8_t * best_tx_type_map,uint8_t * best_blk_skip,int skip_model_rd)259 static int64_t calc_rd_given_intra_angle(
260     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mode_cost,
261     int64_t best_rd_in, int8_t angle_delta, int max_angle_delta, int *rate,
262     RD_STATS *rd_stats, int *best_angle_delta, TX_SIZE *best_tx_size,
263     int64_t *best_rd, int64_t *best_model_rd, uint8_t *best_tx_type_map,
264     uint8_t *best_blk_skip, int skip_model_rd) {
265   RD_STATS tokenonly_rd_stats;
266   int64_t this_rd;
267   MACROBLOCKD *xd = &x->e_mbd;
268   MB_MODE_INFO *mbmi = xd->mi[0];
269   const int n4 = bsize_to_num_blk(bsize);
270   assert(!is_inter_block(mbmi));
271   mbmi->angle_delta[PLANE_TYPE_Y] = angle_delta;
272   if (!skip_model_rd) {
273     if (model_intra_yrd_and_prune(cpi, x, bsize, mode_cost, best_model_rd)) {
274       return INT64_MAX;
275     }
276   }
277   av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
278                                     best_rd_in);
279   if (tokenonly_rd_stats.rate == INT_MAX) return INT64_MAX;
280 
281   int this_rate =
282       mode_cost + tokenonly_rd_stats.rate +
283       x->angle_delta_cost[mbmi->mode - V_PRED][max_angle_delta + angle_delta];
284   this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
285 
286   if (this_rd < *best_rd) {
287     memcpy(best_blk_skip, x->blk_skip, sizeof(best_blk_skip[0]) * n4);
288     av1_copy_array(best_tx_type_map, xd->tx_type_map, n4);
289     *best_rd = this_rd;
290     *best_angle_delta = mbmi->angle_delta[PLANE_TYPE_Y];
291     *best_tx_size = mbmi->tx_size;
292     *rate = this_rate;
293     rd_stats->rate = tokenonly_rd_stats.rate;
294     rd_stats->dist = tokenonly_rd_stats.dist;
295     rd_stats->skip = tokenonly_rd_stats.skip;
296   }
297   return this_rd;
298 }
299 
write_uniform_cost(int n,int v)300 static INLINE int write_uniform_cost(int n, int v) {
301   const int l = get_unsigned_bits(n);
302   const int m = (1 << l) - n;
303   if (l == 0) return 0;
304   if (v < m)
305     return av1_cost_literal(l - 1);
306   else
307     return av1_cost_literal(l);
308 }
309 
310 // Return the rate cost for luma prediction mode info. of intra blocks.
intra_mode_info_cost_y(const AV1_COMP * cpi,const MACROBLOCK * x,const MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int mode_cost)311 static int intra_mode_info_cost_y(const AV1_COMP *cpi, const MACROBLOCK *x,
312                                   const MB_MODE_INFO *mbmi, BLOCK_SIZE bsize,
313                                   int mode_cost) {
314   int total_rate = mode_cost;
315   const int use_palette = mbmi->palette_mode_info.palette_size[0] > 0;
316   const int use_filter_intra = mbmi->filter_intra_mode_info.use_filter_intra;
317   const int use_intrabc = mbmi->use_intrabc;
318   // Can only activate one mode.
319   assert(((mbmi->mode != DC_PRED) + use_palette + use_intrabc +
320           use_filter_intra) <= 1);
321   const int try_palette = av1_allow_palette(
322       cpi->common.features.allow_screen_content_tools, mbmi->sb_type);
323   if (try_palette && mbmi->mode == DC_PRED) {
324     const MACROBLOCKD *xd = &x->e_mbd;
325     const int bsize_ctx = av1_get_palette_bsize_ctx(bsize);
326     const int mode_ctx = av1_get_palette_mode_ctx(xd);
327     total_rate += x->palette_y_mode_cost[bsize_ctx][mode_ctx][use_palette];
328     if (use_palette) {
329       const uint8_t *const color_map = xd->plane[0].color_index_map;
330       int block_width, block_height, rows, cols;
331       av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
332                                &cols);
333       const int plt_size = mbmi->palette_mode_info.palette_size[0];
334       int palette_mode_cost =
335           x->palette_y_size_cost[bsize_ctx][plt_size - PALETTE_MIN_SIZE] +
336           write_uniform_cost(plt_size, color_map[0]);
337       uint16_t color_cache[2 * PALETTE_MAX_SIZE];
338       const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
339       palette_mode_cost +=
340           av1_palette_color_cost_y(&mbmi->palette_mode_info, color_cache,
341                                    n_cache, cpi->common.seq_params.bit_depth);
342       palette_mode_cost +=
343           av1_cost_color_map(x, 0, bsize, mbmi->tx_size, PALETTE_MAP);
344       total_rate += palette_mode_cost;
345     }
346   }
347   if (av1_filter_intra_allowed(&cpi->common, mbmi)) {
348     total_rate += x->filter_intra_cost[mbmi->sb_type][use_filter_intra];
349     if (use_filter_intra) {
350       total_rate += x->filter_intra_mode_cost[mbmi->filter_intra_mode_info
351                                                   .filter_intra_mode];
352     }
353   }
354   if (av1_is_directional_mode(mbmi->mode)) {
355     if (av1_use_angle_delta(bsize)) {
356       total_rate += x->angle_delta_cost[mbmi->mode - V_PRED]
357                                        [MAX_ANGLE_DELTA +
358                                         mbmi->angle_delta[PLANE_TYPE_Y]];
359     }
360   }
361   if (av1_allow_intrabc(&cpi->common))
362     total_rate += x->intrabc_cost[use_intrabc];
363   return total_rate;
364 }
365 
366 // Return the rate cost for chroma prediction mode info. of intra blocks.
intra_mode_info_cost_uv(const AV1_COMP * cpi,const MACROBLOCK * x,const MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int mode_cost)367 static int intra_mode_info_cost_uv(const AV1_COMP *cpi, const MACROBLOCK *x,
368                                    const MB_MODE_INFO *mbmi, BLOCK_SIZE bsize,
369                                    int mode_cost) {
370   int total_rate = mode_cost;
371   const int use_palette = mbmi->palette_mode_info.palette_size[1] > 0;
372   const UV_PREDICTION_MODE mode = mbmi->uv_mode;
373   // Can only activate one mode.
374   assert(((mode != UV_DC_PRED) + use_palette + mbmi->use_intrabc) <= 1);
375 
376   const int try_palette = av1_allow_palette(
377       cpi->common.features.allow_screen_content_tools, mbmi->sb_type);
378   if (try_palette && mode == UV_DC_PRED) {
379     const PALETTE_MODE_INFO *pmi = &mbmi->palette_mode_info;
380     total_rate +=
381         x->palette_uv_mode_cost[pmi->palette_size[0] > 0][use_palette];
382     if (use_palette) {
383       const int bsize_ctx = av1_get_palette_bsize_ctx(bsize);
384       const int plt_size = pmi->palette_size[1];
385       const MACROBLOCKD *xd = &x->e_mbd;
386       const uint8_t *const color_map = xd->plane[1].color_index_map;
387       int palette_mode_cost =
388           x->palette_uv_size_cost[bsize_ctx][plt_size - PALETTE_MIN_SIZE] +
389           write_uniform_cost(plt_size, color_map[0]);
390       uint16_t color_cache[2 * PALETTE_MAX_SIZE];
391       const int n_cache = av1_get_palette_cache(xd, 1, color_cache);
392       palette_mode_cost += av1_palette_color_cost_uv(
393           pmi, color_cache, n_cache, cpi->common.seq_params.bit_depth);
394       palette_mode_cost +=
395           av1_cost_color_map(x, 1, bsize, mbmi->tx_size, PALETTE_MAP);
396       total_rate += palette_mode_cost;
397     }
398   }
399   if (av1_is_directional_mode(get_uv_mode(mode))) {
400     if (av1_use_angle_delta(bsize)) {
401       total_rate +=
402           x->angle_delta_cost[mode - V_PRED][mbmi->angle_delta[PLANE_TYPE_UV] +
403                                              MAX_ANGLE_DELTA];
404     }
405   }
406   return total_rate;
407 }
408 
409 // Return 1 if an filter intra mode is selected; return 0 otherwise.
rd_pick_filter_intra_sby(const AV1_COMP * const cpi,MACROBLOCK * x,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,BLOCK_SIZE bsize,int mode_cost,int64_t * best_rd,int64_t * best_model_rd,PICK_MODE_CONTEXT * ctx)410 static int rd_pick_filter_intra_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
411                                     int *rate, int *rate_tokenonly,
412                                     int64_t *distortion, int *skippable,
413                                     BLOCK_SIZE bsize, int mode_cost,
414                                     int64_t *best_rd, int64_t *best_model_rd,
415                                     PICK_MODE_CONTEXT *ctx) {
416   MACROBLOCKD *const xd = &x->e_mbd;
417   MB_MODE_INFO *mbmi = xd->mi[0];
418   int filter_intra_selected_flag = 0;
419   FILTER_INTRA_MODE mode;
420   TX_SIZE best_tx_size = TX_8X8;
421   FILTER_INTRA_MODE_INFO filter_intra_mode_info;
422   uint8_t best_tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
423   (void)ctx;
424   av1_zero(filter_intra_mode_info);
425   mbmi->filter_intra_mode_info.use_filter_intra = 1;
426   mbmi->mode = DC_PRED;
427   mbmi->palette_mode_info.palette_size[0] = 0;
428 
429   for (mode = 0; mode < FILTER_INTRA_MODES; ++mode) {
430     int64_t this_rd;
431     RD_STATS tokenonly_rd_stats;
432     mbmi->filter_intra_mode_info.filter_intra_mode = mode;
433 
434     if (model_intra_yrd_and_prune(cpi, x, bsize, mode_cost, best_model_rd)) {
435       continue;
436     }
437     av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
438                                       *best_rd);
439     if (tokenonly_rd_stats.rate == INT_MAX) continue;
440     const int this_rate =
441         tokenonly_rd_stats.rate +
442         intra_mode_info_cost_y(cpi, x, mbmi, bsize, mode_cost);
443     this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
444 
445     // Collect mode stats for multiwinner mode processing
446     const int txfm_search_done = 1;
447     store_winner_mode_stats(
448         &cpi->common, x, mbmi, NULL, NULL, NULL, 0, NULL, bsize, this_rd,
449         cpi->sf.winner_mode_sf.enable_multiwinner_mode_process,
450         txfm_search_done);
451     if (this_rd < *best_rd) {
452       *best_rd = this_rd;
453       best_tx_size = mbmi->tx_size;
454       filter_intra_mode_info = mbmi->filter_intra_mode_info;
455       av1_copy_array(best_tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
456       memcpy(ctx->blk_skip, x->blk_skip,
457              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
458       *rate = this_rate;
459       *rate_tokenonly = tokenonly_rd_stats.rate;
460       *distortion = tokenonly_rd_stats.dist;
461       *skippable = tokenonly_rd_stats.skip;
462       filter_intra_selected_flag = 1;
463     }
464   }
465 
466   if (filter_intra_selected_flag) {
467     mbmi->mode = DC_PRED;
468     mbmi->tx_size = best_tx_size;
469     mbmi->filter_intra_mode_info = filter_intra_mode_info;
470     av1_copy_array(ctx->tx_type_map, best_tx_type_map, ctx->num_4x4_blk);
471     return 1;
472   } else {
473     return 0;
474   }
475 }
476 
av1_count_colors(const uint8_t * src,int stride,int rows,int cols,int * val_count)477 int av1_count_colors(const uint8_t *src, int stride, int rows, int cols,
478                      int *val_count) {
479   const int max_pix_val = 1 << 8;
480   memset(val_count, 0, max_pix_val * sizeof(val_count[0]));
481   for (int r = 0; r < rows; ++r) {
482     for (int c = 0; c < cols; ++c) {
483       const int this_val = src[r * stride + c];
484       assert(this_val < max_pix_val);
485       ++val_count[this_val];
486     }
487   }
488   int n = 0;
489   for (int i = 0; i < max_pix_val; ++i) {
490     if (val_count[i]) ++n;
491   }
492   return n;
493 }
494 
av1_count_colors_highbd(const uint8_t * src8,int stride,int rows,int cols,int bit_depth,int * val_count)495 int av1_count_colors_highbd(const uint8_t *src8, int stride, int rows, int cols,
496                             int bit_depth, int *val_count) {
497   assert(bit_depth <= 12);
498   const int max_pix_val = 1 << bit_depth;
499   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
500   memset(val_count, 0, max_pix_val * sizeof(val_count[0]));
501   for (int r = 0; r < rows; ++r) {
502     for (int c = 0; c < cols; ++c) {
503       const int this_val = src[r * stride + c];
504       assert(this_val < max_pix_val);
505       if (this_val >= max_pix_val) return 0;
506       ++val_count[this_val];
507     }
508   }
509   int n = 0;
510   for (int i = 0; i < max_pix_val; ++i) {
511     if (val_count[i]) ++n;
512   }
513   return n;
514 }
515 
516 // Extends 'color_map' array from 'orig_width x orig_height' to 'new_width x
517 // new_height'. Extra rows and columns are filled in by copying last valid
518 // row/column.
extend_palette_color_map(uint8_t * const color_map,int orig_width,int orig_height,int new_width,int new_height)519 static AOM_INLINE void extend_palette_color_map(uint8_t *const color_map,
520                                                 int orig_width, int orig_height,
521                                                 int new_width, int new_height) {
522   int j;
523   assert(new_width >= orig_width);
524   assert(new_height >= orig_height);
525   if (new_width == orig_width && new_height == orig_height) return;
526 
527   for (j = orig_height - 1; j >= 0; --j) {
528     memmove(color_map + j * new_width, color_map + j * orig_width, orig_width);
529     // Copy last column to extra columns.
530     memset(color_map + j * new_width + orig_width,
531            color_map[j * new_width + orig_width - 1], new_width - orig_width);
532   }
533   // Copy last row to extra rows.
534   for (j = orig_height; j < new_height; ++j) {
535     memcpy(color_map + j * new_width, color_map + (orig_height - 1) * new_width,
536            new_width);
537   }
538 }
539 
540 // Bias toward using colors in the cache.
541 // TODO(huisu): Try other schemes to improve compression.
optimize_palette_colors(uint16_t * color_cache,int n_cache,int n_colors,int stride,int * centroids)542 static AOM_INLINE void optimize_palette_colors(uint16_t *color_cache,
543                                                int n_cache, int n_colors,
544                                                int stride, int *centroids) {
545   if (n_cache <= 0) return;
546   for (int i = 0; i < n_colors * stride; i += stride) {
547     int min_diff = abs(centroids[i] - (int)color_cache[0]);
548     int idx = 0;
549     for (int j = 1; j < n_cache; ++j) {
550       const int this_diff = abs(centroids[i] - color_cache[j]);
551       if (this_diff < min_diff) {
552         min_diff = this_diff;
553         idx = j;
554       }
555     }
556     if (min_diff <= 1) centroids[i] = color_cache[idx];
557   }
558 }
559 
560 // Given the base colors as specified in centroids[], calculate the RD cost
561 // of palette mode.
palette_rd_y(const AV1_COMP * const cpi,MACROBLOCK * x,MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int dc_mode_cost,const int * data,int * centroids,int n,uint16_t * color_cache,int n_cache,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,int * beat_best_rd,PICK_MODE_CONTEXT * ctx,uint8_t * blk_skip,uint8_t * tx_type_map,int * beat_best_pallette_rd)562 static AOM_INLINE void palette_rd_y(
563     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
564     BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int *centroids, int n,
565     uint16_t *color_cache, int n_cache, MB_MODE_INFO *best_mbmi,
566     uint8_t *best_palette_color_map, int64_t *best_rd, int64_t *best_model_rd,
567     int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
568     int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *blk_skip,
569     uint8_t *tx_type_map, int *beat_best_pallette_rd) {
570   optimize_palette_colors(color_cache, n_cache, n, 1, centroids);
571   const int num_unique_colors = av1_remove_duplicates(centroids, n);
572   if (num_unique_colors < PALETTE_MIN_SIZE) {
573     // Too few unique colors to create a palette. And DC_PRED will work
574     // well for that case anyway. So skip.
575     return;
576   }
577   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
578   if (cpi->common.seq_params.use_highbitdepth) {
579     for (int i = 0; i < num_unique_colors; ++i) {
580       pmi->palette_colors[i] = clip_pixel_highbd(
581           (int)centroids[i], cpi->common.seq_params.bit_depth);
582     }
583   } else {
584     for (int i = 0; i < num_unique_colors; ++i) {
585       pmi->palette_colors[i] = clip_pixel(centroids[i]);
586     }
587   }
588   pmi->palette_size[0] = num_unique_colors;
589   MACROBLOCKD *const xd = &x->e_mbd;
590   uint8_t *const color_map = xd->plane[0].color_index_map;
591   int block_width, block_height, rows, cols;
592   av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
593                            &cols);
594   av1_calc_indices(data, centroids, color_map, rows * cols, num_unique_colors,
595                    1);
596   extend_palette_color_map(color_map, cols, rows, block_width, block_height);
597 
598   const int palette_mode_cost =
599       intra_mode_info_cost_y(cpi, x, mbmi, bsize, dc_mode_cost);
600   if (model_intra_yrd_and_prune(cpi, x, bsize, palette_mode_cost,
601                                 best_model_rd)) {
602     return;
603   }
604 
605   RD_STATS tokenonly_rd_stats;
606   av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
607                                     *best_rd);
608   if (tokenonly_rd_stats.rate == INT_MAX) return;
609   int this_rate = tokenonly_rd_stats.rate + palette_mode_cost;
610   int64_t this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
611   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->sb_type)) {
612     tokenonly_rd_stats.rate -= tx_size_cost(x, bsize, mbmi->tx_size);
613   }
614   // Collect mode stats for multiwinner mode processing
615   const int txfm_search_done = 1;
616   store_winner_mode_stats(
617       &cpi->common, x, mbmi, NULL, NULL, NULL, THR_DC, color_map, bsize,
618       this_rd, cpi->sf.winner_mode_sf.enable_multiwinner_mode_process,
619       txfm_search_done);
620   if (this_rd < *best_rd) {
621     *best_rd = this_rd;
622     // Setting beat_best_rd flag because current mode rd is better than best_rd.
623     // This flag need to be updated only for palette evaluation in key frames
624     if (beat_best_rd) *beat_best_rd = 1;
625     memcpy(best_palette_color_map, color_map,
626            block_width * block_height * sizeof(color_map[0]));
627     *best_mbmi = *mbmi;
628     memcpy(blk_skip, x->blk_skip, sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
629     av1_copy_array(tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
630     if (rate) *rate = this_rate;
631     if (rate_tokenonly) *rate_tokenonly = tokenonly_rd_stats.rate;
632     if (distortion) *distortion = tokenonly_rd_stats.dist;
633     if (skippable) *skippable = tokenonly_rd_stats.skip;
634     if (beat_best_pallette_rd) *beat_best_pallette_rd = 1;
635   }
636 }
637 
perform_top_color_coarse_palette_search(const AV1_COMP * const cpi,MACROBLOCK * x,MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int dc_mode_cost,const int * data,const int * const top_colors,int start_n,int end_n,int step_size,uint16_t * color_cache,int n_cache,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,int * beat_best_rd,PICK_MODE_CONTEXT * ctx,uint8_t * best_blk_skip,uint8_t * tx_type_map)638 static AOM_INLINE int perform_top_color_coarse_palette_search(
639     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
640     BLOCK_SIZE bsize, int dc_mode_cost, const int *data,
641     const int *const top_colors, int start_n, int end_n, int step_size,
642     uint16_t *color_cache, int n_cache, MB_MODE_INFO *best_mbmi,
643     uint8_t *best_palette_color_map, int64_t *best_rd, int64_t *best_model_rd,
644     int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
645     int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip,
646     uint8_t *tx_type_map) {
647   int centroids[PALETTE_MAX_SIZE];
648   int n = start_n;
649   int top_color_winner = end_n + 1;
650   while (1) {
651     int beat_best_pallette_rd = 0;
652     for (int i = 0; i < n; ++i) centroids[i] = top_colors[i];
653     palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
654                  color_cache, n_cache, best_mbmi, best_palette_color_map,
655                  best_rd, best_model_rd, rate, rate_tokenonly, distortion,
656                  skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
657                  &beat_best_pallette_rd);
658     // Break if current palette colors is not winning
659     if (beat_best_pallette_rd) top_color_winner = n;
660     n += step_size;
661     if (n > end_n) break;
662   }
663   return top_color_winner;
664 }
665 
perform_k_means_coarse_palette_search(const AV1_COMP * const cpi,MACROBLOCK * x,MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int dc_mode_cost,const int * data,int lb,int ub,int start_n,int end_n,int step_size,uint16_t * color_cache,int n_cache,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,int * beat_best_rd,PICK_MODE_CONTEXT * ctx,uint8_t * best_blk_skip,uint8_t * tx_type_map,uint8_t * color_map,int data_points)666 static AOM_INLINE int perform_k_means_coarse_palette_search(
667     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
668     BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int lb, int ub,
669     int start_n, int end_n, int step_size, uint16_t *color_cache, int n_cache,
670     MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
671     int64_t *best_model_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
672     int *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
673     uint8_t *best_blk_skip, uint8_t *tx_type_map, uint8_t *color_map,
674     int data_points) {
675   int centroids[PALETTE_MAX_SIZE];
676   const int max_itr = 50;
677   int n = start_n;
678   int k_means_winner = end_n + 1;
679   while (1) {
680     int beat_best_pallette_rd = 0;
681     for (int i = 0; i < n; ++i) {
682       centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
683     }
684     av1_k_means(data, centroids, color_map, data_points, n, 1, max_itr);
685     palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
686                  color_cache, n_cache, best_mbmi, best_palette_color_map,
687                  best_rd, best_model_rd, rate, rate_tokenonly, distortion,
688                  skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
689                  &beat_best_pallette_rd);
690     // Break if current palette colors is not winning
691     if (beat_best_pallette_rd) k_means_winner = n;
692     n += step_size;
693     if (n > end_n) break;
694   }
695   return k_means_winner;
696 }
697 
698 // Perform palette search for top colors from minimum palette colors (/maximum)
699 // with a step-size of 1 (/-1)
perform_top_color_palette_search(const AV1_COMP * const cpi,MACROBLOCK * x,MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int dc_mode_cost,const int * data,int * top_colors,int start_n,int end_n,int step_size,uint16_t * color_cache,int n_cache,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,int * beat_best_rd,PICK_MODE_CONTEXT * ctx,uint8_t * best_blk_skip,uint8_t * tx_type_map)700 static AOM_INLINE int perform_top_color_palette_search(
701     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
702     BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int *top_colors,
703     int start_n, int end_n, int step_size, uint16_t *color_cache, int n_cache,
704     MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
705     int64_t *best_model_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
706     int *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
707     uint8_t *best_blk_skip, uint8_t *tx_type_map) {
708   int centroids[PALETTE_MAX_SIZE];
709   int n = start_n;
710   assert((step_size == -1) || (step_size == 1) || (step_size == 0) ||
711          (step_size == 2));
712   assert(IMPLIES(step_size == -1, start_n > end_n));
713   assert(IMPLIES(step_size == 1, start_n < end_n));
714   while (1) {
715     int beat_best_pallette_rd = 0;
716     for (int i = 0; i < n; ++i) centroids[i] = top_colors[i];
717     palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
718                  color_cache, n_cache, best_mbmi, best_palette_color_map,
719                  best_rd, best_model_rd, rate, rate_tokenonly, distortion,
720                  skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
721                  &beat_best_pallette_rd);
722     // Break if current palette colors is not winning
723     if ((cpi->sf.intra_sf.prune_palette_search_level == 2) &&
724         !beat_best_pallette_rd)
725       return n;
726     n += step_size;
727     if (n == end_n) break;
728   }
729   return n;
730 }
731 // Perform k-means based palette search from minimum palette colors (/maximum)
732 // with a step-size of 1 (/-1)
perform_k_means_palette_search(const AV1_COMP * const cpi,MACROBLOCK * x,MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int dc_mode_cost,const int * data,int lb,int ub,int start_n,int end_n,int step_size,uint16_t * color_cache,int n_cache,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,int * beat_best_rd,PICK_MODE_CONTEXT * ctx,uint8_t * best_blk_skip,uint8_t * tx_type_map,uint8_t * color_map,int data_points)733 static AOM_INLINE int perform_k_means_palette_search(
734     const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
735     BLOCK_SIZE bsize, int dc_mode_cost, const int *data, int lb, int ub,
736     int start_n, int end_n, int step_size, uint16_t *color_cache, int n_cache,
737     MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
738     int64_t *best_model_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
739     int *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
740     uint8_t *best_blk_skip, uint8_t *tx_type_map, uint8_t *color_map,
741     int data_points) {
742   int centroids[PALETTE_MAX_SIZE];
743   const int max_itr = 50;
744   int n = start_n;
745   assert((step_size == -1) || (step_size == 1) || (step_size == 0) ||
746          (step_size == 2));
747   assert(IMPLIES(step_size == -1, start_n > end_n));
748   assert(IMPLIES(step_size == 1, start_n < end_n));
749   while (1) {
750     int beat_best_pallette_rd = 0;
751     for (int i = 0; i < n; ++i) {
752       centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
753     }
754     av1_k_means(data, centroids, color_map, data_points, n, 1, max_itr);
755     palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
756                  color_cache, n_cache, best_mbmi, best_palette_color_map,
757                  best_rd, best_model_rd, rate, rate_tokenonly, distortion,
758                  skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
759                  &beat_best_pallette_rd);
760     // Break if current palette colors is not winning
761     if ((cpi->sf.intra_sf.prune_palette_search_level == 2) &&
762         !beat_best_pallette_rd)
763       return n;
764     n += step_size;
765     if (n == end_n) break;
766   }
767   return n;
768 }
769 
770 #define START_N_STAGE2(x)                         \
771   ((x == PALETTE_MIN_SIZE) ? PALETTE_MIN_SIZE + 1 \
772                            : AOMMAX(x - 1, PALETTE_MIN_SIZE));
773 #define END_N_STAGE2(x, end_n) \
774   ((x == end_n) ? x - 1 : AOMMIN(x + 1, PALETTE_MAX_SIZE));
775 
update_start_end_stage_2(int * start_n_stage2,int * end_n_stage2,int * step_size_stage2,int winner,int end_n)776 static AOM_INLINE void update_start_end_stage_2(int *start_n_stage2,
777                                                 int *end_n_stage2,
778                                                 int *step_size_stage2,
779                                                 int winner, int end_n) {
780   *start_n_stage2 = START_N_STAGE2(winner);
781   *end_n_stage2 = END_N_STAGE2(winner, end_n);
782   *step_size_stage2 = *end_n_stage2 - *start_n_stage2;
783 }
784 
785 // Start index and step size below are chosen to evaluate unique
786 // candidates in neighbor search, in case a winner candidate is found in
787 // coarse search. Example,
788 // 1) 8 colors (end_n = 8): 2,3,4,5,6,7,8. start_n is chosen as 2 and step
789 // size is chosen as 3. Therefore, coarse search will evaluate 2, 5 and 8.
790 // If winner is found at 5, then 4 and 6 are evaluated. Similarly, for 2
791 // (3) and 8 (7).
792 // 2) 7 colors (end_n = 7): 2,3,4,5,6,7. If start_n is chosen as 2 (same
793 // as for 8 colors) then step size should also be 2, to cover all
794 // candidates. Coarse search will evaluate 2, 4 and 6. If winner is either
795 // 2 or 4, 3 will be evaluated. Instead, if start_n=3 and step_size=3,
796 // coarse search will evaluate 3 and 6. For the winner, unique neighbors
797 // (3: 2,4 or 6: 5,7) would be evaluated.
798 
799 // start index for coarse palette search for dominant colors and k-means
800 static const uint8_t start_n_lookup_table[PALETTE_MAX_SIZE + 1] = { 0, 0, 0,
801                                                                     3, 3, 2,
802                                                                     3, 3, 2 };
803 // step size for coarse palette search for dominant colors and k-means
804 static const uint8_t step_size_lookup_table[PALETTE_MAX_SIZE + 1] = { 0, 0, 0,
805                                                                       3, 3, 3,
806                                                                       3, 3, 3 };
807 
rd_pick_palette_intra_sby(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int dc_mode_cost,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,int * beat_best_rd,PICK_MODE_CONTEXT * ctx,uint8_t * best_blk_skip,uint8_t * tx_type_map)808 static void rd_pick_palette_intra_sby(
809     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
810     int dc_mode_cost, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
811     int64_t *best_rd, int64_t *best_model_rd, int *rate, int *rate_tokenonly,
812     int64_t *distortion, int *skippable, int *beat_best_rd,
813     PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip, uint8_t *tx_type_map) {
814   MACROBLOCKD *const xd = &x->e_mbd;
815   MB_MODE_INFO *const mbmi = xd->mi[0];
816   assert(!is_inter_block(mbmi));
817   assert(av1_allow_palette(cpi->common.features.allow_screen_content_tools,
818                            bsize));
819 
820   const int src_stride = x->plane[0].src.stride;
821   const uint8_t *const src = x->plane[0].src.buf;
822   int block_width, block_height, rows, cols;
823   av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
824                            &cols);
825   const SequenceHeader *const seq_params = &cpi->common.seq_params;
826   const int is_hbd = seq_params->use_highbitdepth;
827   const int bit_depth = seq_params->bit_depth;
828   int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
829   int colors;
830   if (is_hbd) {
831     colors = av1_count_colors_highbd(src, src_stride, rows, cols, bit_depth,
832                                      count_buf);
833   } else {
834     colors = av1_count_colors(src, src_stride, rows, cols, count_buf);
835   }
836 
837   uint8_t *const color_map = xd->plane[0].color_index_map;
838   if (colors > 1 && colors <= 64) {
839     int *const data = x->palette_buffer->kmeans_data_buf;
840     int centroids[PALETTE_MAX_SIZE];
841     int lb, ub;
842     if (is_hbd) {
843       int *data_pt = data;
844       const uint16_t *src_pt = CONVERT_TO_SHORTPTR(src);
845       lb = ub = src_pt[0];
846       for (int r = 0; r < rows; ++r) {
847         for (int c = 0; c < cols; ++c) {
848           const int val = src_pt[c];
849           data_pt[c] = val;
850           lb = AOMMIN(lb, val);
851           ub = AOMMAX(ub, val);
852         }
853         src_pt += src_stride;
854         data_pt += cols;
855       }
856     } else {
857       int *data_pt = data;
858       const uint8_t *src_pt = src;
859       lb = ub = src[0];
860       for (int r = 0; r < rows; ++r) {
861         for (int c = 0; c < cols; ++c) {
862           const int val = src_pt[c];
863           data_pt[c] = val;
864           lb = AOMMIN(lb, val);
865           ub = AOMMAX(ub, val);
866         }
867         src_pt += src_stride;
868         data_pt += cols;
869       }
870     }
871 
872     mbmi->mode = DC_PRED;
873     mbmi->filter_intra_mode_info.use_filter_intra = 0;
874 
875     uint16_t color_cache[2 * PALETTE_MAX_SIZE];
876     const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
877 
878     // Find the dominant colors, stored in top_colors[].
879     int top_colors[PALETTE_MAX_SIZE] = { 0 };
880     for (int i = 0; i < AOMMIN(colors, PALETTE_MAX_SIZE); ++i) {
881       int max_count = 0;
882       for (int j = 0; j < (1 << bit_depth); ++j) {
883         if (count_buf[j] > max_count) {
884           max_count = count_buf[j];
885           top_colors[i] = j;
886         }
887       }
888       assert(max_count > 0);
889       count_buf[top_colors[i]] = 0;
890     }
891 
892     // Try the dominant colors directly.
893     // TODO(huisu@google.com): Try to avoid duplicate computation in cases
894     // where the dominant colors and the k-means results are similar.
895     if ((cpi->sf.intra_sf.prune_palette_search_level == 1) &&
896         (colors > PALETTE_MIN_SIZE)) {
897       const int end_n = AOMMIN(colors, PALETTE_MAX_SIZE);
898       assert(PALETTE_MAX_SIZE == 8);
899       assert(PALETTE_MIN_SIZE == 2);
900       // Choose the start index and step size for coarse search based on number
901       // of colors
902       const int start_n = start_n_lookup_table[end_n];
903       const int step_size = step_size_lookup_table[end_n];
904       // Perform top color coarse palette search to find the winner candidate
905       const int top_color_winner = perform_top_color_coarse_palette_search(
906           cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, start_n, end_n,
907           step_size, color_cache, n_cache, best_mbmi, best_palette_color_map,
908           best_rd, best_model_rd, rate, rate_tokenonly, distortion, skippable,
909           beat_best_rd, ctx, best_blk_skip, tx_type_map);
910       // Evaluate neighbors for the winner color (if winner is found) in the
911       // above coarse search for dominant colors
912       if (top_color_winner <= end_n) {
913         int start_n_stage2, end_n_stage2, step_size_stage2;
914         update_start_end_stage_2(&start_n_stage2, &end_n_stage2,
915                                  &step_size_stage2, top_color_winner, end_n);
916         // perform finer search for the winner candidate
917         perform_top_color_palette_search(
918             cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, start_n_stage2,
919             end_n_stage2 + step_size_stage2, step_size_stage2, color_cache,
920             n_cache, best_mbmi, best_palette_color_map, best_rd, best_model_rd,
921             rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
922             best_blk_skip, tx_type_map);
923       }
924       // K-means clustering.
925       // Perform k-means coarse palette search to find the winner candidate
926       const int k_means_winner = perform_k_means_coarse_palette_search(
927           cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, start_n, end_n,
928           step_size, color_cache, n_cache, best_mbmi, best_palette_color_map,
929           best_rd, best_model_rd, rate, rate_tokenonly, distortion, skippable,
930           beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
931           rows * cols);
932       // Evaluate neighbors for the winner color (if winner is found) in the
933       // above coarse search for k-means
934       if (k_means_winner <= end_n) {
935         int start_n_stage2, end_n_stage2, step_size_stage2;
936         update_start_end_stage_2(&start_n_stage2, &end_n_stage2,
937                                  &step_size_stage2, k_means_winner, end_n);
938         // perform finer search for the winner candidate
939         perform_k_means_palette_search(
940             cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, start_n_stage2,
941             end_n_stage2 + step_size_stage2, step_size_stage2, color_cache,
942             n_cache, best_mbmi, best_palette_color_map, best_rd, best_model_rd,
943             rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
944             best_blk_skip, tx_type_map, color_map, rows * cols);
945       }
946     } else {
947       const int start_n = AOMMIN(colors, PALETTE_MAX_SIZE),
948                 end_n = PALETTE_MIN_SIZE;
949       // Perform top color palette search from start_n
950       const int top_color_winner = perform_top_color_palette_search(
951           cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, start_n,
952           end_n - 1, -1, color_cache, n_cache, best_mbmi,
953           best_palette_color_map, best_rd, best_model_rd, rate, rate_tokenonly,
954           distortion, skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map);
955 
956       if (top_color_winner > end_n) {
957         // Perform top color palette search in reverse order for the remaining
958         // colors
959         perform_top_color_palette_search(
960             cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, end_n,
961             top_color_winner, 1, color_cache, n_cache, best_mbmi,
962             best_palette_color_map, best_rd, best_model_rd, rate,
963             rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
964             best_blk_skip, tx_type_map);
965       }
966       // K-means clustering.
967       if (colors == PALETTE_MIN_SIZE) {
968         // Special case: These colors automatically become the centroids.
969         assert(colors == 2);
970         centroids[0] = lb;
971         centroids[1] = ub;
972         palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, colors,
973                      color_cache, n_cache, best_mbmi, best_palette_color_map,
974                      best_rd, best_model_rd, rate, rate_tokenonly, distortion,
975                      skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
976                      NULL);
977       } else {
978         // Perform k-means palette search from start_n
979         const int k_means_winner = perform_k_means_palette_search(
980             cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, start_n, end_n - 1,
981             -1, color_cache, n_cache, best_mbmi, best_palette_color_map,
982             best_rd, best_model_rd, rate, rate_tokenonly, distortion, skippable,
983             beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
984             rows * cols);
985         if (k_means_winner > end_n) {
986           // Perform k-means palette search in reverse order for the remaining
987           // colors
988           perform_k_means_palette_search(
989               cpi, x, mbmi, bsize, dc_mode_cost, data, lb, ub, end_n,
990               k_means_winner, 1, color_cache, n_cache, best_mbmi,
991               best_palette_color_map, best_rd, best_model_rd, rate,
992               rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
993               best_blk_skip, tx_type_map, color_map, rows * cols);
994         }
995       }
996     }
997   }
998 
999   if (best_mbmi->palette_mode_info.palette_size[0] > 0) {
1000     memcpy(color_map, best_palette_color_map,
1001            block_width * block_height * sizeof(best_palette_color_map[0]));
1002   }
1003   *mbmi = *best_mbmi;
1004 }
1005 
rd_pick_palette_intra_sbuv(const AV1_COMP * const cpi,MACROBLOCK * x,int dc_mode_cost,uint8_t * best_palette_color_map,MB_MODE_INFO * const best_mbmi,int64_t * best_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable)1006 static AOM_INLINE void rd_pick_palette_intra_sbuv(
1007     const AV1_COMP *const cpi, MACROBLOCK *x, int dc_mode_cost,
1008     uint8_t *best_palette_color_map, MB_MODE_INFO *const best_mbmi,
1009     int64_t *best_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
1010     int *skippable) {
1011   MACROBLOCKD *const xd = &x->e_mbd;
1012   MB_MODE_INFO *const mbmi = xd->mi[0];
1013   assert(!is_inter_block(mbmi));
1014   assert(av1_allow_palette(cpi->common.features.allow_screen_content_tools,
1015                            mbmi->sb_type));
1016   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
1017   const BLOCK_SIZE bsize = mbmi->sb_type;
1018   const SequenceHeader *const seq_params = &cpi->common.seq_params;
1019   int this_rate;
1020   int64_t this_rd;
1021   int colors_u, colors_v, colors;
1022   const int src_stride = x->plane[1].src.stride;
1023   const uint8_t *const src_u = x->plane[1].src.buf;
1024   const uint8_t *const src_v = x->plane[2].src.buf;
1025   uint8_t *const color_map = xd->plane[1].color_index_map;
1026   RD_STATS tokenonly_rd_stats;
1027   int plane_block_width, plane_block_height, rows, cols;
1028   av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
1029                            &plane_block_height, &rows, &cols);
1030 
1031   mbmi->uv_mode = UV_DC_PRED;
1032 
1033   int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
1034   if (seq_params->use_highbitdepth) {
1035     colors_u = av1_count_colors_highbd(src_u, src_stride, rows, cols,
1036                                        seq_params->bit_depth, count_buf);
1037     colors_v = av1_count_colors_highbd(src_v, src_stride, rows, cols,
1038                                        seq_params->bit_depth, count_buf);
1039   } else {
1040     colors_u = av1_count_colors(src_u, src_stride, rows, cols, count_buf);
1041     colors_v = av1_count_colors(src_v, src_stride, rows, cols, count_buf);
1042   }
1043 
1044   uint16_t color_cache[2 * PALETTE_MAX_SIZE];
1045   const int n_cache = av1_get_palette_cache(xd, 1, color_cache);
1046 
1047   colors = colors_u > colors_v ? colors_u : colors_v;
1048   if (colors > 1 && colors <= 64) {
1049     int r, c, n, i, j;
1050     const int max_itr = 50;
1051     int lb_u, ub_u, val_u;
1052     int lb_v, ub_v, val_v;
1053     int *const data = x->palette_buffer->kmeans_data_buf;
1054     int centroids[2 * PALETTE_MAX_SIZE];
1055 
1056     uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src_u);
1057     uint16_t *src_v16 = CONVERT_TO_SHORTPTR(src_v);
1058     if (seq_params->use_highbitdepth) {
1059       lb_u = src_u16[0];
1060       ub_u = src_u16[0];
1061       lb_v = src_v16[0];
1062       ub_v = src_v16[0];
1063     } else {
1064       lb_u = src_u[0];
1065       ub_u = src_u[0];
1066       lb_v = src_v[0];
1067       ub_v = src_v[0];
1068     }
1069 
1070     for (r = 0; r < rows; ++r) {
1071       for (c = 0; c < cols; ++c) {
1072         if (seq_params->use_highbitdepth) {
1073           val_u = src_u16[r * src_stride + c];
1074           val_v = src_v16[r * src_stride + c];
1075           data[(r * cols + c) * 2] = val_u;
1076           data[(r * cols + c) * 2 + 1] = val_v;
1077         } else {
1078           val_u = src_u[r * src_stride + c];
1079           val_v = src_v[r * src_stride + c];
1080           data[(r * cols + c) * 2] = val_u;
1081           data[(r * cols + c) * 2 + 1] = val_v;
1082         }
1083         if (val_u < lb_u)
1084           lb_u = val_u;
1085         else if (val_u > ub_u)
1086           ub_u = val_u;
1087         if (val_v < lb_v)
1088           lb_v = val_v;
1089         else if (val_v > ub_v)
1090           ub_v = val_v;
1091       }
1092     }
1093 
1094     for (n = colors > PALETTE_MAX_SIZE ? PALETTE_MAX_SIZE : colors; n >= 2;
1095          --n) {
1096       for (i = 0; i < n; ++i) {
1097         centroids[i * 2] = lb_u + (2 * i + 1) * (ub_u - lb_u) / n / 2;
1098         centroids[i * 2 + 1] = lb_v + (2 * i + 1) * (ub_v - lb_v) / n / 2;
1099       }
1100       av1_k_means(data, centroids, color_map, rows * cols, n, 2, max_itr);
1101       optimize_palette_colors(color_cache, n_cache, n, 2, centroids);
1102       // Sort the U channel colors in ascending order.
1103       for (i = 0; i < 2 * (n - 1); i += 2) {
1104         int min_idx = i;
1105         int min_val = centroids[i];
1106         for (j = i + 2; j < 2 * n; j += 2)
1107           if (centroids[j] < min_val) min_val = centroids[j], min_idx = j;
1108         if (min_idx != i) {
1109           int temp_u = centroids[i], temp_v = centroids[i + 1];
1110           centroids[i] = centroids[min_idx];
1111           centroids[i + 1] = centroids[min_idx + 1];
1112           centroids[min_idx] = temp_u, centroids[min_idx + 1] = temp_v;
1113         }
1114       }
1115       av1_calc_indices(data, centroids, color_map, rows * cols, n, 2);
1116       extend_palette_color_map(color_map, cols, rows, plane_block_width,
1117                                plane_block_height);
1118       pmi->palette_size[1] = n;
1119       for (i = 1; i < 3; ++i) {
1120         for (j = 0; j < n; ++j) {
1121           if (seq_params->use_highbitdepth)
1122             pmi->palette_colors[i * PALETTE_MAX_SIZE + j] = clip_pixel_highbd(
1123                 (int)centroids[j * 2 + i - 1], seq_params->bit_depth);
1124           else
1125             pmi->palette_colors[i * PALETTE_MAX_SIZE + j] =
1126                 clip_pixel((int)centroids[j * 2 + i - 1]);
1127         }
1128       }
1129 
1130       av1_txfm_uvrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
1131       if (tokenonly_rd_stats.rate == INT_MAX) continue;
1132       this_rate = tokenonly_rd_stats.rate +
1133                   intra_mode_info_cost_uv(cpi, x, mbmi, bsize, dc_mode_cost);
1134       this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
1135       if (this_rd < *best_rd) {
1136         *best_rd = this_rd;
1137         *best_mbmi = *mbmi;
1138         memcpy(best_palette_color_map, color_map,
1139                plane_block_width * plane_block_height *
1140                    sizeof(best_palette_color_map[0]));
1141         *rate = this_rate;
1142         *distortion = tokenonly_rd_stats.dist;
1143         *rate_tokenonly = tokenonly_rd_stats.rate;
1144         *skippable = tokenonly_rd_stats.skip;
1145       }
1146     }
1147   }
1148   if (best_mbmi->palette_mode_info.palette_size[1] > 0) {
1149     memcpy(color_map, best_palette_color_map,
1150            plane_block_width * plane_block_height *
1151                sizeof(best_palette_color_map[0]));
1152   }
1153 }
1154 
av1_restore_uv_color_map(const AV1_COMP * const cpi,MACROBLOCK * x)1155 void av1_restore_uv_color_map(const AV1_COMP *const cpi, MACROBLOCK *x) {
1156   MACROBLOCKD *const xd = &x->e_mbd;
1157   MB_MODE_INFO *const mbmi = xd->mi[0];
1158   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
1159   const BLOCK_SIZE bsize = mbmi->sb_type;
1160   int src_stride = x->plane[1].src.stride;
1161   const uint8_t *const src_u = x->plane[1].src.buf;
1162   const uint8_t *const src_v = x->plane[2].src.buf;
1163   int *const data = x->palette_buffer->kmeans_data_buf;
1164   int centroids[2 * PALETTE_MAX_SIZE];
1165   uint8_t *const color_map = xd->plane[1].color_index_map;
1166   int r, c;
1167   const uint16_t *const src_u16 = CONVERT_TO_SHORTPTR(src_u);
1168   const uint16_t *const src_v16 = CONVERT_TO_SHORTPTR(src_v);
1169   int plane_block_width, plane_block_height, rows, cols;
1170   av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
1171                            &plane_block_height, &rows, &cols);
1172 
1173   for (r = 0; r < rows; ++r) {
1174     for (c = 0; c < cols; ++c) {
1175       if (cpi->common.seq_params.use_highbitdepth) {
1176         data[(r * cols + c) * 2] = src_u16[r * src_stride + c];
1177         data[(r * cols + c) * 2 + 1] = src_v16[r * src_stride + c];
1178       } else {
1179         data[(r * cols + c) * 2] = src_u[r * src_stride + c];
1180         data[(r * cols + c) * 2 + 1] = src_v[r * src_stride + c];
1181       }
1182     }
1183   }
1184 
1185   for (r = 1; r < 3; ++r) {
1186     for (c = 0; c < pmi->palette_size[1]; ++c) {
1187       centroids[c * 2 + r - 1] = pmi->palette_colors[r * PALETTE_MAX_SIZE + c];
1188     }
1189   }
1190 
1191   av1_calc_indices(data, centroids, color_map, rows * cols,
1192                    pmi->palette_size[1], 2);
1193   extend_palette_color_map(color_map, cols, rows, plane_block_width,
1194                            plane_block_height);
1195 }
1196 
choose_intra_uv_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,TX_SIZE max_tx_size,int * rate_uv,int * rate_uv_tokenonly,int64_t * dist_uv,int * skip_uv,UV_PREDICTION_MODE * mode_uv)1197 static AOM_INLINE void choose_intra_uv_mode(
1198     const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
1199     TX_SIZE max_tx_size, int *rate_uv, int *rate_uv_tokenonly, int64_t *dist_uv,
1200     int *skip_uv, UV_PREDICTION_MODE *mode_uv) {
1201   const AV1_COMMON *const cm = &cpi->common;
1202   MACROBLOCKD *xd = &x->e_mbd;
1203   MB_MODE_INFO *mbmi = xd->mi[0];
1204   // Use an estimated rd for uv_intra based on DC_PRED if the
1205   // appropriate speed flag is set.
1206   init_sbuv_mode(mbmi);
1207   if (!xd->is_chroma_ref) {
1208     *rate_uv = 0;
1209     *rate_uv_tokenonly = 0;
1210     *dist_uv = 0;
1211     *skip_uv = 1;
1212     *mode_uv = UV_DC_PRED;
1213     return;
1214   }
1215 
1216   // Only store reconstructed luma when there's chroma RDO. When there's no
1217   // chroma RDO, the reconstructed luma will be stored in encode_superblock().
1218   xd->cfl.store_y = store_cfl_required_rdo(cm, x);
1219   if (xd->cfl.store_y) {
1220     // Restore reconstructed luma values.
1221     av1_encode_intra_block_plane(cpi, x, mbmi->sb_type, AOM_PLANE_Y,
1222                                  DRY_RUN_NORMAL,
1223                                  cpi->optimize_seg_arr[mbmi->segment_id]);
1224     xd->cfl.store_y = 0;
1225   }
1226   av1_rd_pick_intra_sbuv_mode(cpi, x, rate_uv, rate_uv_tokenonly, dist_uv,
1227                               skip_uv, bsize, max_tx_size);
1228   *mode_uv = mbmi->uv_mode;
1229 }
1230 
1231 // Run RD calculation with given chroma intra prediction angle., and return
1232 // the RD cost. Update the best mode info. if the RD cost is the best so far.
pick_intra_angle_routine_sbuv(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int rate_overhead,int64_t best_rd_in,int * rate,RD_STATS * rd_stats,int * best_angle_delta,int64_t * best_rd)1233 static int64_t pick_intra_angle_routine_sbuv(
1234     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
1235     int rate_overhead, int64_t best_rd_in, int *rate, RD_STATS *rd_stats,
1236     int *best_angle_delta, int64_t *best_rd) {
1237   MB_MODE_INFO *mbmi = x->e_mbd.mi[0];
1238   assert(!is_inter_block(mbmi));
1239   int this_rate;
1240   int64_t this_rd;
1241   RD_STATS tokenonly_rd_stats;
1242 
1243   if (!av1_txfm_uvrd(cpi, x, &tokenonly_rd_stats, bsize, best_rd_in))
1244     return INT64_MAX;
1245   this_rate = tokenonly_rd_stats.rate +
1246               intra_mode_info_cost_uv(cpi, x, mbmi, bsize, rate_overhead);
1247   this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
1248   if (this_rd < *best_rd) {
1249     *best_rd = this_rd;
1250     *best_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
1251     *rate = this_rate;
1252     rd_stats->rate = tokenonly_rd_stats.rate;
1253     rd_stats->dist = tokenonly_rd_stats.dist;
1254     rd_stats->skip = tokenonly_rd_stats.skip;
1255   }
1256   return this_rd;
1257 }
1258 
1259 // With given chroma directional intra prediction mode, pick the best angle
1260 // delta. Return true if a RD cost that is smaller than the input one is found.
rd_pick_intra_angle_sbuv(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int rate_overhead,int64_t best_rd,int * rate,RD_STATS * rd_stats)1261 static int rd_pick_intra_angle_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x,
1262                                     BLOCK_SIZE bsize, int rate_overhead,
1263                                     int64_t best_rd, int *rate,
1264                                     RD_STATS *rd_stats) {
1265   MACROBLOCKD *const xd = &x->e_mbd;
1266   MB_MODE_INFO *mbmi = xd->mi[0];
1267   assert(!is_inter_block(mbmi));
1268   int i, angle_delta, best_angle_delta = 0;
1269   int64_t this_rd, best_rd_in, rd_cost[2 * (MAX_ANGLE_DELTA + 2)];
1270 
1271   rd_stats->rate = INT_MAX;
1272   rd_stats->skip = 0;
1273   rd_stats->dist = INT64_MAX;
1274   for (i = 0; i < 2 * (MAX_ANGLE_DELTA + 2); ++i) rd_cost[i] = INT64_MAX;
1275 
1276   for (angle_delta = 0; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
1277     for (i = 0; i < 2; ++i) {
1278       best_rd_in = (best_rd == INT64_MAX)
1279                        ? INT64_MAX
1280                        : (best_rd + (best_rd >> ((angle_delta == 0) ? 3 : 5)));
1281       mbmi->angle_delta[PLANE_TYPE_UV] = (1 - 2 * i) * angle_delta;
1282       this_rd = pick_intra_angle_routine_sbuv(cpi, x, bsize, rate_overhead,
1283                                               best_rd_in, rate, rd_stats,
1284                                               &best_angle_delta, &best_rd);
1285       rd_cost[2 * angle_delta + i] = this_rd;
1286       if (angle_delta == 0) {
1287         if (this_rd == INT64_MAX) return 0;
1288         rd_cost[1] = this_rd;
1289         break;
1290       }
1291     }
1292   }
1293 
1294   assert(best_rd != INT64_MAX);
1295   for (angle_delta = 1; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
1296     int64_t rd_thresh;
1297     for (i = 0; i < 2; ++i) {
1298       int skip_search = 0;
1299       rd_thresh = best_rd + (best_rd >> 5);
1300       if (rd_cost[2 * (angle_delta + 1) + i] > rd_thresh &&
1301           rd_cost[2 * (angle_delta - 1) + i] > rd_thresh)
1302         skip_search = 1;
1303       if (!skip_search) {
1304         mbmi->angle_delta[PLANE_TYPE_UV] = (1 - 2 * i) * angle_delta;
1305         pick_intra_angle_routine_sbuv(cpi, x, bsize, rate_overhead, best_rd,
1306                                       rate, rd_stats, &best_angle_delta,
1307                                       &best_rd);
1308       }
1309     }
1310   }
1311 
1312   mbmi->angle_delta[PLANE_TYPE_UV] = best_angle_delta;
1313   return rd_stats->rate != INT_MAX;
1314 }
1315 
1316 #define PLANE_SIGN_TO_JOINT_SIGN(plane, a, b) \
1317   (plane == CFL_PRED_U ? a * CFL_SIGNS + b - 1 : b * CFL_SIGNS + a - 1)
cfl_rd_pick_alpha(MACROBLOCK * const x,const AV1_COMP * const cpi,TX_SIZE tx_size,int64_t best_rd)1318 static int cfl_rd_pick_alpha(MACROBLOCK *const x, const AV1_COMP *const cpi,
1319                              TX_SIZE tx_size, int64_t best_rd) {
1320   MACROBLOCKD *const xd = &x->e_mbd;
1321   MB_MODE_INFO *const mbmi = xd->mi[0];
1322   const MACROBLOCKD_PLANE *pd = &xd->plane[AOM_PLANE_U];
1323   const BLOCK_SIZE plane_bsize =
1324       get_plane_block_size(mbmi->sb_type, pd->subsampling_x, pd->subsampling_y);
1325 
1326   assert(is_cfl_allowed(xd) && cpi->oxcf.enable_cfl_intra);
1327   assert(plane_bsize < BLOCK_SIZES_ALL);
1328   if (!xd->lossless[mbmi->segment_id]) {
1329     assert(block_size_wide[plane_bsize] == tx_size_wide[tx_size]);
1330     assert(block_size_high[plane_bsize] == tx_size_high[tx_size]);
1331   }
1332 
1333   xd->cfl.use_dc_pred_cache = 1;
1334   const int64_t mode_rd =
1335       RDCOST(x->rdmult,
1336              x->intra_uv_mode_cost[CFL_ALLOWED][mbmi->mode][UV_CFL_PRED], 0);
1337   int64_t best_rd_uv[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
1338   int best_c[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
1339 #if CONFIG_DEBUG
1340   int best_rate_uv[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
1341 #endif  // CONFIG_DEBUG
1342 
1343   const int skip_trellis = 0;
1344   for (int plane = 0; plane < CFL_PRED_PLANES; plane++) {
1345     RD_STATS rd_stats;
1346     av1_init_rd_stats(&rd_stats);
1347     for (int joint_sign = 0; joint_sign < CFL_JOINT_SIGNS; joint_sign++) {
1348       best_rd_uv[joint_sign][plane] = INT64_MAX;
1349       best_c[joint_sign][plane] = 0;
1350     }
1351     // Collect RD stats for an alpha value of zero in this plane.
1352     // Skip i == CFL_SIGN_ZERO as (0, 0) is invalid.
1353     for (int i = CFL_SIGN_NEG; i < CFL_SIGNS; i++) {
1354       const int8_t joint_sign =
1355           PLANE_SIGN_TO_JOINT_SIGN(plane, CFL_SIGN_ZERO, i);
1356       if (i == CFL_SIGN_NEG) {
1357         mbmi->cfl_alpha_idx = 0;
1358         mbmi->cfl_alpha_signs = joint_sign;
1359         av1_txfm_rd_in_plane(
1360             x, cpi, &rd_stats, best_rd, 0, plane + 1, plane_bsize, tx_size,
1361             cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE, skip_trellis);
1362         if (rd_stats.rate == INT_MAX) break;
1363       }
1364       const int alpha_rate = x->cfl_cost[joint_sign][plane][0];
1365       best_rd_uv[joint_sign][plane] =
1366           RDCOST(x->rdmult, rd_stats.rate + alpha_rate, rd_stats.dist);
1367 #if CONFIG_DEBUG
1368       best_rate_uv[joint_sign][plane] = rd_stats.rate;
1369 #endif  // CONFIG_DEBUG
1370     }
1371   }
1372 
1373   int8_t best_joint_sign = -1;
1374 
1375   for (int plane = 0; plane < CFL_PRED_PLANES; plane++) {
1376     for (int pn_sign = CFL_SIGN_NEG; pn_sign < CFL_SIGNS; pn_sign++) {
1377       int progress = 0;
1378       for (int c = 0; c < CFL_ALPHABET_SIZE; c++) {
1379         int flag = 0;
1380         RD_STATS rd_stats;
1381         if (c > 2 && progress < c) break;
1382         av1_init_rd_stats(&rd_stats);
1383         for (int i = 0; i < CFL_SIGNS; i++) {
1384           const int8_t joint_sign = PLANE_SIGN_TO_JOINT_SIGN(plane, pn_sign, i);
1385           if (i == 0) {
1386             mbmi->cfl_alpha_idx = (c << CFL_ALPHABET_SIZE_LOG2) + c;
1387             mbmi->cfl_alpha_signs = joint_sign;
1388             av1_txfm_rd_in_plane(
1389                 x, cpi, &rd_stats, best_rd, 0, plane + 1, plane_bsize, tx_size,
1390                 cpi->sf.rd_sf.use_fast_coef_costing, FTXS_NONE, skip_trellis);
1391             if (rd_stats.rate == INT_MAX) break;
1392           }
1393           const int alpha_rate = x->cfl_cost[joint_sign][plane][c];
1394           int64_t this_rd =
1395               RDCOST(x->rdmult, rd_stats.rate + alpha_rate, rd_stats.dist);
1396           if (this_rd >= best_rd_uv[joint_sign][plane]) continue;
1397           best_rd_uv[joint_sign][plane] = this_rd;
1398           best_c[joint_sign][plane] = c;
1399 #if CONFIG_DEBUG
1400           best_rate_uv[joint_sign][plane] = rd_stats.rate;
1401 #endif  // CONFIG_DEBUG
1402           flag = 2;
1403           if (best_rd_uv[joint_sign][!plane] == INT64_MAX) continue;
1404           this_rd += mode_rd + best_rd_uv[joint_sign][!plane];
1405           if (this_rd >= best_rd) continue;
1406           best_rd = this_rd;
1407           best_joint_sign = joint_sign;
1408         }
1409         progress += flag;
1410       }
1411     }
1412   }
1413 
1414   int best_rate_overhead = INT_MAX;
1415   uint8_t ind = 0;
1416   if (best_joint_sign >= 0) {
1417     const int u = best_c[best_joint_sign][CFL_PRED_U];
1418     const int v = best_c[best_joint_sign][CFL_PRED_V];
1419     ind = (u << CFL_ALPHABET_SIZE_LOG2) + v;
1420     best_rate_overhead = x->cfl_cost[best_joint_sign][CFL_PRED_U][u] +
1421                          x->cfl_cost[best_joint_sign][CFL_PRED_V][v];
1422 #if CONFIG_DEBUG
1423     xd->cfl.rate = x->intra_uv_mode_cost[CFL_ALLOWED][mbmi->mode][UV_CFL_PRED] +
1424                    best_rate_overhead +
1425                    best_rate_uv[best_joint_sign][CFL_PRED_U] +
1426                    best_rate_uv[best_joint_sign][CFL_PRED_V];
1427 #endif  // CONFIG_DEBUG
1428   } else {
1429     best_joint_sign = 0;
1430   }
1431 
1432   mbmi->cfl_alpha_idx = ind;
1433   mbmi->cfl_alpha_signs = best_joint_sign;
1434   xd->cfl.use_dc_pred_cache = 0;
1435   xd->cfl.dc_pred_is_cached[0] = 0;
1436   xd->cfl.dc_pred_is_cached[1] = 0;
1437   return best_rate_overhead;
1438 }
1439 
av1_rd_pick_intra_sbuv_mode(const AV1_COMP * const cpi,MACROBLOCK * x,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,BLOCK_SIZE bsize,TX_SIZE max_tx_size)1440 int64_t av1_rd_pick_intra_sbuv_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
1441                                     int *rate, int *rate_tokenonly,
1442                                     int64_t *distortion, int *skippable,
1443                                     BLOCK_SIZE bsize, TX_SIZE max_tx_size) {
1444   MACROBLOCKD *xd = &x->e_mbd;
1445   MB_MODE_INFO *mbmi = xd->mi[0];
1446   assert(!is_inter_block(mbmi));
1447   MB_MODE_INFO best_mbmi = *mbmi;
1448   int64_t best_rd = INT64_MAX, this_rd;
1449 
1450   for (int mode_idx = 0; mode_idx < UV_INTRA_MODES; ++mode_idx) {
1451     int this_rate;
1452     RD_STATS tokenonly_rd_stats;
1453     UV_PREDICTION_MODE mode = uv_rd_search_mode_order[mode_idx];
1454     const int is_directional_mode = av1_is_directional_mode(get_uv_mode(mode));
1455     if (!(cpi->sf.intra_sf.intra_uv_mode_mask[txsize_sqr_up_map[max_tx_size]] &
1456           (1 << mode)))
1457       continue;
1458     if (!cpi->oxcf.enable_smooth_intra && mode >= UV_SMOOTH_PRED &&
1459         mode <= UV_SMOOTH_H_PRED)
1460       continue;
1461 
1462     if (!cpi->oxcf.enable_paeth_intra && mode == UV_PAETH_PRED) continue;
1463 
1464     mbmi->uv_mode = mode;
1465     int cfl_alpha_rate = 0;
1466     if (mode == UV_CFL_PRED) {
1467       if (!is_cfl_allowed(xd) || !cpi->oxcf.enable_cfl_intra) continue;
1468       assert(!is_directional_mode);
1469       const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
1470       cfl_alpha_rate = cfl_rd_pick_alpha(x, cpi, uv_tx_size, best_rd);
1471       if (cfl_alpha_rate == INT_MAX) continue;
1472     }
1473     mbmi->angle_delta[PLANE_TYPE_UV] = 0;
1474     if (is_directional_mode && av1_use_angle_delta(mbmi->sb_type) &&
1475         cpi->oxcf.enable_angle_delta) {
1476       const int rate_overhead =
1477           x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][mode];
1478       if (!rd_pick_intra_angle_sbuv(cpi, x, bsize, rate_overhead, best_rd,
1479                                     &this_rate, &tokenonly_rd_stats))
1480         continue;
1481     } else {
1482       if (!av1_txfm_uvrd(cpi, x, &tokenonly_rd_stats, bsize, best_rd)) {
1483         continue;
1484       }
1485     }
1486     const int mode_cost =
1487         x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][mode] +
1488         cfl_alpha_rate;
1489     this_rate = tokenonly_rd_stats.rate +
1490                 intra_mode_info_cost_uv(cpi, x, mbmi, bsize, mode_cost);
1491     if (mode == UV_CFL_PRED) {
1492       assert(is_cfl_allowed(xd) && cpi->oxcf.enable_cfl_intra);
1493 #if CONFIG_DEBUG
1494       if (!xd->lossless[mbmi->segment_id])
1495         assert(xd->cfl.rate == tokenonly_rd_stats.rate + mode_cost);
1496 #endif  // CONFIG_DEBUG
1497     }
1498     this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
1499 
1500     if (this_rd < best_rd) {
1501       best_mbmi = *mbmi;
1502       best_rd = this_rd;
1503       *rate = this_rate;
1504       *rate_tokenonly = tokenonly_rd_stats.rate;
1505       *distortion = tokenonly_rd_stats.dist;
1506       *skippable = tokenonly_rd_stats.skip;
1507     }
1508   }
1509 
1510   const int try_palette =
1511       cpi->oxcf.enable_palette &&
1512       av1_allow_palette(cpi->common.features.allow_screen_content_tools,
1513                         mbmi->sb_type);
1514   if (try_palette) {
1515     uint8_t *best_palette_color_map = x->palette_buffer->best_palette_color_map;
1516     rd_pick_palette_intra_sbuv(
1517         cpi, x,
1518         x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][UV_DC_PRED],
1519         best_palette_color_map, &best_mbmi, &best_rd, rate, rate_tokenonly,
1520         distortion, skippable);
1521   }
1522 
1523   *mbmi = best_mbmi;
1524   // Make sure we actually chose a mode
1525   assert(best_rd < INT64_MAX);
1526   return best_rd;
1527 }
1528 
av1_search_palette_mode(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * this_rd_cost,PICK_MODE_CONTEXT * ctx,BLOCK_SIZE bsize,MB_MODE_INFO * const mbmi,PALETTE_MODE_INFO * const pmi,unsigned int * ref_costs_single,IntraModeSearchState * intra_search_state,int64_t best_rd)1529 int av1_search_palette_mode(const AV1_COMP *cpi, MACROBLOCK *x,
1530                             RD_STATS *this_rd_cost, PICK_MODE_CONTEXT *ctx,
1531                             BLOCK_SIZE bsize, MB_MODE_INFO *const mbmi,
1532                             PALETTE_MODE_INFO *const pmi,
1533                             unsigned int *ref_costs_single,
1534                             IntraModeSearchState *intra_search_state,
1535                             int64_t best_rd) {
1536   const AV1_COMMON *const cm = &cpi->common;
1537   const int num_planes = av1_num_planes(cm);
1538   MACROBLOCKD *const xd = &x->e_mbd;
1539   int rate2 = 0;
1540   int64_t distortion2 = 0, best_rd_palette = best_rd, this_rd,
1541           best_model_rd_palette = INT64_MAX;
1542   int skippable = 0;
1543   TX_SIZE uv_tx = TX_4X4;
1544   uint8_t *const best_palette_color_map =
1545       x->palette_buffer->best_palette_color_map;
1546   uint8_t *const color_map = xd->plane[0].color_index_map;
1547   MB_MODE_INFO best_mbmi_palette = *mbmi;
1548   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
1549   uint8_t best_tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
1550   const int *const intra_mode_cost = x->mbmode_cost[size_group_lookup[bsize]];
1551   const int rows = block_size_high[bsize];
1552   const int cols = block_size_wide[bsize];
1553 
1554   mbmi->mode = DC_PRED;
1555   mbmi->uv_mode = UV_DC_PRED;
1556   mbmi->ref_frame[0] = INTRA_FRAME;
1557   mbmi->ref_frame[1] = NONE_FRAME;
1558   RD_STATS rd_stats_y;
1559   av1_invalid_rd_stats(&rd_stats_y);
1560   rd_pick_palette_intra_sby(
1561       cpi, x, bsize, intra_mode_cost[DC_PRED], &best_mbmi_palette,
1562       best_palette_color_map, &best_rd_palette, &best_model_rd_palette,
1563       &rd_stats_y.rate, NULL, &rd_stats_y.dist, &rd_stats_y.skip, NULL, ctx,
1564       best_blk_skip, best_tx_type_map);
1565   if (rd_stats_y.rate == INT_MAX || pmi->palette_size[0] == 0) {
1566     this_rd_cost->rdcost = INT64_MAX;
1567     return skippable;
1568   }
1569 
1570   memcpy(x->blk_skip, best_blk_skip,
1571          sizeof(best_blk_skip[0]) * bsize_to_num_blk(bsize));
1572   av1_copy_array(xd->tx_type_map, best_tx_type_map, ctx->num_4x4_blk);
1573   memcpy(color_map, best_palette_color_map,
1574          rows * cols * sizeof(best_palette_color_map[0]));
1575 
1576   skippable = rd_stats_y.skip;
1577   distortion2 = rd_stats_y.dist;
1578   rate2 = rd_stats_y.rate + ref_costs_single[INTRA_FRAME];
1579   if (num_planes > 1) {
1580     uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
1581     if (intra_search_state->rate_uv_intra == INT_MAX) {
1582       choose_intra_uv_mode(
1583           cpi, x, bsize, uv_tx, &intra_search_state->rate_uv_intra,
1584           &intra_search_state->rate_uv_tokenonly, &intra_search_state->dist_uvs,
1585           &intra_search_state->skip_uvs, &intra_search_state->mode_uv);
1586       intra_search_state->pmi_uv = *pmi;
1587       intra_search_state->uv_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
1588     }
1589     mbmi->uv_mode = intra_search_state->mode_uv;
1590     pmi->palette_size[1] = intra_search_state->pmi_uv.palette_size[1];
1591     if (pmi->palette_size[1] > 0) {
1592       memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
1593              intra_search_state->pmi_uv.palette_colors + PALETTE_MAX_SIZE,
1594              2 * PALETTE_MAX_SIZE * sizeof(pmi->palette_colors[0]));
1595     }
1596     mbmi->angle_delta[PLANE_TYPE_UV] = intra_search_state->uv_angle_delta;
1597     skippable = skippable && intra_search_state->skip_uvs;
1598     distortion2 += intra_search_state->dist_uvs;
1599     rate2 += intra_search_state->rate_uv_intra;
1600   }
1601 
1602   if (skippable) {
1603     rate2 -= rd_stats_y.rate;
1604     if (num_planes > 1) rate2 -= intra_search_state->rate_uv_tokenonly;
1605     rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
1606   } else {
1607     rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
1608   }
1609   this_rd = RDCOST(x->rdmult, rate2, distortion2);
1610   this_rd_cost->rate = rate2;
1611   this_rd_cost->dist = distortion2;
1612   this_rd_cost->rdcost = this_rd;
1613   return skippable;
1614 }
1615 
1616 // Given selected prediction mode, search for the best tx type and size.
intra_block_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,const int * bmode_costs,int64_t * best_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,MB_MODE_INFO * best_mbmi,PICK_MODE_CONTEXT * ctx)1617 static AOM_INLINE int intra_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
1618                                       BLOCK_SIZE bsize, const int *bmode_costs,
1619                                       int64_t *best_rd, int *rate,
1620                                       int *rate_tokenonly, int64_t *distortion,
1621                                       int *skippable, MB_MODE_INFO *best_mbmi,
1622                                       PICK_MODE_CONTEXT *ctx) {
1623   MACROBLOCKD *const xd = &x->e_mbd;
1624   MB_MODE_INFO *const mbmi = xd->mi[0];
1625   RD_STATS rd_stats;
1626   // In order to improve txfm search avoid rd based breakouts during winner
1627   // mode evaluation. Hence passing ref_best_rd as a maximum value
1628   av1_pick_uniform_tx_size_type_yrd(cpi, x, &rd_stats, bsize, INT64_MAX);
1629   if (rd_stats.rate == INT_MAX) return 0;
1630   int this_rate_tokenonly = rd_stats.rate;
1631   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->sb_type)) {
1632     // av1_pick_uniform_tx_size_type_yrd above includes the cost of the tx_size
1633     // in the tokenonly rate, but for intra blocks, tx_size is always coded
1634     // (prediction granularity), so we account for it in the full rate,
1635     // not the tokenonly rate.
1636     this_rate_tokenonly -= tx_size_cost(x, bsize, mbmi->tx_size);
1637   }
1638   const int this_rate =
1639       rd_stats.rate +
1640       intra_mode_info_cost_y(cpi, x, mbmi, bsize, bmode_costs[mbmi->mode]);
1641   const int64_t this_rd = RDCOST(x->rdmult, this_rate, rd_stats.dist);
1642   if (this_rd < *best_rd) {
1643     *best_mbmi = *mbmi;
1644     *best_rd = this_rd;
1645     *rate = this_rate;
1646     *rate_tokenonly = this_rate_tokenonly;
1647     *distortion = rd_stats.dist;
1648     *skippable = rd_stats.skip;
1649     av1_copy_array(ctx->blk_skip, x->blk_skip, ctx->num_4x4_blk);
1650     av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
1651     return 1;
1652   }
1653   return 0;
1654 }
1655 
1656 // With given luma directional intra prediction mode, pick the best angle delta
1657 // Return the RD cost corresponding to the best angle delta.
rd_pick_intra_angle_sby(const AV1_COMP * const cpi,MACROBLOCK * x,int * rate,RD_STATS * rd_stats,BLOCK_SIZE bsize,int mode_cost,int64_t best_rd,int64_t * best_model_rd,int skip_model_rd_for_zero_deg)1658 static int64_t rd_pick_intra_angle_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
1659                                        int *rate, RD_STATS *rd_stats,
1660                                        BLOCK_SIZE bsize, int mode_cost,
1661                                        int64_t best_rd, int64_t *best_model_rd,
1662                                        int skip_model_rd_for_zero_deg) {
1663   MACROBLOCKD *xd = &x->e_mbd;
1664   MB_MODE_INFO *mbmi = xd->mi[0];
1665   assert(!is_inter_block(mbmi));
1666 
1667   int best_angle_delta = 0;
1668   int64_t rd_cost[2 * (MAX_ANGLE_DELTA + 2)];
1669   TX_SIZE best_tx_size = mbmi->tx_size;
1670   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
1671   uint8_t best_tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
1672 
1673   for (int i = 0; i < 2 * (MAX_ANGLE_DELTA + 2); ++i) rd_cost[i] = INT64_MAX;
1674 
1675   int first_try = 1;
1676   for (int angle_delta = 0; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
1677     for (int i = 0; i < 2; ++i) {
1678       const int64_t best_rd_in =
1679           (best_rd == INT64_MAX) ? INT64_MAX
1680                                  : (best_rd + (best_rd >> (first_try ? 3 : 5)));
1681       const int64_t this_rd = calc_rd_given_intra_angle(
1682           cpi, x, bsize, mode_cost, best_rd_in, (1 - 2 * i) * angle_delta,
1683           MAX_ANGLE_DELTA, rate, rd_stats, &best_angle_delta, &best_tx_size,
1684           &best_rd, best_model_rd, best_tx_type_map, best_blk_skip,
1685           (skip_model_rd_for_zero_deg & !angle_delta));
1686       rd_cost[2 * angle_delta + i] = this_rd;
1687       if (first_try && this_rd == INT64_MAX) return best_rd;
1688       first_try = 0;
1689       if (angle_delta == 0) {
1690         rd_cost[1] = this_rd;
1691         break;
1692       }
1693     }
1694   }
1695 
1696   assert(best_rd != INT64_MAX);
1697   for (int angle_delta = 1; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
1698     for (int i = 0; i < 2; ++i) {
1699       int skip_search = 0;
1700       const int64_t rd_thresh = best_rd + (best_rd >> 5);
1701       if (rd_cost[2 * (angle_delta + 1) + i] > rd_thresh &&
1702           rd_cost[2 * (angle_delta - 1) + i] > rd_thresh)
1703         skip_search = 1;
1704       if (!skip_search) {
1705         calc_rd_given_intra_angle(
1706             cpi, x, bsize, mode_cost, best_rd, (1 - 2 * i) * angle_delta,
1707             MAX_ANGLE_DELTA, rate, rd_stats, &best_angle_delta, &best_tx_size,
1708             &best_rd, best_model_rd, best_tx_type_map, best_blk_skip, 0);
1709       }
1710     }
1711   }
1712 
1713   if (rd_stats->rate != INT_MAX) {
1714     mbmi->tx_size = best_tx_size;
1715     mbmi->angle_delta[PLANE_TYPE_Y] = best_angle_delta;
1716     const int n4 = bsize_to_num_blk(bsize);
1717     memcpy(x->blk_skip, best_blk_skip, sizeof(best_blk_skip[0]) * n4);
1718     av1_copy_array(xd->tx_type_map, best_tx_type_map, n4);
1719   }
1720   return best_rd;
1721 }
1722 
av1_handle_intra_mode(IntraModeSearchState * intra_search_state,const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int ref_frame_cost,const PICK_MODE_CONTEXT * ctx,int disable_skip,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int64_t best_rd,int64_t * best_intra_rd,int8_t best_mbmode_skip)1723 int64_t av1_handle_intra_mode(IntraModeSearchState *intra_search_state,
1724                               const AV1_COMP *cpi, MACROBLOCK *x,
1725                               BLOCK_SIZE bsize, int ref_frame_cost,
1726                               const PICK_MODE_CONTEXT *ctx, int disable_skip,
1727                               RD_STATS *rd_stats, RD_STATS *rd_stats_y,
1728                               RD_STATS *rd_stats_uv, int64_t best_rd,
1729                               int64_t *best_intra_rd, int8_t best_mbmode_skip) {
1730   const AV1_COMMON *cm = &cpi->common;
1731   const SPEED_FEATURES *const sf = &cpi->sf;
1732   MACROBLOCKD *const xd = &x->e_mbd;
1733   MB_MODE_INFO *const mbmi = xd->mi[0];
1734   assert(mbmi->ref_frame[0] == INTRA_FRAME);
1735   const PREDICTION_MODE mode = mbmi->mode;
1736   const int mode_cost =
1737       x->mbmode_cost[size_group_lookup[bsize]][mode] + ref_frame_cost;
1738   const int intra_cost_penalty = av1_get_intra_cost_penalty(
1739       cm->quant_params.base_qindex, cm->quant_params.y_dc_delta_q,
1740       cm->seq_params.bit_depth);
1741   const int skip_ctx = av1_get_skip_context(xd);
1742 
1743   int known_rate = mode_cost;
1744   known_rate += ref_frame_cost;
1745   if (mode != DC_PRED && mode != PAETH_PRED) known_rate += intra_cost_penalty;
1746   known_rate += AOMMIN(x->skip_cost[skip_ctx][0], x->skip_cost[skip_ctx][1]);
1747   const int64_t known_rd = RDCOST(x->rdmult, known_rate, 0);
1748   if (known_rd > best_rd) {
1749     intra_search_state->skip_intra_modes = 1;
1750     return INT64_MAX;
1751   }
1752 
1753   const int is_directional_mode = av1_is_directional_mode(mode);
1754   if (is_directional_mode && av1_use_angle_delta(bsize) &&
1755       cpi->oxcf.enable_angle_delta) {
1756     if (sf->intra_sf.intra_pruning_with_hog &&
1757         !intra_search_state->angle_stats_ready) {
1758       prune_intra_mode_with_hog(x, bsize,
1759                                 cpi->sf.intra_sf.intra_pruning_with_hog_thresh,
1760                                 intra_search_state->directional_mode_skip_mask);
1761       intra_search_state->angle_stats_ready = 1;
1762     }
1763     if (intra_search_state->directional_mode_skip_mask[mode]) return INT64_MAX;
1764     av1_init_rd_stats(rd_stats_y);
1765     rd_stats_y->rate = INT_MAX;
1766     int64_t model_rd = INT64_MAX;
1767     int rate_dummy;
1768     rd_pick_intra_angle_sby(cpi, x, &rate_dummy, rd_stats_y, bsize, mode_cost,
1769                             best_rd, &model_rd, 0);
1770 
1771   } else {
1772     av1_init_rd_stats(rd_stats_y);
1773     mbmi->angle_delta[PLANE_TYPE_Y] = 0;
1774     av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, best_rd);
1775   }
1776 
1777   // Pick filter intra modes.
1778   if (mode == DC_PRED && av1_filter_intra_allowed_bsize(cm, bsize)) {
1779     int try_filter_intra = 0;
1780     int64_t best_rd_so_far = INT64_MAX;
1781     if (rd_stats_y->rate != INT_MAX) {
1782       const int tmp_rate =
1783           rd_stats_y->rate + x->filter_intra_cost[bsize][0] + mode_cost;
1784       best_rd_so_far = RDCOST(x->rdmult, tmp_rate, rd_stats_y->dist);
1785       try_filter_intra = (best_rd_so_far / 2) <= best_rd;
1786     } else {
1787       try_filter_intra = !best_mbmode_skip;
1788     }
1789 
1790     if (try_filter_intra) {
1791       RD_STATS rd_stats_y_fi;
1792       int filter_intra_selected_flag = 0;
1793       TX_SIZE best_tx_size = mbmi->tx_size;
1794       FILTER_INTRA_MODE best_fi_mode = FILTER_DC_PRED;
1795       uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
1796       memcpy(best_blk_skip, x->blk_skip,
1797              sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
1798       uint8_t best_tx_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
1799       av1_copy_array(best_tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
1800       mbmi->filter_intra_mode_info.use_filter_intra = 1;
1801       for (FILTER_INTRA_MODE fi_mode = FILTER_DC_PRED;
1802            fi_mode < FILTER_INTRA_MODES; ++fi_mode) {
1803         mbmi->filter_intra_mode_info.filter_intra_mode = fi_mode;
1804         av1_pick_uniform_tx_size_type_yrd(cpi, x, &rd_stats_y_fi, bsize,
1805                                           best_rd);
1806         if (rd_stats_y_fi.rate == INT_MAX) continue;
1807         const int this_rate_tmp =
1808             rd_stats_y_fi.rate +
1809             intra_mode_info_cost_y(cpi, x, mbmi, bsize, mode_cost);
1810         const int64_t this_rd_tmp =
1811             RDCOST(x->rdmult, this_rate_tmp, rd_stats_y_fi.dist);
1812 
1813         if (this_rd_tmp != INT64_MAX && this_rd_tmp / 2 > best_rd) {
1814           break;
1815         }
1816         if (this_rd_tmp < best_rd_so_far) {
1817           best_tx_size = mbmi->tx_size;
1818           av1_copy_array(best_tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
1819           memcpy(best_blk_skip, x->blk_skip,
1820                  sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
1821           best_fi_mode = fi_mode;
1822           *rd_stats_y = rd_stats_y_fi;
1823           filter_intra_selected_flag = 1;
1824           best_rd_so_far = this_rd_tmp;
1825         }
1826       }
1827 
1828       mbmi->tx_size = best_tx_size;
1829       av1_copy_array(xd->tx_type_map, best_tx_type_map, ctx->num_4x4_blk);
1830       memcpy(x->blk_skip, best_blk_skip,
1831              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
1832 
1833       if (filter_intra_selected_flag) {
1834         mbmi->filter_intra_mode_info.use_filter_intra = 1;
1835         mbmi->filter_intra_mode_info.filter_intra_mode = best_fi_mode;
1836       } else {
1837         mbmi->filter_intra_mode_info.use_filter_intra = 0;
1838       }
1839     }
1840   }
1841 
1842   if (rd_stats_y->rate == INT_MAX) return INT64_MAX;
1843 
1844   const int mode_cost_y =
1845       intra_mode_info_cost_y(cpi, x, mbmi, bsize, mode_cost);
1846   av1_init_rd_stats(rd_stats);
1847   av1_init_rd_stats(rd_stats_uv);
1848   const int num_planes = av1_num_planes(cm);
1849   if (num_planes > 1) {
1850     PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
1851     const int try_palette =
1852         cpi->oxcf.enable_palette &&
1853         av1_allow_palette(cm->features.allow_screen_content_tools,
1854                           mbmi->sb_type);
1855     const TX_SIZE uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
1856     if (intra_search_state->rate_uv_intra == INT_MAX) {
1857       const int rate_y =
1858           rd_stats_y->skip ? x->skip_cost[skip_ctx][1] : rd_stats_y->rate;
1859       const int64_t rdy =
1860           RDCOST(x->rdmult, rate_y + mode_cost_y, rd_stats_y->dist);
1861       if (best_rd < (INT64_MAX / 2) && rdy > (best_rd + (best_rd >> 2))) {
1862         intra_search_state->skip_intra_modes = 1;
1863         return INT64_MAX;
1864       }
1865       choose_intra_uv_mode(
1866           cpi, x, bsize, uv_tx, &intra_search_state->rate_uv_intra,
1867           &intra_search_state->rate_uv_tokenonly, &intra_search_state->dist_uvs,
1868           &intra_search_state->skip_uvs, &intra_search_state->mode_uv);
1869       if (try_palette) intra_search_state->pmi_uv = *pmi;
1870       intra_search_state->uv_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
1871 
1872       const int uv_rate = intra_search_state->rate_uv_tokenonly;
1873       const int64_t uv_dist = intra_search_state->dist_uvs;
1874       const int64_t uv_rd = RDCOST(x->rdmult, uv_rate, uv_dist);
1875       if (uv_rd > best_rd) {
1876         intra_search_state->skip_intra_modes = 1;
1877         return INT64_MAX;
1878       }
1879     }
1880 
1881     rd_stats_uv->rate = intra_search_state->rate_uv_tokenonly;
1882     rd_stats_uv->dist = intra_search_state->dist_uvs;
1883     rd_stats_uv->skip = intra_search_state->skip_uvs;
1884     rd_stats->skip = rd_stats_y->skip && rd_stats_uv->skip;
1885     mbmi->uv_mode = intra_search_state->mode_uv;
1886     if (try_palette) {
1887       pmi->palette_size[1] = intra_search_state->pmi_uv.palette_size[1];
1888       memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
1889              intra_search_state->pmi_uv.palette_colors + PALETTE_MAX_SIZE,
1890              2 * PALETTE_MAX_SIZE * sizeof(pmi->palette_colors[0]));
1891     }
1892     mbmi->angle_delta[PLANE_TYPE_UV] = intra_search_state->uv_angle_delta;
1893   }
1894 
1895   rd_stats->rate = rd_stats_y->rate + mode_cost_y;
1896   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(bsize)) {
1897     // av1_pick_uniform_tx_size_type_yrd above includes the cost of the tx_size
1898     // in the tokenonly rate, but for intra blocks, tx_size is always coded
1899     // (prediction granularity), so we account for it in the full rate,
1900     // not the tokenonly rate.
1901     rd_stats_y->rate -= tx_size_cost(x, bsize, mbmi->tx_size);
1902   }
1903   if (num_planes > 1 && xd->is_chroma_ref) {
1904     const int uv_mode_cost =
1905         x->intra_uv_mode_cost[is_cfl_allowed(xd)][mode][mbmi->uv_mode];
1906     rd_stats->rate +=
1907         rd_stats_uv->rate +
1908         intra_mode_info_cost_uv(cpi, x, mbmi, bsize, uv_mode_cost);
1909   }
1910   if (mode != DC_PRED && mode != PAETH_PRED) {
1911     rd_stats->rate += intra_cost_penalty;
1912   }
1913 
1914   // Intra block is always coded as non-skip
1915   rd_stats->skip = 0;
1916   rd_stats->dist = rd_stats_y->dist + rd_stats_uv->dist;
1917   // Add in the cost of the no skip flag.
1918   rd_stats->rate += x->skip_cost[skip_ctx][0];
1919   // Calculate the final RD estimate for this mode.
1920   const int64_t this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
1921   // Keep record of best intra rd
1922   if (this_rd < *best_intra_rd) {
1923     *best_intra_rd = this_rd;
1924     intra_search_state->best_intra_mode = mode;
1925   }
1926 
1927   if (sf->intra_sf.skip_intra_in_interframe) {
1928     if (best_rd < (INT64_MAX / 2) && this_rd > (best_rd + (best_rd >> 1)))
1929       intra_search_state->skip_intra_modes = 1;
1930   }
1931 
1932   if (!disable_skip) {
1933     for (int i = 0; i < REFERENCE_MODES; ++i) {
1934       intra_search_state->best_pred_rd[i] =
1935           AOMMIN(intra_search_state->best_pred_rd[i], this_rd);
1936     }
1937   }
1938   return this_rd;
1939 }
1940 
1941 // This function is used only for intra_only frames
av1_rd_pick_intra_sby_mode(const AV1_COMP * const cpi,MACROBLOCK * x,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,BLOCK_SIZE bsize,int64_t best_rd,PICK_MODE_CONTEXT * ctx)1942 int64_t av1_rd_pick_intra_sby_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
1943                                    int *rate, int *rate_tokenonly,
1944                                    int64_t *distortion, int *skippable,
1945                                    BLOCK_SIZE bsize, int64_t best_rd,
1946                                    PICK_MODE_CONTEXT *ctx) {
1947   MACROBLOCKD *const xd = &x->e_mbd;
1948   MB_MODE_INFO *const mbmi = xd->mi[0];
1949   assert(!is_inter_block(mbmi));
1950   int64_t best_model_rd = INT64_MAX;
1951   int is_directional_mode;
1952   uint8_t directional_mode_skip_mask[INTRA_MODES] = { 0 };
1953   // Flag to check rd of any intra mode is better than best_rd passed to this
1954   // function
1955   int beat_best_rd = 0;
1956   const int *bmode_costs;
1957   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
1958   const int try_palette =
1959       cpi->oxcf.enable_palette &&
1960       av1_allow_palette(cpi->common.features.allow_screen_content_tools,
1961                         mbmi->sb_type);
1962   uint8_t *best_palette_color_map =
1963       try_palette ? x->palette_buffer->best_palette_color_map : NULL;
1964   const MB_MODE_INFO *above_mi = xd->above_mbmi;
1965   const MB_MODE_INFO *left_mi = xd->left_mbmi;
1966   const PREDICTION_MODE A = av1_above_block_mode(above_mi);
1967   const PREDICTION_MODE L = av1_left_block_mode(left_mi);
1968   const int above_ctx = intra_mode_context[A];
1969   const int left_ctx = intra_mode_context[L];
1970   bmode_costs = x->y_mode_costs[above_ctx][left_ctx];
1971 
1972   mbmi->angle_delta[PLANE_TYPE_Y] = 0;
1973   if (cpi->sf.intra_sf.intra_pruning_with_hog) {
1974     prune_intra_mode_with_hog(x, bsize,
1975                               cpi->sf.intra_sf.intra_pruning_with_hog_thresh,
1976                               directional_mode_skip_mask);
1977   }
1978   mbmi->filter_intra_mode_info.use_filter_intra = 0;
1979   pmi->palette_size[0] = 0;
1980 
1981   // Set params for mode evaluation
1982   set_mode_eval_params(cpi, x, MODE_EVAL);
1983 
1984   MB_MODE_INFO best_mbmi = *mbmi;
1985   av1_zero(x->winner_mode_stats);
1986   x->winner_mode_count = 0;
1987 
1988   /* Y Search for intra prediction mode */
1989   for (int mode_idx = INTRA_MODE_START; mode_idx < INTRA_MODE_END; ++mode_idx) {
1990     RD_STATS this_rd_stats;
1991     int this_rate, this_rate_tokenonly, s;
1992     int64_t this_distortion, this_rd;
1993     mbmi->mode = intra_rd_search_mode_order[mode_idx];
1994     if ((!cpi->oxcf.enable_smooth_intra ||
1995          cpi->sf.intra_sf.disable_smooth_intra) &&
1996         (mbmi->mode == SMOOTH_PRED || mbmi->mode == SMOOTH_H_PRED ||
1997          mbmi->mode == SMOOTH_V_PRED))
1998       continue;
1999     if (!cpi->oxcf.enable_paeth_intra && mbmi->mode == PAETH_PRED) continue;
2000     mbmi->angle_delta[PLANE_TYPE_Y] = 0;
2001 
2002     if (model_intra_yrd_and_prune(cpi, x, bsize, bmode_costs[mbmi->mode],
2003                                   &best_model_rd)) {
2004       continue;
2005     }
2006 
2007     is_directional_mode = av1_is_directional_mode(mbmi->mode);
2008     if (is_directional_mode && directional_mode_skip_mask[mbmi->mode]) continue;
2009     if (is_directional_mode && av1_use_angle_delta(bsize) &&
2010         cpi->oxcf.enable_angle_delta) {
2011       this_rd_stats.rate = INT_MAX;
2012       rd_pick_intra_angle_sby(cpi, x, &this_rate, &this_rd_stats, bsize,
2013                               bmode_costs[mbmi->mode], best_rd, &best_model_rd,
2014                               1);
2015     } else {
2016       av1_pick_uniform_tx_size_type_yrd(cpi, x, &this_rd_stats, bsize, best_rd);
2017     }
2018     this_rate_tokenonly = this_rd_stats.rate;
2019     this_distortion = this_rd_stats.dist;
2020     s = this_rd_stats.skip;
2021 
2022     if (this_rate_tokenonly == INT_MAX) continue;
2023 
2024     if (!xd->lossless[mbmi->segment_id] &&
2025         block_signals_txsize(mbmi->sb_type)) {
2026       // av1_pick_uniform_tx_size_type_yrd above includes the cost of the
2027       // tx_size in the tokenonly rate, but for intra blocks, tx_size is always
2028       // coded (prediction granularity), so we account for it in the full rate,
2029       // not the tokenonly rate.
2030       this_rate_tokenonly -= tx_size_cost(x, bsize, mbmi->tx_size);
2031     }
2032     this_rate =
2033         this_rd_stats.rate +
2034         intra_mode_info_cost_y(cpi, x, mbmi, bsize, bmode_costs[mbmi->mode]);
2035     this_rd = RDCOST(x->rdmult, this_rate, this_distortion);
2036     // Collect mode stats for multiwinner mode processing
2037     const int txfm_search_done = 1;
2038     store_winner_mode_stats(
2039         &cpi->common, x, mbmi, NULL, NULL, NULL, 0, NULL, bsize, this_rd,
2040         cpi->sf.winner_mode_sf.enable_multiwinner_mode_process,
2041         txfm_search_done);
2042     if (this_rd < best_rd) {
2043       best_mbmi = *mbmi;
2044       best_rd = this_rd;
2045       // Setting beat_best_rd flag because current mode rd is better than
2046       // best_rd passed to this function
2047       beat_best_rd = 1;
2048       *rate = this_rate;
2049       *rate_tokenonly = this_rate_tokenonly;
2050       *distortion = this_distortion;
2051       *skippable = s;
2052       memcpy(ctx->blk_skip, x->blk_skip,
2053              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
2054       av1_copy_array(ctx->tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
2055     }
2056   }
2057 
2058   if (try_palette) {
2059     rd_pick_palette_intra_sby(
2060         cpi, x, bsize, bmode_costs[DC_PRED], &best_mbmi, best_palette_color_map,
2061         &best_rd, &best_model_rd, rate, rate_tokenonly, distortion, skippable,
2062         &beat_best_rd, ctx, ctx->blk_skip, ctx->tx_type_map);
2063   }
2064 
2065   if (beat_best_rd && av1_filter_intra_allowed_bsize(&cpi->common, bsize)) {
2066     if (rd_pick_filter_intra_sby(cpi, x, rate, rate_tokenonly, distortion,
2067                                  skippable, bsize, bmode_costs[DC_PRED],
2068                                  &best_rd, &best_model_rd, ctx)) {
2069       best_mbmi = *mbmi;
2070     }
2071   }
2072   // No mode is identified with less rd value than best_rd passed to this
2073   // function. In such cases winner mode processing is not necessary and return
2074   // best_rd as INT64_MAX to indicate best mode is not identified
2075   if (!beat_best_rd) return INT64_MAX;
2076 
2077   // In multi-winner mode processing, perform tx search for few best modes
2078   // identified during mode evaluation. Winner mode processing uses best tx
2079   // configuration for tx search.
2080   if (cpi->sf.winner_mode_sf.enable_multiwinner_mode_process) {
2081     int best_mode_idx = 0;
2082     int block_width, block_height;
2083     uint8_t *color_map_dst = xd->plane[PLANE_TYPE_Y].color_index_map;
2084     av1_get_block_dimensions(bsize, AOM_PLANE_Y, xd, &block_width,
2085                              &block_height, NULL, NULL);
2086 
2087     for (int mode_idx = 0; mode_idx < x->winner_mode_count; mode_idx++) {
2088       *mbmi = x->winner_mode_stats[mode_idx].mbmi;
2089       if (is_winner_mode_processing_enabled(cpi, mbmi, mbmi->mode)) {
2090         // Restore color_map of palette mode before winner mode processing
2091         if (mbmi->palette_mode_info.palette_size[0] > 0) {
2092           uint8_t *color_map_src =
2093               x->winner_mode_stats[mode_idx].color_index_map;
2094           memcpy(color_map_dst, color_map_src,
2095                  block_width * block_height * sizeof(*color_map_src));
2096         }
2097         // Set params for winner mode evaluation
2098         set_mode_eval_params(cpi, x, WINNER_MODE_EVAL);
2099 
2100         // Winner mode processing
2101         // If previous searches use only the default tx type/no R-D optimization
2102         // of quantized coeffs, do an extra search for the best tx type/better
2103         // R-D optimization of quantized coeffs
2104         if (intra_block_yrd(cpi, x, bsize, bmode_costs, &best_rd, rate,
2105                             rate_tokenonly, distortion, skippable, &best_mbmi,
2106                             ctx))
2107           best_mode_idx = mode_idx;
2108       }
2109     }
2110     // Copy color_map of palette mode for final winner mode
2111     if (best_mbmi.palette_mode_info.palette_size[0] > 0) {
2112       uint8_t *color_map_src =
2113           x->winner_mode_stats[best_mode_idx].color_index_map;
2114       memcpy(color_map_dst, color_map_src,
2115              block_width * block_height * sizeof(*color_map_src));
2116     }
2117   } else {
2118     // If previous searches use only the default tx type/no R-D optimization of
2119     // quantized coeffs, do an extra search for the best tx type/better R-D
2120     // optimization of quantized coeffs
2121     if (is_winner_mode_processing_enabled(cpi, mbmi, best_mbmi.mode)) {
2122       // Set params for winner mode evaluation
2123       set_mode_eval_params(cpi, x, WINNER_MODE_EVAL);
2124       *mbmi = best_mbmi;
2125       intra_block_yrd(cpi, x, bsize, bmode_costs, &best_rd, rate,
2126                       rate_tokenonly, distortion, skippable, &best_mbmi, ctx);
2127     }
2128   }
2129   *mbmi = best_mbmi;
2130   av1_copy_array(xd->tx_type_map, ctx->tx_type_map, ctx->num_4x4_blk);
2131   return best_rd;
2132 }
2133