1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_ 17 #define TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_ 18 19 #include <memory> 20 #include <string> 21 22 #include "absl/synchronization/mutex.h" 23 #include "tensorflow/compiler/xla/client/local_client.h" 24 #include "tensorflow/core/framework/resource_mgr.h" 25 #include "tensorflow/core/lib/core/refcount.h" 26 27 namespace tensorflow { 28 29 extern const char* kXRTCompilationCacheResourceName; 30 31 struct XRTCompilationCacheEntry { XRTCompilationCacheEntryXRTCompilationCacheEntry32 explicit XRTCompilationCacheEntry(xla::LocalExecutable* executable) 33 : executable(executable) {} 34 35 // Returns a non-owned pointer to an immutable executable. get_executableXRTCompilationCacheEntry36 xla::LocalExecutable* get_executable() const { return executable; } 37 38 private: 39 xla::LocalExecutable* executable; 40 }; 41 42 // Base class for a reference to a cached executable. A unique_ptr to a 43 // XRTCompilationCacheEntryRef is returned by the cache Lookup methods below, 44 // and ensures the underlying executable is not garbage-collected until the 45 // client discards the ptr. 46 class XRTCompilationCacheEntryRef { 47 public: 48 virtual ~XRTCompilationCacheEntryRef() = default; 49 50 // Returns a XRTCompilationCacheEntry that should not be used beyond the 51 // lifetime of the XRTCompilationCacheEntryRef. 52 virtual XRTCompilationCacheEntry get() = 0; 53 }; 54 55 // Cache for compiled XLA executables. 56 // TODO(b/112646171) rationalize this with the other compilation caches. 57 // 58 // Each key identifies a unique XLA computation, and the value is executable 59 // generated by compiling the computation. 60 // 61 // When a computation is considered for compilation, the client calls 62 // 63 // auto key = <compute key for computation>; 64 // auto compile_function = <lambda to compile computation into executable>; 65 // int64 uid; 66 // CompileIfKeyAbsent(computation_key, &uid, compile_function); 67 // 68 // where computation_key is the key computed for the computation. On success, 69 // uid contains an identifier that can be used to look up the executable. If the 70 // compiled executable were not present in the cache, compile_function would be 71 // called to generate it. 72 // 73 // The caller is responsible for calling Release(uid) once for every 74 // call to CompileIfKeyAbsent(key, ...) to discard the reference to the 75 // compilation results, after the caller is sure it will not look up the 76 // compiled executables again. 77 // 78 // Subsequently the client can call 79 // 80 // std::unique_ptr<XRTCompilationCacheEntryRef> entry; 81 // Lookup(uid, &entry); 82 // auto proto = entry->get(); 83 // 84 // to access a cached executable. 85 class XRTCompilationCache : public ResourceBase { 86 public: 87 // There is no way in general to discover the size taken by an XLA executable, 88 // so the cache defaults to a specific number of entries to determine when to 89 // start evicting programs. TODO(b/112592410) change this if the XLA API gets 90 // a mechanism to query size. 91 explicit XRTCompilationCache(int max_number_of_entries); 92 ~XRTCompilationCache() override; 93 94 // Ensures there is an entry for key present in the cache. By the time 95 // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache 96 // for key, and that entry will remain valid at least until Release is called 97 // on the returned uid. The first call to CompileIfKeyAbsent with a key that 98 // is not in the cache will evaluate compile_function to compute the value to 99 // use in the entry. Subsequent calls with the same key will block until 100 // compile_function completes. Other cache reads and inserts may proceed on 101 // other threads while compile_function is executing. The caller is 102 // responsible for calling Release(uid) to manually discard its reference to 103 // the compiled program, once the caller will not look up the compiled program 104 // again. 105 // 106 // compile_function should compile the computation represented by key and fill 107 // the xla::LocalExecutable into its passed argument. It should return OK 108 // if and only if compilation succeeds. The executable will be discarded on 109 // non-OK status. 110 Status CompileIfKeyAbsent( 111 const string& key, int64* uid, 112 const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>& 113 compile_function); 114 115 Status Release(int64 uid); 116 117 // Looks up an executable corresponding to uid. On success a pointer to an 118 // EntryRef holding the program is returned in entry. 119 Status Lookup(int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry); 120 121 string DebugString() const override; 122 123 private: 124 // An entry in the compilation cache. The entry is deleted once it has been 125 // marked for eviction from the cache _and_ all looked-up entries have been 126 // released. When the entry is first created, it is uninitialized and a 127 // client-supplied compilation function is run outside the cache's lock to 128 // generate the program to be stored in the entry. Any other client that 129 // requests the entry will block until it has been initialized. Each entry has 130 // a last_use value that set from a monotonically-increasing counter in the 131 // cache whenever the entry is referenced. When the cache becomes full, 132 // entries are marked for eviction in LRU order. 133 struct CompiledSubgraph : public core::RefCounted { 134 ~CompiledSubgraph() override = default; 135 136 XRTCompilationCache* parent = nullptr; // Not owned. 137 bool initialized = false; 138 // The Status returned by the compilation function when the entry is 139 // initialized. This status will be returned to any client that requests the 140 // entry. 141 Status initialization_status; 142 // Counter to keep track of LRU entries for the eviction policy. 143 int64 last_use = -1; 144 // The unique key describing this entry. 145 string key; 146 // The uid describing this entry. 147 int64 uid; 148 // The compiled payload corresponding to the key. 149 std::unique_ptr<xla::LocalExecutable> program; 150 }; 151 152 // Wrapper for a cache entry that holds a reference to the entry until the 153 // wrapper is deleted. This wrapper is the concrete type of 154 // XRTCompilationCacheEntryRef returned by Lookup. 155 class EntryRefImpl : public XRTCompilationCacheEntryRef { 156 public: 157 EntryRefImpl(XRTCompilationCache* parent, CompiledSubgraph* entry); 158 ~EntryRefImpl() override; 159 160 XRTCompilationCacheEntry get() override; 161 162 private: 163 XRTCompilationCache* parent_; // Not owned. 164 // A reference to entry_ is acquired in the contructor and released via 165 // parent->DiscardEntryRef in the destructor. 166 CompiledSubgraph* entry_; 167 }; 168 169 // Releases one reference to entry. This is called by the cache when entry is 170 // marked for eviction; or by an EntryRefImpl when it is destroyed. Before the 171 // last reference to entry is released, entry is removed from cache_. 172 void DiscardEntryRef(CompiledSubgraph* entry); 173 void DiscardEntryRefLocked(CompiledSubgraph* entry) 174 EXCLUSIVE_LOCKS_REQUIRED(mu_); 175 176 // Marks the oldest unmarked entry for eviction. Requires that there is at 177 // least one such entry. 178 void MarkOldestEntryForEviction() EXCLUSIVE_LOCKS_REQUIRED(mu_); 179 180 // Updates datastructures to indicate that entry, which had been marked for 181 // eviction, has been looked up. This is called by CompileIfKeyAbsent when an 182 // entry is newly created, or an entry that has been marked for eviction but 183 // not yet evicted is looked up. 184 // 185 // First the entry is unmarked for eviction, i.e. the cache gains a reference 186 // to entry, entry's last_use field is set to be the most recent value of 187 // use_counter_ and entries_by_last_use_ is updated accordingly. 188 // 189 // Next, the size of the cache is examined to see if any other entries need to 190 // be marked for eviction now that entry has been unmarked. While the total 191 // number of unmarked cached entries is greater than max_cache_entries_, 192 // entries are marked for eviction in LRU order. The most recently used entry 193 // is never marked for eviction, so an entry larger than the max cache entries 194 // will remain in the cache until it is replaced by something else. 195 void LookupEntryMarkedForEviction(CompiledSubgraph* entry) 196 EXCLUSIVE_LOCKS_REQUIRED(mu_); 197 198 // Creates a new entry by running initialize_program and places it in the 199 // cache to be looked up by key. The new entry is in the 'marked for eviction' 200 // state (not present in entries_by_last_use_) and the caller is expected to 201 // call LookupEntryMarkedForEviction after InitializeEntry. 202 // 203 // **InitializeEntry releases mu_ during the call to initialize_program.** 204 CompiledSubgraph* InitializeEntry( 205 const string& key, 206 const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>& 207 initialize_program) EXCLUSIVE_LOCKS_REQUIRED(mu_); 208 209 // The maximum number of entries that are stored in the cache before entries 210 // are marked for eviction. 211 const int max_cache_entries_; 212 213 mutable absl::Mutex mu_; 214 // The total number of entries that are stored and not marked for eviction. 215 int cache_entries_ GUARDED_BY(mu_) = 0; 216 // The total number of entries that are marked for eviction. 217 int marked_for_eviction_entries_ GUARDED_BY(mu_) = 0; 218 // The value to assign to the last_use field of the next entry that is looked 219 // up. 220 int64 use_counter_ GUARDED_BY(mu_) = 0; 221 // All the executables that can be looked up in the cache index by key. An 222 // entry is marked for eviction iff it is present in cache_ and not in 223 // entries_by_last_use_. 224 std::unordered_map<string, CompiledSubgraph*> cache_ GUARDED_BY(mu_); 225 // All the executable entries that can be looked up in the cache indexed by 226 // uid. 227 std::unordered_map<int64, CompiledSubgraph*> entries_by_uid_ GUARDED_BY(mu_); 228 // Map from last_use to entry, used to mark entries for eviction in LRU 229 // order. If an entry's last_use counter is not present as a key in 230 // entries_by_last_use_ then the entry has been marked for eviction. 231 std::map<int64, CompiledSubgraph*> entries_by_last_use_ GUARDED_BY(mu_); 232 }; 233 234 } // namespace tensorflow 235 236 #endif // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_ 237