• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_JIT_EXECUTABLE_H_
17 #define XLA_RUNTIME_JIT_EXECUTABLE_H_
18 
19 #include <any>
20 #include <memory>
21 #include <string>
22 #include <string_view>
23 
24 #include "tensorflow/compiler/xla/mlir/transforms/runtime/jit_compiler.h"
25 #include "tensorflow/compiler/xla/runtime/async_values_cache.h"
26 #include "tensorflow/compiler/xla/runtime/constraints.h"
27 
28 namespace xla {
29 namespace runtime {
30 
31 // JitExecutable owns a default executable compiled from the MLIR module (if
32 // operands constraints allow that), and orchestrates on-demand re-compilation
33 // for specific argument ranks, shapes or values depending on the operands
34 // constraints.
35 class JitExecutable {
36  public:
37   using UserData = std::any;
38 
39   // XLA program can be specialized and recompiled at runtime to the concrete
40   // input shapes and sometimes values (e.g. reduction dimension).
41   enum class Specialization {
42     // Recompile specialized kernels when needed.
43     kEnabled,
44     // Completely disable specialized kernels (always call default executable).
45     kDisabled,
46     // Always use specialized kernels, and never call default executable (only
47     // required for getting reproducible results in benchmarks).
48     kAlways,
49   };
50 
51   struct Options {
52     // What level of specialization is enabled at runtime.
53     Specialization specialization = Specialization::kAlways;
54 
55     // Options for the XLA runtime JitCompiler.
56     JitCompiler::Options compiler;
57   };
58 
59   // We use `llvm::unique_function` to represent compilation task because it
60   // allows to capture move-only values.
61   using CompilationTask = llvm::unique_function<void()>;
62 
63   // Compilation task runner called at runtime when specialization compilation
64   // is required with the `TaskFunction` that does the compilation, and updates
65   // the internal state of the `JitExecutable`. This runner can be used by the
66   // caller to offload compilation task to the specialized thread pool and
67   // add tracing events (e.g. add Tensorflow profiler tracing). Task runner must
68   // call the `TaskFunction`, otherwise it will lead to deadlock.
69   //
70   // Caller can pass arbitrary user data to the `GetExecutable` method, and it
71   // will be passed to the runner if recompilation is required. It is guaranteed
72   // that the runner will be called in the same thread as `GetExecutable`.
73   //
74   using CompilationTaskRunner =
75       llvm::unique_function<void(size_t, llvm::ArrayRef<ArgumentConstraint>,
76                                  ArgumentsRef, CompilationTask, UserData)>;
77 
78   // Inline compilation task runner runs compilation task in the caller thread.
79   static void InlineCompilationTaskRunner(
80       size_t num_specializations,
81       llvm::ArrayRef<ArgumentConstraint> constraints, ArgumentsRef arguments,
82       CompilationTask task, UserData user_data);
83 
84   static llvm::Expected<JitExecutable> Instantiate(
85       std::string_view mlir_module, std::string_view entrypoint, Options opts,
86       std::string_view memory_region_name = "",
87       CompilationTaskRunner runner = InlineCompilationTaskRunner);
88 
89   // Returns entrypoint operands constraints after resolving them using the
90   // statically known information in the entrypoint function signature.
91   llvm::ArrayRef<ArgumentConstraint> constraints() const;
92 
93   // Returns default executable that accepts all compatible operands
94   // (operands rank and all static dimensions should match the operands).
95   tfrt::AsyncValuePtr<Executable> DefaultExecutable() const;
96 
97   // Returns an executable that may be specialized for the arguments. Can return
98   // default executable if no specialization is required, or if the specialized
99   // executable is not yet available.
100   //
101   // Caller can pass arbitrary data via the `user_data` argument, and it will be
102   // available to the compilation task runner. This can be used for tracing,
103   // e.g. to track what user-level requests triggered recompilation.
104   //
105   // Returns an error if the arguments do not match the expected function
106   // signature and specialization is not possible (without trying to compile).
107   // If specialization is disabled, returns the default executable without
108   // checking the arguments (the default executable itself will check arguments
109   // when called).
110   //
111   // Async values holding compilation results (executables) cached in the
112   // JitExecutable, and successive calls with the same arguments are cheap (the
113   // definition of "same" depend on the argument type specialization and chosen
114   // hash function, e.g. shaped arguments compared using their symbolic shape).
115   // If compilation fails, then the returned async value will hold a compilation
116   // error message. Compilation errors are never retried.
117   //
118   // Note: This function never falls back on the default executable if
119   // specialization compilation fails.
120   llvm::Expected<tfrt::AsyncValuePtr<Executable>> GetExecutable(
121       ArgumentsRef arguments, UserData user_data = {},
122       const SpecializationListener* listener = nullptr);
123 
124   // Returns an async value that becomes ready when all executables owned by
125   // this JitExecutable are compiled (no pending compilation tasks).
126   tfrt::AsyncValueRef<tfrt::Chain> AllExecutablesCompiled() const;
127 
128   // JitExecutable is move-only type.
129   JitExecutable(const JitExecutable&) = delete;
130   JitExecutable(JitExecutable&&) = default;
131 
132  private:
133   JitExecutable(std::string_view mlir_module, std::string_view entrypoint,
134                 std::string_view memory_region_name, Options opts,
135                 llvm::ArrayRef<ArgumentConstraint> constraints,
136                 FunctionType signature,
137                 llvm::Optional<Executable> default_executable,
138                 CompilationTaskRunner runner);
139 
140   std::string mlir_module_;
141   std::string entrypoint_;
142 
143   // Name of the memory region where JIT'ed code is compiled to.
144   // This allows profilers to correctly label JIT-executed code.
145   // Note: this feature might only be available on some platforms, e.g. Linux.
146   std::string memory_region_name_;
147 
148   Options opts_;
149 
150   // Entrypoint operands constraints after resolving them using the statically
151   // known information in the entrypoint function signature. If constraint
152   // specified by the argument attribute known to be statically satisfied by the
153   // operand type (e.g. rank constraint with an operand of statically known
154   // rank), then the constraint value for that operand will be updated to
155   // `kResolved`.
156   llvm::SmallVector<ArgumentConstraint> constraints_;
157 
158   // True if any of the operands has `ArgumentConstraint::kValue` constraint.
159   bool has_value_constraints_;
160 
161   // Signature of the compiled module entrypoint function.
162   //
163   // This function signature is allowed to have operands and results types
164   // without a well-defined ABI (e.g. it can have tensors when compiled module
165   // defined in Tensorflow dialect), and it corresponds to the kernel definition
166   // in one of the high level dialects (e.g. Tensorflow or mHLO).
167   //
168   // When compiled module prepared for execution, function operands and results
169   // are mapped to the types with well-defined ABI (e.g. tensors mapped to
170   // memrefs). See `signature_` documentation in the `Executable` type.
171   FunctionType signature_;
172 
173   // Symbolic shape resolver assigns symbolic dimensions to runtime operands
174   // based on the entrypoint function signature.
175   SymbolicShapesResolver symbolic_shapes_resolver_;
176 
177   // Default executable that was not specialized to any of the arguments.
178   AsyncValueRef<Executable> default_executable_;
179   bool has_default_executable_;
180 
181   // A custom runner for compiling specializations.
182   CompilationTaskRunner runner_;
183 
184   // Executables specialized for the arguments shapes or/and values.
185   using Specializations = AsyncValuesCache<llvm::hash_code, Executable>;
186   std::unique_ptr<Specializations> specializations_;
187 };
188 
189 }  // namespace runtime
190 }  // namespace xla
191 
192 #endif  // XLA_RUNTIME_JIT_EXECUTABLE_H_
193