1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef META_BASE_SHARED_PTR_H
17 #define META_BASE_SHARED_PTR_H
18
19 #include <meta/base/meta_types.h>
20
21 #include "shared_ptr_internals.h"
22
BASE_BEGIN_NAMESPACE()23 BASE_BEGIN_NAMESPACE()
24
25 /**
26 * @brief C++ standard like weak_ptr.
27 */
28 template<typename T>
29 class weak_ptr final : public Internals::PtrCountedBase<T> {
30 public:
31 using element_type = BASE_NS::remove_extent_t<T>;
32 weak_ptr() = default;
33 ~weak_ptr()
34 {
35 if (this->control_) {
36 this->control_->ReleaseWeak();
37 }
38 }
39
40 weak_ptr(nullptr_t) {}
41 weak_ptr(const shared_ptr<T>& p) : Internals::PtrCountedBase<T>(p)
42 {
43 if (this->control_) {
44 this->control_->AddWeak();
45 }
46 }
47 weak_ptr(const weak_ptr& p) noexcept : Internals::PtrCountedBase<T>(p)
48 {
49 if (this->control_) {
50 this->control_->AddWeak();
51 }
52 }
53 weak_ptr(weak_ptr&& p) noexcept : Internals::PtrCountedBase<T>(p)
54 {
55 p.InternalReset();
56 }
57
58 // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
59 shared_ptr<T> lock() const
60 {
61 return shared_ptr<T>(*this);
62 }
63 weak_ptr& operator=(weak_ptr&& p) noexcept
64 {
65 if (this != &p) {
66 reset();
67 this->control_ = p.control_;
68 this->pointer_ = p.pointer_;
69 p.InternalReset();
70 }
71 return *this;
72 }
73 weak_ptr& operator=(const weak_ptr& p) noexcept
74 {
75 if (this != &p) {
76 reset();
77 this->control_ = p.control_;
78 this->pointer_ = p.pointer_;
79 if (this->control_) {
80 this->control_->AddWeak();
81 }
82 }
83 return *this;
84 }
85 weak_ptr& operator=(const shared_ptr<T>& p)
86 {
87 reset();
88 this->control_ = p.control_;
89 this->pointer_ = p.pointer_;
90 if (this->control_) {
91 this->control_->AddWeak();
92 }
93 return *this;
94 }
95
96 weak_ptr& operator=(nullptr_t) noexcept
97 {
98 reset();
99 return *this;
100 }
101 // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
102 void reset()
103 {
104 if (this->control_) {
105 this->control_->ReleaseWeak();
106 this->InternalReset();
107 }
108 }
109
110 /*"implicit" casting constructors */
111 template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
112 weak_ptr(const shared_ptr<U>& p)
113 // handle casting by using functionality in shared_ptr. (creates an aliased shared_ptr to original.)
114 : weak_ptr(shared_ptr<T>(p))
115 {}
116
117 template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
118 weak_ptr(const weak_ptr<U>& p) : weak_ptr(shared_ptr<T>(p.lock()))
119 {}
120
121 /* "implicit" casting move */
122 template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
123 weak_ptr(weak_ptr<U>&& p) noexcept : weak_ptr(shared_ptr<T>(p.lock()))
124 {
125 p.reset();
126 }
127
128 /* "implicit" casting operators */
129 template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
130 weak_ptr& operator=(const shared_ptr<U>& p)
131 {
132 // handle casting by using functionality in shared_ptr. (creates an aliased shared_ptr to original.)
133 *this = shared_ptr<T>(p);
134 return *this;
135 }
136
137 template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
138 weak_ptr& operator=(const weak_ptr<U>& p)
139 {
140 // first lock the given weak ptr. (to see if it has expired, and to get a pointer that can be cast)
141 *this = shared_ptr<T>(p.lock());
142 return *this;
143 }
144
145 // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
146 bool expired() const noexcept
147 {
148 return this->use_count() == 0;
149 }
150
151 private:
152 friend class shared_ptr<T>;
153 template<typename>
154 friend class weak_ptr;
155 };
156
157 /**
158 * @brief C++ standard like shared_ptr with IInterface support for reference counting.
159 */
160 template<typename T>
161 class shared_ptr final : public Internals::PtrCountedBase<T> {
162 public:
163 using element_type = BASE_NS::remove_extent_t<T>;
164 using weak_type = weak_ptr<T>;
165
166 constexpr shared_ptr() noexcept = default;
shared_ptr(nullptr_t)167 constexpr shared_ptr(nullptr_t) noexcept {}
shared_ptr(const shared_ptr & p)168 shared_ptr(const shared_ptr& p) noexcept : Internals::PtrCountedBase<T>(p)
169 {
170 if (this->control_) {
171 this->control_->AddStrongCopy();
172 }
173 }
shared_ptr(shared_ptr && p)174 shared_ptr(shared_ptr&& p) noexcept : Internals::PtrCountedBase<T>(p)
175 {
176 p.InternalReset();
177 }
shared_ptr(T * ptr)178 explicit shared_ptr(T* ptr)
179 {
180 if (ptr) {
181 ConstructBlock(ptr);
182 }
183 }
184
185 template<typename Deleter>
shared_ptr(T * ptr,Deleter deleter)186 shared_ptr(T* ptr, Deleter deleter)
187 {
188 if (ptr) {
189 ConstructBlock(ptr, BASE_NS::move(deleter));
190 }
191 }
192
shared_ptr(const weak_type & p)193 explicit shared_ptr(const weak_type& p) noexcept : Internals::PtrCountedBase<T>(p)
194 {
195 if (this->control_) {
196 if (!this->control_->AddStrongLock()) {
197 this->InternalReset();
198 }
199 }
200 }
201 template<class Y>
shared_ptr(const shared_ptr<Y> & r,T * ptr)202 shared_ptr(const shared_ptr<Y>& r, T* ptr) noexcept : Internals::PtrCountedBase<T>(r.control_)
203 {
204 if (this->control_ && ptr) {
205 this->control_->AddStrongCopy();
206 this->pointer_ = const_cast<deletableType*>(ptr);
207 } else {
208 this->InternalReset();
209 }
210 }
211 template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
shared_ptr(shared_ptr<U> && p)212 shared_ptr(shared_ptr<U>&& p) noexcept : Internals::PtrCountedBase<T>(p.control_)
213 {
214 if (this->control_) {
215 void* ptr = nullptr;
216 if constexpr (BASE_NS::is_same_v<T, const BASE_NS::remove_const_t<U>> ||
217 !META_NS::HasGetInterfaceMethod_v<U>) {
218 ptr = p.get();
219 } else {
220 // make a proper interface cast here.
221 if constexpr (BASE_NS::is_const_v<T>) {
222 ptr = const_cast<void*>(static_cast<const void*>(p->GetInterface(T::UID)));
223 } else {
224 ptr = static_cast<void*>(p->GetInterface(T::UID));
225 }
226 }
227 if (ptr) {
228 this->pointer_ = static_cast<deletableType*>(ptr);
229 p.InternalReset();
230 } else {
231 this->InternalReset();
232 p.reset();
233 }
234 }
235 }
236 template<class U, class = Internals::EnableIfPointerConvertible<U, T>>
shared_ptr(const shared_ptr<U> & p)237 shared_ptr(const shared_ptr<U>& p) noexcept : shared_ptr(shared_ptr<U>(p)) // use the above move constructor
238 {}
239
240 template<class U, class D, class = Internals::EnableIfPointerConvertible<U, T>>
shared_ptr(unique_ptr<U,D> && p)241 shared_ptr(unique_ptr<U, D>&& p) noexcept
242 {
243 if (p) {
244 ConstructBlock(p.release(), BASE_NS::move(p.get_deleter()));
245 }
246 }
247
~shared_ptr()248 ~shared_ptr()
249 {
250 if (this->control_) {
251 this->control_->Release();
252 }
253 }
254 T* operator->() const noexcept
255 {
256 return get();
257 }
258 T& operator*() const noexcept
259 {
260 return *get();
261 }
262 explicit operator bool() const
263 {
264 return get();
265 }
266 bool operator==(const shared_ptr& other) const noexcept
267 {
268 return get() == other.get();
269 }
270 bool operator!=(const shared_ptr& other) const noexcept
271 {
272 return !(*this == other);
273 }
274 // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
reset()275 void reset()
276 {
277 if (this->control_) {
278 this->control_->Release();
279 this->InternalReset();
280 }
281 }
282 // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
reset(T * ptr)283 void reset(T* ptr)
284 {
285 if (ptr != this->pointer_) {
286 reset();
287 if (ptr) {
288 ConstructBlock(ptr);
289 }
290 }
291 }
292 template<typename Deleter>
293 // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
reset(T * ptr,Deleter deleter)294 void reset(T* ptr, Deleter deleter)
295 {
296 if (ptr != this->pointer_) {
297 reset();
298 if (ptr) {
299 ConstructBlock(ptr, BASE_NS::move(deleter));
300 }
301 }
302 }
303
304 shared_ptr& operator=(nullptr_t) noexcept
305 {
306 reset();
307 return *this;
308 }
309 shared_ptr& operator=(const shared_ptr& o)
310 {
311 if (this != &o) {
312 reset();
313 this->control_ = o.control_;
314 this->pointer_ = o.pointer_;
315 if (this->control_) {
316 this->control_->AddStrongCopy();
317 }
318 }
319 return *this;
320 }
321 shared_ptr& operator=(shared_ptr&& o) noexcept
322 {
323 if (this != &o) {
324 reset();
325 this->control_ = o.control_;
326 this->pointer_ = o.pointer_;
327 o.InternalReset();
328 }
329 return *this;
330 }
331 // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
swap(shared_ptr & p)332 void swap(shared_ptr& p)
333 {
334 auto tp = p.pointer_;
335 auto tc = p.control_;
336 p.pointer_ = this->pointer_;
337 p.control_ = this->control_;
338 this->pointer_ = tp;
339 this->control_ = tc;
340 }
341 // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
get()342 element_type* get() const noexcept
343 {
344 return this->pointer_;
345 }
346
347 // NOLINTNEXTLINE(readability-identifier-naming) to keep std like syntax
weak_count()348 int32_t weak_count() const
349 {
350 if (!this->control_) {
351 return 0;
352 }
353 int32_t c = this->control_->GetWeakCount();
354 if (c > 0) {
355 // strong references are counted as single weak reference for the internal bookkeeping.
356 --c;
357 }
358 return c;
359 }
360
361 private:
362 using deletableType = BASE_NS::remove_const_t<T>;
363
ConstructBlock(T * ptr)364 void ConstructBlock(T* ptr)
365 {
366 static_assert(sizeof(T), "type has to be complete when constructing control block");
367 if constexpr (__is_base_of(CORE_NS::IInterface, T)) {
368 this->control_ = new Internals::RefCountedObjectStorageBlock(ptr);
369 } else {
370 this->control_ = new Internals::StorageBlock(ptr);
371 }
372 this->pointer_ = ptr;
373 }
374 template<typename Deleter>
ConstructBlock(T * ptr,Deleter deleter)375 void ConstructBlock(T* ptr, Deleter deleter)
376 {
377 this->control_ = new Internals::StorageBlockWithDeleter(ptr, BASE_NS::move(deleter));
378 this->pointer_ = ptr;
379 }
380
381 template<typename>
382 friend class weak_ptr;
383 template<typename>
384 friend class shared_ptr;
385 };
386
BASE_END_NAMESPACE()387 BASE_END_NAMESPACE()
388
389 // NOLINTBEGIN(readability-identifier-naming) to keep std like syntax
390 template<class U, class T>
391 BASE_NS::shared_ptr<U> static_pointer_cast(const BASE_NS::shared_ptr<T>& ptr)
392 {
393 if (ptr) {
394 return BASE_NS::shared_ptr<U>(ptr, static_cast<U*>(ptr.get()));
395 }
396 return {};
397 }
398
399 template<class U, class T>
interface_pointer_cast(const BASE_NS::shared_ptr<T> & ptr)400 BASE_NS::shared_ptr<U> interface_pointer_cast(const BASE_NS::shared_ptr<T>& ptr)
401 {
402 static_assert(META_NS::HasGetInterfaceMethod_v<T>, "T::GetInterface not defined");
403 static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
404 if (ptr) {
405 if constexpr (BASE_NS::is_same_v<U, T>) {
406 // same type.
407 return ptr;
408 } else {
409 return BASE_NS::shared_ptr<U>(ptr, static_cast<U*>(static_cast<void*>(ptr->GetInterface(U::UID))));
410 }
411 }
412 return {};
413 }
414
415 template<class U, class T>
interface_pointer_cast(const BASE_NS::shared_ptr<const T> & ptr)416 BASE_NS::shared_ptr<const U> interface_pointer_cast(const BASE_NS::shared_ptr<const T>& ptr)
417 {
418 static_assert(META_NS::HasGetInterfaceMethod_v<T>, "T::GetInterface not defined");
419 static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
420 if (ptr) {
421 if constexpr (BASE_NS::is_same_v<U, T>) {
422 // same type.
423 return ptr;
424 } else {
425 return BASE_NS::shared_ptr<const U>(
426 ptr, static_cast<const U*>(static_cast<const void*>(ptr->GetInterface(U::UID))));
427 }
428 }
429 return {};
430 }
431
432 template<class U, class T>
interface_pointer_cast(const BASE_NS::weak_ptr<T> & weak)433 BASE_NS::shared_ptr<U> interface_pointer_cast(const BASE_NS::weak_ptr<T>& weak)
434 {
435 return interface_pointer_cast<U>(weak.lock());
436 }
437
438 template<class U, class T>
interface_pointer_cast(const BASE_NS::weak_ptr<const T> & weak)439 BASE_NS::shared_ptr<const U> interface_pointer_cast(const BASE_NS::weak_ptr<const T>& weak)
440 {
441 return interface_pointer_cast<const U>(weak.lock());
442 }
443
444 template<class U, class T>
interface_cast(const BASE_NS::shared_ptr<T> & ptr)445 U* interface_cast(const BASE_NS::shared_ptr<T>& ptr)
446 {
447 static_assert(META_NS::HasGetInterfaceMethod_v<T>, "T::GetInterface not defined");
448 static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
449 if (ptr) {
450 if constexpr (BASE_NS::is_same_v<U, T>) {
451 // same type.
452 return ptr.get();
453 } else {
454 return static_cast<U*>(static_cast<void*>(ptr->GetInterface(U::UID)));
455 }
456 }
457 return {};
458 }
459
460 template<class U, class T>
interface_cast(const BASE_NS::shared_ptr<const T> & ptr)461 const U* interface_cast(const BASE_NS::shared_ptr<const T>& ptr)
462 {
463 static_assert(META_NS::HasGetInterfaceMethod_v<T>, "T::GetInterface not defined");
464 static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
465 if (ptr) {
466 if constexpr (BASE_NS::is_same_v<U, T>) {
467 // same type.
468 return ptr.get();
469 } else {
470 return static_cast<const U*>(static_cast<const void*>(ptr->GetInterface(U::UID)));
471 }
472 }
473 return {};
474 }
475
476 template<class U>
interface_cast(CORE_NS::IInterface * ptr)477 U* interface_cast(CORE_NS::IInterface* ptr)
478 {
479 static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
480 if (ptr) {
481 return static_cast<U*>(static_cast<void*>(ptr->GetInterface(U::UID)));
482 }
483 return {};
484 }
485
486 template<class U>
interface_cast(const CORE_NS::IInterface * ptr)487 const U* interface_cast(const CORE_NS::IInterface* ptr)
488 {
489 static_assert(META_NS::HasGetInterfaceMethod_v<U>, "U::GetInterface not defined");
490 if (ptr) {
491 return static_cast<const U*>(static_cast<const void*>(ptr->GetInterface(U::UID)));
492 }
493 return {};
494 }
495 // NOLINTEND(readability-identifier-naming) to keep std like syntax
496
497 template<typename T, typename... Args>
CreateShared(Args &&...args)498 BASE_NS::shared_ptr<T> CreateShared(Args&&... args)
499 {
500 return BASE_NS::shared_ptr<T>(new T(BASE_NS::forward<Args>(args)...));
501 }
502
503 META_TYPE(BASE_NS::shared_ptr<const CORE_NS::IInterface>)
504 META_TYPE(BASE_NS::shared_ptr<CORE_NS::IInterface>)
505 META_TYPE(BASE_NS::weak_ptr<const CORE_NS::IInterface>)
506 META_TYPE(BASE_NS::weak_ptr<CORE_NS::IInterface>)
507
508 #endif
509