1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 17 #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/container/inlined_vector.h" 21 #include "absl/types/optional.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 24 #include "tensorflow/compiler/tf2xla/xla_context.h" 25 #include "tensorflow/compiler/xla/client/local_client.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/core/common_runtime/device.h" 28 #include "tensorflow/core/common_runtime/device_mgr.h" 29 #include "tensorflow/core/framework/graph.pb.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/lib/core/threadpool.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 35 namespace tensorflow { 36 37 // The XlaCompilationCache class caches the results of the XlaCompiler class, 38 // which converts a Tensorflow graph into a compiled XLA compilation. 39 // 40 // Since XLA computations must have static shapes, the cache generates a new 41 // XLA computation for each new set of input shapes. 42 // 43 // Currently no cache eviction policy is implemented and the cache grows without 44 // bound. 45 class XlaCompilationCache : public ResourceBase { 46 public: 47 XlaCompilationCache(xla::LocalClient* client, DeviceType device_type); 48 ~XlaCompilationCache() override; 49 50 enum class CompileMode { 51 kLazy, 52 kStrict, 53 kAsync, 54 }; 55 56 enum class CompileState { 57 kUncompiled, 58 kCompiling, 59 kCompiled, 60 }; 61 62 enum class CompileScope { 63 kOp, 64 kFunction, 65 }; 66 67 // Compiles a function into a XlaCompiler::CompilationResult that can be used 68 // to execute an XLA Computation. Compilation results are cached. 69 // `function` is the name of a Tensorflow function to compile. 70 // `args` is a description of the arguments to the computation. 71 // 72 // `compile_mode` controls the behavior of the compilation cache on a cache 73 // miss. If `compile_mode` is `kLazy` then, based on some profitability 74 // heuristics, the compilation cache may decide not to compile the cluster at 75 // this time. In this case it returns null into both `out_compilation_result` 76 // and `out_executable`. If `compile_mode` is `kStrict` then the compilation 77 // cache always attempts the compilation on a cache miss. If compilation mode 78 // is 'kAsync' compilation of the cluster happens in the background while the 79 // fallback path executes. 80 // 81 // The result of compilation is written to `*out_compilation_result`, which 82 // must be non-null. If `out_executable` is non-null, also builds an 83 // xla::LocalExecutable and sets `out_executable` to point to it. The 84 // resulting executable pointer may be null if the computation has no 85 // non-constant outputs. 86 Status Compile(const XlaCompiler::Options& options, 87 const NameAttrList& function, 88 const std::vector<XlaCompiler::Argument>& args, 89 const XlaCompiler::CompileOptions& compile_options, 90 CompileMode compile_mode, 91 const XlaCompiler::CompilationResult** out_compilation_result, 92 xla::LocalExecutable** out_executable); 93 94 // As above, but calls XlaCompiler::CompileSingleOp instead of 95 // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto 96 // in OpKernelContext, then uses MLIR bridge for compilation instead of 97 // XlaCompiler, if possible. 98 Status CompileSingleOp( 99 const XlaCompiler::Options& options, 100 const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx, 101 const XlaCompiler::CompileOptions& compile_options, 102 const XlaCompiler::CompilationResult** out_compilation_result, 103 xla::LocalExecutable** out_executable); 104 client()105 xla::LocalClient* client() const { return client_; } device_type()106 const DeviceType& device_type() const { return device_type_; } 107 108 string DebugString() const override; 109 110 // Describes the types, shapes and any compile-time constant arguments 111 // to a kernel. Key that uniquely identifies a compilation output. 112 struct Signature { 113 string name; 114 115 // List of Tensor types & shapes for compile-time constant arguments to the 116 // compilation, ordered by argument number. 117 absl::InlinedVector<std::pair<DataType, absl::InlinedVector<int64, 4>>, 4> 118 arg_shapes; 119 120 // List of Tensor values for compile-time constant arguments to the 121 // compilation, ordered by argument number. Tensors must be in host memory. 122 absl::InlinedVector<Tensor, 4> arg_values; 123 124 bool operator==(const Signature& other) const; 125 126 struct Hash { 127 uint64 operator()(const Signature& signature) const; 128 }; 129 130 // Returns a human-readable description of the signature. 131 string HumanString() const; 132 }; 133 134 // Builds the signature for a compilation. 135 static StatusOr<Signature> BuildSignature( 136 const NameAttrList& function, 137 absl::Span<const XlaCompiler::Argument> args); 138 139 private: 140 // Common implementation of Compile and CompileSingleOp. The `OpKernelContext` 141 // parameter is always null for the former. 142 Status CompileImpl( 143 const XlaCompiler::CompileOptions& compile_options, 144 const XlaCompiler::Options& options, const NameAttrList& function, 145 const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx, 146 CompileScope scope, CompileMode compile_mode, 147 const XlaCompiler::CompilationResult** out_compilation_result, 148 xla::LocalExecutable** out_executable); 149 150 // Takes `result` which has been compiled from a Tensorflow subgraph to a 151 // XLA computation already, and generates an XLA LocalExecutable `executable`. 152 Status BuildExecutable(const XlaCompiler::Options& options, 153 const XlaCompiler::CompilationResult& result, 154 std::unique_ptr<xla::LocalExecutable>* executable); 155 156 // Determines whether the cluster should be compiled. 157 bool ShouldCompileCluster(CompileMode compile_mode, bool is_megamorphic, 158 bool is_first_execution, 159 int64_t current_request_count, 160 const NameAttrList& function); 161 162 xla::LocalClient* const client_; 163 const DeviceType device_type_; 164 165 // The value associated with a cache entry. 166 struct Entry { 167 mutex mu; 168 169 // The current compilation state for this entry. 170 CompileState compile_state = CompileState::kUncompiled; 171 172 // The number of times a compilation with this signature has been requested. 173 int64 request_count = 0; 174 175 // Did compilation succeed? 176 Status compilation_status TF_GUARDED_BY(mu); 177 178 // Output of the XlaCompiler. 179 XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu); 180 181 // The XLA executable compiled from <computation>. May be null if no 182 // executable has been built. 183 std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu); 184 }; 185 186 Status CompileStrict(Entry* entry, 187 const XlaCompiler::CompileOptions& compile_options, 188 const XlaCompiler::Options& options, 189 const std::vector<XlaCompiler::Argument>& args, 190 const NameAttrList& function, OpKernelContext* ctx, 191 CompileScope scope) 192 TF_EXCLUSIVE_LOCKS_REQUIRED(entry->mu); 193 Status CompileAsynchronous(Entry* entry, 194 const XlaCompiler::CompileOptions& compile_options, 195 const XlaCompiler::Options& options, 196 const std::vector<XlaCompiler::Argument>& args, 197 const NameAttrList& function, OpKernelContext* ctx, 198 CompileScope scope); 199 200 mutex compile_cache_mu_; 201 absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ 202 TF_GUARDED_BY(compile_cache_mu_); 203 204 struct ClusterCompileStats { 205 // Number of times the cluster has been (re-)compiled. 206 int64 compile_count = 0; 207 208 // The number of times this cluster has been executed. 209 int64 execution_count = 0; 210 211 // Cumulative time spent compiling the cluster. 212 int64 cumulative_compile_time_us = 0; 213 214 // True if we have decided that this cluster is too dynamic (i.e. its shapes 215 // change too frequently) to profitably JIT compile. Once a cluster is 216 // tagged megamorphic, it stays megamorphic forever. 217 bool is_megamorphic = false; 218 }; 219 220 mutex cluster_compile_stats_mu_; 221 222 // Maps cluster names to compilation statistics for said cluster. 223 absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_ 224 TF_GUARDED_BY(cluster_compile_stats_mu_); 225 226 struct AsyncCompilationState { 227 mutex async_compilation_state_mu; 228 229 // Number of threads for asynchronous compilations. 230 static constexpr int64_t kNumCompilerThreads = 10; 231 232 // Maximum number of ongoing compilations. 233 static constexpr int64_t kMaxNumOngoingCompilations = kNumCompilerThreads; 234 235 // Number of ongoing compilations. 236 int64 num_ongoing_compilations TF_GUARDED_BY(async_compilation_state_mu) = 237 0; 238 239 // Pool of threads for asynchronous compilations. 240 std::unique_ptr<thread::ThreadPool> compiler_threads; 241 AsyncCompilationStateAsyncCompilationState242 AsyncCompilationState() { 243 compiler_threads = absl::make_unique<tensorflow::thread::ThreadPool>( 244 tensorflow::Env::Default(), "async_compiler_threads", 245 kNumCompilerThreads); 246 } 247 248 } async_compilation_state_; 249 250 // The number of times a lazy compilation must be requested for a specific 251 // signature before we attempt to compile it. 252 static constexpr int64_t kDefaultCompilationThreshold = 2; 253 254 TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); 255 }; 256 257 // Creates a single-node graph using the specified node_def as the only op apart 258 // from the arg and retval nodes. 259 StatusOr<std::unique_ptr<Graph>> CreateGraph( 260 const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args, 261 absl::Span<const DataType> result_types); 262 263 // Use XlaCompiler to compile a single op into HLO. 264 Status XlaSingleOpToHlo(XlaCompiler* compiler, 265 const XlaCompiler::Options& options, 266 const std::vector<XlaCompiler::Argument>& args, 267 OpKernelContext* ctx, 268 const XlaCompiler::CompileOptions& compile_options, 269 XlaCompiler::CompilationResult* compilation_result); 270 271 } // namespace tensorflow 272 273 #endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 274