• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <atomic>  // NOLINT
16 #include <vector>
17 #include <iostream>
18 #include <cstdlib>
19 
20 #include "../internal/multi_thread_gemm.h"
21 #include "../profiling/pthread_everywhere.h"
22 #include "test.h"
23 
24 namespace gemmlowp {
25 
26 class Thread {
27  public:
Thread(BlockingCounter * blocking_counter,int number_of_times_to_decrement)28   Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
29       : blocking_counter_(blocking_counter),
30         number_of_times_to_decrement_(number_of_times_to_decrement),
31         made_the_last_decrement_(false),
32         finished_(false) {
33 #if defined GEMMLOWP_USE_PTHREAD
34     // Limit the stack size so as not to deplete memory when creating
35     // many threads.
36     pthread_attr_t attr;
37     int err = pthread_attr_init(&attr);
38     if (!err) {
39       size_t stack_size;
40       err = pthread_attr_getstacksize(&attr, &stack_size);
41       if (!err && stack_size > max_stack_size_) {
42         err = pthread_attr_setstacksize(&attr, max_stack_size_);
43       }
44       if (!err) {
45         err = pthread_create(&thread_, &attr, ThreadFunc, this);
46       }
47     }
48     if (err) {
49       std::cerr << "Failed to create a thread.\n";
50       std::abort();
51     }
52 #else
53     pthread_create(&thread_, nullptr, ThreadFunc, this);
54 #endif
55   }
56 
~Thread()57   ~Thread() { Join(); }
58 
Join()59   bool Join() {
60     while (!finished_.load()) {
61     }
62     return made_the_last_decrement_;
63   }
64 
65  private:
66   Thread(const Thread& other) = delete;
67 
ThreadFunc()68   void ThreadFunc() {
69     for (int i = 0; i < number_of_times_to_decrement_; i++) {
70       Check(!made_the_last_decrement_);
71       made_the_last_decrement_ = blocking_counter_->DecrementCount();
72     }
73     finished_.store(true);
74   }
75 
ThreadFunc(void * ptr)76   static void* ThreadFunc(void* ptr) {
77     static_cast<Thread*>(ptr)->ThreadFunc();
78     return nullptr;
79   }
80 
81   static constexpr size_t max_stack_size_ = 256 * 1024;
82   BlockingCounter* const blocking_counter_;
83   const int number_of_times_to_decrement_;
84   pthread_t thread_;
85   bool made_the_last_decrement_;
86   // finished_ is used to manually implement Join() by busy-waiting.
87   // I wanted to use pthread_join / std::thread::join, but the behavior
88   // observed on Android was that pthread_join aborts when the thread has
89   // already joined before calling pthread_join, making that hard to use.
90   // It appeared simplest to just implement this simple spinlock, and that
91   // is good enough as this is just a test.
92   std::atomic<bool> finished_;
93 };
94 
test_blocking_counter(BlockingCounter * blocking_counter,int num_threads,int num_decrements_per_thread,int num_decrements_to_wait_for)95 void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
96                            int num_decrements_per_thread,
97                            int num_decrements_to_wait_for) {
98   std::vector<Thread*> threads;
99   blocking_counter->Reset(num_decrements_to_wait_for);
100   for (int i = 0; i < num_threads; i++) {
101     threads.push_back(new Thread(blocking_counter, num_decrements_per_thread));
102   }
103   blocking_counter->Wait();
104 
105   int num_threads_that_made_the_last_decrement = 0;
106   for (int i = 0; i < num_threads; i++) {
107     if (threads[i]->Join()) {
108       num_threads_that_made_the_last_decrement++;
109     }
110     delete threads[i];
111   }
112   Check(num_threads_that_made_the_last_decrement == 1);
113 }
114 
test_blocking_counter()115 void test_blocking_counter() {
116   BlockingCounter* blocking_counter = new BlockingCounter;
117 
118   // repeating the entire test sequence ensures that we test
119   // non-monotonic changes.
120   for (int repeat = 1; repeat <= 2; repeat++) {
121     for (int num_threads = 1; num_threads <= 5; num_threads++) {
122       for (int num_decrements_per_thread = 1;
123            num_decrements_per_thread <= 4 * 1024;
124            num_decrements_per_thread *= 16) {
125         test_blocking_counter(blocking_counter, num_threads,
126                               num_decrements_per_thread,
127                               num_threads * num_decrements_per_thread);
128       }
129     }
130   }
131   delete blocking_counter;
132 }
133 
134 }  // end namespace gemmlowp
135 
main()136 int main() { gemmlowp::test_blocking_counter(); }
137