• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
17 #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/types/optional.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
23 #include "tensorflow/compiler/tf2xla/xla_context.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 
34 namespace tensorflow {
35 
36 // The XlaCompilationCache class caches the results of the XlaCompiler class,
37 // which converts a Tensorflow graph into a compiled XLA compilation.
38 //
39 // Since XLA computations must have static shapes, the cache generates a new
40 // XLA computation for each new set of input shapes.
41 //
42 // Currently no cache eviction policy is implemented and the cache grows without
43 // bound.
44 class XlaCompilationCache : public ResourceBase {
45  public:
46   XlaCompilationCache(xla::LocalClient* client, DeviceType device_type);
47   ~XlaCompilationCache() override;
48 
49   enum class CompileMode {
50     kLazy,
51     kStrict,
52   };
53 
54   // Compiles a function into a XlaCompiler::CompilationResult that can be used
55   // to execute an XLA Computation. Compilation results are cached.
56   // `function` is the name of a Tensorflow function to compile.
57   // `args` is a description of the arguments to the computation.
58   //
59   // `compile_mode` controls the behavior of the compilation cache on a cache
60   // miss.  If `compile_mode` is `kLazy` then, based on some profitability
61   // heuristics, the compilation cache may decide not to compile the cluster at
62   // this time.  In this case it returns null into both `out_compilation_result`
63   // and `out_executable`.  If `compile_mode` is `kStrict` then the compilation
64   // cache always attempts the compilation on a cache miss.
65   //
66   // The result of compilation is written to `*compilation_result`, which must
67   // be non-null. If `executable` is non-null, also builds an
68   // xla::LocalExecutable and sets `executable` to point to it. The resulting
69   // executable pointer may be null if the computation has no non-constant
70   // outputs.
71   Status Compile(const XlaCompiler::Options& options,
72                  const NameAttrList& function,
73                  absl::Span<const XlaCompiler::Argument> args,
74                  const XlaCompiler::CompileOptions& compile_options,
75                  CompileMode compile_mode,
76                  const XlaCompiler::CompilationResult** out_compilation_result,
77                  xla::LocalExecutable** out_executable);
78 
79   // As above, but calls XlaCompiler::CompileSingleOp instead of
80   // XlaCompiler::CompileFunction.
81   Status CompileSingleOp(
82       const XlaCompiler::Options& options,
83       absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
84       const XlaCompiler::CompileOptions& compile_options,
85       const XlaCompiler::CompilationResult** out_compilation_result,
86       xla::LocalExecutable** out_executable);
87 
client()88   xla::LocalClient* client() const { return client_; }
device_type()89   const DeviceType& device_type() const { return device_type_; }
90 
91   string DebugString() const override;
92 
93   // Describes the types, shapes and any compile-time constant arguments
94   // to a kernel. Key that uniquely identifies a compilation output.
95   struct Signature {
96     string name;
97 
98     // List of Tensor types & shapes for compile-time constant arguments to the
99     // compilation, ordered by argument number.
100     std::vector<std::pair<DataType, std::vector<int64>>> arg_shapes;
101 
102     // List of Tensor values for compile-time constant arguments to the
103     // compilation, ordered by argument number. Tensors must be in host memory.
104     std::vector<Tensor> arg_values;
105 
106     bool operator==(const Signature& other) const;
107 
108     struct Hash {
109       uint64 operator()(const Signature& signature) const;
110     };
111 
112     // Returns a human-readable description of the signature.
113     string HumanString() const;
114   };
115 
116   // Builds the signature for a compilation.
117   static xla::StatusOr<Signature> BuildSignature(
118       const NameAttrList& function,
119       absl::Span<const XlaCompiler::Argument> args);
120 
121  private:
122   // Common implementation of Compile and CompileSingleOp.
123   Status CompileImpl(
124       const XlaCompiler::Options& options, const NameAttrList& function,
125       absl::Span<const XlaCompiler::Argument> args,
126       const std::function<Status(XlaCompiler* compiler,
127                                  XlaCompiler::CompilationResult*)>& compile_fn,
128       absl::optional<int64> compile_threshold,
129       const XlaCompiler::CompilationResult** out_compilation_result,
130       xla::LocalExecutable** out_executable);
131 
132   // Takes `result` which has been compiled from a Tensorflow subgraph to a
133   // XLA computation already, and generates an XLA LocalExecutable `executable`.
134   Status BuildExecutable(const XlaCompiler::Options& options,
135                          const XlaCompiler::CompilationResult& result,
136                          std::unique_ptr<xla::LocalExecutable>* executable);
137 
138   xla::LocalClient* const client_;
139   const DeviceType device_type_;
140 
141   // The value associated with a cache entry.
142   struct Entry {
143     mutex mu;
144 
145     // Have we tried compiling this entry?
146     bool compiled = false;
147 
148     // The number of times a compilation with this signature has been requested.
149     int64 request_count = 0;
150 
151     // Did compilation succeed?
152     Status compilation_status GUARDED_BY(mu);
153 
154     // Output of the XlaCompiler.
155     XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);
156 
157     // The XLA executable compiled from <computation>. May be null if no
158     // executable has been built.
159     std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
160   };
161 
162   mutex compile_cache_mu_;
163   absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
164       GUARDED_BY(compile_cache_mu_);
165 
166   struct ClusterCompileStats {
167     // Number of times the cluster has been (re-)compiled.
168     int64 compile_count = 0;
169 
170     // The number of times this cluster has been executed.
171     int64 execution_count = 0;
172 
173     // Cumulative time spent compiling the cluster.
174     int64 cumulative_compile_time_us = 0;
175 
176     // True if we have decided that this cluster is too dynamic (i.e. its shapes
177     // change too frequently) to profitably JIT compile.  Once a cluster is
178     // tagged megamorphic, it stays megamorphic forever.
179     bool is_megamorphic = false;
180   };
181 
182   mutex cluster_compile_stats_mu_;
183 
184   // Maps cluster names to compilation statistics for said cluster.
185   absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
186       GUARDED_BY(cluster_compile_stats_mu_);
187 
188   // The number of times a lazy compilation must be requested for a specific
189   // signature before  we attempt to compile it.
190   static constexpr int64 kDefaultCompilationThreshold = 2;
191 
192   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
193 };
194 
195 }  // namespace tensorflow
196 
197 #endif  // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
198