1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "test.h"
16
17 #include <pthread.h>
18 #include <vector>
19
20 #include "../internal/multi_thread_gemm.h"
21
22 namespace gemmlowp {
23
24 class Thread {
25 public:
Thread(BlockingCounter * blocking_counter,int number_of_times_to_decrement)26 Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
27 : blocking_counter_(blocking_counter),
28 number_of_times_to_decrement_(number_of_times_to_decrement),
29 made_the_last_decrement_(false) {
30 pthread_create(&thread_, nullptr, ThreadFunc, this);
31 }
32
~Thread()33 ~Thread() { Join(); }
34
Join() const35 bool Join() const {
36 pthread_join(thread_, nullptr);
37 return made_the_last_decrement_;
38 }
39
40 private:
41 Thread(const Thread& other) = delete;
42
ThreadFunc()43 void ThreadFunc() {
44 for (int i = 0; i < number_of_times_to_decrement_; i++) {
45 Check(!made_the_last_decrement_);
46 made_the_last_decrement_ = blocking_counter_->DecrementCount();
47 }
48 }
49
ThreadFunc(void * ptr)50 static void* ThreadFunc(void* ptr) {
51 static_cast<Thread*>(ptr)->ThreadFunc();
52 return nullptr;
53 }
54
55 BlockingCounter* const blocking_counter_;
56 const int number_of_times_to_decrement_;
57 pthread_t thread_;
58 bool made_the_last_decrement_;
59 };
60
test_blocking_counter(BlockingCounter * blocking_counter,int num_threads,int num_decrements_per_thread,int num_decrements_to_wait_for)61 void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
62 int num_decrements_per_thread,
63 int num_decrements_to_wait_for) {
64 std::vector<Thread*> threads;
65 blocking_counter->Reset(num_decrements_to_wait_for);
66 for (int i = 0; i < num_threads; i++) {
67 threads.push_back(new Thread(blocking_counter, num_decrements_per_thread));
68 }
69 blocking_counter->Wait();
70
71 int num_threads_that_made_the_last_decrement = 0;
72 for (int i = 0; i < num_threads; i++) {
73 if (threads[i]->Join()) {
74 num_threads_that_made_the_last_decrement++;
75 }
76 delete threads[i];
77 }
78 Check(num_threads_that_made_the_last_decrement == 1);
79 }
80
test_blocking_counter()81 void test_blocking_counter() {
82 BlockingCounter* blocking_counter = new BlockingCounter;
83
84 // repeating the entire test sequence ensures that we test
85 // non-monotonic changes.
86 for (int repeat = 1; repeat <= 2; repeat++) {
87 for (int num_threads = 1; num_threads <= 16; num_threads++) {
88 for (int num_decrements_per_thread = 1;
89 num_decrements_per_thread <= 64 * 1024;
90 num_decrements_per_thread *= 4) {
91 test_blocking_counter(blocking_counter, num_threads,
92 num_decrements_per_thread,
93 num_threads * num_decrements_per_thread);
94 }
95 }
96 }
97 delete blocking_counter;
98 }
99
100 } // end namespace gemmlowp
101
main()102 int main() { gemmlowp::test_blocking_counter(); }
103