• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
17 #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/inlined_vector.h"
21 #include "absl/types/optional.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
24 #include "tensorflow/compiler/tf2xla/xla_context.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 
35 namespace tensorflow {
36 
37 // The XlaCompilationCache class caches the results of the XlaCompiler class,
38 // which converts a Tensorflow graph into a compiled XLA compilation.
39 //
40 // Since XLA computations must have static shapes, the cache generates a new
41 // XLA computation for each new set of input shapes.
42 //
43 // Currently no cache eviction policy is implemented and the cache grows without
44 // bound.
45 class XlaCompilationCache : public ResourceBase {
46  public:
47   XlaCompilationCache(xla::LocalClient* client, DeviceType device_type);
48   ~XlaCompilationCache() override;
49 
50   enum class CompileMode {
51     kLazy,
52     kStrict,
53     kAsync,
54   };
55 
56   enum class CompileState {
57     kUncompiled,
58     kCompiling,
59     kCompiled,
60   };
61 
62   enum class CompileScope {
63     kOp,
64     kFunction,
65   };
66 
67   // Compiles a function into a XlaCompiler::CompilationResult that can be used
68   // to execute an XLA Computation. Compilation results are cached.
69   // `function` is the name of a Tensorflow function to compile.
70   // `args` is a description of the arguments to the computation.
71   //
72   // `compile_mode` controls the behavior of the compilation cache on a cache
73   // miss.  If `compile_mode` is `kLazy` then, based on some profitability
74   // heuristics, the compilation cache may decide not to compile the cluster at
75   // this time.  In this case it returns null into both `out_compilation_result`
76   // and `out_executable`.  If `compile_mode` is `kStrict` then the compilation
77   // cache always attempts the compilation on a cache miss. If compilation mode
78   // is 'kAsync' compilation of the cluster happens in the background while the
79   // fallback path executes.
80   //
81   // The result of compilation is written to `*out_compilation_result`, which
82   // must be non-null. If `out_executable` is non-null, also builds an
83   // xla::LocalExecutable and sets `out_executable` to point to it. The
84   // resulting executable pointer may be null if the computation has no
85   // non-constant outputs.
86   Status Compile(const XlaCompiler::Options& options,
87                  const NameAttrList& function,
88                  const std::vector<XlaCompiler::Argument>& args,
89                  const XlaCompiler::CompileOptions& compile_options,
90                  CompileMode compile_mode,
91                  const XlaCompiler::CompilationResult** out_compilation_result,
92                  xla::LocalExecutable** out_executable);
93 
94   // As above, but calls XlaCompiler::CompileSingleOp instead of
95   // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto
96   // in OpKernelContext, then uses MLIR bridge for compilation instead of
97   // XlaCompiler, if possible.
98   Status CompileSingleOp(
99       const XlaCompiler::Options& options,
100       const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
101       const XlaCompiler::CompileOptions& compile_options,
102       const XlaCompiler::CompilationResult** out_compilation_result,
103       xla::LocalExecutable** out_executable);
104 
client()105   xla::LocalClient* client() const { return client_; }
device_type()106   const DeviceType& device_type() const { return device_type_; }
107 
108   string DebugString() const override;
109 
110   // Describes the types, shapes and any compile-time constant arguments
111   // to a kernel. Key that uniquely identifies a compilation output.
112   struct Signature {
113     string name;
114 
115     // List of Tensor types & shapes for compile-time constant arguments to the
116     // compilation, ordered by argument number.
117     absl::InlinedVector<std::pair<DataType, absl::InlinedVector<int64, 4>>, 4>
118         arg_shapes;
119 
120     // List of Tensor values for compile-time constant arguments to the
121     // compilation, ordered by argument number. Tensors must be in host memory.
122     absl::InlinedVector<Tensor, 4> arg_values;
123 
124     bool operator==(const Signature& other) const;
125 
126     struct Hash {
127       uint64 operator()(const Signature& signature) const;
128     };
129 
130     // Returns a human-readable description of the signature.
131     string HumanString() const;
132   };
133 
134   // Builds the signature for a compilation.
135   static StatusOr<Signature> BuildSignature(
136       const NameAttrList& function,
137       absl::Span<const XlaCompiler::Argument> args);
138 
139  private:
140   // Common implementation of Compile and CompileSingleOp. The `OpKernelContext`
141   // parameter is always null for the former.
142   Status CompileImpl(
143       const XlaCompiler::CompileOptions& compile_options,
144       const XlaCompiler::Options& options, const NameAttrList& function,
145       const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
146       CompileScope scope, CompileMode compile_mode,
147       const XlaCompiler::CompilationResult** out_compilation_result,
148       xla::LocalExecutable** out_executable);
149 
150   // Takes `result` which has been compiled from a Tensorflow subgraph to a
151   // XLA computation already, and generates an XLA LocalExecutable `executable`.
152   Status BuildExecutable(const XlaCompiler::Options& options,
153                          const XlaCompiler::CompilationResult& result,
154                          std::unique_ptr<xla::LocalExecutable>* executable);
155 
156   // Determines whether the cluster should be compiled.
157   bool ShouldCompileCluster(CompileMode compile_mode, bool is_megamorphic,
158                             bool is_first_execution,
159                             int64_t current_request_count,
160                             const NameAttrList& function);
161 
162   xla::LocalClient* const client_;
163   const DeviceType device_type_;
164 
165   // The value associated with a cache entry.
166   struct Entry {
167     mutex mu;
168 
169     // The current compilation state for this entry.
170     CompileState compile_state = CompileState::kUncompiled;
171 
172     // The number of times a compilation with this signature has been requested.
173     int64 request_count = 0;
174 
175     // Did compilation succeed?
176     Status compilation_status TF_GUARDED_BY(mu);
177 
178     // Output of the XlaCompiler.
179     XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu);
180 
181     // The XLA executable compiled from <computation>. May be null if no
182     // executable has been built.
183     std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu);
184   };
185 
186   Status CompileStrict(Entry* entry,
187                        const XlaCompiler::CompileOptions& compile_options,
188                        const XlaCompiler::Options& options,
189                        const std::vector<XlaCompiler::Argument>& args,
190                        const NameAttrList& function, OpKernelContext* ctx,
191                        CompileScope scope)
192       TF_EXCLUSIVE_LOCKS_REQUIRED(entry->mu);
193   Status CompileAsynchronous(Entry* entry,
194                              const XlaCompiler::CompileOptions& compile_options,
195                              const XlaCompiler::Options& options,
196                              const std::vector<XlaCompiler::Argument>& args,
197                              const NameAttrList& function, OpKernelContext* ctx,
198                              CompileScope scope);
199 
200   mutex compile_cache_mu_;
201   absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
202       TF_GUARDED_BY(compile_cache_mu_);
203 
204   struct ClusterCompileStats {
205     // Number of times the cluster has been (re-)compiled.
206     int64 compile_count = 0;
207 
208     // The number of times this cluster has been executed.
209     int64 execution_count = 0;
210 
211     // Cumulative time spent compiling the cluster.
212     int64 cumulative_compile_time_us = 0;
213 
214     // True if we have decided that this cluster is too dynamic (i.e. its shapes
215     // change too frequently) to profitably JIT compile.  Once a cluster is
216     // tagged megamorphic, it stays megamorphic forever.
217     bool is_megamorphic = false;
218   };
219 
220   mutex cluster_compile_stats_mu_;
221 
222   // Maps cluster names to compilation statistics for said cluster.
223   absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
224       TF_GUARDED_BY(cluster_compile_stats_mu_);
225 
226   struct AsyncCompilationState {
227     mutex async_compilation_state_mu;
228 
229     // Number of threads for asynchronous compilations.
230     static constexpr int64_t kNumCompilerThreads = 10;
231 
232     // Maximum number of ongoing compilations.
233     static constexpr int64_t kMaxNumOngoingCompilations = kNumCompilerThreads;
234 
235     // Number of ongoing compilations.
236     int64 num_ongoing_compilations TF_GUARDED_BY(async_compilation_state_mu) =
237         0;
238 
239     // Pool of threads for asynchronous compilations.
240     std::unique_ptr<thread::ThreadPool> compiler_threads;
241 
AsyncCompilationStateAsyncCompilationState242     AsyncCompilationState() {
243       compiler_threads = absl::make_unique<tensorflow::thread::ThreadPool>(
244           tensorflow::Env::Default(), "async_compiler_threads",
245           kNumCompilerThreads);
246     }
247 
248   } async_compilation_state_;
249 
250   // The number of times a lazy compilation must be requested for a specific
251   // signature before  we attempt to compile it.
252   static constexpr int64_t kDefaultCompilationThreshold = 2;
253 
254   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
255 };
256 
257 // Creates a single-node graph using the specified node_def as the only op apart
258 // from the arg and retval nodes.
259 StatusOr<std::unique_ptr<Graph>> CreateGraph(
260     const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
261     absl::Span<const DataType> result_types);
262 
263 // Use XlaCompiler to compile a single op into HLO.
264 Status XlaSingleOpToHlo(XlaCompiler* compiler,
265                         const XlaCompiler::Options& options,
266                         const std::vector<XlaCompiler::Argument>& args,
267                         OpKernelContext* ctx,
268                         const XlaCompiler::CompileOptions& compile_options,
269                         XlaCompiler::CompilationResult* compilation_result);
270 
271 }  // namespace tensorflow
272 
273 #endif  // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
274