• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test.h"
16 
17 #include <pthread.h>
18 #include <vector>
19 
20 #include "../internal/multi_thread_gemm.h"
21 
22 namespace gemmlowp {
23 
24 class Thread {
25  public:
Thread(BlockingCounter * blocking_counter,int number_of_times_to_decrement)26   Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
27       : blocking_counter_(blocking_counter),
28         number_of_times_to_decrement_(number_of_times_to_decrement),
29         made_the_last_decrement_(false) {
30     pthread_create(&thread_, nullptr, ThreadFunc, this);
31   }
32 
~Thread()33   ~Thread() { Join(); }
34 
Join() const35   bool Join() const {
36     pthread_join(thread_, nullptr);
37     return made_the_last_decrement_;
38   }
39 
40  private:
41   Thread(const Thread& other) = delete;
42 
ThreadFunc()43   void ThreadFunc() {
44     for (int i = 0; i < number_of_times_to_decrement_; i++) {
45       Check(!made_the_last_decrement_);
46       made_the_last_decrement_ = blocking_counter_->DecrementCount();
47     }
48   }
49 
ThreadFunc(void * ptr)50   static void* ThreadFunc(void* ptr) {
51     static_cast<Thread*>(ptr)->ThreadFunc();
52     return nullptr;
53   }
54 
55   BlockingCounter* const blocking_counter_;
56   const int number_of_times_to_decrement_;
57   pthread_t thread_;
58   bool made_the_last_decrement_;
59 };
60 
test_blocking_counter(BlockingCounter * blocking_counter,int num_threads,int num_decrements_per_thread,int num_decrements_to_wait_for)61 void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
62                            int num_decrements_per_thread,
63                            int num_decrements_to_wait_for) {
64   std::vector<Thread*> threads;
65   blocking_counter->Reset(num_decrements_to_wait_for);
66   for (int i = 0; i < num_threads; i++) {
67     threads.push_back(new Thread(blocking_counter, num_decrements_per_thread));
68   }
69   blocking_counter->Wait();
70 
71   int num_threads_that_made_the_last_decrement = 0;
72   for (int i = 0; i < num_threads; i++) {
73     if (threads[i]->Join()) {
74       num_threads_that_made_the_last_decrement++;
75     }
76     delete threads[i];
77   }
78   Check(num_threads_that_made_the_last_decrement == 1);
79 }
80 
test_blocking_counter()81 void test_blocking_counter() {
82   BlockingCounter* blocking_counter = new BlockingCounter;
83 
84   // repeating the entire test sequence ensures that we test
85   // non-monotonic changes.
86   for (int repeat = 1; repeat <= 2; repeat++) {
87     for (int num_threads = 1; num_threads <= 16; num_threads++) {
88       for (int num_decrements_per_thread = 1;
89            num_decrements_per_thread <= 64 * 1024;
90            num_decrements_per_thread *= 4) {
91         test_blocking_counter(blocking_counter, num_threads,
92                               num_decrements_per_thread,
93                               num_threads * num_decrements_per_thread);
94       }
95     }
96   }
97   delete blocking_counter;
98 }
99 
100 }  // end namespace gemmlowp
101 
main()102 int main() { gemmlowp::test_blocking_counter(); }
103