• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a definition of the ThreadLocalCache class. This class
10 // provides support for defining thread local objects with non-static duration.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
15 #define MLIR_SUPPORT_THREADLOCALCACHE_H
16 
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Mutex.h"
21 #include "llvm/Support/ThreadLocal.h"
22 
23 namespace mlir {
24 /// This class provides support for defining a thread local object with non
25 /// static storage duration. This is very useful for situations in which a data
26 /// cache has very large lock contention.
27 template <typename ValueT>
28 class ThreadLocalCache {
29   /// The type used for the static thread_local cache. This is a map between an
30   /// instance of the non-static cache and a weak reference to an instance of
31   /// ValueT. We use a weak reference here so that the object can be destroyed
32   /// without needing to lock access to the cache itself.
33   struct CacheType : public llvm::SmallDenseMap<ThreadLocalCache<ValueT> *,
34                                                 std::weak_ptr<ValueT>> {
~CacheTypeCacheType35     ~CacheType() {
36       // Remove the values of this cache that haven't already expired.
37       for (auto &it : *this)
38         if (std::shared_ptr<ValueT> value = it.second.lock())
39           it.first->remove(value.get());
40     }
41 
42     /// Clear out any unused entries within the map. This method is not
43     /// thread-safe, and should only be called by the same thread as the cache.
clearExpiredEntriesCacheType44     void clearExpiredEntries() {
45       for (auto it = this->begin(), e = this->end(); it != e;) {
46         auto curIt = it++;
47         if (curIt->second.expired())
48           this->erase(curIt);
49       }
50     }
51   };
52 
53 public:
54   ThreadLocalCache() = default;
~ThreadLocalCache()55   ~ThreadLocalCache() {
56     // No cleanup is necessary here as the shared_pointer memory will go out of
57     // scope and invalidate the weak pointers held by the thread_local caches.
58   }
59 
60   /// Return an instance of the value type for the current thread.
get()61   ValueT &get() {
62     // Check for an already existing instance for this thread.
63     CacheType &staticCache = getStaticCache();
64     std::weak_ptr<ValueT> &threadInstance = staticCache[this];
65     if (std::shared_ptr<ValueT> value = threadInstance.lock())
66       return *value;
67 
68     // Otherwise, create a new instance for this thread.
69     llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
70     instances.push_back(std::make_shared<ValueT>());
71     std::shared_ptr<ValueT> &instance = instances.back();
72     threadInstance = instance;
73 
74     // Before returning the new instance, take the chance to clear out any used
75     // entries in the static map. The cache is only cleared within the same
76     // thread to remove the need to lock the cache itself.
77     staticCache.clearExpiredEntries();
78     return *instance;
79   }
80   ValueT &operator*() { return get(); }
81   ValueT *operator->() { return &get(); }
82 
83 private:
84   ThreadLocalCache(ThreadLocalCache &&) = delete;
85   ThreadLocalCache(const ThreadLocalCache &) = delete;
86   ThreadLocalCache &operator=(const ThreadLocalCache &) = delete;
87 
88   /// Return the static thread local instance of the cache type.
getStaticCache()89   static CacheType &getStaticCache() {
90     static LLVM_THREAD_LOCAL CacheType cache;
91     return cache;
92   }
93 
94   /// Remove the given value entry. This is generally called when a thread local
95   /// cache is destructing.
remove(ValueT * value)96   void remove(ValueT *value) {
97     // Erase the found value directly, because it is guaranteed to be in the
98     // list.
99     llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
100     auto it = llvm::find_if(instances, [&](std::shared_ptr<ValueT> &instance) {
101       return instance.get() == value;
102     });
103     assert(it != instances.end() && "expected value to exist in cache");
104     instances.erase(it);
105   }
106 
107   /// Owning pointers to all of the values that have been constructed for this
108   /// object in the static cache.
109   SmallVector<std::shared_ptr<ValueT>, 1> instances;
110 
111   /// A mutex used when a new thread instance has been added to the cache for
112   /// this object.
113   llvm::sys::SmartMutex<true> instanceMutex;
114 };
115 } // end namespace mlir
116 
117 #endif // MLIR_SUPPORT_THREADLOCALCACHE_H
118