• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test.h"
16 #include "../profiling/pthread_everywhere.h"
17 
18 #include <vector>
19 
20 #include "../internal/multi_thread_gemm.h"
21 
22 namespace gemmlowp {
23 
24 class Thread {
25  public:
Thread(BlockingCounter * blocking_counter,int number_of_times_to_decrement)26   Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
27       : blocking_counter_(blocking_counter),
28         number_of_times_to_decrement_(number_of_times_to_decrement),
29         finished_(false),
30         made_the_last_decrement_(false) {
31     pthread_create(&thread_, nullptr, ThreadFunc, this);
32   }
33 
~Thread()34   ~Thread() { Join(); }
35 
Join() const36   bool Join() const {
37     if (!finished_) {
38       pthread_join(thread_, nullptr);
39     }
40     return made_the_last_decrement_;
41   }
42 
43  private:
44   Thread(const Thread& other) = delete;
45 
ThreadFunc()46   void ThreadFunc() {
47     for (int i = 0; i < number_of_times_to_decrement_; i++) {
48       Check(!made_the_last_decrement_);
49       made_the_last_decrement_ = blocking_counter_->DecrementCount();
50     }
51     finished_ = true;
52   }
53 
ThreadFunc(void * ptr)54   static void* ThreadFunc(void* ptr) {
55     static_cast<Thread*>(ptr)->ThreadFunc();
56     return nullptr;
57   }
58 
59   BlockingCounter* const blocking_counter_;
60   const int number_of_times_to_decrement_;
61   pthread_t thread_;
62   bool finished_;
63   bool made_the_last_decrement_;
64 };
65 
test_blocking_counter(BlockingCounter * blocking_counter,int num_threads,int num_decrements_per_thread,int num_decrements_to_wait_for)66 void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
67                            int num_decrements_per_thread,
68                            int num_decrements_to_wait_for) {
69   std::vector<Thread*> threads;
70   blocking_counter->Reset(num_decrements_to_wait_for);
71   for (int i = 0; i < num_threads; i++) {
72     threads.push_back(new Thread(blocking_counter, num_decrements_per_thread));
73   }
74   blocking_counter->Wait();
75 
76   int num_threads_that_made_the_last_decrement = 0;
77   for (int i = 0; i < num_threads; i++) {
78     if (threads[i]->Join()) {
79       num_threads_that_made_the_last_decrement++;
80     }
81     delete threads[i];
82   }
83   Check(num_threads_that_made_the_last_decrement == 1);
84 }
85 
test_blocking_counter()86 void test_blocking_counter() {
87   BlockingCounter* blocking_counter = new BlockingCounter;
88 
89   // repeating the entire test sequence ensures that we test
90   // non-monotonic changes.
91   for (int repeat = 1; repeat <= 2; repeat++) {
92     for (int num_threads = 1; num_threads <= 16; num_threads++) {
93       for (int num_decrements_per_thread = 1;
94            num_decrements_per_thread <= 64 * 1024;
95            num_decrements_per_thread *= 4) {
96         test_blocking_counter(blocking_counter, num_threads,
97                               num_decrements_per_thread,
98                               num_threads * num_decrements_per_thread);
99       }
100     }
101   }
102   delete blocking_counter;
103 }
104 
105 }  // end namespace gemmlowp
106 
main()107 int main() { gemmlowp::test_blocking_counter(); }
108