// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include // NOLINT #include #include #include #include "../internal/multi_thread_gemm.h" #include "../profiling/pthread_everywhere.h" #include "test.h" namespace gemmlowp { class Thread { public: Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement) : blocking_counter_(blocking_counter), number_of_times_to_decrement_(number_of_times_to_decrement), made_the_last_decrement_(false), finished_(false) { #if defined GEMMLOWP_USE_PTHREAD // Limit the stack size so as not to deplete memory when creating // many threads. pthread_attr_t attr; int err = pthread_attr_init(&attr); if (!err) { size_t stack_size; err = pthread_attr_getstacksize(&attr, &stack_size); if (!err && stack_size > max_stack_size_) { err = pthread_attr_setstacksize(&attr, max_stack_size_); } if (!err) { err = pthread_create(&thread_, &attr, ThreadFunc, this); } } if (err) { std::cerr << "Failed to create a thread.\n"; std::abort(); } #else pthread_create(&thread_, nullptr, ThreadFunc, this); #endif } ~Thread() { Join(); } bool Join() { while (!finished_.load()) { } return made_the_last_decrement_; } private: Thread(const Thread& other) = delete; void ThreadFunc() { for (int i = 0; i < number_of_times_to_decrement_; i++) { Check(!made_the_last_decrement_); made_the_last_decrement_ = blocking_counter_->DecrementCount(); } finished_.store(true); } static void* ThreadFunc(void* ptr) { static_cast(ptr)->ThreadFunc(); return nullptr; } static constexpr size_t max_stack_size_ = 256 * 1024; BlockingCounter* const blocking_counter_; const int number_of_times_to_decrement_; pthread_t thread_; bool made_the_last_decrement_; // finished_ is used to manually implement Join() by busy-waiting. // I wanted to use pthread_join / std::thread::join, but the behavior // observed on Android was that pthread_join aborts when the thread has // already joined before calling pthread_join, making that hard to use. // It appeared simplest to just implement this simple spinlock, and that // is good enough as this is just a test. std::atomic finished_; }; void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads, int num_decrements_per_thread, int num_decrements_to_wait_for) { std::vector threads; blocking_counter->Reset(num_decrements_to_wait_for); for (int i = 0; i < num_threads; i++) { threads.push_back(new Thread(blocking_counter, num_decrements_per_thread)); } blocking_counter->Wait(); int num_threads_that_made_the_last_decrement = 0; for (int i = 0; i < num_threads; i++) { if (threads[i]->Join()) { num_threads_that_made_the_last_decrement++; } delete threads[i]; } Check(num_threads_that_made_the_last_decrement == 1); } void test_blocking_counter() { BlockingCounter* blocking_counter = new BlockingCounter; // repeating the entire test sequence ensures that we test // non-monotonic changes. for (int repeat = 1; repeat <= 2; repeat++) { for (int num_threads = 1; num_threads <= 5; num_threads++) { for (int num_decrements_per_thread = 1; num_decrements_per_thread <= 4 * 1024; num_decrements_per_thread *= 16) { test_blocking_counter(blocking_counter, num_threads, num_decrements_per_thread, num_threads * num_decrements_per_thread); } } } delete blocking_counter; } } // end namespace gemmlowp int main() { gemmlowp::test_blocking_counter(); }