1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 17 #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/container/inlined_vector.h" 21 #include "absl/types/optional.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 24 #include "tensorflow/compiler/tf2xla/xla_context.h" 25 #include "tensorflow/compiler/xla/client/local_client.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/core/common_runtime/device.h" 28 #include "tensorflow/core/common_runtime/device_mgr.h" 29 #include "tensorflow/core/framework/graph.pb.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/lib/core/threadpool.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 35 namespace tensorflow { 36 37 // The XlaCompilationCache class caches the results of the XlaCompiler class, 38 // which converts a Tensorflow graph into a compiled XLA compilation. 39 // 40 // Since XLA computations must have static shapes, the cache generates a new 41 // XLA computation for each new set of input shapes. 42 // 43 // Currently no cache eviction policy is implemented and the cache grows without 44 // bound. 45 class XlaCompilationCache : public ResourceBase { 46 public: 47 XlaCompilationCache(xla::LocalClient* client, DeviceType device_type); 48 ~XlaCompilationCache() override; 49 50 enum class CompileMode { 51 kLazy, 52 kStrict, 53 }; 54 55 // Compiles a function into a XlaCompiler::CompilationResult that can be used 56 // to execute an XLA Computation. Compilation results are cached. 57 // `function` is the name of a Tensorflow function to compile. 58 // `args` is a description of the arguments to the computation. 59 // 60 // `compile_mode` controls the behavior of the compilation cache on a cache 61 // miss. If `compile_mode` is `kLazy` then, based on some profitability 62 // heuristics, the compilation cache may decide not to compile the cluster at 63 // this time. In this case it returns null into both `out_compilation_result` 64 // and `out_executable`. If `compile_mode` is `kStrict` then the compilation 65 // cache always attempts the compilation on a cache miss. 66 // 67 // The result of compilation is written to `*out_compilation_result`, which 68 // must be non-null. If `out_executable` is non-null, also builds an 69 // xla::LocalExecutable and sets `out_executable` to point to it. The 70 // resulting executable pointer may be null if the computation has no 71 // non-constant outputs. 72 Status Compile(const XlaCompiler::Options& options, 73 const NameAttrList& function, 74 absl::Span<const XlaCompiler::Argument> args, 75 const XlaCompiler::CompileOptions& compile_options, 76 CompileMode compile_mode, 77 const XlaCompiler::CompilationResult** out_compilation_result, 78 xla::LocalExecutable** out_executable); 79 80 // As above, but calls XlaCompiler::CompileSingleOp instead of 81 // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto 82 // in OpKernelContext, then uses MLIR bridge for compilation instead of 83 // XlaCompiler, if possible. 84 Status CompileSingleOp( 85 const XlaCompiler::Options& options, 86 absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx, 87 const XlaCompiler::CompileOptions& compile_options, 88 const XlaCompiler::CompilationResult** out_compilation_result, 89 xla::LocalExecutable** out_executable); 90 client()91 xla::LocalClient* client() const { return client_; } device_type()92 const DeviceType& device_type() const { return device_type_; } 93 94 string DebugString() const override; 95 96 // Describes the types, shapes and any compile-time constant arguments 97 // to a kernel. Key that uniquely identifies a compilation output. 98 struct Signature { 99 string name; 100 101 // List of Tensor types & shapes for compile-time constant arguments to the 102 // compilation, ordered by argument number. 103 absl::InlinedVector<std::pair<DataType, absl::InlinedVector<int64, 4>>, 4> 104 arg_shapes; 105 106 // List of Tensor values for compile-time constant arguments to the 107 // compilation, ordered by argument number. Tensors must be in host memory. 108 absl::InlinedVector<Tensor, 4> arg_values; 109 110 bool operator==(const Signature& other) const; 111 112 struct Hash { 113 uint64 operator()(const Signature& signature) const; 114 }; 115 116 // Returns a human-readable description of the signature. 117 string HumanString() const; 118 }; 119 120 // Builds the signature for a compilation. 121 static xla::StatusOr<Signature> BuildSignature( 122 const NameAttrList& function, 123 absl::Span<const XlaCompiler::Argument> args); 124 125 private: 126 // Common implementation of Compile and CompileSingleOp. 127 Status CompileImpl( 128 const XlaCompiler::Options& options, const NameAttrList& function, 129 absl::Span<const XlaCompiler::Argument> args, 130 const std::function<Status(XlaCompiler* compiler, 131 XlaCompiler::CompilationResult*)>& compile_fn, 132 absl::optional<int64> compile_threshold, 133 const XlaCompiler::CompilationResult** out_compilation_result, 134 xla::LocalExecutable** out_executable); 135 136 // Takes `result` which has been compiled from a Tensorflow subgraph to a 137 // XLA computation already, and generates an XLA LocalExecutable `executable`. 138 Status BuildExecutable(const XlaCompiler::Options& options, 139 const XlaCompiler::CompilationResult& result, 140 std::unique_ptr<xla::LocalExecutable>* executable); 141 142 xla::LocalClient* const client_; 143 const DeviceType device_type_; 144 145 // The value associated with a cache entry. 146 struct Entry { 147 mutex mu; 148 149 // Have we tried compiling this entry? 150 bool compiled = false; 151 152 // The number of times a compilation with this signature has been requested. 153 int64 request_count = 0; 154 155 // Did compilation succeed? 156 Status compilation_status TF_GUARDED_BY(mu); 157 158 // Output of the XlaCompiler. 159 XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu); 160 161 // The XLA executable compiled from <computation>. May be null if no 162 // executable has been built. 163 std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu); 164 }; 165 166 mutex compile_cache_mu_; 167 absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ 168 TF_GUARDED_BY(compile_cache_mu_); 169 170 struct ClusterCompileStats { 171 // Number of times the cluster has been (re-)compiled. 172 int64 compile_count = 0; 173 174 // The number of times this cluster has been executed. 175 int64 execution_count = 0; 176 177 // Cumulative time spent compiling the cluster. 178 int64 cumulative_compile_time_us = 0; 179 180 // True if we have decided that this cluster is too dynamic (i.e. its shapes 181 // change too frequently) to profitably JIT compile. Once a cluster is 182 // tagged megamorphic, it stays megamorphic forever. 183 bool is_megamorphic = false; 184 }; 185 186 mutex cluster_compile_stats_mu_; 187 188 // Maps cluster names to compilation statistics for said cluster. 189 absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_ 190 TF_GUARDED_BY(cluster_compile_stats_mu_); 191 192 // The number of times a lazy compilation must be requested for a specific 193 // signature before we attempt to compile it. 194 static constexpr int64 kDefaultCompilationThreshold = 2; 195 196 TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); 197 }; 198 199 } // namespace tensorflow 200 201 #endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 202