• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 #include "av1/qmode_rc/ratectrl_qmode.h"
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <climits>
16 #include <functional>
17 #include <numeric>
18 #include <sstream>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "aom/aom_codec.h"
24 #include "av1/encoder/pass2_strategy.h"
25 #include "av1/encoder/tpl_model.h"
26 
27 namespace aom {
28 
29 // This is used before division to ensure that the divisor isn't zero or
30 // too close to zero.
ModifyDivisor(double divisor)31 static double ModifyDivisor(double divisor) {
32   const double kEpsilon = 0.0000001;
33   return (divisor < 0 ? std::min(divisor, -kEpsilon)
34                       : std::max(divisor, kEpsilon));
35 }
36 
GopFrameInvalid()37 GopFrame GopFrameInvalid() {
38   GopFrame gop_frame = {};
39   gop_frame.is_valid = false;
40   gop_frame.coding_idx = -1;
41   gop_frame.order_idx = -1;
42   return gop_frame;
43 }
44 
SetGopFrameByType(GopFrameType gop_frame_type,GopFrame * gop_frame)45 void SetGopFrameByType(GopFrameType gop_frame_type, GopFrame *gop_frame) {
46   gop_frame->update_type = gop_frame_type;
47   switch (gop_frame_type) {
48     case GopFrameType::kRegularKey:
49       gop_frame->is_key_frame = 1;
50       gop_frame->is_arf_frame = 0;
51       gop_frame->is_show_frame = 1;
52       gop_frame->is_golden_frame = 1;
53       gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
54       break;
55     case GopFrameType::kRegularGolden:
56       gop_frame->is_key_frame = 0;
57       gop_frame->is_arf_frame = 0;
58       gop_frame->is_show_frame = 1;
59       gop_frame->is_golden_frame = 1;
60       gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
61       break;
62     case GopFrameType::kRegularArf:
63       gop_frame->is_key_frame = 0;
64       gop_frame->is_arf_frame = 1;
65       gop_frame->is_show_frame = 0;
66       gop_frame->is_golden_frame = 1;
67       gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
68       break;
69     case GopFrameType::kIntermediateArf:
70       gop_frame->is_key_frame = 0;
71       gop_frame->is_arf_frame = 1;
72       gop_frame->is_show_frame = 0;
73       gop_frame->is_golden_frame = gop_frame->layer_depth <= 2 ? 1 : 0;
74       gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
75       break;
76     case GopFrameType::kRegularLeaf:
77       gop_frame->is_key_frame = 0;
78       gop_frame->is_arf_frame = 0;
79       gop_frame->is_show_frame = 1;
80       gop_frame->is_golden_frame = 0;
81       gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
82       break;
83     case GopFrameType::kIntermediateOverlay:
84       gop_frame->is_key_frame = 0;
85       gop_frame->is_arf_frame = 0;
86       gop_frame->is_show_frame = 1;
87       gop_frame->is_golden_frame = 0;
88       gop_frame->encode_ref_mode = EncodeRefMode::kShowExisting;
89       break;
90     case GopFrameType::kOverlay:
91       gop_frame->is_key_frame = 0;
92       gop_frame->is_arf_frame = 0;
93       gop_frame->is_show_frame = 1;
94       gop_frame->is_golden_frame = 0;
95       gop_frame->encode_ref_mode = EncodeRefMode::kOverlay;
96       break;
97   }
98 }
99 
GopFrameBasic(int global_coding_idx_offset,int global_order_idx_offset,int coding_idx,int order_idx,int depth,int display_idx,GopFrameType gop_frame_type)100 GopFrame GopFrameBasic(int global_coding_idx_offset,
101                        int global_order_idx_offset, int coding_idx,
102                        int order_idx, int depth, int display_idx,
103                        GopFrameType gop_frame_type) {
104   GopFrame gop_frame = {};
105   gop_frame.is_valid = true;
106   gop_frame.coding_idx = coding_idx;
107   gop_frame.order_idx = order_idx;
108   gop_frame.display_idx = display_idx;
109   gop_frame.global_coding_idx = global_coding_idx_offset + coding_idx;
110   gop_frame.global_order_idx = global_order_idx_offset + order_idx;
111   gop_frame.layer_depth = depth + kLayerDepthOffset;
112   gop_frame.colocated_ref_idx = -1;
113   gop_frame.update_ref_idx = -1;
114   SetGopFrameByType(gop_frame_type, &gop_frame);
115   return gop_frame;
116 }
117 
118 // This function create gop frames with indices of display order from
119 // order_start to order_end - 1. The function will recursively introduce
120 // intermediate ARF untill maximum depth is met or the number of regular frames
121 // in between two ARFs are less than 3. Than the regular frames will be added
122 // into the gop_struct.
ConstructGopMultiLayer(GopStruct * gop_struct,RefFrameManager * ref_frame_manager,int max_depth,int depth,int order_start,int order_end)123 void ConstructGopMultiLayer(GopStruct *gop_struct,
124                             RefFrameManager *ref_frame_manager, int max_depth,
125                             int depth, int order_start, int order_end) {
126   GopFrame gop_frame;
127   int num_frames = order_end - order_start;
128   const int global_coding_idx_offset = gop_struct->global_coding_idx_offset;
129   const int global_order_idx_offset = gop_struct->global_order_idx_offset;
130   // If there are less than kMinIntervalToAddArf frames, stop introducing ARF
131   if (depth < max_depth && num_frames >= kMinIntervalToAddArf) {
132     int order_mid = (order_start + order_end) / 2;
133     // intermediate ARF
134     gop_frame = GopFrameBasic(
135         global_coding_idx_offset, global_order_idx_offset,
136         static_cast<int>(gop_struct->gop_frame_list.size()), order_mid, depth,
137         gop_struct->display_tracker, GopFrameType::kIntermediateArf);
138     ref_frame_manager->UpdateRefFrameTable(&gop_frame);
139     gop_struct->gop_frame_list.push_back(gop_frame);
140     ConstructGopMultiLayer(gop_struct, ref_frame_manager, max_depth, depth + 1,
141                            order_start, order_mid);
142     // show existing intermediate ARF
143     gop_frame =
144         GopFrameBasic(global_coding_idx_offset, global_order_idx_offset,
145                       static_cast<int>(gop_struct->gop_frame_list.size()),
146                       order_mid, max_depth, gop_struct->display_tracker,
147                       GopFrameType::kIntermediateOverlay);
148     ref_frame_manager->UpdateRefFrameTable(&gop_frame);
149     gop_struct->gop_frame_list.push_back(gop_frame);
150     ++gop_struct->display_tracker;
151     ConstructGopMultiLayer(gop_struct, ref_frame_manager, max_depth, depth + 1,
152                            order_mid + 1, order_end);
153   } else {
154     // regular frame
155     for (int i = order_start; i < order_end; ++i) {
156       gop_frame = GopFrameBasic(
157           global_coding_idx_offset, global_order_idx_offset,
158           static_cast<int>(gop_struct->gop_frame_list.size()), i, max_depth,
159           gop_struct->display_tracker, GopFrameType::kRegularLeaf);
160       ref_frame_manager->UpdateRefFrameTable(&gop_frame);
161       gop_struct->gop_frame_list.push_back(gop_frame);
162       ++gop_struct->display_tracker;
163     }
164   }
165 }
166 
ConstructGop(RefFrameManager * ref_frame_manager,int show_frame_count,bool has_key_frame,int global_coding_idx_offset,int global_order_idx_offset)167 GopStruct ConstructGop(RefFrameManager *ref_frame_manager, int show_frame_count,
168                        bool has_key_frame, int global_coding_idx_offset,
169                        int global_order_idx_offset) {
170   GopStruct gop_struct;
171   gop_struct.show_frame_count = show_frame_count;
172   gop_struct.global_coding_idx_offset = global_coding_idx_offset;
173   gop_struct.global_order_idx_offset = global_order_idx_offset;
174   int order_start = 0;
175   int order_end = show_frame_count - 1;
176 
177   // TODO(jingning): Re-enable the use of pyramid coding structure.
178   bool has_arf_frame = show_frame_count > kMinIntervalToAddArf;
179 
180   gop_struct.display_tracker = 0;
181 
182   GopFrame gop_frame;
183   if (has_key_frame) {
184     const int key_frame_depth = -1;
185     ref_frame_manager->Reset();
186     gop_frame = GopFrameBasic(
187         global_coding_idx_offset, global_order_idx_offset,
188         static_cast<int>(gop_struct.gop_frame_list.size()), order_start,
189         key_frame_depth, gop_struct.display_tracker, GopFrameType::kRegularKey);
190     ref_frame_manager->UpdateRefFrameTable(&gop_frame);
191     gop_struct.gop_frame_list.push_back(gop_frame);
192     order_start++;
193     ++gop_struct.display_tracker;
194   }
195 
196   const int arf_depth = 0;
197   if (has_arf_frame) {
198     // Use multi-layer pyrmaid coding structure.
199     gop_frame = GopFrameBasic(
200         global_coding_idx_offset, global_order_idx_offset,
201         static_cast<int>(gop_struct.gop_frame_list.size()), order_end,
202         arf_depth, gop_struct.display_tracker, GopFrameType::kRegularArf);
203     ref_frame_manager->UpdateRefFrameTable(&gop_frame);
204     gop_struct.gop_frame_list.push_back(gop_frame);
205     ConstructGopMultiLayer(&gop_struct, ref_frame_manager,
206                            ref_frame_manager->MaxRefFrame() - 1, arf_depth + 1,
207                            order_start, order_end);
208     // Overlay
209     gop_frame =
210         GopFrameBasic(global_coding_idx_offset, global_order_idx_offset,
211                       static_cast<int>(gop_struct.gop_frame_list.size()),
212                       order_end, ref_frame_manager->MaxRefFrame() - 1,
213                       gop_struct.display_tracker, GopFrameType::kOverlay);
214     ref_frame_manager->UpdateRefFrameTable(&gop_frame);
215     gop_struct.gop_frame_list.push_back(gop_frame);
216     ++gop_struct.display_tracker;
217   } else {
218     // Use IPPP format.
219     for (int i = order_start; i <= order_end; ++i) {
220       gop_frame = GopFrameBasic(
221           global_coding_idx_offset, global_order_idx_offset,
222           static_cast<int>(gop_struct.gop_frame_list.size()), i, arf_depth + 1,
223           gop_struct.display_tracker, GopFrameType::kRegularLeaf);
224       ref_frame_manager->UpdateRefFrameTable(&gop_frame);
225       gop_struct.gop_frame_list.push_back(gop_frame);
226       ++gop_struct.display_tracker;
227     }
228   }
229 
230   return gop_struct;
231 }
232 
SetRcParam(const RateControlParam & rc_param)233 Status AV1RateControlQMode::SetRcParam(const RateControlParam &rc_param) {
234   std::ostringstream error_message;
235   if (rc_param.max_gop_show_frame_count <
236       std::max(4, rc_param.min_gop_show_frame_count)) {
237     error_message << "max_gop_show_frame_count ("
238                   << rc_param.max_gop_show_frame_count
239                   << ") must be at least 4 and may not be less than "
240                      "min_gop_show_frame_count ("
241                   << rc_param.min_gop_show_frame_count << ")";
242     return { AOM_CODEC_INVALID_PARAM, error_message.str() };
243   }
244   if (rc_param.ref_frame_table_size < 1 || rc_param.ref_frame_table_size > 8) {
245     error_message << "ref_frame_table_size (" << rc_param.ref_frame_table_size
246                   << ") must be in the range [1, 8].";
247     return { AOM_CODEC_INVALID_PARAM, error_message.str() };
248   }
249   if (rc_param.max_ref_frames < 1 || rc_param.max_ref_frames > 7) {
250     error_message << "max_ref_frames (" << rc_param.max_ref_frames
251                   << ") must be in the range [1, 7].";
252     return { AOM_CODEC_INVALID_PARAM, error_message.str() };
253   }
254   if (rc_param.base_q_index < 0 || rc_param.base_q_index > 255) {
255     error_message << "base_q_index (" << rc_param.base_q_index
256                   << ") must be in the range [0, 255].";
257     return { AOM_CODEC_INVALID_PARAM, error_message.str() };
258   }
259   if (rc_param.frame_width < 16 || rc_param.frame_width > 16384 ||
260       rc_param.frame_height < 16 || rc_param.frame_height > 16384) {
261     error_message << "frame_width (" << rc_param.frame_width
262                   << ") and frame_height (" << rc_param.frame_height
263                   << ") must be in the range [16, 16384].";
264     return { AOM_CODEC_INVALID_PARAM, error_message.str() };
265   }
266   rc_param_ = rc_param;
267   return { AOM_CODEC_OK, "" };
268 }
269 
270 // Threshold for use of the lagging second reference frame. High second ref
271 // usage may point to a transient event like a flash or occlusion rather than
272 // a real scene cut.
273 // We adapt the threshold based on number of frames in this key-frame group so
274 // far.
GetSecondRefUsageThreshold(int frame_count_so_far)275 static double GetSecondRefUsageThreshold(int frame_count_so_far) {
276   const int adapt_upto = 32;
277   const double min_second_ref_usage_thresh = 0.085;
278   const double second_ref_usage_thresh_max_delta = 0.035;
279   if (frame_count_so_far >= adapt_upto) {
280     return min_second_ref_usage_thresh + second_ref_usage_thresh_max_delta;
281   }
282   return min_second_ref_usage_thresh +
283          ((double)frame_count_so_far / (adapt_upto - 1)) *
284              second_ref_usage_thresh_max_delta;
285 }
286 
287 // Slide show transition detection.
288 // Tests for case where there is very low error either side of the current frame
289 // but much higher just for this frame. This can help detect key frames in
290 // slide shows even where the slides are pictures of different sizes.
291 // Also requires that intra and inter errors are very similar to help eliminate
292 // harmful false positives.
293 // It will not help if the transition is a fade or other multi-frame effect.
DetectSlideTransition(const FIRSTPASS_STATS & this_frame,const FIRSTPASS_STATS & last_frame,const FIRSTPASS_STATS & next_frame)294 static bool DetectSlideTransition(const FIRSTPASS_STATS &this_frame,
295                                   const FIRSTPASS_STATS &last_frame,
296                                   const FIRSTPASS_STATS &next_frame) {
297   // Intra / Inter threshold very low
298   constexpr double kVeryLowII = 1.5;
299   // Clean slide transitions we expect a sharp single frame spike in error.
300   constexpr double kErrorSpike = 5.0;
301 
302   // TODO(angiebird): Understand the meaning of these conditions.
303   return (this_frame.intra_error < (this_frame.coded_error * kVeryLowII)) &&
304          (this_frame.coded_error > (last_frame.coded_error * kErrorSpike)) &&
305          (this_frame.coded_error > (next_frame.coded_error * kErrorSpike));
306 }
307 
308 // Check if there is a significant intra/inter error change between the current
309 // frame and its neighbor. If so, we should further test whether the current
310 // frame should be a key frame.
DetectIntraInterErrorChange(const FIRSTPASS_STATS & this_stats,const FIRSTPASS_STATS & last_stats,const FIRSTPASS_STATS & next_stats)311 static bool DetectIntraInterErrorChange(const FIRSTPASS_STATS &this_stats,
312                                         const FIRSTPASS_STATS &last_stats,
313                                         const FIRSTPASS_STATS &next_stats) {
314   // Minimum % intra coding observed in first pass (1.0 = 100%)
315   constexpr double kMinIntraLevel = 0.25;
316   // Minimum ratio between the % of intra coding and inter coding in the first
317   // pass after discounting neutral blocks (discounting neutral blocks in this
318   // way helps catch scene cuts in clips with very flat areas or letter box
319   // format clips with image padding.
320   constexpr double kIntraVsInterRatio = 2.0;
321 
322   const double modified_pcnt_inter =
323       this_stats.pcnt_inter - this_stats.pcnt_neutral;
324   const double pcnt_intra_min =
325       std::max(kMinIntraLevel, kIntraVsInterRatio * modified_pcnt_inter);
326 
327   // In real scene cuts there is almost always a sharp change in the intra
328   // or inter error score.
329   constexpr double kErrorChangeThreshold = 0.4;
330   const double last_this_error_ratio =
331       fabs(last_stats.coded_error - this_stats.coded_error) /
332       ModifyDivisor(this_stats.coded_error);
333 
334   const double this_next_error_ratio =
335       fabs(last_stats.intra_error - this_stats.intra_error) /
336       ModifyDivisor(this_stats.intra_error);
337 
338   // Maximum threshold for the relative ratio of intra error score vs best
339   // inter error score.
340   constexpr double kThisIntraCodedErrorRatioMax = 1.9;
341   const double this_intra_coded_error_ratio =
342       this_stats.intra_error / ModifyDivisor(this_stats.coded_error);
343 
344   // For real scene cuts we expect an improvment in the intra inter error
345   // ratio in the next frame.
346   constexpr double kNextIntraCodedErrorRatioMin = 3.5;
347   const double next_intra_coded_error_ratio =
348       next_stats.intra_error / ModifyDivisor(next_stats.coded_error);
349 
350   double pcnt_intra = 1.0 - this_stats.pcnt_inter;
351   return pcnt_intra > pcnt_intra_min &&
352          this_intra_coded_error_ratio < kThisIntraCodedErrorRatioMax &&
353          (last_this_error_ratio > kErrorChangeThreshold ||
354           this_next_error_ratio > kErrorChangeThreshold ||
355           next_intra_coded_error_ratio > kNextIntraCodedErrorRatioMin);
356 }
357 
358 // Check whether the candidate can be a key frame.
359 // This is a rewrite of test_candidate_kf().
TestCandidateKey(const FirstpassInfo & first_pass_info,int candidate_key_idx,int frames_since_prev_key)360 static bool TestCandidateKey(const FirstpassInfo &first_pass_info,
361                              int candidate_key_idx, int frames_since_prev_key) {
362   const auto &stats_list = first_pass_info.stats_list;
363   const int stats_count = static_cast<int>(stats_list.size());
364   if (candidate_key_idx + 1 >= stats_count || candidate_key_idx - 1 < 0) {
365     return false;
366   }
367   const auto &last_stats = stats_list[candidate_key_idx - 1];
368   const auto &this_stats = stats_list[candidate_key_idx];
369   const auto &next_stats = stats_list[candidate_key_idx + 1];
370 
371   if (frames_since_prev_key < 3) return false;
372   const double second_ref_usage_threshold =
373       GetSecondRefUsageThreshold(frames_since_prev_key);
374   if (this_stats.pcnt_second_ref >= second_ref_usage_threshold) return false;
375   if (next_stats.pcnt_second_ref >= second_ref_usage_threshold) return false;
376 
377   // Hard threshold where the first pass chooses intra for almost all blocks.
378   // In such a case even if the frame is not a scene cut coding a key frame
379   // may be a good option.
380   constexpr double kVeryLowInterThreshold = 0.05;
381   if (this_stats.pcnt_inter < kVeryLowInterThreshold ||
382       DetectSlideTransition(this_stats, last_stats, next_stats) ||
383       DetectIntraInterErrorChange(this_stats, last_stats, next_stats)) {
384     double boost_score = 0.0;
385     double decay_accumulator = 1.0;
386 
387     // We do "-1" because the candidate key is not counted.
388     int stats_after_this_stats = stats_count - candidate_key_idx - 1;
389 
390     // Number of frames required to test for scene cut detection
391     constexpr int kSceneCutKeyTestIntervalMax = 16;
392 
393     // Make sure we have enough stats after the candidate key.
394     const int frames_to_test_after_candidate_key =
395         std::min(kSceneCutKeyTestIntervalMax, stats_after_this_stats);
396 
397     // Examine how well the key frame predicts subsequent frames.
398     int i;
399     for (i = 1; i <= frames_to_test_after_candidate_key; ++i) {
400       // Get the next frame details
401       const auto &stats = stats_list[candidate_key_idx + i];
402 
403       // Cumulative effect of decay in prediction quality.
404       if (stats.pcnt_inter > 0.85) {
405         decay_accumulator *= stats.pcnt_inter;
406       } else {
407         decay_accumulator *= (0.85 + stats.pcnt_inter) / 2.0;
408       }
409 
410       constexpr double kBoostFactor = 12.5;
411       double next_iiratio =
412           (kBoostFactor * stats.intra_error / ModifyDivisor(stats.coded_error));
413       next_iiratio = std::min(next_iiratio, 128.0);
414       double boost_score_increment = decay_accumulator * next_iiratio;
415 
416       // Keep a running total.
417       boost_score += boost_score_increment;
418 
419       // Test various breakout clauses.
420       // TODO(any): Test of intra error should be normalized to an MB.
421       // TODO(angiebird): Investigate the following questions.
422       // Question 1: next_iiratio (intra_error / coded_error) * kBoostFactor
423       // We know intra_error / coded_error >= 1 and kBoostFactor = 12.5,
424       // therefore, (intra_error / coded_error) * kBoostFactor will always
425       // greater than 1.5. Is "next_iiratio < 1.5" always false?
426       // Question 2: Similar to question 1, is "next_iiratio < 3.0" always true?
427       // Question 3: Why do we need to divide 200 with num_mbs_16x16?
428       if ((stats.pcnt_inter < 0.05) || (next_iiratio < 1.5) ||
429           (((stats.pcnt_inter - stats.pcnt_neutral) < 0.20) &&
430            (next_iiratio < 3.0)) ||
431           (boost_score_increment < 3.0) ||
432           (stats.intra_error <
433            (200.0 / static_cast<double>(first_pass_info.num_mbs_16x16)))) {
434         break;
435       }
436     }
437 
438     // If there is tolerable prediction for at least the next 3 frames then
439     // break out else discard this potential key frame and move on
440     const int count_for_tolerable_prediction = 3;
441     if (boost_score > 30.0 && (i > count_for_tolerable_prediction)) {
442       return true;
443     }
444   }
445   return false;
446 }
447 
448 // Compute key frame location from first_pass_info.
GetKeyFrameList(const FirstpassInfo & first_pass_info)449 std::vector<int> GetKeyFrameList(const FirstpassInfo &first_pass_info) {
450   std::vector<int> key_frame_list;
451   key_frame_list.push_back(0);  // The first frame is always a key frame
452   int candidate_key_idx = 1;
453   while (candidate_key_idx <
454          static_cast<int>(first_pass_info.stats_list.size())) {
455     const int frames_since_prev_key = candidate_key_idx - key_frame_list.back();
456     // Check for a scene cut.
457     const bool scenecut_detected = TestCandidateKey(
458         first_pass_info, candidate_key_idx, frames_since_prev_key);
459     if (scenecut_detected) {
460       key_frame_list.push_back(candidate_key_idx);
461     }
462     ++candidate_key_idx;
463   }
464   return key_frame_list;
465 }
466 
467 // initialize GF_GROUP_STATS
InitGFStats(GF_GROUP_STATS * gf_stats)468 static void InitGFStats(GF_GROUP_STATS *gf_stats) {
469   gf_stats->gf_group_err = 0.0;
470   gf_stats->gf_group_raw_error = 0.0;
471   gf_stats->gf_group_skip_pct = 0.0;
472   gf_stats->gf_group_inactive_zone_rows = 0.0;
473 
474   gf_stats->mv_ratio_accumulator = 0.0;
475   gf_stats->decay_accumulator = 1.0;
476   gf_stats->zero_motion_accumulator = 1.0;
477   gf_stats->loop_decay_rate = 1.0;
478   gf_stats->last_loop_decay_rate = 1.0;
479   gf_stats->this_frame_mv_in_out = 0.0;
480   gf_stats->mv_in_out_accumulator = 0.0;
481   gf_stats->abs_mv_in_out_accumulator = 0.0;
482 
483   gf_stats->avg_sr_coded_error = 0.0;
484   gf_stats->avg_pcnt_second_ref = 0.0;
485   gf_stats->avg_new_mv_count = 0.0;
486   gf_stats->avg_wavelet_energy = 0.0;
487   gf_stats->avg_raw_err_stdev = 0.0;
488   gf_stats->non_zero_stdev_count = 0;
489 }
490 
FindRegionIndex(const std::vector<REGIONS> & regions,int frame_idx)491 static int FindRegionIndex(const std::vector<REGIONS> &regions, int frame_idx) {
492   for (int k = 0; k < static_cast<int>(regions.size()); k++) {
493     if (regions[k].start <= frame_idx && regions[k].last >= frame_idx) {
494       return k;
495     }
496   }
497   return -1;
498 }
499 
500 // This function detects a flash through the high relative pcnt_second_ref
501 // score in the frame following a flash frame. The offset passed in should
502 // reflect this.
DetectFlash(const std::vector<FIRSTPASS_STATS> & stats_list,int index)503 static bool DetectFlash(const std::vector<FIRSTPASS_STATS> &stats_list,
504                         int index) {
505   int next_index = index + 1;
506   if (next_index >= static_cast<int>(stats_list.size())) return false;
507   const FIRSTPASS_STATS &next_frame = stats_list[next_index];
508 
509   // What we are looking for here is a situation where there is a
510   // brief break in prediction (such as a flash) but subsequent frames
511   // are reasonably well predicted by an earlier (pre flash) frame.
512   // The recovery after a flash is indicated by a high pcnt_second_ref
513   // compared to pcnt_inter.
514   return next_frame.pcnt_second_ref > next_frame.pcnt_inter &&
515          next_frame.pcnt_second_ref >= 0.5;
516 }
517 
518 #define MIN_SHRINK_LEN 6
519 
520 // This function takes in a suggesting gop interval from cur_start to cur_last,
521 // analyzes firstpass stats and region stats and then return a better gop cut
522 // location.
523 // TODO(b/231517281): Simplify the indices once we have an unit test.
524 // We are using four indices here, order_index, cur_start, cur_last, and
525 // frames_since_key. Ideally, only three indices are needed.
526 // 1) start_index = order_index + cur_start
527 // 2) end_index = order_index + cur_end
528 // 3) key_index
FindBetterGopCut(const std::vector<FIRSTPASS_STATS> & stats_list,const std::vector<REGIONS> & regions_list,int min_gop_show_frame_count,int max_gop_show_frame_count,int order_index,int cur_start,int cur_last,int frames_since_key)529 int FindBetterGopCut(const std::vector<FIRSTPASS_STATS> &stats_list,
530                      const std::vector<REGIONS> &regions_list,
531                      int min_gop_show_frame_count, int max_gop_show_frame_count,
532                      int order_index, int cur_start, int cur_last,
533                      int frames_since_key) {
534   // only try shrinking if interval smaller than active_max_gf_interval
535   if (cur_last - cur_start > max_gop_show_frame_count ||
536       cur_start >= cur_last) {
537     return cur_last;
538   }
539   int num_regions = static_cast<int>(regions_list.size());
540   int num_stats = static_cast<int>(stats_list.size());
541   const int min_shrink_int = std::max(MIN_SHRINK_LEN, min_gop_show_frame_count);
542 
543   // find the region indices of where the first and last frame belong.
544   int k_start = FindRegionIndex(regions_list, cur_start + frames_since_key);
545   int k_last = FindRegionIndex(regions_list, cur_last + frames_since_key);
546   if (cur_start + frames_since_key == 0) k_start = 0;
547 
548   int scenecut_idx = -1;
549   // See if we have a scenecut in between
550   for (int r = k_start + 1; r <= k_last; r++) {
551     if (regions_list[r].type == SCENECUT_REGION &&
552         regions_list[r].last - frames_since_key - cur_start >
553             min_gop_show_frame_count) {
554       scenecut_idx = r;
555       break;
556     }
557   }
558 
559   // if the found scenecut is very close to the end, ignore it.
560   if (scenecut_idx >= 0 &&
561       regions_list[num_regions - 1].last - regions_list[scenecut_idx].last <
562           4) {
563     scenecut_idx = -1;
564   }
565 
566   if (scenecut_idx != -1) {
567     // If we have a scenecut, then stop at it.
568     // TODO(bohanli): add logic here to stop before the scenecut and for
569     // the next gop start from the scenecut with GF
570     int is_minor_sc =
571         (regions_list[scenecut_idx].avg_cor_coeff *
572              (1 - stats_list[order_index + regions_list[scenecut_idx].start -
573                              frames_since_key]
574                           .noise_var /
575                       regions_list[scenecut_idx].avg_intra_err) >
576          0.6);
577     cur_last =
578         regions_list[scenecut_idx].last - frames_since_key - !is_minor_sc;
579   } else {
580     int is_last_analysed =
581         (k_last == num_regions - 1) &&
582         (cur_last + frames_since_key == regions_list[k_last].last);
583     int not_enough_regions =
584         k_last - k_start <= 1 + (regions_list[k_start].type == SCENECUT_REGION);
585     // if we are very close to the end, then do not shrink since it may
586     // introduce intervals that are too short
587     if (!(is_last_analysed && not_enough_regions)) {
588       const double arf_length_factor = 0.1;
589       double best_score = 0;
590       int best_j = -1;
591       const int first_frame = regions_list[0].start - frames_since_key;
592       const int last_frame =
593           regions_list[num_regions - 1].last - frames_since_key;
594       // score of how much the arf helps the whole GOP
595       double base_score = 0.0;
596       // Accumulate base_score in
597       for (int j = cur_start + 1; j < cur_start + min_shrink_int; j++) {
598         if (order_index + j >= num_stats) break;
599         base_score = (base_score + 1.0) * stats_list[order_index + j].cor_coeff;
600       }
601       int met_blending = 0;   // Whether we have met blending areas before
602       int last_blending = 0;  // Whether the previous frame if blending
603       for (int j = cur_start + min_shrink_int; j <= cur_last; j++) {
604         if (order_index + j >= num_stats) break;
605         base_score = (base_score + 1.0) * stats_list[order_index + j].cor_coeff;
606         int this_reg = FindRegionIndex(regions_list, j + frames_since_key);
607         if (this_reg < 0) continue;
608         // A GOP should include at most 1 blending region.
609         if (regions_list[this_reg].type == BLENDING_REGION) {
610           last_blending = 1;
611           if (met_blending) {
612             break;
613           } else {
614             base_score = 0;
615             continue;
616           }
617         } else {
618           if (last_blending) met_blending = 1;
619           last_blending = 0;
620         }
621 
622         // Add the factor of how good the neighborhood is for this
623         // candidate arf.
624         double this_score = arf_length_factor * base_score;
625         double temp_accu_coeff = 1.0;
626         // following frames
627         int count_f = 0;
628         for (int n = j + 1; n <= j + 3 && n <= last_frame; n++) {
629           if (order_index + n >= num_stats) break;
630           temp_accu_coeff *= stats_list[order_index + n].cor_coeff;
631           this_score +=
632               temp_accu_coeff *
633               (1 - stats_list[order_index + n].noise_var /
634                        AOMMAX(regions_list[this_reg].avg_intra_err, 0.001));
635           count_f++;
636         }
637         // preceding frames
638         temp_accu_coeff = 1.0;
639         for (int n = j; n > j - 3 * 2 + count_f && n > first_frame; n--) {
640           if (order_index + n < 0) break;
641           temp_accu_coeff *= stats_list[order_index + n].cor_coeff;
642           this_score +=
643               temp_accu_coeff *
644               (1 - stats_list[order_index + n].noise_var /
645                        AOMMAX(regions_list[this_reg].avg_intra_err, 0.001));
646         }
647 
648         if (this_score > best_score) {
649           best_score = this_score;
650           best_j = j;
651         }
652       }
653 
654       // For blending areas, move one more frame in case we missed the
655       // first blending frame.
656       int best_reg = FindRegionIndex(regions_list, best_j + frames_since_key);
657       if (best_reg < num_regions - 1 && best_reg > 0) {
658         if (regions_list[best_reg - 1].type == BLENDING_REGION &&
659             regions_list[best_reg + 1].type == BLENDING_REGION) {
660           if (best_j + frames_since_key == regions_list[best_reg].start &&
661               best_j + frames_since_key < regions_list[best_reg].last) {
662             best_j += 1;
663           } else if (best_j + frames_since_key == regions_list[best_reg].last &&
664                      best_j + frames_since_key > regions_list[best_reg].start) {
665             best_j -= 1;
666           }
667         }
668       }
669 
670       if (cur_last - best_j < 2) best_j = cur_last;
671       if (best_j > 0 && best_score > 0.1) cur_last = best_j;
672       // if cannot find anything, just cut at the original place.
673     }
674   }
675 
676   return cur_last;
677 }
678 
679 // Function to test for a condition where a complex transition is followed
680 // by a static section. For example in slide shows where there is a fade
681 // between slides. This is to help with more optimal kf and gf positioning.
DetectTransitionToStill(const std::vector<FIRSTPASS_STATS> & stats_list,int next_stats_index,int min_gop_show_frame_count,int frame_interval,int still_interval,double loop_decay_rate,double last_decay_rate)682 static bool DetectTransitionToStill(
683     const std::vector<FIRSTPASS_STATS> &stats_list, int next_stats_index,
684     int min_gop_show_frame_count, int frame_interval, int still_interval,
685     double loop_decay_rate, double last_decay_rate) {
686   // Break clause to detect very still sections after motion
687   // For example a static image after a fade or other transition
688   // instead of a clean scene cut.
689   if (frame_interval > min_gop_show_frame_count && loop_decay_rate >= 0.999 &&
690       last_decay_rate < 0.9) {
691     int stats_count = static_cast<int>(stats_list.size());
692     int stats_left = stats_count - next_stats_index;
693     if (stats_left >= still_interval) {
694       // Look ahead a few frames to see if static condition persists...
695       int j;
696       for (j = 0; j < still_interval; ++j) {
697         const FIRSTPASS_STATS &stats = stats_list[next_stats_index + j];
698         if (stats.pcnt_inter - stats.pcnt_motion < 0.999) break;
699       }
700       // Only if it does do we signal a transition to still.
701       return j == still_interval;
702     }
703   }
704   return false;
705 }
706 
DetectGopCut(const std::vector<FIRSTPASS_STATS> & stats_list,int start_idx,int candidate_cut_idx,int next_key_idx,int flash_detected,int min_gop_show_frame_count,int max_gop_show_frame_count,int frame_width,int frame_height,const GF_GROUP_STATS & gf_stats)707 static int DetectGopCut(const std::vector<FIRSTPASS_STATS> &stats_list,
708                         int start_idx, int candidate_cut_idx, int next_key_idx,
709                         int flash_detected, int min_gop_show_frame_count,
710                         int max_gop_show_frame_count, int frame_width,
711                         int frame_height, const GF_GROUP_STATS &gf_stats) {
712   (void)max_gop_show_frame_count;
713   const int candidate_gop_size = candidate_cut_idx - start_idx;
714 
715   if (!flash_detected) {
716     // Break clause to detect very still sections after motion. For example,
717     // a static image after a fade or other transition.
718     if (DetectTransitionToStill(stats_list, start_idx, min_gop_show_frame_count,
719                                 candidate_gop_size, 5, gf_stats.loop_decay_rate,
720                                 gf_stats.last_loop_decay_rate)) {
721       return 1;
722     }
723     const double arf_abs_zoom_thresh = 4.4;
724     // Motion breakout threshold for loop below depends on image size.
725     const double mv_ratio_accumulator_thresh =
726         (frame_height + frame_width) / 4.0;
727     // Some conditions to breakout after min interval.
728     if (candidate_gop_size >= min_gop_show_frame_count &&
729         // If possible don't break very close to a kf
730         (next_key_idx - candidate_cut_idx >= min_gop_show_frame_count) &&
731         (candidate_gop_size & 0x01) &&
732         (gf_stats.mv_ratio_accumulator > mv_ratio_accumulator_thresh ||
733          gf_stats.abs_mv_in_out_accumulator > arf_abs_zoom_thresh)) {
734       return 1;
735     }
736   }
737 
738   // TODO(b/231489624): Check if we need this part.
739   // If almost totally static, we will not use the the max GF length later,
740   // so we can continue for more frames.
741   // if ((candidate_gop_size >= active_max_gf_interval + 1) &&
742   //     !is_almost_static(gf_stats->zero_motion_accumulator,
743   //                       twopass->kf_zeromotion_pct, cpi->ppi->lap_enabled)) {
744   //   return 0;
745   // }
746   return 0;
747 }
748 
749 /*!\brief Determine the length of future GF groups.
750  *
751  * \ingroup gf_group_algo
752  * This function decides the gf group length of future frames in batch
753  *
754  * \param[in]    rc_param         Rate control parameters
755  * \param[in]    stats_list       List of first pass stats
756  * \param[in]    regions_list     List of regions from av1_identify_regions
757  * \param[in]    order_index      Index of current frame in stats_list
758  * \param[in]    frames_since_key Number of frames since the last key frame
759  * \param[in]    frames_to_key    Number of frames to the next key frame
760  *
761  * \return Returns a vector of decided GF group lengths.
762  */
PartitionGopIntervals(const RateControlParam & rc_param,const std::vector<FIRSTPASS_STATS> & stats_list,const std::vector<REGIONS> & regions_list,int order_index,int frames_since_key,int frames_to_key)763 static std::vector<int> PartitionGopIntervals(
764     const RateControlParam &rc_param,
765     const std::vector<FIRSTPASS_STATS> &stats_list,
766     const std::vector<REGIONS> &regions_list, int order_index,
767     int frames_since_key, int frames_to_key) {
768   int i = 0;
769   // If cpi->gf_state.arf_gf_boost_lst is 0, we are starting with a KF or GF.
770   int cur_start = 0;
771   // Each element is the last frame of the previous GOP. If there are n GOPs,
772   // you need n + 1 cuts to find the durations. So cut_pos starts out with -1,
773   // which is the last frame of the previous GOP.
774   std::vector<int> cut_pos(1, -1);
775   int cut_here = 0;
776   GF_GROUP_STATS gf_stats;
777   InitGFStats(&gf_stats);
778   int num_stats = static_cast<int>(stats_list.size());
779 
780   while (i + order_index < num_stats) {
781     // reaches next key frame, break here
782     if (i >= frames_to_key - 1) {
783       cut_here = 2;
784     } else if (i - cur_start >= rc_param.max_gop_show_frame_count) {
785       // reached maximum len, but nothing special yet (almost static)
786       // let's look at the next interval
787       cut_here = 2;
788     } else {
789       // Test for the case where there is a brief flash but the prediction
790       // quality back to an earlier frame is then restored.
791       const int gop_start_idx = cur_start + order_index;
792       const int candidate_gop_cut_idx = i + order_index;
793       const int next_key_idx = frames_to_key + order_index;
794       const bool flash_detected =
795           DetectFlash(stats_list, candidate_gop_cut_idx);
796 
797       // TODO(bohanli): remove redundant accumulations here, or unify
798       // this and the ones in define_gf_group
799       const FIRSTPASS_STATS *stats = &stats_list[candidate_gop_cut_idx];
800       av1_accumulate_next_frame_stats(stats, flash_detected, frames_since_key,
801                                       i, &gf_stats, rc_param.frame_width,
802                                       rc_param.frame_height);
803 
804       // TODO(angiebird): Can we simplify this part? Looks like we are going to
805       // change the gop cut index with FindBetterGopCut() anyway.
806       cut_here = DetectGopCut(
807           stats_list, gop_start_idx, candidate_gop_cut_idx, next_key_idx,
808           flash_detected, rc_param.min_gop_show_frame_count,
809           rc_param.max_gop_show_frame_count, rc_param.frame_width,
810           rc_param.frame_height, gf_stats);
811     }
812 
813     if (!cut_here) {
814       ++i;
815       continue;
816     }
817 
818     // the current last frame in the gf group
819     int original_last = cut_here > 1 ? i : i - 1;
820     int cur_last = FindBetterGopCut(
821         stats_list, regions_list, rc_param.min_gop_show_frame_count,
822         rc_param.max_gop_show_frame_count, order_index, cur_start,
823         original_last, frames_since_key);
824     // only try shrinking if interval smaller than active_max_gf_interval
825     cut_pos.push_back(cur_last);
826 
827     // reset pointers to the shrunken location
828     cur_start = cur_last;
829     int cur_region_idx =
830         FindRegionIndex(regions_list, cur_start + 1 + frames_since_key);
831     if (cur_region_idx >= 0)
832       if (regions_list[cur_region_idx].type == SCENECUT_REGION) cur_start++;
833 
834     // reset accumulators
835     InitGFStats(&gf_stats);
836     i = cur_last + 1;
837 
838     if (cut_here == 2 && i >= frames_to_key) break;
839   }
840 
841   std::vector<int> gf_intervals;
842   // save intervals
843   for (size_t n = 1; n < cut_pos.size(); n++) {
844     gf_intervals.push_back(cut_pos[n] - cut_pos[n - 1]);
845   }
846 
847   return gf_intervals;
848 }
849 
DetermineGopInfo(const FirstpassInfo & firstpass_info)850 StatusOr<GopStructList> AV1RateControlQMode::DetermineGopInfo(
851     const FirstpassInfo &firstpass_info) {
852   const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
853   GopStructList gop_list;
854   RefFrameManager ref_frame_manager(rc_param_.ref_frame_table_size,
855                                     rc_param_.max_ref_frames);
856 
857   // Make a copy of the first pass stats, and analyze them
858   FirstpassInfo fp_info_copy = firstpass_info;
859   av1_mark_flashes(fp_info_copy.stats_list.data(),
860                    fp_info_copy.stats_list.data() + stats_size);
861   av1_estimate_noise(fp_info_copy.stats_list.data(),
862                      fp_info_copy.stats_list.data() + stats_size);
863   av1_estimate_coeff(fp_info_copy.stats_list.data(),
864                      fp_info_copy.stats_list.data() + stats_size);
865 
866   int global_coding_idx_offset = 0;
867   int global_order_idx_offset = 0;
868   std::vector<int> key_frame_list = GetKeyFrameList(fp_info_copy);
869   key_frame_list.push_back(stats_size);  // a sentinel value
870   for (size_t ki = 0; ki + 1 < key_frame_list.size(); ++ki) {
871     int frames_to_key = key_frame_list[ki + 1] - key_frame_list[ki];
872     int key_order_index = key_frame_list[ki];  // The key frame's display order
873 
874     std::vector<REGIONS> regions_list(MAX_FIRSTPASS_ANALYSIS_FRAMES);
875     int total_regions = 0;
876     av1_identify_regions(fp_info_copy.stats_list.data() + key_order_index,
877                          frames_to_key, 0, regions_list.data(), &total_regions);
878     regions_list.resize(total_regions);
879     std::vector<int> gf_intervals = PartitionGopIntervals(
880         rc_param_, fp_info_copy.stats_list, regions_list, key_order_index,
881         /*frames_since_key=*/0, frames_to_key);
882     for (size_t gi = 0; gi < gf_intervals.size(); ++gi) {
883       const bool has_key_frame = gi == 0;
884       const int show_frame_count = gf_intervals[gi];
885       GopStruct gop =
886           ConstructGop(&ref_frame_manager, show_frame_count, has_key_frame,
887                        global_coding_idx_offset, global_order_idx_offset);
888       assert(gop.show_frame_count == show_frame_count);
889       global_coding_idx_offset += static_cast<int>(gop.gop_frame_list.size());
890       global_order_idx_offset += gop.show_frame_count;
891       gop_list.push_back(gop);
892     }
893   }
894   return gop_list;
895 }
896 
CreateTplFrameDepStats(int frame_height,int frame_width,int min_block_size)897 TplFrameDepStats CreateTplFrameDepStats(int frame_height, int frame_width,
898                                         int min_block_size) {
899   const int unit_rows = (frame_height + min_block_size - 1) / min_block_size;
900   const int unit_cols = (frame_width + min_block_size - 1) / min_block_size;
901   TplFrameDepStats frame_dep_stats;
902   frame_dep_stats.unit_size = min_block_size;
903   frame_dep_stats.unit_stats.resize(unit_rows);
904   for (auto &row : frame_dep_stats.unit_stats) {
905     row.resize(unit_cols);
906   }
907   return frame_dep_stats;
908 }
909 
TplBlockStatsToDepStats(const TplBlockStats & block_stats,int unit_count)910 TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
911                                         int unit_count) {
912   TplUnitDepStats dep_stats = {};
913   dep_stats.intra_cost = block_stats.intra_cost * 1.0 / unit_count;
914   dep_stats.inter_cost = block_stats.inter_cost * 1.0 / unit_count;
915   // In rare case, inter_cost may be greater than intra_cost.
916   // If so, we need to modify inter_cost such that inter_cost <= intra_cost
917   // because it is required by GetPropagationFraction()
918   dep_stats.inter_cost = std::min(dep_stats.intra_cost, dep_stats.inter_cost);
919   dep_stats.mv = block_stats.mv;
920   dep_stats.ref_frame_index = block_stats.ref_frame_index;
921   return dep_stats;
922 }
923 
924 namespace {
ValidateBlockStats(const TplFrameStats & frame_stats,const TplBlockStats & block_stats,int min_block_size)925 Status ValidateBlockStats(const TplFrameStats &frame_stats,
926                           const TplBlockStats &block_stats,
927                           int min_block_size) {
928   if (block_stats.col >= frame_stats.frame_width ||
929       block_stats.row >= frame_stats.frame_height) {
930     std::ostringstream error_message;
931     error_message << "Block position (" << block_stats.col << ", "
932                   << block_stats.row
933                   << ") is out of range; frame dimensions are "
934                   << frame_stats.frame_width << " x "
935                   << frame_stats.frame_height;
936     return { AOM_CODEC_INVALID_PARAM, error_message.str() };
937   }
938   if (block_stats.col % min_block_size != 0 ||
939       block_stats.row % min_block_size != 0 ||
940       block_stats.width % min_block_size != 0 ||
941       block_stats.height % min_block_size != 0) {
942     std::ostringstream error_message;
943     error_message
944         << "Invalid block position or dimension, must be a multiple of "
945         << min_block_size << "; col = " << block_stats.col
946         << ", row = " << block_stats.row << ", width = " << block_stats.width
947         << ", height = " << block_stats.height;
948     return { AOM_CODEC_INVALID_PARAM, error_message.str() };
949   }
950   return { AOM_CODEC_OK, "" };
951 }
952 
ValidateTplStats(const GopStruct & gop_struct,const TplGopStats & tpl_gop_stats)953 Status ValidateTplStats(const GopStruct &gop_struct,
954                         const TplGopStats &tpl_gop_stats) {
955   constexpr char kAdvice[] =
956       "Do the current RateControlParam settings match those used to generate "
957       "the TPL stats?";
958   if (gop_struct.gop_frame_list.size() !=
959       tpl_gop_stats.frame_stats_list.size()) {
960     std::ostringstream error_message;
961     error_message << "Frame count of GopStruct ("
962                   << gop_struct.gop_frame_list.size()
963                   << ") doesn't match frame count of TPL stats ("
964                   << tpl_gop_stats.frame_stats_list.size() << "). " << kAdvice;
965     return { AOM_CODEC_INVALID_PARAM, error_message.str() };
966   }
967   for (int i = 0; i < static_cast<int>(gop_struct.gop_frame_list.size()); ++i) {
968     const bool is_ref_frame = gop_struct.gop_frame_list[i].update_ref_idx >= 0;
969     const bool has_tpl_stats =
970         !tpl_gop_stats.frame_stats_list[i].block_stats_list.empty();
971     if (is_ref_frame && !has_tpl_stats) {
972       std::ostringstream error_message;
973       error_message << "The frame with global_coding_idx "
974                     << gop_struct.gop_frame_list[i].global_coding_idx
975                     << " is a reference frame, but has no TPL stats. "
976                     << kAdvice;
977       return { AOM_CODEC_INVALID_PARAM, error_message.str() };
978     }
979   }
980   return { AOM_CODEC_OK, "" };
981 }
982 }  // namespace
983 
CreateTplFrameDepStatsWithoutPropagation(const TplFrameStats & frame_stats)984 StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
985     const TplFrameStats &frame_stats) {
986   if (frame_stats.block_stats_list.empty()) {
987     return TplFrameDepStats();
988   }
989   const int min_block_size = frame_stats.min_block_size;
990   const int unit_rows =
991       (frame_stats.frame_height + min_block_size - 1) / min_block_size;
992   const int unit_cols =
993       (frame_stats.frame_width + min_block_size - 1) / min_block_size;
994   TplFrameDepStats frame_dep_stats = CreateTplFrameDepStats(
995       frame_stats.frame_height, frame_stats.frame_width, min_block_size);
996   for (const TplBlockStats &block_stats : frame_stats.block_stats_list) {
997     Status status =
998         ValidateBlockStats(frame_stats, block_stats, min_block_size);
999     if (!status.ok()) {
1000       return status;
1001     }
1002     const int block_unit_row = block_stats.row / min_block_size;
1003     const int block_unit_col = block_stats.col / min_block_size;
1004     // The block must start within the frame boundaries, but it may extend past
1005     // the right edge or bottom of the frame. Find the number of unit rows and
1006     // columns in the block which are fully within the frame.
1007     const int block_unit_rows = std::min(block_stats.height / min_block_size,
1008                                          unit_rows - block_unit_row);
1009     const int block_unit_cols = std::min(block_stats.width / min_block_size,
1010                                          unit_cols - block_unit_col);
1011     const int unit_count = block_unit_rows * block_unit_cols;
1012     TplUnitDepStats unit_stats =
1013         TplBlockStatsToDepStats(block_stats, unit_count);
1014     for (int r = 0; r < block_unit_rows; r++) {
1015       for (int c = 0; c < block_unit_cols; c++) {
1016         frame_dep_stats.unit_stats[block_unit_row + r][block_unit_col + c] =
1017             unit_stats;
1018       }
1019     }
1020   }
1021 
1022   frame_dep_stats.rdcost = TplFrameDepStatsAccumulateInterCost(frame_dep_stats);
1023 
1024   return frame_dep_stats;
1025 }
1026 
GetRefCodingIdxList(const TplUnitDepStats & unit_dep_stats,const RefFrameTable & ref_frame_table,int * ref_coding_idx_list)1027 int GetRefCodingIdxList(const TplUnitDepStats &unit_dep_stats,
1028                         const RefFrameTable &ref_frame_table,
1029                         int *ref_coding_idx_list) {
1030   int ref_frame_count = 0;
1031   for (int i = 0; i < kBlockRefCount; ++i) {
1032     ref_coding_idx_list[i] = -1;
1033     int ref_frame_index = unit_dep_stats.ref_frame_index[i];
1034     if (ref_frame_index != -1) {
1035       assert(ref_frame_index < static_cast<int>(ref_frame_table.size()));
1036       ref_coding_idx_list[i] = ref_frame_table[ref_frame_index].coding_idx;
1037       ref_frame_count++;
1038     }
1039   }
1040   return ref_frame_count;
1041 }
1042 
GetBlockOverlapArea(int r0,int c0,int r1,int c1,int size)1043 int GetBlockOverlapArea(int r0, int c0, int r1, int c1, int size) {
1044   const int r_low = std::max(r0, r1);
1045   const int r_high = std::min(r0 + size, r1 + size);
1046   const int c_low = std::max(c0, c1);
1047   const int c_high = std::min(c0 + size, c1 + size);
1048   if (r_high >= r_low && c_high >= c_low) {
1049     return (r_high - r_low) * (c_high - c_low);
1050   }
1051   return 0;
1052 }
1053 
1054 // TODO(angiebird): Merge TplFrameDepStatsAccumulateIntraCost and
1055 // TplFrameDepStatsAccumulate.
TplFrameDepStatsAccumulateIntraCost(const TplFrameDepStats & frame_dep_stats)1056 double TplFrameDepStatsAccumulateIntraCost(
1057     const TplFrameDepStats &frame_dep_stats) {
1058   auto getIntraCost = [](double sum, const TplUnitDepStats &unit) {
1059     return sum + unit.intra_cost;
1060   };
1061   double sum = 0;
1062   for (const auto &row : frame_dep_stats.unit_stats) {
1063     sum = std::accumulate(row.begin(), row.end(), sum, getIntraCost);
1064   }
1065   return std::max(sum, 1.0);
1066 }
1067 
TplFrameDepStatsAccumulateInterCost(const TplFrameDepStats & frame_dep_stats)1068 double TplFrameDepStatsAccumulateInterCost(
1069     const TplFrameDepStats &frame_dep_stats) {
1070   auto getInterCost = [](double sum, const TplUnitDepStats &unit) {
1071     return sum + unit.inter_cost;
1072   };
1073   double sum = 0;
1074   for (const auto &row : frame_dep_stats.unit_stats) {
1075     sum = std::accumulate(row.begin(), row.end(), sum, getInterCost);
1076   }
1077   return std::max(sum, 1.0);
1078 }
1079 
TplFrameDepStatsAccumulate(const TplFrameDepStats & frame_dep_stats)1080 double TplFrameDepStatsAccumulate(const TplFrameDepStats &frame_dep_stats) {
1081   auto getOverallCost = [](double sum, const TplUnitDepStats &unit) {
1082     return sum + unit.propagation_cost + unit.intra_cost;
1083   };
1084   double sum = 0;
1085   for (const auto &row : frame_dep_stats.unit_stats) {
1086     sum = std::accumulate(row.begin(), row.end(), sum, getOverallCost);
1087   }
1088   return std::max(sum, 1.0);
1089 }
1090 
1091 // This is a generalization of GET_MV_RAWPEL that allows for an arbitrary
1092 // number of fractional bits.
1093 // TODO(angiebird): Add unit test to this function
GetFullpelValue(int subpel_value,int subpel_bits)1094 int GetFullpelValue(int subpel_value, int subpel_bits) {
1095   const int subpel_scale = (1 << subpel_bits);
1096   const int sign = subpel_value >= 0 ? 1 : -1;
1097   int fullpel_value = (abs(subpel_value) + subpel_scale / 2) >> subpel_bits;
1098   fullpel_value *= sign;
1099   return fullpel_value;
1100 }
1101 
GetPropagationFraction(const TplUnitDepStats & unit_dep_stats)1102 double GetPropagationFraction(const TplUnitDepStats &unit_dep_stats) {
1103   assert(unit_dep_stats.intra_cost >= unit_dep_stats.inter_cost);
1104   return (unit_dep_stats.intra_cost - unit_dep_stats.inter_cost) /
1105          ModifyDivisor(unit_dep_stats.intra_cost);
1106 }
1107 
TplFrameDepStatsPropagate(int coding_idx,const RefFrameTable & ref_frame_table,TplGopDepStats * tpl_gop_dep_stats)1108 void TplFrameDepStatsPropagate(int coding_idx,
1109                                const RefFrameTable &ref_frame_table,
1110                                TplGopDepStats *tpl_gop_dep_stats) {
1111   assert(!tpl_gop_dep_stats->frame_dep_stats_list.empty());
1112   TplFrameDepStats *frame_dep_stats =
1113       &tpl_gop_dep_stats->frame_dep_stats_list[coding_idx];
1114 
1115   if (frame_dep_stats->unit_stats.empty()) return;
1116 
1117   const int unit_size = frame_dep_stats->unit_size;
1118   const int frame_unit_rows =
1119       static_cast<int>(frame_dep_stats->unit_stats.size());
1120   const int frame_unit_cols =
1121       static_cast<int>(frame_dep_stats->unit_stats[0].size());
1122   for (int unit_row = 0; unit_row < frame_unit_rows; ++unit_row) {
1123     for (int unit_col = 0; unit_col < frame_unit_cols; ++unit_col) {
1124       TplUnitDepStats &unit_dep_stats =
1125           frame_dep_stats->unit_stats[unit_row][unit_col];
1126       int ref_coding_idx_list[kBlockRefCount] = { -1, -1 };
1127       int ref_frame_count = GetRefCodingIdxList(unit_dep_stats, ref_frame_table,
1128                                                 ref_coding_idx_list);
1129       if (ref_frame_count == 0) continue;
1130       for (int i = 0; i < kBlockRefCount; ++i) {
1131         if (ref_coding_idx_list[i] == -1) continue;
1132         assert(
1133             ref_coding_idx_list[i] <
1134             static_cast<int>(tpl_gop_dep_stats->frame_dep_stats_list.size()));
1135         TplFrameDepStats &ref_frame_dep_stats =
1136             tpl_gop_dep_stats->frame_dep_stats_list[ref_coding_idx_list[i]];
1137         assert(!ref_frame_dep_stats.unit_stats.empty());
1138         const auto &mv = unit_dep_stats.mv[i];
1139         const int mv_row = GetFullpelValue(mv.row, mv.subpel_bits);
1140         const int mv_col = GetFullpelValue(mv.col, mv.subpel_bits);
1141         const int ref_pixel_r = unit_row * unit_size + mv_row;
1142         const int ref_pixel_c = unit_col * unit_size + mv_col;
1143         const int ref_unit_row_low =
1144             (unit_row * unit_size + mv_row) / unit_size;
1145         const int ref_unit_col_low =
1146             (unit_col * unit_size + mv_col) / unit_size;
1147 
1148         for (int j = 0; j < 2; ++j) {
1149           for (int k = 0; k < 2; ++k) {
1150             const int ref_unit_row = ref_unit_row_low + j;
1151             const int ref_unit_col = ref_unit_col_low + k;
1152             if (ref_unit_row >= 0 && ref_unit_row < frame_unit_rows &&
1153                 ref_unit_col >= 0 && ref_unit_col < frame_unit_cols) {
1154               const int overlap_area = GetBlockOverlapArea(
1155                   ref_pixel_r, ref_pixel_c, ref_unit_row * unit_size,
1156                   ref_unit_col * unit_size, unit_size);
1157               const double overlap_ratio =
1158                   overlap_area * 1.0 / (unit_size * unit_size);
1159               const double propagation_fraction =
1160                   GetPropagationFraction(unit_dep_stats);
1161               const double propagation_ratio =
1162                   1.0 / ref_frame_count * overlap_ratio * propagation_fraction;
1163               TplUnitDepStats &ref_unit_stats =
1164                   ref_frame_dep_stats.unit_stats[ref_unit_row][ref_unit_col];
1165               ref_unit_stats.propagation_cost +=
1166                   (unit_dep_stats.intra_cost +
1167                    unit_dep_stats.propagation_cost) *
1168                   propagation_ratio;
1169             }
1170           }
1171         }
1172       }
1173     }
1174   }
1175 }
1176 
GetRefFrameTableList(const GopStruct & gop_struct,const std::vector<LookaheadStats> & lookahead_stats,RefFrameTable ref_frame_table)1177 std::vector<RefFrameTable> AV1RateControlQMode::GetRefFrameTableList(
1178     const GopStruct &gop_struct,
1179     const std::vector<LookaheadStats> &lookahead_stats,
1180     RefFrameTable ref_frame_table) {
1181   if (gop_struct.global_coding_idx_offset == 0) {
1182     // For the first GOP, ref_frame_table need not be initialized. This is fine,
1183     // because the first frame (a key frame) will fully initialize it.
1184     ref_frame_table.assign(rc_param_.ref_frame_table_size, GopFrameInvalid());
1185   } else {
1186     // It's not the first GOP, so ref_frame_table must be valid.
1187     assert(static_cast<int>(ref_frame_table.size()) ==
1188            rc_param_.ref_frame_table_size);
1189     assert(std::all_of(ref_frame_table.begin(), ref_frame_table.end(),
1190                        std::mem_fn(&GopFrame::is_valid)));
1191     // Reset the frame processing order of the initial ref_frame_table.
1192     for (GopFrame &gop_frame : ref_frame_table) gop_frame.coding_idx = -1;
1193   }
1194 
1195   std::vector<RefFrameTable> ref_frame_table_list;
1196   ref_frame_table_list.push_back(ref_frame_table);
1197   for (const GopFrame &gop_frame : gop_struct.gop_frame_list) {
1198     if (gop_frame.is_key_frame) {
1199       ref_frame_table.assign(rc_param_.ref_frame_table_size, gop_frame);
1200     } else if (gop_frame.update_ref_idx != -1) {
1201       assert(gop_frame.update_ref_idx <
1202              static_cast<int>(ref_frame_table.size()));
1203       ref_frame_table[gop_frame.update_ref_idx] = gop_frame;
1204     }
1205     ref_frame_table_list.push_back(ref_frame_table);
1206   }
1207 
1208   int gop_size_offset = static_cast<int>(gop_struct.gop_frame_list.size());
1209 
1210   for (const auto &lookahead_stat : lookahead_stats) {
1211     for (GopFrame gop_frame : lookahead_stat.gop_struct->gop_frame_list) {
1212       if (gop_frame.is_key_frame) {
1213         ref_frame_table.assign(rc_param_.ref_frame_table_size, gop_frame);
1214       } else if (gop_frame.update_ref_idx != -1) {
1215         assert(gop_frame.update_ref_idx <
1216                static_cast<int>(ref_frame_table.size()));
1217         gop_frame.coding_idx += gop_size_offset;
1218         ref_frame_table[gop_frame.update_ref_idx] = gop_frame;
1219       }
1220       ref_frame_table_list.push_back(ref_frame_table);
1221     }
1222     gop_size_offset +=
1223         static_cast<int>(lookahead_stat.gop_struct->gop_frame_list.size());
1224   }
1225 
1226   return ref_frame_table_list;
1227 }
1228 
ComputeTplGopDepStats(const TplGopStats & tpl_gop_stats,const std::vector<LookaheadStats> & lookahead_stats,const std::vector<RefFrameTable> & ref_frame_table_list)1229 StatusOr<TplGopDepStats> ComputeTplGopDepStats(
1230     const TplGopStats &tpl_gop_stats,
1231     const std::vector<LookaheadStats> &lookahead_stats,
1232     const std::vector<RefFrameTable> &ref_frame_table_list) {
1233   std::vector<const TplFrameStats *> tpl_frame_stats_list_with_lookahead;
1234   for (const auto &tpl_frame_stats : tpl_gop_stats.frame_stats_list) {
1235     tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
1236   }
1237   for (const auto &lookahead_stat : lookahead_stats) {
1238     for (const auto &tpl_frame_stats :
1239          lookahead_stat.tpl_gop_stats->frame_stats_list) {
1240       tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
1241     }
1242   }
1243 
1244   const int frame_count =
1245       static_cast<int>(tpl_frame_stats_list_with_lookahead.size());
1246 
1247   // Create the struct to store TPL dependency stats
1248   TplGopDepStats tpl_gop_dep_stats;
1249 
1250   tpl_gop_dep_stats.frame_dep_stats_list.reserve(frame_count);
1251   for (int coding_idx = 0; coding_idx < frame_count; coding_idx++) {
1252     const StatusOr<TplFrameDepStats> tpl_frame_dep_stats =
1253         CreateTplFrameDepStatsWithoutPropagation(
1254             *tpl_frame_stats_list_with_lookahead[coding_idx]);
1255     if (!tpl_frame_dep_stats.ok()) {
1256       return tpl_frame_dep_stats.status();
1257     }
1258     tpl_gop_dep_stats.frame_dep_stats_list.push_back(
1259         std::move(*tpl_frame_dep_stats));
1260   }
1261 
1262   // Back propagation
1263   for (int coding_idx = frame_count - 1; coding_idx >= 0; coding_idx--) {
1264     auto &ref_frame_table = ref_frame_table_list[coding_idx];
1265     // TODO(angiebird): Handle/test the case where reference frame
1266     // is in the previous GOP
1267     TplFrameDepStatsPropagate(coding_idx, ref_frame_table, &tpl_gop_dep_stats);
1268   }
1269   return tpl_gop_dep_stats;
1270 }
1271 
SetupDeltaQ(const TplFrameDepStats & frame_dep_stats,int frame_width,int frame_height,int base_qindex,double frame_importance)1272 static std::vector<uint8_t> SetupDeltaQ(const TplFrameDepStats &frame_dep_stats,
1273                                         int frame_width, int frame_height,
1274                                         int base_qindex,
1275                                         double frame_importance) {
1276   // TODO(jianj) : Add support to various superblock sizes.
1277   const int sb_size = 64;
1278   const int delta_q_res = 4;
1279   const int num_unit_per_sb = sb_size / frame_dep_stats.unit_size;
1280   const int sb_rows = (frame_height + sb_size - 1) / sb_size;
1281   const int sb_cols = (frame_width + sb_size - 1) / sb_size;
1282   const int unit_rows = (frame_height + frame_dep_stats.unit_size - 1) /
1283                         frame_dep_stats.unit_size;
1284   const int unit_cols =
1285       (frame_width + frame_dep_stats.unit_size - 1) / frame_dep_stats.unit_size;
1286   std::vector<uint8_t> superblock_q_indices;
1287   // Calculate delta_q offset for each superblock.
1288   for (int sb_row = 0; sb_row < sb_rows; ++sb_row) {
1289     for (int sb_col = 0; sb_col < sb_cols; ++sb_col) {
1290       double intra_cost = 0;
1291       double mc_dep_cost = 0;
1292       const int unit_row_start = sb_row * num_unit_per_sb;
1293       const int unit_row_end =
1294           std::min((sb_row + 1) * num_unit_per_sb, unit_rows);
1295       const int unit_col_start = sb_col * num_unit_per_sb;
1296       const int unit_col_end =
1297           std::min((sb_col + 1) * num_unit_per_sb, unit_cols);
1298       // A simplified version of av1_get_q_for_deltaq_objective()
1299       for (int unit_row = unit_row_start; unit_row < unit_row_end; ++unit_row) {
1300         for (int unit_col = unit_col_start; unit_col < unit_col_end;
1301              ++unit_col) {
1302           const TplUnitDepStats &unit_dep_stat =
1303               frame_dep_stats.unit_stats[unit_row][unit_col];
1304           intra_cost += unit_dep_stat.intra_cost;
1305           mc_dep_cost += unit_dep_stat.propagation_cost;
1306         }
1307       }
1308 
1309       double beta = 1.0;
1310       if (mc_dep_cost > 0 && intra_cost > 0) {
1311         const double r0 = 1 / frame_importance;
1312         const double rk = intra_cost / mc_dep_cost;
1313         beta = r0 / rk;
1314         assert(beta > 0.0);
1315       }
1316       int offset = av1_get_deltaq_offset(AOM_BITS_8, base_qindex, beta);
1317       offset = std::min(offset, delta_q_res * 9 - 1);
1318       offset = std::max(offset, -delta_q_res * 9 + 1);
1319       int qindex = offset + base_qindex;
1320       qindex = std::min(qindex, MAXQ);
1321       qindex = std::max(qindex, MINQ);
1322       qindex = av1_adjust_q_from_delta_q_res(delta_q_res, base_qindex, qindex);
1323       superblock_q_indices.push_back(static_cast<uint8_t>(qindex));
1324     }
1325   }
1326 
1327   return superblock_q_indices;
1328 }
1329 
FindKMeansClusterMap(const std::vector<uint8_t> & qindices,const std::vector<double> & centroids)1330 static std::unordered_map<int, double> FindKMeansClusterMap(
1331     const std::vector<uint8_t> &qindices,
1332     const std::vector<double> &centroids) {
1333   std::unordered_map<int, double> cluster_map;
1334   for (const uint8_t qindex : qindices) {
1335     double nearest_centroid = *std::min_element(
1336         centroids.begin(), centroids.end(),
1337         [qindex](const double centroid_a, const double centroid_b) {
1338           return fabs(centroid_a - qindex) < fabs(centroid_b - qindex);
1339         });
1340     cluster_map.insert({ qindex, nearest_centroid });
1341   }
1342   return cluster_map;
1343 }
1344 
1345 namespace internal {
1346 
KMeans(std::vector<uint8_t> qindices,int k)1347 std::unordered_map<int, int> KMeans(std::vector<uint8_t> qindices, int k) {
1348   std::vector<double> centroids;
1349   // Initialize the centroids with first k qindices
1350   std::unordered_set<int> qindices_set;
1351 
1352   for (const uint8_t qp : qindices) {
1353     if (!qindices_set.insert(qp).second) continue;  // Already added.
1354     centroids.push_back(qp);
1355     if (static_cast<int>(centroids.size()) >= k) break;
1356   }
1357 
1358   std::unordered_map<int, double> intermediate_cluster_map;
1359   while (true) {
1360     // Find the closest centroid for each qindex
1361     intermediate_cluster_map = FindKMeansClusterMap(qindices, centroids);
1362     // For each cluster, calculate the new centroids
1363     std::unordered_map<double, std::vector<int>> centroid_to_qindices;
1364     for (const auto &qindex_centroid : intermediate_cluster_map) {
1365       centroid_to_qindices[qindex_centroid.second].push_back(
1366           qindex_centroid.first);
1367     }
1368     bool centroids_changed = false;
1369     std::vector<double> new_centroids;
1370     for (const auto &cluster : centroid_to_qindices) {
1371       double sum = 0.0;
1372       for (const int qindex : cluster.second) {
1373         sum += qindex;
1374       }
1375       double new_centroid = sum / cluster.second.size();
1376       new_centroids.push_back(new_centroid);
1377       if (new_centroid != cluster.first) centroids_changed = true;
1378     }
1379     if (!centroids_changed) break;
1380     centroids = new_centroids;
1381   }
1382   std::unordered_map<int, int> cluster_map;
1383   for (const auto &qindex_centroid : intermediate_cluster_map) {
1384     cluster_map.insert(
1385         { qindex_centroid.first, static_cast<int>(qindex_centroid.second) });
1386   }
1387   return cluster_map;
1388 }
1389 }  // namespace internal
1390 
GetRDMult(const GopFrame & gop_frame,int q_index)1391 static int GetRDMult(const GopFrame &gop_frame, int q_index) {
1392   // TODO(angiebird):
1393   // 1) Check if these rdmult rules are good in our use case.
1394   // 2) Support high-bit-depth mode
1395   if (gop_frame.is_golden_frame) {
1396     // Assume ARF_UPDATE/GF_UPDATE share the same remult rule.
1397     return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, GF_UPDATE, q_index);
1398   } else if (gop_frame.is_key_frame) {
1399     return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, KF_UPDATE, q_index);
1400   } else {
1401     // Assume LF_UPDATE/OVERLAY_UPDATE/INTNL_OVERLAY_UPDATE/INTNL_ARF_UPDATE
1402     // share the same remult rule.
1403     return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, LF_UPDATE, q_index);
1404   }
1405 }
1406 
GetGopEncodeInfoWithNoStats(const GopStruct & gop_struct)1407 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithNoStats(
1408     const GopStruct &gop_struct) {
1409   GopEncodeInfo gop_encode_info;
1410   const int frame_count = static_cast<int>(gop_struct.gop_frame_list.size());
1411   for (int i = 0; i < frame_count; i++) {
1412     FrameEncodeParameters param;
1413     const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
1414     // Use constant QP for TPL pass encoding. Keep the functionality
1415     // that allows QP changes across sub-gop.
1416     param.q_index = rc_param_.base_q_index;
1417     param.rdmult = av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, LF_UPDATE,
1418                                                        rc_param_.base_q_index);
1419     // TODO(jingning): gop_frame is needed in two pass tpl later.
1420     (void)gop_frame;
1421 
1422     if (rc_param_.tpl_pass_index) {
1423       if (gop_frame.update_type == GopFrameType::kRegularGolden ||
1424           gop_frame.update_type == GopFrameType::kRegularKey ||
1425           gop_frame.update_type == GopFrameType::kRegularArf) {
1426         double qstep_ratio = 1 / 3.0;
1427         param.q_index = av1_get_q_index_from_qstep_ratio(
1428             rc_param_.base_q_index, qstep_ratio, AOM_BITS_8);
1429         if (rc_param_.base_q_index) param.q_index = AOMMAX(param.q_index, 1);
1430       }
1431     }
1432     gop_encode_info.param_list.push_back(param);
1433   }
1434   return gop_encode_info;
1435 }
1436 
GetGopEncodeInfoWithFp(const GopStruct & gop_struct,const FirstpassInfo & firstpass_info AOM_UNUSED)1437 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithFp(
1438     const GopStruct &gop_struct,
1439     const FirstpassInfo &firstpass_info AOM_UNUSED) {
1440   // TODO(b/260859962): This is currently a placeholder. Should use the fp
1441   // stats to calculate frame-level qp.
1442   return GetGopEncodeInfoWithNoStats(gop_struct);
1443 }
1444 
GetGopEncodeInfoWithTpl(const GopStruct & gop_struct,const TplGopStats & tpl_gop_stats,const std::vector<LookaheadStats> & lookahead_stats,const RefFrameTable & ref_frame_table_snapshot_init)1445 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithTpl(
1446     const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
1447     const std::vector<LookaheadStats> &lookahead_stats,
1448     const RefFrameTable &ref_frame_table_snapshot_init) {
1449   const std::vector<RefFrameTable> ref_frame_table_list = GetRefFrameTableList(
1450       gop_struct, lookahead_stats, ref_frame_table_snapshot_init);
1451 
1452   GopEncodeInfo gop_encode_info;
1453   gop_encode_info.final_snapshot = ref_frame_table_list.back();
1454   StatusOr<TplGopDepStats> gop_dep_stats = ComputeTplGopDepStats(
1455       tpl_gop_stats, lookahead_stats, ref_frame_table_list);
1456   if (!gop_dep_stats.ok()) {
1457     return gop_dep_stats.status();
1458   }
1459   const int frame_count =
1460       static_cast<int>(tpl_gop_stats.frame_stats_list.size());
1461   const int active_worst_quality = rc_param_.base_q_index;
1462   int active_best_quality = rc_param_.base_q_index;
1463   for (int i = 0; i < frame_count; i++) {
1464     FrameEncodeParameters param;
1465     const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
1466 
1467     if (gop_frame.update_type == GopFrameType::kOverlay ||
1468         gop_frame.update_type == GopFrameType::kIntermediateOverlay ||
1469         gop_frame.update_type == GopFrameType::kRegularLeaf) {
1470       param.q_index = rc_param_.base_q_index;
1471     } else if (gop_frame.update_type == GopFrameType::kRegularGolden ||
1472                gop_frame.update_type == GopFrameType::kRegularKey ||
1473                gop_frame.update_type == GopFrameType::kRegularArf) {
1474       const TplFrameDepStats &frame_dep_stats =
1475           gop_dep_stats->frame_dep_stats_list[i];
1476       const double cost_without_propagation =
1477           TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
1478       const double cost_with_propagation =
1479           TplFrameDepStatsAccumulate(frame_dep_stats);
1480       const double frame_importance =
1481           cost_with_propagation / cost_without_propagation;
1482       // Imitate the behavior of av1_tpl_get_qstep_ratio()
1483       const double qstep_ratio = sqrt(1 / frame_importance);
1484       param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
1485                                                        qstep_ratio, AOM_BITS_8);
1486       if (rc_param_.base_q_index) param.q_index = AOMMAX(param.q_index, 1);
1487       active_best_quality = param.q_index;
1488 
1489       if (rc_param_.max_distinct_q_indices_per_frame > 1) {
1490         std::vector<uint8_t> superblock_q_indices = SetupDeltaQ(
1491             frame_dep_stats, rc_param_.frame_width, rc_param_.frame_height,
1492             param.q_index, frame_importance);
1493         std::unordered_map<int, int> qindex_centroids = internal::KMeans(
1494             superblock_q_indices, rc_param_.max_distinct_q_indices_per_frame);
1495         for (size_t i = 0; i < superblock_q_indices.size(); ++i) {
1496           const int curr_sb_qindex =
1497               qindex_centroids.find(superblock_q_indices[i])->second;
1498           const int delta_q_res = 4;
1499           const int adjusted_qindex =
1500               param.q_index +
1501               (curr_sb_qindex - param.q_index) / delta_q_res * delta_q_res;
1502           const int rd_mult = GetRDMult(gop_frame, adjusted_qindex);
1503           param.superblock_encode_params.push_back(
1504               { static_cast<uint8_t>(adjusted_qindex), rd_mult });
1505         }
1506       }
1507     } else {
1508       // Intermediate ARFs
1509       assert(gop_frame.layer_depth >= 1);
1510       const int depth_factor = 1 << (gop_frame.layer_depth - 1);
1511       param.q_index =
1512           (active_worst_quality * (depth_factor - 1) + active_best_quality) /
1513           depth_factor;
1514     }
1515     param.rdmult = GetRDMult(gop_frame, param.q_index);
1516     gop_encode_info.param_list.push_back(param);
1517   }
1518   return gop_encode_info;
1519 }
1520 
GetTplPassGopEncodeInfo(const GopStruct & gop_struct,const FirstpassInfo & firstpass_info)1521 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetTplPassGopEncodeInfo(
1522     const GopStruct &gop_struct, const FirstpassInfo &firstpass_info) {
1523   return GetGopEncodeInfoWithFp(gop_struct, firstpass_info);
1524 }
1525 
GetGopEncodeInfo(const GopStruct & gop_struct,const TplGopStats & tpl_gop_stats,const std::vector<LookaheadStats> & lookahead_stats,const FirstpassInfo & firstpass_info AOM_UNUSED,const RefFrameTable & ref_frame_table_snapshot_init)1526 StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfo(
1527     const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
1528     const std::vector<LookaheadStats> &lookahead_stats,
1529     const FirstpassInfo &firstpass_info AOM_UNUSED,
1530     const RefFrameTable &ref_frame_table_snapshot_init) {
1531   // When TPL stats are not valid, use first pass stats.
1532   Status status = ValidateTplStats(gop_struct, tpl_gop_stats);
1533   if (!status.ok()) {
1534     return status;
1535   }
1536 
1537   for (const auto &lookahead_stat : lookahead_stats) {
1538     Status status = ValidateTplStats(*lookahead_stat.gop_struct,
1539                                      *lookahead_stat.tpl_gop_stats);
1540     if (!status.ok()) {
1541       return status;
1542     }
1543   }
1544 
1545   // TODO(b/260859962): Currently firstpass stats are used as an alternative,
1546   // but we could also combine it with tpl results in the future for more
1547   // stable qp determination.
1548   return GetGopEncodeInfoWithTpl(gop_struct, tpl_gop_stats, lookahead_stats,
1549                                  ref_frame_table_snapshot_init);
1550 }
1551 
1552 }  // namespace aom
1553