1 // 2 // Copyright 2020 gRPC authors. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 17 #ifndef GRPC_SRC_CORE_UTIL_DUAL_REF_COUNTED_H 18 #define GRPC_SRC_CORE_UTIL_DUAL_REF_COUNTED_H 19 20 #include <grpc/support/port_platform.h> 21 22 #include <atomic> 23 #include <cstdint> 24 25 #include "absl/log/check.h" 26 #include "absl/log/log.h" 27 #include "src/core/util/debug_location.h" 28 #include "src/core/util/down_cast.h" 29 #include "src/core/util/orphanable.h" 30 #include "src/core/util/ref_counted.h" 31 #include "src/core/util/ref_counted_ptr.h" 32 33 namespace grpc_core { 34 35 // DualRefCounted is an interface for reference-counted objects with two 36 // classes of refs: strong refs (usually just called "refs") and weak refs. 37 // This supports cases where an object needs to start shutting down when 38 // all external callers are done with it (represented by strong refs) but 39 // cannot be destroyed until all internal callbacks are complete 40 // (represented by weak refs). 41 // 42 // Each class of refs can be incremented and decremented independently. 43 // Objects start with 1 strong ref and 0 weak refs at instantiation. 44 // When the strong refcount reaches 0, the object's Orphaned() method is called. 45 // When the weak refcount reaches 0, the object is destroyed. 46 // 47 // This will be used by CRTP (curiously-recurring template pattern), e.g.: 48 // class MyClass : public RefCounted<MyClass> { ... }; 49 // 50 // Impl & UnrefBehavior are as per RefCounted. 51 template <typename Child, typename Impl = PolymorphicRefCount, 52 typename UnrefBehavior = UnrefDelete> 53 class DualRefCounted : public Impl { 54 public: 55 // Not copyable nor movable. 56 DualRefCounted(const DualRefCounted&) = delete; 57 DualRefCounted& operator=(const DualRefCounted&) = delete; 58 Ref()59 GRPC_MUST_USE_RESULT RefCountedPtr<Child> Ref() { 60 IncrementRefCount(); 61 return RefCountedPtr<Child>(static_cast<Child*>(this)); 62 } Ref(const DebugLocation & location,const char * reason)63 GRPC_MUST_USE_RESULT RefCountedPtr<Child> Ref(const DebugLocation& location, 64 const char* reason) { 65 IncrementRefCount(location, reason); 66 return RefCountedPtr<Child>(static_cast<Child*>(this)); 67 } 68 69 template < 70 typename Subclass, 71 std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true> RefAsSubclass()72 RefCountedPtr<Subclass> RefAsSubclass() { 73 IncrementRefCount(); 74 return RefCountedPtr<Subclass>( 75 DownCast<Subclass*>(static_cast<Child*>(this))); 76 } 77 template < 78 typename Subclass, 79 std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true> RefAsSubclass(const DebugLocation & location,const char * reason)80 RefCountedPtr<Subclass> RefAsSubclass(const DebugLocation& location, 81 const char* reason) { 82 IncrementRefCount(location, reason); 83 return RefCountedPtr<Subclass>( 84 DownCast<Subclass*>(static_cast<Child*>(this))); 85 } 86 Unref()87 void Unref() { 88 // Convert strong ref to weak ref. 89 const uint64_t prev_ref_pair = 90 refs_.fetch_add(MakeRefPair(-1, 1), std::memory_order_acq_rel); 91 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 92 #ifndef NDEBUG 93 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 94 if (trace_ != nullptr) { 95 VLOG(2) << trace_ << ":" << this << " unref " << strong_refs << " -> " 96 << strong_refs - 1 << ", weak_ref " << weak_refs << " -> " 97 << weak_refs + 1; 98 } 99 CHECK_GT(strong_refs, 0u); 100 #endif 101 if (GPR_UNLIKELY(strong_refs == 1)) { 102 Orphaned(); 103 } 104 // Now drop the weak ref. 105 WeakUnref(); 106 } Unref(const DebugLocation & location,const char * reason)107 void Unref(const DebugLocation& location, const char* reason) { 108 const uint64_t prev_ref_pair = 109 refs_.fetch_add(MakeRefPair(-1, 1), std::memory_order_acq_rel); 110 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 111 #ifndef NDEBUG 112 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 113 if (trace_ != nullptr) { 114 VLOG(2) << trace_ << ":" << this << " " << location.file() << ":" 115 << location.line() << " unref " << strong_refs << " -> " 116 << strong_refs - 1 << ", weak_ref " << weak_refs << " -> " 117 << weak_refs + 1 << ") " << reason; 118 } 119 CHECK_GT(strong_refs, 0u); 120 #else 121 // Avoid unused-parameter warnings for debug-only parameters 122 (void)location; 123 (void)reason; 124 #endif 125 if (GPR_UNLIKELY(strong_refs == 1)) { 126 Orphaned(); 127 } 128 // Now drop the weak ref. 129 WeakUnref(location, reason); 130 } 131 RefIfNonZero()132 GRPC_MUST_USE_RESULT RefCountedPtr<Child> RefIfNonZero() { 133 uint64_t prev_ref_pair = refs_.load(std::memory_order_acquire); 134 do { 135 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 136 #ifndef NDEBUG 137 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 138 if (trace_ != nullptr) { 139 VLOG(2) << trace_ << ":" << this << " ref_if_non_zero " << strong_refs 140 << " -> " << strong_refs + 1 << " (weak_refs=" << weak_refs 141 << ")"; 142 } 143 #endif 144 if (strong_refs == 0) return nullptr; 145 } while (!refs_.compare_exchange_weak( 146 prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0), 147 std::memory_order_acq_rel, std::memory_order_acquire)); 148 return RefCountedPtr<Child>(static_cast<Child*>(this)); 149 } RefIfNonZero(const DebugLocation & location,const char * reason)150 GRPC_MUST_USE_RESULT RefCountedPtr<Child> RefIfNonZero( 151 const DebugLocation& location, const char* reason) { 152 uint64_t prev_ref_pair = refs_.load(std::memory_order_acquire); 153 do { 154 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 155 #ifndef NDEBUG 156 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 157 if (trace_ != nullptr) { 158 VLOG(2) << trace_ << ":" << this << " " << location.file() << ":" 159 << location.line() << " ref_if_non_zero " << strong_refs 160 << " -> " << strong_refs + 1 << " (weak_refs=" << weak_refs 161 << ") " << reason; 162 } 163 #else 164 // Avoid unused-parameter warnings for debug-only parameters 165 (void)location; 166 (void)reason; 167 #endif 168 if (strong_refs == 0) return nullptr; 169 } while (!refs_.compare_exchange_weak( 170 prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0), 171 std::memory_order_acq_rel, std::memory_order_acquire)); 172 return RefCountedPtr<Child>(static_cast<Child*>(this)); 173 } 174 WeakRef()175 GRPC_MUST_USE_RESULT WeakRefCountedPtr<Child> WeakRef() { 176 IncrementWeakRefCount(); 177 return WeakRefCountedPtr<Child>(static_cast<Child*>(this)); 178 } WeakRef(const DebugLocation & location,const char * reason)179 GRPC_MUST_USE_RESULT WeakRefCountedPtr<Child> WeakRef( 180 const DebugLocation& location, const char* reason) { 181 IncrementWeakRefCount(location, reason); 182 return WeakRefCountedPtr<Child>(static_cast<Child*>(this)); 183 } 184 185 template < 186 typename Subclass, 187 std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true> WeakRefAsSubclass()188 WeakRefCountedPtr<Subclass> WeakRefAsSubclass() { 189 IncrementWeakRefCount(); 190 return WeakRefCountedPtr<Subclass>( 191 DownCast<Subclass*>(static_cast<Child*>(this))); 192 } 193 template < 194 typename Subclass, 195 std::enable_if_t<std::is_base_of<Child, Subclass>::value, bool> = true> WeakRefAsSubclass(const DebugLocation & location,const char * reason)196 WeakRefCountedPtr<Subclass> WeakRefAsSubclass(const DebugLocation& location, 197 const char* reason) { 198 IncrementWeakRefCount(location, reason); 199 return WeakRefCountedPtr<Subclass>( 200 DownCast<Subclass*>(static_cast<Child*>(this))); 201 } 202 WeakUnref()203 void WeakUnref() { 204 #ifndef NDEBUG 205 // Grab a copy of the trace flag before the atomic change, since we 206 // will no longer be holding a ref afterwards and therefore can't 207 // safely access it, since another thread might free us in the interim. 208 const char* trace = trace_; 209 #endif 210 const uint64_t prev_ref_pair = 211 refs_.fetch_sub(MakeRefPair(0, 1), std::memory_order_acq_rel); 212 #ifndef NDEBUG 213 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 214 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 215 if (trace != nullptr) { 216 VLOG(2) << trace << ":" << this << " weak_unref " << weak_refs << " -> " 217 << weak_refs - 1 << " (refs=" << strong_refs << ")"; 218 } 219 CHECK_GT(weak_refs, 0u); 220 #endif 221 if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) { 222 unref_behavior_(static_cast<Child*>(this)); 223 } 224 } WeakUnref(const DebugLocation & location,const char * reason)225 void WeakUnref(const DebugLocation& location, const char* reason) { 226 #ifndef NDEBUG 227 // Grab a copy of the trace flag before the atomic change, since we 228 // will no longer be holding a ref afterwards and therefore can't 229 // safely access it, since another thread might free us in the interim. 230 const char* trace = trace_; 231 #endif 232 const uint64_t prev_ref_pair = 233 refs_.fetch_sub(MakeRefPair(0, 1), std::memory_order_acq_rel); 234 #ifndef NDEBUG 235 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 236 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 237 if (trace != nullptr) { 238 VLOG(2) << trace << ":" << this << " " << location.file() << ":" 239 << location.line() << " weak_unref " << weak_refs << " -> " 240 << weak_refs - 1 << " (refs=" << strong_refs << ") " << reason; 241 } 242 CHECK_GT(weak_refs, 0u); 243 #else 244 // Avoid unused-parameter warnings for debug-only parameters 245 (void)location; 246 (void)reason; 247 #endif 248 if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) { 249 unref_behavior_(static_cast<const Child*>(this)); 250 } 251 } 252 253 protected: 254 // Note: Tracing is a no-op in non-debug builds. 255 explicit DualRefCounted( 256 const char* 257 #ifndef NDEBUG 258 // Leave unnamed if NDEBUG to avoid unused parameter warning 259 trace 260 #endif 261 = nullptr, 262 int32_t initial_refcount = 1) 263 : 264 #ifndef NDEBUG trace_(trace)265 trace_(trace), 266 #endif 267 refs_(MakeRefPair(initial_refcount, 0)) { 268 } 269 270 // Ref count has dropped to zero, so the object is now orphaned. 271 virtual void Orphaned() = 0; 272 273 // Note: Depending on the Impl used, this dtor can be implicitly virtual. 274 ~DualRefCounted() = default; 275 276 private: 277 // Allow RefCountedPtr<> to access IncrementRefCount(). 278 template <typename T> 279 friend class RefCountedPtr; 280 // Allow WeakRefCountedPtr<> to access IncrementWeakRefCount(). 281 template <typename T> 282 friend class WeakRefCountedPtr; 283 284 // First 32 bits are strong refs, next 32 bits are weak refs. MakeRefPair(uint32_t strong,uint32_t weak)285 static uint64_t MakeRefPair(uint32_t strong, uint32_t weak) { 286 return (static_cast<uint64_t>(strong) << 32) + static_cast<int64_t>(weak); 287 } GetStrongRefs(uint64_t ref_pair)288 static uint32_t GetStrongRefs(uint64_t ref_pair) { 289 return static_cast<uint32_t>(ref_pair >> 32); 290 } GetWeakRefs(uint64_t ref_pair)291 static uint32_t GetWeakRefs(uint64_t ref_pair) { 292 return static_cast<uint32_t>(ref_pair & 0xffffffffu); 293 } 294 IncrementRefCount()295 void IncrementRefCount() { 296 #ifndef NDEBUG 297 const uint64_t prev_ref_pair = 298 refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed); 299 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 300 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 301 CHECK_NE(strong_refs, 0u); 302 if (trace_ != nullptr) { 303 VLOG(2) << trace_ << ":" << this << " ref " << strong_refs << " -> " 304 << strong_refs + 1 << "; (weak_refs=" << weak_refs << ")"; 305 } 306 #else 307 refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed); 308 #endif 309 } IncrementRefCount(const DebugLocation & location,const char * reason)310 void IncrementRefCount(const DebugLocation& location, const char* reason) { 311 #ifndef NDEBUG 312 const uint64_t prev_ref_pair = 313 refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed); 314 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 315 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 316 CHECK_NE(strong_refs, 0u); 317 if (trace_ != nullptr) { 318 VLOG(2) << trace_ << ":" << this << " " << location.file() << ":" 319 << location.line() << " ref " << strong_refs << " -> " 320 << strong_refs + 1 << " (weak_refs=" << weak_refs << ") " 321 << reason; 322 } 323 #else 324 // Use conditionally-important parameters 325 (void)location; 326 (void)reason; 327 refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed); 328 #endif 329 } 330 IncrementWeakRefCount()331 void IncrementWeakRefCount() { 332 #ifndef NDEBUG 333 const uint64_t prev_ref_pair = 334 refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed); 335 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 336 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 337 if (trace_ != nullptr) { 338 VLOG(2) << trace_ << ":" << this << " weak_ref " << weak_refs << " -> " 339 << weak_refs + 1 << "; (refs=" << strong_refs << ")"; 340 } 341 if (strong_refs == 0) CHECK_NE(weak_refs, 0u); 342 #else 343 refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed); 344 #endif 345 } IncrementWeakRefCount(const DebugLocation & location,const char * reason)346 void IncrementWeakRefCount(const DebugLocation& location, 347 const char* reason) { 348 #ifndef NDEBUG 349 const uint64_t prev_ref_pair = 350 refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed); 351 const uint32_t strong_refs = GetStrongRefs(prev_ref_pair); 352 const uint32_t weak_refs = GetWeakRefs(prev_ref_pair); 353 if (trace_ != nullptr) { 354 VLOG(2) << trace_ << ":" << this << " " << location.file() << ":" 355 << location.line() << " weak_ref " << weak_refs << " -> " 356 << weak_refs + 1 << " (refs=" << strong_refs << ") " << reason; 357 } 358 if (strong_refs == 0) CHECK_NE(weak_refs, 0u); 359 #else 360 // Use conditionally-important parameters 361 (void)location; 362 (void)reason; 363 refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed); 364 #endif 365 } 366 367 #ifndef NDEBUG 368 const char* trace_; 369 #endif 370 std::atomic<uint64_t> refs_{0}; 371 GPR_NO_UNIQUE_ADDRESS UnrefBehavior unref_behavior_; 372 }; 373 374 } // namespace grpc_core 375 376 #endif // GRPC_SRC_CORE_UTIL_DUAL_REF_COUNTED_H 377