1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_ 17 #define TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_ 18 19 #include <memory> 20 #include <string> 21 22 #include "absl/synchronization/mutex.h" 23 #include "tensorflow/compiler/xla/client/local_client.h" 24 #include "tensorflow/compiler/xla/statusor.h" 25 #include "tensorflow/compiler/xrt/xrt_refptr.h" 26 #include "tensorflow/core/framework/resource_mgr.h" 27 #include "tensorflow/core/lib/core/refcount.h" 28 29 namespace tensorflow { 30 31 extern const char* kXRTCompilationCacheResourceName; 32 33 struct XRTCompilationCacheEntry { XRTCompilationCacheEntryXRTCompilationCacheEntry34 explicit XRTCompilationCacheEntry(xla::LocalExecutable* executable) 35 : executable(executable) {} 36 37 // Returns a non-owned pointer to an immutable executable. get_executableXRTCompilationCacheEntry38 xla::LocalExecutable* get_executable() const { return executable; } 39 40 private: 41 xla::LocalExecutable* executable; 42 }; 43 44 // Base class for a reference to a cached executable. A unique_ptr to a 45 // XRTCompilationCacheEntryRef is returned by the cache Lookup methods below, 46 // and ensures the underlying executable is not garbage-collected until the 47 // client discards the ptr. 48 class XRTCompilationCacheEntryRef { 49 public: 50 virtual ~XRTCompilationCacheEntryRef() = default; 51 52 // Returns a XRTCompilationCacheEntry that should not be used beyond the 53 // lifetime of the XRTCompilationCacheEntryRef. 54 virtual XRTCompilationCacheEntry get() = 0; 55 }; 56 57 // Cache for compiled XLA executables. 58 // TODO(b/112646171) rationalize this with the other compilation caches. 59 // 60 // Each key identifies a unique XLA computation, and the value is executable 61 // generated by compiling the computation. 62 // 63 // When a computation is considered for compilation, the client calls 64 // 65 // auto key = <compute key for computation>; 66 // auto compile_function = <lambda to compile computation into executable>; 67 // int64 uid; 68 // CompileIfKeyAbsent(computation_key, &uid, compile_function); 69 // 70 // where computation_key is the key computed for the computation. On success, 71 // uid contains an identifier that can be used to look up the executable. If the 72 // compiled executable were not present in the cache, compile_function would be 73 // called to generate it. 74 // 75 // The caller is responsible for calling Release(uid) once for every 76 // call to CompileIfKeyAbsent(key, ...) to discard the reference to the 77 // compilation results, after the caller is sure it will not look up the 78 // compiled executables again. 79 // 80 // Subsequently the client can call 81 // 82 // std::unique_ptr<XRTCompilationCacheEntryRef> entry; 83 // Lookup(uid, &entry); 84 // auto proto = entry->get(); 85 // 86 // to access a cached executable. 87 class XRTCompilationCache : public ResourceBase { 88 public: 89 // There is no way in general to discover the size taken by an XLA executable, 90 // so the cache defaults to a specific number of entries to determine when to 91 // start evicting programs. TODO(b/112592410) change this if the XLA API gets 92 // a mechanism to query size. 93 explicit XRTCompilationCache(int max_number_of_entries); 94 ~XRTCompilationCache() override; 95 96 // Ensures there is an entry for key present in the cache. By the time 97 // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache 98 // for key, and that entry will remain valid at least until Release is called 99 // on the returned uid. The first call to CompileIfKeyAbsent with a key that 100 // is not in the cache will evaluate compile_function to compute the value to 101 // use in the entry. Subsequent calls with the same key will block until 102 // compile_function completes. Other cache reads and inserts may proceed on 103 // other threads while compile_function is executing. The caller is 104 // responsible for calling Release(uid) to manually discard its reference to 105 // the compiled program, once the caller will not look up the compiled program 106 // again. 107 // 108 // compile_function should compile the computation represented by key and fill 109 // the xla::LocalExecutable into its passed argument. It should return OK 110 // if and only if compilation succeeds. The executable will be discarded on 111 // non-OK status. 112 Status CompileIfKeyAbsent( 113 const string& key, int64* uid, 114 const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>& 115 compile_function); 116 117 Status Release(int64_t uid); 118 119 // Looks up an executable corresponding to uid. On success a pointer to an 120 // EntryRef holding the program is returned in entry. 121 Status Lookup(int64_t uid, 122 std::unique_ptr<XRTCompilationCacheEntryRef>* entry); 123 124 string DebugString() const override; 125 126 private: 127 // An entry in the compilation cache. The entry is deleted once it has been 128 // marked for eviction from the cache _and_ all looked-up entries have been 129 // released. When the entry is first created, it is uninitialized and a 130 // client-supplied compilation function is run outside the cache's lock to 131 // generate the program to be stored in the entry. Any other client that 132 // requests the entry will block until it has been initialized. Each entry has 133 // a last_use value that set from a monotonically-increasing counter in the 134 // cache whenever the entry is referenced. When the cache becomes full, 135 // entries are marked for eviction in LRU order. 136 struct CompiledSubgraph : public core::RefCounted { 137 ~CompiledSubgraph() override = default; 138 139 XRTCompilationCache* parent = nullptr; // Not owned. 140 bool initialized = false; 141 // The Status returned by the compilation function when the entry is 142 // initialized. This status will be returned to any client that requests the 143 // entry. 144 Status initialization_status; 145 // Counter to keep track of LRU entries for the eviction policy. 146 int64 last_use = -1; 147 // The unique key describing this entry. 148 string key; 149 // The uid describing this entry. 150 int64 uid; 151 // The compiled payload corresponding to the key. 152 std::unique_ptr<xla::LocalExecutable> program; 153 }; 154 155 // Wrapper for a cache entry that holds a reference to the entry until the 156 // wrapper is deleted. This wrapper is the concrete type of 157 // XRTCompilationCacheEntryRef returned by Lookup. 158 class EntryRefImpl : public XRTCompilationCacheEntryRef { 159 public: 160 EntryRefImpl(XRTCompilationCache* parent, CompiledSubgraph* entry); 161 ~EntryRefImpl() override; 162 163 XRTCompilationCacheEntry get() override; 164 165 private: 166 XRTCompilationCache* parent_; // Not owned. 167 // A reference to entry_ is acquired in the contructor and released via 168 // parent->DiscardEntryRef in the destructor. 169 CompiledSubgraph* entry_; 170 }; 171 172 // Releases one reference to entry. This is called by the cache when entry is 173 // marked for eviction; or by an EntryRefImpl when it is destroyed. Before the 174 // last reference to entry is released, entry is removed from cache_. 175 void DiscardEntryRef(CompiledSubgraph* entry); 176 void DiscardEntryRefLocked(CompiledSubgraph* entry) 177 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 178 179 // Marks the oldest unmarked entry for eviction. Requires that there is at 180 // least one such entry. 181 void MarkOldestEntryForEviction() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 182 183 // Updates datastructures to indicate that entry, which had been marked for 184 // eviction, has been looked up. This is called by CompileIfKeyAbsent when an 185 // entry is newly created, or an entry that has been marked for eviction but 186 // not yet evicted is looked up. 187 // 188 // First the entry is unmarked for eviction, i.e. the cache gains a reference 189 // to entry, entry's last_use field is set to be the most recent value of 190 // use_counter_ and entries_by_last_use_ is updated accordingly. 191 // 192 // Next, the size of the cache is examined to see if any other entries need to 193 // be marked for eviction now that entry has been unmarked. While the total 194 // number of unmarked cached entries is greater than max_cache_entries_, 195 // entries are marked for eviction in LRU order. The most recently used entry 196 // is never marked for eviction, so an entry larger than the max cache entries 197 // will remain in the cache until it is replaced by something else. 198 void LookupEntryMarkedForEviction(CompiledSubgraph* entry) 199 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 200 201 // Creates a new entry by running initialize_program and places it in the 202 // cache to be looked up by key. The new entry is in the 'marked for eviction' 203 // state (not present in entries_by_last_use_) and the caller is expected to 204 // call LookupEntryMarkedForEviction after InitializeEntry. 205 // 206 // **InitializeEntry releases mu_ during the call to initialize_program.** 207 CompiledSubgraph* InitializeEntry( 208 const string& key, 209 const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>& 210 initialize_program) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 211 212 // The maximum number of entries that are stored in the cache before entries 213 // are marked for eviction. 214 const int max_cache_entries_; 215 216 mutable absl::Mutex mu_; 217 // The total number of entries that are stored and not marked for eviction. 218 int cache_entries_ TF_GUARDED_BY(mu_) = 0; 219 // The total number of entries that are marked for eviction. 220 int marked_for_eviction_entries_ TF_GUARDED_BY(mu_) = 0; 221 // The value to assign to the last_use field of the next entry that is looked 222 // up. 223 int64 use_counter_ TF_GUARDED_BY(mu_) = 0; 224 // All the executables that can be looked up in the cache index by key. An 225 // entry is marked for eviction iff it is present in cache_ and not in 226 // entries_by_last_use_. 227 std::unordered_map<string, CompiledSubgraph*> cache_ TF_GUARDED_BY(mu_); 228 // All the executable entries that can be looked up in the cache indexed by 229 // uid. 230 std::unordered_map<int64, CompiledSubgraph*> entries_by_uid_ 231 TF_GUARDED_BY(mu_); 232 // Map from last_use to entry, used to mark entries for eviction in LRU 233 // order. If an entry's last_use counter is not present as a key in 234 // entries_by_last_use_ then the entry has been marked for eviction. 235 std::map<int64, CompiledSubgraph*> entries_by_last_use_ TF_GUARDED_BY(mu_); 236 }; 237 238 // Looks up or create an XRTCompilationCache object within the given resource 239 // manager, under the default container. The max_number_of_entries sets the 240 // maximum number of entries within the cache (which will be LRU-evicted). 241 // If max_number_of_entries is set to sero, the size of the cache will be 242 // configured using the TF_XRT_COMPILATION_CACHE_SIZE environment variable. 243 xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache( 244 ResourceMgr* rm, int64_t max_number_of_entries); 245 246 } // namespace tensorflow 247 248 #endif // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_ 249