• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "absl/synchronization/notification.h"
16 
17 #include <thread>  // NOLINT(build/c++11)
18 #include <tuple>
19 #include <vector>
20 
21 #include "gtest/gtest.h"
22 #include "absl/base/attributes.h"
23 #include "absl/base/config.h"
24 #include "absl/base/internal/tracing.h"
25 #include "absl/synchronization/mutex.h"
26 #include "absl/time/time.h"
27 
28 namespace absl {
29 ABSL_NAMESPACE_BEGIN
30 
31 // A thread-safe class that holds a counter.
32 class ThreadSafeCounter {
33  public:
ThreadSafeCounter()34   ThreadSafeCounter() : count_(0) {}
35 
Increment()36   void Increment() {
37     MutexLock lock(&mutex_);
38     ++count_;
39   }
40 
Get() const41   int Get() const {
42     MutexLock lock(&mutex_);
43     return count_;
44   }
45 
WaitUntilGreaterOrEqual(int n)46   void WaitUntilGreaterOrEqual(int n) {
47     MutexLock lock(&mutex_);
48     auto cond = [this, n]() { return count_ >= n; };
49     mutex_.Await(Condition(&cond));
50   }
51 
52  private:
53   mutable Mutex mutex_;
54   int count_;
55 };
56 
57 // Runs the |i|'th worker thread for the tests in BasicTests().  Increments the
58 // |ready_counter|, waits on the |notification|, and then increments the
59 // |done_counter|.
RunWorker(int i,ThreadSafeCounter * ready_counter,Notification * notification,ThreadSafeCounter * done_counter)60 static void RunWorker(int i, ThreadSafeCounter* ready_counter,
61                       Notification* notification,
62                       ThreadSafeCounter* done_counter) {
63   ready_counter->Increment();
64   notification->WaitForNotification();
65   done_counter->Increment();
66 }
67 
68 // Tests that the |notification| properly blocks and awakens threads.  Assumes
69 // that the |notification| is not yet triggered.  If |notify_before_waiting| is
70 // true, the |notification| is triggered before any threads are created, so the
71 // threads never block in WaitForNotification().  Otherwise, the |notification|
72 // is triggered at a later point when most threads are likely to be blocking in
73 // WaitForNotification().
BasicTests(bool notify_before_waiting,Notification * notification)74 static void BasicTests(bool notify_before_waiting, Notification* notification) {
75   EXPECT_FALSE(notification->HasBeenNotified());
76   EXPECT_FALSE(
77       notification->WaitForNotificationWithTimeout(absl::Milliseconds(0)));
78   EXPECT_FALSE(notification->WaitForNotificationWithDeadline(absl::Now()));
79 
80   const absl::Duration delay = absl::Milliseconds(50);
81   const absl::Time start = absl::Now();
82   EXPECT_FALSE(notification->WaitForNotificationWithTimeout(delay));
83   const absl::Duration elapsed = absl::Now() - start;
84 
85   // Allow for a slight early return, to account for quality of implementation
86   // issues on various platforms.
87   absl::Duration slop = absl::Milliseconds(5);
88 #ifdef _MSC_VER
89   // Avoid flakiness on MSVC.
90   slop = absl::Milliseconds(15);
91 #endif
92   EXPECT_LE(delay - slop, elapsed)
93       << "WaitForNotificationWithTimeout returned " << delay - elapsed
94       << " early (with " << slop << " slop), start time was " << start;
95 
96   ThreadSafeCounter ready_counter;
97   ThreadSafeCounter done_counter;
98 
99   if (notify_before_waiting) {
100     notification->Notify();
101   }
102 
103   // Create a bunch of threads that increment the |done_counter| after being
104   // notified.
105   const int kNumThreads = 10;
106   std::vector<std::thread> workers;
107   for (int i = 0; i < kNumThreads; ++i) {
108     workers.push_back(std::thread(&RunWorker, i, &ready_counter, notification,
109                                   &done_counter));
110   }
111 
112   if (!notify_before_waiting) {
113     ready_counter.WaitUntilGreaterOrEqual(kNumThreads);
114 
115     // Workers have not been notified yet, so the |done_counter| should be
116     // unmodified.
117     EXPECT_EQ(0, done_counter.Get());
118 
119     notification->Notify();
120   }
121 
122   // After notifying and then joining the workers, both counters should be
123   // fully incremented.
124   notification->WaitForNotification();  // should exit immediately
125   EXPECT_TRUE(notification->HasBeenNotified());
126   EXPECT_TRUE(notification->WaitForNotificationWithTimeout(absl::Seconds(0)));
127   EXPECT_TRUE(notification->WaitForNotificationWithDeadline(absl::Now()));
128   for (std::thread& worker : workers) {
129     worker.join();
130   }
131   EXPECT_EQ(kNumThreads, ready_counter.Get());
132   EXPECT_EQ(kNumThreads, done_counter.Get());
133 }
134 
TEST(NotificationTest,SanityTest)135 TEST(NotificationTest, SanityTest) {
136   Notification local_notification1, local_notification2;
137   BasicTests(false, &local_notification1);
138   BasicTests(true, &local_notification2);
139 }
140 
141 #if ABSL_HAVE_ATTRIBUTE_WEAK
142 
143 namespace base_internal {
144 
145 namespace {
146 
147 using TraceRecord = std::tuple<const void*, ObjectKind>;
148 
149 thread_local TraceRecord tls_signal;
150 thread_local TraceRecord tls_wait;
151 thread_local TraceRecord tls_continue;
152 thread_local TraceRecord tls_observed;
153 
154 }  // namespace
155 
156 // Strong extern "C" implementation.
157 extern "C" {
158 
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)159 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)(const void* object,
160                                                    ObjectKind kind) {
161   tls_wait = {object, kind};
162 }
163 
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)164 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)(const void* object,
165                                                        ObjectKind kind) {
166   tls_continue = {object, kind};
167 }
168 
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)169 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)(const void* object,
170                                                      ObjectKind kind) {
171   tls_signal = {object, kind};
172 }
173 
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceObserved)174 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceObserved)(const void* object,
175                                                        ObjectKind kind) {
176   tls_observed = {object, kind};
177 }
178 
179 }  // extern "C"
180 
TEST(NotificationTest,TracesNotify)181 TEST(NotificationTest, TracesNotify) {
182   Notification n;
183   tls_signal = {};
184   n.Notify();
185   EXPECT_EQ(tls_signal, TraceRecord(&n, ObjectKind::kNotification));
186 }
187 
TEST(NotificationTest,TracesWaitForNotification)188 TEST(NotificationTest, TracesWaitForNotification) {
189   Notification n;
190   n.Notify();
191   tls_wait = tls_continue = {};
192   n.WaitForNotification();
193   EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification));
194   EXPECT_EQ(tls_continue, TraceRecord(&n, ObjectKind::kNotification));
195 }
196 
TEST(NotificationTest,TracesWaitForNotificationWithTimeout)197 TEST(NotificationTest, TracesWaitForNotificationWithTimeout) {
198   Notification n;
199 
200   tls_wait = tls_continue = {};
201   n.WaitForNotificationWithTimeout(absl::Milliseconds(1));
202   EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification));
203   EXPECT_EQ(tls_continue, TraceRecord(nullptr, ObjectKind::kNotification));
204 
205   n.Notify();
206   tls_wait = tls_continue = {};
207   n.WaitForNotificationWithTimeout(absl::Milliseconds(1));
208   EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification));
209   EXPECT_EQ(tls_continue, TraceRecord(&n, ObjectKind::kNotification));
210 }
211 
TEST(NotificationTest,TracesHasBeenNotified)212 TEST(NotificationTest, TracesHasBeenNotified) {
213   Notification n;
214 
215   tls_observed = {};
216   ASSERT_FALSE(n.HasBeenNotified());
217   EXPECT_EQ(tls_observed, TraceRecord(nullptr, ObjectKind::kUnknown));
218 
219   n.Notify();
220   tls_observed = {};
221   ASSERT_TRUE(n.HasBeenNotified());
222   EXPECT_EQ(tls_observed, TraceRecord(&n, ObjectKind::kNotification));
223 }
224 
225 }  // namespace base_internal
226 
227 #endif  // ABSL_HAVE_ATTRIBUTE_WEAK
228 
229 ABSL_NAMESPACE_END
230 }  // namespace absl
231