1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/weight_mask.h"
16
17 #include <algorithm>
18 #include <cstdint>
19 #include <ostream>
20 #include <string>
21 #include <type_traits>
22
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/time/clock.h"
26 #include "absl/time/time.h"
27 #include "gtest/gtest.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 #include "src/utils/constants.h"
31 #include "src/utils/cpu.h"
32 #include "src/utils/memory.h"
33 #include "tests/third_party/libvpx/acm_random.h"
34 #include "tests/utils.h"
35
36 namespace libgav1 {
37 namespace dsp {
38 namespace {
39
40 constexpr int kNumSpeedTests = 50000;
41 constexpr int kMaxPredictionSize = 128;
42 // weight_mask is only used with kCompoundPredictionTypeDiffWeighted with
43 // convolve producing the most extreme ranges.
44 // This includes kCompoundOffset in 10bpp and 12bpp.
45 // see: src/dsp/convolve.cc & src/dsp/warp.cc.
46 constexpr int kCompoundPredictionRange[3][2] = {
47 // 8bpp
48 {-5132, 9212},
49 // 10bpp
50 {3988, 61532},
51 // 12bpp
52 {3974, 61559},
53 };
54
GetDigest8bpp(int id)55 const char* GetDigest8bpp(int id) {
56 static const char* const kDigest[] = {
57 "eaca5b6a96dcfe5e44f3926a071b48b3",
58 "1d82c75cfdf8e57925eb1d5301647538",
59 "25bd455d74fb891b97b133c528f8db60",
60 "" /*kBlock4x16*/,
61 "1d82c75cfdf8e57925eb1d5301647538",
62 "25bd455d74fb891b97b133c528f8db60",
63 "62a08776db35a186406a11ab92dee71c",
64 "95131d1dc0e05fcf4bd234d5ce9eea11",
65 "25bd455d74fb891b97b133c528f8db60",
66 "62a08776db35a186406a11ab92dee71c",
67 "95131d1dc0e05fcf4bd234d5ce9eea11",
68 "0b3c75272e0fb0747b9850145d340c4c",
69 "95131d1dc0e05fcf4bd234d5ce9eea11",
70 "0b3c75272e0fb0747b9850145d340c4c",
71 "f26c43d4bc823a89c1ed47ab8708bc06",
72 "0d99bbf31ecddc1c2d5063a68c0e9375",
73 "0d99bbf31ecddc1c2d5063a68c0e9375",
74 "5fb8ec5f582f0ebfe519ed55860f67c4",
75
76 // mask_is_inverse = true.
77 "96811f3b192828ff679e4c9ad8069d7d",
78 "a04dc180c028d55af70240163445523a",
79 "8513e3988233d0a7de316a0179bb6139",
80 "" /*kBlock4x16*/,
81 "a04dc180c028d55af70240163445523a",
82 "8513e3988233d0a7de316a0179bb6139",
83 "f7356d42fb44a6ccb41253ba35b8b3c7",
84 "3d2d61ffc203ee64fe91c9d16168a19d",
85 "8513e3988233d0a7de316a0179bb6139",
86 "f7356d42fb44a6ccb41253ba35b8b3c7",
87 "3d2d61ffc203ee64fe91c9d16168a19d",
88 "87a2011ac69fb597ca4f71bb3c35ebb0",
89 "3d2d61ffc203ee64fe91c9d16168a19d",
90 "87a2011ac69fb597ca4f71bb3c35ebb0",
91 "97100a3639d567046dc8a99fcb84cb2e",
92 "9fabe05a6523da81a45150e19f75acff",
93 "9fabe05a6523da81a45150e19f75acff",
94 "7c0643e4d02421d06d7ca71822a94e1d",
95 };
96 return kDigest[id];
97 }
98
99 #if LIBGAV1_MAX_BITDEPTH >= 10
GetDigest10bpp(int id)100 const char* GetDigest10bpp(int id) {
101 static const char* const kDigest[] = {
102 "5ae8d64b65a671301a457b8a73368ab5",
103 "61535217f179054d4b76a8d9352a223d",
104 "1aa6614773570e7b021cd509849c4180",
105 "" /*kBlock4x16*/,
106 "61535217f179054d4b76a8d9352a223d",
107 "1aa6614773570e7b021cd509849c4180",
108 "f04c2825cfb6408c7778658f71fa176e",
109 "e1694ea1f026dac7fe7e86a84482cf86",
110 "1aa6614773570e7b021cd509849c4180",
111 "f04c2825cfb6408c7778658f71fa176e",
112 "e1694ea1f026dac7fe7e86a84482cf86",
113 "9c4855d44c013fbddb373b2e9e311080",
114 "e1694ea1f026dac7fe7e86a84482cf86",
115 "9c4855d44c013fbddb373b2e9e311080",
116 "f510e743c3efe3b83374a98ef8a30838",
117 "b6e0bd03c521c5f00e90530daa7d4432",
118 "b6e0bd03c521c5f00e90530daa7d4432",
119 "3270d7f621d488aec5b76bcf121debd0",
120
121 // mask_is_inverse = true.
122 "9aa00fcfe21b71e30c5393699122a020",
123 "4d8ce33262cf6b5375f363530815189a",
124 "428625c51ac1bd4585988f7b36dff1db",
125 "" /*kBlock4x16*/,
126 "4d8ce33262cf6b5375f363530815189a",
127 "428625c51ac1bd4585988f7b36dff1db",
128 "1ef63c06a2d9c42da293fdf924032981",
129 "5dd3f201d755d1c22c126a633bfbb3c0",
130 "428625c51ac1bd4585988f7b36dff1db",
131 "1ef63c06a2d9c42da293fdf924032981",
132 "5dd3f201d755d1c22c126a633bfbb3c0",
133 "fe1e6843e6f214939da516dcbea04a79",
134 "5dd3f201d755d1c22c126a633bfbb3c0",
135 "fe1e6843e6f214939da516dcbea04a79",
136 "240187f27389b5e89f9ec6bdbd7d20a7",
137 "44925dab01011a98b8ab1f0308fa852a",
138 "44925dab01011a98b8ab1f0308fa852a",
139 "6d984b2ccfa056278e2130771127a943",
140 };
141 return kDigest[id];
142 }
143 #endif // LIBGAV1_MAX_BITDEPTH >= 10
144
145 #if LIBGAV1_MAX_BITDEPTH == 12
GetDigest12bpp(int id)146 const char* GetDigest12bpp(int id) {
147 static const char* const kDigest[] = {
148 "57629d3872fd52ff4bbec439c5517ec5",
149 "dba421ceeb534756c77167e00ae91a2c",
150 "72e8ac1d450ef0c6c6b03e93856d5cc2",
151 "" /*kBlock4x16*/,
152 "dba421ceeb534756c77167e00ae91a2c",
153 "72e8ac1d450ef0c6c6b03e93856d5cc2",
154 "ae573eb368df04e6a0133b4e15471728",
155 "ceede597b2729357b15e0d08bb9bb760",
156 "72e8ac1d450ef0c6c6b03e93856d5cc2",
157 "ae573eb368df04e6a0133b4e15471728",
158 "ceede597b2729357b15e0d08bb9bb760",
159 "c4976af803d7ad3f92ef26f25b9f3754",
160 "ceede597b2729357b15e0d08bb9bb760",
161 "c4976af803d7ad3f92ef26f25b9f3754",
162 "1d957d49f71bb7f304705a11a597f0cb",
163 "9522d5713fb951b79f42d78fbff914cf",
164 "9522d5713fb951b79f42d78fbff914cf",
165 "422c046013f79a9f46e2c855967570ba",
166
167 // mask_is_inverse = true.
168 "a585cca9bc459d10e081bc0eb847b6e3",
169 "2fa4ec5f74fad2831d216c51c2cdad5a",
170 "d6c9ac69a9eb3059f5bb6e42b486ebcd",
171 "" /*kBlock4x16*/,
172 "2fa4ec5f74fad2831d216c51c2cdad5a",
173 "d6c9ac69a9eb3059f5bb6e42b486ebcd",
174 "2ddd8c8a1841501964011030e2557e20",
175 "97ef2575023dda008711015cf08d7590",
176 "d6c9ac69a9eb3059f5bb6e42b486ebcd",
177 "2ddd8c8a1841501964011030e2557e20",
178 "97ef2575023dda008711015cf08d7590",
179 "d69aff1e0d43395ce305c9be0dfb4c89",
180 "97ef2575023dda008711015cf08d7590",
181 "d69aff1e0d43395ce305c9be0dfb4c89",
182 "48786f640191dcbee5b3321672778519",
183 "6ad4718230353440b01f2bb78348157e",
184 "6ad4718230353440b01f2bb78348157e",
185 "ad49bd7af0ea17c84f434c7dfd0a911d",
186 };
187 return kDigest[id];
188 }
189 #endif // LIBGAV1_MAX_BITDEPTH == 12
190
191 struct WeightMaskTestParam {
WeightMaskTestParamlibgav1::dsp::__anon21aa963e0111::WeightMaskTestParam192 WeightMaskTestParam(int width, int height, bool mask_is_inverse)
193 : width(width), height(height), mask_is_inverse(mask_is_inverse) {}
194 int width;
195 int height;
196 bool mask_is_inverse;
197 };
198
operator <<(std::ostream & os,const WeightMaskTestParam & param)199 std::ostream& operator<<(std::ostream& os, const WeightMaskTestParam& param) {
200 return os << param.width << "x" << param.height
201 << ", mask_is_inverse: " << param.mask_is_inverse;
202 }
203
204 template <int bitdepth>
205 class WeightMaskTest : public testing::TestWithParam<WeightMaskTestParam>,
206 public test_utils::MaxAlignedAllocable {
207 public:
208 static_assert(bitdepth >= kBitdepth8 && bitdepth <= LIBGAV1_MAX_BITDEPTH, "");
209 WeightMaskTest() = default;
210 ~WeightMaskTest() override = default;
211
SetUp()212 void SetUp() override {
213 test_utils::ResetDspTable(bitdepth);
214 WeightMaskInit_C();
215 const dsp::Dsp* const dsp = dsp::GetDspTable(bitdepth);
216 ASSERT_NE(dsp, nullptr);
217 const int width_index = FloorLog2(width_) - 3;
218 const int height_index = FloorLog2(height_) - 3;
219 const testing::TestInfo* const test_info =
220 testing::UnitTest::GetInstance()->current_test_info();
221 const char* const test_case = test_info->test_suite_name();
222 if (absl::StartsWith(test_case, "C/")) {
223 } else if (absl::StartsWith(test_case, "NEON/")) {
224 WeightMaskInit_NEON();
225 } else if (absl::StartsWith(test_case, "SSE41/")) {
226 WeightMaskInit_SSE4_1();
227 }
228 func_ = dsp->weight_mask[width_index][height_index][mask_is_inverse_];
229 }
230
231 protected:
232 void SetInputData(bool use_fixed_values, int value_1, int value_2);
233 void Test(int num_runs, bool use_fixed_values, int value_1, int value_2);
234
235 private:
236 const int width_ = GetParam().width;
237 const int height_ = GetParam().height;
238 const bool mask_is_inverse_ = GetParam().mask_is_inverse;
239 using PredType =
240 typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
241 alignas(
242 kMaxAlignment) PredType block_1_[kMaxPredictionSize * kMaxPredictionSize];
243 alignas(
244 kMaxAlignment) PredType block_2_[kMaxPredictionSize * kMaxPredictionSize];
245 uint8_t mask_[kMaxPredictionSize * kMaxPredictionSize] = {};
246 dsp::WeightMaskFunc func_;
247 };
248
249 template <int bitdepth>
SetInputData(const bool use_fixed_values,const int value_1,const int value_2)250 void WeightMaskTest<bitdepth>::SetInputData(const bool use_fixed_values,
251 const int value_1,
252 const int value_2) {
253 if (use_fixed_values) {
254 std::fill(block_1_, block_1_ + kMaxPredictionSize * kMaxPredictionSize,
255 value_1);
256 std::fill(block_2_, block_2_ + kMaxPredictionSize * kMaxPredictionSize,
257 value_2);
258 } else {
259 constexpr int bitdepth_index = (bitdepth - 8) >> 1;
260 libvpx_test::ACMRandom rnd(libvpx_test::ACMRandom::DeterministicSeed());
261 for (int y = 0; y < height_; ++y) {
262 for (int x = 0; x < width_; ++x) {
263 const int min_val = kCompoundPredictionRange[bitdepth_index][0];
264 const int max_val = kCompoundPredictionRange[bitdepth_index][1];
265 block_1_[y * width_ + x] =
266 static_cast<PredType>(rnd(max_val - min_val) + min_val);
267 block_2_[y * width_ + x] =
268 static_cast<PredType>(rnd(max_val - min_val) + min_val);
269 }
270 }
271 }
272 }
273
DimensionsToBlockSize(int width,int height)274 BlockSize DimensionsToBlockSize(int width, int height) {
275 if (width == 4) {
276 if (height == 4) return kBlock4x4;
277 if (height == 8) return kBlock4x8;
278 if (height == 16) return kBlock4x16;
279 return kBlockInvalid;
280 }
281 if (width == 8) {
282 if (height == 4) return kBlock8x4;
283 if (height == 8) return kBlock8x8;
284 if (height == 16) return kBlock8x16;
285 if (height == 32) return kBlock8x32;
286 return kBlockInvalid;
287 }
288 if (width == 16) {
289 if (height == 4) return kBlock16x4;
290 if (height == 8) return kBlock16x8;
291 if (height == 16) return kBlock16x16;
292 if (height == 32) return kBlock16x32;
293 if (height == 64) return kBlock16x64;
294 return kBlockInvalid;
295 }
296 if (width == 32) {
297 if (height == 8) return kBlock32x8;
298 if (height == 16) return kBlock32x16;
299 if (height == 32) return kBlock32x32;
300 if (height == 64) return kBlock32x64;
301 return kBlockInvalid;
302 }
303 if (width == 64) {
304 if (height == 16) return kBlock64x16;
305 if (height == 32) return kBlock64x32;
306 if (height == 64) return kBlock64x64;
307 if (height == 128) return kBlock64x128;
308 return kBlockInvalid;
309 }
310 if (width == 128) {
311 if (height == 64) return kBlock128x64;
312 if (height == 128) return kBlock128x128;
313 return kBlockInvalid;
314 }
315 return kBlockInvalid;
316 }
317
318 template <int bitdepth>
Test(const int num_runs,const bool use_fixed_values,const int value_1,const int value_2)319 void WeightMaskTest<bitdepth>::Test(const int num_runs,
320 const bool use_fixed_values,
321 const int value_1, const int value_2) {
322 if (func_ == nullptr) return;
323 SetInputData(use_fixed_values, value_1, value_2);
324 const absl::Time start = absl::Now();
325 for (int i = 0; i < num_runs; ++i) {
326 func_(block_1_, block_2_, mask_, width_);
327 }
328 const absl::Duration elapsed_time = absl::Now() - start;
329 if (use_fixed_values) {
330 int fixed_value = (value_1 - value_2 == 0) ? 38 : 64;
331 if (mask_is_inverse_) fixed_value = 64 - fixed_value;
332 for (int y = 0; y < height_; ++y) {
333 for (int x = 0; x < width_; ++x) {
334 ASSERT_EQ(static_cast<int>(mask_[y * width_ + x]), fixed_value)
335 << "x: " << x << " y: " << y;
336 }
337 }
338 } else {
339 const int id_offset = mask_is_inverse_ ? kMaxBlockSizes - 4 : 0;
340 const int id = id_offset +
341 static_cast<int>(DimensionsToBlockSize(width_, height_)) - 4;
342 const char* expected_digest = nullptr;
343 switch (bitdepth) {
344 case 8:
345 expected_digest = GetDigest8bpp(id);
346 break;
347 #if LIBGAV1_MAX_BITDEPTH >= 10
348 case 10:
349 expected_digest = GetDigest10bpp(id);
350 break;
351 #endif
352 #if LIBGAV1_MAX_BITDEPTH == 12
353 case 12:
354 expected_digest = GetDigest12bpp(id);
355 break;
356 #endif
357 }
358 ASSERT_NE(expected_digest, nullptr);
359 test_utils::CheckMd5Digest(
360 absl::StrFormat("BlockSize %dx%d", width_, height_).c_str(),
361 "WeightMask", expected_digest, mask_, sizeof(mask_), elapsed_time);
362 }
363 }
364
365 const WeightMaskTestParam weight_mask_test_param[] = {
366 WeightMaskTestParam(8, 8, false), WeightMaskTestParam(8, 16, false),
367 WeightMaskTestParam(8, 32, false), WeightMaskTestParam(16, 8, false),
368 WeightMaskTestParam(16, 16, false), WeightMaskTestParam(16, 32, false),
369 WeightMaskTestParam(16, 64, false), WeightMaskTestParam(32, 8, false),
370 WeightMaskTestParam(32, 16, false), WeightMaskTestParam(32, 32, false),
371 WeightMaskTestParam(32, 64, false), WeightMaskTestParam(64, 16, false),
372 WeightMaskTestParam(64, 32, false), WeightMaskTestParam(64, 64, false),
373 WeightMaskTestParam(64, 128, false), WeightMaskTestParam(128, 64, false),
374 WeightMaskTestParam(128, 128, false), WeightMaskTestParam(8, 8, true),
375 WeightMaskTestParam(8, 16, true), WeightMaskTestParam(8, 32, true),
376 WeightMaskTestParam(16, 8, true), WeightMaskTestParam(16, 16, true),
377 WeightMaskTestParam(16, 32, true), WeightMaskTestParam(16, 64, true),
378 WeightMaskTestParam(32, 8, true), WeightMaskTestParam(32, 16, true),
379 WeightMaskTestParam(32, 32, true), WeightMaskTestParam(32, 64, true),
380 WeightMaskTestParam(64, 16, true), WeightMaskTestParam(64, 32, true),
381 WeightMaskTestParam(64, 64, true), WeightMaskTestParam(64, 128, true),
382 WeightMaskTestParam(128, 64, true), WeightMaskTestParam(128, 128, true),
383 };
384
385 using WeightMaskTest8bpp = WeightMaskTest<8>;
386
TEST_P(WeightMaskTest8bpp,FixedValues)387 TEST_P(WeightMaskTest8bpp, FixedValues) {
388 const int min = kCompoundPredictionRange[0][0];
389 const int max = kCompoundPredictionRange[0][1];
390 Test(1, true, min, min);
391 Test(1, true, min, max);
392 Test(1, true, max, min);
393 Test(1, true, max, max);
394 }
395
TEST_P(WeightMaskTest8bpp,RandomValues)396 TEST_P(WeightMaskTest8bpp, RandomValues) { Test(1, false, -1, -1); }
397
TEST_P(WeightMaskTest8bpp,DISABLED_Speed)398 TEST_P(WeightMaskTest8bpp, DISABLED_Speed) {
399 Test(kNumSpeedTests, false, -1, -1);
400 }
401
402 INSTANTIATE_TEST_SUITE_P(C, WeightMaskTest8bpp,
403 testing::ValuesIn(weight_mask_test_param));
404 #if LIBGAV1_ENABLE_NEON
405 INSTANTIATE_TEST_SUITE_P(NEON, WeightMaskTest8bpp,
406 testing::ValuesIn(weight_mask_test_param));
407 #endif
408 #if LIBGAV1_ENABLE_SSE4_1
409 INSTANTIATE_TEST_SUITE_P(SSE41, WeightMaskTest8bpp,
410 testing::ValuesIn(weight_mask_test_param));
411 #endif
412
413 #if LIBGAV1_MAX_BITDEPTH >= 10
414 using WeightMaskTest10bpp = WeightMaskTest<10>;
415
TEST_P(WeightMaskTest10bpp,FixedValues)416 TEST_P(WeightMaskTest10bpp, FixedValues) {
417 const int min = kCompoundPredictionRange[1][0];
418 const int max = kCompoundPredictionRange[1][1];
419 Test(1, true, min, min);
420 Test(1, true, min, max);
421 Test(1, true, max, min);
422 Test(1, true, max, max);
423 }
424
TEST_P(WeightMaskTest10bpp,RandomValues)425 TEST_P(WeightMaskTest10bpp, RandomValues) { Test(1, false, -1, -1); }
426
TEST_P(WeightMaskTest10bpp,DISABLED_Speed)427 TEST_P(WeightMaskTest10bpp, DISABLED_Speed) {
428 Test(kNumSpeedTests, false, -1, -1);
429 }
430
431 INSTANTIATE_TEST_SUITE_P(C, WeightMaskTest10bpp,
432 testing::ValuesIn(weight_mask_test_param));
433 #if LIBGAV1_ENABLE_NEON
434 INSTANTIATE_TEST_SUITE_P(NEON, WeightMaskTest10bpp,
435 testing::ValuesIn(weight_mask_test_param));
436 #endif
437 #if LIBGAV1_ENABLE_SSE4_1
438 INSTANTIATE_TEST_SUITE_P(SSE41, WeightMaskTest10bpp,
439 testing::ValuesIn(weight_mask_test_param));
440 #endif
441 #endif // LIBGAV1_MAX_BITDEPTH >= 10
442
443 #if LIBGAV1_MAX_BITDEPTH == 12
444 using WeightMaskTest12bpp = WeightMaskTest<12>;
445
TEST_P(WeightMaskTest12bpp,FixedValues)446 TEST_P(WeightMaskTest12bpp, FixedValues) {
447 const int min = kCompoundPredictionRange[2][0];
448 const int max = kCompoundPredictionRange[2][1];
449 Test(1, true, min, min);
450 Test(1, true, min, max);
451 Test(1, true, max, min);
452 Test(1, true, max, max);
453 }
454
TEST_P(WeightMaskTest12bpp,RandomValues)455 TEST_P(WeightMaskTest12bpp, RandomValues) { Test(1, false, -1, -1); }
456
TEST_P(WeightMaskTest12bpp,DISABLED_Speed)457 TEST_P(WeightMaskTest12bpp, DISABLED_Speed) {
458 Test(kNumSpeedTests, false, -1, -1);
459 }
460
461 INSTANTIATE_TEST_SUITE_P(C, WeightMaskTest12bpp,
462 testing::ValuesIn(weight_mask_test_param));
463 #endif // LIBGAV1_MAX_BITDEPTH == 12
464
465 } // namespace
466 } // namespace dsp
467 } // namespace libgav1
468