1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
5 // Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11 #define EIGEN_USE_THREADS
12 #include "main.h"
13 #include <Eigen/CXX11/ThreadPool>
14
15 // Visual studio doesn't implement a rand_r() function since its
16 // implementation of rand() is already thread safe
rand_reentrant(unsigned int * s)17 int rand_reentrant(unsigned int* s) {
18 #ifdef EIGEN_COMP_MSVC_STRICT
19 EIGEN_UNUSED_VARIABLE(s);
20 return rand();
21 #else
22 return rand_r(s);
23 #endif
24 }
25
test_basic_eventcount()26 static void test_basic_eventcount()
27 {
28 MaxSizeVector<EventCount::Waiter> waiters(1);
29 waiters.resize(1);
30 EventCount ec(waiters);
31 EventCount::Waiter& w = waiters[0];
32 ec.Notify(false);
33 ec.Prewait(&w);
34 ec.Notify(true);
35 ec.CommitWait(&w);
36 ec.Prewait(&w);
37 ec.CancelWait(&w);
38 }
39
40 // Fake bounded counter-based queue.
41 struct TestQueue {
42 std::atomic<int> val_;
43 static const int kQueueSize = 10;
44
TestQueueTestQueue45 TestQueue() : val_() {}
46
~TestQueueTestQueue47 ~TestQueue() { VERIFY_IS_EQUAL(val_.load(), 0); }
48
PushTestQueue49 bool Push() {
50 int val = val_.load(std::memory_order_relaxed);
51 for (;;) {
52 VERIFY_GE(val, 0);
53 VERIFY_LE(val, kQueueSize);
54 if (val == kQueueSize) return false;
55 if (val_.compare_exchange_weak(val, val + 1, std::memory_order_relaxed))
56 return true;
57 }
58 }
59
PopTestQueue60 bool Pop() {
61 int val = val_.load(std::memory_order_relaxed);
62 for (;;) {
63 VERIFY_GE(val, 0);
64 VERIFY_LE(val, kQueueSize);
65 if (val == 0) return false;
66 if (val_.compare_exchange_weak(val, val - 1, std::memory_order_relaxed))
67 return true;
68 }
69 }
70
EmptyTestQueue71 bool Empty() { return val_.load(std::memory_order_relaxed) == 0; }
72 };
73
74 const int TestQueue::kQueueSize;
75
76 // A number of producers send messages to a set of consumers using a set of
77 // fake queues. Ensure that it does not crash, consumers don't deadlock and
78 // number of blocked and unblocked threads match.
test_stress_eventcount()79 static void test_stress_eventcount()
80 {
81 const int kThreads = std::thread::hardware_concurrency();
82 static const int kEvents = 1 << 16;
83 static const int kQueues = 10;
84
85 MaxSizeVector<EventCount::Waiter> waiters(kThreads);
86 waiters.resize(kThreads);
87 EventCount ec(waiters);
88 TestQueue queues[kQueues];
89
90 std::vector<std::unique_ptr<std::thread>> producers;
91 for (int i = 0; i < kThreads; i++) {
92 producers.emplace_back(new std::thread([&ec, &queues]() {
93 unsigned int rnd = static_cast<unsigned int>(std::hash<std::thread::id>()(std::this_thread::get_id()));
94 for (int j = 0; j < kEvents; j++) {
95 unsigned idx = rand_reentrant(&rnd) % kQueues;
96 if (queues[idx].Push()) {
97 ec.Notify(false);
98 continue;
99 }
100 EIGEN_THREAD_YIELD();
101 j--;
102 }
103 }));
104 }
105
106 std::vector<std::unique_ptr<std::thread>> consumers;
107 for (int i = 0; i < kThreads; i++) {
108 consumers.emplace_back(new std::thread([&ec, &queues, &waiters, i]() {
109 EventCount::Waiter& w = waiters[i];
110 unsigned int rnd = static_cast<unsigned int>(std::hash<std::thread::id>()(std::this_thread::get_id()));
111 for (int j = 0; j < kEvents; j++) {
112 unsigned idx = rand_reentrant(&rnd) % kQueues;
113 if (queues[idx].Pop()) continue;
114 j--;
115 ec.Prewait(&w);
116 bool empty = true;
117 for (int q = 0; q < kQueues; q++) {
118 if (!queues[q].Empty()) {
119 empty = false;
120 break;
121 }
122 }
123 if (!empty) {
124 ec.CancelWait(&w);
125 continue;
126 }
127 ec.CommitWait(&w);
128 }
129 }));
130 }
131
132 for (int i = 0; i < kThreads; i++) {
133 producers[i]->join();
134 consumers[i]->join();
135 }
136 }
137
test_cxx11_eventcount()138 void test_cxx11_eventcount()
139 {
140 CALL_SUBTEST(test_basic_eventcount());
141 CALL_SUBTEST(test_stress_eventcount());
142 }
143