• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/qmode_rc/ratectrl_qmode.h"
13 
14 #include <algorithm>
15 #include <array>
16 #include <cerrno>
17 #include <cstring>
18 #include <fstream>
19 #include <memory>
20 #include <numeric>
21 #include <random>
22 #include <string>
23 #include <unordered_set>
24 #include <vector>
25 
26 #include "av1/qmode_rc/ducky_encode.h"
27 #include "av1/qmode_rc/reference_manager.h"
28 #include "test/mock_ratectrl_qmode.h"
29 #include "test/video_source.h"
30 #include "third_party/googletest/src/googlemock/include/gmock/gmock.h"
31 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
32 
33 namespace {
34 
35 using ::testing::HasSubstr;
36 
37 constexpr int kRefFrameTableSize = 7;
38 constexpr int kFrameWidth = 352;
39 constexpr int kFrameHeight = 288;
40 constexpr int kFrameLimit = 250;
41 
42 MATCHER(IsOkStatus, "") {
43   *result_listener << "with code " << arg.code
44                    << " and message: " << arg.message;
45   return arg.ok();
46 }
47 
48 // Reads a whitespace-delimited string from stream, and parses it as a double.
49 // Returns an empty string if the entire string was successfully parsed as a
50 // double, or an error messaage if not.
ReadDouble(std::istream & stream,double * value)51 std::string ReadDouble(std::istream &stream, double *value) {
52   std::string word;
53   stream >> word;
54   if (word.empty()) {
55     return "Unexpectedly reached end of input";
56   }
57   char *end;
58   *value = std::strtod(word.c_str(), &end);
59   if (*end != '\0') {
60     return "Unexpected characters found: " + word;
61   }
62   return "";
63 }
64 
ReadFirstpassInfo(const std::string & filename,aom::FirstpassInfo * firstpass_info,const int frame_limit)65 void ReadFirstpassInfo(const std::string &filename,
66                        aom::FirstpassInfo *firstpass_info,
67                        const int frame_limit) {
68   // These golden files are generated by the following command line:
69   // ./aomenc --width=352 --height=288 --fps=30/1 --limit=250 --codec=av1
70   // --cpu-used=3 --end-usage=q --cq-level=36 --threads=0 --profile=0
71   // --lag-in-frames=35 --min-q=0 --max-q=63 --auto-alt-ref=1 --passes=2
72   // --kf-max-dist=160 --kf-min-dist=0 --drop-frame=0
73   // --static-thresh=0 --minsection-pct=0 --maxsection-pct=2000
74   // --arnr-maxframes=7
75   // --arnr-strength=5 --sharpness=0 --undershoot-pct=100 --overshoot-pct=100
76   // --frame-parallel=0
77   // --tile-columns=0 -o output.webm hantro_collage_w352h288.yuv
78   // First pass stats are written out in av1_get_second_pass_params right after
79   // calculate_gf_length.
80   std::string path = libaom_test::GetDataPath() + "/" + filename;
81   std::ifstream firstpass_stats_file(path);
82   ASSERT_TRUE(firstpass_stats_file.good())
83       << "Error opening " << path << ": " << std::strerror(errno);
84   firstpass_info->num_mbs_16x16 =
85       (kFrameWidth / 16 + 1) * (kFrameHeight / 16 + 1);
86   std::string newline;
87   int frame_number = 0;
88   while (std::getline(firstpass_stats_file, newline) &&
89          frame_number < frame_limit) {
90     std::istringstream iss(newline);
91     FIRSTPASS_STATS firstpass_stats_input = {};
92     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.frame), "");
93     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.weight), "");
94     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.intra_error), "");
95     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.frame_avg_wavelet_energy),
96               "");
97     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.coded_error), "");
98     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.sr_coded_error), "");
99     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.pcnt_inter), "");
100     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.pcnt_motion), "");
101     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.pcnt_second_ref), "");
102     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.pcnt_neutral), "");
103     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.intra_skip_pct), "");
104     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.inactive_zone_rows), "");
105     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.inactive_zone_cols), "");
106     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.MVr), "");
107     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.mvr_abs), "");
108     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.MVc), "");
109     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.mvc_abs), "");
110     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.MVrv), "");
111     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.MVcv), "");
112     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.mv_in_out_count), "");
113     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.new_mv_count), "");
114     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.duration), "");
115     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.count), "");
116     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.raw_error_stdev), "");
117     iss >> firstpass_stats_input.is_flash;
118     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.noise_var), "");
119     ASSERT_EQ(ReadDouble(iss, &firstpass_stats_input.cor_coeff), "");
120     ASSERT_TRUE(iss.eof()) << "Too many fields on line "
121                            << firstpass_info->stats_list.size() + 1 << "\n"
122                            << newline;
123     firstpass_info->stats_list.push_back(firstpass_stats_input);
124 
125     frame_number++;
126   }
127 }
128 }  // namespace
129 
130 namespace aom {
131 
132 using ::testing::ElementsAre;
133 using ::testing::Field;
134 using ::testing::Return;
135 
136 constexpr double kErrorEpsilon = 0.000001;
137 
TestGopDisplayOrder(const GopStruct & gop_struct)138 void TestGopDisplayOrder(const GopStruct &gop_struct) {
139   // Test whether show frames' order indices are sequential
140   int expected_order_idx = 0;
141   int expected_show_frame_count = 0;
142   for (const auto &gop_frame : gop_struct.gop_frame_list) {
143     if (gop_frame.is_show_frame) {
144       EXPECT_EQ(gop_frame.order_idx, expected_order_idx);
145       expected_order_idx++;
146       expected_show_frame_count++;
147     }
148   }
149   EXPECT_EQ(gop_struct.show_frame_count, expected_show_frame_count);
150 }
151 
TestGopGlobalOrderIdx(const GopStruct & gop_struct,int global_order_idx_offset)152 void TestGopGlobalOrderIdx(const GopStruct &gop_struct,
153                            int global_order_idx_offset) {
154   // Test whether show frames' global order indices are sequential
155   EXPECT_EQ(gop_struct.global_order_idx_offset, global_order_idx_offset);
156   int expected_global_order_idx = global_order_idx_offset;
157   for (const auto &gop_frame : gop_struct.gop_frame_list) {
158     if (gop_frame.is_show_frame) {
159       EXPECT_EQ(gop_frame.global_order_idx, expected_global_order_idx);
160       expected_global_order_idx++;
161     }
162   }
163 }
164 
TestGopGlobalCodingIdx(const GopStruct & gop_struct,int global_coding_idx_offset)165 void TestGopGlobalCodingIdx(const GopStruct &gop_struct,
166                             int global_coding_idx_offset) {
167   EXPECT_EQ(gop_struct.global_coding_idx_offset, global_coding_idx_offset);
168   for (const auto &gop_frame : gop_struct.gop_frame_list) {
169     EXPECT_EQ(gop_frame.global_coding_idx,
170               global_coding_idx_offset + gop_frame.coding_idx);
171   }
172 }
173 
TestColocatedShowFrame(const GopStruct & gop_struct)174 void TestColocatedShowFrame(const GopStruct &gop_struct) {
175   // Test whether each non show frame has a colocated show frame
176   int gop_size = static_cast<int>(gop_struct.gop_frame_list.size());
177   for (int gop_idx = 0; gop_idx < gop_size; ++gop_idx) {
178     auto &gop_frame = gop_struct.gop_frame_list[gop_idx];
179     if (gop_frame.is_show_frame == 0) {
180       bool found_colocated_ref_frame = false;
181       for (int i = gop_idx + 1; i < gop_size; ++i) {
182         auto &next_gop_frame = gop_struct.gop_frame_list[i];
183         if (gop_frame.order_idx == next_gop_frame.order_idx) {
184           found_colocated_ref_frame = true;
185           EXPECT_EQ(gop_frame.update_ref_idx, next_gop_frame.colocated_ref_idx);
186           EXPECT_TRUE(next_gop_frame.is_show_frame);
187         }
188         if (gop_frame.update_ref_idx == next_gop_frame.update_ref_idx) {
189           break;
190         }
191       }
192       EXPECT_TRUE(found_colocated_ref_frame);
193     }
194   }
195 }
196 
TestLayerDepth(const GopStruct & gop_struct,int max_layer_depth)197 void TestLayerDepth(const GopStruct &gop_struct, int max_layer_depth) {
198   int gop_size = static_cast<int>(gop_struct.gop_frame_list.size());
199   for (int gop_idx = 0; gop_idx < gop_size; ++gop_idx) {
200     const auto &gop_frame = gop_struct.gop_frame_list[gop_idx];
201     if (gop_frame.is_key_frame) {
202       EXPECT_EQ(gop_frame.layer_depth, 0);
203     }
204 
205     if (gop_frame.is_arf_frame) {
206       EXPECT_LT(gop_frame.layer_depth, max_layer_depth);
207     }
208 
209     if (!gop_frame.is_key_frame && !gop_frame.is_arf_frame) {
210       EXPECT_EQ(gop_frame.layer_depth, max_layer_depth);
211     }
212   }
213 }
214 
TestArfInterval(const GopStruct & gop_struct)215 void TestArfInterval(const GopStruct &gop_struct) {
216   std::vector<int> arf_order_idx_list;
217   for (const auto &gop_frame : gop_struct.gop_frame_list) {
218     if (gop_frame.is_arf_frame) {
219       arf_order_idx_list.push_back(gop_frame.order_idx);
220     }
221   }
222   std::sort(arf_order_idx_list.begin(), arf_order_idx_list.end());
223   int arf_count = static_cast<int>(arf_order_idx_list.size());
224   for (int i = 1; i < arf_count; ++i) {
225     int arf_interval = arf_order_idx_list[i] - arf_order_idx_list[i - 1];
226     EXPECT_GE(arf_interval, kMinArfInterval);
227   }
228 }
229 
230 class RateControlQModeTest : public ::testing::Test {
231  protected:
RateControlQModeTest()232   RateControlQModeTest() {
233     rc_param_.max_gop_show_frame_count = 32;
234     rc_param_.min_gop_show_frame_count = 4;
235     rc_param_.ref_frame_table_size = 7;
236     rc_param_.max_ref_frames = 7;
237     rc_param_.base_q_index = 128;
238     rc_param_.frame_height = kFrameHeight;
239     rc_param_.frame_width = kFrameWidth;
240   }
241 
242   RateControlParam rc_param_ = {};
243 };
244 
TEST_F(RateControlQModeTest,ConstructGopARF)245 TEST_F(RateControlQModeTest, ConstructGopARF) {
246   int show_frame_count = 16;
247   const bool has_key_frame = false;
248   const int global_coding_idx_offset = 5;
249   const int global_order_idx_offset = 20;
250   RefFrameManager ref_frame_manager(kRefFrameTableSize, 7);
251   GopStruct gop_struct =
252       ConstructGop(&ref_frame_manager, show_frame_count, has_key_frame,
253                    global_coding_idx_offset, global_order_idx_offset);
254   EXPECT_EQ(gop_struct.show_frame_count, show_frame_count);
255   TestGopDisplayOrder(gop_struct);
256   TestGopGlobalOrderIdx(gop_struct, global_order_idx_offset);
257   TestGopGlobalCodingIdx(gop_struct, global_coding_idx_offset);
258   TestColocatedShowFrame(gop_struct);
259   const int max_layer_depth = ref_frame_manager.MaxRefFrame();
260   TestLayerDepth(gop_struct, max_layer_depth);
261   TestArfInterval(gop_struct);
262 }
263 
TEST_F(RateControlQModeTest,ConstructGopKey)264 TEST_F(RateControlQModeTest, ConstructGopKey) {
265   const int show_frame_count = 16;
266   const bool has_key_frame = true;
267   const int global_coding_idx_offset = 10;
268   const int global_order_idx_offset = 8;
269   RefFrameManager ref_frame_manager(kRefFrameTableSize, 7);
270   GopStruct gop_struct =
271       ConstructGop(&ref_frame_manager, show_frame_count, has_key_frame,
272                    global_coding_idx_offset, global_order_idx_offset);
273   EXPECT_EQ(gop_struct.show_frame_count, show_frame_count);
274   TestGopDisplayOrder(gop_struct);
275   TestGopGlobalOrderIdx(gop_struct, global_order_idx_offset);
276   TestGopGlobalCodingIdx(gop_struct, global_coding_idx_offset);
277   TestColocatedShowFrame(gop_struct);
278   const int max_layer_depth = ref_frame_manager.MaxRefFrame();
279   TestLayerDepth(gop_struct, max_layer_depth);
280   TestArfInterval(gop_struct);
281 }
282 
TEST_F(RateControlQModeTest,ConstructShortGop)283 TEST_F(RateControlQModeTest, ConstructShortGop) {
284   int show_frame_count = 2;
285   const bool has_key_frame = false;
286   const int global_coding_idx_offset = 5;
287   const int global_order_idx_offset = 20;
288   RefFrameManager ref_frame_manager(kRefFrameTableSize, 7);
289   GopStruct gop_struct =
290       ConstructGop(&ref_frame_manager, show_frame_count, has_key_frame,
291                    global_coding_idx_offset, global_order_idx_offset);
292   EXPECT_EQ(gop_struct.show_frame_count, show_frame_count);
293   TestGopDisplayOrder(gop_struct);
294   TestGopGlobalOrderIdx(gop_struct, global_order_idx_offset);
295   TestGopGlobalCodingIdx(gop_struct, global_coding_idx_offset);
296   TestColocatedShowFrame(gop_struct);
297   const int max_layer_depth = 1 + kLayerDepthOffset;
298   TestLayerDepth(gop_struct, max_layer_depth);
299   TestArfInterval(gop_struct);
300 }
301 
CreateToyTplBlockStats(int h,int w,int r,int c,int intra_cost,int inter_cost)302 static TplBlockStats CreateToyTplBlockStats(int h, int w, int r, int c,
303                                             int intra_cost, int inter_cost) {
304   TplBlockStats tpl_block_stats = {};
305   tpl_block_stats.height = h;
306   tpl_block_stats.width = w;
307   tpl_block_stats.row = r;
308   tpl_block_stats.col = c;
309   tpl_block_stats.intra_cost = intra_cost;
310   tpl_block_stats.inter_cost = inter_cost;
311   tpl_block_stats.ref_frame_index = { -1, -1 };
312   return tpl_block_stats;
313 }
314 
CreateToyTplFrameStatsWithDiffSizes(int min_block_size,int max_block_size)315 static TplFrameStats CreateToyTplFrameStatsWithDiffSizes(int min_block_size,
316                                                          int max_block_size) {
317   TplFrameStats frame_stats;
318   const int max_h = max_block_size;
319   const int max_w = max_h;
320   const int count = max_block_size / min_block_size;
321   frame_stats.min_block_size = min_block_size;
322   frame_stats.frame_height = max_h * count;
323   frame_stats.frame_width = max_w * count;
324   frame_stats.rate_dist_present = false;
325   for (int i = 0; i < count; ++i) {
326     for (int j = 0; j < count; ++j) {
327       int h = max_h >> i;
328       int w = max_w >> j;
329       for (int u = 0; u * h < max_h; ++u) {
330         for (int v = 0; v * w < max_w; ++v) {
331           int r = max_h * i + h * u;
332           int c = max_w * j + w * v;
333           int intra_cost = std::rand() % 16;
334           TplBlockStats block_stats =
335               CreateToyTplBlockStats(h, w, r, c, intra_cost, 0);
336           frame_stats.block_stats_list.push_back(block_stats);
337         }
338       }
339     }
340   }
341   return frame_stats;
342 }
343 
AugmentTplFrameStatsWithRefFrames(TplFrameStats * tpl_frame_stats,const std::array<int,kBlockRefCount> & ref_frame_index)344 static void AugmentTplFrameStatsWithRefFrames(
345     TplFrameStats *tpl_frame_stats,
346     const std::array<int, kBlockRefCount> &ref_frame_index) {
347   for (auto &block_stats : tpl_frame_stats->block_stats_list) {
348     block_stats.ref_frame_index = ref_frame_index;
349   }
350 }
AugmentTplFrameStatsWithMotionVector(TplFrameStats * tpl_frame_stats,const std::array<MotionVector,kBlockRefCount> & mv)351 static void AugmentTplFrameStatsWithMotionVector(
352     TplFrameStats *tpl_frame_stats,
353     const std::array<MotionVector, kBlockRefCount> &mv) {
354   for (auto &block_stats : tpl_frame_stats->block_stats_list) {
355     block_stats.mv = mv;
356   }
357 }
358 
CreateToyRefFrameTable(int frame_count)359 static RefFrameTable CreateToyRefFrameTable(int frame_count) {
360   RefFrameTable ref_frame_table(kRefFrameTableSize);
361   EXPECT_LE(frame_count, kRefFrameTableSize);
362   for (int i = 0; i < frame_count; ++i) {
363     ref_frame_table[i] =
364         GopFrameBasic(0, 0, i, i, 0, 0, GopFrameType::kRegularLeaf);
365   }
366   for (int i = frame_count; i < kRefFrameTableSize; ++i) {
367     ref_frame_table[i] = GopFrameInvalid();
368   }
369   return ref_frame_table;
370 }
371 
CreateFullpelMv(int row,int col)372 static MotionVector CreateFullpelMv(int row, int col) {
373   return { row, col, 0 };
374 }
375 
TplFrameStatsAccumulateIntraCost(const TplFrameStats & frame_stats)376 double TplFrameStatsAccumulateIntraCost(const TplFrameStats &frame_stats) {
377   double sum = 0;
378   for (auto &block_stats : frame_stats.block_stats_list) {
379     sum += block_stats.intra_cost;
380   }
381   return std::max(sum, 1.0);
382 }
383 
TEST_F(RateControlQModeTest,CreateTplFrameDepStats)384 TEST_F(RateControlQModeTest, CreateTplFrameDepStats) {
385   TplFrameStats frame_stats = CreateToyTplFrameStatsWithDiffSizes(8, 16);
386   StatusOr<TplFrameDepStats> frame_dep_stats =
387       CreateTplFrameDepStatsWithoutPropagation(frame_stats);
388   ASSERT_THAT(frame_dep_stats.status(), IsOkStatus());
389   EXPECT_EQ(frame_stats.min_block_size, frame_dep_stats->unit_size);
390   const int unit_rows = static_cast<int>(frame_dep_stats->unit_stats.size());
391   const int unit_cols = static_cast<int>(frame_dep_stats->unit_stats[0].size());
392   EXPECT_EQ(frame_stats.frame_height, unit_rows * frame_dep_stats->unit_size);
393   EXPECT_EQ(frame_stats.frame_width, unit_cols * frame_dep_stats->unit_size);
394   const double intra_cost_sum =
395       TplFrameDepStatsAccumulateIntraCost(*frame_dep_stats);
396 
397   const double expected_intra_cost_sum =
398       TplFrameStatsAccumulateIntraCost(frame_stats);
399   EXPECT_NEAR(intra_cost_sum, expected_intra_cost_sum, kErrorEpsilon);
400 }
401 
TEST_F(RateControlQModeTest,BlockRowNotAMultipleOfMinBlockSizeError)402 TEST_F(RateControlQModeTest, BlockRowNotAMultipleOfMinBlockSizeError) {
403   TplFrameStats frame_stats = CreateToyTplFrameStatsWithDiffSizes(8, 16);
404   frame_stats.block_stats_list.back().row = 1;
405   auto result = CreateTplFrameDepStatsWithoutPropagation(frame_stats);
406   EXPECT_FALSE(result.ok());
407   EXPECT_THAT(result.status().message, HasSubstr("must be a multiple of 8"));
408 }
409 
TEST_F(RateControlQModeTest,BlockPositionOutOfRangeError)410 TEST_F(RateControlQModeTest, BlockPositionOutOfRangeError) {
411   TplFrameStats frame_stats = CreateToyTplFrameStatsWithDiffSizes(8, 16);
412   frame_stats.block_stats_list.back().row += 8;
413   auto result = CreateTplFrameDepStatsWithoutPropagation(frame_stats);
414   EXPECT_FALSE(result.ok());
415   EXPECT_THAT(result.status().message, HasSubstr("out of range"));
416 }
417 
TEST_F(RateControlQModeTest,GetBlockOverlapArea)418 TEST_F(RateControlQModeTest, GetBlockOverlapArea) {
419   const int size = 8;
420   const int r0 = 8;
421   const int c0 = 9;
422   std::vector<int> r1 = { 8, 10, 16, 10, 8, 100 };
423   std::vector<int> c1 = { 9, 12, 17, 5, 100, 9 };
424   std::vector<int> ref_overlap = { 64, 30, 0, 24, 0, 0 };
425   for (int i = 0; i < static_cast<int>(r1.size()); ++i) {
426     const int overlap0 = GetBlockOverlapArea(r0, c0, r1[i], c1[i], size);
427     const int overlap1 = GetBlockOverlapArea(r1[i], c1[i], r0, c0, size);
428     EXPECT_EQ(overlap0, ref_overlap[i]);
429     EXPECT_EQ(overlap1, ref_overlap[i]);
430   }
431 }
432 
TEST_F(RateControlQModeTest,TplBlockStatsToDepStats)433 TEST_F(RateControlQModeTest, TplBlockStatsToDepStats) {
434   const int intra_cost = 100;
435   const int inter_cost = 120;
436   const int unit_count = 2;
437   TplBlockStats block_stats =
438       CreateToyTplBlockStats(8, 4, 0, 0, intra_cost, inter_cost);
439   TplUnitDepStats unit_stats = TplBlockStatsToDepStats(block_stats, unit_count);
440   double expected_intra_cost = intra_cost * 1.0 / unit_count;
441   EXPECT_NEAR(unit_stats.intra_cost, expected_intra_cost, kErrorEpsilon);
442   // When inter_cost >= intra_cost in block_stats, in unit_stats,
443   // the inter_cost will be modified so that it's upper-bounded by intra_cost.
444   EXPECT_LE(unit_stats.inter_cost, unit_stats.intra_cost);
445 }
446 
TEST_F(RateControlQModeTest,TplFrameDepStatsPropagateSingleZeroMotion)447 TEST_F(RateControlQModeTest, TplFrameDepStatsPropagateSingleZeroMotion) {
448   // cur frame with coding_idx 1 use ref frame with coding_idx 0
449   const std::array<int, kBlockRefCount> ref_frame_index = { 0, -1 };
450   TplFrameStats frame_stats = CreateToyTplFrameStatsWithDiffSizes(8, 16);
451   AugmentTplFrameStatsWithRefFrames(&frame_stats, ref_frame_index);
452 
453   TplGopDepStats gop_dep_stats;
454   const int frame_count = 2;
455   // ref frame with coding_idx 0
456   TplFrameDepStats frame_dep_stats0 =
457       CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
458                              frame_stats.min_block_size);
459   gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats0);
460 
461   // cur frame with coding_idx 1
462   const StatusOr<TplFrameDepStats> frame_dep_stats1 =
463       CreateTplFrameDepStatsWithoutPropagation(frame_stats);
464   ASSERT_THAT(frame_dep_stats1.status(), IsOkStatus());
465   gop_dep_stats.frame_dep_stats_list.push_back(std::move(*frame_dep_stats1));
466 
467   const RefFrameTable ref_frame_table = CreateToyRefFrameTable(frame_count);
468   TplFrameDepStatsPropagate(/*coding_idx=*/1, ref_frame_table, &gop_dep_stats);
469 
470   // cur frame with coding_idx 1
471   const double expected_propagation_sum =
472       TplFrameStatsAccumulateIntraCost(frame_stats);
473 
474   // ref frame with coding_idx 0
475   const double propagation_sum =
476       TplFrameDepStatsAccumulate(gop_dep_stats.frame_dep_stats_list[0]);
477 
478   // The propagation_sum between coding_idx 0 and coding_idx 1 should be equal
479   // because every block in cur frame has zero motion, use ref frame with
480   // coding_idx 0 for prediction, and ref frame itself is empty.
481   EXPECT_NEAR(propagation_sum, expected_propagation_sum, kErrorEpsilon);
482 }
483 
TEST_F(RateControlQModeTest,TplFrameDepStatsPropagateCompoundZeroMotion)484 TEST_F(RateControlQModeTest, TplFrameDepStatsPropagateCompoundZeroMotion) {
485   // cur frame with coding_idx 2 use two ref frames with coding_idx 0 and 1
486   const std::array<int, kBlockRefCount> ref_frame_index = { 0, 1 };
487   TplFrameStats frame_stats = CreateToyTplFrameStatsWithDiffSizes(8, 16);
488   AugmentTplFrameStatsWithRefFrames(&frame_stats, ref_frame_index);
489 
490   TplGopDepStats gop_dep_stats;
491   const int frame_count = 3;
492   // ref frame with coding_idx 0
493   const TplFrameDepStats frame_dep_stats0 =
494       CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
495                              frame_stats.min_block_size);
496   gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats0);
497 
498   // ref frame with coding_idx 1
499   const TplFrameDepStats frame_dep_stats1 =
500       CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
501                              frame_stats.min_block_size);
502   gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats1);
503 
504   // cur frame with coding_idx 2
505   const StatusOr<TplFrameDepStats> frame_dep_stats2 =
506       CreateTplFrameDepStatsWithoutPropagation(frame_stats);
507   ASSERT_THAT(frame_dep_stats2.status(), IsOkStatus());
508   gop_dep_stats.frame_dep_stats_list.push_back(std::move(*frame_dep_stats2));
509 
510   const RefFrameTable ref_frame_table = CreateToyRefFrameTable(frame_count);
511   TplFrameDepStatsPropagate(/*coding_idx=*/2, ref_frame_table, &gop_dep_stats);
512 
513   // cur frame with coding_idx 1
514   const double expected_ref_sum = TplFrameStatsAccumulateIntraCost(frame_stats);
515 
516   // ref frame with coding_idx 0
517   const double cost_sum0 =
518       TplFrameDepStatsAccumulate(gop_dep_stats.frame_dep_stats_list[0]);
519   EXPECT_NEAR(cost_sum0, expected_ref_sum * 0.5, kErrorEpsilon);
520 
521   // ref frame with coding_idx 1
522   const double cost_sum1 =
523       TplFrameDepStatsAccumulate(gop_dep_stats.frame_dep_stats_list[1]);
524   EXPECT_NEAR(cost_sum1, expected_ref_sum * 0.5, kErrorEpsilon);
525 }
526 
TEST_F(RateControlQModeTest,TplFrameDepStatsPropagateSingleWithMotion)527 TEST_F(RateControlQModeTest, TplFrameDepStatsPropagateSingleWithMotion) {
528   // cur frame with coding_idx 1 use ref frame with coding_idx 0
529   const std::array<int, kBlockRefCount> ref_frame_index = { 0, -1 };
530   const int min_block_size = 8;
531   TplFrameStats frame_stats =
532       CreateToyTplFrameStatsWithDiffSizes(min_block_size, min_block_size * 2);
533   AugmentTplFrameStatsWithRefFrames(&frame_stats, ref_frame_index);
534 
535   const int mv_row = min_block_size / 2;
536   const int mv_col = min_block_size / 4;
537   const double r_ratio = 1.0 / 2;
538   const double c_ratio = 1.0 / 4;
539   std::array<MotionVector, kBlockRefCount> mv;
540   mv[0] = CreateFullpelMv(mv_row, mv_col);
541   mv[1] = CreateFullpelMv(0, 0);
542   AugmentTplFrameStatsWithMotionVector(&frame_stats, mv);
543 
544   TplGopDepStats gop_dep_stats;
545   const int frame_count = 2;
546   // ref frame with coding_idx 0
547   gop_dep_stats.frame_dep_stats_list.push_back(
548       CreateTplFrameDepStats(frame_stats.frame_height, frame_stats.frame_width,
549                              frame_stats.min_block_size));
550 
551   // cur frame with coding_idx 1
552   const StatusOr<TplFrameDepStats> frame_dep_stats =
553       CreateTplFrameDepStatsWithoutPropagation(frame_stats);
554   ASSERT_THAT(frame_dep_stats.status(), IsOkStatus());
555   gop_dep_stats.frame_dep_stats_list.push_back(std::move(*frame_dep_stats));
556 
557   const RefFrameTable ref_frame_table = CreateToyRefFrameTable(frame_count);
558   TplFrameDepStatsPropagate(/*coding_idx=*/1, ref_frame_table, &gop_dep_stats);
559 
560   const auto &dep_stats0 = gop_dep_stats.frame_dep_stats_list[0];
561   const auto &dep_stats1 = gop_dep_stats.frame_dep_stats_list[1];
562   const int unit_rows = static_cast<int>(dep_stats0.unit_stats.size());
563   const int unit_cols = static_cast<int>(dep_stats0.unit_stats[0].size());
564   for (int r = 0; r < unit_rows; ++r) {
565     for (int c = 0; c < unit_cols; ++c) {
566       double ref_value = 0;
567       ref_value += (1 - r_ratio) * (1 - c_ratio) *
568                    dep_stats1.unit_stats[r][c].intra_cost;
569       if (r - 1 >= 0) {
570         ref_value += r_ratio * (1 - c_ratio) *
571                      dep_stats1.unit_stats[r - 1][c].intra_cost;
572       }
573       if (c - 1 >= 0) {
574         ref_value += (1 - r_ratio) * c_ratio *
575                      dep_stats1.unit_stats[r][c - 1].intra_cost;
576       }
577       if (r - 1 >= 0 && c - 1 >= 0) {
578         ref_value +=
579             r_ratio * c_ratio * dep_stats1.unit_stats[r - 1][c - 1].intra_cost;
580       }
581       EXPECT_NEAR(dep_stats0.unit_stats[r][c].propagation_cost, ref_value,
582                   kErrorEpsilon);
583     }
584   }
585 }
586 
587 // TODO(jianj): Add tests for non empty lookahead stats.
TEST_F(RateControlQModeTest,ComputeTplGopDepStats)588 TEST_F(RateControlQModeTest, ComputeTplGopDepStats) {
589   TplGopStats tpl_gop_stats;
590   std::vector<RefFrameTable> ref_frame_table_list;
591   GopStruct gop_struct;
592   gop_struct.show_frame_count = 3;
593   for (int i = 0; i < 3; i++) {
594     // Use the previous frame as reference
595     const std::array<int, kBlockRefCount> ref_frame_index = { i - 1, -1 };
596     int min_block_size = 8;
597     TplFrameStats frame_stats =
598         CreateToyTplFrameStatsWithDiffSizes(min_block_size, min_block_size * 2);
599     AugmentTplFrameStatsWithRefFrames(&frame_stats, ref_frame_index);
600     tpl_gop_stats.frame_stats_list.push_back(frame_stats);
601 
602     ref_frame_table_list.push_back(CreateToyRefFrameTable(i));
603   }
604   const StatusOr<TplGopDepStats> gop_dep_stats =
605       ComputeTplGopDepStats(tpl_gop_stats, {}, ref_frame_table_list);
606   ASSERT_THAT(gop_dep_stats.status(), IsOkStatus());
607 
608   double expected_sum = 0;
609   for (int i = 2; i >= 0; i--) {
610     // Due to the linear propagation with zero motion, we can accumulate the
611     // frame_stats intra_cost and use it as expected sum for dependency stats
612     expected_sum +=
613         TplFrameStatsAccumulateIntraCost(tpl_gop_stats.frame_stats_list[i]);
614     const double sum =
615         TplFrameDepStatsAccumulate(gop_dep_stats->frame_dep_stats_list[i]);
616     EXPECT_NEAR(sum, expected_sum, kErrorEpsilon);
617     break;
618   }
619 }
620 
TEST(RefFrameManagerTest,GetRefFrameCount)621 TEST(RefFrameManagerTest, GetRefFrameCount) {
622   const std::vector<int> order_idx_list = { 0, 4, 2, 1, 2, 3, 4 };
623   const std::vector<GopFrameType> type_list = {
624     GopFrameType::kRegularKey,
625     GopFrameType::kRegularArf,
626     GopFrameType::kIntermediateArf,
627     GopFrameType::kRegularLeaf,
628     GopFrameType::kIntermediateOverlay,
629     GopFrameType::kRegularLeaf,
630     GopFrameType::kOverlay
631   };
632   RefFrameManager ref_manager(kRefFrameTableSize, 7);
633   int coding_idx = 0;
634   const int first_leaf_idx = 3;
635   EXPECT_EQ(type_list[first_leaf_idx], GopFrameType::kRegularLeaf);
636   // update reference frame until we see the first kRegularLeaf frame
637   for (; coding_idx <= first_leaf_idx; ++coding_idx) {
638     GopFrame gop_frame =
639         GopFrameBasic(0, 0, coding_idx, order_idx_list[coding_idx], 0, 0,
640                       type_list[coding_idx]);
641     ref_manager.UpdateRefFrameTable(&gop_frame);
642   }
643   EXPECT_EQ(ref_manager.GetRefFrameCount(), 4);
644   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kForward), 2);
645   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kBackward), 1);
646   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kLast), 1);
647   EXPECT_EQ(ref_manager.CurGlobalOrderIdx(), 1);
648 
649   // update reference frame until we see the first kShowExisting frame
650   const int first_show_existing_idx = 4;
651   EXPECT_EQ(type_list[first_show_existing_idx],
652             GopFrameType::kIntermediateOverlay);
653   for (; coding_idx <= first_show_existing_idx; ++coding_idx) {
654     GopFrame gop_frame =
655         GopFrameBasic(0, 0, coding_idx, order_idx_list[coding_idx], 0, 0,
656                       type_list[coding_idx]);
657     ref_manager.UpdateRefFrameTable(&gop_frame);
658   }
659   EXPECT_EQ(ref_manager.GetRefFrameCount(), 4);
660   EXPECT_EQ(ref_manager.CurGlobalOrderIdx(), 2);
661   // After the first kShowExisting, the kIntermediateArf should be moved from
662   // kForward to kLast due to the cur_global_order_idx_ update
663   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kForward), 1);
664   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kBackward), 2);
665   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kLast), 1);
666 
667   const int second_leaf_idx = 5;
668   EXPECT_EQ(type_list[second_leaf_idx], GopFrameType::kRegularLeaf);
669   for (; coding_idx <= second_leaf_idx; ++coding_idx) {
670     GopFrame gop_frame =
671         GopFrameBasic(0, 0, coding_idx, order_idx_list[coding_idx], 0, 0,
672                       type_list[coding_idx]);
673     ref_manager.UpdateRefFrameTable(&gop_frame);
674   }
675   EXPECT_EQ(ref_manager.GetRefFrameCount(), 5);
676   EXPECT_EQ(ref_manager.CurGlobalOrderIdx(), 3);
677   // An additional kRegularLeaf frame is added into kLast
678   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kForward), 1);
679   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kBackward), 2);
680   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kLast), 2);
681 
682   const int first_overlay_idx = 6;
683   EXPECT_EQ(type_list[first_overlay_idx], GopFrameType::kOverlay);
684   for (; coding_idx <= first_overlay_idx; ++coding_idx) {
685     GopFrame gop_frame =
686         GopFrameBasic(0, 0, coding_idx, order_idx_list[coding_idx], 0, 0,
687                       type_list[coding_idx]);
688     ref_manager.UpdateRefFrameTable(&gop_frame);
689   }
690 
691   EXPECT_EQ(ref_manager.GetRefFrameCount(), 5);
692   EXPECT_EQ(ref_manager.CurGlobalOrderIdx(), 4);
693   // After the kOverlay, the kRegularArf should be moved from
694   // kForward to kBackward due to the cur_global_order_idx_ update
695   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kForward), 0);
696   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kBackward), 3);
697   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kLast), 2);
698 }
699 
TestRefFrameManagerPriority(const RefFrameManager & ref_manager,RefUpdateType type)700 void TestRefFrameManagerPriority(const RefFrameManager &ref_manager,
701                                  RefUpdateType type) {
702   int ref_count = ref_manager.GetRefFrameCountByType(type);
703   int prev_global_order_idx = ref_manager.CurGlobalOrderIdx();
704   // The lower the priority is, the closer the gop_frame.global_order_idx should
705   // be with cur_global_order_idx_, with exception of a base layer ARF.
706   for (int priority = 0; priority < ref_count; ++priority) {
707     GopFrame gop_frame = ref_manager.GetRefFrameByPriority(type, priority);
708     EXPECT_EQ(gop_frame.is_valid, true);
709     if (type == RefUpdateType::kForward) {
710       if (priority == 0) continue;
711       EXPECT_GE(gop_frame.global_order_idx, prev_global_order_idx);
712     } else {
713       EXPECT_LE(gop_frame.global_order_idx, prev_global_order_idx);
714     }
715     prev_global_order_idx = gop_frame.global_order_idx;
716   }
717   GopFrame gop_frame =
718       ref_manager.GetRefFrameByPriority(RefUpdateType::kForward, ref_count);
719   EXPECT_EQ(gop_frame.is_valid, false);
720 }
721 
TEST(RefFrameManagerTest,GetRefFrameByPriority)722 TEST(RefFrameManagerTest, GetRefFrameByPriority) {
723   const std::vector<int> order_idx_list = { 0, 4, 2, 1, 2, 3, 4 };
724   const std::vector<GopFrameType> type_list = {
725     GopFrameType::kRegularKey,
726     GopFrameType::kRegularArf,
727     GopFrameType::kIntermediateArf,
728     GopFrameType::kRegularLeaf,
729     GopFrameType::kIntermediateOverlay,
730     GopFrameType::kRegularLeaf,
731     GopFrameType::kOverlay
732   };
733   RefFrameManager ref_manager(kRefFrameTableSize, 7);
734   int coding_idx = 0;
735   const int first_leaf_idx = 3;
736   EXPECT_EQ(type_list[first_leaf_idx], GopFrameType::kRegularLeaf);
737   // update reference frame until we see the first kRegularLeaf frame
738   for (; coding_idx <= first_leaf_idx; ++coding_idx) {
739     GopFrame gop_frame =
740         GopFrameBasic(0, 0, coding_idx, order_idx_list[coding_idx], 0, 0,
741                       type_list[coding_idx]);
742     ref_manager.UpdateRefFrameTable(&gop_frame);
743   }
744   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kForward), 2);
745   TestRefFrameManagerPriority(ref_manager, RefUpdateType::kForward);
746 
747   const int first_overlay_idx = 6;
748   EXPECT_EQ(type_list[first_overlay_idx], GopFrameType::kOverlay);
749   for (; coding_idx <= first_overlay_idx; ++coding_idx) {
750     GopFrame gop_frame =
751         GopFrameBasic(0, 0, coding_idx, order_idx_list[coding_idx], 0, 0,
752                       type_list[coding_idx]);
753     ref_manager.UpdateRefFrameTable(&gop_frame);
754   }
755 
756   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kBackward), 3);
757   TestRefFrameManagerPriority(ref_manager, RefUpdateType::kBackward);
758   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kLast), 2);
759   TestRefFrameManagerPriority(ref_manager, RefUpdateType::kLast);
760 }
761 
TEST(RefFrameManagerTest,GetRefFrameListByPriority)762 TEST(RefFrameManagerTest, GetRefFrameListByPriority) {
763   const std::vector<int> order_idx_list = { 0, 4, 2, 1 };
764   const int frame_count = static_cast<int>(order_idx_list.size());
765   const std::vector<GopFrameType> type_list = { GopFrameType::kRegularKey,
766                                                 GopFrameType::kRegularArf,
767                                                 GopFrameType::kIntermediateArf,
768                                                 GopFrameType::kRegularLeaf };
769   RefFrameManager ref_manager(kRefFrameTableSize, 7);
770   for (int coding_idx = 0; coding_idx < frame_count; ++coding_idx) {
771     GopFrame gop_frame =
772         GopFrameBasic(0, 0, coding_idx, order_idx_list[coding_idx], 0, 0,
773                       type_list[coding_idx]);
774     ref_manager.UpdateRefFrameTable(&gop_frame);
775   }
776   EXPECT_EQ(ref_manager.GetRefFrameCount(), frame_count);
777   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kForward), 2);
778   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kBackward), 1);
779   EXPECT_EQ(ref_manager.GetRefFrameCountByType(RefUpdateType::kLast), 1);
780   std::vector<ReferenceFrame> ref_frame_list =
781       ref_manager.GetRefFrameListByPriority();
782   EXPECT_EQ(ref_frame_list.size(), order_idx_list.size());
783   std::vector<int> expected_global_order_idx = { 4, 0, 1, 2 };
784   std::vector<ReferenceName> expected_names = { ReferenceName::kAltrefFrame,
785                                                 ReferenceName::kGoldenFrame,
786                                                 ReferenceName::kLastFrame,
787                                                 ReferenceName::kBwdrefFrame };
788   for (size_t i = 0; i < ref_frame_list.size(); ++i) {
789     ReferenceFrame &ref_frame = ref_frame_list[i];
790     GopFrame gop_frame = ref_manager.GetRefFrameByIndex(ref_frame.index);
791     EXPECT_EQ(gop_frame.global_order_idx, expected_global_order_idx[i]);
792     EXPECT_EQ(ref_frame.name, expected_names[i]);
793   }
794 }
795 
TEST(RefFrameManagerTest,GetPrimaryRefFrame)796 TEST(RefFrameManagerTest, GetPrimaryRefFrame) {
797   const std::vector<int> order_idx_list = { 0, 4, 2, 1 };
798   const int frame_count = static_cast<int>(order_idx_list.size());
799   const std::vector<GopFrameType> type_list = { GopFrameType::kRegularKey,
800                                                 GopFrameType::kRegularArf,
801                                                 GopFrameType::kIntermediateArf,
802                                                 GopFrameType::kRegularLeaf };
803   const std::vector<int> layer_depth_list = { 0, 2, 4, 6 };
804   RefFrameManager ref_manager(kRefFrameTableSize, 7);
805   for (int coding_idx = 0; coding_idx < frame_count; ++coding_idx) {
806     GopFrame gop_frame =
807         GopFrameBasic(0, 0, coding_idx, order_idx_list[coding_idx],
808                       layer_depth_list[coding_idx], 0, type_list[coding_idx]);
809     ref_manager.UpdateRefFrameTable(&gop_frame);
810   }
811 
812   for (int i = 0; i < frame_count; ++i) {
813     // Test frame that share the same layer depth with a reference frame
814     int layer_depth = layer_depth_list[i];
815     // Set different frame type
816     GopFrameType type = type_list[(i + 1) % frame_count];
817     GopFrame gop_frame = GopFrameBasic(0, 0, 0, 0, layer_depth, 0, type);
818     gop_frame.ref_frame_list = ref_manager.GetRefFrameListByPriority();
819     ReferenceFrame ref_frame = ref_manager.GetPrimaryRefFrame(gop_frame);
820     GopFrame primary_ref_frame =
821         ref_manager.GetRefFrameByIndex(ref_frame.index);
822     // The GetPrimaryRefFrame should find the ref_frame with matched layer depth
823     // because it's our first priority
824     EXPECT_EQ(primary_ref_frame.layer_depth, gop_frame.layer_depth);
825   }
826 
827   const std::vector<int> mid_layer_depth_list = { 1, 3, 5 };
828   for (int i = 0; i < 3; ++i) {
829     // Test frame that share the same frame type with a reference frame
830     GopFrameType type = type_list[i];
831     // Let the frame layer_depth sit in the middle of two reference frames
832     int layer_depth = mid_layer_depth_list[i];
833     GopFrame gop_frame = GopFrameBasic(0, 0, 0, 0, layer_depth, 0, type);
834     gop_frame.ref_frame_list = ref_manager.GetRefFrameListByPriority();
835     ReferenceFrame ref_frame = ref_manager.GetPrimaryRefFrame(gop_frame);
836     GopFrame primary_ref_frame =
837         ref_manager.GetRefFrameByIndex(ref_frame.index);
838     // The GetPrimaryRefFrame should find the ref_frame with matched frame type
839     // Here we use coding_idx to confirm that.
840     EXPECT_EQ(primary_ref_frame.coding_idx, i);
841   }
842 }
843 
TEST_F(RateControlQModeTest,TestKeyframeDetection)844 TEST_F(RateControlQModeTest, TestKeyframeDetection) {
845   FirstpassInfo firstpass_info;
846   const std::string kFirstpassStatsFile = "firstpass_stats";
847   ASSERT_NO_FATAL_FAILURE(
848       ReadFirstpassInfo(kFirstpassStatsFile, &firstpass_info, kFrameLimit));
849   EXPECT_THAT(GetKeyFrameList(firstpass_info),
850               ElementsAre(0, 30, 60, 90, 120, 150, 180, 210, 240));
851 }
852 
853 MATCHER_P(GopFrameMatches, expected, "") {
854 #define COMPARE_FIELD(FIELD)                                   \
855   do {                                                         \
856     if (arg.FIELD != expected.FIELD) {                         \
857       *result_listener << "where " #FIELD " is " << arg.FIELD  \
858                        << " but should be " << expected.FIELD; \
859       return false;                                            \
860     }                                                          \
861   } while (0)
862   COMPARE_FIELD(is_valid);
863   COMPARE_FIELD(order_idx);
864   COMPARE_FIELD(coding_idx);
865   COMPARE_FIELD(global_order_idx);
866   COMPARE_FIELD(global_coding_idx);
867   COMPARE_FIELD(is_key_frame);
868   COMPARE_FIELD(is_arf_frame);
869   COMPARE_FIELD(is_show_frame);
870   COMPARE_FIELD(is_golden_frame);
871   COMPARE_FIELD(colocated_ref_idx);
872   COMPARE_FIELD(update_ref_idx);
873   COMPARE_FIELD(layer_depth);
874 #undef COMPARE_FIELD
875 
876   return true;
877 }
878 
879 // Helper for tests which need to set update_ref_idx, but for which the indices
880 // and depth don't matter (other than to allow creating multiple GopFrames which
881 // are distinguishable).
GopFrameUpdateRefIdx(int index,GopFrameType gop_frame_type,int update_ref_idx)882 GopFrame GopFrameUpdateRefIdx(int index, GopFrameType gop_frame_type,
883                               int update_ref_idx) {
884   GopFrame frame =
885       GopFrameBasic(0, 0, index, index, /*depth=*/0, 0, gop_frame_type);
886   frame.update_ref_idx = update_ref_idx;
887   return frame;
888 }
889 
TEST_F(RateControlQModeTest,TestInvalidRateControlParam)890 TEST_F(RateControlQModeTest, TestInvalidRateControlParam) {
891   // Default constructed RateControlParam should not be valid.
892   RateControlParam rc_param = {};
893   EXPECT_NE(AV1RateControlQMode().SetRcParam(rc_param).code, AOM_CODEC_OK);
894 }
895 
TEST_F(RateControlQModeTest,TestInvalidMaxGopShowFrameCount)896 TEST_F(RateControlQModeTest, TestInvalidMaxGopShowFrameCount) {
897   rc_param_.min_gop_show_frame_count = 2;
898   rc_param_.max_gop_show_frame_count = 3;
899   Status status = AV1RateControlQMode().SetRcParam(rc_param_);
900   EXPECT_EQ(status.code, AOM_CODEC_INVALID_PARAM);
901   EXPECT_THAT(status.message,
902               HasSubstr("max_gop_show_frame_count (3) must be at least 4"));
903 }
904 
TEST_F(RateControlQModeTest,TestInvalidMinGopShowFrameCount)905 TEST_F(RateControlQModeTest, TestInvalidMinGopShowFrameCount) {
906   rc_param_.min_gop_show_frame_count = 9;
907   rc_param_.max_gop_show_frame_count = 8;
908   Status status = AV1RateControlQMode().SetRcParam(rc_param_);
909   EXPECT_EQ(status.code, AOM_CODEC_INVALID_PARAM);
910   EXPECT_THAT(status.message,
911               HasSubstr("may not be less than min_gop_show_frame_count (9)"));
912 }
913 
TEST_F(RateControlQModeTest,TestInvalidRefFrameTableSize)914 TEST_F(RateControlQModeTest, TestInvalidRefFrameTableSize) {
915   rc_param_.ref_frame_table_size = 9;
916   Status status = AV1RateControlQMode().SetRcParam(rc_param_);
917   EXPECT_EQ(status.code, AOM_CODEC_INVALID_PARAM);
918   EXPECT_THAT(status.message,
919               HasSubstr("ref_frame_table_size (9) must be in the range"));
920 }
921 
TEST_F(RateControlQModeTest,TestInvalidMaxRefFrames)922 TEST_F(RateControlQModeTest, TestInvalidMaxRefFrames) {
923   rc_param_.max_ref_frames = 8;
924   Status status = AV1RateControlQMode().SetRcParam(rc_param_);
925   EXPECT_EQ(status.code, AOM_CODEC_INVALID_PARAM);
926   EXPECT_THAT(status.message,
927               HasSubstr("max_ref_frames (8) must be in the range"));
928 }
929 
TEST_F(RateControlQModeTest,TestInvalidBaseQIndex)930 TEST_F(RateControlQModeTest, TestInvalidBaseQIndex) {
931   rc_param_.base_q_index = 256;
932   Status status = AV1RateControlQMode().SetRcParam(rc_param_);
933   EXPECT_EQ(status.code, AOM_CODEC_INVALID_PARAM);
934   EXPECT_THAT(status.message,
935               HasSubstr("base_q_index (256) must be in the range"));
936 }
937 
TEST_F(RateControlQModeTest,TestInvalidFrameHeight)938 TEST_F(RateControlQModeTest, TestInvalidFrameHeight) {
939   rc_param_.frame_height = 15;
940   Status status = AV1RateControlQMode().SetRcParam(rc_param_);
941   EXPECT_EQ(status.code, AOM_CODEC_INVALID_PARAM);
942   EXPECT_THAT(status.message,
943               HasSubstr("frame_height (15) must be in the range"));
944 }
945 
TEST_F(RateControlQModeTest,TestGetRefFrameTableListFirstGop)946 TEST_F(RateControlQModeTest, TestGetRefFrameTableListFirstGop) {
947   AV1RateControlQMode rc;
948   rc_param_.ref_frame_table_size = 3;
949   ASSERT_THAT(rc.SetRcParam(rc_param_), IsOkStatus());
950 
951   const auto invalid = GopFrameInvalid();
952   const auto frame0 = GopFrameUpdateRefIdx(0, GopFrameType::kRegularKey, -1);
953   const auto frame1 = GopFrameUpdateRefIdx(1, GopFrameType::kRegularLeaf, 2);
954   const auto frame2 = GopFrameUpdateRefIdx(2, GopFrameType::kRegularLeaf, 0);
955 
956   const auto matches_invalid = GopFrameMatches(invalid);
957   const auto matches_frame0 = GopFrameMatches(frame0);
958   const auto matches_frame1 = GopFrameMatches(frame1);
959   const auto matches_frame2 = GopFrameMatches(frame2);
960 
961   GopStruct gop_struct;
962   gop_struct.global_coding_idx_offset = 0;  // This is the first GOP.
963   gop_struct.gop_frame_list = { frame0, frame1, frame2 };
964   ASSERT_THAT(
965       // For the first GOP only, GetRefFrameTableList can be passed a
966       // default-constructed RefFrameTable (because it's all going to be
967       // replaced by the key frame anyway).
968       rc.GetRefFrameTableList(gop_struct, {}, RefFrameTable()),
969       ElementsAre(
970           ElementsAre(matches_invalid, matches_invalid, matches_invalid),
971           ElementsAre(matches_frame0, matches_frame0, matches_frame0),
972           ElementsAre(matches_frame0, matches_frame0, matches_frame1),
973           ElementsAre(matches_frame2, matches_frame0, matches_frame1)));
974 }
975 
TEST_F(RateControlQModeTest,TestGetRefFrameTableListNotFirstGop)976 TEST_F(RateControlQModeTest, TestGetRefFrameTableListNotFirstGop) {
977   AV1RateControlQMode rc;
978   rc_param_.ref_frame_table_size = 3;
979   ASSERT_THAT(rc.SetRcParam(rc_param_), IsOkStatus());
980 
981   const auto previous = GopFrameUpdateRefIdx(0, GopFrameType::kRegularKey, -1);
982   const auto frame0 = GopFrameUpdateRefIdx(5, GopFrameType::kRegularLeaf, 2);
983   const auto frame1 = GopFrameUpdateRefIdx(6, GopFrameType::kRegularLeaf, -1);
984   const auto frame2 = GopFrameUpdateRefIdx(7, GopFrameType::kRegularLeaf, 0);
985 
986   // Frames in the initial table should have coding_idx of -1
987   // to prevent propagating TPL stats to already coded frames.
988   auto previous_modified = previous;
989   previous_modified.coding_idx = -1;
990   const auto matches_previous = GopFrameMatches(previous_modified);
991   const auto matches_frame0 = GopFrameMatches(frame0);
992   const auto matches_frame2 = GopFrameMatches(frame2);
993 
994   GopStruct gop_struct;
995   gop_struct.global_coding_idx_offset = 5;  // This is not the first GOP.
996   gop_struct.gop_frame_list = { frame0, frame1, frame2 };
997   ASSERT_THAT(
998       rc.GetRefFrameTableList(gop_struct, {}, RefFrameTable(3, previous)),
999       ElementsAre(
1000           ElementsAre(matches_previous, matches_previous, matches_previous),
1001           ElementsAre(matches_previous, matches_previous, matches_frame0),
1002           ElementsAre(matches_previous, matches_previous, matches_frame0),
1003           ElementsAre(matches_frame2, matches_previous, matches_frame0)));
1004 }
1005 
TEST_F(RateControlQModeTest,TestGopIntervals)1006 TEST_F(RateControlQModeTest, TestGopIntervals) {
1007   FirstpassInfo firstpass_info;
1008   ASSERT_NO_FATAL_FAILURE(
1009       ReadFirstpassInfo("firstpass_stats", &firstpass_info, kFrameLimit));
1010   AV1RateControlQMode rc;
1011   ASSERT_THAT(rc.SetRcParam(rc_param_), IsOkStatus());
1012 
1013   const auto gop_info = rc.DetermineGopInfo(firstpass_info);
1014   ASSERT_THAT(gop_info.status(), IsOkStatus());
1015   std::vector<int> gop_interval_list;
1016   std::transform(gop_info->begin(), gop_info->end(),
1017                  std::back_inserter(gop_interval_list),
1018                  [](GopStruct const &x) { return x.show_frame_count; });
1019   EXPECT_THAT(gop_interval_list,
1020               ElementsAre(21, 9, 30, 30, 16, 14, 21, 9, 30, 12, 16, 2, 30, 10));
1021 }
1022 
1023 // TODO(b/242892473): Add a test which passes lookahead GOPs.
TEST_F(RateControlQModeTest,TestGetGopEncodeInfo)1024 TEST_F(RateControlQModeTest, TestGetGopEncodeInfo) {
1025   FirstpassInfo firstpass_info;
1026   ASSERT_NO_FATAL_FAILURE(
1027       ReadFirstpassInfo("firstpass_stats", &firstpass_info, 50));
1028   AV1RateControlQMode rc;
1029   rc_param_.max_gop_show_frame_count = 16;
1030   rc_param_.max_ref_frames = 3;
1031   rc_param_.base_q_index = 117;
1032   ASSERT_THAT(rc.SetRcParam(rc_param_), IsOkStatus());
1033   const auto gop_info = rc.DetermineGopInfo(firstpass_info);
1034   ASSERT_THAT(gop_info.status(), IsOkStatus());
1035   const GopStructList &gop_list = *gop_info;
1036   const aom_rational_t frame_rate = { 30, 1 };
1037   const aom::VideoInfo input_video = {
1038     kFrameWidth, kFrameHeight,
1039     frame_rate,  AOM_IMG_FMT_I420,
1040     50,          libaom_test::GetDataPath() + "/hantro_collage_w352h288.yuv"
1041   };
1042   DuckyEncode ducky_encode(input_video, BLOCK_64X64, rc_param_.max_ref_frames,
1043                            3, rc_param_.base_q_index);
1044 
1045   std::vector<aom::GopEncodeInfo> gop_encode_info_list;
1046   for (const auto &gop_struct : gop_list) {
1047     const auto gop_encode_info =
1048         rc.GetTplPassGopEncodeInfo(gop_struct, firstpass_info);
1049     ASSERT_TRUE(gop_encode_info.ok());
1050     gop_encode_info_list.push_back(gop_encode_info.value());
1051   }
1052 
1053   // Read TPL stats
1054   std::vector<TplGopStats> tpl_gop_list = ducky_encode.ComputeTplStats(
1055       firstpass_info.stats_list, gop_list, gop_encode_info_list);
1056 
1057   RefFrameTable ref_frame_table;
1058   int num_gop_skipped = 0;
1059   for (size_t gop_idx = 0; gop_idx < gop_list.size(); gop_idx++) {
1060     size_t tpl_gop_idx = gop_idx - num_gop_skipped;
1061     const auto gop_encode_info =
1062         rc.GetGopEncodeInfo(gop_list[gop_idx], tpl_gop_list[tpl_gop_idx], {},
1063                             firstpass_info, ref_frame_table);
1064     ASSERT_THAT(gop_encode_info.status(), IsOkStatus());
1065     for (auto &frame_param : gop_encode_info->param_list) {
1066       EXPECT_LE(frame_param.q_index, rc_param_.base_q_index);
1067     }
1068     ref_frame_table = gop_encode_info->final_snapshot;
1069     for (auto &gop_frame : ref_frame_table) {
1070       EXPECT_LE(static_cast<int>(gop_frame.ref_frame_list.size()),
1071                 rc_param_.max_ref_frames);
1072     }
1073   }
1074 }
1075 
TEST_F(RateControlQModeTest,GetGopEncodeInfoWrongGopSize)1076 TEST_F(RateControlQModeTest, GetGopEncodeInfoWrongGopSize) {
1077   GopStruct gop_struct;
1078   gop_struct.gop_frame_list.assign(7, GopFrameInvalid());
1079   TplGopStats tpl_gop_stats;
1080   tpl_gop_stats.frame_stats_list.assign(
1081       5, CreateToyTplFrameStatsWithDiffSizes(8, 8));
1082   AV1RateControlQMode rc;
1083   const Status status =
1084       rc.GetGopEncodeInfo(gop_struct, tpl_gop_stats, {}, {}, RefFrameTable())
1085           .status();
1086   EXPECT_EQ(status.code, AOM_CODEC_INVALID_PARAM);
1087   EXPECT_THAT(status.message,
1088               HasSubstr("Frame count of GopStruct (7) doesn't match frame "
1089                         "count of TPL stats (5)"));
1090 }
1091 
TEST_F(RateControlQModeTest,GetGopEncodeInfoRefFrameMissingBlockStats)1092 TEST_F(RateControlQModeTest, GetGopEncodeInfoRefFrameMissingBlockStats) {
1093   GopStruct gop_struct;
1094   // Frames 0 and 2 are reference frames.
1095   gop_struct.gop_frame_list = {
1096     GopFrameUpdateRefIdx(0, GopFrameType::kRegularKey, 1),
1097     GopFrameUpdateRefIdx(1, GopFrameType::kRegularLeaf, -1),
1098     GopFrameUpdateRefIdx(2, GopFrameType::kRegularLeaf, 2),
1099   };
1100   gop_struct.show_frame_count = 3;
1101 
1102   // Only frame 0 has TPL block stats.
1103   TplGopStats tpl_gop_stats;
1104   tpl_gop_stats.frame_stats_list.assign(3, { 8, 176, 144, false, {}, {} });
1105   tpl_gop_stats.frame_stats_list[0] = CreateToyTplFrameStatsWithDiffSizes(8, 8);
1106 
1107   AV1RateControlQMode rc;
1108   const Status status =
1109       rc.GetGopEncodeInfo(gop_struct, tpl_gop_stats, {}, {}, RefFrameTable())
1110           .status();
1111   EXPECT_EQ(status.code, AOM_CODEC_INVALID_PARAM);
1112   EXPECT_THAT(status.message,
1113               HasSubstr("The frame with global_coding_idx 2 is a reference "
1114                         "frame, but has no TPL stats"));
1115 }
1116 
1117 // MockRateControlQMode is provided for the use of clients of libaom, but it's
1118 // not expected that it will be used in any real libaom tests.
1119 // This simple "toy" test exists solely to verify the integration of gmock into
1120 // the aom build.
TEST_F(RateControlQModeTest,TestMock)1121 TEST_F(RateControlQModeTest, TestMock) {
1122   MockRateControlQMode mock_rc;
1123   EXPECT_CALL(mock_rc,
1124               DetermineGopInfo(Field(&FirstpassInfo::num_mbs_16x16, 1000)))
1125       .WillOnce(Return(aom::Status{ AOM_CODEC_ERROR, "message" }));
1126   FirstpassInfo firstpass_info = {};
1127   firstpass_info.num_mbs_16x16 = 1000;
1128   const auto result = mock_rc.DetermineGopInfo(firstpass_info);
1129   EXPECT_EQ(result.status().code, AOM_CODEC_ERROR);
1130   EXPECT_EQ(result.status().message, "message");
1131 }
1132 
TEST_F(RateControlQModeTest,TestKMeans)1133 TEST_F(RateControlQModeTest, TestKMeans) {
1134   // The distance between intended centroids is designed so each cluster is far
1135   // enough from others.
1136   std::vector<int> centroids_ref = { 16, 48, 80, 112, 144, 176, 208, 240 };
1137   std::vector<uint8_t> random_input;
1138   const int num_sample_per_cluster = 10;
1139   const int num_clusters = 8;
1140   std::default_random_engine generator;
1141   for (const int centroid : centroids_ref) {
1142     // This is to make sure each cluster is far enough from others.
1143     std::uniform_int_distribution<int> distribution(centroid - 8, centroid + 8);
1144     for (int i = 0; i < num_sample_per_cluster; ++i) {
1145       const int random_sample = distribution(generator);
1146       random_input.push_back(static_cast<uint8_t>(random_sample));
1147     }
1148   }
1149   std::shuffle(random_input.begin(), random_input.end(), generator);
1150   std::unordered_map<int, int> kmeans_result =
1151       aom::internal::KMeans(random_input, num_clusters);
1152 
1153   std::unordered_set<int> found_centroids;
1154   for (const auto &result : kmeans_result) {
1155     found_centroids.insert(result.second);
1156   }
1157   // Verify there're num_clusters in the k-means result.
1158   EXPECT_EQ(static_cast<int>(found_centroids.size()), num_clusters);
1159 
1160   // Verify that for each data point, the assigned centroid is the closest one.
1161   for (const auto &result : kmeans_result) {
1162     const int distance_from_cluster_centroid =
1163         abs(result.first - result.second);
1164     for (const int centroid : found_centroids) {
1165       if (centroid == result.second) continue;
1166       const int distance_from_other_cluster_centroid =
1167           abs(result.first - centroid);
1168       EXPECT_LE(distance_from_cluster_centroid,
1169                 distance_from_other_cluster_centroid);
1170     }
1171   }
1172 }
1173 
1174 }  // namespace aom
1175 
main(int argc,char ** argv)1176 int main(int argc, char **argv) {
1177   ::testing::InitGoogleTest(&argc, argv);
1178   std::srand(0);
1179   return RUN_ALL_TESTS();
1180 }
1181