1 //===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a definition of the ThreadLocalCache class. This class 10 // provides support for defining thread local objects with non-static duration. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_SUPPORT_THREADLOCALCACHE_H 15 #define MLIR_SUPPORT_THREADLOCALCACHE_H 16 17 #include "mlir/Support/LLVM.h" 18 #include "llvm/ADT/DenseMap.h" 19 #include "llvm/Support/ManagedStatic.h" 20 #include "llvm/Support/Mutex.h" 21 #include "llvm/Support/ThreadLocal.h" 22 23 namespace mlir { 24 /// This class provides support for defining a thread local object with non 25 /// static storage duration. This is very useful for situations in which a data 26 /// cache has very large lock contention. 27 template <typename ValueT> 28 class ThreadLocalCache { 29 /// The type used for the static thread_local cache. This is a map between an 30 /// instance of the non-static cache and a weak reference to an instance of 31 /// ValueT. We use a weak reference here so that the object can be destroyed 32 /// without needing to lock access to the cache itself. 33 struct CacheType : public llvm::SmallDenseMap<ThreadLocalCache<ValueT> *, 34 std::weak_ptr<ValueT>> { ~CacheTypeCacheType35 ~CacheType() { 36 // Remove the values of this cache that haven't already expired. 37 for (auto &it : *this) 38 if (std::shared_ptr<ValueT> value = it.second.lock()) 39 it.first->remove(value.get()); 40 } 41 42 /// Clear out any unused entries within the map. This method is not 43 /// thread-safe, and should only be called by the same thread as the cache. clearExpiredEntriesCacheType44 void clearExpiredEntries() { 45 for (auto it = this->begin(), e = this->end(); it != e;) { 46 auto curIt = it++; 47 if (curIt->second.expired()) 48 this->erase(curIt); 49 } 50 } 51 }; 52 53 public: 54 ThreadLocalCache() = default; ~ThreadLocalCache()55 ~ThreadLocalCache() { 56 // No cleanup is necessary here as the shared_pointer memory will go out of 57 // scope and invalidate the weak pointers held by the thread_local caches. 58 } 59 60 /// Return an instance of the value type for the current thread. get()61 ValueT &get() { 62 // Check for an already existing instance for this thread. 63 CacheType &staticCache = getStaticCache(); 64 std::weak_ptr<ValueT> &threadInstance = staticCache[this]; 65 if (std::shared_ptr<ValueT> value = threadInstance.lock()) 66 return *value; 67 68 // Otherwise, create a new instance for this thread. 69 llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex); 70 instances.push_back(std::make_shared<ValueT>()); 71 std::shared_ptr<ValueT> &instance = instances.back(); 72 threadInstance = instance; 73 74 // Before returning the new instance, take the chance to clear out any used 75 // entries in the static map. The cache is only cleared within the same 76 // thread to remove the need to lock the cache itself. 77 staticCache.clearExpiredEntries(); 78 return *instance; 79 } 80 ValueT &operator*() { return get(); } 81 ValueT *operator->() { return &get(); } 82 83 private: 84 ThreadLocalCache(ThreadLocalCache &&) = delete; 85 ThreadLocalCache(const ThreadLocalCache &) = delete; 86 ThreadLocalCache &operator=(const ThreadLocalCache &) = delete; 87 88 /// Return the static thread local instance of the cache type. getStaticCache()89 static CacheType &getStaticCache() { 90 static LLVM_THREAD_LOCAL CacheType cache; 91 return cache; 92 } 93 94 /// Remove the given value entry. This is generally called when a thread local 95 /// cache is destructing. remove(ValueT * value)96 void remove(ValueT *value) { 97 // Erase the found value directly, because it is guaranteed to be in the 98 // list. 99 llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex); 100 auto it = llvm::find_if(instances, [&](std::shared_ptr<ValueT> &instance) { 101 return instance.get() == value; 102 }); 103 assert(it != instances.end() && "expected value to exist in cache"); 104 instances.erase(it); 105 } 106 107 /// Owning pointers to all of the values that have been constructed for this 108 /// object in the static cache. 109 SmallVector<std::shared_ptr<ValueT>, 1> instances; 110 111 /// A mutex used when a new thread instance has been added to the cache for 112 /// this object. 113 llvm::sys::SmartMutex<true> instanceMutex; 114 }; 115 } // end namespace mlir 116 117 #endif // MLIR_SUPPORT_THREADLOCALCACHE_H 118