1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/runtime/jit_executable.h"
17
18 #include <memory>
19 #include <string_view>
20 #include <utility>
21
22 #include "llvm/ADT/Optional.h"
23 #include "tensorflow/compiler/xla/mlir/utils/runtime/constraints.h"
24
25 namespace xla {
26 namespace runtime {
27
28 using llvm::cast;
29 using llvm::dyn_cast;
30 using llvm::isa;
31
32 using llvm::ArrayRef;
33 using llvm::ErrorOr;
34 using llvm::Expected;
35 using llvm::Optional;
36
37 using tfrt::MakeAvailableAsyncValueRef;
38 using tfrt::MakeErrorAsyncValueRef;
39 using tfrt::MakeStringError;
40
41 using Specialization = JitExecutable::Specialization;
42
IsSpecializationOnly(ArrayRef<ArgumentConstraint> constraints)43 static bool IsSpecializationOnly(ArrayRef<ArgumentConstraint> constraints) {
44 return llvm::any_of(constraints, [](ArgumentConstraint constraint) {
45 return constraint != ArgumentConstraint::kResolved;
46 });
47 }
48
HasValueConstraints(ArrayRef<ArgumentConstraint> constraints)49 static bool HasValueConstraints(ArrayRef<ArgumentConstraint> constraints) {
50 return llvm::any_of(constraints, [](ArgumentConstraint constraint) {
51 return constraint == ArgumentConstraint::kValue;
52 });
53 }
54
55 // Returns true if all function operands have statically known shape.
HasStaticShapeOperands(const FunctionType & signature)56 static bool HasStaticShapeOperands(const FunctionType& signature) {
57 auto is_dynamic = [](ArrayRef<int64_t> sizes) -> bool {
58 return llvm::any_of(sizes, mlir::ShapedType::isDynamic);
59 };
60
61 for (unsigned i = 0; i < signature.num_operands(); ++i) {
62 const Type* type = signature.operand(i);
63
64 // Get the underlying value type from the async value.
65 while (auto* value = dyn_cast<AsyncValueType>(type))
66 type = &value->value_type();
67
68 // Unranked types do not have statically known shape.
69 if (isa<UnrankedTensorType, UnrankedMemrefType>(type)) return false;
70
71 // For ranked memrefs and tensors check known sizes.
72 if (auto* memref = dyn_cast<MemrefType>(type))
73 if (is_dynamic(memref->sizes())) return false;
74 if (auto* tensor = dyn_cast<RankedTensorType>(type))
75 if (is_dynamic(tensor->sizes())) return false;
76
77 // All other types are non-shaped and thus have "statically known shape".
78
79 // TODO(ezhulenev): Run time types might need to support type interfaces or
80 // a hierarchy with a base `ShapedType` so that users can define their own
81 // types that can participate in shape specialization. This becomes
82 // complicated for container-like types (e.g. async value) that might
83 // contain a nested type that is shaped (e.g. memref). For now only the
84 // canonical types can participate in shape specialization.
85 }
86
87 return true;
88 }
89
InlineCompilationTaskRunner(size_t num_specializations,ArrayRef<ArgumentConstraint> constraints,ArgumentsRef arguments,CompilationTask task,UserData user_data)90 /*static*/ void JitExecutable::InlineCompilationTaskRunner(
91 size_t num_specializations, ArrayRef<ArgumentConstraint> constraints,
92 ArgumentsRef arguments, CompilationTask task, UserData user_data) {
93 task();
94 }
95
Instantiate(std::string_view mlir_module,std::string_view entrypoint,Options opts,std::string_view memory_region_name,CompilationTaskRunner runner)96 /*static*/ Expected<JitExecutable> JitExecutable::Instantiate(
97 std::string_view mlir_module, std::string_view entrypoint, Options opts,
98 std::string_view memory_region_name, CompilationTaskRunner runner) {
99 // Try to instantiate compilation context from the mlir source.
100 Expected<std::unique_ptr<JitCompiler>> compiler =
101 JitCompiler::Instantiate(opts.compiler, mlir_module, entrypoint);
102 if (auto err = compiler.takeError()) return std::move(err);
103
104 // Get resolved operands constraints for the entrypoint function.
105 auto constraints = GetArgumentsConstraints((*compiler)->entrypoint());
106 if (auto err = constraints.takeError()) return std::move(err);
107
108 // Get the entrypoint function signature, it will be later required to
109 // compute the specialized function signature from the operands at runtime.
110 auto signature = opts.compiler.type_converter.Convert(
111 (*compiler)->entrypoint().getFunctionType());
112 if (auto err = signature.takeError()) return std::move(err);
113
114 // If all of the operands have static shape, then we can always use default
115 // binary for execution (unless specialization is explicitly required by the
116 // operands constraints).
117 if (HasStaticShapeOperands(*signature) && !IsSpecializationOnly(*constraints))
118 opts.specialization = Specialization::kDisabled;
119
120 // Return an error if specialization is explicitly disabled, yet some of
121 // the operands have unresolved constraints.
122 if (opts.specialization == Specialization::kDisabled &&
123 IsSpecializationOnly(*constraints))
124 return MakeStringError(
125 "compilation options disabled specialization, yet operands have "
126 "unresolved constraints: ",
127 *constraints);
128
129 // If the module must be specialized, return JitExecutable without a default
130 // compiled executable.
131 if (opts.specialization == Specialization::kAlways ||
132 IsSpecializationOnly(*constraints))
133 return JitExecutable(mlir_module, entrypoint, memory_region_name,
134 std::move(opts), std::move(*constraints),
135 std::move(*signature),
136 /*default_executable=*/llvm::None, std::move(runner));
137
138 // Otherwise try to compile the default executable.
139 Expected<Executable> executable =
140 JitCompiler::Compile(std::move(*compiler), memory_region_name);
141 if (auto err = executable.takeError()) return std::move(err);
142
143 return JitExecutable(mlir_module, entrypoint, memory_region_name,
144 std::move(opts), std::move(*constraints),
145 std::move(*signature), std::move(*executable),
146 std::move(runner));
147 }
148
JitExecutable(std::string_view mlir_module,std::string_view entrypoint,std::string_view memory_region_name,Options opts,ArrayRef<ArgumentConstraint> constraints,FunctionType signature,Optional<Executable> default_executable,CompilationTaskRunner runner)149 JitExecutable::JitExecutable(std::string_view mlir_module,
150 std::string_view entrypoint,
151 std::string_view memory_region_name, Options opts,
152 ArrayRef<ArgumentConstraint> constraints,
153 FunctionType signature,
154 Optional<Executable> default_executable,
155 CompilationTaskRunner runner)
156 : mlir_module_(mlir_module),
157 entrypoint_(entrypoint),
158 memory_region_name_(memory_region_name),
159 opts_(std::move(opts)),
160 constraints_(constraints.begin(), constraints.end()),
161 has_value_constraints_(HasValueConstraints(constraints_)),
162 signature_(std::move(signature)),
163 symbolic_shapes_resolver_(signature_, constraints_),
164 has_default_executable_(default_executable.has_value()),
165 runner_(std::move(runner)),
166 specializations_(std::make_unique<Specializations>()) {
167 // Initialize default executable if it is available.
168 if (has_default_executable_) {
169 default_executable_ =
170 MakeAvailableAsyncValueRef<Executable>(std::move(*default_executable));
171 } else {
172 default_executable_ =
173 MakeErrorAsyncValueRef("default executable is not available");
174 }
175 }
176
DefaultExecutable() const177 AsyncValuePtr<Executable> JitExecutable::DefaultExecutable() const {
178 return default_executable_.AsPtr();
179 }
180
constraints() const181 ArrayRef<ArgumentConstraint> JitExecutable::constraints() const {
182 return constraints_;
183 }
184
185 // Combines `hash` with a hash value computed from a value constrained operands.
CombineWithValueConstraineOperands(llvm::hash_code hash,ArgumentsRef arguments,ArrayRef<ArgumentConstraint> constraints)186 static llvm::hash_code CombineWithValueConstraineOperands(
187 llvm::hash_code hash, ArgumentsRef arguments,
188 ArrayRef<ArgumentConstraint> constraints) {
189 for (int i = 0; i < constraints.size(); ++i) {
190 if (LLVM_LIKELY(constraints[i] != ArgumentConstraint::kValue)) continue;
191
192 // TODO(ezhulenev): Currently we only support value specialization of Tensor
193 // operands (with MemrefDesc run time argument), it should be extended to
194 // support open type and argument hierarchies.
195 const MemrefDesc& memref = cast<MemrefDesc>(arguments[i]);
196 const auto* data = static_cast<uint8_t*>(memref.data());
197 size_t rank = memref.rank();
198 assert(rank == 0 || rank == 1);
199 size_t num_values = rank == 0 ? 1 : memref.size(0);
200 int64_t len = num_values * GetHostSize(memref.dtype());
201 hash = llvm::hash_combine(hash, llvm::hash_combine_range(data, data + len));
202 }
203 return hash;
204 }
205
206 // TODO(ezhulenev): The fast path should be free of mutex to find the
207 // pre-compiled specialization. Maybe use atomic pointers (multiple atomic
208 // pointers?) to keep the most commonly used specialization available without
209 // doing a lookup in the AsyncValuesCache.
210 //
211 // TODO(ezhulenev): The number of specializations should be bounded, ideally we
212 // should only keep N most common specializations, and for everything else
213 // fall back on the default executable. However what to do if default executable
214 // is not available, and the number of specializations is above N?
GetExecutable(ArgumentsRef arguments,UserData user_data,const SpecializationListener * listener)215 Expected<AsyncValuePtr<Executable>> JitExecutable::GetExecutable(
216 ArgumentsRef arguments, UserData user_data,
217 const SpecializationListener* listener) {
218 // Do not try to compile specialized executable if it is explicitly disabled.
219 if (opts_.specialization == Specialization::kDisabled)
220 return DefaultExecutable();
221
222 // The number of arguments must match the entrypoint signature.
223 if (LLVM_UNLIKELY(arguments.size() != signature_.num_operands()))
224 return MakeStringError("expected ", signature_.num_operands(),
225 " arguments, got: ", arguments.size());
226
227 // Resolve symbolic shapes hash based on the static and runtime information.
228 //
229 // We rely on the hash code to find the specialized executable. In case of
230 // a collision (practically impossible) incompatible arguments will be
231 // rejected by the executable arguments verification.
232 ErrorOr<llvm::hash_code> hash =
233 symbolic_shapes_resolver_.ResolveHash(arguments);
234
235 // If we failed to resolve the symbolic shapes hash, then we need to verify
236 // all the operands to find the mismatch and report it to the user.
237 if (LLVM_UNLIKELY(hash.getError())) {
238 for (unsigned i = 0; i < arguments.size(); ++i) {
239 auto* type = signature_.operand(i);
240
241 // TODO(ezhulenev): Support open shaped type/argument hierarchy.
242 auto* memref_arg = dyn_cast<MemrefDesc>(&arguments[i]);
243 if (!memref_arg) continue;
244
245 if (auto* memref = dyn_cast<MemrefType>(type)) {
246 if (auto err = VerifyMemrefArgument(i, *memref, *memref_arg))
247 return std::move(err);
248
249 } else if (auto* tensor = dyn_cast<RankedTensorType>(type)) {
250 if (auto err = VerifyMemrefArgument(i, *tensor, *memref_arg))
251 return std::move(err);
252
253 } else {
254 return MakeStringError("expected shaped operand at #", i,
255 ", got: ", *signature_.operand(i));
256 }
257 }
258
259 assert(false && "failed to detect incorrect operand");
260 return MakeStringError("failed to resolve symbolic shapes");
261 }
262
263 // Combine with a hash value computed from the value constrained operands.
264 if (LLVM_UNLIKELY(has_value_constraints_))
265 *hash = CombineWithValueConstraineOperands(*hash, arguments, constraints_);
266
267 // Maybe return Executable from the cache.
268 if (auto cached = specializations_->Find(*hash)) {
269 // Always use specialized kernel if required by the compilation options.
270 if (opts_.specialization == Specialization::kAlways) return cached;
271
272 // Fall back on default executable if the specialization is not yet
273 // available.
274 if (has_default_executable_ && !cached.IsAvailable())
275 return DefaultExecutable();
276
277 return cached;
278 }
279
280 // Instantiation from the source and specialization are cheap, so we do it in
281 // the caller thread. We only use compilation runner for expensive part.
282
283 // Try to instantiate compilation context from the mlir source.
284 Expected<std::unique_ptr<JitCompiler>> compiler =
285 JitCompiler::Instantiate(opts_.compiler, mlir_module_, entrypoint_);
286
287 if (auto err = compiler.takeError()) {
288 assert(false && "parsing mlir module must always succeed at this point");
289 return std::move(err);
290 }
291
292 // Specialize executable to the concrete operands.
293 ErrorOr<llvm::SmallVector<SymbolicShapesResolver::SymbolicShape>>
294 symbolic_shapes = symbolic_shapes_resolver_.Resolve(arguments);
295 if (auto err = (*compiler)->Specialize(arguments, *symbolic_shapes,
296 constraints_, listener)) {
297 return MakeStringError("failed to specialize executable: ", err);
298 }
299
300 // Allocate a placeholder for the compiled specialization only after we are
301 // ready to dispatch the compilation task.
302 Specializations::Entry entry = specializations_->Allocate(*hash);
303
304 // We lost the race; some other invocation will do the compilation.
305 if (!entry.allocated) return entry.ptr;
306
307 // Get the specialization id from the size of the specializations cache.
308 size_t specialization = entry.size - 1;
309
310 // Construct the task that will do the specialized executable compilation.
311 auto compile = CompilationTask(
312 [compiler = std::move(*compiler), ref = entry.ptr.CopyRef(),
313 memory_region_name = memory_region_name_, specialization]() mutable {
314 Expected<Executable> executable = JitCompiler::Compile(
315 std::move(compiler), memory_region_name, specialization);
316
317 // Set the allocated entry async value state to error or concrete.
318 if (auto err = executable.takeError()) {
319 ref.SetError(std::move(err));
320 } else {
321 ref.emplace(std::move(*executable));
322 }
323 });
324
325 // Offload specialization compilation to the user provided runner.
326 runner_(specialization, constraints_, arguments, std::move(compile),
327 user_data);
328
329 // Use the default executable while we are compiling a specialized version if
330 // this is not explicitly disabled by the compilation options.
331 if (opts_.specialization == Specialization::kAlways)
332 return entry.ptr;
333 else
334 return has_default_executable_ ? DefaultExecutable() : entry.ptr;
335 }
336
AllExecutablesCompiled() const337 AsyncValueRef<Chain> JitExecutable::AllExecutablesCompiled() const {
338 return specializations_->AllAvailable();
339 }
340
341 } // namespace runtime
342 } // namespace xla
343