1 // Copyright 2021 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/threading_strategy.h"
16
17 #include <memory>
18 #include <utility>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "gtest/gtest.h"
23 #include "src/frame_scratch_buffer.h"
24 #include "src/obu_parser.h"
25 #include "src/utils/constants.h"
26 #include "src/utils/threadpool.h"
27 #include "src/utils/types.h"
28
29 namespace libgav1 {
30 namespace {
31
32 class ThreadingStrategyTest : public testing::Test {
33 protected:
34 ThreadingStrategy strategy_;
35 ObuFrameHeader frame_header_ = {};
36 };
37
TEST_F(ThreadingStrategyTest,MaxThreadEnforced)38 TEST_F(ThreadingStrategyTest, MaxThreadEnforced) {
39 frame_header_.tile_info.tile_count = 32;
40 ASSERT_TRUE(strategy_.Reset(frame_header_, 32));
41 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
42 for (int i = 0; i < 32; ++i) {
43 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
44 }
45 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
46 }
47
TEST_F(ThreadingStrategyTest,UseAllThreadsForTiles)48 TEST_F(ThreadingStrategyTest, UseAllThreadsForTiles) {
49 frame_header_.tile_info.tile_count = 8;
50 ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
51 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
52 for (int i = 0; i < 8; ++i) {
53 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
54 }
55 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
56 }
57
TEST_F(ThreadingStrategyTest,RowThreads)58 TEST_F(ThreadingStrategyTest, RowThreads) {
59 frame_header_.tile_info.tile_count = 2;
60 ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
61 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
62 // Each tile should get 3 threads each.
63 for (int i = 0; i < 2; ++i) {
64 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
65 }
66 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
67 }
68
TEST_F(ThreadingStrategyTest,RowThreadsUnequal)69 TEST_F(ThreadingStrategyTest, RowThreadsUnequal) {
70 frame_header_.tile_info.tile_count = 2;
71
72 ASSERT_TRUE(strategy_.Reset(frame_header_, 9));
73 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
74 EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
75 EXPECT_NE(strategy_.row_thread_pool(1), nullptr);
76 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
77 }
78
79 // Test a random combination of tile_count and thread_count.
TEST_F(ThreadingStrategyTest,MultipleCalls)80 TEST_F(ThreadingStrategyTest, MultipleCalls) {
81 frame_header_.tile_info.tile_count = 2;
82 ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
83 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
84 for (int i = 0; i < 2; ++i) {
85 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
86 }
87 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
88
89 frame_header_.tile_info.tile_count = 8;
90 ASSERT_TRUE(strategy_.Reset(frame_header_, 8));
91 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
92 // Row threads must have been reset.
93 for (int i = 0; i < 8; ++i) {
94 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
95 }
96 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
97
98 frame_header_.tile_info.tile_count = 8;
99 ASSERT_TRUE(strategy_.Reset(frame_header_, 16));
100 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
101 for (int i = 0; i < 8; ++i) {
102 // See ThreadingStrategy::Reset().
103 #if defined(__ANDROID__)
104 if (i >= 4) {
105 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr) << "i = " << i;
106 continue;
107 }
108 #endif
109 EXPECT_NE(strategy_.row_thread_pool(i), nullptr) << "i = " << i;
110 }
111 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
112
113 frame_header_.tile_info.tile_count = 4;
114 ASSERT_TRUE(strategy_.Reset(frame_header_, 16));
115 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
116 for (int i = 0; i < 4; ++i) {
117 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
118 }
119 // All the other row threads must be reset.
120 for (int i = 4; i < 8; ++i) {
121 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
122 }
123 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
124
125 frame_header_.tile_info.tile_count = 4;
126 ASSERT_TRUE(strategy_.Reset(frame_header_, 6));
127 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
128 // First two tiles will get 1 thread each.
129 for (int i = 0; i < 2; ++i) {
130 // See ThreadingStrategy::Reset().
131 #if defined(__ANDROID__)
132 if (i == 1) {
133 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr) << "i = " << i;
134 continue;
135 }
136 #endif
137 EXPECT_NE(strategy_.row_thread_pool(i), nullptr) << "i = " << i;
138 }
139 // All the other row threads must be reset.
140 for (int i = 2; i < 8; ++i) {
141 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr) << "i = " << i;
142 }
143 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
144
145 ASSERT_TRUE(strategy_.Reset(frame_header_, 1));
146 EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
147 for (int i = 0; i < 8; ++i) {
148 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
149 }
150 EXPECT_EQ(strategy_.post_filter_thread_pool(), nullptr);
151 }
152
153 // Tests the following order of calls (with thread count fixed at 4):
154 // * 1 Tile - 2 Tiles - 1 Tile.
TEST_F(ThreadingStrategyTest,MultipleCalls2)155 TEST_F(ThreadingStrategyTest, MultipleCalls2) {
156 frame_header_.tile_info.tile_count = 1;
157 ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
158 // When there is only one tile, tile thread pool must be nullptr.
159 EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
160 EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
161 for (int i = 1; i < 8; ++i) {
162 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
163 }
164 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
165
166 frame_header_.tile_info.tile_count = 2;
167 ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
168 EXPECT_NE(strategy_.tile_thread_pool(), nullptr);
169 for (int i = 0; i < 2; ++i) {
170 // See ThreadingStrategy::Reset().
171 #if defined(__ANDROID__)
172 if (i == 1) {
173 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr) << "i = " << i;
174 continue;
175 }
176 #endif
177 EXPECT_NE(strategy_.row_thread_pool(i), nullptr);
178 }
179 for (int i = 2; i < 8; ++i) {
180 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
181 }
182 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
183
184 frame_header_.tile_info.tile_count = 1;
185 ASSERT_TRUE(strategy_.Reset(frame_header_, 4));
186 EXPECT_EQ(strategy_.tile_thread_pool(), nullptr);
187 EXPECT_NE(strategy_.row_thread_pool(0), nullptr);
188 for (int i = 1; i < 8; ++i) {
189 EXPECT_EQ(strategy_.row_thread_pool(i), nullptr);
190 }
191 EXPECT_NE(strategy_.post_filter_thread_pool(), nullptr);
192 }
193
VerifyFrameParallel(int thread_count,int tile_count,int tile_columns,int expected_frame_threads,const std::vector<int> & expected_tile_threads)194 void VerifyFrameParallel(int thread_count, int tile_count, int tile_columns,
195 int expected_frame_threads,
196 const std::vector<int>& expected_tile_threads) {
197 ASSERT_EQ(expected_frame_threads, expected_tile_threads.size());
198 ASSERT_GT(thread_count, 1);
199 std::unique_ptr<ThreadPool> frame_thread_pool;
200 FrameScratchBufferPool frame_scratch_buffer_pool;
201 ASSERT_TRUE(InitializeThreadPoolsForFrameParallel(
202 thread_count, tile_count, tile_columns, &frame_thread_pool,
203 &frame_scratch_buffer_pool));
204 if (expected_frame_threads == 0) {
205 EXPECT_EQ(frame_thread_pool, nullptr);
206 return;
207 }
208 EXPECT_NE(frame_thread_pool.get(), nullptr);
209 EXPECT_EQ(frame_thread_pool->num_threads(), expected_frame_threads);
210 std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers;
211 int actual_thread_count = frame_thread_pool->num_threads();
212 for (int i = 0; i < expected_frame_threads; ++i) {
213 SCOPED_TRACE(absl::StrCat("i: ", i));
214 frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get());
215 ThreadPool* const thread_pool =
216 frame_scratch_buffers.back()->threading_strategy.thread_pool();
217 if (expected_tile_threads[i] > 0) {
218 EXPECT_NE(thread_pool, nullptr);
219 EXPECT_EQ(thread_pool->num_threads(), expected_tile_threads[i]);
220 actual_thread_count += thread_pool->num_threads();
221 } else {
222 EXPECT_EQ(thread_pool, nullptr);
223 }
224 }
225 EXPECT_EQ(thread_count, actual_thread_count);
226 for (auto& frame_scratch_buffer : frame_scratch_buffers) {
227 frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer));
228 }
229 }
230
TEST(FrameParallelStrategyTest,FrameParallel)231 TEST(FrameParallelStrategyTest, FrameParallel) {
232 // This loop has thread_count <= 3 * tile count. So there should be no frame
233 // threads irrespective of the number of tile columns.
234 for (int thread_count = 2; thread_count <= 6; ++thread_count) {
235 VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/1,
236 /*expected_frame_threads=*/0,
237 /*expected_tile_threads=*/{});
238 VerifyFrameParallel(thread_count, /*tile_count=*/2, /*tile_columns=*/2,
239 /*expected_frame_threads=*/0,
240 /*expected_tile_threads=*/{});
241 }
242
243 // Equal number of tile threads for each frame thread.
244 VerifyFrameParallel(
245 /*thread_count=*/8, /*tile_count=*/1, /*tile_columns=*/1,
246 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{1, 1, 1, 1});
247 VerifyFrameParallel(
248 /*thread_count=*/12, /*tile_count=*/2, /*tile_columns=*/2,
249 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{2, 2, 2, 2});
250 VerifyFrameParallel(
251 /*thread_count=*/18, /*tile_count=*/2, /*tile_columns=*/2,
252 /*expected_frame_threads=*/6,
253 /*expected_tile_threads=*/{2, 2, 2, 2, 2, 2});
254 VerifyFrameParallel(
255 /*thread_count=*/16, /*tile_count=*/3, /*tile_columns=*/3,
256 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 3, 3});
257
258 // Unequal number of tile threads for each frame thread.
259 VerifyFrameParallel(
260 /*thread_count=*/7, /*tile_count=*/1, /*tile_columns=*/1,
261 /*expected_frame_threads=*/3, /*expected_tile_threads=*/{2, 1, 1});
262 VerifyFrameParallel(
263 /*thread_count=*/14, /*tile_count=*/2, /*tile_columns=*/2,
264 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{3, 3, 2, 2});
265 VerifyFrameParallel(
266 /*thread_count=*/20, /*tile_count=*/2, /*tile_columns=*/2,
267 /*expected_frame_threads=*/6,
268 /*expected_tile_threads=*/{3, 3, 2, 2, 2, 2});
269 VerifyFrameParallel(
270 /*thread_count=*/17, /*tile_count=*/3, /*tile_columns=*/3,
271 /*expected_frame_threads=*/4, /*expected_tile_threads=*/{4, 3, 3, 3});
272 }
273
TEST(FrameParallelStrategyTest,ThreadCountDoesNotExceedkMaxThreads)274 TEST(FrameParallelStrategyTest, ThreadCountDoesNotExceedkMaxThreads) {
275 std::unique_ptr<ThreadPool> frame_thread_pool;
276 FrameScratchBufferPool frame_scratch_buffer_pool;
277 ASSERT_TRUE(InitializeThreadPoolsForFrameParallel(
278 /*thread_count=*/kMaxThreads + 10, /*tile_count=*/2, /*tile_columns=*/2,
279 &frame_thread_pool, &frame_scratch_buffer_pool));
280 EXPECT_NE(frame_thread_pool.get(), nullptr);
281 std::vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers;
282 int actual_thread_count = frame_thread_pool->num_threads();
283 for (int i = 0; i < frame_thread_pool->num_threads(); ++i) {
284 SCOPED_TRACE(absl::StrCat("i: ", i));
285 frame_scratch_buffers.push_back(frame_scratch_buffer_pool.Get());
286 ThreadPool* const thread_pool =
287 frame_scratch_buffers.back()->threading_strategy.thread_pool();
288 if (thread_pool != nullptr) {
289 actual_thread_count += thread_pool->num_threads();
290 }
291 }
292 // In this case, the exact number of frame threads and tile threads depend on
293 // the value of kMaxThreads. So simply ensure that the total number of threads
294 // does not exceed kMaxThreads.
295 EXPECT_LE(actual_thread_count, kMaxThreads);
296 for (auto& frame_scratch_buffer : frame_scratch_buffers) {
297 frame_scratch_buffer_pool.Release(std::move(frame_scratch_buffer));
298 }
299 }
300
301 } // namespace
302 } // namespace libgav1
303