• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
17 #define TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
18 
19 #include <memory>
20 #include <string>
21 
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/core/framework/resource_mgr.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 
27 namespace tensorflow {
28 
29 extern const char* kXRTCompilationCacheResourceName;
30 
31 struct XRTCompilationCacheEntry {
XRTCompilationCacheEntryXRTCompilationCacheEntry32   explicit XRTCompilationCacheEntry(xla::LocalExecutable* executable)
33       : executable(executable) {}
34 
35   // Returns a non-owned pointer to an immutable executable.
get_executableXRTCompilationCacheEntry36   xla::LocalExecutable* get_executable() const { return executable; }
37 
38  private:
39   xla::LocalExecutable* executable;
40 };
41 
42 // Base class for a reference to a cached executable. A unique_ptr to a
43 // XRTCompilationCacheEntryRef is returned by the cache Lookup methods below,
44 // and ensures the underlying executable is not garbage-collected until the
45 // client discards the ptr.
46 class XRTCompilationCacheEntryRef {
47  public:
48   virtual ~XRTCompilationCacheEntryRef() = default;
49 
50   // Returns a XRTCompilationCacheEntry that should not be used beyond the
51   // lifetime of the XRTCompilationCacheEntryRef.
52   virtual XRTCompilationCacheEntry get() = 0;
53 };
54 
55 // Cache for compiled XLA executables.
56 // TODO(b/112646171) rationalize this with the other compilation caches.
57 //
58 // Each key identifies a unique XLA computation, and the value is executable
59 // generated by compiling the computation.
60 //
61 // When a computation is considered for compilation, the client calls
62 //
63 // auto key = <compute key for computation>;
64 // auto compile_function = <lambda to compile computation into executable>;
65 // int64 uid;
66 // CompileIfKeyAbsent(computation_key, &uid, compile_function);
67 //
68 // where computation_key is the key computed for the computation. On success,
69 // uid contains an identifier that can be used to look up the executable. If the
70 // compiled executable were not present in the cache, compile_function would be
71 // called to generate it.
72 //
73 // The caller is responsible for calling Release(uid) once for every
74 // call to CompileIfKeyAbsent(key, ...) to discard the reference to the
75 // compilation results, after the caller is sure it will not look up the
76 // compiled executables again.
77 //
78 // Subsequently the client can call
79 //
80 // std::unique_ptr<XRTCompilationCacheEntryRef> entry;
81 // Lookup(uid, &entry);
82 // auto proto = entry->get();
83 //
84 // to access a cached executable.
85 class XRTCompilationCache : public ResourceBase {
86  public:
87   // There is no way in general to discover the size taken by an XLA executable,
88   // so the cache defaults to a specific number of entries to determine when to
89   // start evicting programs. TODO(b/112592410) change this if the XLA API gets
90   // a mechanism to query size.
91   explicit XRTCompilationCache(int max_number_of_entries);
92   ~XRTCompilationCache() override;
93 
94   // Ensures there is an entry for key present in the cache. By the time
95   // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
96   // for key, and that entry will remain valid at least until Release is called
97   // on the returned uid. The first call to CompileIfKeyAbsent with a key that
98   // is not in the cache will evaluate compile_function to compute the value to
99   // use in the entry. Subsequent calls with the same key will block until
100   // compile_function completes. Other cache reads and inserts may proceed on
101   // other threads while compile_function is executing. The caller is
102   // responsible for calling Release(uid) to manually discard its reference to
103   // the compiled program, once the caller will not look up the compiled program
104   // again.
105   //
106   // compile_function should compile the computation represented by key and fill
107   // the xla::LocalExecutable into its passed argument. It should return OK
108   // if and only if compilation succeeds. The executable will be discarded on
109   // non-OK status.
110   Status CompileIfKeyAbsent(
111       const string& key, int64* uid,
112       const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
113           compile_function);
114 
115   Status Release(int64 uid);
116 
117   // Looks up an executable corresponding to uid. On success a pointer to an
118   // EntryRef holding the program is returned in entry.
119   Status Lookup(int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry);
120 
121   string DebugString() const override;
122 
123  private:
124   // An entry in the compilation cache. The entry is deleted once it has been
125   // marked for eviction from the cache _and_ all looked-up entries have been
126   // released. When the entry is first created, it is uninitialized and a
127   // client-supplied compilation function is run outside the cache's lock to
128   // generate the program to be stored in the entry. Any other client that
129   // requests the entry will block until it has been initialized. Each entry has
130   // a last_use value that set from a monotonically-increasing counter in the
131   // cache whenever the entry is referenced. When the cache becomes full,
132   // entries are marked for eviction in LRU order.
133   struct CompiledSubgraph : public core::RefCounted {
134     ~CompiledSubgraph() override = default;
135 
136     XRTCompilationCache* parent = nullptr;  // Not owned.
137     bool initialized = false;
138     // The Status returned by the compilation function when the entry is
139     // initialized. This status will be returned to any client that requests the
140     // entry.
141     Status initialization_status;
142     // Counter to keep track of LRU entries for the eviction policy.
143     int64 last_use = -1;
144     // The unique key describing this entry.
145     string key;
146     // The uid describing this entry.
147     int64 uid;
148     // The compiled payload corresponding to the key.
149     std::unique_ptr<xla::LocalExecutable> program;
150   };
151 
152   // Wrapper for a cache entry that holds a reference to the entry until the
153   // wrapper is deleted. This wrapper is the concrete type of
154   // XRTCompilationCacheEntryRef returned by Lookup.
155   class EntryRefImpl : public XRTCompilationCacheEntryRef {
156    public:
157     EntryRefImpl(XRTCompilationCache* parent, CompiledSubgraph* entry);
158     ~EntryRefImpl() override;
159 
160     XRTCompilationCacheEntry get() override;
161 
162    private:
163     XRTCompilationCache* parent_;  // Not owned.
164     // A reference to entry_ is acquired in the contructor and released via
165     // parent->DiscardEntryRef in the destructor.
166     CompiledSubgraph* entry_;
167   };
168 
169   // Releases one reference to entry. This is called by the cache when entry is
170   // marked for eviction; or by an EntryRefImpl when it is destroyed. Before the
171   // last reference to entry is released, entry is removed from cache_.
172   void DiscardEntryRef(CompiledSubgraph* entry);
173   void DiscardEntryRefLocked(CompiledSubgraph* entry)
174       EXCLUSIVE_LOCKS_REQUIRED(mu_);
175 
176   // Marks the oldest unmarked entry for eviction. Requires that there is at
177   // least one such entry.
178   void MarkOldestEntryForEviction() EXCLUSIVE_LOCKS_REQUIRED(mu_);
179 
180   // Updates datastructures to indicate that entry, which had been marked for
181   // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
182   // entry is newly created, or an entry that has been marked for eviction but
183   // not yet evicted is looked up.
184   //
185   // First the entry is unmarked for eviction, i.e. the cache gains a reference
186   // to entry, entry's last_use field is set to be the most recent value of
187   // use_counter_ and entries_by_last_use_ is updated accordingly.
188   //
189   // Next, the size of the cache is examined to see if any other entries need to
190   // be marked for eviction now that entry has been unmarked. While the total
191   // number of unmarked cached entries is greater than max_cache_entries_,
192   // entries are marked for eviction in LRU order. The most recently used entry
193   // is never marked for eviction, so an entry larger than the max cache entries
194   // will remain in the cache until it is replaced by something else.
195   void LookupEntryMarkedForEviction(CompiledSubgraph* entry)
196       EXCLUSIVE_LOCKS_REQUIRED(mu_);
197 
198   // Creates a new entry by running initialize_program and places it in the
199   // cache to be looked up by key. The new entry is in the 'marked for eviction'
200   // state (not present in entries_by_last_use_) and the caller is expected to
201   // call LookupEntryMarkedForEviction after InitializeEntry.
202   //
203   // **InitializeEntry releases mu_ during the call to initialize_program.**
204   CompiledSubgraph* InitializeEntry(
205       const string& key,
206       const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
207           initialize_program) EXCLUSIVE_LOCKS_REQUIRED(mu_);
208 
209   // The maximum number of entries that are stored in the cache before entries
210   // are marked for eviction.
211   const int max_cache_entries_;
212 
213   mutable absl::Mutex mu_;
214   // The total number of entries that are stored and not marked for eviction.
215   int cache_entries_ GUARDED_BY(mu_) = 0;
216   // The total number of entries that are marked for eviction.
217   int marked_for_eviction_entries_ GUARDED_BY(mu_) = 0;
218   // The value to assign to the last_use field of the next entry that is looked
219   // up.
220   int64 use_counter_ GUARDED_BY(mu_) = 0;
221   // All the executables that can be looked up in the cache index by key. An
222   // entry is marked for eviction iff it is present in cache_ and not in
223   // entries_by_last_use_.
224   std::unordered_map<string, CompiledSubgraph*> cache_ GUARDED_BY(mu_);
225   // All the executable entries that can be looked up in the cache indexed by
226   // uid.
227   std::unordered_map<int64, CompiledSubgraph*> entries_by_uid_ GUARDED_BY(mu_);
228   // Map from last_use to entry, used to mark entries for eviction in LRU
229   // order. If an entry's last_use counter is not present as a key in
230   // entries_by_last_use_ then the entry has been marked for eviction.
231   std::map<int64, CompiledSubgraph*> entries_by_last_use_ GUARDED_BY(mu_);
232 };
233 
234 }  // namespace tensorflow
235 
236 #endif  // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
237