1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_RUNTIME_JIT_EXECUTABLE_H_ 17 #define XLA_RUNTIME_JIT_EXECUTABLE_H_ 18 19 #include <any> 20 #include <memory> 21 #include <string> 22 #include <string_view> 23 24 #include "tensorflow/compiler/xla/mlir/transforms/runtime/jit_compiler.h" 25 #include "tensorflow/compiler/xla/runtime/async_values_cache.h" 26 #include "tensorflow/compiler/xla/runtime/constraints.h" 27 28 namespace xla { 29 namespace runtime { 30 31 // JitExecutable owns a default executable compiled from the MLIR module (if 32 // operands constraints allow that), and orchestrates on-demand re-compilation 33 // for specific argument ranks, shapes or values depending on the operands 34 // constraints. 35 class JitExecutable { 36 public: 37 using UserData = std::any; 38 39 // XLA program can be specialized and recompiled at runtime to the concrete 40 // input shapes and sometimes values (e.g. reduction dimension). 41 enum class Specialization { 42 // Recompile specialized kernels when needed. 43 kEnabled, 44 // Completely disable specialized kernels (always call default executable). 45 kDisabled, 46 // Always use specialized kernels, and never call default executable (only 47 // required for getting reproducible results in benchmarks). 48 kAlways, 49 }; 50 51 struct Options { 52 // What level of specialization is enabled at runtime. 53 Specialization specialization = Specialization::kAlways; 54 55 // Options for the XLA runtime JitCompiler. 56 JitCompiler::Options compiler; 57 }; 58 59 // We use `llvm::unique_function` to represent compilation task because it 60 // allows to capture move-only values. 61 using CompilationTask = llvm::unique_function<void()>; 62 63 // Compilation task runner called at runtime when specialization compilation 64 // is required with the `TaskFunction` that does the compilation, and updates 65 // the internal state of the `JitExecutable`. This runner can be used by the 66 // caller to offload compilation task to the specialized thread pool and 67 // add tracing events (e.g. add Tensorflow profiler tracing). Task runner must 68 // call the `TaskFunction`, otherwise it will lead to deadlock. 69 // 70 // Caller can pass arbitrary user data to the `GetExecutable` method, and it 71 // will be passed to the runner if recompilation is required. It is guaranteed 72 // that the runner will be called in the same thread as `GetExecutable`. 73 // 74 using CompilationTaskRunner = 75 llvm::unique_function<void(size_t, llvm::ArrayRef<ArgumentConstraint>, 76 ArgumentsRef, CompilationTask, UserData)>; 77 78 // Inline compilation task runner runs compilation task in the caller thread. 79 static void InlineCompilationTaskRunner( 80 size_t num_specializations, 81 llvm::ArrayRef<ArgumentConstraint> constraints, ArgumentsRef arguments, 82 CompilationTask task, UserData user_data); 83 84 static llvm::Expected<JitExecutable> Instantiate( 85 std::string_view mlir_module, std::string_view entrypoint, Options opts, 86 std::string_view memory_region_name = "", 87 CompilationTaskRunner runner = InlineCompilationTaskRunner); 88 89 // Returns entrypoint operands constraints after resolving them using the 90 // statically known information in the entrypoint function signature. 91 llvm::ArrayRef<ArgumentConstraint> constraints() const; 92 93 // Returns default executable that accepts all compatible operands 94 // (operands rank and all static dimensions should match the operands). 95 tfrt::AsyncValuePtr<Executable> DefaultExecutable() const; 96 97 // Returns an executable that may be specialized for the arguments. Can return 98 // default executable if no specialization is required, or if the specialized 99 // executable is not yet available. 100 // 101 // Caller can pass arbitrary data via the `user_data` argument, and it will be 102 // available to the compilation task runner. This can be used for tracing, 103 // e.g. to track what user-level requests triggered recompilation. 104 // 105 // Returns an error if the arguments do not match the expected function 106 // signature and specialization is not possible (without trying to compile). 107 // If specialization is disabled, returns the default executable without 108 // checking the arguments (the default executable itself will check arguments 109 // when called). 110 // 111 // Async values holding compilation results (executables) cached in the 112 // JitExecutable, and successive calls with the same arguments are cheap (the 113 // definition of "same" depend on the argument type specialization and chosen 114 // hash function, e.g. shaped arguments compared using their symbolic shape). 115 // If compilation fails, then the returned async value will hold a compilation 116 // error message. Compilation errors are never retried. 117 // 118 // Note: This function never falls back on the default executable if 119 // specialization compilation fails. 120 llvm::Expected<tfrt::AsyncValuePtr<Executable>> GetExecutable( 121 ArgumentsRef arguments, UserData user_data = {}, 122 const SpecializationListener* listener = nullptr); 123 124 // Returns an async value that becomes ready when all executables owned by 125 // this JitExecutable are compiled (no pending compilation tasks). 126 tfrt::AsyncValueRef<tfrt::Chain> AllExecutablesCompiled() const; 127 128 // JitExecutable is move-only type. 129 JitExecutable(const JitExecutable&) = delete; 130 JitExecutable(JitExecutable&&) = default; 131 132 private: 133 JitExecutable(std::string_view mlir_module, std::string_view entrypoint, 134 std::string_view memory_region_name, Options opts, 135 llvm::ArrayRef<ArgumentConstraint> constraints, 136 FunctionType signature, 137 llvm::Optional<Executable> default_executable, 138 CompilationTaskRunner runner); 139 140 std::string mlir_module_; 141 std::string entrypoint_; 142 143 // Name of the memory region where JIT'ed code is compiled to. 144 // This allows profilers to correctly label JIT-executed code. 145 // Note: this feature might only be available on some platforms, e.g. Linux. 146 std::string memory_region_name_; 147 148 Options opts_; 149 150 // Entrypoint operands constraints after resolving them using the statically 151 // known information in the entrypoint function signature. If constraint 152 // specified by the argument attribute known to be statically satisfied by the 153 // operand type (e.g. rank constraint with an operand of statically known 154 // rank), then the constraint value for that operand will be updated to 155 // `kResolved`. 156 llvm::SmallVector<ArgumentConstraint> constraints_; 157 158 // True if any of the operands has `ArgumentConstraint::kValue` constraint. 159 bool has_value_constraints_; 160 161 // Signature of the compiled module entrypoint function. 162 // 163 // This function signature is allowed to have operands and results types 164 // without a well-defined ABI (e.g. it can have tensors when compiled module 165 // defined in Tensorflow dialect), and it corresponds to the kernel definition 166 // in one of the high level dialects (e.g. Tensorflow or mHLO). 167 // 168 // When compiled module prepared for execution, function operands and results 169 // are mapped to the types with well-defined ABI (e.g. tensors mapped to 170 // memrefs). See `signature_` documentation in the `Executable` type. 171 FunctionType signature_; 172 173 // Symbolic shape resolver assigns symbolic dimensions to runtime operands 174 // based on the entrypoint function signature. 175 SymbolicShapesResolver symbolic_shapes_resolver_; 176 177 // Default executable that was not specialized to any of the arguments. 178 AsyncValueRef<Executable> default_executable_; 179 bool has_default_executable_; 180 181 // A custom runner for compiling specializations. 182 CompilationTaskRunner runner_; 183 184 // Executables specialized for the arguments shapes or/and values. 185 using Specializations = AsyncValuesCache<llvm::hash_code, Executable>; 186 std::unique_ptr<Specializations> specializations_; 187 }; 188 189 } // namespace runtime 190 } // namespace xla 191 192 #endif // XLA_RUNTIME_JIT_EXECUTABLE_H_ 193