• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "absl/synchronization/blocking_counter.h"
16 
17 #include <thread>  // NOLINT(build/c++11)
18 #include <tuple>
19 #include <vector>
20 
21 #include "gtest/gtest.h"
22 #include "absl/base/attributes.h"
23 #include "absl/base/config.h"
24 #include "absl/base/internal/tracing.h"
25 #include "absl/time/clock.h"
26 #include "absl/time/time.h"
27 
28 namespace absl {
29 ABSL_NAMESPACE_BEGIN
30 namespace {
31 
PauseAndDecreaseCounter(BlockingCounter * counter,int * done)32 void PauseAndDecreaseCounter(BlockingCounter* counter, int* done) {
33   absl::SleepFor(absl::Seconds(1));
34   *done = 1;
35   counter->DecrementCount();
36 }
37 
TEST(BlockingCounterTest,BasicFunctionality)38 TEST(BlockingCounterTest, BasicFunctionality) {
39   // This test verifies that BlockingCounter functions correctly. Starts a
40   // number of threads that just sleep for a second and decrement a counter.
41 
42   // Initialize the counter.
43   const int num_workers = 10;
44   BlockingCounter counter(num_workers);
45 
46   std::vector<std::thread> workers;
47   std::vector<int> done(num_workers, 0);
48 
49   // Start a number of parallel tasks that will just wait for a seconds and
50   // then decrement the count.
51   workers.reserve(num_workers);
52   for (int k = 0; k < num_workers; k++) {
53     workers.emplace_back(
54         [&counter, &done, k] { PauseAndDecreaseCounter(&counter, &done[k]); });
55   }
56 
57   // Wait for the threads to have all finished.
58   counter.Wait();
59 
60   // Check that all the workers have completed.
61   for (int k = 0; k < num_workers; k++) {
62     EXPECT_EQ(1, done[k]);
63   }
64 
65   for (std::thread& w : workers) {
66     w.join();
67   }
68 }
69 
TEST(BlockingCounterTest,WaitZeroInitialCount)70 TEST(BlockingCounterTest, WaitZeroInitialCount) {
71   BlockingCounter counter(0);
72   counter.Wait();
73 }
74 
75 #if GTEST_HAS_DEATH_TEST
TEST(BlockingCounterTest,WaitNegativeInitialCount)76 TEST(BlockingCounterTest, WaitNegativeInitialCount) {
77   EXPECT_DEATH(BlockingCounter counter(-1),
78                "BlockingCounter initial_count negative");
79 }
80 #endif
81 
82 }  // namespace
83 
84 #if ABSL_HAVE_ATTRIBUTE_WEAK
85 
86 namespace base_internal {
87 
88 namespace {
89 
90 using TraceRecord = std::tuple<const void*, ObjectKind>;
91 
92 thread_local TraceRecord tls_signal;
93 thread_local TraceRecord tls_wait;
94 thread_local TraceRecord tls_continue;
95 
96 }  // namespace
97 
98 // Strong extern "C" implementation.
99 extern "C" {
100 
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)101 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)(const void* object,
102                                                    ObjectKind kind) {
103   tls_wait = {object, kind};
104 }
105 
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)106 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)(const void* object,
107                                                        ObjectKind kind) {
108   tls_continue = {object, kind};
109 }
110 
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)111 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)(const void* object,
112                                                      ObjectKind kind) {
113   tls_signal = {object, kind};
114 }
115 
116 }  // extern "C"
117 
TEST(BlockingCounterTest,TracesSignal)118 TEST(BlockingCounterTest, TracesSignal) {
119   BlockingCounter counter(2);
120 
121   tls_signal = {};
122   counter.DecrementCount();
123   EXPECT_EQ(tls_signal, TraceRecord(nullptr, ObjectKind::kUnknown));
124 
125   tls_signal = {};
126   counter.DecrementCount();
127   EXPECT_EQ(tls_signal, TraceRecord(&counter, ObjectKind::kBlockingCounter));
128 }
129 
TEST(BlockingCounterTest,TracesWaitContinue)130 TEST(BlockingCounterTest, TracesWaitContinue) {
131   BlockingCounter counter(1);
132   counter.DecrementCount();
133 
134   tls_wait = {};
135   tls_continue = {};
136   counter.Wait();
137   EXPECT_EQ(tls_wait, TraceRecord(&counter, ObjectKind::kBlockingCounter));
138   EXPECT_EQ(tls_continue, TraceRecord(&counter, ObjectKind::kBlockingCounter));
139 }
140 
141 }  // namespace base_internal
142 
143 #endif  // ABSL_HAVE_ATTRIBUTE_WEAK
144 
145 ABSL_NAMESPACE_END
146 }  // namespace absl
147