1 /*
2 * Copyright (c) 2022, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11 #include "av1/qmode_rc/ratectrl_qmode.h"
12
13 #include <algorithm>
14 #include <cassert>
15 #include <climits>
16 #include <functional>
17 #include <numeric>
18 #include <sstream>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "aom/aom_codec.h"
24 #include "av1/encoder/pass2_strategy.h"
25 #include "av1/encoder/tpl_model.h"
26
27 namespace aom {
28
29 // This is used before division to ensure that the divisor isn't zero or
30 // too close to zero.
ModifyDivisor(double divisor)31 static double ModifyDivisor(double divisor) {
32 const double kEpsilon = 0.0000001;
33 return (divisor < 0 ? std::min(divisor, -kEpsilon)
34 : std::max(divisor, kEpsilon));
35 }
36
GopFrameInvalid()37 GopFrame GopFrameInvalid() {
38 GopFrame gop_frame = {};
39 gop_frame.is_valid = false;
40 gop_frame.coding_idx = -1;
41 gop_frame.order_idx = -1;
42 return gop_frame;
43 }
44
SetGopFrameByType(GopFrameType gop_frame_type,GopFrame * gop_frame)45 void SetGopFrameByType(GopFrameType gop_frame_type, GopFrame *gop_frame) {
46 gop_frame->update_type = gop_frame_type;
47 switch (gop_frame_type) {
48 case GopFrameType::kRegularKey:
49 gop_frame->is_key_frame = 1;
50 gop_frame->is_arf_frame = 0;
51 gop_frame->is_show_frame = 1;
52 gop_frame->is_golden_frame = 1;
53 gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
54 break;
55 case GopFrameType::kRegularGolden:
56 gop_frame->is_key_frame = 0;
57 gop_frame->is_arf_frame = 0;
58 gop_frame->is_show_frame = 1;
59 gop_frame->is_golden_frame = 1;
60 gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
61 break;
62 case GopFrameType::kRegularArf:
63 gop_frame->is_key_frame = 0;
64 gop_frame->is_arf_frame = 1;
65 gop_frame->is_show_frame = 0;
66 gop_frame->is_golden_frame = 1;
67 gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
68 break;
69 case GopFrameType::kIntermediateArf:
70 gop_frame->is_key_frame = 0;
71 gop_frame->is_arf_frame = 1;
72 gop_frame->is_show_frame = 0;
73 gop_frame->is_golden_frame = gop_frame->layer_depth <= 2 ? 1 : 0;
74 gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
75 break;
76 case GopFrameType::kRegularLeaf:
77 gop_frame->is_key_frame = 0;
78 gop_frame->is_arf_frame = 0;
79 gop_frame->is_show_frame = 1;
80 gop_frame->is_golden_frame = 0;
81 gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
82 break;
83 case GopFrameType::kIntermediateOverlay:
84 gop_frame->is_key_frame = 0;
85 gop_frame->is_arf_frame = 0;
86 gop_frame->is_show_frame = 1;
87 gop_frame->is_golden_frame = 0;
88 gop_frame->encode_ref_mode = EncodeRefMode::kShowExisting;
89 break;
90 case GopFrameType::kOverlay:
91 gop_frame->is_key_frame = 0;
92 gop_frame->is_arf_frame = 0;
93 gop_frame->is_show_frame = 1;
94 gop_frame->is_golden_frame = 0;
95 gop_frame->encode_ref_mode = EncodeRefMode::kOverlay;
96 break;
97 }
98 }
99
GopFrameBasic(int global_coding_idx_offset,int global_order_idx_offset,int coding_idx,int order_idx,int depth,int display_idx,GopFrameType gop_frame_type)100 GopFrame GopFrameBasic(int global_coding_idx_offset,
101 int global_order_idx_offset, int coding_idx,
102 int order_idx, int depth, int display_idx,
103 GopFrameType gop_frame_type) {
104 GopFrame gop_frame = {};
105 gop_frame.is_valid = true;
106 gop_frame.coding_idx = coding_idx;
107 gop_frame.order_idx = order_idx;
108 gop_frame.display_idx = display_idx;
109 gop_frame.global_coding_idx = global_coding_idx_offset + coding_idx;
110 gop_frame.global_order_idx = global_order_idx_offset + order_idx;
111 gop_frame.layer_depth = depth + kLayerDepthOffset;
112 gop_frame.colocated_ref_idx = -1;
113 gop_frame.update_ref_idx = -1;
114 SetGopFrameByType(gop_frame_type, &gop_frame);
115 return gop_frame;
116 }
117
118 // This function create gop frames with indices of display order from
119 // order_start to order_end - 1. The function will recursively introduce
120 // intermediate ARF untill maximum depth is met or the number of regular frames
121 // in between two ARFs are less than 3. Than the regular frames will be added
122 // into the gop_struct.
ConstructGopMultiLayer(GopStruct * gop_struct,RefFrameManager * ref_frame_manager,int max_depth,int depth,int order_start,int order_end)123 void ConstructGopMultiLayer(GopStruct *gop_struct,
124 RefFrameManager *ref_frame_manager, int max_depth,
125 int depth, int order_start, int order_end) {
126 GopFrame gop_frame;
127 int num_frames = order_end - order_start;
128 const int global_coding_idx_offset = gop_struct->global_coding_idx_offset;
129 const int global_order_idx_offset = gop_struct->global_order_idx_offset;
130 // If there are less than kMinIntervalToAddArf frames, stop introducing ARF
131 if (depth < max_depth && num_frames >= kMinIntervalToAddArf) {
132 int order_mid = (order_start + order_end) / 2;
133 // intermediate ARF
134 gop_frame = GopFrameBasic(
135 global_coding_idx_offset, global_order_idx_offset,
136 static_cast<int>(gop_struct->gop_frame_list.size()), order_mid, depth,
137 gop_struct->display_tracker, GopFrameType::kIntermediateArf);
138 ref_frame_manager->UpdateRefFrameTable(&gop_frame);
139 gop_struct->gop_frame_list.push_back(gop_frame);
140 ConstructGopMultiLayer(gop_struct, ref_frame_manager, max_depth, depth + 1,
141 order_start, order_mid);
142 // show existing intermediate ARF
143 gop_frame =
144 GopFrameBasic(global_coding_idx_offset, global_order_idx_offset,
145 static_cast<int>(gop_struct->gop_frame_list.size()),
146 order_mid, max_depth, gop_struct->display_tracker,
147 GopFrameType::kIntermediateOverlay);
148 ref_frame_manager->UpdateRefFrameTable(&gop_frame);
149 gop_struct->gop_frame_list.push_back(gop_frame);
150 ++gop_struct->display_tracker;
151 ConstructGopMultiLayer(gop_struct, ref_frame_manager, max_depth, depth + 1,
152 order_mid + 1, order_end);
153 } else {
154 // regular frame
155 for (int i = order_start; i < order_end; ++i) {
156 gop_frame = GopFrameBasic(
157 global_coding_idx_offset, global_order_idx_offset,
158 static_cast<int>(gop_struct->gop_frame_list.size()), i, max_depth,
159 gop_struct->display_tracker, GopFrameType::kRegularLeaf);
160 ref_frame_manager->UpdateRefFrameTable(&gop_frame);
161 gop_struct->gop_frame_list.push_back(gop_frame);
162 ++gop_struct->display_tracker;
163 }
164 }
165 }
166
ConstructGop(RefFrameManager * ref_frame_manager,int show_frame_count,bool has_key_frame,int global_coding_idx_offset,int global_order_idx_offset)167 GopStruct ConstructGop(RefFrameManager *ref_frame_manager, int show_frame_count,
168 bool has_key_frame, int global_coding_idx_offset,
169 int global_order_idx_offset) {
170 GopStruct gop_struct;
171 gop_struct.show_frame_count = show_frame_count;
172 gop_struct.global_coding_idx_offset = global_coding_idx_offset;
173 gop_struct.global_order_idx_offset = global_order_idx_offset;
174 int order_start = 0;
175 int order_end = show_frame_count - 1;
176
177 // TODO(jingning): Re-enable the use of pyramid coding structure.
178 bool has_arf_frame = show_frame_count > kMinIntervalToAddArf;
179
180 gop_struct.display_tracker = 0;
181
182 GopFrame gop_frame;
183 if (has_key_frame) {
184 const int key_frame_depth = -1;
185 ref_frame_manager->Reset();
186 gop_frame = GopFrameBasic(
187 global_coding_idx_offset, global_order_idx_offset,
188 static_cast<int>(gop_struct.gop_frame_list.size()), order_start,
189 key_frame_depth, gop_struct.display_tracker, GopFrameType::kRegularKey);
190 ref_frame_manager->UpdateRefFrameTable(&gop_frame);
191 gop_struct.gop_frame_list.push_back(gop_frame);
192 order_start++;
193 ++gop_struct.display_tracker;
194 }
195
196 const int arf_depth = 0;
197 if (has_arf_frame) {
198 // Use multi-layer pyrmaid coding structure.
199 gop_frame = GopFrameBasic(
200 global_coding_idx_offset, global_order_idx_offset,
201 static_cast<int>(gop_struct.gop_frame_list.size()), order_end,
202 arf_depth, gop_struct.display_tracker, GopFrameType::kRegularArf);
203 ref_frame_manager->UpdateRefFrameTable(&gop_frame);
204 gop_struct.gop_frame_list.push_back(gop_frame);
205 ConstructGopMultiLayer(&gop_struct, ref_frame_manager,
206 ref_frame_manager->MaxRefFrame() - 1, arf_depth + 1,
207 order_start, order_end);
208 // Overlay
209 gop_frame =
210 GopFrameBasic(global_coding_idx_offset, global_order_idx_offset,
211 static_cast<int>(gop_struct.gop_frame_list.size()),
212 order_end, ref_frame_manager->MaxRefFrame() - 1,
213 gop_struct.display_tracker, GopFrameType::kOverlay);
214 ref_frame_manager->UpdateRefFrameTable(&gop_frame);
215 gop_struct.gop_frame_list.push_back(gop_frame);
216 ++gop_struct.display_tracker;
217 } else {
218 // Use IPPP format.
219 for (int i = order_start; i <= order_end; ++i) {
220 gop_frame = GopFrameBasic(
221 global_coding_idx_offset, global_order_idx_offset,
222 static_cast<int>(gop_struct.gop_frame_list.size()), i, arf_depth + 1,
223 gop_struct.display_tracker, GopFrameType::kRegularLeaf);
224 ref_frame_manager->UpdateRefFrameTable(&gop_frame);
225 gop_struct.gop_frame_list.push_back(gop_frame);
226 ++gop_struct.display_tracker;
227 }
228 }
229
230 return gop_struct;
231 }
232
SetRcParam(const RateControlParam & rc_param)233 Status AV1RateControlQMode::SetRcParam(const RateControlParam &rc_param) {
234 std::ostringstream error_message;
235 if (rc_param.max_gop_show_frame_count <
236 std::max(4, rc_param.min_gop_show_frame_count)) {
237 error_message << "max_gop_show_frame_count ("
238 << rc_param.max_gop_show_frame_count
239 << ") must be at least 4 and may not be less than "
240 "min_gop_show_frame_count ("
241 << rc_param.min_gop_show_frame_count << ")";
242 return { AOM_CODEC_INVALID_PARAM, error_message.str() };
243 }
244 if (rc_param.ref_frame_table_size < 1 || rc_param.ref_frame_table_size > 8) {
245 error_message << "ref_frame_table_size (" << rc_param.ref_frame_table_size
246 << ") must be in the range [1, 8].";
247 return { AOM_CODEC_INVALID_PARAM, error_message.str() };
248 }
249 if (rc_param.max_ref_frames < 1 || rc_param.max_ref_frames > 7) {
250 error_message << "max_ref_frames (" << rc_param.max_ref_frames
251 << ") must be in the range [1, 7].";
252 return { AOM_CODEC_INVALID_PARAM, error_message.str() };
253 }
254 if (rc_param.base_q_index < 0 || rc_param.base_q_index > 255) {
255 error_message << "base_q_index (" << rc_param.base_q_index
256 << ") must be in the range [0, 255].";
257 return { AOM_CODEC_INVALID_PARAM, error_message.str() };
258 }
259 if (rc_param.frame_width < 16 || rc_param.frame_width > 16384 ||
260 rc_param.frame_height < 16 || rc_param.frame_height > 16384) {
261 error_message << "frame_width (" << rc_param.frame_width
262 << ") and frame_height (" << rc_param.frame_height
263 << ") must be in the range [16, 16384].";
264 return { AOM_CODEC_INVALID_PARAM, error_message.str() };
265 }
266 rc_param_ = rc_param;
267 return { AOM_CODEC_OK, "" };
268 }
269
270 // Threshold for use of the lagging second reference frame. High second ref
271 // usage may point to a transient event like a flash or occlusion rather than
272 // a real scene cut.
273 // We adapt the threshold based on number of frames in this key-frame group so
274 // far.
GetSecondRefUsageThreshold(int frame_count_so_far)275 static double GetSecondRefUsageThreshold(int frame_count_so_far) {
276 const int adapt_upto = 32;
277 const double min_second_ref_usage_thresh = 0.085;
278 const double second_ref_usage_thresh_max_delta = 0.035;
279 if (frame_count_so_far >= adapt_upto) {
280 return min_second_ref_usage_thresh + second_ref_usage_thresh_max_delta;
281 }
282 return min_second_ref_usage_thresh +
283 ((double)frame_count_so_far / (adapt_upto - 1)) *
284 second_ref_usage_thresh_max_delta;
285 }
286
287 // Slide show transition detection.
288 // Tests for case where there is very low error either side of the current frame
289 // but much higher just for this frame. This can help detect key frames in
290 // slide shows even where the slides are pictures of different sizes.
291 // Also requires that intra and inter errors are very similar to help eliminate
292 // harmful false positives.
293 // It will not help if the transition is a fade or other multi-frame effect.
DetectSlideTransition(const FIRSTPASS_STATS & this_frame,const FIRSTPASS_STATS & last_frame,const FIRSTPASS_STATS & next_frame)294 static bool DetectSlideTransition(const FIRSTPASS_STATS &this_frame,
295 const FIRSTPASS_STATS &last_frame,
296 const FIRSTPASS_STATS &next_frame) {
297 // Intra / Inter threshold very low
298 constexpr double kVeryLowII = 1.5;
299 // Clean slide transitions we expect a sharp single frame spike in error.
300 constexpr double kErrorSpike = 5.0;
301
302 // TODO(angiebird): Understand the meaning of these conditions.
303 return (this_frame.intra_error < (this_frame.coded_error * kVeryLowII)) &&
304 (this_frame.coded_error > (last_frame.coded_error * kErrorSpike)) &&
305 (this_frame.coded_error > (next_frame.coded_error * kErrorSpike));
306 }
307
308 // Check if there is a significant intra/inter error change between the current
309 // frame and its neighbor. If so, we should further test whether the current
310 // frame should be a key frame.
DetectIntraInterErrorChange(const FIRSTPASS_STATS & this_stats,const FIRSTPASS_STATS & last_stats,const FIRSTPASS_STATS & next_stats)311 static bool DetectIntraInterErrorChange(const FIRSTPASS_STATS &this_stats,
312 const FIRSTPASS_STATS &last_stats,
313 const FIRSTPASS_STATS &next_stats) {
314 // Minimum % intra coding observed in first pass (1.0 = 100%)
315 constexpr double kMinIntraLevel = 0.25;
316 // Minimum ratio between the % of intra coding and inter coding in the first
317 // pass after discounting neutral blocks (discounting neutral blocks in this
318 // way helps catch scene cuts in clips with very flat areas or letter box
319 // format clips with image padding.
320 constexpr double kIntraVsInterRatio = 2.0;
321
322 const double modified_pcnt_inter =
323 this_stats.pcnt_inter - this_stats.pcnt_neutral;
324 const double pcnt_intra_min =
325 std::max(kMinIntraLevel, kIntraVsInterRatio * modified_pcnt_inter);
326
327 // In real scene cuts there is almost always a sharp change in the intra
328 // or inter error score.
329 constexpr double kErrorChangeThreshold = 0.4;
330 const double last_this_error_ratio =
331 fabs(last_stats.coded_error - this_stats.coded_error) /
332 ModifyDivisor(this_stats.coded_error);
333
334 const double this_next_error_ratio =
335 fabs(last_stats.intra_error - this_stats.intra_error) /
336 ModifyDivisor(this_stats.intra_error);
337
338 // Maximum threshold for the relative ratio of intra error score vs best
339 // inter error score.
340 constexpr double kThisIntraCodedErrorRatioMax = 1.9;
341 const double this_intra_coded_error_ratio =
342 this_stats.intra_error / ModifyDivisor(this_stats.coded_error);
343
344 // For real scene cuts we expect an improvment in the intra inter error
345 // ratio in the next frame.
346 constexpr double kNextIntraCodedErrorRatioMin = 3.5;
347 const double next_intra_coded_error_ratio =
348 next_stats.intra_error / ModifyDivisor(next_stats.coded_error);
349
350 double pcnt_intra = 1.0 - this_stats.pcnt_inter;
351 return pcnt_intra > pcnt_intra_min &&
352 this_intra_coded_error_ratio < kThisIntraCodedErrorRatioMax &&
353 (last_this_error_ratio > kErrorChangeThreshold ||
354 this_next_error_ratio > kErrorChangeThreshold ||
355 next_intra_coded_error_ratio > kNextIntraCodedErrorRatioMin);
356 }
357
358 // Check whether the candidate can be a key frame.
359 // This is a rewrite of test_candidate_kf().
TestCandidateKey(const FirstpassInfo & first_pass_info,int candidate_key_idx,int frames_since_prev_key)360 static bool TestCandidateKey(const FirstpassInfo &first_pass_info,
361 int candidate_key_idx, int frames_since_prev_key) {
362 const auto &stats_list = first_pass_info.stats_list;
363 const int stats_count = static_cast<int>(stats_list.size());
364 if (candidate_key_idx + 1 >= stats_count || candidate_key_idx - 1 < 0) {
365 return false;
366 }
367 const auto &last_stats = stats_list[candidate_key_idx - 1];
368 const auto &this_stats = stats_list[candidate_key_idx];
369 const auto &next_stats = stats_list[candidate_key_idx + 1];
370
371 if (frames_since_prev_key < 3) return false;
372 const double second_ref_usage_threshold =
373 GetSecondRefUsageThreshold(frames_since_prev_key);
374 if (this_stats.pcnt_second_ref >= second_ref_usage_threshold) return false;
375 if (next_stats.pcnt_second_ref >= second_ref_usage_threshold) return false;
376
377 // Hard threshold where the first pass chooses intra for almost all blocks.
378 // In such a case even if the frame is not a scene cut coding a key frame
379 // may be a good option.
380 constexpr double kVeryLowInterThreshold = 0.05;
381 if (this_stats.pcnt_inter < kVeryLowInterThreshold ||
382 DetectSlideTransition(this_stats, last_stats, next_stats) ||
383 DetectIntraInterErrorChange(this_stats, last_stats, next_stats)) {
384 double boost_score = 0.0;
385 double decay_accumulator = 1.0;
386
387 // We do "-1" because the candidate key is not counted.
388 int stats_after_this_stats = stats_count - candidate_key_idx - 1;
389
390 // Number of frames required to test for scene cut detection
391 constexpr int kSceneCutKeyTestIntervalMax = 16;
392
393 // Make sure we have enough stats after the candidate key.
394 const int frames_to_test_after_candidate_key =
395 std::min(kSceneCutKeyTestIntervalMax, stats_after_this_stats);
396
397 // Examine how well the key frame predicts subsequent frames.
398 int i;
399 for (i = 1; i <= frames_to_test_after_candidate_key; ++i) {
400 // Get the next frame details
401 const auto &stats = stats_list[candidate_key_idx + i];
402
403 // Cumulative effect of decay in prediction quality.
404 if (stats.pcnt_inter > 0.85) {
405 decay_accumulator *= stats.pcnt_inter;
406 } else {
407 decay_accumulator *= (0.85 + stats.pcnt_inter) / 2.0;
408 }
409
410 constexpr double kBoostFactor = 12.5;
411 double next_iiratio =
412 (kBoostFactor * stats.intra_error / ModifyDivisor(stats.coded_error));
413 next_iiratio = std::min(next_iiratio, 128.0);
414 double boost_score_increment = decay_accumulator * next_iiratio;
415
416 // Keep a running total.
417 boost_score += boost_score_increment;
418
419 // Test various breakout clauses.
420 // TODO(any): Test of intra error should be normalized to an MB.
421 // TODO(angiebird): Investigate the following questions.
422 // Question 1: next_iiratio (intra_error / coded_error) * kBoostFactor
423 // We know intra_error / coded_error >= 1 and kBoostFactor = 12.5,
424 // therefore, (intra_error / coded_error) * kBoostFactor will always
425 // greater than 1.5. Is "next_iiratio < 1.5" always false?
426 // Question 2: Similar to question 1, is "next_iiratio < 3.0" always true?
427 // Question 3: Why do we need to divide 200 with num_mbs_16x16?
428 if ((stats.pcnt_inter < 0.05) || (next_iiratio < 1.5) ||
429 (((stats.pcnt_inter - stats.pcnt_neutral) < 0.20) &&
430 (next_iiratio < 3.0)) ||
431 (boost_score_increment < 3.0) ||
432 (stats.intra_error <
433 (200.0 / static_cast<double>(first_pass_info.num_mbs_16x16)))) {
434 break;
435 }
436 }
437
438 // If there is tolerable prediction for at least the next 3 frames then
439 // break out else discard this potential key frame and move on
440 const int count_for_tolerable_prediction = 3;
441 if (boost_score > 30.0 && (i > count_for_tolerable_prediction)) {
442 return true;
443 }
444 }
445 return false;
446 }
447
448 // Compute key frame location from first_pass_info.
GetKeyFrameList(const FirstpassInfo & first_pass_info)449 std::vector<int> GetKeyFrameList(const FirstpassInfo &first_pass_info) {
450 std::vector<int> key_frame_list;
451 key_frame_list.push_back(0); // The first frame is always a key frame
452 int candidate_key_idx = 1;
453 while (candidate_key_idx <
454 static_cast<int>(first_pass_info.stats_list.size())) {
455 const int frames_since_prev_key = candidate_key_idx - key_frame_list.back();
456 // Check for a scene cut.
457 const bool scenecut_detected = TestCandidateKey(
458 first_pass_info, candidate_key_idx, frames_since_prev_key);
459 if (scenecut_detected) {
460 key_frame_list.push_back(candidate_key_idx);
461 }
462 ++candidate_key_idx;
463 }
464 return key_frame_list;
465 }
466
467 // initialize GF_GROUP_STATS
InitGFStats(GF_GROUP_STATS * gf_stats)468 static void InitGFStats(GF_GROUP_STATS *gf_stats) {
469 gf_stats->gf_group_err = 0.0;
470 gf_stats->gf_group_raw_error = 0.0;
471 gf_stats->gf_group_skip_pct = 0.0;
472 gf_stats->gf_group_inactive_zone_rows = 0.0;
473
474 gf_stats->mv_ratio_accumulator = 0.0;
475 gf_stats->decay_accumulator = 1.0;
476 gf_stats->zero_motion_accumulator = 1.0;
477 gf_stats->loop_decay_rate = 1.0;
478 gf_stats->last_loop_decay_rate = 1.0;
479 gf_stats->this_frame_mv_in_out = 0.0;
480 gf_stats->mv_in_out_accumulator = 0.0;
481 gf_stats->abs_mv_in_out_accumulator = 0.0;
482
483 gf_stats->avg_sr_coded_error = 0.0;
484 gf_stats->avg_pcnt_second_ref = 0.0;
485 gf_stats->avg_new_mv_count = 0.0;
486 gf_stats->avg_wavelet_energy = 0.0;
487 gf_stats->avg_raw_err_stdev = 0.0;
488 gf_stats->non_zero_stdev_count = 0;
489 }
490
FindRegionIndex(const std::vector<REGIONS> & regions,int frame_idx)491 static int FindRegionIndex(const std::vector<REGIONS> ®ions, int frame_idx) {
492 for (int k = 0; k < static_cast<int>(regions.size()); k++) {
493 if (regions[k].start <= frame_idx && regions[k].last >= frame_idx) {
494 return k;
495 }
496 }
497 return -1;
498 }
499
500 // This function detects a flash through the high relative pcnt_second_ref
501 // score in the frame following a flash frame. The offset passed in should
502 // reflect this.
DetectFlash(const std::vector<FIRSTPASS_STATS> & stats_list,int index)503 static bool DetectFlash(const std::vector<FIRSTPASS_STATS> &stats_list,
504 int index) {
505 int next_index = index + 1;
506 if (next_index >= static_cast<int>(stats_list.size())) return false;
507 const FIRSTPASS_STATS &next_frame = stats_list[next_index];
508
509 // What we are looking for here is a situation where there is a
510 // brief break in prediction (such as a flash) but subsequent frames
511 // are reasonably well predicted by an earlier (pre flash) frame.
512 // The recovery after a flash is indicated by a high pcnt_second_ref
513 // compared to pcnt_inter.
514 return next_frame.pcnt_second_ref > next_frame.pcnt_inter &&
515 next_frame.pcnt_second_ref >= 0.5;
516 }
517
518 #define MIN_SHRINK_LEN 6
519
520 // This function takes in a suggesting gop interval from cur_start to cur_last,
521 // analyzes firstpass stats and region stats and then return a better gop cut
522 // location.
523 // TODO(b/231517281): Simplify the indices once we have an unit test.
524 // We are using four indices here, order_index, cur_start, cur_last, and
525 // frames_since_key. Ideally, only three indices are needed.
526 // 1) start_index = order_index + cur_start
527 // 2) end_index = order_index + cur_end
528 // 3) key_index
FindBetterGopCut(const std::vector<FIRSTPASS_STATS> & stats_list,const std::vector<REGIONS> & regions_list,int min_gop_show_frame_count,int max_gop_show_frame_count,int order_index,int cur_start,int cur_last,int frames_since_key)529 int FindBetterGopCut(const std::vector<FIRSTPASS_STATS> &stats_list,
530 const std::vector<REGIONS> ®ions_list,
531 int min_gop_show_frame_count, int max_gop_show_frame_count,
532 int order_index, int cur_start, int cur_last,
533 int frames_since_key) {
534 // only try shrinking if interval smaller than active_max_gf_interval
535 if (cur_last - cur_start > max_gop_show_frame_count ||
536 cur_start >= cur_last) {
537 return cur_last;
538 }
539 int num_regions = static_cast<int>(regions_list.size());
540 int num_stats = static_cast<int>(stats_list.size());
541 const int min_shrink_int = std::max(MIN_SHRINK_LEN, min_gop_show_frame_count);
542
543 // find the region indices of where the first and last frame belong.
544 int k_start = FindRegionIndex(regions_list, cur_start + frames_since_key);
545 int k_last = FindRegionIndex(regions_list, cur_last + frames_since_key);
546 if (cur_start + frames_since_key == 0) k_start = 0;
547
548 int scenecut_idx = -1;
549 // See if we have a scenecut in between
550 for (int r = k_start + 1; r <= k_last; r++) {
551 if (regions_list[r].type == SCENECUT_REGION &&
552 regions_list[r].last - frames_since_key - cur_start >
553 min_gop_show_frame_count) {
554 scenecut_idx = r;
555 break;
556 }
557 }
558
559 // if the found scenecut is very close to the end, ignore it.
560 if (scenecut_idx >= 0 &&
561 regions_list[num_regions - 1].last - regions_list[scenecut_idx].last <
562 4) {
563 scenecut_idx = -1;
564 }
565
566 if (scenecut_idx != -1) {
567 // If we have a scenecut, then stop at it.
568 // TODO(bohanli): add logic here to stop before the scenecut and for
569 // the next gop start from the scenecut with GF
570 int is_minor_sc =
571 (regions_list[scenecut_idx].avg_cor_coeff *
572 (1 - stats_list[order_index + regions_list[scenecut_idx].start -
573 frames_since_key]
574 .noise_var /
575 regions_list[scenecut_idx].avg_intra_err) >
576 0.6);
577 cur_last =
578 regions_list[scenecut_idx].last - frames_since_key - !is_minor_sc;
579 } else {
580 int is_last_analysed =
581 (k_last == num_regions - 1) &&
582 (cur_last + frames_since_key == regions_list[k_last].last);
583 int not_enough_regions =
584 k_last - k_start <= 1 + (regions_list[k_start].type == SCENECUT_REGION);
585 // if we are very close to the end, then do not shrink since it may
586 // introduce intervals that are too short
587 if (!(is_last_analysed && not_enough_regions)) {
588 const double arf_length_factor = 0.1;
589 double best_score = 0;
590 int best_j = -1;
591 const int first_frame = regions_list[0].start - frames_since_key;
592 const int last_frame =
593 regions_list[num_regions - 1].last - frames_since_key;
594 // score of how much the arf helps the whole GOP
595 double base_score = 0.0;
596 // Accumulate base_score in
597 for (int j = cur_start + 1; j < cur_start + min_shrink_int; j++) {
598 if (order_index + j >= num_stats) break;
599 base_score = (base_score + 1.0) * stats_list[order_index + j].cor_coeff;
600 }
601 int met_blending = 0; // Whether we have met blending areas before
602 int last_blending = 0; // Whether the previous frame if blending
603 for (int j = cur_start + min_shrink_int; j <= cur_last; j++) {
604 if (order_index + j >= num_stats) break;
605 base_score = (base_score + 1.0) * stats_list[order_index + j].cor_coeff;
606 int this_reg = FindRegionIndex(regions_list, j + frames_since_key);
607 if (this_reg < 0) continue;
608 // A GOP should include at most 1 blending region.
609 if (regions_list[this_reg].type == BLENDING_REGION) {
610 last_blending = 1;
611 if (met_blending) {
612 break;
613 } else {
614 base_score = 0;
615 continue;
616 }
617 } else {
618 if (last_blending) met_blending = 1;
619 last_blending = 0;
620 }
621
622 // Add the factor of how good the neighborhood is for this
623 // candidate arf.
624 double this_score = arf_length_factor * base_score;
625 double temp_accu_coeff = 1.0;
626 // following frames
627 int count_f = 0;
628 for (int n = j + 1; n <= j + 3 && n <= last_frame; n++) {
629 if (order_index + n >= num_stats) break;
630 temp_accu_coeff *= stats_list[order_index + n].cor_coeff;
631 this_score +=
632 temp_accu_coeff *
633 (1 - stats_list[order_index + n].noise_var /
634 AOMMAX(regions_list[this_reg].avg_intra_err, 0.001));
635 count_f++;
636 }
637 // preceding frames
638 temp_accu_coeff = 1.0;
639 for (int n = j; n > j - 3 * 2 + count_f && n > first_frame; n--) {
640 if (order_index + n < 0) break;
641 temp_accu_coeff *= stats_list[order_index + n].cor_coeff;
642 this_score +=
643 temp_accu_coeff *
644 (1 - stats_list[order_index + n].noise_var /
645 AOMMAX(regions_list[this_reg].avg_intra_err, 0.001));
646 }
647
648 if (this_score > best_score) {
649 best_score = this_score;
650 best_j = j;
651 }
652 }
653
654 // For blending areas, move one more frame in case we missed the
655 // first blending frame.
656 int best_reg = FindRegionIndex(regions_list, best_j + frames_since_key);
657 if (best_reg < num_regions - 1 && best_reg > 0) {
658 if (regions_list[best_reg - 1].type == BLENDING_REGION &&
659 regions_list[best_reg + 1].type == BLENDING_REGION) {
660 if (best_j + frames_since_key == regions_list[best_reg].start &&
661 best_j + frames_since_key < regions_list[best_reg].last) {
662 best_j += 1;
663 } else if (best_j + frames_since_key == regions_list[best_reg].last &&
664 best_j + frames_since_key > regions_list[best_reg].start) {
665 best_j -= 1;
666 }
667 }
668 }
669
670 if (cur_last - best_j < 2) best_j = cur_last;
671 if (best_j > 0 && best_score > 0.1) cur_last = best_j;
672 // if cannot find anything, just cut at the original place.
673 }
674 }
675
676 return cur_last;
677 }
678
679 // Function to test for a condition where a complex transition is followed
680 // by a static section. For example in slide shows where there is a fade
681 // between slides. This is to help with more optimal kf and gf positioning.
DetectTransitionToStill(const std::vector<FIRSTPASS_STATS> & stats_list,int next_stats_index,int min_gop_show_frame_count,int frame_interval,int still_interval,double loop_decay_rate,double last_decay_rate)682 static bool DetectTransitionToStill(
683 const std::vector<FIRSTPASS_STATS> &stats_list, int next_stats_index,
684 int min_gop_show_frame_count, int frame_interval, int still_interval,
685 double loop_decay_rate, double last_decay_rate) {
686 // Break clause to detect very still sections after motion
687 // For example a static image after a fade or other transition
688 // instead of a clean scene cut.
689 if (frame_interval > min_gop_show_frame_count && loop_decay_rate >= 0.999 &&
690 last_decay_rate < 0.9) {
691 int stats_count = static_cast<int>(stats_list.size());
692 int stats_left = stats_count - next_stats_index;
693 if (stats_left >= still_interval) {
694 // Look ahead a few frames to see if static condition persists...
695 int j;
696 for (j = 0; j < still_interval; ++j) {
697 const FIRSTPASS_STATS &stats = stats_list[next_stats_index + j];
698 if (stats.pcnt_inter - stats.pcnt_motion < 0.999) break;
699 }
700 // Only if it does do we signal a transition to still.
701 return j == still_interval;
702 }
703 }
704 return false;
705 }
706
DetectGopCut(const std::vector<FIRSTPASS_STATS> & stats_list,int start_idx,int candidate_cut_idx,int next_key_idx,int flash_detected,int min_gop_show_frame_count,int max_gop_show_frame_count,int frame_width,int frame_height,const GF_GROUP_STATS & gf_stats)707 static int DetectGopCut(const std::vector<FIRSTPASS_STATS> &stats_list,
708 int start_idx, int candidate_cut_idx, int next_key_idx,
709 int flash_detected, int min_gop_show_frame_count,
710 int max_gop_show_frame_count, int frame_width,
711 int frame_height, const GF_GROUP_STATS &gf_stats) {
712 (void)max_gop_show_frame_count;
713 const int candidate_gop_size = candidate_cut_idx - start_idx;
714
715 if (!flash_detected) {
716 // Break clause to detect very still sections after motion. For example,
717 // a static image after a fade or other transition.
718 if (DetectTransitionToStill(stats_list, start_idx, min_gop_show_frame_count,
719 candidate_gop_size, 5, gf_stats.loop_decay_rate,
720 gf_stats.last_loop_decay_rate)) {
721 return 1;
722 }
723 const double arf_abs_zoom_thresh = 4.4;
724 // Motion breakout threshold for loop below depends on image size.
725 const double mv_ratio_accumulator_thresh =
726 (frame_height + frame_width) / 4.0;
727 // Some conditions to breakout after min interval.
728 if (candidate_gop_size >= min_gop_show_frame_count &&
729 // If possible don't break very close to a kf
730 (next_key_idx - candidate_cut_idx >= min_gop_show_frame_count) &&
731 (candidate_gop_size & 0x01) &&
732 (gf_stats.mv_ratio_accumulator > mv_ratio_accumulator_thresh ||
733 gf_stats.abs_mv_in_out_accumulator > arf_abs_zoom_thresh)) {
734 return 1;
735 }
736 }
737
738 // TODO(b/231489624): Check if we need this part.
739 // If almost totally static, we will not use the the max GF length later,
740 // so we can continue for more frames.
741 // if ((candidate_gop_size >= active_max_gf_interval + 1) &&
742 // !is_almost_static(gf_stats->zero_motion_accumulator,
743 // twopass->kf_zeromotion_pct, cpi->ppi->lap_enabled)) {
744 // return 0;
745 // }
746 return 0;
747 }
748
749 /*!\brief Determine the length of future GF groups.
750 *
751 * \ingroup gf_group_algo
752 * This function decides the gf group length of future frames in batch
753 *
754 * \param[in] rc_param Rate control parameters
755 * \param[in] stats_list List of first pass stats
756 * \param[in] regions_list List of regions from av1_identify_regions
757 * \param[in] order_index Index of current frame in stats_list
758 * \param[in] frames_since_key Number of frames since the last key frame
759 * \param[in] frames_to_key Number of frames to the next key frame
760 *
761 * \return Returns a vector of decided GF group lengths.
762 */
PartitionGopIntervals(const RateControlParam & rc_param,const std::vector<FIRSTPASS_STATS> & stats_list,const std::vector<REGIONS> & regions_list,int order_index,int frames_since_key,int frames_to_key)763 static std::vector<int> PartitionGopIntervals(
764 const RateControlParam &rc_param,
765 const std::vector<FIRSTPASS_STATS> &stats_list,
766 const std::vector<REGIONS> ®ions_list, int order_index,
767 int frames_since_key, int frames_to_key) {
768 int i = 0;
769 // If cpi->gf_state.arf_gf_boost_lst is 0, we are starting with a KF or GF.
770 int cur_start = 0;
771 // Each element is the last frame of the previous GOP. If there are n GOPs,
772 // you need n + 1 cuts to find the durations. So cut_pos starts out with -1,
773 // which is the last frame of the previous GOP.
774 std::vector<int> cut_pos(1, -1);
775 int cut_here = 0;
776 GF_GROUP_STATS gf_stats;
777 InitGFStats(&gf_stats);
778 int num_stats = static_cast<int>(stats_list.size());
779
780 while (i + order_index < num_stats) {
781 // reaches next key frame, break here
782 if (i >= frames_to_key - 1) {
783 cut_here = 2;
784 } else if (i - cur_start >= rc_param.max_gop_show_frame_count) {
785 // reached maximum len, but nothing special yet (almost static)
786 // let's look at the next interval
787 cut_here = 2;
788 } else {
789 // Test for the case where there is a brief flash but the prediction
790 // quality back to an earlier frame is then restored.
791 const int gop_start_idx = cur_start + order_index;
792 const int candidate_gop_cut_idx = i + order_index;
793 const int next_key_idx = frames_to_key + order_index;
794 const bool flash_detected =
795 DetectFlash(stats_list, candidate_gop_cut_idx);
796
797 // TODO(bohanli): remove redundant accumulations here, or unify
798 // this and the ones in define_gf_group
799 const FIRSTPASS_STATS *stats = &stats_list[candidate_gop_cut_idx];
800 av1_accumulate_next_frame_stats(stats, flash_detected, frames_since_key,
801 i, &gf_stats, rc_param.frame_width,
802 rc_param.frame_height);
803
804 // TODO(angiebird): Can we simplify this part? Looks like we are going to
805 // change the gop cut index with FindBetterGopCut() anyway.
806 cut_here = DetectGopCut(
807 stats_list, gop_start_idx, candidate_gop_cut_idx, next_key_idx,
808 flash_detected, rc_param.min_gop_show_frame_count,
809 rc_param.max_gop_show_frame_count, rc_param.frame_width,
810 rc_param.frame_height, gf_stats);
811 }
812
813 if (!cut_here) {
814 ++i;
815 continue;
816 }
817
818 // the current last frame in the gf group
819 int original_last = cut_here > 1 ? i : i - 1;
820 int cur_last = FindBetterGopCut(
821 stats_list, regions_list, rc_param.min_gop_show_frame_count,
822 rc_param.max_gop_show_frame_count, order_index, cur_start,
823 original_last, frames_since_key);
824 // only try shrinking if interval smaller than active_max_gf_interval
825 cut_pos.push_back(cur_last);
826
827 // reset pointers to the shrunken location
828 cur_start = cur_last;
829 int cur_region_idx =
830 FindRegionIndex(regions_list, cur_start + 1 + frames_since_key);
831 if (cur_region_idx >= 0)
832 if (regions_list[cur_region_idx].type == SCENECUT_REGION) cur_start++;
833
834 // reset accumulators
835 InitGFStats(&gf_stats);
836 i = cur_last + 1;
837
838 if (cut_here == 2 && i >= frames_to_key) break;
839 }
840
841 std::vector<int> gf_intervals;
842 // save intervals
843 for (size_t n = 1; n < cut_pos.size(); n++) {
844 gf_intervals.push_back(cut_pos[n] - cut_pos[n - 1]);
845 }
846
847 return gf_intervals;
848 }
849
DetermineGopInfo(const FirstpassInfo & firstpass_info)850 StatusOr<GopStructList> AV1RateControlQMode::DetermineGopInfo(
851 const FirstpassInfo &firstpass_info) {
852 const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
853 GopStructList gop_list;
854 RefFrameManager ref_frame_manager(rc_param_.ref_frame_table_size,
855 rc_param_.max_ref_frames);
856
857 // Make a copy of the first pass stats, and analyze them
858 FirstpassInfo fp_info_copy = firstpass_info;
859 av1_mark_flashes(fp_info_copy.stats_list.data(),
860 fp_info_copy.stats_list.data() + stats_size);
861 av1_estimate_noise(fp_info_copy.stats_list.data(),
862 fp_info_copy.stats_list.data() + stats_size);
863 av1_estimate_coeff(fp_info_copy.stats_list.data(),
864 fp_info_copy.stats_list.data() + stats_size);
865
866 int global_coding_idx_offset = 0;
867 int global_order_idx_offset = 0;
868 std::vector<int> key_frame_list = GetKeyFrameList(fp_info_copy);
869 key_frame_list.push_back(stats_size); // a sentinel value
870 for (size_t ki = 0; ki + 1 < key_frame_list.size(); ++ki) {
871 int frames_to_key = key_frame_list[ki + 1] - key_frame_list[ki];
872 int key_order_index = key_frame_list[ki]; // The key frame's display order
873
874 std::vector<REGIONS> regions_list(MAX_FIRSTPASS_ANALYSIS_FRAMES);
875 int total_regions = 0;
876 av1_identify_regions(fp_info_copy.stats_list.data() + key_order_index,
877 frames_to_key, 0, regions_list.data(), &total_regions);
878 regions_list.resize(total_regions);
879 std::vector<int> gf_intervals = PartitionGopIntervals(
880 rc_param_, fp_info_copy.stats_list, regions_list, key_order_index,
881 /*frames_since_key=*/0, frames_to_key);
882 for (size_t gi = 0; gi < gf_intervals.size(); ++gi) {
883 const bool has_key_frame = gi == 0;
884 const int show_frame_count = gf_intervals[gi];
885 GopStruct gop =
886 ConstructGop(&ref_frame_manager, show_frame_count, has_key_frame,
887 global_coding_idx_offset, global_order_idx_offset);
888 assert(gop.show_frame_count == show_frame_count);
889 global_coding_idx_offset += static_cast<int>(gop.gop_frame_list.size());
890 global_order_idx_offset += gop.show_frame_count;
891 gop_list.push_back(gop);
892 }
893 }
894 return gop_list;
895 }
896
CreateTplFrameDepStats(int frame_height,int frame_width,int min_block_size)897 TplFrameDepStats CreateTplFrameDepStats(int frame_height, int frame_width,
898 int min_block_size) {
899 const int unit_rows = (frame_height + min_block_size - 1) / min_block_size;
900 const int unit_cols = (frame_width + min_block_size - 1) / min_block_size;
901 TplFrameDepStats frame_dep_stats;
902 frame_dep_stats.unit_size = min_block_size;
903 frame_dep_stats.unit_stats.resize(unit_rows);
904 for (auto &row : frame_dep_stats.unit_stats) {
905 row.resize(unit_cols);
906 }
907 return frame_dep_stats;
908 }
909
TplBlockStatsToDepStats(const TplBlockStats & block_stats,int unit_count)910 TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
911 int unit_count) {
912 TplUnitDepStats dep_stats = {};
913 dep_stats.intra_cost = block_stats.intra_cost * 1.0 / unit_count;
914 dep_stats.inter_cost = block_stats.inter_cost * 1.0 / unit_count;
915 // In rare case, inter_cost may be greater than intra_cost.
916 // If so, we need to modify inter_cost such that inter_cost <= intra_cost
917 // because it is required by GetPropagationFraction()
918 dep_stats.inter_cost = std::min(dep_stats.intra_cost, dep_stats.inter_cost);
919 dep_stats.mv = block_stats.mv;
920 dep_stats.ref_frame_index = block_stats.ref_frame_index;
921 return dep_stats;
922 }
923
924 namespace {
ValidateBlockStats(const TplFrameStats & frame_stats,const TplBlockStats & block_stats,int min_block_size)925 Status ValidateBlockStats(const TplFrameStats &frame_stats,
926 const TplBlockStats &block_stats,
927 int min_block_size) {
928 if (block_stats.col >= frame_stats.frame_width ||
929 block_stats.row >= frame_stats.frame_height) {
930 std::ostringstream error_message;
931 error_message << "Block position (" << block_stats.col << ", "
932 << block_stats.row
933 << ") is out of range; frame dimensions are "
934 << frame_stats.frame_width << " x "
935 << frame_stats.frame_height;
936 return { AOM_CODEC_INVALID_PARAM, error_message.str() };
937 }
938 if (block_stats.col % min_block_size != 0 ||
939 block_stats.row % min_block_size != 0 ||
940 block_stats.width % min_block_size != 0 ||
941 block_stats.height % min_block_size != 0) {
942 std::ostringstream error_message;
943 error_message
944 << "Invalid block position or dimension, must be a multiple of "
945 << min_block_size << "; col = " << block_stats.col
946 << ", row = " << block_stats.row << ", width = " << block_stats.width
947 << ", height = " << block_stats.height;
948 return { AOM_CODEC_INVALID_PARAM, error_message.str() };
949 }
950 return { AOM_CODEC_OK, "" };
951 }
952
ValidateTplStats(const GopStruct & gop_struct,const TplGopStats & tpl_gop_stats)953 Status ValidateTplStats(const GopStruct &gop_struct,
954 const TplGopStats &tpl_gop_stats) {
955 constexpr char kAdvice[] =
956 "Do the current RateControlParam settings match those used to generate "
957 "the TPL stats?";
958 if (gop_struct.gop_frame_list.size() !=
959 tpl_gop_stats.frame_stats_list.size()) {
960 std::ostringstream error_message;
961 error_message << "Frame count of GopStruct ("
962 << gop_struct.gop_frame_list.size()
963 << ") doesn't match frame count of TPL stats ("
964 << tpl_gop_stats.frame_stats_list.size() << "). " << kAdvice;
965 return { AOM_CODEC_INVALID_PARAM, error_message.str() };
966 }
967 for (int i = 0; i < static_cast<int>(gop_struct.gop_frame_list.size()); ++i) {
968 const bool is_ref_frame = gop_struct.gop_frame_list[i].update_ref_idx >= 0;
969 const bool has_tpl_stats =
970 !tpl_gop_stats.frame_stats_list[i].block_stats_list.empty();
971 if (is_ref_frame && !has_tpl_stats) {
972 std::ostringstream error_message;
973 error_message << "The frame with global_coding_idx "
974 << gop_struct.gop_frame_list[i].global_coding_idx
975 << " is a reference frame, but has no TPL stats. "
976 << kAdvice;
977 return { AOM_CODEC_INVALID_PARAM, error_message.str() };
978 }
979 }
980 return { AOM_CODEC_OK, "" };
981 }
982 } // namespace
983
CreateTplFrameDepStatsWithoutPropagation(const TplFrameStats & frame_stats)984 StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
985 const TplFrameStats &frame_stats) {
986 if (frame_stats.block_stats_list.empty()) {
987 return TplFrameDepStats();
988 }
989 const int min_block_size = frame_stats.min_block_size;
990 const int unit_rows =
991 (frame_stats.frame_height + min_block_size - 1) / min_block_size;
992 const int unit_cols =
993 (frame_stats.frame_width + min_block_size - 1) / min_block_size;
994 TplFrameDepStats frame_dep_stats = CreateTplFrameDepStats(
995 frame_stats.frame_height, frame_stats.frame_width, min_block_size);
996 for (const TplBlockStats &block_stats : frame_stats.block_stats_list) {
997 Status status =
998 ValidateBlockStats(frame_stats, block_stats, min_block_size);
999 if (!status.ok()) {
1000 return status;
1001 }
1002 const int block_unit_row = block_stats.row / min_block_size;
1003 const int block_unit_col = block_stats.col / min_block_size;
1004 // The block must start within the frame boundaries, but it may extend past
1005 // the right edge or bottom of the frame. Find the number of unit rows and
1006 // columns in the block which are fully within the frame.
1007 const int block_unit_rows = std::min(block_stats.height / min_block_size,
1008 unit_rows - block_unit_row);
1009 const int block_unit_cols = std::min(block_stats.width / min_block_size,
1010 unit_cols - block_unit_col);
1011 const int unit_count = block_unit_rows * block_unit_cols;
1012 TplUnitDepStats unit_stats =
1013 TplBlockStatsToDepStats(block_stats, unit_count);
1014 for (int r = 0; r < block_unit_rows; r++) {
1015 for (int c = 0; c < block_unit_cols; c++) {
1016 frame_dep_stats.unit_stats[block_unit_row + r][block_unit_col + c] =
1017 unit_stats;
1018 }
1019 }
1020 }
1021
1022 frame_dep_stats.rdcost = TplFrameDepStatsAccumulateInterCost(frame_dep_stats);
1023
1024 return frame_dep_stats;
1025 }
1026
GetRefCodingIdxList(const TplUnitDepStats & unit_dep_stats,const RefFrameTable & ref_frame_table,int * ref_coding_idx_list)1027 int GetRefCodingIdxList(const TplUnitDepStats &unit_dep_stats,
1028 const RefFrameTable &ref_frame_table,
1029 int *ref_coding_idx_list) {
1030 int ref_frame_count = 0;
1031 for (int i = 0; i < kBlockRefCount; ++i) {
1032 ref_coding_idx_list[i] = -1;
1033 int ref_frame_index = unit_dep_stats.ref_frame_index[i];
1034 if (ref_frame_index != -1) {
1035 assert(ref_frame_index < static_cast<int>(ref_frame_table.size()));
1036 ref_coding_idx_list[i] = ref_frame_table[ref_frame_index].coding_idx;
1037 ref_frame_count++;
1038 }
1039 }
1040 return ref_frame_count;
1041 }
1042
GetBlockOverlapArea(int r0,int c0,int r1,int c1,int size)1043 int GetBlockOverlapArea(int r0, int c0, int r1, int c1, int size) {
1044 const int r_low = std::max(r0, r1);
1045 const int r_high = std::min(r0 + size, r1 + size);
1046 const int c_low = std::max(c0, c1);
1047 const int c_high = std::min(c0 + size, c1 + size);
1048 if (r_high >= r_low && c_high >= c_low) {
1049 return (r_high - r_low) * (c_high - c_low);
1050 }
1051 return 0;
1052 }
1053
1054 // TODO(angiebird): Merge TplFrameDepStatsAccumulateIntraCost and
1055 // TplFrameDepStatsAccumulate.
TplFrameDepStatsAccumulateIntraCost(const TplFrameDepStats & frame_dep_stats)1056 double TplFrameDepStatsAccumulateIntraCost(
1057 const TplFrameDepStats &frame_dep_stats) {
1058 auto getIntraCost = [](double sum, const TplUnitDepStats &unit) {
1059 return sum + unit.intra_cost;
1060 };
1061 double sum = 0;
1062 for (const auto &row : frame_dep_stats.unit_stats) {
1063 sum = std::accumulate(row.begin(), row.end(), sum, getIntraCost);
1064 }
1065 return std::max(sum, 1.0);
1066 }
1067
TplFrameDepStatsAccumulateInterCost(const TplFrameDepStats & frame_dep_stats)1068 double TplFrameDepStatsAccumulateInterCost(
1069 const TplFrameDepStats &frame_dep_stats) {
1070 auto getInterCost = [](double sum, const TplUnitDepStats &unit) {
1071 return sum + unit.inter_cost;
1072 };
1073 double sum = 0;
1074 for (const auto &row : frame_dep_stats.unit_stats) {
1075 sum = std::accumulate(row.begin(), row.end(), sum, getInterCost);
1076 }
1077 return std::max(sum, 1.0);
1078 }
1079
TplFrameDepStatsAccumulate(const TplFrameDepStats & frame_dep_stats)1080 double TplFrameDepStatsAccumulate(const TplFrameDepStats &frame_dep_stats) {
1081 auto getOverallCost = [](double sum, const TplUnitDepStats &unit) {
1082 return sum + unit.propagation_cost + unit.intra_cost;
1083 };
1084 double sum = 0;
1085 for (const auto &row : frame_dep_stats.unit_stats) {
1086 sum = std::accumulate(row.begin(), row.end(), sum, getOverallCost);
1087 }
1088 return std::max(sum, 1.0);
1089 }
1090
1091 // This is a generalization of GET_MV_RAWPEL that allows for an arbitrary
1092 // number of fractional bits.
1093 // TODO(angiebird): Add unit test to this function
GetFullpelValue(int subpel_value,int subpel_bits)1094 int GetFullpelValue(int subpel_value, int subpel_bits) {
1095 const int subpel_scale = (1 << subpel_bits);
1096 const int sign = subpel_value >= 0 ? 1 : -1;
1097 int fullpel_value = (abs(subpel_value) + subpel_scale / 2) >> subpel_bits;
1098 fullpel_value *= sign;
1099 return fullpel_value;
1100 }
1101
GetPropagationFraction(const TplUnitDepStats & unit_dep_stats)1102 double GetPropagationFraction(const TplUnitDepStats &unit_dep_stats) {
1103 assert(unit_dep_stats.intra_cost >= unit_dep_stats.inter_cost);
1104 return (unit_dep_stats.intra_cost - unit_dep_stats.inter_cost) /
1105 ModifyDivisor(unit_dep_stats.intra_cost);
1106 }
1107
TplFrameDepStatsPropagate(int coding_idx,const RefFrameTable & ref_frame_table,TplGopDepStats * tpl_gop_dep_stats)1108 void TplFrameDepStatsPropagate(int coding_idx,
1109 const RefFrameTable &ref_frame_table,
1110 TplGopDepStats *tpl_gop_dep_stats) {
1111 assert(!tpl_gop_dep_stats->frame_dep_stats_list.empty());
1112 TplFrameDepStats *frame_dep_stats =
1113 &tpl_gop_dep_stats->frame_dep_stats_list[coding_idx];
1114
1115 if (frame_dep_stats->unit_stats.empty()) return;
1116
1117 const int unit_size = frame_dep_stats->unit_size;
1118 const int frame_unit_rows =
1119 static_cast<int>(frame_dep_stats->unit_stats.size());
1120 const int frame_unit_cols =
1121 static_cast<int>(frame_dep_stats->unit_stats[0].size());
1122 for (int unit_row = 0; unit_row < frame_unit_rows; ++unit_row) {
1123 for (int unit_col = 0; unit_col < frame_unit_cols; ++unit_col) {
1124 TplUnitDepStats &unit_dep_stats =
1125 frame_dep_stats->unit_stats[unit_row][unit_col];
1126 int ref_coding_idx_list[kBlockRefCount] = { -1, -1 };
1127 int ref_frame_count = GetRefCodingIdxList(unit_dep_stats, ref_frame_table,
1128 ref_coding_idx_list);
1129 if (ref_frame_count == 0) continue;
1130 for (int i = 0; i < kBlockRefCount; ++i) {
1131 if (ref_coding_idx_list[i] == -1) continue;
1132 assert(
1133 ref_coding_idx_list[i] <
1134 static_cast<int>(tpl_gop_dep_stats->frame_dep_stats_list.size()));
1135 TplFrameDepStats &ref_frame_dep_stats =
1136 tpl_gop_dep_stats->frame_dep_stats_list[ref_coding_idx_list[i]];
1137 assert(!ref_frame_dep_stats.unit_stats.empty());
1138 const auto &mv = unit_dep_stats.mv[i];
1139 const int mv_row = GetFullpelValue(mv.row, mv.subpel_bits);
1140 const int mv_col = GetFullpelValue(mv.col, mv.subpel_bits);
1141 const int ref_pixel_r = unit_row * unit_size + mv_row;
1142 const int ref_pixel_c = unit_col * unit_size + mv_col;
1143 const int ref_unit_row_low =
1144 (unit_row * unit_size + mv_row) / unit_size;
1145 const int ref_unit_col_low =
1146 (unit_col * unit_size + mv_col) / unit_size;
1147
1148 for (int j = 0; j < 2; ++j) {
1149 for (int k = 0; k < 2; ++k) {
1150 const int ref_unit_row = ref_unit_row_low + j;
1151 const int ref_unit_col = ref_unit_col_low + k;
1152 if (ref_unit_row >= 0 && ref_unit_row < frame_unit_rows &&
1153 ref_unit_col >= 0 && ref_unit_col < frame_unit_cols) {
1154 const int overlap_area = GetBlockOverlapArea(
1155 ref_pixel_r, ref_pixel_c, ref_unit_row * unit_size,
1156 ref_unit_col * unit_size, unit_size);
1157 const double overlap_ratio =
1158 overlap_area * 1.0 / (unit_size * unit_size);
1159 const double propagation_fraction =
1160 GetPropagationFraction(unit_dep_stats);
1161 const double propagation_ratio =
1162 1.0 / ref_frame_count * overlap_ratio * propagation_fraction;
1163 TplUnitDepStats &ref_unit_stats =
1164 ref_frame_dep_stats.unit_stats[ref_unit_row][ref_unit_col];
1165 ref_unit_stats.propagation_cost +=
1166 (unit_dep_stats.intra_cost +
1167 unit_dep_stats.propagation_cost) *
1168 propagation_ratio;
1169 }
1170 }
1171 }
1172 }
1173 }
1174 }
1175 }
1176
GetRefFrameTableList(const GopStruct & gop_struct,const std::vector<LookaheadStats> & lookahead_stats,RefFrameTable ref_frame_table)1177 std::vector<RefFrameTable> AV1RateControlQMode::GetRefFrameTableList(
1178 const GopStruct &gop_struct,
1179 const std::vector<LookaheadStats> &lookahead_stats,
1180 RefFrameTable ref_frame_table) {
1181 if (gop_struct.global_coding_idx_offset == 0) {
1182 // For the first GOP, ref_frame_table need not be initialized. This is fine,
1183 // because the first frame (a key frame) will fully initialize it.
1184 ref_frame_table.assign(rc_param_.ref_frame_table_size, GopFrameInvalid());
1185 } else {
1186 // It's not the first GOP, so ref_frame_table must be valid.
1187 assert(static_cast<int>(ref_frame_table.size()) ==
1188 rc_param_.ref_frame_table_size);
1189 assert(std::all_of(ref_frame_table.begin(), ref_frame_table.end(),
1190 std::mem_fn(&GopFrame::is_valid)));
1191 // Reset the frame processing order of the initial ref_frame_table.
1192 for (GopFrame &gop_frame : ref_frame_table) gop_frame.coding_idx = -1;
1193 }
1194
1195 std::vector<RefFrameTable> ref_frame_table_list;
1196 ref_frame_table_list.push_back(ref_frame_table);
1197 for (const GopFrame &gop_frame : gop_struct.gop_frame_list) {
1198 if (gop_frame.is_key_frame) {
1199 ref_frame_table.assign(rc_param_.ref_frame_table_size, gop_frame);
1200 } else if (gop_frame.update_ref_idx != -1) {
1201 assert(gop_frame.update_ref_idx <
1202 static_cast<int>(ref_frame_table.size()));
1203 ref_frame_table[gop_frame.update_ref_idx] = gop_frame;
1204 }
1205 ref_frame_table_list.push_back(ref_frame_table);
1206 }
1207
1208 int gop_size_offset = static_cast<int>(gop_struct.gop_frame_list.size());
1209
1210 for (const auto &lookahead_stat : lookahead_stats) {
1211 for (GopFrame gop_frame : lookahead_stat.gop_struct->gop_frame_list) {
1212 if (gop_frame.is_key_frame) {
1213 ref_frame_table.assign(rc_param_.ref_frame_table_size, gop_frame);
1214 } else if (gop_frame.update_ref_idx != -1) {
1215 assert(gop_frame.update_ref_idx <
1216 static_cast<int>(ref_frame_table.size()));
1217 gop_frame.coding_idx += gop_size_offset;
1218 ref_frame_table[gop_frame.update_ref_idx] = gop_frame;
1219 }
1220 ref_frame_table_list.push_back(ref_frame_table);
1221 }
1222 gop_size_offset +=
1223 static_cast<int>(lookahead_stat.gop_struct->gop_frame_list.size());
1224 }
1225
1226 return ref_frame_table_list;
1227 }
1228
ComputeTplGopDepStats(const TplGopStats & tpl_gop_stats,const std::vector<LookaheadStats> & lookahead_stats,const std::vector<RefFrameTable> & ref_frame_table_list)1229 StatusOr<TplGopDepStats> ComputeTplGopDepStats(
1230 const TplGopStats &tpl_gop_stats,
1231 const std::vector<LookaheadStats> &lookahead_stats,
1232 const std::vector<RefFrameTable> &ref_frame_table_list) {
1233 std::vector<const TplFrameStats *> tpl_frame_stats_list_with_lookahead;
1234 for (const auto &tpl_frame_stats : tpl_gop_stats.frame_stats_list) {
1235 tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
1236 }
1237 for (const auto &lookahead_stat : lookahead_stats) {
1238 for (const auto &tpl_frame_stats :
1239 lookahead_stat.tpl_gop_stats->frame_stats_list) {
1240 tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
1241 }
1242 }
1243
1244 const int frame_count =
1245 static_cast<int>(tpl_frame_stats_list_with_lookahead.size());
1246
1247 // Create the struct to store TPL dependency stats
1248 TplGopDepStats tpl_gop_dep_stats;
1249
1250 tpl_gop_dep_stats.frame_dep_stats_list.reserve(frame_count);
1251 for (int coding_idx = 0; coding_idx < frame_count; coding_idx++) {
1252 const StatusOr<TplFrameDepStats> tpl_frame_dep_stats =
1253 CreateTplFrameDepStatsWithoutPropagation(
1254 *tpl_frame_stats_list_with_lookahead[coding_idx]);
1255 if (!tpl_frame_dep_stats.ok()) {
1256 return tpl_frame_dep_stats.status();
1257 }
1258 tpl_gop_dep_stats.frame_dep_stats_list.push_back(
1259 std::move(*tpl_frame_dep_stats));
1260 }
1261
1262 // Back propagation
1263 for (int coding_idx = frame_count - 1; coding_idx >= 0; coding_idx--) {
1264 auto &ref_frame_table = ref_frame_table_list[coding_idx];
1265 // TODO(angiebird): Handle/test the case where reference frame
1266 // is in the previous GOP
1267 TplFrameDepStatsPropagate(coding_idx, ref_frame_table, &tpl_gop_dep_stats);
1268 }
1269 return tpl_gop_dep_stats;
1270 }
1271
SetupDeltaQ(const TplFrameDepStats & frame_dep_stats,int frame_width,int frame_height,int base_qindex,double frame_importance)1272 static std::vector<uint8_t> SetupDeltaQ(const TplFrameDepStats &frame_dep_stats,
1273 int frame_width, int frame_height,
1274 int base_qindex,
1275 double frame_importance) {
1276 // TODO(jianj) : Add support to various superblock sizes.
1277 const int sb_size = 64;
1278 const int delta_q_res = 4;
1279 const int num_unit_per_sb = sb_size / frame_dep_stats.unit_size;
1280 const int sb_rows = (frame_height + sb_size - 1) / sb_size;
1281 const int sb_cols = (frame_width + sb_size - 1) / sb_size;
1282 const int unit_rows = (frame_height + frame_dep_stats.unit_size - 1) /
1283 frame_dep_stats.unit_size;
1284 const int unit_cols =
1285 (frame_width + frame_dep_stats.unit_size - 1) / frame_dep_stats.unit_size;
1286 std::vector<uint8_t> superblock_q_indices;
1287 // Calculate delta_q offset for each superblock.
1288 for (int sb_row = 0; sb_row < sb_rows; ++sb_row) {
1289 for (int sb_col = 0; sb_col < sb_cols; ++sb_col) {
1290 double intra_cost = 0;
1291 double mc_dep_cost = 0;
1292 const int unit_row_start = sb_row * num_unit_per_sb;
1293 const int unit_row_end =
1294 std::min((sb_row + 1) * num_unit_per_sb, unit_rows);
1295 const int unit_col_start = sb_col * num_unit_per_sb;
1296 const int unit_col_end =
1297 std::min((sb_col + 1) * num_unit_per_sb, unit_cols);
1298 // A simplified version of av1_get_q_for_deltaq_objective()
1299 for (int unit_row = unit_row_start; unit_row < unit_row_end; ++unit_row) {
1300 for (int unit_col = unit_col_start; unit_col < unit_col_end;
1301 ++unit_col) {
1302 const TplUnitDepStats &unit_dep_stat =
1303 frame_dep_stats.unit_stats[unit_row][unit_col];
1304 intra_cost += unit_dep_stat.intra_cost;
1305 mc_dep_cost += unit_dep_stat.propagation_cost;
1306 }
1307 }
1308
1309 double beta = 1.0;
1310 if (mc_dep_cost > 0 && intra_cost > 0) {
1311 const double r0 = 1 / frame_importance;
1312 const double rk = intra_cost / mc_dep_cost;
1313 beta = r0 / rk;
1314 assert(beta > 0.0);
1315 }
1316 int offset = av1_get_deltaq_offset(AOM_BITS_8, base_qindex, beta);
1317 offset = std::min(offset, delta_q_res * 9 - 1);
1318 offset = std::max(offset, -delta_q_res * 9 + 1);
1319 int qindex = offset + base_qindex;
1320 qindex = std::min(qindex, MAXQ);
1321 qindex = std::max(qindex, MINQ);
1322 qindex = av1_adjust_q_from_delta_q_res(delta_q_res, base_qindex, qindex);
1323 superblock_q_indices.push_back(static_cast<uint8_t>(qindex));
1324 }
1325 }
1326
1327 return superblock_q_indices;
1328 }
1329
FindKMeansClusterMap(const std::vector<uint8_t> & qindices,const std::vector<double> & centroids)1330 static std::unordered_map<int, double> FindKMeansClusterMap(
1331 const std::vector<uint8_t> &qindices,
1332 const std::vector<double> ¢roids) {
1333 std::unordered_map<int, double> cluster_map;
1334 for (const uint8_t qindex : qindices) {
1335 double nearest_centroid = *std::min_element(
1336 centroids.begin(), centroids.end(),
1337 [qindex](const double centroid_a, const double centroid_b) {
1338 return fabs(centroid_a - qindex) < fabs(centroid_b - qindex);
1339 });
1340 cluster_map.insert({ qindex, nearest_centroid });
1341 }
1342 return cluster_map;
1343 }
1344
1345 namespace internal {
1346
KMeans(std::vector<uint8_t> qindices,int k)1347 std::unordered_map<int, int> KMeans(std::vector<uint8_t> qindices, int k) {
1348 std::vector<double> centroids;
1349 // Initialize the centroids with first k qindices
1350 std::unordered_set<int> qindices_set;
1351
1352 for (const uint8_t qp : qindices) {
1353 if (!qindices_set.insert(qp).second) continue; // Already added.
1354 centroids.push_back(qp);
1355 if (static_cast<int>(centroids.size()) >= k) break;
1356 }
1357
1358 std::unordered_map<int, double> intermediate_cluster_map;
1359 while (true) {
1360 // Find the closest centroid for each qindex
1361 intermediate_cluster_map = FindKMeansClusterMap(qindices, centroids);
1362 // For each cluster, calculate the new centroids
1363 std::unordered_map<double, std::vector<int>> centroid_to_qindices;
1364 for (const auto &qindex_centroid : intermediate_cluster_map) {
1365 centroid_to_qindices[qindex_centroid.second].push_back(
1366 qindex_centroid.first);
1367 }
1368 bool centroids_changed = false;
1369 std::vector<double> new_centroids;
1370 for (const auto &cluster : centroid_to_qindices) {
1371 double sum = 0.0;
1372 for (const int qindex : cluster.second) {
1373 sum += qindex;
1374 }
1375 double new_centroid = sum / cluster.second.size();
1376 new_centroids.push_back(new_centroid);
1377 if (new_centroid != cluster.first) centroids_changed = true;
1378 }
1379 if (!centroids_changed) break;
1380 centroids = new_centroids;
1381 }
1382 std::unordered_map<int, int> cluster_map;
1383 for (const auto &qindex_centroid : intermediate_cluster_map) {
1384 cluster_map.insert(
1385 { qindex_centroid.first, static_cast<int>(qindex_centroid.second) });
1386 }
1387 return cluster_map;
1388 }
1389 } // namespace internal
1390
GetRDMult(const GopFrame & gop_frame,int q_index)1391 static int GetRDMult(const GopFrame &gop_frame, int q_index) {
1392 // TODO(angiebird):
1393 // 1) Check if these rdmult rules are good in our use case.
1394 // 2) Support high-bit-depth mode
1395 if (gop_frame.is_golden_frame) {
1396 // Assume ARF_UPDATE/GF_UPDATE share the same remult rule.
1397 return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, GF_UPDATE, q_index);
1398 } else if (gop_frame.is_key_frame) {
1399 return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, KF_UPDATE, q_index);
1400 } else {
1401 // Assume LF_UPDATE/OVERLAY_UPDATE/INTNL_OVERLAY_UPDATE/INTNL_ARF_UPDATE
1402 // share the same remult rule.
1403 return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, LF_UPDATE, q_index);
1404 }
1405 }
1406
GetGopEncodeInfoWithNoStats(const GopStruct & gop_struct)1407 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithNoStats(
1408 const GopStruct &gop_struct) {
1409 GopEncodeInfo gop_encode_info;
1410 const int frame_count = static_cast<int>(gop_struct.gop_frame_list.size());
1411 for (int i = 0; i < frame_count; i++) {
1412 FrameEncodeParameters param;
1413 const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
1414 // Use constant QP for TPL pass encoding. Keep the functionality
1415 // that allows QP changes across sub-gop.
1416 param.q_index = rc_param_.base_q_index;
1417 param.rdmult = av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, LF_UPDATE,
1418 rc_param_.base_q_index);
1419 // TODO(jingning): gop_frame is needed in two pass tpl later.
1420 (void)gop_frame;
1421
1422 if (rc_param_.tpl_pass_index) {
1423 if (gop_frame.update_type == GopFrameType::kRegularGolden ||
1424 gop_frame.update_type == GopFrameType::kRegularKey ||
1425 gop_frame.update_type == GopFrameType::kRegularArf) {
1426 double qstep_ratio = 1 / 3.0;
1427 param.q_index = av1_get_q_index_from_qstep_ratio(
1428 rc_param_.base_q_index, qstep_ratio, AOM_BITS_8);
1429 if (rc_param_.base_q_index) param.q_index = AOMMAX(param.q_index, 1);
1430 }
1431 }
1432 gop_encode_info.param_list.push_back(param);
1433 }
1434 return gop_encode_info;
1435 }
1436
GetGopEncodeInfoWithFp(const GopStruct & gop_struct,const FirstpassInfo & firstpass_info AOM_UNUSED)1437 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithFp(
1438 const GopStruct &gop_struct,
1439 const FirstpassInfo &firstpass_info AOM_UNUSED) {
1440 // TODO(b/260859962): This is currently a placeholder. Should use the fp
1441 // stats to calculate frame-level qp.
1442 return GetGopEncodeInfoWithNoStats(gop_struct);
1443 }
1444
GetGopEncodeInfoWithTpl(const GopStruct & gop_struct,const TplGopStats & tpl_gop_stats,const std::vector<LookaheadStats> & lookahead_stats,const RefFrameTable & ref_frame_table_snapshot_init)1445 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithTpl(
1446 const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
1447 const std::vector<LookaheadStats> &lookahead_stats,
1448 const RefFrameTable &ref_frame_table_snapshot_init) {
1449 const std::vector<RefFrameTable> ref_frame_table_list = GetRefFrameTableList(
1450 gop_struct, lookahead_stats, ref_frame_table_snapshot_init);
1451
1452 GopEncodeInfo gop_encode_info;
1453 gop_encode_info.final_snapshot = ref_frame_table_list.back();
1454 StatusOr<TplGopDepStats> gop_dep_stats = ComputeTplGopDepStats(
1455 tpl_gop_stats, lookahead_stats, ref_frame_table_list);
1456 if (!gop_dep_stats.ok()) {
1457 return gop_dep_stats.status();
1458 }
1459 const int frame_count =
1460 static_cast<int>(tpl_gop_stats.frame_stats_list.size());
1461 const int active_worst_quality = rc_param_.base_q_index;
1462 int active_best_quality = rc_param_.base_q_index;
1463 for (int i = 0; i < frame_count; i++) {
1464 FrameEncodeParameters param;
1465 const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
1466
1467 if (gop_frame.update_type == GopFrameType::kOverlay ||
1468 gop_frame.update_type == GopFrameType::kIntermediateOverlay ||
1469 gop_frame.update_type == GopFrameType::kRegularLeaf) {
1470 param.q_index = rc_param_.base_q_index;
1471 } else if (gop_frame.update_type == GopFrameType::kRegularGolden ||
1472 gop_frame.update_type == GopFrameType::kRegularKey ||
1473 gop_frame.update_type == GopFrameType::kRegularArf) {
1474 const TplFrameDepStats &frame_dep_stats =
1475 gop_dep_stats->frame_dep_stats_list[i];
1476 const double cost_without_propagation =
1477 TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
1478 const double cost_with_propagation =
1479 TplFrameDepStatsAccumulate(frame_dep_stats);
1480 const double frame_importance =
1481 cost_with_propagation / cost_without_propagation;
1482 // Imitate the behavior of av1_tpl_get_qstep_ratio()
1483 const double qstep_ratio = sqrt(1 / frame_importance);
1484 param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
1485 qstep_ratio, AOM_BITS_8);
1486 if (rc_param_.base_q_index) param.q_index = AOMMAX(param.q_index, 1);
1487 active_best_quality = param.q_index;
1488
1489 if (rc_param_.max_distinct_q_indices_per_frame > 1) {
1490 std::vector<uint8_t> superblock_q_indices = SetupDeltaQ(
1491 frame_dep_stats, rc_param_.frame_width, rc_param_.frame_height,
1492 param.q_index, frame_importance);
1493 std::unordered_map<int, int> qindex_centroids = internal::KMeans(
1494 superblock_q_indices, rc_param_.max_distinct_q_indices_per_frame);
1495 for (size_t i = 0; i < superblock_q_indices.size(); ++i) {
1496 const int curr_sb_qindex =
1497 qindex_centroids.find(superblock_q_indices[i])->second;
1498 const int delta_q_res = 4;
1499 const int adjusted_qindex =
1500 param.q_index +
1501 (curr_sb_qindex - param.q_index) / delta_q_res * delta_q_res;
1502 const int rd_mult = GetRDMult(gop_frame, adjusted_qindex);
1503 param.superblock_encode_params.push_back(
1504 { static_cast<uint8_t>(adjusted_qindex), rd_mult });
1505 }
1506 }
1507 } else {
1508 // Intermediate ARFs
1509 assert(gop_frame.layer_depth >= 1);
1510 const int depth_factor = 1 << (gop_frame.layer_depth - 1);
1511 param.q_index =
1512 (active_worst_quality * (depth_factor - 1) + active_best_quality) /
1513 depth_factor;
1514 }
1515 param.rdmult = GetRDMult(gop_frame, param.q_index);
1516 gop_encode_info.param_list.push_back(param);
1517 }
1518 return gop_encode_info;
1519 }
1520
GetTplPassGopEncodeInfo(const GopStruct & gop_struct,const FirstpassInfo & firstpass_info)1521 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetTplPassGopEncodeInfo(
1522 const GopStruct &gop_struct, const FirstpassInfo &firstpass_info) {
1523 return GetGopEncodeInfoWithFp(gop_struct, firstpass_info);
1524 }
1525
GetGopEncodeInfo(const GopStruct & gop_struct,const TplGopStats & tpl_gop_stats,const std::vector<LookaheadStats> & lookahead_stats,const FirstpassInfo & firstpass_info AOM_UNUSED,const RefFrameTable & ref_frame_table_snapshot_init)1526 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfo(
1527 const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
1528 const std::vector<LookaheadStats> &lookahead_stats,
1529 const FirstpassInfo &firstpass_info AOM_UNUSED,
1530 const RefFrameTable &ref_frame_table_snapshot_init) {
1531 // When TPL stats are not valid, use first pass stats.
1532 Status status = ValidateTplStats(gop_struct, tpl_gop_stats);
1533 if (!status.ok()) {
1534 return status;
1535 }
1536
1537 for (const auto &lookahead_stat : lookahead_stats) {
1538 Status status = ValidateTplStats(*lookahead_stat.gop_struct,
1539 *lookahead_stat.tpl_gop_stats);
1540 if (!status.ok()) {
1541 return status;
1542 }
1543 }
1544
1545 // TODO(b/260859962): Currently firstpass stats are used as an alternative,
1546 // but we could also combine it with tpl results in the future for more
1547 // stable qp determination.
1548 return GetGopEncodeInfoWithTpl(gop_struct, tpl_gop_stats, lookahead_stats,
1549 ref_frame_table_snapshot_init);
1550 }
1551
1552 } // namespace aom
1553