1 // Copyright 2017 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "absl/synchronization/blocking_counter.h"
16
17 #include <thread> // NOLINT(build/c++11)
18 #include <tuple>
19 #include <vector>
20
21 #include "gtest/gtest.h"
22 #include "absl/base/attributes.h"
23 #include "absl/base/config.h"
24 #include "absl/base/internal/tracing.h"
25 #include "absl/time/clock.h"
26 #include "absl/time/time.h"
27
28 namespace absl {
29 ABSL_NAMESPACE_BEGIN
30 namespace {
31
PauseAndDecreaseCounter(BlockingCounter * counter,int * done)32 void PauseAndDecreaseCounter(BlockingCounter* counter, int* done) {
33 absl::SleepFor(absl::Seconds(1));
34 *done = 1;
35 counter->DecrementCount();
36 }
37
TEST(BlockingCounterTest,BasicFunctionality)38 TEST(BlockingCounterTest, BasicFunctionality) {
39 // This test verifies that BlockingCounter functions correctly. Starts a
40 // number of threads that just sleep for a second and decrement a counter.
41
42 // Initialize the counter.
43 const int num_workers = 10;
44 BlockingCounter counter(num_workers);
45
46 std::vector<std::thread> workers;
47 std::vector<int> done(num_workers, 0);
48
49 // Start a number of parallel tasks that will just wait for a seconds and
50 // then decrement the count.
51 workers.reserve(num_workers);
52 for (int k = 0; k < num_workers; k++) {
53 workers.emplace_back(
54 [&counter, &done, k] { PauseAndDecreaseCounter(&counter, &done[k]); });
55 }
56
57 // Wait for the threads to have all finished.
58 counter.Wait();
59
60 // Check that all the workers have completed.
61 for (int k = 0; k < num_workers; k++) {
62 EXPECT_EQ(1, done[k]);
63 }
64
65 for (std::thread& w : workers) {
66 w.join();
67 }
68 }
69
TEST(BlockingCounterTest,WaitZeroInitialCount)70 TEST(BlockingCounterTest, WaitZeroInitialCount) {
71 BlockingCounter counter(0);
72 counter.Wait();
73 }
74
75 #if GTEST_HAS_DEATH_TEST
TEST(BlockingCounterTest,WaitNegativeInitialCount)76 TEST(BlockingCounterTest, WaitNegativeInitialCount) {
77 EXPECT_DEATH(BlockingCounter counter(-1),
78 "BlockingCounter initial_count negative");
79 }
80 #endif
81
82 } // namespace
83
84 #if ABSL_HAVE_ATTRIBUTE_WEAK
85
86 namespace base_internal {
87
88 namespace {
89
90 using TraceRecord = std::tuple<const void*, ObjectKind>;
91
92 thread_local TraceRecord tls_signal;
93 thread_local TraceRecord tls_wait;
94 thread_local TraceRecord tls_continue;
95
96 } // namespace
97
98 // Strong extern "C" implementation.
99 extern "C" {
100
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)101 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)(const void* object,
102 ObjectKind kind) {
103 tls_wait = {object, kind};
104 }
105
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)106 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)(const void* object,
107 ObjectKind kind) {
108 tls_continue = {object, kind};
109 }
110
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)111 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)(const void* object,
112 ObjectKind kind) {
113 tls_signal = {object, kind};
114 }
115
116 } // extern "C"
117
TEST(BlockingCounterTest,TracesSignal)118 TEST(BlockingCounterTest, TracesSignal) {
119 BlockingCounter counter(2);
120
121 tls_signal = {};
122 counter.DecrementCount();
123 EXPECT_EQ(tls_signal, TraceRecord(nullptr, ObjectKind::kUnknown));
124
125 tls_signal = {};
126 counter.DecrementCount();
127 EXPECT_EQ(tls_signal, TraceRecord(&counter, ObjectKind::kBlockingCounter));
128 }
129
TEST(BlockingCounterTest,TracesWaitContinue)130 TEST(BlockingCounterTest, TracesWaitContinue) {
131 BlockingCounter counter(1);
132 counter.DecrementCount();
133
134 tls_wait = {};
135 tls_continue = {};
136 counter.Wait();
137 EXPECT_EQ(tls_wait, TraceRecord(&counter, ObjectKind::kBlockingCounter));
138 EXPECT_EQ(tls_continue, TraceRecord(&counter, ObjectKind::kBlockingCounter));
139 }
140
141 } // namespace base_internal
142
143 #endif // ABSL_HAVE_ATTRIBUTE_WEAK
144
145 ABSL_NAMESPACE_END
146 } // namespace absl
147