• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 gRPC authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #ifndef GRPC_SRC_CORE_UTIL_DUAL_REF_COUNTED_H
18 #define GRPC_SRC_CORE_UTIL_DUAL_REF_COUNTED_H
19 
20 #include <grpc/support/port_platform.h>
21 
22 #include <atomic>
23 #include <cstdint>
24 
25 #include "absl/log/check.h"
26 #include "absl/log/log.h"
27 #include "src/core/util/debug_location.h"
28 #include "src/core/util/down_cast.h"
29 #include "src/core/util/orphanable.h"
30 #include "src/core/util/ref_counted.h"
31 #include "src/core/util/ref_counted_ptr.h"
32 
33 namespace grpc_core {
34 
35 // DualRefCounted is an interface for reference-counted objects with two
36 // classes of refs: strong refs (usually just called "refs") and weak refs.
37 // This supports cases where an object needs to start shutting down when
38 // all external callers are done with it (represented by strong refs) but
39 // cannot be destroyed until all internal callbacks are complete
40 // (represented by weak refs).
41 //
42 // Each class of refs can be incremented and decremented independently.
43 // Objects start with 1 strong ref and 0 weak refs at instantiation.
44 // When the strong refcount reaches 0, the object's Orphaned() method is called.
45 // When the weak refcount reaches 0, the object is destroyed.
46 //
47 // This will be used by CRTP (curiously-recurring template pattern), e.g.:
48 //   class MyClass : public RefCounted<MyClass> { ... };
49 //
50 // Impl & UnrefBehavior are as per RefCounted.
51 template <typename Child, typename Impl = PolymorphicRefCount,
52           typename UnrefBehavior = UnrefDelete>
53 class DualRefCounted : public Impl {
54  public:
55   // Not copyable nor movable.
56   DualRefCounted(const DualRefCounted&) = delete;
57   DualRefCounted& operator=(const DualRefCounted&) = delete;
58 
Ref()59   GRPC_MUST_USE_RESULT RefCountedPtr<Child> Ref() {
60     IncrementRefCount();
61     return RefCountedPtr<Child>(static_cast<Child*>(this));
62   }
Ref(const DebugLocation & location,const char * reason)63   GRPC_MUST_USE_RESULT RefCountedPtr<Child> Ref(const DebugLocation& location,
64                                                 const char* reason) {
65     IncrementRefCount(location, reason);
66     return RefCountedPtr<Child>(static_cast<Child*>(this));
67   }
68 
69   template <
70       typename Subclass,
71       std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
RefAsSubclass()72   RefCountedPtr<Subclass> RefAsSubclass() {
73     IncrementRefCount();
74     return RefCountedPtr<Subclass>(
75         DownCast<Subclass*>(static_cast<Child*>(this)));
76   }
77   template <
78       typename Subclass,
79       std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
RefAsSubclass(const DebugLocation & location,const char * reason)80   RefCountedPtr<Subclass> RefAsSubclass(const DebugLocation& location,
81                                         const char* reason) {
82     IncrementRefCount(location, reason);
83     return RefCountedPtr<Subclass>(
84         DownCast<Subclass*>(static_cast<Child*>(this)));
85   }
86 
Unref()87   void Unref() {
88     // Convert strong ref to weak ref.
89     const uint64_t prev_ref_pair =
90         refs_.fetch_add(MakeRefPair(-1, 1), std::memory_order_acq_rel);
91     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
92 #ifndef NDEBUG
93     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
94     if (trace_ != nullptr) {
95       VLOG(2) << trace_ << ":" << this << " unref " << strong_refs << " -> "
96               << strong_refs - 1 << ", weak_ref " << weak_refs << " -> "
97               << weak_refs + 1;
98     }
99     CHECK_GT(strong_refs, 0u);
100 #endif
101     if (GPR_UNLIKELY(strong_refs == 1)) {
102       Orphaned();
103     }
104     // Now drop the weak ref.
105     WeakUnref();
106   }
Unref(const DebugLocation & location,const char * reason)107   void Unref(const DebugLocation& location, const char* reason) {
108     const uint64_t prev_ref_pair =
109         refs_.fetch_add(MakeRefPair(-1, 1), std::memory_order_acq_rel);
110     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
111 #ifndef NDEBUG
112     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
113     if (trace_ != nullptr) {
114       VLOG(2) << trace_ << ":" << this << " " << location.file() << ":"
115               << location.line() << " unref " << strong_refs << " -> "
116               << strong_refs - 1 << ", weak_ref " << weak_refs << " -> "
117               << weak_refs + 1 << ") " << reason;
118     }
119     CHECK_GT(strong_refs, 0u);
120 #else
121     // Avoid unused-parameter warnings for debug-only parameters
122     (void)location;
123     (void)reason;
124 #endif
125     if (GPR_UNLIKELY(strong_refs == 1)) {
126       Orphaned();
127     }
128     // Now drop the weak ref.
129     WeakUnref(location, reason);
130   }
131 
RefIfNonZero()132   GRPC_MUST_USE_RESULT RefCountedPtr<Child> RefIfNonZero() {
133     uint64_t prev_ref_pair = refs_.load(std::memory_order_acquire);
134     do {
135       const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
136 #ifndef NDEBUG
137       const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
138       if (trace_ != nullptr) {
139         VLOG(2) << trace_ << ":" << this << " ref_if_non_zero " << strong_refs
140                 << " -> " << strong_refs + 1 << " (weak_refs=" << weak_refs
141                 << ")";
142       }
143 #endif
144       if (strong_refs == 0) return nullptr;
145     } while (!refs_.compare_exchange_weak(
146         prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0),
147         std::memory_order_acq_rel, std::memory_order_acquire));
148     return RefCountedPtr<Child>(static_cast<Child*>(this));
149   }
RefIfNonZero(const DebugLocation & location,const char * reason)150   GRPC_MUST_USE_RESULT RefCountedPtr<Child> RefIfNonZero(
151       const DebugLocation& location, const char* reason) {
152     uint64_t prev_ref_pair = refs_.load(std::memory_order_acquire);
153     do {
154       const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
155 #ifndef NDEBUG
156       const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
157       if (trace_ != nullptr) {
158         VLOG(2) << trace_ << ":" << this << " " << location.file() << ":"
159                 << location.line() << " ref_if_non_zero " << strong_refs
160                 << " -> " << strong_refs + 1 << " (weak_refs=" << weak_refs
161                 << ") " << reason;
162       }
163 #else
164       // Avoid unused-parameter warnings for debug-only parameters
165       (void)location;
166       (void)reason;
167 #endif
168       if (strong_refs == 0) return nullptr;
169     } while (!refs_.compare_exchange_weak(
170         prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0),
171         std::memory_order_acq_rel, std::memory_order_acquire));
172     return RefCountedPtr<Child>(static_cast<Child*>(this));
173   }
174 
WeakRef()175   GRPC_MUST_USE_RESULT WeakRefCountedPtr<Child> WeakRef() {
176     IncrementWeakRefCount();
177     return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
178   }
WeakRef(const DebugLocation & location,const char * reason)179   GRPC_MUST_USE_RESULT WeakRefCountedPtr<Child> WeakRef(
180       const DebugLocation& location, const char* reason) {
181     IncrementWeakRefCount(location, reason);
182     return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
183   }
184 
185   template <
186       typename Subclass,
187       std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
WeakRefAsSubclass()188   WeakRefCountedPtr<Subclass> WeakRefAsSubclass() {
189     IncrementWeakRefCount();
190     return WeakRefCountedPtr<Subclass>(
191         DownCast<Subclass*>(static_cast<Child*>(this)));
192   }
193   template <
194       typename Subclass,
195       std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true>
WeakRefAsSubclass(const DebugLocation & location,const char * reason)196   WeakRefCountedPtr<Subclass> WeakRefAsSubclass(const DebugLocation& location,
197                                                 const char* reason) {
198     IncrementWeakRefCount(location, reason);
199     return WeakRefCountedPtr<Subclass>(
200         DownCast<Subclass*>(static_cast<Child*>(this)));
201   }
202 
WeakUnref()203   void WeakUnref() {
204 #ifndef NDEBUG
205     // Grab a copy of the trace flag before the atomic change, since we
206     // will no longer be holding a ref afterwards and therefore can't
207     // safely access it, since another thread might free us in the interim.
208     const char* trace = trace_;
209 #endif
210     const uint64_t prev_ref_pair =
211         refs_.fetch_sub(MakeRefPair(0, 1), std::memory_order_acq_rel);
212 #ifndef NDEBUG
213     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
214     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
215     if (trace != nullptr) {
216       VLOG(2) << trace << ":" << this << " weak_unref " << weak_refs << " -> "
217               << weak_refs - 1 << " (refs=" << strong_refs << ")";
218     }
219     CHECK_GT(weak_refs, 0u);
220 #endif
221     if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) {
222       unref_behavior_(static_cast<Child*>(this));
223     }
224   }
WeakUnref(const DebugLocation & location,const char * reason)225   void WeakUnref(const DebugLocation& location, const char* reason) {
226 #ifndef NDEBUG
227     // Grab a copy of the trace flag before the atomic change, since we
228     // will no longer be holding a ref afterwards and therefore can't
229     // safely access it, since another thread might free us in the interim.
230     const char* trace = trace_;
231 #endif
232     const uint64_t prev_ref_pair =
233         refs_.fetch_sub(MakeRefPair(0, 1), std::memory_order_acq_rel);
234 #ifndef NDEBUG
235     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
236     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
237     if (trace != nullptr) {
238       VLOG(2) << trace << ":" << this << " " << location.file() << ":"
239               << location.line() << " weak_unref " << weak_refs << " -> "
240               << weak_refs - 1 << " (refs=" << strong_refs << ") " << reason;
241     }
242     CHECK_GT(weak_refs, 0u);
243 #else
244     // Avoid unused-parameter warnings for debug-only parameters
245     (void)location;
246     (void)reason;
247 #endif
248     if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) {
249       unref_behavior_(static_cast<const Child*>(this));
250     }
251   }
252 
253  protected:
254   // Note: Tracing is a no-op in non-debug builds.
255   explicit DualRefCounted(
256       const char*
257 #ifndef NDEBUG
258           // Leave unnamed if NDEBUG to avoid unused parameter warning
259           trace
260 #endif
261       = nullptr,
262       int32_t initial_refcount = 1)
263       :
264 #ifndef NDEBUG
trace_(trace)265         trace_(trace),
266 #endif
267         refs_(MakeRefPair(initial_refcount, 0)) {
268   }
269 
270   // Ref count has dropped to zero, so the object is now orphaned.
271   virtual void Orphaned() = 0;
272 
273   // Note: Depending on the Impl used, this dtor can be implicitly virtual.
274   ~DualRefCounted() = default;
275 
276  private:
277   // Allow RefCountedPtr<> to access IncrementRefCount().
278   template <typename T>
279   friend class RefCountedPtr;
280   // Allow WeakRefCountedPtr<> to access IncrementWeakRefCount().
281   template <typename T>
282   friend class WeakRefCountedPtr;
283 
284   // First 32 bits are strong refs, next 32 bits are weak refs.
MakeRefPair(uint32_t strong,uint32_t weak)285   static uint64_t MakeRefPair(uint32_t strong, uint32_t weak) {
286     return (static_cast<uint64_t>(strong) << 32) + static_cast<int64_t>(weak);
287   }
GetStrongRefs(uint64_t ref_pair)288   static uint32_t GetStrongRefs(uint64_t ref_pair) {
289     return static_cast<uint32_t>(ref_pair >> 32);
290   }
GetWeakRefs(uint64_t ref_pair)291   static uint32_t GetWeakRefs(uint64_t ref_pair) {
292     return static_cast<uint32_t>(ref_pair & 0xffffffffu);
293   }
294 
IncrementRefCount()295   void IncrementRefCount() {
296 #ifndef NDEBUG
297     const uint64_t prev_ref_pair =
298         refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed);
299     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
300     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
301     CHECK_NE(strong_refs, 0u);
302     if (trace_ != nullptr) {
303       VLOG(2) << trace_ << ":" << this << " ref " << strong_refs << " -> "
304               << strong_refs + 1 << "; (weak_refs=" << weak_refs << ")";
305     }
306 #else
307     refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed);
308 #endif
309   }
IncrementRefCount(const DebugLocation & location,const char * reason)310   void IncrementRefCount(const DebugLocation& location, const char* reason) {
311 #ifndef NDEBUG
312     const uint64_t prev_ref_pair =
313         refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed);
314     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
315     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
316     CHECK_NE(strong_refs, 0u);
317     if (trace_ != nullptr) {
318       VLOG(2) << trace_ << ":" << this << " " << location.file() << ":"
319               << location.line() << " ref " << strong_refs << " -> "
320               << strong_refs + 1 << " (weak_refs=" << weak_refs << ") "
321               << reason;
322     }
323 #else
324     // Use conditionally-important parameters
325     (void)location;
326     (void)reason;
327     refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed);
328 #endif
329   }
330 
IncrementWeakRefCount()331   void IncrementWeakRefCount() {
332 #ifndef NDEBUG
333     const uint64_t prev_ref_pair =
334         refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed);
335     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
336     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
337     if (trace_ != nullptr) {
338       VLOG(2) << trace_ << ":" << this << " weak_ref " << weak_refs << " -> "
339               << weak_refs + 1 << "; (refs=" << strong_refs << ")";
340     }
341     if (strong_refs == 0) CHECK_NE(weak_refs, 0u);
342 #else
343     refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed);
344 #endif
345   }
IncrementWeakRefCount(const DebugLocation & location,const char * reason)346   void IncrementWeakRefCount(const DebugLocation& location,
347                              const char* reason) {
348 #ifndef NDEBUG
349     const uint64_t prev_ref_pair =
350         refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed);
351     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
352     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
353     if (trace_ != nullptr) {
354       VLOG(2) << trace_ << ":" << this << " " << location.file() << ":"
355               << location.line() << " weak_ref " << weak_refs << " -> "
356               << weak_refs + 1 << " (refs=" << strong_refs << ") " << reason;
357     }
358     if (strong_refs == 0) CHECK_NE(weak_refs, 0u);
359 #else
360     // Use conditionally-important parameters
361     (void)location;
362     (void)reason;
363     refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed);
364 #endif
365   }
366 
367 #ifndef NDEBUG
368   const char* trace_;
369 #endif
370   std::atomic<uint64_t> refs_{0};
371   GPR_NO_UNIQUE_ADDRESS UnrefBehavior unref_behavior_;
372 };
373 
374 }  // namespace grpc_core
375 
376 #endif  // GRPC_SRC_CORE_UTIL_DUAL_REF_COUNTED_H
377