• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // This Source Code Form is subject to the terms of the Mozilla
5 // Public License v. 2.0. If a copy of the MPL was not distributed
6 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
7 
8 #define EIGEN_USE_THREADS
9 
10 #include <iostream>
11 #include <unordered_set>
12 
13 #include "main.h"
14 #include <Eigen/CXX11/ThreadPool>
15 
16 struct Counter {
17   Counter() = default;
18 
incCounter19   void inc() {
20     // Check that mutation happens only in a thread that created this counter.
21     VERIFY_IS_EQUAL(std::this_thread::get_id(), created_by);
22     counter_value++;
23   }
valueCounter24   int value() { return counter_value; }
25 
26   std::thread::id created_by;
27   int counter_value = 0;
28 };
29 
30 struct InitCounter {
operator ()InitCounter31   void operator()(Counter& counter) {
32     counter.created_by = std::this_thread::get_id();
33   }
34 };
35 
test_simple_thread_local()36 void test_simple_thread_local() {
37   int num_threads = internal::random<int>(4, 32);
38   Eigen::ThreadPool thread_pool(num_threads);
39   Eigen::ThreadLocal<Counter, InitCounter> counter(num_threads, InitCounter());
40 
41   int num_tasks = 3 * num_threads;
42   Eigen::Barrier barrier(num_tasks);
43 
44   for (int i = 0; i < num_tasks; ++i) {
45     thread_pool.Schedule([&counter, &barrier]() {
46       Counter& local = counter.local();
47       local.inc();
48 
49       std::this_thread::sleep_for(std::chrono::milliseconds(100));
50       barrier.Notify();
51     });
52   }
53 
54   barrier.Wait();
55 
56   counter.ForEach(
57       [](std::thread::id, Counter& cnt) { VERIFY_IS_EQUAL(cnt.value(), 3); });
58 }
59 
test_zero_sized_thread_local()60 void test_zero_sized_thread_local() {
61   Eigen::ThreadLocal<Counter, InitCounter> counter(0, InitCounter());
62 
63   Counter& local = counter.local();
64   local.inc();
65 
66   int total = 0;
67   counter.ForEach([&total](std::thread::id, Counter& cnt) {
68     total += cnt.value();
69     VERIFY_IS_EQUAL(cnt.value(), 1);
70   });
71 
72   VERIFY_IS_EQUAL(total, 1);
73 }
74 
75 // All thread local values fits into the lock-free storage.
test_large_number_of_tasks_no_spill()76 void test_large_number_of_tasks_no_spill() {
77   int num_threads = internal::random<int>(4, 32);
78   Eigen::ThreadPool thread_pool(num_threads);
79   Eigen::ThreadLocal<Counter, InitCounter> counter(num_threads, InitCounter());
80 
81   int num_tasks = 10000;
82   Eigen::Barrier barrier(num_tasks);
83 
84   for (int i = 0; i < num_tasks; ++i) {
85     thread_pool.Schedule([&counter, &barrier]() {
86       Counter& local = counter.local();
87       local.inc();
88       barrier.Notify();
89     });
90   }
91 
92   barrier.Wait();
93 
94   int total = 0;
95   std::unordered_set<std::thread::id> unique_threads;
96 
97   counter.ForEach([&](std::thread::id id, Counter& cnt) {
98     total += cnt.value();
99     unique_threads.insert(id);
100   });
101 
102   VERIFY_IS_EQUAL(total, num_tasks);
103   // Not all threads in a pool might be woken up to execute submitted tasks.
104   // Also thread_pool.Schedule() might use current thread if queue is full.
105   VERIFY_IS_EQUAL(
106       unique_threads.size() <= (static_cast<size_t>(num_threads + 1)), true);
107 }
108 
109 // Lock free thread local storage is too small to fit all the unique threads,
110 // and it spills to a map guarded by a mutex.
test_large_number_of_tasks_with_spill()111 void test_large_number_of_tasks_with_spill() {
112   int num_threads = internal::random<int>(4, 32);
113   Eigen::ThreadPool thread_pool(num_threads);
114   Eigen::ThreadLocal<Counter, InitCounter> counter(1, InitCounter());
115 
116   int num_tasks = 10000;
117   Eigen::Barrier barrier(num_tasks);
118 
119   for (int i = 0; i < num_tasks; ++i) {
120     thread_pool.Schedule([&counter, &barrier]() {
121       Counter& local = counter.local();
122       local.inc();
123       barrier.Notify();
124     });
125   }
126 
127   barrier.Wait();
128 
129   int total = 0;
130   std::unordered_set<std::thread::id> unique_threads;
131 
132   counter.ForEach([&](std::thread::id id, Counter& cnt) {
133     total += cnt.value();
134     unique_threads.insert(id);
135   });
136 
137   VERIFY_IS_EQUAL(total, num_tasks);
138   // Not all threads in a pool might be woken up to execute submitted tasks.
139   // Also thread_pool.Schedule() might use current thread if queue is full.
140   VERIFY_IS_EQUAL(
141       unique_threads.size() <= (static_cast<size_t>(num_threads + 1)), true);
142 }
143 
EIGEN_DECLARE_TEST(cxx11_tensor_thread_local)144 EIGEN_DECLARE_TEST(cxx11_tensor_thread_local) {
145   CALL_SUBTEST(test_simple_thread_local());
146   CALL_SUBTEST(test_zero_sized_thread_local());
147   CALL_SUBTEST(test_large_number_of_tasks_no_spill());
148   CALL_SUBTEST(test_large_number_of_tasks_with_spill());
149 }
150