• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
17 #define TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
18 
19 #include <memory>
20 #include <string>
21 
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xrt/xrt_refptr.h"
26 #include "tensorflow/core/framework/resource_mgr.h"
27 #include "tensorflow/core/lib/core/refcount.h"
28 
29 namespace tensorflow {
30 
31 extern const char* kXRTCompilationCacheResourceName;
32 
33 struct XRTCompilationCacheEntry {
XRTCompilationCacheEntryXRTCompilationCacheEntry34   explicit XRTCompilationCacheEntry(xla::LocalExecutable* executable)
35       : executable(executable) {}
36 
37   // Returns a non-owned pointer to an immutable executable.
get_executableXRTCompilationCacheEntry38   xla::LocalExecutable* get_executable() const { return executable; }
39 
40  private:
41   xla::LocalExecutable* executable;
42 };
43 
44 // Base class for a reference to a cached executable. A unique_ptr to a
45 // XRTCompilationCacheEntryRef is returned by the cache Lookup methods below,
46 // and ensures the underlying executable is not garbage-collected until the
47 // client discards the ptr.
48 class XRTCompilationCacheEntryRef {
49  public:
50   virtual ~XRTCompilationCacheEntryRef() = default;
51 
52   // Returns a XRTCompilationCacheEntry that should not be used beyond the
53   // lifetime of the XRTCompilationCacheEntryRef.
54   virtual XRTCompilationCacheEntry get() = 0;
55 };
56 
57 // Cache for compiled XLA executables.
58 // TODO(b/112646171) rationalize this with the other compilation caches.
59 //
60 // Each key identifies a unique XLA computation, and the value is executable
61 // generated by compiling the computation.
62 //
63 // When a computation is considered for compilation, the client calls
64 //
65 // auto key = <compute key for computation>;
66 // auto compile_function = <lambda to compile computation into executable>;
67 // int64 uid;
68 // CompileIfKeyAbsent(computation_key, &uid, compile_function);
69 //
70 // where computation_key is the key computed for the computation. On success,
71 // uid contains an identifier that can be used to look up the executable. If the
72 // compiled executable were not present in the cache, compile_function would be
73 // called to generate it.
74 //
75 // The caller is responsible for calling Release(uid) once for every
76 // call to CompileIfKeyAbsent(key, ...) to discard the reference to the
77 // compilation results, after the caller is sure it will not look up the
78 // compiled executables again.
79 //
80 // Subsequently the client can call
81 //
82 // std::unique_ptr<XRTCompilationCacheEntryRef> entry;
83 // Lookup(uid, &entry);
84 // auto proto = entry->get();
85 //
86 // to access a cached executable.
87 class XRTCompilationCache : public ResourceBase {
88  public:
89   // There is no way in general to discover the size taken by an XLA executable,
90   // so the cache defaults to a specific number of entries to determine when to
91   // start evicting programs. TODO(b/112592410) change this if the XLA API gets
92   // a mechanism to query size.
93   explicit XRTCompilationCache(int max_number_of_entries);
94   ~XRTCompilationCache() override;
95 
96   // Ensures there is an entry for key present in the cache. By the time
97   // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
98   // for key, and that entry will remain valid at least until Release is called
99   // on the returned uid. The first call to CompileIfKeyAbsent with a key that
100   // is not in the cache will evaluate compile_function to compute the value to
101   // use in the entry. Subsequent calls with the same key will block until
102   // compile_function completes. Other cache reads and inserts may proceed on
103   // other threads while compile_function is executing. The caller is
104   // responsible for calling Release(uid) to manually discard its reference to
105   // the compiled program, once the caller will not look up the compiled program
106   // again.
107   //
108   // compile_function should compile the computation represented by key and fill
109   // the xla::LocalExecutable into its passed argument. It should return OK
110   // if and only if compilation succeeds. The executable will be discarded on
111   // non-OK status.
112   Status CompileIfKeyAbsent(
113       const string& key, int64* uid,
114       const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
115           compile_function);
116 
117   Status Release(int64_t uid);
118 
119   // Looks up an executable corresponding to uid. On success a pointer to an
120   // EntryRef holding the program is returned in entry.
121   Status Lookup(int64_t uid,
122                 std::unique_ptr<XRTCompilationCacheEntryRef>* entry);
123 
124   string DebugString() const override;
125 
126  private:
127   // An entry in the compilation cache. The entry is deleted once it has been
128   // marked for eviction from the cache _and_ all looked-up entries have been
129   // released. When the entry is first created, it is uninitialized and a
130   // client-supplied compilation function is run outside the cache's lock to
131   // generate the program to be stored in the entry. Any other client that
132   // requests the entry will block until it has been initialized. Each entry has
133   // a last_use value that set from a monotonically-increasing counter in the
134   // cache whenever the entry is referenced. When the cache becomes full,
135   // entries are marked for eviction in LRU order.
136   struct CompiledSubgraph : public core::RefCounted {
137     ~CompiledSubgraph() override = default;
138 
139     XRTCompilationCache* parent = nullptr;  // Not owned.
140     bool initialized = false;
141     // The Status returned by the compilation function when the entry is
142     // initialized. This status will be returned to any client that requests the
143     // entry.
144     Status initialization_status;
145     // Counter to keep track of LRU entries for the eviction policy.
146     int64 last_use = -1;
147     // The unique key describing this entry.
148     string key;
149     // The uid describing this entry.
150     int64 uid;
151     // The compiled payload corresponding to the key.
152     std::unique_ptr<xla::LocalExecutable> program;
153   };
154 
155   // Wrapper for a cache entry that holds a reference to the entry until the
156   // wrapper is deleted. This wrapper is the concrete type of
157   // XRTCompilationCacheEntryRef returned by Lookup.
158   class EntryRefImpl : public XRTCompilationCacheEntryRef {
159    public:
160     EntryRefImpl(XRTCompilationCache* parent, CompiledSubgraph* entry);
161     ~EntryRefImpl() override;
162 
163     XRTCompilationCacheEntry get() override;
164 
165    private:
166     XRTCompilationCache* parent_;  // Not owned.
167     // A reference to entry_ is acquired in the contructor and released via
168     // parent->DiscardEntryRef in the destructor.
169     CompiledSubgraph* entry_;
170   };
171 
172   // Releases one reference to entry. This is called by the cache when entry is
173   // marked for eviction; or by an EntryRefImpl when it is destroyed. Before the
174   // last reference to entry is released, entry is removed from cache_.
175   void DiscardEntryRef(CompiledSubgraph* entry);
176   void DiscardEntryRefLocked(CompiledSubgraph* entry)
177       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
178 
179   // Marks the oldest unmarked entry for eviction. Requires that there is at
180   // least one such entry.
181   void MarkOldestEntryForEviction() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
182 
183   // Updates datastructures to indicate that entry, which had been marked for
184   // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
185   // entry is newly created, or an entry that has been marked for eviction but
186   // not yet evicted is looked up.
187   //
188   // First the entry is unmarked for eviction, i.e. the cache gains a reference
189   // to entry, entry's last_use field is set to be the most recent value of
190   // use_counter_ and entries_by_last_use_ is updated accordingly.
191   //
192   // Next, the size of the cache is examined to see if any other entries need to
193   // be marked for eviction now that entry has been unmarked. While the total
194   // number of unmarked cached entries is greater than max_cache_entries_,
195   // entries are marked for eviction in LRU order. The most recently used entry
196   // is never marked for eviction, so an entry larger than the max cache entries
197   // will remain in the cache until it is replaced by something else.
198   void LookupEntryMarkedForEviction(CompiledSubgraph* entry)
199       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
200 
201   // Creates a new entry by running initialize_program and places it in the
202   // cache to be looked up by key. The new entry is in the 'marked for eviction'
203   // state (not present in entries_by_last_use_) and the caller is expected to
204   // call LookupEntryMarkedForEviction after InitializeEntry.
205   //
206   // **InitializeEntry releases mu_ during the call to initialize_program.**
207   CompiledSubgraph* InitializeEntry(
208       const string& key,
209       const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
210           initialize_program) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
211 
212   // The maximum number of entries that are stored in the cache before entries
213   // are marked for eviction.
214   const int max_cache_entries_;
215 
216   mutable absl::Mutex mu_;
217   // The total number of entries that are stored and not marked for eviction.
218   int cache_entries_ TF_GUARDED_BY(mu_) = 0;
219   // The total number of entries that are marked for eviction.
220   int marked_for_eviction_entries_ TF_GUARDED_BY(mu_) = 0;
221   // The value to assign to the last_use field of the next entry that is looked
222   // up.
223   int64 use_counter_ TF_GUARDED_BY(mu_) = 0;
224   // All the executables that can be looked up in the cache index by key. An
225   // entry is marked for eviction iff it is present in cache_ and not in
226   // entries_by_last_use_.
227   std::unordered_map<string, CompiledSubgraph*> cache_ TF_GUARDED_BY(mu_);
228   // All the executable entries that can be looked up in the cache indexed by
229   // uid.
230   std::unordered_map<int64, CompiledSubgraph*> entries_by_uid_
231       TF_GUARDED_BY(mu_);
232   // Map from last_use to entry, used to mark entries for eviction in LRU
233   // order. If an entry's last_use counter is not present as a key in
234   // entries_by_last_use_ then the entry has been marked for eviction.
235   std::map<int64, CompiledSubgraph*> entries_by_last_use_ TF_GUARDED_BY(mu_);
236 };
237 
238 // Looks up or create an XRTCompilationCache object within the given resource
239 // manager, under the default container. The max_number_of_entries sets the
240 // maximum number of entries within the cache (which will be LRU-evicted).
241 // If max_number_of_entries is set to sero, the size of the cache will be
242 // configured using the TF_XRT_COMPILATION_CACHE_SIZE environment variable.
243 xla::StatusOr<RefPtr<XRTCompilationCache>> GetOrCreateCompilationCache(
244     ResourceMgr* rm, int64_t max_number_of_entries);
245 
246 }  // namespace tensorflow
247 
248 #endif  // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
249