• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
17 #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/inlined_vector.h"
21 #include "absl/types/optional.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
24 #include "tensorflow/compiler/tf2xla/xla_context.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 
35 namespace tensorflow {
36 
37 // The XlaCompilationCache class caches the results of the XlaCompiler class,
38 // which converts a Tensorflow graph into a compiled XLA compilation.
39 //
40 // Since XLA computations must have static shapes, the cache generates a new
41 // XLA computation for each new set of input shapes.
42 //
43 // Currently no cache eviction policy is implemented and the cache grows without
44 // bound.
45 class XlaCompilationCache : public ResourceBase {
46  public:
47   XlaCompilationCache(xla::LocalClient* client, DeviceType device_type);
48   ~XlaCompilationCache() override;
49 
50   enum class CompileMode {
51     kLazy,
52     kStrict,
53   };
54 
55   // Compiles a function into a XlaCompiler::CompilationResult that can be used
56   // to execute an XLA Computation. Compilation results are cached.
57   // `function` is the name of a Tensorflow function to compile.
58   // `args` is a description of the arguments to the computation.
59   //
60   // `compile_mode` controls the behavior of the compilation cache on a cache
61   // miss.  If `compile_mode` is `kLazy` then, based on some profitability
62   // heuristics, the compilation cache may decide not to compile the cluster at
63   // this time.  In this case it returns null into both `out_compilation_result`
64   // and `out_executable`.  If `compile_mode` is `kStrict` then the compilation
65   // cache always attempts the compilation on a cache miss.
66   //
67   // The result of compilation is written to `*out_compilation_result`, which
68   // must be non-null. If `out_executable` is non-null, also builds an
69   // xla::LocalExecutable and sets `out_executable` to point to it. The
70   // resulting executable pointer may be null if the computation has no
71   // non-constant outputs.
72   Status Compile(const XlaCompiler::Options& options,
73                  const NameAttrList& function,
74                  absl::Span<const XlaCompiler::Argument> args,
75                  const XlaCompiler::CompileOptions& compile_options,
76                  CompileMode compile_mode,
77                  const XlaCompiler::CompilationResult** out_compilation_result,
78                  xla::LocalExecutable** out_executable);
79 
80   // As above, but calls XlaCompiler::CompileSingleOp instead of
81   // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto
82   // in OpKernelContext, then uses MLIR bridge for compilation instead of
83   // XlaCompiler, if possible.
84   Status CompileSingleOp(
85       const XlaCompiler::Options& options,
86       absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
87       const XlaCompiler::CompileOptions& compile_options,
88       const XlaCompiler::CompilationResult** out_compilation_result,
89       xla::LocalExecutable** out_executable);
90 
client()91   xla::LocalClient* client() const { return client_; }
device_type()92   const DeviceType& device_type() const { return device_type_; }
93 
94   string DebugString() const override;
95 
96   // Describes the types, shapes and any compile-time constant arguments
97   // to a kernel. Key that uniquely identifies a compilation output.
98   struct Signature {
99     string name;
100 
101     // List of Tensor types & shapes for compile-time constant arguments to the
102     // compilation, ordered by argument number.
103     absl::InlinedVector<std::pair<DataType, absl::InlinedVector<int64, 4>>, 4>
104         arg_shapes;
105 
106     // List of Tensor values for compile-time constant arguments to the
107     // compilation, ordered by argument number. Tensors must be in host memory.
108     absl::InlinedVector<Tensor, 4> arg_values;
109 
110     bool operator==(const Signature& other) const;
111 
112     struct Hash {
113       uint64 operator()(const Signature& signature) const;
114     };
115 
116     // Returns a human-readable description of the signature.
117     string HumanString() const;
118   };
119 
120   // Builds the signature for a compilation.
121   static xla::StatusOr<Signature> BuildSignature(
122       const NameAttrList& function,
123       absl::Span<const XlaCompiler::Argument> args);
124 
125  private:
126   // Common implementation of Compile and CompileSingleOp.
127   Status CompileImpl(
128       const XlaCompiler::Options& options, const NameAttrList& function,
129       absl::Span<const XlaCompiler::Argument> args,
130       const std::function<Status(XlaCompiler* compiler,
131                                  XlaCompiler::CompilationResult*)>& compile_fn,
132       absl::optional<int64> compile_threshold,
133       const XlaCompiler::CompilationResult** out_compilation_result,
134       xla::LocalExecutable** out_executable);
135 
136   // Takes `result` which has been compiled from a Tensorflow subgraph to a
137   // XLA computation already, and generates an XLA LocalExecutable `executable`.
138   Status BuildExecutable(const XlaCompiler::Options& options,
139                          const XlaCompiler::CompilationResult& result,
140                          std::unique_ptr<xla::LocalExecutable>* executable);
141 
142   xla::LocalClient* const client_;
143   const DeviceType device_type_;
144 
145   // The value associated with a cache entry.
146   struct Entry {
147     mutex mu;
148 
149     // Have we tried compiling this entry?
150     bool compiled = false;
151 
152     // The number of times a compilation with this signature has been requested.
153     int64 request_count = 0;
154 
155     // Did compilation succeed?
156     Status compilation_status TF_GUARDED_BY(mu);
157 
158     // Output of the XlaCompiler.
159     XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu);
160 
161     // The XLA executable compiled from <computation>. May be null if no
162     // executable has been built.
163     std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu);
164   };
165 
166   mutex compile_cache_mu_;
167   absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
168       TF_GUARDED_BY(compile_cache_mu_);
169 
170   struct ClusterCompileStats {
171     // Number of times the cluster has been (re-)compiled.
172     int64 compile_count = 0;
173 
174     // The number of times this cluster has been executed.
175     int64 execution_count = 0;
176 
177     // Cumulative time spent compiling the cluster.
178     int64 cumulative_compile_time_us = 0;
179 
180     // True if we have decided that this cluster is too dynamic (i.e. its shapes
181     // change too frequently) to profitably JIT compile.  Once a cluster is
182     // tagged megamorphic, it stays megamorphic forever.
183     bool is_megamorphic = false;
184   };
185 
186   mutex cluster_compile_stats_mu_;
187 
188   // Maps cluster names to compilation statistics for said cluster.
189   absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
190       TF_GUARDED_BY(cluster_compile_stats_mu_);
191 
192   // The number of times a lazy compilation must be requested for a specific
193   // signature before  we attempt to compile it.
194   static constexpr int64 kDefaultCompilationThreshold = 2;
195 
196   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
197 };
198 
199 }  // namespace tensorflow
200 
201 #endif  // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
202