1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 17 #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/types/optional.h" 21 #include "absl/types/span.h" 22 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 23 #include "tensorflow/compiler/tf2xla/xla_context.h" 24 #include "tensorflow/compiler/xla/client/local_client.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/core/common_runtime/device.h" 27 #include "tensorflow/core/common_runtime/device_mgr.h" 28 #include "tensorflow/core/framework/graph.pb.h" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/lib/core/threadpool.h" 31 #include "tensorflow/core/platform/mutex.h" 32 #include "tensorflow/core/platform/thread_annotations.h" 33 34 namespace tensorflow { 35 36 // The XlaCompilationCache class caches the results of the XlaCompiler class, 37 // which converts a Tensorflow graph into a compiled XLA compilation. 38 // 39 // Since XLA computations must have static shapes, the cache generates a new 40 // XLA computation for each new set of input shapes. 41 // 42 // Currently no cache eviction policy is implemented and the cache grows without 43 // bound. 44 class XlaCompilationCache : public ResourceBase { 45 public: 46 XlaCompilationCache(xla::LocalClient* client, DeviceType device_type); 47 ~XlaCompilationCache() override; 48 49 enum class CompileMode { 50 kLazy, 51 kStrict, 52 }; 53 54 // Compiles a function into a XlaCompiler::CompilationResult that can be used 55 // to execute an XLA Computation. Compilation results are cached. 56 // `function` is the name of a Tensorflow function to compile. 57 // `args` is a description of the arguments to the computation. 58 // 59 // `compile_mode` controls the behavior of the compilation cache on a cache 60 // miss. If `compile_mode` is `kLazy` then, based on some profitability 61 // heuristics, the compilation cache may decide not to compile the cluster at 62 // this time. In this case it returns null into both `out_compilation_result` 63 // and `out_executable`. If `compile_mode` is `kStrict` then the compilation 64 // cache always attempts the compilation on a cache miss. 65 // 66 // The result of compilation is written to `*compilation_result`, which must 67 // be non-null. If `executable` is non-null, also builds an 68 // xla::LocalExecutable and sets `executable` to point to it. The resulting 69 // executable pointer may be null if the computation has no non-constant 70 // outputs. 71 Status Compile(const XlaCompiler::Options& options, 72 const NameAttrList& function, 73 absl::Span<const XlaCompiler::Argument> args, 74 const XlaCompiler::CompileOptions& compile_options, 75 CompileMode compile_mode, 76 const XlaCompiler::CompilationResult** out_compilation_result, 77 xla::LocalExecutable** out_executable); 78 79 // As above, but calls XlaCompiler::CompileSingleOp instead of 80 // XlaCompiler::CompileFunction. 81 Status CompileSingleOp( 82 const XlaCompiler::Options& options, 83 absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx, 84 const XlaCompiler::CompileOptions& compile_options, 85 const XlaCompiler::CompilationResult** out_compilation_result, 86 xla::LocalExecutable** out_executable); 87 client()88 xla::LocalClient* client() const { return client_; } device_type()89 const DeviceType& device_type() const { return device_type_; } 90 91 string DebugString() const override; 92 93 // Describes the types, shapes and any compile-time constant arguments 94 // to a kernel. Key that uniquely identifies a compilation output. 95 struct Signature { 96 string name; 97 98 // List of Tensor types & shapes for compile-time constant arguments to the 99 // compilation, ordered by argument number. 100 std::vector<std::pair<DataType, std::vector<int64>>> arg_shapes; 101 102 // List of Tensor values for compile-time constant arguments to the 103 // compilation, ordered by argument number. Tensors must be in host memory. 104 std::vector<Tensor> arg_values; 105 106 bool operator==(const Signature& other) const; 107 108 struct Hash { 109 uint64 operator()(const Signature& signature) const; 110 }; 111 112 // Returns a human-readable description of the signature. 113 string HumanString() const; 114 }; 115 116 // Builds the signature for a compilation. 117 static xla::StatusOr<Signature> BuildSignature( 118 const NameAttrList& function, 119 absl::Span<const XlaCompiler::Argument> args); 120 121 private: 122 // Common implementation of Compile and CompileSingleOp. 123 Status CompileImpl( 124 const XlaCompiler::Options& options, const NameAttrList& function, 125 absl::Span<const XlaCompiler::Argument> args, 126 const std::function<Status(XlaCompiler* compiler, 127 XlaCompiler::CompilationResult*)>& compile_fn, 128 absl::optional<int64> compile_threshold, 129 const XlaCompiler::CompilationResult** out_compilation_result, 130 xla::LocalExecutable** out_executable); 131 132 // Takes `result` which has been compiled from a Tensorflow subgraph to a 133 // XLA computation already, and generates an XLA LocalExecutable `executable`. 134 Status BuildExecutable(const XlaCompiler::Options& options, 135 const XlaCompiler::CompilationResult& result, 136 std::unique_ptr<xla::LocalExecutable>* executable); 137 138 xla::LocalClient* const client_; 139 const DeviceType device_type_; 140 141 // The value associated with a cache entry. 142 struct Entry { 143 mutex mu; 144 145 // Have we tried compiling this entry? 146 bool compiled = false; 147 148 // The number of times a compilation with this signature has been requested. 149 int64 request_count = 0; 150 151 // Did compilation succeed? 152 Status compilation_status GUARDED_BY(mu); 153 154 // Output of the XlaCompiler. 155 XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu); 156 157 // The XLA executable compiled from <computation>. May be null if no 158 // executable has been built. 159 std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu); 160 }; 161 162 mutex compile_cache_mu_; 163 absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ 164 GUARDED_BY(compile_cache_mu_); 165 166 struct ClusterCompileStats { 167 // Number of times the cluster has been (re-)compiled. 168 int64 compile_count = 0; 169 170 // The number of times this cluster has been executed. 171 int64 execution_count = 0; 172 173 // Cumulative time spent compiling the cluster. 174 int64 cumulative_compile_time_us = 0; 175 176 // True if we have decided that this cluster is too dynamic (i.e. its shapes 177 // change too frequently) to profitably JIT compile. Once a cluster is 178 // tagged megamorphic, it stays megamorphic forever. 179 bool is_megamorphic = false; 180 }; 181 182 mutex cluster_compile_stats_mu_; 183 184 // Maps cluster names to compilation statistics for said cluster. 185 absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_ 186 GUARDED_BY(cluster_compile_stats_mu_); 187 188 // The number of times a lazy compilation must be requested for a specific 189 // signature before we attempt to compile it. 190 static constexpr int64 kDefaultCompilationThreshold = 2; 191 192 TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); 193 }; 194 195 } // namespace tensorflow 196 197 #endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 198